Better constraint virtual state

This commit is contained in:
Nuwan Yapa 2025-04-29 20:55:50 -05:00
parent a761c6d0de
commit c1eb71c231
1 changed files with 4 additions and 0 deletions

View File

@ -64,8 +64,12 @@ for epoch in range(epochs):
kvs[index] = current_kv kvs[index] = current_kv
loss = ((ks - train_ks).abs() ** 2).sum() loss = ((ks - train_ks).abs() ** 2).sum()
# push virtual states towards (-)bound and then to the imaginary axis
if epoch/epochs < 0.25: if epoch/epochs < 0.25:
loss += ((kvs + train_ks).abs() ** 2).sum() loss += ((kvs + train_ks).abs() ** 2).sum()
else:
loss += (kvs.imag ** 2).sum()
if epoch % 1000 == 0: if epoch % 1000 == 0:
print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}") print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")