summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py13
1 files changed, 4 insertions, 9 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index a9233e2..ac85f88 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
return out
-def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
+def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeA = SA[0]
shapeB = SB[0]
dimsA = len(shapeA)
@@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
- if row_scale is not None: assert row_scale.numel() == out.shape[0]
assert dimsB != 3, 'len(B.shape)==3 not supported'
assert A.device.type == 'cuda'
assert B.device.type == 'cuda'
@@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
ptrA = get_ptr(A)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
- ptrRowScale = get_ptr(row_scale)
k = shapeA[-1]
lda = ct.c_int32(m*32)
@@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
k = ct.c_int32(k)
has_error = 0
+ ptrRowScale = get_ptr(None)
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
- elif row_scale is None:
- has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
- has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif formatB == 'col_ampere':
if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
- elif row_scale is None:
- has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
- has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 1:
raise Exception('cublasLt ran into an error!')