summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd/_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/autograd/_functions.py')
-rw-r--r--bitsandbytes/autograd/_functions.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index b56b2ee..14f2660 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,7 +1,7 @@
from dataclasses import dataclass
import torch
-
+import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -199,6 +199,17 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ # default to pytorch behavior if inputs are empty
+ ctx.is_empty = False
+ if math.prod(A.shape) == 0:
+ ctx.is_empty = True
+ ctx.A = A
+ ctx.B = B
+ if A.shape[-1] == B.shape[0]:
+ return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
+ else:
+ return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
+ if ctx.is_empty:
+ return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
@@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.grad_shape
)
- return grad_A, grad_B, None, None, None, None, None
+ return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply