summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--bitsandbytes/autograd/_functions.py16
-rw-r--r--bitsandbytes/functional.py47
-rw-r--r--csrc/ops.cu126
-rw-r--r--tests/test_autograd.py20
5 files changed, 141 insertions, 70 deletions
diff --git a/Makefile b/Makefile
index 195009f..3e95b35 100644
--- a/Makefile
+++ b/Makefile
@@ -51,7 +51,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
- $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index e641583..607d868 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,4 +1,5 @@
import torch
+import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ # default to pytorch behavior if inputs are empty
+ ctx.is_empty = False
+ if math.prod(A.shape) == 0:
+ ctx.is_empty = True
+ ctx.A = A
+ ctx.B = B
+ if A.shape[-1] == B.shape[0]:
+ return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
+ else:
+ return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@@ -283,6 +295,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
+ if ctx.is_empty:
+ return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
@@ -311,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
- return grad_A, grad_B, None, None, None, None, None
+ return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index ac85f88..08c108c 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import random
-from typing import Tuple
-
+import math
import torch
+
+from typing import Tuple
from torch import Tensor
from .cextension import lib, COMPILED_WITH_CUDA
@@ -141,6 +142,14 @@ def get_special_format_str():
elif major == 8: return 'col_ampere'
else: return 'col_turing'
+
+def is_on_gpu(tensors):
+ on_gpu = True
+ for t in tensors:
+ if t is None: continue # NULL pointers are fine
+ on_gpu &= t.device.type == 'cuda'
+ return on_gpu
+
def get_ptr(A: Tensor) -> ct.c_void_p:
'''
Get the ctypes pointer from a PyTorch Tensor.
@@ -284,6 +293,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
The 256 quantiles in float32 datatype.
'''
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+ is_on_gpu([A, out])
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
@@ -337,6 +347,7 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
if A.device.type != 'cpu':
+ is_on_gpu([code, A, absmax, out, rand])
if rand is not None:
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
@@ -401,6 +412,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]')
if A.device.type != 'cpu':
+ is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
@@ -458,6 +470,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
Quantized 8-bit tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+ is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -483,6 +496,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
32-bit output tensor.
'''
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+ is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@@ -662,6 +676,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
The current optimiation steps (number of past gradient norms).
"""
+ is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
elif grad.dtype == torch.float16:
@@ -694,6 +709,7 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
+ is_on_gpu([histogram, index1, index2d, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
@@ -820,6 +836,7 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# B^T @ A^T = C^T
# [km, nk -> mn]
+ is_on_gpu([B, A, out])
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
return out
@@ -892,6 +909,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ptr = CUBLAS_Context.get_instance().get_context(A.device)
+ is_on_gpu([B, A, out])
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
@@ -902,15 +920,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0]
dimsA = len(shapeA)
dimsB = len(shapeB)
+ assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0]*shapeA[1]
- if dimsB == 2:
- rows = n = shapeB[0]
- elif dimsB == 3:
- rows = n = shapeB[0]*shapeB[1]
+ rows = n = shapeB[0]
+ assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
+
+ # if the tensor is empty, return a transformed empty tensor with the right dimensions
+ if shapeA[0] == 0 and dimsA == 2:
+ return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
+ elif shapeA[1] == 0 and dimsA == 3:
+ return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
@@ -954,6 +977,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0
ptrRowScale = get_ptr(None)
+ is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
@@ -966,6 +990,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 1:
+ print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')
torch.cuda.set_device(prev_device)
@@ -994,6 +1019,7 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
+ is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
return out
@@ -1024,6 +1050,7 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
cols = ct.c_int32(cols)
prev_device = pre_call(A.device)
+ is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
@@ -1133,6 +1160,7 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row)
+ is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
@@ -1185,6 +1213,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
+ is_on_gpu([A, out])
+ prev_device = pre_call(A.device)
if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
@@ -1207,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
-
-
+ post_call(prev_device)
return out, new_state
@@ -1240,6 +1269,7 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
return out
@@ -1285,6 +1315,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
#print(cooA.rowidx[:64])
#print(cooA.colidx[:64].sort()[0])
+ is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
elif B.dtype == torch.int8:
diff --git a/csrc/ops.cu b/csrc/ops.cu
index 952894c..9b75e69 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -19,53 +19,59 @@ using std::endl;
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{
int threads = 512;
- int blocks = n/threads;
- blocks = n % threads == 0 ? blocks : blocks + 1;
- kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
+ int num_blocks = n/threads;
+ num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{
- int blocks = n/4096;
- blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ int num_blocks = n/4096;
+ num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
- kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
+ kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void quantize(float *code, float *A, unsigned char *out, int n)
{
- int blocks = n/1024;
- blocks = n % 1024 == 0 ? blocks : blocks + 1;
- kQuantize<<<blocks, 1024>>>(code, A, out, n);
+ int num_blocks = n/1024;
+ num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void dequantize(float *code, unsigned char *A, float *out, int n)
{
- int blocks = n/1024;
- blocks = n % 1024 == 0 ? blocks : blocks + 1;
- kDequantize<<<blocks, 1024>>>(code, A, out, n);
+ int num_blocks = n/1024;
+ num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
{
- int blocks = n/4096;
- blocks = n % 4096 == 0 ? blocks : blocks + 1;
- kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
+ int num_blocks = n/4096;
+ num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{
- int blocks = n/blocksize;
- blocks = n % blocksize == 0 ? blocks : blocks + 1;
+ int num_blocks = n/blocksize;
+ num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(blocksize == 4096)
- kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n);
+ kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048)
- kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n);
+ kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
- int blocks = n/4096;
- blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ int num_blocks = n/4096;
+ num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
switch(OPTIMIZER)
{
case ADAM:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
- kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
+ kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
- kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
+ kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
@@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
- kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
+ kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
- kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
+ kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
}
@@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
float weight_decay,
const float gnorm_scale, int n)
{
- int blocks = n/4096;
- blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ int num_blocks = n/4096;
+ num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
@@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
- kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
+ kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
- kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
+ kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
- kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
+ kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
- kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
+ kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
- int blocks = 0;
+ int num_blocks = 0;
switch(OPTIMIZER)
{
case ADAM:
- blocks = n/BLOCKSIZE_2STATE;
- blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
- kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
+ num_blocks = n/BLOCKSIZE_2STATE;
+ num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
- blocks = n/BLOCKSIZE_1STATE;
- blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
- kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
+ num_blocks = n/BLOCKSIZE_1STATE;
+ num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
- int blocks = n/2048;
- blocks = n % 2048 == 0 ? blocks : blocks + 1;
+ int num_blocks = n/2048;
+ num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
- kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n);
+ kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32);
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
assert(threads <= tilesize);
- //cout << num_blocks << " blocks" << endl;
-
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
+ int row_tiles = (tiledRows/STATS_ROWS);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
+
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
@@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ int row_tiles = (tiledRows/tile_rows);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
- //cout << cols << " " << tiledCols << " " << tiledRows << endl;
- //cout << "num blocks " << num_blocks << endl;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
- //cout << A << " " << out_col_normed << endl;
if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else
@@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ int row_tiles = (tiledRows/tile_rows);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
+
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING)
@@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
}
}
- //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
- //cout << "num blocks " << num_blocks << endl;
-
- //cout << A << " " << out_col_normed << endl;
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index d2b5d59..1b6c2ab 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -23,7 +23,8 @@ str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, st
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
- dim2 = dim2 - (dim2 % 16)
+ if dim2 > 0:
+ dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
@@ -179,6 +180,7 @@ dim2 = torch.randint(32,96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist()
+dim2.append(0)
#dim1 = (17,)
#dim2 = (7,)
#dim3 = (37,)
@@ -234,9 +236,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
err = torch.abs(out_bnb-out_torch).mean().item()
#print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.0175
+ assert (idx==0).sum().item() <= n*0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() < n*0.001
+ assert (idx==0).sum().item() <= n*0.001
if has_fp16_weights:
if any(req_grad):
@@ -260,11 +262,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
- assert torch.abs(gradB1).sum() > 0.0
- assert torch.abs(gradB2).sum() > 0.0
+ if dim2 > 0:
+ assert torch.abs(gradB1).sum() > 0.0
+ assert torch.abs(gradB2).sum() > 0.0
+ else:
+ assert torch.abs(gradB1).sum() == 0.0
+ assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx==0).sum().item() <= n*0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
+ assert (idx==0).sum().item() <= n*0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)