From a761c6d0de3585f8dfc5d51112a3cd86d0a1ed1b Mon Sep 17 00:00:00 2001 From: Nuwan Yapa Date: Tue, 29 Apr 2025 20:30:36 -0500 Subject: [PATCH] Train for virtual trajectory --- calculations/PMM.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/calculations/PMM.py b/calculations/PMM.py index 397cd5e..fd319a6 100644 --- a/calculations/PMM.py +++ b/calculations/PMM.py @@ -43,21 +43,29 @@ enforce_ep() subdivisions = 2 c_steps = np.concatenate([np.linspace(start, stop, subdivisions, endpoint=False) for (start, stop) in zip(train_cs, train_cs[1:])]) c_steps = np.append(c_steps, train_cs[-1]) +c_steps = np.delete(c_steps, 0) # remove the first point (c=0) lr = 0.01 -epochs = 100000 +epochs = 200000 for epoch in range(epochs): ks = torch.empty(len(train_data), dtype=torch.complex128) + kvs = torch.empty(len(train_data), dtype=torch.complex128) current_k = 0.0 # start at the threshold + current_kv = 0.0 # start at the threshold for c in c_steps: H = H0 + c * H1 evals = torch.linalg.eigvals(H) current_k = evals[torch.argmin(torch.abs(evals - current_k))] + evals = evals[evals != current_k] # remove selected k + current_kv = evals[torch.argmin(torch.abs(evals - current_kv))] if np.any(c == train_cs): index = np.where(c == train_cs)[0][0] ks[index] = current_k + kvs[index] = current_kv loss = ((ks - train_ks).abs() ** 2).sum() + if epoch/epochs < 0.25: + loss += ((kvs + train_ks).abs() ** 2).sum() if epoch % 1000 == 0: print(f"Training {(epoch+1)/epochs:.1%} \t Loss: {loss}")