Train for virtual trajectory

This commit is contained in:
Nuwan Yapa 2025-04-29 20:30:36 -05:00
parent e68370cf7d
commit a761c6d0de
1 changed files with 9 additions and 1 deletions

View File

@ -43,21 +43,29 @@ enforce_ep()
subdivisions = 2
c_steps = np.concatenate([np.linspace(start, stop, subdivisions, endpoint=False) for (start, stop) in zip(train_cs, train_cs[1:])])
c_steps = np.append(c_steps, train_cs[-1])
c_steps = np.delete(c_steps, 0) # remove the first point (c=0)
lr = 0.01
epochs = 100000
epochs = 200000
for epoch in range(epochs):
ks = torch.empty(len(train_data), dtype=torch.complex128)
kvs = torch.empty(len(train_data), dtype=torch.complex128)
current_k = 0.0 # start at the threshold
current_kv = 0.0 # start at the threshold
for c in c_steps:
H = H0 + c * H1
evals = torch.linalg.eigvals(H)
current_k = evals[torch.argmin(torch.abs(evals - current_k))]
evals = evals[evals != current_k] # remove selected k
current_kv = evals[torch.argmin(torch.abs(evals - current_kv))]
if np.any(c == train_cs):
index = np.where(c == train_cs)[0][0]
ks[index] = current_k
kvs[index] = current_kv
loss = ((ks - train_ks).abs() ** 2).sum()
if epoch/epochs < 0.25:
loss += ((kvs + train_ks).abs() ** 2).sum()
if epoch % 1000 == 0:
print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")