Major bug fix
This commit is contained in:
parent
b30ee6e1fb
commit
336a6a1020
|
|
@ -50,7 +50,7 @@ c_steps = np.delete(c_steps, 0) # remove the first point (c=0)
|
|||
optimizer = torch.optim.Adam([H0, H1])
|
||||
|
||||
# Training loop
|
||||
epochs = 200000
|
||||
epochs = 50000
|
||||
for epoch in range(epochs):
|
||||
ks = torch.empty(len(train_data), dtype=torch.complex128)
|
||||
kvs = torch.empty(len(train_data), dtype=torch.complex128)
|
||||
|
|
@ -70,10 +70,10 @@ for epoch in range(epochs):
|
|||
loss = ((ks - train_ks).abs() ** 2).sum()
|
||||
|
||||
# push virtual states towards (-)bound and then to the imaginary axis
|
||||
if epoch/epochs < 0.25:
|
||||
if epoch/epochs < 0.75:
|
||||
loss += ((kvs + train_ks).abs() ** 2).sum()
|
||||
else:
|
||||
loss += (kvs.imag ** 2).sum()
|
||||
loss += (kvs.real ** 2).sum()
|
||||
|
||||
if epoch % 1000 == 0:
|
||||
print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue