using LinearAlgebra, Random, Plots include("../p_space.jl") μ = 0.5 V_system(c) = (p, q) -> c*(-5*g0(sqrt(3), p, q) + 2*g0(sqrt(10), p, q)) # ResonanceEC: Eq. (20) training_c = range(1.2, 0.9, 9) # original: range(1.35, 0.9, 5) extrapolating_c = range(0.78, 0.45, 7) # original: range(0.75, 0.40, 8) # calculate training data data_c = training_c data_E = [quick_pole_E(V_system(c)) for c in data_c] # hyperparameters N = 9 # initialize random Hamiltonians H0 = randn(ComplexF64, N, N) H0 = H0 + transpose(H0) # symmetric H1 = randn(ComplexF64, N, N) H1 = H1 + transpose(H1) # symmetric # define the inner product for vectors inner(ψ1, ψ2) = only(transpose(ψ1) * ψ2) normalized(ψ) = ψ ./ sqrt(inner(ψ, ψ)) overlap(ψ1_hat, ψ2) = inner(ψ1_hat, ψ2) / sqrt(inner(ψ2, ψ2)) # training Es = ComplexF64[] ψs = Vector{ComplexF64}[] lr = 0.05 epochs = 100000 for epoch in 1:epochs last_ψ = isempty(ψs) ? nothing : ψs[1] empty!(Es) empty!(ψs) for (c, E) in zip(data_c, data_E) H = H0 + c * H1 evals, evecs = eigen(H) # identification of the eigenstate if isnothing(last_ψ) i = nearestIndex(evals, E) else overlaps = [abs(overlap(last_ψ, evecs[:, i])) for i in 1:N] i = argmax(overlaps) if epoch > 0.01epochs # check agreement with eigenvalues after 1% of epochs i ≠ nearestIndex(evals, E) && @warn("Identification via overlap contradicts eigenvalues") end end last_ψ = evecs[:, i] |> normalized push!(ψs, last_ψ) push!(Es, evals[i]) end if epoch % 1000 == 0 loss = sum(abs2, Es .- data_E) println("Epoch:$epoch/$epochs \t Loss: $loss") end # gradient of the loss function function grad(c_order=0) out = zeros(ComplexF64, N, N) for (c, E_target, ψ, E) in zip(data_c, data_E, ψs, Es) out .+= (c^c_order * conj(E - E_target)) .* (ψ * transpose(ψ)) end return 2 .* real.(out) end H0 .-= lr .* grad(0) # update H0 H1 .-= lr .* grad(1) # update H1 end # evaluate for all points all_c = vcat(training_c, extrapolating_c) exact_E = [quick_pole_E(V_system(c)) for c in all_c] extrapolated_E = ComplexF64[] for (c, ref) in zip(all_c, exact_E) H = H0 + c * H1 evals, evecs = eigen(H) evals = vcat(evals, conj.(evals)) # include complex conjugates push!(extrapolated_E, nearest(evals, ref)) end # plot results scatter(real.(data_E), imag.(data_E), label="training", title="PMM", xlabel="Re E", ylabel="Im E") scatter!(real.(exact_E), imag.(exact_E), label="exact", m=:+) scatter!(real.(extrapolated_E), imag.(extrapolated_E), label="predicted", m=:x)