Compare commits

...

2 Commits

Author SHA1 Message Date
Nuwan Yapa 29e34d54c6 Bug fix in extrapolation 2025-04-15 15:16:15 -04:00
Nuwan Yapa 665016fbb6 Can train, but cannot extrapolate resonances 2025-04-15 14:12:34 -04:00
1 changed files with 18 additions and 14 deletions

View File

@ -15,26 +15,30 @@ data_E = [quick_pole_E(V_system(c)) for c in data_c]
N = 9 N = 9
# initialize random Hamiltonians # initialize random Hamiltonians
H0 = randn(ComplexF64, N, N) H0 = randn(N, N)
H0 = H0 + transpose(H0) # symmetric H1 = randn(N, N)
H1 = randn(ComplexF64, N, N)
H1 = H1 + transpose(H1) # symmetric
# training # training
Es = ComplexF64[] Es = ComplexF64[]
ψs = Vector{ComplexF64}[] ψrs = Vector{ComplexF64}[]
ψls = Vector{ComplexF64}[]
lr = 0.05 lr = 0.05
epochs = 100000 epochs = 100000
for epoch in 1:epochs for epoch in 1:epochs
empty!(Es) empty!(Es)
empty!(ψs) empty!(ψrs)
empty!(ψls)
for (c, E) in zip(data_c, data_E) for (c, E) in zip(data_c, data_E)
H = H0 + c * H1 H = H0 + c * H1
evals, evecs = eigen(H) r_evals, r_evecs = eigen(H)
i = nearestIndex(evals, E) # TODO: more robust way to identify the eigenvector l_evals, l_evecs = eigen(transpose(H))
push!(Es, evals[i]) @assert all(r_evals .≈ l_evals) "Right/left eigenvalues do not match"
push!(ψs, evecs[:, i]) i = nearestIndex(r_evals, E) # TODO: more robust way to identify the eigenvector
push!(Es, r_evals[i])
push!(ψrs, r_evecs[:, i])
push!(ψls, l_evecs[:, i])
end end
if epoch % 1000 == 0 if epoch % 1000 == 0
@ -45,8 +49,8 @@ for epoch in 1:epochs
# gradient of the loss function # gradient of the loss function
function grad(c_order=0) function grad(c_order=0)
out = zeros(ComplexF64, N, N) out = zeros(ComplexF64, N, N)
for (c, E_target, ψ, E) in zip(data_c, data_E, ψs, Es) for (c, E_target, ψr, ψl, E) in zip(data_c, data_E, ψrs, ψls, Es)
out .+= (c^c_order * conj(E - E_target)) .* (ψ * transpose(ψ)) out .+= (c^c_order * conj(E - E_target)) .* (ψl * transpose(ψr))
end end
return 2 .* real.(out) return 2 .* real.(out)
end end
@ -58,11 +62,11 @@ end
all_c = vcat(training_c, extrapolating_c) all_c = vcat(training_c, extrapolating_c)
exact_E = [quick_pole_E(V_system(c)) for c in all_c] exact_E = [quick_pole_E(V_system(c)) for c in all_c]
extrapolated_E = ComplexF64[] extrapolated_E = ComplexF64[]
for c in all_c for (c, ref) in zip(all_c, exact_E)
H = H0 + c * H1 H = H0 + c * H1
evals, evecs = eigen(H) evals, evecs = eigen(H)
evals = vcat(evals, conj.(evals)) # include complex conjugates evals = vcat(evals, conj.(evals)) # include complex conjugates
push!(extrapolated_E, nearest(evals, exact_E[end])) push!(extrapolated_E, nearest(evals, ref))
end end
# plot results # plot results