PMM implemented
This commit is contained in:
parent
7f902cda92
commit
56914be36b
|
|
@ -0,0 +1,59 @@
|
||||||
|
using LinearAlgebra, Random, Plots
|
||||||
|
include("../p_space.jl")
|
||||||
|
|
||||||
|
μ = 0.5
|
||||||
|
V_system(c) = (p, q) -> c*(-5*g0(sqrt(3), p, q) + 2*g0(sqrt(10), p, q)) # ResonanceEC: Eq. (20)
|
||||||
|
|
||||||
|
training_c = range(1.2, 0.9, 9) # original: range(1.35, 0.9, 5)
|
||||||
|
extrapolating_c = range(0.78, 0.45, 7) # original: range(0.75, 0.40, 8)
|
||||||
|
|
||||||
|
# calculate training data
|
||||||
|
data_c = vcat(training_c, extrapolating_c)
|
||||||
|
data_E = [quick_pole_E(V_system(c)) for c in data_c]
|
||||||
|
|
||||||
|
# hyperparameters
|
||||||
|
N = 9
|
||||||
|
|
||||||
|
# initialize random Hamiltonians
|
||||||
|
H0 = randn(ComplexF64, N, N)
|
||||||
|
#H0 = H0 + H0' # symmetric
|
||||||
|
H1 = randn(ComplexF64, N, N)
|
||||||
|
#H1 = H1 + H1' # symmetric
|
||||||
|
|
||||||
|
# training
|
||||||
|
Es = ComplexF64[]
|
||||||
|
ψs = Vector{ComplexF64}[]
|
||||||
|
|
||||||
|
lr = 0.05
|
||||||
|
epochs = 100000
|
||||||
|
for epoch in 1:epochs
|
||||||
|
empty!(Es)
|
||||||
|
empty!(ψs)
|
||||||
|
for (c, E) in zip(data_c, data_E)
|
||||||
|
H = H0 + c * H1
|
||||||
|
evals, evecs = eigen(H)
|
||||||
|
i = nearestIndex(evals, E) # TODO: more robust way to identify the eigenvector
|
||||||
|
push!(Es, evals[i])
|
||||||
|
push!(ψs, evecs[:, i])
|
||||||
|
end
|
||||||
|
|
||||||
|
if epoch % 1000 == 0
|
||||||
|
loss = sum(abs2, Es .- data_E)
|
||||||
|
println("Epoch:$epoch/$epochs \t Loss: $loss")
|
||||||
|
end
|
||||||
|
|
||||||
|
# gradient of the loss function
|
||||||
|
function grad(c_order=0)
|
||||||
|
out = zeros(ComplexF64, N, N)
|
||||||
|
for (c, E_target, ψ, E) in zip(data_c, data_E, ψs, Es)
|
||||||
|
out .+= (c^c_order * (E - E_target)) .* (ψ * ψ')
|
||||||
|
end
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
H0 .-= lr .* grad(0) # update H0
|
||||||
|
H1 .-= lr .* grad(1) # update H1
|
||||||
|
end
|
||||||
|
|
||||||
|
# plot the results
|
||||||
|
scatter(real.(data_E), imag.(data_E), label="Target", title="PMM", xlabel="Re E", ylabel="Im E")
|
||||||
|
scatter!(real.(Es), imag.(Es), label="Predicted")
|
||||||
Loading…
Reference in New Issue