summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-03 11:54:01 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-03 11:54:01 -0700
commit451fd9506e215aa25643e9782cb7d8aed2a266cc (patch)
treea95aac44018b664dcae503918bb551728f8147c3 /bitsandbytes
parent2f01865a2ff4ad3345c156f7a2f76fe79ec4ed9a (diff)
Added fixes for the case that matmullt dim A is zero, e.g. [0, 768].
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/autograd/_functions.py16
-rw-r--r--bitsandbytes/functional.py21
2 files changed, 30 insertions, 7 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 815a4f1..370ca83 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,4 +1,5 @@
import torch
+import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ # default to pytorch behavior if inputs are empty
+ ctx.is_empty = False
+ if math.prod(A.shape) == 0:
+ ctx.is_empty = True
+ ctx.A = A
+ ctx.B = B
+ if A.shape[-1] == B.shape[0]:
+ return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
+ else:
+ return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@@ -265,6 +277,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
+ if ctx.is_empty:
+ return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
@@ -293,7 +307,7 @@ class MatMul8bitLt(torch.autograd.Function):
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
- return grad_A, grad_B, None, None, None, None, None
+ return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 0a2d557..494de1b 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import random
-from typing import Tuple
-
+import math
import torch
+
+from typing import Tuple
from torch import Tensor
from .cextension import lib, COMPILED_WITH_CUDA
@@ -919,15 +920,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0]
dimsA = len(shapeA)
dimsB = len(shapeB)
+ assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0]*shapeA[1]
- if dimsB == 2:
- rows = n = shapeB[0]
- elif dimsB == 3:
- rows = n = shapeB[0]*shapeB[1]
+ rows = n = shapeB[0]
+ assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
+ print(shapeA, math.prod(shapeA), math.prod(list(shapeA)))
+ print('aaa')
+
+ # if the tensor is empty, return a transformed empty tensor with the right dimensions
+ if shapeA[0] == 0 and dimsA == 2:
+ return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
+ elif shapeA[1] == 0 and dimsA == 3:
+ return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
@@ -984,6 +992,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 1:
+ print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')
torch.cuda.set_device(prev_device)