summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/autograd/_functions.py52
-rw-r--r--bitsandbytes/functional.py26
-rw-r--r--csrc/kernels.cu78
-rw-r--r--csrc/kernels.cuh2
-rw-r--r--csrc/ops.cu26
-rw-r--r--csrc/ops.cuh2
-rw-r--r--csrc/pythonInterface.c6
-rw-r--r--tests/test_functional.py26
8 files changed, 194 insertions, 24 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 815a4f1..5503749 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -191,24 +191,24 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
- if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
- # generate outlier index and subB
- outlier_idx = torch.unique(coo_tensorA.colidx).long()
- state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
- # do not use pool for 2nd FFN layer
- state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- else:
- state.idx = outlier_idx
- state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
-
- if state.idx is not None:
- # extract outliers
- CA[:, state.idx] = 0
- CAt[:, state.idx] = 0
- subA = A[:, state.idx]
- else:
- subA = None
+ #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
+ # # generate outlier index and subB
+ # outlier_idx = torch.unique(coo_tensorA.colidx).long()
+ # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # # do not use pool for 2nd FFN layer
+ # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
+ # else:
+ # state.idx = outlier_idx
+ # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
+
+ #if state.idx is not None:
+ # # extract outliers
+ # CA[:, state.idx] = 0
+ # CAt[:, state.idx] = 0
+ # subA = A[:, state.idx]
+ #else:
+ # subA = None
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@@ -229,6 +229,22 @@ class MatMul8bitLt(torch.autograd.Function):
else:
has_grad = False
+ if coo_tensorA is not None and not state.has_fp16_weights:
+ # extract outliers
+
+ outlier_idx = torch.unique(coo_tensorA.colidx)
+ state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # do not use pool for 2nd FFN layer
+ state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
+ else:
+ state.idx = outlier_idx
+ outliers = F.extract_outliers(state.CxB, state.SB, outlier_idx).half()
+ state.subB = (outliers*state.SCB.view(-1, 1).half()/127.0).t().contiguous()
+ CA[:, state.idx.long()] = 0
+ CAt[:, state.idx.long()] = 0
+ subA = A[:, state.idx.long()]
+
shapeB = state.SB[0]
if len(input_shape) == 3:
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 0190a7e..ac85f88 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1404,3 +1404,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
x *= SA[1]/127
x +=offset
return x.to(dtype)
+
+def extract_outliers(A, SA, idx):
+ shapeA = SA[0]
+ formatA = SA[1]
+ assert formatA in ['col_turing', 'col_ampere']
+ assert A.device.type == 'cuda'
+
+ out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
+
+ idx_size = ct.c_int32(idx.numel())
+ rows = ct.c_int32(shapeA[0])
+ cols = ct.c_int32(shapeA[1])
+ ptrA = get_ptr(A)
+ ptrIdx = get_ptr(idx)
+ ptrOut = get_ptr(out)
+
+ if formatA == 'col_turing':
+ lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
+ elif formatA == 'col_ampere':
+ lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
+
+ return out
+
+
+
+
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 6eca3aa..d4eb56c 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -2591,16 +2591,82 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}
+template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
+{
+ int local_colidx = idx[blockIdx.x];
+
+ if(FORMAT==COL_TURING)
+ {
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+ // columns are grouped in increments of 4, meaning that one has the following rows and columns
+ // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
+ // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]
+
+ // each thread reads 1 element = 1 row
+ for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
+ {
+ int offset_per_col_tile = ((rowsA+7)/8)*32*8;
+ int tile_offset_rows = (row/8)*32*8;
+ int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
+ int offset = 0;
+ int subtile_col_idx = local_colidx%32;
+ int subtile_row_idx = row % 8;
+ if(row % 2 == 1)
+ offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2);
+ else
+ // even
+ offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2);
+
+ offset += tile_offset_rows + tile_offset_cols;
+
+ char val = A[offset];
+
+ int out_idx = (row*idx_size) + blockIdx.x;
+ out[out_idx] = val;
+ }
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+
+ for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
+ {
+ // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
+ // within each tile.
+ int offset_per_col_tile = ((rowsA+31)/32)*32*32;
+ int tile_offset_rows = (row/32)*32*32;
+ int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
+ int subtile_col_idx = local_colidx%32;
+ int subtile_row_idx = row % 32;
+ // this magic is taken from the cublasLt doc (search for COL32)
+ int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
+ offset += tile_offset_cols + tile_offset_rows;
+
+ char val = A[offset];
+ int out_idx = (row*idx_size) + blockIdx.x;
+ out[out_idx] = val;
+ }
+ }
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
-template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+
+template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index 4e65e96..2447494 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -118,6 +118,8 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+
#endif
diff --git a/csrc/ops.cu b/csrc/ops.cu
index c430d55..952894c 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -598,10 +598,36 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
+
+template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
+{
+ int threads = 256;
+ // we load 128 column values per warp
+ int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
+ int tiledRows = 0;
+
+ int num_blocks = idx_size;
+
+ if(FORMAT == COL_TURING)
+ {
+ tiledRows = fill_up_to_nearest_multiple(rows, 8);
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+ tiledRows = fill_up_to_nearest_multiple(rows, 32);
+ }
+
+ kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
+template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
+
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 4e719df..4b09ecf 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -174,4 +174,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
+
#endif
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index a6a4b13..7356c11 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -105,6 +105,9 @@ void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRo
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
+void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
+void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
+
int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
@@ -280,6 +283,9 @@ extern "C"
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+ void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
+ void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
+
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
diff --git a/tests/test_functional.py b/tests/test_functional.py
index d80a4f9..bfc3e28 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1859,3 +1859,29 @@ def test_zp():
print(err1, err2, err3, err4, err5, err6)
+
+def test_extract_outliers():
+ for i in range(k):
+ shapeA = (4096, 4096*4)
+ idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
+ #idx = torch.Tensor([0]).int().cuda()
+ A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ outliers1 = A[:, idx.long()]
+
+ CA, SA = F.transform(A, 'col_turing')
+
+ outliers2 = F.extract_outliers(CA, SA, idx)
+
+ assert outliers2.shape[0] == shapeA[0]
+ assert outliers2.shape[1] == idx.numel()
+
+ torch.testing.assert_allclose(outliers1, outliers2)
+
+ CA, SA = F.transform(A, 'col_ampere')
+
+ outliers2 = F.extract_outliers(CA, SA, idx)
+
+ assert outliers2.shape[0] == shapeA[0]
+ assert outliers2.shape[1] == idx.numel()
+
+ torch.testing.assert_allclose(outliers1, outliers2)