summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/autograd/_functions.py17
-rw-r--r--bitsandbytes/functional.py217
2 files changed, 96 insertions, 138 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index b56b2ee..14f2660 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,7 +1,7 @@
from dataclasses import dataclass
import torch
-
+import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -199,6 +199,17 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ # default to pytorch behavior if inputs are empty
+ ctx.is_empty = False
+ if math.prod(A.shape) == 0:
+ ctx.is_empty = True
+ ctx.A = A
+ ctx.B = B
+ if A.shape[-1] == B.shape[0]:
+ return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
+ else:
+ return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
+ if ctx.is_empty:
+ return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
@@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.grad_shape
)
- return grad_A, grad_B, None, None, None, None, None
+ return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 236ef39..b4409e4 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import random
-from typing import Tuple
-
+import math
import torch
+
+from typing import Tuple
from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib
@@ -193,6 +194,14 @@ def get_special_format_str():
return "col_turing"
+
+def is_on_gpu(tensors):
+ on_gpu = True
+ for t in tensors:
+ if t is None: continue # NULL pointers are fine
+ on_gpu &= t.device.type == 'cuda'
+ return on_gpu
+
def get_ptr(A: Tensor) -> ct.c_void_p:
"""
Get the ctypes pointer from a PyTorch Tensor.
@@ -336,7 +345,7 @@ def nvidia_transform(
def estimate_quantiles(
A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor:
- """
+ '''
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
@@ -361,9 +370,9 @@ def estimate_quantiles(
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
- """
- if out is None:
- out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+ '''
+ if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+ is_on_gpu([A, out])
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
@@ -428,7 +437,8 @@ def quantize_blockwise(
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
- if A.device.type != "cpu":
+ if A.device.type != 'cpu':
+ is_on_gpu([code, A, absmax, out, rand])
if rand is not None:
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
@@ -541,7 +551,8 @@ def dequantize_blockwise(
f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
)
- if A.device.type != "cpu":
+ if A.device.type != 'cpu':
+ is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(
get_ptr(quant_state[1]),
@@ -610,7 +621,7 @@ def dequantize(
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
- """
+ '''
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
@@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
-------
torch.Tensor:
Quantized 8-bit tensor.
- """
- if out is None:
- out = torch.zeros_like(A, dtype=torch.uint8)
+ '''
+ if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+ is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
- """
+ '''
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
@@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
-------
torch.Tensor:
32-bit output tensor.
- """
- if out is None:
- out = torch.zeros_like(A, dtype=torch.float32)
- lib.cdequantize(
- get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
- )
+ '''
+ if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+ is_on_gpu([code, A, out])
+ lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -983,6 +992,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms).
"""
+ is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(
get_ptr(grad),
@@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d(
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
- lib.chistogram_scatter_add_2d(
- get_ptr(histogram),
- get_ptr(index1),
- get_ptr(index2),
- get_ptr(source),
- maxdim1,
- n,
- )
+ is_on_gpu([histogram, index1, index2d, source])
+ lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
-
-def check_matmul(
- A, B, out, transposed_A, transposed_B, expected_type=torch.int8
-):
- if not torch.cuda.is_initialized():
- torch.cuda.init()
+def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
+ if not torch.cuda.is_initialized(): torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError(
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
@@ -1212,21 +1212,10 @@ def igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
- # [km, nk -> mn]
- lib.cigemm(
- ptr,
- ct.c_bool(transposed_B),
- ct.c_bool(transposed_A),
- ct.c_int32(m),
- ct.c_int32(n),
- ct.c_int32(k),
- get_ptr(B),
- get_ptr(A),
- get_ptr(out),
- ct.c_int32(lda),
- ct.c_int32(ldb),
- ct.c_int32(ldc),
- )
+ # [km, nk -> mn]
+ is_on_gpu([B, A, out])
+ lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
+ get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
return out
@@ -1306,24 +1295,10 @@ def batched_igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device)
- lib.cbatched_igemm(
- ptr,
- ct.c_bool(transposed_B),
- ct.c_bool(transposed_A),
- ct.c_int32(m),
- ct.c_int32(n),
- ct.c_int32(k),
- get_ptr(B),
- get_ptr(A),
- get_ptr(out),
- ct.c_int32(lda),
- ct.c_int32(ldb),
- ct.c_int32(ldc),
- ct.c_long(strideA),
- ct.c_long(strideB),
- ct.c_long(strideC),
- ct.c_uint32(num_batch),
- )
+ is_on_gpu([B, A, out])
+ lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
+ get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
+ ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
return out
@@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0]
dimsA = len(shapeA)
dimsB = len(shapeB)
+ assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0] * shapeA[1]
- if dimsB == 2:
- rows = n = shapeB[0]
- elif dimsB == 3:
- rows = n = shapeB[0] * shapeB[1]
+ rows = n = shapeB[0]
+ assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
+
+ # if the tensor is empty, return a transformed empty tensor with the right dimensions
+ if shapeA[0] == 0 and dimsA == 2:
+ return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
+ elif shapeA[1] == 0 and dimsA == 3:
+ return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer(
@@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0
ptrRowScale = get_ptr(None)
- if formatB == "col_turing":
+ is_on_gpu([A, B, out])
+ if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
@@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
)
if has_error == 1:
- raise Exception("cublasLt ran into an error!")
+ print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
+ raise Exception('cublasLt ran into an error!')
torch.cuda.set_device(prev_device)
@@ -1457,16 +1439,8 @@ def mm_dequant(
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
- lib.cdequant_mm_int32_fp16(
- ptrA,
- ptrRowStats,
- ptrColStats,
- ptrOut,
- ptrNewRowStats,
- ptrNewColStats,
- numRows,
- numCols,
- )
+ is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
+ lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
return out
@@ -1507,15 +1481,8 @@ def get_colrow_absmax(
cols = ct.c_int32(cols)
prev_device = pre_call(A.device)
- lib.cget_col_row_stats(
- ptrA,
- ptrRowStats,
- ptrColStats,
- ptrNnzrows,
- ct.c_float(threshold),
- rows,
- cols,
- )
+ is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
+ lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
if threshold > 0.0:
@@ -1642,6 +1609,7 @@ def double_quant(
ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row)
+ is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
@@ -1714,33 +1682,19 @@ def get_special_format_str():
)
assert major >= 7
- if major == 7:
- return "col_turing"
- elif major == 8:
- return "col_ampere"
- else:
- return "col_turing"
+ if major == 7: return 'col_turing'
+ elif major == 8: return 'col_ampere'
+ else: return 'col_turing'
-def transform(
- A,
- to_order,
- from_order="row",
- out=None,
- transpose=False,
- state=None,
- ld=None,
-):
- if state is None:
- state = (A.shape, from_order)
- else:
- from_order = state[1]
- if out is None:
- out, new_state = get_transform_buffer(
- state[0], A.dtype, A.device, to_order, state[1], transpose
- )
- else:
- new_state = (state[0], to_order) # (shape, order)
+
+
+def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
+ prev_device = pre_call(A.device)
+ if state is None: state = (A.shape, from_order)
+ else: from_order = state[1]
+ if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
+ else: new_state = (state[0], to_order) # (shape, order)
shape = state[0]
if len(shape) == 2:
@@ -1752,7 +1706,8 @@ def transform(
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
- if to_order == "col32":
+ is_on_gpu([A, out])
+ if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else:
@@ -1773,9 +1728,9 @@ def transform(
elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
- raise NotImplementedError(
- f"Transform function not implemented: From {from_order} to {to_order}"
- )
+ raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
+
+ post_call(prev_device)
return out, new_state
@@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
- lib.cspmm_coo(
- ptr,
- ptrRowidx,
- ptrColidx,
- ptrValues,
- cnnz,
- crowsA,
- ccolsA,
- ccolsB,
- cldb,
- ptrB,
- cldc,
- ptrC,
- ct.c_bool(transposed_B),
- )
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
+ lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
return out
@@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
# print(cooA.rowidx[:64])
# print(cooA.colidx[:64].sort()[0])
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16(
ptrMaxCount,
@@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx):
ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out)
- if formatA == "col_turing":
+ prev_device = pre_call(A.device)
+ if formatA == 'col_turing':
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
+ post_call(prev_device)
return out