summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index b4409e4..23e5464 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -218,7 +218,7 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
if A is None:
return None
else:
- return ct.c_void_p(A.data.storage().data_ptr())
+ return ct.c_void_p(A.data.data_ptr())
def pre_call(device):
@@ -1407,8 +1407,10 @@ def mm_dequant(
out=None,
new_row_stats=None,
new_col_stats=None,
+ bias=None
):
assert A.dtype == torch.int32
+ if bias is not None: assert bias.dtype == torch.float16
out_shape = quant_state[0]
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
@@ -1430,17 +1432,20 @@ def mm_dequant(
new_col_stats.shape[0] == col_stats.shape[0]
), f"{new_col_stats.shape} vs {col_stats.shape}"
+ prev_device = pre_call(A.device)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
ptrNewRowStats = get_ptr(new_row_stats)
ptrNewColStats = get_ptr(new_col_stats)
+ ptrBias = get_ptr(bias)
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
- is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
- lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
+ is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
+ lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols)
+ post_call(prev_device)
return out