From 2f01865a2ff4ad3345c156f7a2f76fe79ec4ed9a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 3 Aug 2022 09:05:37 -0700 Subject: Added CUDA block assert and is_on_gpu check. --- bitsandbytes/functional.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0190a7e..0a2d557 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -141,6 +141,14 @@ def get_special_format_str(): elif major == 8: return 'col_ampere' else: return 'col_turing' + +def is_on_gpu(tensors): + on_gpu = True + for t in tensors: + if t is None: continue # NULL pointers are fine + on_gpu &= t.device.type == 'cuda' + return on_gpu + def get_ptr(A: Tensor) -> ct.c_void_p: ''' Get the ctypes pointer from a PyTorch Tensor. @@ -284,6 +292,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens The 256 quantiles in float32 datatype. ''' if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + is_on_gpu([A, out]) if A.dtype == torch.float32: lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: @@ -337,6 +346,7 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N if A.device.type != 'cpu': + is_on_gpu([code, A, absmax, out, rand]) if rand is not None: assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) @@ -401,6 +411,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]') if A.device.type != 'cpu': + is_on_gpu([A, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: @@ -458,6 +469,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: Quantized 8-bit tensor. ''' if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out @@ -483,6 +495,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: 32-bit output tensor. ''' if out is None: out = torch.zeros_like(A, dtype=torch.float32) + is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out @@ -662,6 +675,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: The current optimiation steps (number of past gradient norms). """ + is_on_gpu([grad, gnorm_vec]) if grad.dtype == torch.float32: lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel())) elif grad.dtype == torch.float16: @@ -694,6 +708,7 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) + is_on_gpu([histogram, index1, index2d, source]) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): @@ -820,6 +835,7 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed # B^T @ A^T = C^T # [km, nk -> mn] + is_on_gpu([B, A, out]) lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) return out @@ -892,6 +908,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ptr = CUBLAS_Context.get_instance().get_context(A.device) + is_on_gpu([B, A, out]) lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) @@ -954,6 +971,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) + is_on_gpu([A, B, out]) if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) @@ -994,6 +1012,7 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non numRows = ct.c_int32(out_shape[0]) numCols = ct.c_int32(out_shape[1]) + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats]) lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols) return out @@ -1024,6 +1043,7 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr cols = ct.c_int32(cols) prev_device = pre_call(A.device) + is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) post_call(prev_device) @@ -1133,6 +1153,7 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, ptrOutCol = get_ptr(out_col) ptrOutRow = get_ptr(out_row) + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: @@ -1185,6 +1206,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ptrA = get_ptr(A) ptrOut = get_ptr(out) + is_on_gpu([A, out]) if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) @@ -1240,6 +1262,7 @@ def spmm_coo(cooA, B, out=None): cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) return out @@ -1285,6 +1308,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): #print(cooA.rowidx[:64]) #print(cooA.colidx[:64].sort()[0]) + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) if B.dtype == torch.float16: lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) elif B.dtype == torch.int8: -- cgit v1.2.3 From 451fd9506e215aa25643e9782cb7d8aed2a266cc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 3 Aug 2022 11:54:01 -0700 Subject: Added fixes for the case that matmullt dim A is zero, e.g. [0, 768]. --- bitsandbytes/autograd/_functions.py | 16 +++++++++++++++- bitsandbytes/functional.py | 21 +++++++++++++++------ 2 files changed, 30 insertions(+), 7 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 815a4f1..370ca83 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,4 +1,5 @@ import torch +import math import bitsandbytes as bnb import bitsandbytes.functional as F @@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if math.prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -265,6 +277,8 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None req_gradA, req_gradB = ctx.req_grads CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states @@ -293,7 +307,7 @@ class MatMul8bitLt(torch.autograd.Function): gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) - return grad_A, grad_B, None, None, None, None, None + return grad_A, grad_B, None, None matmul = MatMul8bitLt.apply diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0a2d557..494de1b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. import ctypes as ct import random -from typing import Tuple - +import math import torch + +from typing import Tuple from torch import Tensor from .cextension import lib, COMPILED_WITH_CUDA @@ -919,15 +920,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0]*shapeA[1] - if dimsB == 2: - rows = n = shapeB[0] - elif dimsB == 3: - rows = n = shapeB[0]*shapeB[1] + rows = n = shapeB[0] + assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + print(shapeA, math.prod(shapeA), math.prod(list(shapeA))) + print('aaa') + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row') @@ -984,6 +992,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 1: + print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') raise Exception('cublasLt ran into an error!') torch.cuda.set_device(prev_device) -- cgit v1.2.3 From 320eacb4c23adeaaf4a54166f19eac950aa631f1 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 3 Aug 2022 14:17:54 -0700 Subject: Removed print statement. --- bitsandbytes/functional.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 494de1b..334bdd9 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -928,8 +928,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): rows = n = shapeB[0] assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' - print(shapeA, math.prod(shapeA), math.prod(list(shapeA))) - print('aaa') # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: -- cgit v1.2.3 From 6101a8fb9f76c2cc4018452b4420dd52e946d52b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Thu, 4 Aug 2022 07:28:12 -0700 Subject: Added pre and post device call to transform. --- bitsandbytes/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 334bdd9..e7261bc 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1214,6 +1214,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ptrA = get_ptr(A) ptrOut = get_ptr(out) is_on_gpu([A, out]) + prev_device = pre_call(A.device) if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) @@ -1236,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - - + post_call(prev_device) return out, new_state -- cgit v1.2.3 From ab72a1294fda03a0fd4ec297562fdab806349752 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Thu, 4 Aug 2022 07:47:22 -0700 Subject: Added pre/post device call for extract outliers. --- bitsandbytes/functional.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 08c108c..ad85f53 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1198,6 +1198,7 @@ def get_special_format_str(): def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + prev_device = pre_call(A.device) if state is None: state = (A.shape, from_order) else: from_order = state[1] if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) @@ -1214,7 +1215,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ptrA = get_ptr(A) ptrOut = get_ptr(out) is_on_gpu([A, out]) - prev_device = pre_call(A.device) if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) @@ -1237,8 +1237,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - post_call(prev_device) + post_call(prev_device) return out, new_state @@ -1451,10 +1451,12 @@ def extract_outliers(A, SA, idx): ptrIdx = get_ptr(idx) ptrOut = get_ptr(out) + prev_device = pre_call(A.device) if formatA == 'col_turing': lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == 'col_ampere': lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + post_call(prev_device) return out -- cgit v1.2.3