summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/autograd/_functions.py16
-rw-r--r--bitsandbytes/functional.py47
2 files changed, 54 insertions, 9 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index e641583..607d868 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,4 +1,5 @@
import torch
+import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ # default to pytorch behavior if inputs are empty
+ ctx.is_empty = False
+ if math.prod(A.shape) == 0:
+ ctx.is_empty = True
+ ctx.A = A
+ ctx.B = B
+ if A.shape[-1] == B.shape[0]:
+ return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
+ else:
+ return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@@ -283,6 +295,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
+ if ctx.is_empty:
+ return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
@@ -311,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
- return grad_A, grad_B, None, None, None, None, None
+ return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index ac85f88..08c108c 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import random
-from typing import Tuple
-
+import math
import torch
+
+from typing import Tuple
from torch import Tensor
from .cextension import lib, COMPILED_WITH_CUDA
@@ -141,6 +142,14 @@ def get_special_format_str():
elif major == 8: return 'col_ampere'
else: return 'col_turing'
+
+def is_on_gpu(tensors):
+ on_gpu = True
+ for t in tensors:
+ if t is None: continue # NULL pointers are fine
+ on_gpu &= t.device.type == 'cuda'
+ return on_gpu
+
def get_ptr(A: Tensor) -> ct.c_void_p:
'''
Get the ctypes pointer from a PyTorch Tensor.
@@ -284,6 +293,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
The 256 quantiles in float32 datatype.
'''
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+ is_on_gpu([A, out])
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
@@ -337,6 +347,7 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
if A.device.type != 'cpu':
+ is_on_gpu([code, A, absmax, out, rand])
if rand is not None:
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
@@ -401,6 +412,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]')
if A.device.type != 'cpu':
+ is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
@@ -458,6 +470,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
Quantized 8-bit tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+ is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -483,6 +496,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
32-bit output tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+ is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -662,6 +676,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
The current optimiation steps (number of past gradient norms).
"""
+ is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
elif grad.dtype == torch.float16:
@@ -694,6 +709,7 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
+ is_on_gpu([histogram, index1, index2d, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
@@ -820,6 +836,7 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# B^T @ A^T = C^T
# [km, nk -> mn]
+ is_on_gpu([B, A, out])
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
return out
@@ -892,6 +909,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ptr = CUBLAS_Context.get_instance().get_context(A.device)
+ is_on_gpu([B, A, out])
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
@@ -902,15 +920,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0]
dimsA = len(shapeA)
dimsB = len(shapeB)
+ assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0]*shapeA[1]
- if dimsB == 2:
- rows = n = shapeB[0]
- elif dimsB == 3:
- rows = n = shapeB[0]*shapeB[1]
+ rows = n = shapeB[0]
+ assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
+
+ # if the tensor is empty, return a transformed empty tensor with the right dimensions
+ if shapeA[0] == 0 and dimsA == 2:
+ return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
+ elif shapeA[1] == 0 and dimsA == 3:
+ return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
@@ -954,6 +977,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0
ptrRowScale = get_ptr(None)
+ is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
@@ -966,6 +990,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 1:
+ print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')
torch.cuda.set_device(prev_device)
@@ -994,6 +1019,7 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
+ is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
return out
@@ -1024,6 +1050,7 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
cols = ct.c_int32(cols)
prev_device = pre_call(A.device)
+ is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
@@ -1133,6 +1160,7 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row)
+ is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
@@ -1185,6 +1213,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
+ is_on_gpu([A, out])
+ prev_device = pre_call(A.device)
if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
@@ -1207,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
-
-
+ post_call(prev_device)
return out, new_state
@@ -1240,6 +1269,7 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
return out
@@ -1285,6 +1315,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
#print(cooA.rowidx[:64])
#print(cooA.colidx[:64].sort()[0])
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
elif B.dtype == torch.int8: