Simplification of overloaded function

This commit is contained in:
ysyapa 2023-04-07 00:00:22 -04:00
parent 6bbf238f4a
commit 1f67dea4c9
2 changed files with 30 additions and 13 deletions

View File

@ -122,14 +122,8 @@ function LinearAlgebra.mul!(out::CuArray{Complex{T}}, H::HOperator{T}, v::CuArra
return out_t.data return out_t.data
end end
"Apply 'H' on 'v' and return the result using the 'cpu_tensor' backend" "Apply 'H' on 'v' and return the result"
function (H::HOperator{T})(v::Array{Complex{T}})::Array{Complex{T}} where {T<:Float} function (H::HOperator)(v)
out = similar(v)
return mul!(out, H, v)
end
"Apply 'H' on 'v' and return the result using the 'gpu_cutensor' backend"
function (H::HOperator{T})(v::CuArray{Complex{T}})::CuArray{Complex{T}} where {T<:Float}
out = similar(v) out = similar(v)
return mul!(out, H, v) return mul!(out, H, v)
end end

View File

@ -2,9 +2,23 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"ename": "LoadError",
"evalue": "LoadError: invalid redefinition of constant cpu_tensor\nin expression starting at c:\\Users\\yapan\\DVR.jl\\HOperator.jl:4",
"output_type": "error",
"traceback": [
"LoadError: invalid redefinition of constant cpu_tensor\n",
"in expression starting at c:\\Users\\yapan\\DVR.jl\\HOperator.jl:4\n",
"\n",
"Stacktrace:\n",
" [1] top-level scope\n",
" @ Enums.jl:204"
]
}
],
"source": [ "source": [
"# prerequisite packages: KrylovKit, TensorOperations, LinearAlgebra, CUDA#tb/cutensor, Plots\n", "# prerequisite packages: KrylovKit, TensorOperations, LinearAlgebra, CUDA#tb/cutensor, Plots\n",
"include(\"HOperator.jl\")\n", "include(\"HOperator.jl\")\n",
@ -14,9 +28,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 4.377701 seconds (5.09 M allocations: 764.743 MiB, 6.09% gc time, 71.04% compilation time: 99% of which was recompilation)\n",
"114 operations : ComplexF32[-7.6208663f0 + 0.0f0im, -3.551723f0 + 0.0f0im, -3.5371912f0 + 0.0f0im, -3.5240355f0 + 0.0f0im, -3.5159583f0 + 0.0f0im, -3.4865863f0 + 0.0f0im, -3.1896422f0 + 0.0f0im, -2.9661055f0 + 0.0f0im]\n"
]
}
],
"source": [ "source": [
"V_gauss(r2) =\n", "V_gauss(r2) =\n",
" -4 * exp(-r2 / 4)\n", " -4 * exp(-r2 / 4)\n",
@ -37,7 +60,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [