Major bug fix
This commit is contained in:
parent
b30ee6e1fb
commit
336a6a1020
|
|
@ -50,7 +50,7 @@ c_steps = np.delete(c_steps, 0) # remove the first point (c=0)
|
||||||
optimizer = torch.optim.Adam([H0, H1])
|
optimizer = torch.optim.Adam([H0, H1])
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
epochs = 200000
|
epochs = 50000
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
ks = torch.empty(len(train_data), dtype=torch.complex128)
|
ks = torch.empty(len(train_data), dtype=torch.complex128)
|
||||||
kvs = torch.empty(len(train_data), dtype=torch.complex128)
|
kvs = torch.empty(len(train_data), dtype=torch.complex128)
|
||||||
|
|
@ -70,10 +70,10 @@ for epoch in range(epochs):
|
||||||
loss = ((ks - train_ks).abs() ** 2).sum()
|
loss = ((ks - train_ks).abs() ** 2).sum()
|
||||||
|
|
||||||
# push virtual states towards (-)bound and then to the imaginary axis
|
# push virtual states towards (-)bound and then to the imaginary axis
|
||||||
if epoch/epochs < 0.25:
|
if epoch/epochs < 0.75:
|
||||||
loss += ((kvs + train_ks).abs() ** 2).sum()
|
loss += ((kvs + train_ks).abs() ** 2).sum()
|
||||||
else:
|
else:
|
||||||
loss += (kvs.imag ** 2).sum()
|
loss += (kvs.real ** 2).sum()
|
||||||
|
|
||||||
if epoch % 1000 == 0:
|
if epoch % 1000 == 0:
|
||||||
print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")
|
print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue