From 1e88edd8c096bde5202dd61411d3c8d7eda56645 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 25 Jul 2022 17:27:57 -0700 Subject: Removed rowscale (segfaults on ampere). --- bitsandbytes/functional.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) (limited to 'bitsandbytes/functional.py') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 806c254..0190a7e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) return out -def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32): +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] dimsA = len(shapeA) @@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 elif dimsA == 3 and out is None: out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row') - if row_scale is not None: assert row_scale.numel() == out.shape[0] assert dimsB != 3, 'len(B.shape)==3 not supported' assert A.device.type == 'cuda' assert B.device.type == 'cuda' @@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 ptrA = get_ptr(A) ptrB = get_ptr(B) ptrC = get_ptr(out) - ptrRowScale = get_ptr(row_scale) k = shapeA[-1] lda = ct.c_int32(m*32) @@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 k = ct.c_int32(k) has_error = 0 + ptrRowScale = get_ptr(None) if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif row_scale is None: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == 'col_ampere': if dtype == torch.int32: has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif row_scale is None: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 1: raise Exception('cublasLt ran into an error!') -- cgit v1.2.3