diff --git a/HOperator.jl b/HOperator.jl index 07efdff..10b0259 100644 --- a/HOperator.jl +++ b/HOperator.jl @@ -122,14 +122,8 @@ function LinearAlgebra.mul!(out::CuArray{Complex{T}}, H::HOperator{T}, v::CuArra return out_t.data end -"Apply 'H' on 'v' and return the result using the 'cpu_tensor' backend" -function (H::HOperator{T})(v::Array{Complex{T}})::Array{Complex{T}} where {T<:Float} - out = similar(v) - return mul!(out, H, v) -end - -"Apply 'H' on 'v' and return the result using the 'gpu_cutensor' backend" -function (H::HOperator{T})(v::CuArray{Complex{T}})::CuArray{Complex{T}} where {T<:Float} +"Apply 'H' on 'v' and return the result" +function (H::HOperator)(v) out = similar(v) return mul!(out, H, v) end diff --git a/example.ipynb b/example.ipynb index a8caef4..8b5103f 100644 --- a/example.ipynb +++ b/example.ipynb @@ -2,9 +2,23 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "LoadError", + "evalue": "LoadError: invalid redefinition of constant cpu_tensor\nin expression starting at c:\\Users\\yapan\\DVR.jl\\HOperator.jl:4", + "output_type": "error", + "traceback": [ + "LoadError: invalid redefinition of constant cpu_tensor\n", + "in expression starting at c:\\Users\\yapan\\DVR.jl\\HOperator.jl:4\n", + "\n", + "Stacktrace:\n", + " [1] top-level scope\n", + " @ Enums.jl:204" + ] + } + ], "source": [ "# prerequisite packages: KrylovKit, TensorOperations, LinearAlgebra, CUDA#tb/cutensor, Plots\n", "include(\"HOperator.jl\")\n", @@ -14,9 +28,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 4.377701 seconds (5.09 M allocations: 764.743 MiB, 6.09% gc time, 71.04% compilation time: 99% of which was recompilation)\n", + "114 operations : ComplexF32[-7.6208663f0 + 0.0f0im, -3.551723f0 + 0.0f0im, -3.5371912f0 + 0.0f0im, -3.5240355f0 + 0.0f0im, -3.5159583f0 + 0.0f0im, -3.4865863f0 + 0.0f0im, -3.1896422f0 + 0.0f0im, -2.9661055f0 + 0.0f0im]\n" + ] + } + ], "source": [ "V_gauss(r2) =\n", " -4 * exp(-r2 / 4)\n", @@ -37,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [