From 1e88edd8c096bde5202dd61411d3c8d7eda56645 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 25 Jul 2022 17:27:57 -0700 Subject: Removed rowscale (segfaults on ampere). --- Makefile | 1 - bitsandbytes/functional.py | 13 ++++--------- tests/test_functional.py | 2 ++ 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 728c8e1..2e1d265 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta # CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not CC_CUDA92 := -gencode arch=compute_30,code=sm_30 diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 806c254..0190a7e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) return out -def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32): +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] dimsA = len(shapeA) @@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 elif dimsA == 3 and out is None: out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row') - if row_scale is not None: assert row_scale.numel() == out.shape[0] assert dimsB != 3, 'len(B.shape)==3 not supported' assert A.device.type == 'cuda' assert B.device.type == 'cuda' @@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 ptrA = get_ptr(A) ptrB = get_ptr(B) ptrC = get_ptr(out) - ptrRowScale = get_ptr(row_scale) k = shapeA[-1] lda = ct.c_int32(m*32) @@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 k = ct.c_int32(k) has_error = 0 + ptrRowScale = get_ptr(None) if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif row_scale is None: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == 'col_ampere': if dtype == torch.int32: has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif row_scale is None: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 1: raise Exception('cublasLt ran into an error!') diff --git a/tests/test_functional.py b/tests/test_functional.py index dcb0255..d80a4f9 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist() values = list(zip(dim1, dim4, inner)) names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): formatB = F.get_special_format_str() err1, err2, err3 = [], [], [] @@ -1064,6 +1065,7 @@ dim4 = [12288, 4096] values = list(zip(dim1, dim4, inner)) names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.skip("Row scale has some bugs for ampere") def test_row_scale_bench(dim1, dim4, inner): err1, err2, err3 = [], [], [] relerr1, relerr2 = [], [] -- cgit v1.2.3