summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-03 11:54:01 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-03 11:54:01 -0700
commit451fd9506e215aa25643e9782cb7d8aed2a266cc (patch)
treea95aac44018b664dcae503918bb551728f8147c3
parent2f01865a2ff4ad3345c156f7a2f76fe79ec4ed9a (diff)
Added fixes for the case that matmullt dim A is zero, e.g. [0, 768].
-rw-r--r--bitsandbytes/autograd/_functions.py16
-rw-r--r--bitsandbytes/functional.py21
-rw-r--r--csrc/ops.cu31
-rw-r--r--tests/test_autograd.py20
4 files changed, 60 insertions, 28 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 815a4f1..370ca83 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,4 +1,5 @@
import torch
+import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ # default to pytorch behavior if inputs are empty
+ ctx.is_empty = False
+ if math.prod(A.shape) == 0:
+ ctx.is_empty = True
+ ctx.A = A
+ ctx.B = B
+ if A.shape[-1] == B.shape[0]:
+ return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
+ else:
+ return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@@ -265,6 +277,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
+ if ctx.is_empty:
+ return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
@@ -293,7 +307,7 @@ class MatMul8bitLt(torch.autograd.Function):
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
- return grad_A, grad_B, None, None, None, None, None
+ return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 0a2d557..494de1b 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import random
-from typing import Tuple
-
+import math
import torch
+
+from typing import Tuple
from torch import Tensor
from .cextension import lib, COMPILED_WITH_CUDA
@@ -919,15 +920,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0]
dimsA = len(shapeA)
dimsB = len(shapeB)
+ assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0]*shapeA[1]
- if dimsB == 2:
- rows = n = shapeB[0]
- elif dimsB == 3:
- rows = n = shapeB[0]*shapeB[1]
+ rows = n = shapeB[0]
+ assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
+ print(shapeA, math.prod(shapeA), math.prod(list(shapeA)))
+ print('aaa')
+
+ # if the tensor is empty, return a transformed empty tensor with the right dimensions
+ if shapeA[0] == 0 and dimsA == 2:
+ return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
+ elif shapeA[1] == 0 and dimsA == 3:
+ return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
@@ -984,6 +992,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 1:
+ print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')
torch.cuda.set_device(prev_device)
diff --git a/csrc/ops.cu b/csrc/ops.cu
index b3d07c6..cfc9605 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -459,8 +459,6 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
assert(threads <= tilesize);
- //cout << num_blocks << " blocks" << endl;
-
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -473,11 +471,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
+ int row_tiles = (tiledRows/STATS_ROWS);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
-
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
@@ -494,13 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
- assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ int row_tiles = (tiledRows/tile_rows);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
- //cout << cols << " " << tiledCols << " " << tiledRows << endl;
- //cout << "num blocks " << num_blocks << endl;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
- //cout << A << " " << out_col_normed << endl;
if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else
@@ -518,7 +520,12 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ int row_tiles = (tiledRows/tile_rows);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
+
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
@@ -545,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
}
}
- //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
- //cout << "num blocks " << num_blocks << endl;
-
- //cout << A << " " << out_col_normed << endl;
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index d2b5d59..1b6c2ab 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -23,7 +23,8 @@ str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, st
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
- dim2 = dim2 - (dim2 % 16)
+ if dim2 > 0:
+ dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
@@ -179,6 +180,7 @@ dim2 = torch.randint(32,96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist()
+dim2.append(0)
#dim1 = (17,)
#dim2 = (7,)
#dim3 = (37,)
@@ -234,9 +236,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
err = torch.abs(out_bnb-out_torch).mean().item()
#print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.0175
+ assert (idx==0).sum().item() <= n*0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() < n*0.001
+ assert (idx==0).sum().item() <= n*0.001
if has_fp16_weights:
if any(req_grad):
@@ -260,11 +262,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
- assert torch.abs(gradB1).sum() > 0.0
- assert torch.abs(gradB2).sum() > 0.0
+ if dim2 > 0:
+ assert torch.abs(gradB1).sum() > 0.0
+ assert torch.abs(gradB2).sum() > 0.0
+ else:
+ assert torch.abs(gradB1).sum() == 0.0
+ assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx==0).sum().item() <= n*0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
+ assert (idx==0).sum().item() <= n*0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)