From 1ed2fa2f218d8dac401f3315420ffec92014c124 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 10:56:17 -0700 Subject: Removed storage() from get_ptr; added boilerplate for bias dequant_mm. --- bitsandbytes/functional.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b4409e4..23e5464 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -218,7 +218,7 @@ def get_ptr(A: Tensor) -> ct.c_void_p: if A is None: return None else: - return ct.c_void_p(A.data.storage().data_ptr()) + return ct.c_void_p(A.data.data_ptr()) def pre_call(device): @@ -1407,8 +1407,10 @@ def mm_dequant( out=None, new_row_stats=None, new_col_stats=None, + bias=None ): assert A.dtype == torch.int32 + if bias is not None: assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -1430,17 +1432,20 @@ def mm_dequant( new_col_stats.shape[0] == col_stats.shape[0] ), f"{new_col_stats.shape} vs {col_stats.shape}" + prev_device = pre_call(A.device) ptrA = get_ptr(A) ptrOut = get_ptr(out) ptrRowStats = get_ptr(row_stats) ptrColStats = get_ptr(col_stats) ptrNewRowStats = get_ptr(new_row_stats) ptrNewColStats = get_ptr(new_col_stats) + ptrBias = get_ptr(bias) numRows = ct.c_int32(out_shape[0]) numCols = ct.c_int32(out_shape[1]) - is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols) + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) + lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + post_call(prev_device) return out -- cgit v1.2.3