summaryrefslogtreecommitdiff
path: root/bitsandbytes/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r--bitsandbytes/functional.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 0190a7e..0a2d557 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -141,6 +141,14 @@ def get_special_format_str():
elif major == 8: return 'col_ampere'
else: return 'col_turing'
+
+def is_on_gpu(tensors):
+ on_gpu = True
+ for t in tensors:
+ if t is None: continue # NULL pointers are fine
+ on_gpu &= t.device.type == 'cuda'
+ return on_gpu
+
def get_ptr(A: Tensor) -> ct.c_void_p:
'''
Get the ctypes pointer from a PyTorch Tensor.
@@ -284,6 +292,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
The 256 quantiles in float32 datatype.
'''
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+ is_on_gpu([A, out])
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
@@ -337,6 +346,7 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
if A.device.type != 'cpu':
+ is_on_gpu([code, A, absmax, out, rand])
if rand is not None:
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
@@ -401,6 +411,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]')
if A.device.type != 'cpu':
+ is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
@@ -458,6 +469,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
Quantized 8-bit tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+ is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -483,6 +495,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
32-bit output tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+ is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -662,6 +675,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
The current optimiation steps (number of past gradient norms).
"""
+ is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
elif grad.dtype == torch.float16:
@@ -694,6 +708,7 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
+ is_on_gpu([histogram, index1, index2d, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
@@ -820,6 +835,7 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# B^T @ A^T = C^T
# [km, nk -> mn]
+ is_on_gpu([B, A, out])
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
return out
@@ -892,6 +908,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ptr = CUBLAS_Context.get_instance().get_context(A.device)
+ is_on_gpu([B, A, out])
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
@@ -954,6 +971,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0
ptrRowScale = get_ptr(None)
+ is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
@@ -994,6 +1012,7 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
+ is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
return out
@@ -1024,6 +1043,7 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
cols = ct.c_int32(cols)
prev_device = pre_call(A.device)
+ is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
@@ -1133,6 +1153,7 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row)
+ is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
@@ -1185,6 +1206,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
+ is_on_gpu([A, out])
if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
@@ -1240,6 +1262,7 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
return out
@@ -1285,6 +1308,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
#print(cooA.rowidx[:64])
#print(cooA.colidx[:64].sort()[0])
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
elif B.dtype == torch.int8: