From 1ed2fa2f218d8dac401f3315420ffec92014c124 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 10:56:17 -0700 Subject: Removed storage() from get_ptr; added boilerplate for bias dequant_mm. --- bitsandbytes/functional.py | 11 ++++++++--- csrc/kernels.cu | 4 ++-- csrc/kernels.cuh | 2 +- csrc/ops.cu | 5 ++--- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 4 ++-- tests/test_functional.py | 20 +++++++++++--------- 7 files changed, 27 insertions(+), 21 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b4409e4..23e5464 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -218,7 +218,7 @@ def get_ptr(A: Tensor) -> ct.c_void_p: if A is None: return None else: - return ct.c_void_p(A.data.storage().data_ptr()) + return ct.c_void_p(A.data.data_ptr()) def pre_call(device): @@ -1407,8 +1407,10 @@ def mm_dequant( out=None, new_row_stats=None, new_col_stats=None, + bias=None ): assert A.dtype == torch.int32 + if bias is not None: assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -1430,17 +1432,20 @@ def mm_dequant( new_col_stats.shape[0] == col_stats.shape[0] ), f"{new_col_stats.shape} vs {col_stats.shape}" + prev_device = pre_call(A.device) ptrA = get_ptr(A) ptrOut = get_ptr(out) ptrRowStats = get_ptr(row_stats) ptrColStats = get_ptr(col_stats) ptrNewRowStats = get_ptr(new_row_stats) ptrNewColStats = get_ptr(new_col_stats) + ptrBias = get_ptr(bias) numRows = ct.c_int32(out_shape[0]) numCols = ct.c_int32(out_shape[1]) - is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols) + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) + lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + post_call(prev_device) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d4eb56c..0a1bf79 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1889,7 +1889,7 @@ template __global__ void kgetColRowStats(half * __rest #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) -template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n) +template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n) { // Strategy: To dequantize we need to load col/row statistics. This can be very expensive @@ -2675,7 +2675,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>( template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); +template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 2447494..bdf61b2 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -111,7 +111,7 @@ template __global__ void kspmm_coo_very_s template __global__ void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, - half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); + half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); diff --git a/csrc/ops.cu b/csrc/ops.cu index c16dd96..ed32828 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -435,7 +435,7 @@ int fill_up_to_nearest_multiple(int value, int multiple) return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols) +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) { int threads = 512; int tileCols = fill_up_to_nearest_multiple(numCols, 32); @@ -447,7 +447,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, num_blocks = num_blocks*(tileCols/32); assert(threads <= tilesize); - kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); + kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -465,7 +465,6 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r col_tiles = col_tiles > 0 ? col_tiles : 1; int num_blocks = row_tiles * col_tiles; - if(nnz_threshold == 0.0) kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); else if(nnz_threshold != 0.0) diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 4b09ecf..acfdb06 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -163,7 +163,7 @@ template int igemmlt(cublasLtHandle template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 7356c11..0707674 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -248,8 +248,8 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) - void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols) - { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); } + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols) + { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } diff --git a/tests/test_functional.py b/tests/test_functional.py index ab7d672..65bf092 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -961,20 +961,24 @@ dim4 = torch.randint(64, 1024, size=(n,)).tolist() dims = (2,) # ldb = list(range(256, 1*1024, 256)) formatB = ["col_turing", "col_ampere"] -values = list(product(dim1, dim4, dims, formatB)) +has_bias = [True, False] +values = list(product(dim1, dim4, dims, formatB, has_bias)) names = [ - "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values + "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values ] -@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) -def test_dequant_mm(dim1, dim4, dims, formatB): +@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) +def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() + bias = None + if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16) formatB = F.get_special_format_str() for i in range(k): A = torch.randn(dim1, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda") C1 = torch.matmul(A.half(), B.t().half()) + if has_bias: C1 += bias A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) @@ -985,17 +989,15 @@ def test_dequant_mm(dim1, dim4, dims, formatB): C3, S = F.nvidia_transform(C2, "row", state=SC) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + if has_bias: C4 += bias count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item() n = C1.numel() p = 0.06 - assert ( - count / n < p - ), f"error in more than {p} of elements: {count}/{n}={count/n}" + assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" - C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten()) + C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) torch.testing.assert_allclose(C5, C4) - # print(C2) n = 2 -- cgit v1.2.3