Train for virtual trajectory
This commit is contained in:
parent
e68370cf7d
commit
a761c6d0de
|
|
@ -43,21 +43,29 @@ enforce_ep()
|
||||||
subdivisions = 2
|
subdivisions = 2
|
||||||
c_steps = np.concatenate([np.linspace(start, stop, subdivisions, endpoint=False) for (start, stop) in zip(train_cs, train_cs[1:])])
|
c_steps = np.concatenate([np.linspace(start, stop, subdivisions, endpoint=False) for (start, stop) in zip(train_cs, train_cs[1:])])
|
||||||
c_steps = np.append(c_steps, train_cs[-1])
|
c_steps = np.append(c_steps, train_cs[-1])
|
||||||
|
c_steps = np.delete(c_steps, 0) # remove the first point (c=0)
|
||||||
|
|
||||||
lr = 0.01
|
lr = 0.01
|
||||||
epochs = 100000
|
epochs = 200000
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
ks = torch.empty(len(train_data), dtype=torch.complex128)
|
ks = torch.empty(len(train_data), dtype=torch.complex128)
|
||||||
|
kvs = torch.empty(len(train_data), dtype=torch.complex128)
|
||||||
current_k = 0.0 # start at the threshold
|
current_k = 0.0 # start at the threshold
|
||||||
|
current_kv = 0.0 # start at the threshold
|
||||||
for c in c_steps:
|
for c in c_steps:
|
||||||
H = H0 + c * H1
|
H = H0 + c * H1
|
||||||
evals = torch.linalg.eigvals(H)
|
evals = torch.linalg.eigvals(H)
|
||||||
current_k = evals[torch.argmin(torch.abs(evals - current_k))]
|
current_k = evals[torch.argmin(torch.abs(evals - current_k))]
|
||||||
|
evals = evals[evals != current_k] # remove selected k
|
||||||
|
current_kv = evals[torch.argmin(torch.abs(evals - current_kv))]
|
||||||
if np.any(c == train_cs):
|
if np.any(c == train_cs):
|
||||||
index = np.where(c == train_cs)[0][0]
|
index = np.where(c == train_cs)[0][0]
|
||||||
ks[index] = current_k
|
ks[index] = current_k
|
||||||
|
kvs[index] = current_kv
|
||||||
|
|
||||||
loss = ((ks - train_ks).abs() ** 2).sum()
|
loss = ((ks - train_ks).abs() ** 2).sum()
|
||||||
|
if epoch/epochs < 0.25:
|
||||||
|
loss += ((kvs + train_ks).abs() ** 2).sum()
|
||||||
|
|
||||||
if epoch % 1000 == 0:
|
if epoch % 1000 == 0:
|
||||||
print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")
|
print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue