From 1507f53f50fa4cf5c1bc4bbe6340751dbd3abd2e Mon Sep 17 00:00:00 2001 From: Nuwan Yapa Date: Tue, 23 Apr 2024 12:35:20 -0400 Subject: [PATCH] V matrix elements caching --- ho_basis.jl | 12 +++++++++--- ho_basis_3body.jl | 12 ++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/ho_basis.jl b/ho_basis.jl index b97576e..3944eae 100644 --- a/ho_basis.jl +++ b/ho_basis.jl @@ -71,13 +71,19 @@ function sp_T_matrix(ns, ls; mask=trues(length(ns),length(ns)), μω_gen=1.0, μ return (μω_gen / μ) .* mat end -function sp_V_matrix(V_l, ns, ls; mask=trues(length(ns),length(ns)), dtype=Float64) +function sp_V_matrix(V_l, ns, ls; mask=trues(length(ns),length(ns)), dtype=Float64, cache=fill(NaN, 1+maximum(ls), 1+maximum(ns), 1+maximum(ns))) mat = zeros(dtype, length(ns), length(ns)) - Threads.@threads for idx in CartesianIndices(mat) + Threads.@threads for idx in CartesianIndices(mat) if !mask[idx]; continue; end (i, j) = Tuple(idx) if ls[i] == ls[j] - mat[idx] = V_l(ls[i], ns[i], ns[j]) + l = ls[i] + n1, n2 = minmax(ns[i], ns[j]) # assuming transpose symmetry + if isnan(cache[1+l, 1+n1, 1+n2]) + cache[1+l, 1+n1, 1+n2] = V_l(l, n1, n2) # hopefully no race condition + @assert !isnan(cache[1+l, 1+n1, 1+n2]) "V matrix element returned NaN" + end + mat[idx] = cache[1+l, 1+n1, 1+n2] end end return sparse(mat) diff --git a/ho_basis_3body.jl b/ho_basis_3body.jl index 44cad58..47882a4 100644 --- a/ho_basis_3body.jl +++ b/ho_basis_3body.jl @@ -18,8 +18,10 @@ E_max = 40 println("No of threads = ", Threads.nthreads()) -@time "Basis" Es, n1s, l1s, n2s, l2s = get_2p_basis(E_max) -@time "Masks" begin +@time "Basis" begin + Es, n1s, l1s, n2s, l2s = get_2p_basis(E_max) + l_max = max(maximum(l1s), maximum(l2s)) + n_max = max(maximum(n1s), maximum(n2s)) mask1 = (n2s .== n2s') .&& (l2s .== l2s') mask2 = (n1s .== n1s') .&& (l1s .== l1s') end @@ -33,8 +35,10 @@ println("Constructing KE matrices") println("Constructing PE matrices") V1_elem(l, n1, n2) = Va * V_Gaussian(Ra, l, n1, n2; μω_gen=μ1ω1) V_relative_elem(l, n1, n2) = Va * V_Gaussian(Ra, l, n1, n2; μω_gen=μω_global) -@time "V1" V1 = sp_V_matrix(V1_elem, n1s, l1s; mask=mask1) -@time "V relative" V_relative = sp_V_matrix(V_relative_elem, n1s, l1s; mask=mask1) + sp_V_matrix(V_relative_elem, n2s, l2s; mask=mask2) +V1_cache = fill(NaN, 1+l_max, 1+n_max, 1+n_max) +V_relative_cache = fill(NaN, 1+l_max, 1+n_max, 1+n_max) +@time "V1" V1 = sp_V_matrix(V1_elem, n1s, l1s; mask=mask1, cache=V1_cache) +@time "V relative" V_relative = sp_V_matrix(V_relative_elem, n1s, l1s; mask=mask1, cache=V_relative_cache) + sp_V_matrix(V_relative_elem, n2s, l2s; mask=mask2, cache=V_relative_cache) @time "Moshinsky brackets" U = Moshinsky_transform(Es, n1s, l1s, n2s, l2s, Λ) @time "V2" V2 = U' * V_relative * U