Fixed GPU related bug
This commit is contained in:
parent
b21242ad49
commit
7cc20a9c27
|
|
@ -68,8 +68,8 @@ end
|
|||
|
||||
"cuTENSOR contraction and accumulation (C = A * B + C)"
|
||||
function contract_accumulate!(C::CuTensor, A::CuTensor, B::CuTensor)::CuTensor
|
||||
CUTENSOR.contraction!(one(eltype(C)), A.data, A.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, B.data, B.inds, CUTENSOR.CUTENSOR_OP_IDENTITY,
|
||||
one(eltype(C)), C.data, C.inds, CUTENSOR.CUTENSOR_OP_IDENTITY, CUTENSOR.CUTENSOR_OP_IDENTITY)
|
||||
cuTENSOR.contraction!(one(eltype(C)), A.data, A.inds, cuTENSOR.CUTENSOR_OP_IDENTITY, B.data, B.inds, cuTENSOR.CUTENSOR_OP_IDENTITY,
|
||||
one(eltype(C)), C.data, C.inds, cuTENSOR.CUTENSOR_OP_IDENTITY, cuTENSOR.CUTENSOR_OP_IDENTITY)
|
||||
return C
|
||||
end
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue