summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--bitsandbytes/functional.py24
-rw-r--r--csrc/ops.cu99
3 files changed, 83 insertions, 42 deletions
diff --git a/Makefile b/Makefile
index 10f267a..3e95b35 100644
--- a/Makefile
+++ b/Makefile
@@ -51,7 +51,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
- $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 0190a7e..0a2d557 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -141,6 +141,14 @@ def get_special_format_str():
elif major == 8: return 'col_ampere'
else: return 'col_turing'
+
+def is_on_gpu(tensors):
+ on_gpu = True
+ for t in tensors:
+ if t is None: continue # NULL pointers are fine
+ on_gpu &= t.device.type == 'cuda'
+ return on_gpu
+
def get_ptr(A: Tensor) -> ct.c_void_p:
'''
Get the ctypes pointer from a PyTorch Tensor.
@@ -284,6 +292,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
The 256 quantiles in float32 datatype.
'''
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+ is_on_gpu([A, out])
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
@@ -337,6 +346,7 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
if A.device.type != 'cpu':
+ is_on_gpu([code, A, absmax, out, rand])
if rand is not None:
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
@@ -401,6 +411,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]')
if A.device.type != 'cpu':
+ is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
@@ -458,6 +469,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
Quantized 8-bit tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+ is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -483,6 +495,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
32-bit output tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+ is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -662,6 +675,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
The current optimiation steps (number of past gradient norms).
"""
+ is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
elif grad.dtype == torch.float16:
@@ -694,6 +708,7 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
+ is_on_gpu([histogram, index1, index2d, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
@@ -820,6 +835,7 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# B^T @ A^T = C^T
# [km, nk -> mn]
+ is_on_gpu([B, A, out])
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
return out
@@ -892,6 +908,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ptr = CUBLAS_Context.get_instance().get_context(A.device)
+ is_on_gpu([B, A, out])
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
@@ -954,6 +971,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0
ptrRowScale = get_ptr(None)
+ is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
@@ -994,6 +1012,7 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
+ is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
return out
@@ -1024,6 +1043,7 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
cols = ct.c_int32(cols)
prev_device = pre_call(A.device)
+ is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
@@ -1133,6 +1153,7 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row)
+ is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
@@ -1185,6 +1206,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
+ is_on_gpu([A, out])
if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
@@ -1240,6 +1262,7 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
return out
@@ -1285,6 +1308,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
#print(cooA.rowidx[:64])
#print(cooA.colidx[:64].sort()[0])
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
elif B.dtype == torch.int8:
diff --git a/csrc/ops.cu b/csrc/ops.cu
index c430d55..b3d07c6 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -19,53 +19,59 @@ using std::endl;
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{
int threads = 512;
- int blocks = n/threads;
- blocks = n % threads == 0 ? blocks : blocks + 1;
- kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
+ int num_blocks = n/threads;
+ num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{
- int blocks = n/4096;
- blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ int num_blocks = n/4096;
+ num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
- kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
+ kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void quantize(float *code, float *A, unsigned char *out, int n)
{
- int blocks = n/1024;
- blocks = n % 1024 == 0 ? blocks : blocks + 1;
- kQuantize<<<blocks, 1024>>>(code, A, out, n);
+ int num_blocks = n/1024;
+ num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void dequantize(float *code, unsigned char *A, float *out, int n)
{
- int blocks = n/1024;
- blocks = n % 1024 == 0 ? blocks : blocks + 1;
- kDequantize<<<blocks, 1024>>>(code, A, out, n);
+ int num_blocks = n/1024;
+ num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
{
- int blocks = n/4096;
- blocks = n % 4096 == 0 ? blocks : blocks + 1;
- kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
+ int num_blocks = n/4096;
+ num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{
- int blocks = n/blocksize;
- blocks = n % blocksize == 0 ? blocks : blocks + 1;
+ int num_blocks = n/blocksize;
+ num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(blocksize == 4096)
- kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n);
+ kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048)
- kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n);
+ kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
- int blocks = n/4096;
- blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ int num_blocks = n/4096;
+ num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
switch(OPTIMIZER)
{
case ADAM:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
- kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
+ kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
- kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
+ kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
@@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
- kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
+ kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
- kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
+ kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
}
@@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
float weight_decay,
const float gnorm_scale, int n)
{
- int blocks = n/4096;
- blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ int num_blocks = n/4096;
+ num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
@@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
- kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
+ kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
- kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
+ kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
- kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
+ kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
- kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
+ kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
- int blocks = 0;
+ int num_blocks = 0;
switch(OPTIMIZER)
{
case ADAM:
- blocks = n/BLOCKSIZE_2STATE;
- blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
- kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
+ num_blocks = n/BLOCKSIZE_2STATE;
+ num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
- blocks = n/BLOCKSIZE_1STATE;
- blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
- kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
+ num_blocks = n/BLOCKSIZE_1STATE;
+ num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
- int blocks = n/2048;
- blocks = n % 2048 == 0 ? blocks : blocks + 1;
+ int num_blocks = n/2048;
+ num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
- kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n);
+ kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -445,6 +456,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32);
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
assert(threads <= tilesize);
//cout << num_blocks << " blocks" << endl;
@@ -463,6 +475,9 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+
+
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
@@ -480,6 +495,7 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
//cout << "num blocks " << num_blocks << endl;
@@ -503,6 +519,7 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING)