Remove third-party root finding package in favor of bisection (implemented by AI)

This commit is contained in:
Nuwan Yapa 2025-02-18 18:59:04 -05:00
parent f49207fc42
commit 5fced75f41
4 changed files with 71 additions and 9 deletions

View File

@ -3,4 +3,3 @@ DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PolyLog = "85e3b03c-9856-11eb-0374-4dc1f8670e7f" PolyLog = "85e3b03c-9856-11eb-0374-4dc1f8670e7f"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"

62
bisection.jl Normal file
View File

@ -0,0 +1,62 @@
"Bisection method: Finds a zero in the interval [a, b]"
function bisection(f::Function, a::Float64, b::Float64; tol::Float64 = 1e-8, max_iter::Int = 1000)::Float64
fa::Float64 = f(a)
fb::Float64 = f(b)
# Ensure the endpoints bracket a root.
if fa * fb > 0.0
error("No sign change detected in the interval: f(a)*f(b) > 0")
end
c::Float64 = a
for _ in 1:max_iter
c = (a + b) / 2.0
fc::Float64 = f(c)
# Check for convergence using the function value or the interval width.
if abs(fc) < tol || abs(b - a) < tol
return c
elseif fa * fc < 0.0
b = c
fb = fc
else
a = c
fa = fc
end
end
return c # Return the current approximation if max iterations reached.
end
"Function to find all zeros in [x_start, x_end] by scanning for sign changes."
function find_all_zeros(f::Function, x_start::Float64, x_end::Float64; partitions::Int = 1000, tol::Float64 = 1e-8)::Vector{Float64}
zeros_list::Vector{Float64} = Float64[]
Δ::Float64 = (x_end - x_start) / partitions
x_prev::Float64 = x_start
f_prev::Float64 = f(x_prev)
for i in 1:partitions
x_curr::Float64 = x_start + i * Δ
f_curr::Float64 = f(x_curr)
# Check for a sign change or an exact zero.
if f_prev * f_curr < 0.0 || f_prev == 0.0
try
root::Float64 = bisection(f, x_prev, x_curr; tol=tol)
# Add the root if it's new (avoid duplicates from neighboring intervals)
if isempty(zeros_list) || abs(root - zeros_list[end]) > tol
push!(zeros_list, root)
end
catch err
# Skip this interval if bisection fails (should not occur if a sign change exists)
end
elseif f_curr == 0.0
if isempty(zeros_list) || abs(x_curr - zeros_list[end]) > tol
push!(zeros_list, x_curr)
end
end
x_prev = x_curr
f_prev = f_curr
end
return zeros_list
end

View File

@ -1,4 +1,5 @@
using DifferentialEquations, Roots using DifferentialEquations
include("bisection.jl")
const ħc = 197.33 # MeVfm const ħc = 197.33 # MeVfm
const M_n = 939.0 # MeV/c2 const M_n = 939.0 # MeV/c2
@ -87,18 +88,18 @@ function boundaryValueFunc(κ, p, s::system; dtype=Float64, algo=Tsit5())
return func return func
end end
"Find all bound energies between E_min (=850) and E_max (=938) where "Find all bound energies between E_min (=850.0) and E_max (=938.0) where
tol_digits determines the precision for root finding and the threshold for identifying duplicates, tol_digits determines the precision for root finding and the threshold for identifying duplicates,
the other parameters are the same from dirac!(...)." the other parameters are the same from dirac!(...)."
function findEs(κ, p, s::system, E_min=850, E_max=938, tol_digits=5) function findEs(κ, p, s::system, E_min=850.0, E_max=938.0, tol_digits=5)
func = boundaryValueFunc(κ, p, s) func = boundaryValueFunc(κ, p, s)
Es = find_zeros(func, (E_min, E_max); xatol=1/10^tol_digits) Es = find_all_zeros(func, E_min, E_max; partitions=20, tol=1/10^tol_digits)
return unique(E -> round(E; digits=tol_digits), Es) return unique(E -> round(E; digits=tol_digits), Es)
end end
"Find all orbitals and return two lists containing κ values and corresponding energies for a single species where "Find all orbitals and return two lists containing κ values and corresponding energies for a single species where
the other parameters are defined above" the other parameters are defined above"
function findAllOrbitals(p, s::system, E_min=850, E_max=938) function findAllOrbitals(p, s::system, E_min=850.0, E_max=938.0)
κs = Int[] κs = Int[]
Es = Float64[] Es = Float64[]
# start from κ=-1 and go both up and down # start from κ=-1 and go both up and down
@ -160,7 +161,7 @@ end
"Solve the Dirac equation and calculate scalar and vector densities of a nucleon species where "Solve the Dirac equation and calculate scalar and vector densities of a nucleon species where
the other parameters are defined above" the other parameters are defined above"
function solveNucleonDensity(p, s::system, E_min=850, E_max=938) function solveNucleonDensity(p, s::system, E_min=850.0, E_max=938.0)
κs, Es = findAllOrbitals(p, s, E_min, E_max) κs, Es = findAllOrbitals(p, s, E_min, E_max)
occs = fillNucleons(Z_or_N(s, p), κs, Es) occs = fillNucleons(Z_or_N(s, p), κs, Es)
return calculateNucleonDensity(κs, Es, occs, p, s) return calculateNucleonDensity(κs, Es, occs, p, s)

View File

@ -21,8 +21,8 @@ s.W0 = Vs
s.B0 = Rs s.B0 = Rs
s.A0 = As s.A0 = As
E_min = 850 E_min = 850.0
E_max = 938 E_max = 938.0
boundEs = findEs(κ, p, s, E_min, E_max) boundEs = findEs(κ, p, s, E_min, E_max)