diff options
Diffstat (limited to 'bitsandbytes/autograd/_functions.py')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index b56b2ee..14f2660 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import torch - +import math import bitsandbytes as bnb import bitsandbytes.functional as F @@ -199,6 +199,17 @@ class MatmulLtState: class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if math.prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None req_gradA, req_gradB = ctx.req_grads CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states @@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.grad_shape ) - return grad_A, grad_B, None, None, None, None, None + return grad_A, grad_B, None, None matmul = MatMul8bitLt.apply |