From 2f01865a2ff4ad3345c156f7a2f76fe79ec4ed9a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 3 Aug 2022 09:05:37 -0700 Subject: Added CUDA block assert and is_on_gpu check. --- Makefile | 2 +- bitsandbytes/functional.py | 24 +++++++++++ csrc/ops.cu | 99 +++++++++++++++++++++++++++------------------- 3 files changed, 83 insertions(+), 42 deletions(-) diff --git a/Makefile b/Makefile index 10f267a..3e95b35 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0190a7e..0a2d557 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -141,6 +141,14 @@ def get_special_format_str(): elif major == 8: return 'col_ampere' else: return 'col_turing' + +def is_on_gpu(tensors): + on_gpu = True + for t in tensors: + if t is None: continue # NULL pointers are fine + on_gpu &= t.device.type == 'cuda' + return on_gpu + def get_ptr(A: Tensor) -> ct.c_void_p: ''' Get the ctypes pointer from a PyTorch Tensor. @@ -284,6 +292,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens The 256 quantiles in float32 datatype. ''' if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + is_on_gpu([A, out]) if A.dtype == torch.float32: lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: @@ -337,6 +346,7 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N if A.device.type != 'cpu': + is_on_gpu([code, A, absmax, out, rand]) if rand is not None: assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) @@ -401,6 +411,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]') if A.device.type != 'cpu': + is_on_gpu([A, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: @@ -458,6 +469,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: Quantized 8-bit tensor. ''' if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out @@ -483,6 +495,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: 32-bit output tensor. ''' if out is None: out = torch.zeros_like(A, dtype=torch.float32) + is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out @@ -662,6 +675,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: The current optimiation steps (number of past gradient norms). """ + is_on_gpu([grad, gnorm_vec]) if grad.dtype == torch.float32: lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel())) elif grad.dtype == torch.float16: @@ -694,6 +708,7 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) + is_on_gpu([histogram, index1, index2d, source]) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): @@ -820,6 +835,7 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed # B^T @ A^T = C^T # [km, nk -> mn] + is_on_gpu([B, A, out]) lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) return out @@ -892,6 +908,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ptr = CUBLAS_Context.get_instance().get_context(A.device) + is_on_gpu([B, A, out]) lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) @@ -954,6 +971,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) + is_on_gpu([A, B, out]) if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) @@ -994,6 +1012,7 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non numRows = ct.c_int32(out_shape[0]) numCols = ct.c_int32(out_shape[1]) + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats]) lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols) return out @@ -1024,6 +1043,7 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr cols = ct.c_int32(cols) prev_device = pre_call(A.device) + is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) post_call(prev_device) @@ -1133,6 +1153,7 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, ptrOutCol = get_ptr(out_col) ptrOutRow = get_ptr(out_row) + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: @@ -1185,6 +1206,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ptrA = get_ptr(A) ptrOut = get_ptr(out) + is_on_gpu([A, out]) if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) @@ -1240,6 +1262,7 @@ def spmm_coo(cooA, B, out=None): cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) return out @@ -1285,6 +1308,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): #print(cooA.rowidx[:64]) #print(cooA.colidx[:64].sort()[0]) + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) if B.dtype == torch.float16: lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) elif B.dtype == torch.int8: diff --git a/csrc/ops.cu b/csrc/ops.cu index c430d55..b3d07c6 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -19,53 +19,59 @@ using std::endl; void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) { int threads = 512; - int blocks = n/threads; - blocks = n % threads == 0 ? blocks : blocks + 1; - kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } template void estimateQuantiles(T *A, float *code, float offset, int n) { - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); - kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); + kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } void quantize(float *code, float *A, unsigned char *out, int n) { - int blocks = n/1024; - blocks = n % 1024 == 0 ? blocks : blocks + 1; - kQuantize<<>>(code, A, out, n); + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kQuantize<<>>(code, A, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } void dequantize(float *code, unsigned char *A, float *out, int n) { - int blocks = n/1024; - blocks = n % 1024 == 0 ? blocks : blocks + 1; - kDequantize<<>>(code, A, out, n); + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kDequantize<<>>(code, A, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) { - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { - int blocks = n/blocksize; - blocks = n % blocksize == 0 ? blocks : blocks + 1; + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); if(blocksize == 4096) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); + kDequantizeBlockwise<<>>(code, A, absmax, out, n); else if(blocksize == 2048) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); + kDequantizeBlockwise<<>>(code, A, absmax, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -74,18 +80,19 @@ template void optimizer32bit(T* g, T* p, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) { - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); switch(OPTIMIZER) { case ADAM: if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -95,11 +102,11 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; } @@ -115,8 +122,9 @@ template void optimizerStatic8bit(T* p, T* g, float weight_decay, const float gnorm_scale, int n) { - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } @@ -125,9 +133,9 @@ template void optimizerStatic8bit(T* p, T* g, case ADAM: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; @@ -135,9 +143,9 @@ template void optimizerStatic8bit(T* p, T* g, case RMSPROP: case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; @@ -156,22 +164,24 @@ template void optimizerStatic8bitBlockwise(T* p, T* g float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) { - int blocks = 0; + int num_blocks = 0; switch(OPTIMIZER) { case ADAM: - blocks = n/BLOCKSIZE_2STATE; - blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; - kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: case RMSPROP: case ADAGRAD: - blocks = n/BLOCKSIZE_1STATE; - blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; - kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; @@ -182,10 +192,11 @@ template void optimizerStatic8bitBlockwise(T* p, T* g template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) { - int blocks = n/2048; - blocks = n % 2048 == 0 ? blocks : blocks + 1; + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); - kPercentileClipping<<>>(g, gnorm_vec, step, n); + kPercentileClipping<<>>(g, gnorm_vec, step, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -445,6 +456,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, int num_blocks = numRows/subtile_rows; num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; num_blocks = num_blocks*(tileCols/32); + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); assert(threads <= tilesize); //cout << num_blocks << " blocks" << endl; @@ -463,6 +475,9 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS); + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + + if(nnz_threshold == 0.0) kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); else if(nnz_threshold != 0.0) @@ -480,6 +495,7 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); //cout << cols << " " << tiledCols << " " << tiledRows << endl; //cout << "num blocks " << num_blocks << endl; @@ -503,6 +519,7 @@ template void transformRowToFormat(char * A, char *o int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); int outCols = fill_up_to_nearest_multiple(cols, 32); int outRows = fill_up_to_nearest_multiple(rows, 32); if(FORMAT == COL_TURING) -- cgit v1.2.3