From cbb901ac51bd6c41e4243ffb936ef0e2f7ca8ada Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 12:12:38 -0700 Subject: Boilerplate and test for extract_outliers. --- bitsandbytes/functional.py | 26 ++++++++++++++++++++++++++ csrc/kernels.cu | 7 +++++++ csrc/kernels.cuh | 2 ++ csrc/ops.cu | 27 +++++++++++++++++++++++++++ csrc/ops.cuh | 2 ++ csrc/pythonInterface.c | 6 ++++++ tests/test_functional.py | 17 +++++++++++++++++ 7 files changed, 87 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 806c254..a9233e2 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1409,3 +1409,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): x *= SA[1]/127 x +=offset return x.to(dtype) + +def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ['col_turing', 'col_ampere'] + assert A.device.type == 'cuda' + + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + if formatA == 'col_turing': + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == 'col_ampere': + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + + return out + + + + diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1c3e723..78170d0 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2592,10 +2592,17 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } } +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +{ +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index cbfbeba..ec2068e 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -118,6 +118,8 @@ template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 8946015..fe2d7fe 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -578,10 +578,37 @@ template void spmm_coo_very_sparse_naive(int *max_count, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) +{ + int threads = 256; + // we load 128 column values per warp + int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); + int tiledRows = 0; + + int elements = idx_size*cols; // matrix A is transposed, so we extract columns + int num_blocks = (elements+threads-1)/threads; + + if(FORMAT == COL_TURING) + { + tiledRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + tiledRows = fill_up_to_nearest_multiple(rows, 32); + } + + kExtractOutliers<<>>(A, idx, out, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 4e719df..4b09ecf 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -174,4 +174,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 03c8d92..2ecbaae 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -106,6 +106,9 @@ void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRo void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } +void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } + int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } @@ -280,6 +283,9 @@ extern "C" void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } + void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 6cbe58f..b508367 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1856,3 +1856,20 @@ def test_zp(): print(err1, err2, err3, err4, err5, err6) + +def test_extract_outliers(): + shapeA = (128, 128) + idx = torch.randint(0, shapeA[1], size=(10,)).int() + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + outliers1 = A[:, idx.long()] + + CA, SA = F.transform(A, 'col_turing') + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + + + + torch.testing.assert_allclose(outliers1, outliers2) -- cgit v1.2.3 From bcab99ec877ba063543bd7c03ba1cdd1b06e8078 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 17:39:30 -0700 Subject: Working outlier extraction for Turing. --- csrc/kernels.cu | 61 +++++++++++++++++++++++++++++++++++++++++++++--- csrc/kernels.cuh | 2 +- csrc/ops.cu | 5 ++-- tests/test_functional.py | 23 ++++++++++-------- 4 files changed, 74 insertions(+), 17 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 78170d0..bb36d9b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2592,16 +2592,71 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } } -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) { + int local_colidx = idx[blockIdx.x]; + + if(FORMAT==COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + + char val = 0; + //printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); + if(offset > tiledColsA*tiledRowsA) + printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); + else + val = A[offset]; + + int out_idx = (row*idx_size) + blockIdx.x; + + //if(out_idx > colsA*idx_size) + if(val != 0) + { + //printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val); + out[out_idx] = val; + } + else + { + out[out_idx] = val; + } + } + + } } //============================================================== // TEMPLATE DEFINITIONS //============================================================== -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ec2068e..eda5ba0 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -118,7 +118,7 @@ template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index fe2d7fe..e6227ae 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -586,8 +586,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); int tiledRows = 0; - int elements = idx_size*cols; // matrix A is transposed, so we extract columns - int num_blocks = (elements+threads-1)/threads; + int num_blocks = idx_size; if(FORMAT == COL_TURING) { @@ -598,7 +597,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id tiledRows = fill_up_to_nearest_multiple(rows, 32); } - kExtractOutliers<<>>(A, idx, out, rows, cols, tiledRows, tiledCols); + kExtractOutliers<<>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } diff --git a/tests/test_functional.py b/tests/test_functional.py index b508367..4d06447 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1858,18 +1858,21 @@ def test_zp(): def test_extract_outliers(): - shapeA = (128, 128) - idx = torch.randint(0, shapeA[1], size=(10,)).int() - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) - outliers1 = A[:, idx.long()] - - CA, SA = F.transform(A, 'col_turing') + for i in range(k): + shapeA = (4096, 4*4096) + idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() + #idx = torch.Tensor([32]).int().cuda() + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + outliers1 = A[:, idx.long()] - outliers2 = F.extract_outliers(CA, SA, idx) + CA, SA = F.transform(A, 'col_turing') - assert outliers2.shape[0] == shapeA[0] - assert outliers2.shape[1] == idx.numel() + outliers2 = F.extract_outliers(CA, SA, idx) + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + #print(outliers1) + #print(outliers2) - torch.testing.assert_allclose(outliers1, outliers2) + torch.testing.assert_allclose(outliers1, outliers2) -- cgit v1.2.3 From 32fa459ed7c812c79e847145004061f21b7ac0d9 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 18:15:51 -0700 Subject: Added col_ampere outlier extraction kernel. --- csrc/kernels.cu | 40 ++++++++++++++++++++++------------------ tests/test_functional.py | 14 ++++++++++---- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index bb36d9b..79ad5de 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2626,28 +2626,32 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * offset += tile_offset_rows + tile_offset_cols; - - char val = 0; - //printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); - if(offset > tiledColsA*tiledRowsA) - printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); - else - val = A[offset]; + char val = A[offset]; int out_idx = (row*idx_size) + blockIdx.x; - - //if(out_idx > colsA*idx_size) - if(val != 0) - { - //printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val); - out[out_idx] = val; - } - else - { - out[out_idx] = val; - } + out[out_idx] = val; } + } + else if(FORMAT == COL_AMPERE) + { + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } } } diff --git a/tests/test_functional.py b/tests/test_functional.py index 4d06447..2d58fac 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1859,9 +1859,9 @@ def test_zp(): def test_extract_outliers(): for i in range(k): - shapeA = (4096, 4*4096) + shapeA = (4096, 4096*4) idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() - #idx = torch.Tensor([32]).int().cuda() + #idx = torch.Tensor([0]).int().cuda() A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) outliers1 = A[:, idx.long()] @@ -1872,7 +1872,13 @@ def test_extract_outliers(): assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[1] == idx.numel() - #print(outliers1) - #print(outliers2) + torch.testing.assert_allclose(outliers1, outliers2) + + CA, SA = F.transform(A, 'col_ampere') + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() torch.testing.assert_allclose(outliers1, outliers2) -- cgit v1.2.3 From 47a73d94c3d3284f6073b0ff189ed5bc9e3a8762 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 19:15:35 -0700 Subject: Matmullt with direct outlier extraction for 8-bit inference. --- bitsandbytes/autograd/_functions.py | 52 ++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 815a4f1..5503749 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -191,24 +191,24 @@ class MatMul8bitLt(torch.autograd.Function): # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: - # generate outlier index and subB - outlier_idx = torch.unique(coo_tensorA.colidx).long() - state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # do not use pool for 2nd FFN layer - state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - else: - state.idx = outlier_idx - state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() - - if state.idx is not None: - # extract outliers - CA[:, state.idx] = 0 - CAt[:, state.idx] = 0 - subA = A[:, state.idx] - else: - subA = None + #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: + # # generate outlier index and subB + # outlier_idx = torch.unique(coo_tensorA.colidx).long() + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx + # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() + + #if state.idx is not None: + # # extract outliers + # CA[:, state.idx] = 0 + # CAt[:, state.idx] = 0 + # subA = A[:, state.idx] + #else: + # subA = None else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -229,6 +229,22 @@ class MatMul8bitLt(torch.autograd.Function): else: has_grad = False + if coo_tensorA is not None and not state.has_fp16_weights: + # extract outliers + + outlier_idx = torch.unique(coo_tensorA.colidx) + state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # do not use pool for 2nd FFN layer + state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + else: + state.idx = outlier_idx + outliers = F.extract_outliers(state.CxB, state.SB, outlier_idx).half() + state.subB = (outliers*state.SCB.view(-1, 1).half()/127.0).t().contiguous() + CA[:, state.idx.long()] = 0 + CAt[:, state.idx.long()] = 0 + subA = A[:, state.idx.long()] + shapeB = state.SB[0] if len(input_shape) == 3: -- cgit v1.2.3 From a40921365639b7e4c292ca344e6109a7ccd7cc63 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 19:38:17 -0700 Subject: Fixed make default to compile with cublaslt. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 10f267a..195009f 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) -- cgit v1.2.3 From 389f66ca5a737eb7f22f22fed420274ff622623e Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 27 Jul 2022 01:46:35 -0700 Subject: Fixed direct extraction masking. --- bitsandbytes/autograd/_functions.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 5503749..e641583 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -191,6 +191,7 @@ class MatMul8bitLt(torch.autograd.Function): # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # # generate outlier index and subB # outlier_idx = torch.unique(coo_tensorA.colidx).long() @@ -214,7 +215,6 @@ class MatMul8bitLt(torch.autograd.Function): state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None - C32A, SA = F.transform(CA, 'col32') # 2. Quantize B if state.has_fp16_weights: @@ -233,14 +233,15 @@ class MatMul8bitLt(torch.autograd.Function): # extract outliers outlier_idx = torch.unique(coo_tensorA.colidx) - state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # do not use pool for 2nd FFN layer - state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - else: - state.idx = outlier_idx - outliers = F.extract_outliers(state.CxB, state.SB, outlier_idx).half() - state.subB = (outliers*state.SCB.view(-1, 1).half()/127.0).t().contiguous() + state.idx = outlier_idx + #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + #else: + # state.idx = outlier_idx + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half() CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] @@ -253,11 +254,12 @@ class MatMul8bitLt(torch.autograd.Function): output_shape = (input_shape[0], shapeB[0]) # 3. Matmul + C32A, SA = F.transform(CA, 'col32') out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) output = F.mm_dequant(out32, Sout32, SCA, state.SCB) # 4. Mixed-precision decomposition matmul - if state.threshold > 0.0 and coo_tensorA is not None and subA is not None: + if coo_tensorA is not None and subA is not None: output += torch.matmul(subA, state.subB) # 5. Save state -- cgit v1.2.3 From bd515328d70f344f935075f359c5aefc616878d5 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 27 Jul 2022 05:57:50 -0700 Subject: Fixed deployment script to check for LD_LIBRARY_PATH. --- deploy_from_slurm.sh | 66 +++++++++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/deploy_from_slurm.sh b/deploy_from_slurm.sh index 664d40e..37311bc 100644 --- a/deploy_from_slurm.sh +++ b/deploy_from_slurm.sh @@ -1,28 +1,37 @@ #!/bin/bash BASE_PATH=$1 +echo "MAKE SURE LD_LIBRARY_PATH IS EMPTY!" +echo $LD_LIBRARY_PATH + +if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + + module unload cuda module unload gcc -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME= -#make cpuonly -# -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -#CUDA_VERSION=cpu python -m build -#python -m twine upload dist/* --verbose --repository testpypi +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME= +make cpuonly + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=cpu python -m build +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.0 -make cuda110 +make cuda110 if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. @@ -102,20 +111,20 @@ fi CUDA_VERSION=115 python -m build python -m twine upload dist/* --verbose --repository testpypi -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.6 -# -#make cuda11x -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -#CUDA_VERSION=116 python -m build -#python -m twine upload dist/* --verbose --repository testpypi -# +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.6 + +make cuda11x +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=116 python -m build +python -m twine upload dist/* --verbose --repository testpypi + rm -rf dist build make clean make cleaneggs @@ -257,5 +266,4 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=117-nomatmul python -m build -python -m twine upload dist/* --verbose python -m twine upload dist/* --verbose --repository testpypi -- cgit v1.2.3