diff options
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 17 | ||||
-rw-r--r-- | bitsandbytes/functional.py | 217 | ||||
-rw-r--r-- | csrc/ops.cu | 126 | ||||
-rw-r--r-- | tests/test_autograd.py | 17 |
5 files changed, 180 insertions, 199 deletions
@@ -58,7 +58,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index b56b2ee..14f2660 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import torch - +import math import bitsandbytes as bnb import bitsandbytes.functional as F @@ -199,6 +199,17 @@ class MatmulLtState: class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if math.prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None req_gradA, req_gradB = ctx.req_grads CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states @@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.grad_shape ) - return grad_A, grad_B, None, None, None, None, None + return grad_A, grad_B, None, None matmul = MatMul8bitLt.apply diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 236ef39..b4409e4 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. import ctypes as ct import random -from typing import Tuple - +import math import torch + +from typing import Tuple from torch import Tensor from .cextension import COMPILED_WITH_CUDA, lib @@ -193,6 +194,14 @@ def get_special_format_str(): return "col_turing" + +def is_on_gpu(tensors): + on_gpu = True + for t in tensors: + if t is None: continue # NULL pointers are fine + on_gpu &= t.device.type == 'cuda' + return on_gpu + def get_ptr(A: Tensor) -> ct.c_void_p: """ Get the ctypes pointer from a PyTorch Tensor. @@ -336,7 +345,7 @@ def nvidia_transform( def estimate_quantiles( A: Tensor, out: Tensor = None, offset: float = 1 / 512 ) -> Tensor: - """ + ''' Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -361,9 +370,9 @@ def estimate_quantiles( ------- torch.Tensor: The 256 quantiles in float32 datatype. - """ - if out is None: - out = torch.zeros((256,), dtype=torch.float32, device=A.device) + ''' + if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + is_on_gpu([A, out]) if A.dtype == torch.float32: lib.cestimate_quantiles_fp32( get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) @@ -428,7 +437,8 @@ def quantize_blockwise( if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - if A.device.type != "cpu": + if A.device.type != 'cpu': + is_on_gpu([code, A, absmax, out, rand]) if rand is not None: assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) @@ -541,7 +551,8 @@ def dequantize_blockwise( f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" ) - if A.device.type != "cpu": + if A.device.type != 'cpu': + is_on_gpu([A, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32( get_ptr(quant_state[1]), @@ -610,7 +621,7 @@ def dequantize( def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - """ + ''' Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ------- torch.Tensor: Quantized 8-bit tensor. - """ - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) + ''' + if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - """ + ''' Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ------- torch.Tensor: 32-bit output tensor. - """ - if out is None: - out = torch.zeros_like(A, dtype=torch.float32) - lib.cdequantize( - get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()) - ) + ''' + if out is None: out = torch.zeros_like(A, dtype=torch.float32) + is_on_gpu([code, A, out]) + lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out @@ -983,6 +992,7 @@ def percentile_clipping( The current optimiation steps (number of past gradient norms). """ + is_on_gpu([grad, gnorm_vec]) if grad.dtype == torch.float32: lib.cpercentile_clipping_g32( get_ptr(grad), @@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d( maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) - lib.chistogram_scatter_add_2d( - get_ptr(histogram), - get_ptr(index1), - get_ptr(index2), - get_ptr(source), - maxdim1, - n, - ) + is_on_gpu([histogram, index1, index2d, source]) + lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) - -def check_matmul( - A, B, out, transposed_A, transposed_B, expected_type=torch.int8 -): - if not torch.cuda.is_initialized(): - torch.cuda.init() +def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): + if not torch.cuda.is_initialized(): torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: raise TypeError( f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" @@ -1212,21 +1212,10 @@ def igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) # B^T @ A^T = C^T - # [km, nk -> mn] - lib.cigemm( - ptr, - ct.c_bool(transposed_B), - ct.c_bool(transposed_A), - ct.c_int32(m), - ct.c_int32(n), - ct.c_int32(k), - get_ptr(B), - get_ptr(A), - get_ptr(out), - ct.c_int32(lda), - ct.c_int32(ldb), - ct.c_int32(ldc), - ) + # [km, nk -> mn] + is_on_gpu([B, A, out]) + lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) return out @@ -1306,24 +1295,10 @@ def batched_igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) - lib.cbatched_igemm( - ptr, - ct.c_bool(transposed_B), - ct.c_bool(transposed_A), - ct.c_int32(m), - ct.c_int32(n), - ct.c_int32(k), - get_ptr(B), - get_ptr(A), - get_ptr(out), - ct.c_int32(lda), - ct.c_int32(ldb), - ct.c_int32(ldc), - ct.c_long(strideA), - ct.c_long(strideB), - ct.c_long(strideC), - ct.c_uint32(num_batch), - ) + is_on_gpu([B, A, out]) + lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), + ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) return out @@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] - if dimsB == 2: - rows = n = shapeB[0] - elif dimsB == 3: - rows = n = shapeB[0] * shapeB[1] + rows = n = shapeB[0] + assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: out, Sout = get_transform_buffer( @@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) - if formatB == "col_turing": + is_on_gpu([A, B, out]) + if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32( ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc @@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ) if has_error == 1: - raise Exception("cublasLt ran into an error!") + print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') + raise Exception('cublasLt ran into an error!') torch.cuda.set_device(prev_device) @@ -1457,16 +1439,8 @@ def mm_dequant( numRows = ct.c_int32(out_shape[0]) numCols = ct.c_int32(out_shape[1]) - lib.cdequant_mm_int32_fp16( - ptrA, - ptrRowStats, - ptrColStats, - ptrOut, - ptrNewRowStats, - ptrNewColStats, - numRows, - numCols, - ) + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats]) + lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols) return out @@ -1507,15 +1481,8 @@ def get_colrow_absmax( cols = ct.c_int32(cols) prev_device = pre_call(A.device) - lib.cget_col_row_stats( - ptrA, - ptrRowStats, - ptrColStats, - ptrNnzrows, - ct.c_float(threshold), - rows, - cols, - ) + is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) + lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) post_call(prev_device) if threshold > 0.0: @@ -1642,6 +1609,7 @@ def double_quant( ptrOutCol = get_ptr(out_col) ptrOutRow = get_ptr(out_row) + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: @@ -1714,33 +1682,19 @@ def get_special_format_str(): ) assert major >= 7 - if major == 7: - return "col_turing" - elif major == 8: - return "col_ampere" - else: - return "col_turing" + if major == 7: return 'col_turing' + elif major == 8: return 'col_ampere' + else: return 'col_turing' -def transform( - A, - to_order, - from_order="row", - out=None, - transpose=False, - state=None, - ld=None, -): - if state is None: - state = (A.shape, from_order) - else: - from_order = state[1] - if out is None: - out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1], transpose - ) - else: - new_state = (state[0], to_order) # (shape, order) + + +def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + prev_device = pre_call(A.device) + if state is None: state = (A.shape, from_order) + else: from_order = state[1] + if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -1752,7 +1706,8 @@ def transform( ptrA = get_ptr(A) ptrOut = get_ptr(out) - if to_order == "col32": + is_on_gpu([A, out]) + if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -1773,9 +1728,9 @@ def transform( elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError( - f"Transform function not implemented: From {from_order} to {to_order}" - ) + raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + + post_call(prev_device) return out, new_state @@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None): cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - lib.cspmm_coo( - ptr, - ptrRowidx, - ptrColidx, - ptrValues, - cnnz, - crowsA, - ccolsA, - ccolsB, - cldb, - ptrB, - cldc, - ptrC, - ct.c_bool(transposed_B), - ) + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) + lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) return out @@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): # print(cooA.rowidx[:64]) # print(cooA.colidx[:64].sort()[0]) + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) if B.dtype == torch.float16: lib.cspmm_coo_very_sparse_naive_fp16( ptrMaxCount, @@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx): ptrIdx = get_ptr(idx) ptrOut = get_ptr(out) - if formatA == "col_turing": + prev_device = pre_call(A.device) + if formatA == 'col_turing': lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + post_call(prev_device) return out diff --git a/csrc/ops.cu b/csrc/ops.cu index 952894c..9b75e69 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -19,53 +19,59 @@ using std::endl; void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) { int threads = 512; - int blocks = n/threads; - blocks = n % threads == 0 ? blocks : blocks + 1; - kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n); + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n) { - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); - kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n); + kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } void quantize(float *code, float *A, unsigned char *out, int n) { - int blocks = n/1024; - blocks = n % 1024 == 0 ? blocks : blocks + 1; - kQuantize<<<blocks, 1024>>>(code, A, out, n); + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kQuantize<<<num_blocks, 1024>>>(code, A, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } void dequantize(float *code, unsigned char *A, float *out, int n) { - int blocks = n/1024; - blocks = n % 1024 == 0 ? blocks : blocks + 1; - kDequantize<<<blocks, 1024>>>(code, A, out, n); + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kDequantize<<<num_blocks, 1024>>>(code, A, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) { - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n); + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { - int blocks = n/blocksize; - blocks = n % blocksize == 0 ? blocks : blocks + 1; + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); if(blocksize == 4096) - kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n); + kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n); else if(blocksize == 2048) - kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n); + kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) { - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); switch(OPTIMIZER) { case ADAM: if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; } @@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, float weight_decay, const float gnorm_scale, int n) { - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } @@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, case ADAM: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; @@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, case RMSPROP: case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; @@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) { - int blocks = 0; + int num_blocks = 0; switch(OPTIMIZER) { case ADAM: - blocks = n/BLOCKSIZE_2STATE; - blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; - kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr, + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: case RMSPROP: case ADAGRAD: - blocks = n/BLOCKSIZE_1STATE; - blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; - kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; @@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n) { - int blocks = n/2048; - blocks = n % 2048 == 0 ? blocks : blocks + 1; + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); - kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n); + kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, int num_blocks = numRows/subtile_rows; num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; num_blocks = num_blocks*(tileCols/32); + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); assert(threads <= tilesize); - //cout << num_blocks << " blocks" << endl; - kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r int tile_cols = STATS_THREADS*STATS_ITEMS; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); if(nnz_threshold == 0.0) kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); @@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col int tile_rows = 16; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; - //cout << cols << " " << tiledCols << " " << tiledRows << endl; - //cout << "num blocks " << num_blocks << endl; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); - //cout << A << " " << out_col_normed << endl; if(threshold > 0.0f) kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); else @@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o int tile_rows = 32; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); int outCols = fill_up_to_nearest_multiple(cols, 32); int outRows = fill_up_to_nearest_multiple(rows, 32); if(FORMAT == COL_TURING) @@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o } } - //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl; - //cout << "num blocks " << num_blocks << endl; - - //cout << A << " " << out_col_normed << endl; kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index fc7a0e1..8ebe8c8 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -40,7 +40,8 @@ names = [ ids=names, ) def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): - dim2 = dim2 - (dim2 % 16) + if dim2 > 0: + dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) for i in range(k): @@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist() dim3 = torch.randint(32, 96, size=(n,)).tolist() dim4 = torch.randint(32, 96, size=(n,)).tolist() -# dim1 = (17,) -# dim2 = (7,) -# dim3 = (37,) -# dim4 = (23,) +dim2.append(0) decomp = [0.0, 6.0] funcs = [(torch.matmul, bnb.matmul)] @@ -385,9 +383,14 @@ def test_matmullt( ) if req_grad[1]: n = gradB1.numel() - assert torch.abs(gradB1).sum() > 0.0 - assert torch.abs(gradB2).sum() > 0.0 + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 |