diff --git a/GPU.jl b/GPU.jl index 8099f18..8e731d4 100644 --- a/GPU.jl +++ b/GPU.jl @@ -29,17 +29,8 @@ vectorDims(H::HOperator)::Dims = tuple(fill(H.N, H.d * (H.n - 1))...) "cuTENSOR contraction and accumulation (C = A * B + C)" function contract_accumulate!(C::CuTensor, A::CuTensor, B::CuTensor)::CuTensor - # https://docs.nvidia.com/cuda/cutensor/api/cutensor.html#cutensorcontraction - compute_type = if eltype(C) == ComplexF32 - CUTENSOR.CUTENSOR_COMPUTE_TF32 - elseif eltype(C) == ComplexF64 - CUTENSOR.CUTENSOR_COMPUTE_64F - else - eltype(C) - end CUTENSOR.contraction!(one(eltype(C)), A.data, A.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, B.data, B.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, - one(eltype(C)), C.data, C.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, CUTENSOR.CUTENSOR_OP_IDENTITY, - compute_type=compute_type) + one(eltype(C)), C.data, C.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, CUTENSOR.CUTENSOR_OP_IDENTITY) return C end