diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-25 17:27:57 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-25 17:27:57 -0700 |
commit | 1e88edd8c096bde5202dd61411d3c8d7eda56645 (patch) | |
tree | 84e514d4538d113fe6f78985808d26cc5c677b62 /bitsandbytes | |
parent | 8b1fd32e3e4f5073fd055cb5f9261ec585f8cc2c (diff) |
Removed rowscale (segfaults on ampere).
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 13 |
1 files changed, 4 insertions, 9 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 806c254..0190a7e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) return out -def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32): +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] dimsA = len(shapeA) @@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 elif dimsA == 3 and out is None: out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row') - if row_scale is not None: assert row_scale.numel() == out.shape[0] assert dimsB != 3, 'len(B.shape)==3 not supported' assert A.device.type == 'cuda' assert B.device.type == 'cuda' @@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 ptrA = get_ptr(A) ptrB = get_ptr(B) ptrC = get_ptr(out) - ptrRowScale = get_ptr(row_scale) k = shapeA[-1] lda = ct.c_int32(m*32) @@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 k = ct.c_int32(k) has_error = 0 + ptrRowScale = get_ptr(None) if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif row_scale is None: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == 'col_ampere': if dtype == torch.int32: has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif row_scale is None: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 1: raise Exception('cublasLt ran into an error!') |