Major bug fix

This commit is contained in:
Nuwan Yapa 2025-04-30 01:05:53 -05:00
parent b30ee6e1fb
commit 336a6a1020
1 changed files with 3 additions and 3 deletions

View File

@ -50,7 +50,7 @@ c_steps = np.delete(c_steps, 0) # remove the first point (c=0)
optimizer = torch.optim.Adam([H0, H1])
# Training loop
epochs = 200000
epochs = 50000
for epoch in range(epochs):
ks = torch.empty(len(train_data), dtype=torch.complex128)
kvs = torch.empty(len(train_data), dtype=torch.complex128)
@ -70,10 +70,10 @@ for epoch in range(epochs):
loss = ((ks - train_ks).abs() ** 2).sum()
# push virtual states towards (-)bound and then to the imaginary axis
if epoch/epochs < 0.25:
if epoch/epochs < 0.75:
loss += ((kvs + train_ks).abs() ** 2).sum()
else:
loss += (kvs.imag ** 2).sum()
loss += (kvs.real ** 2).sum()
if epoch % 1000 == 0:
print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")