Fixed GPU related bug

This commit is contained in:
ysyapa 2023-08-27 23:02:51 -04:00
parent b21242ad49
commit 7cc20a9c27
1 changed files with 2 additions and 2 deletions

View File

@ -68,8 +68,8 @@ end
"cuTENSOR contraction and accumulation (C = A * B + C)" "cuTENSOR contraction and accumulation (C = A * B + C)"
function contract_accumulate!(C::CuTensor, A::CuTensor, B::CuTensor)::CuTensor function contract_accumulate!(C::CuTensor, A::CuTensor, B::CuTensor)::CuTensor
CUTENSOR.contraction!(one(eltype(C)), A.data, A.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, B.data, B.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, cuTENSOR.contraction!(one(eltype(C)), A.data, A.inds, cuTENSOR.CUTENSOR_OP_IDENTITY, B.data, B.inds, cuTENSOR.CUTENSOR_OP_IDENTITY,
one(eltype(C)), C.data, C.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, CUTENSOR.CUTENSOR_OP_IDENTITY) one(eltype(C)), C.data, C.inds, cuTENSOR.CUTENSOR_OP_IDENTITY, cuTENSOR.CUTENSOR_OP_IDENTITY)
return C return C
end end