diff --git a/calculations/PMM.py b/calculations/PMM.py index fd319a6..e75f5df 100644 --- a/calculations/PMM.py +++ b/calculations/PMM.py @@ -64,8 +64,12 @@ for epoch in range(epochs): kvs[index] = current_kv loss = ((ks - train_ks).abs() ** 2).sum() + + # push virtual states towards (-)bound and then to the imaginary axis if epoch/epochs < 0.25: loss += ((kvs + train_ks).abs() ** 2).sum() + else: + loss += (kvs.imag ** 2).sum() if epoch % 1000 == 0: print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")