summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile1
-rw-r--r--bitsandbytes/functional.py13
-rw-r--r--tests/test_functional.py2
3 files changed, 6 insertions, 10 deletions
diff --git a/Makefile b/Makefile
index 728c8e1..2e1d265 100644
--- a/Makefile
+++ b/Makefile
@@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
-COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 806c254..0190a7e 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
return out
-def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
+def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeA = SA[0]
shapeB = SB[0]
dimsA = len(shapeA)
@@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
- if row_scale is not None: assert row_scale.numel() == out.shape[0]
assert dimsB != 3, 'len(B.shape)==3 not supported'
assert A.device.type == 'cuda'
assert B.device.type == 'cuda'
@@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
ptrA = get_ptr(A)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
- ptrRowScale = get_ptr(row_scale)
k = shapeA[-1]
lda = ct.c_int32(m*32)
@@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
k = ct.c_int32(k)
has_error = 0
+ ptrRowScale = get_ptr(None)
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
- elif row_scale is None:
- has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
- has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif formatB == 'col_ampere':
if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
- elif row_scale is None:
- has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
- has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 1:
raise Exception('cublasLt ran into an error!')
diff --git a/tests/test_functional.py b/tests/test_functional.py
index dcb0255..d80a4f9 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
@@ -1064,6 +1065,7 @@ dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []