diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-03 11:54:01 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-03 11:54:01 -0700 |
commit | 451fd9506e215aa25643e9782cb7d8aed2a266cc (patch) | |
tree | a95aac44018b664dcae503918bb551728f8147c3 /bitsandbytes/autograd | |
parent | 2f01865a2ff4ad3345c156f7a2f76fe79ec4ed9a (diff) |
Added fixes for the case that matmullt dim A is zero, e.g. [0, 768].
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 815a4f1..370ca83 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,4 +1,5 @@ import torch +import math import bitsandbytes as bnb import bitsandbytes.functional as F @@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if math.prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -265,6 +277,8 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None req_gradA, req_gradB = ctx.req_grads CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states @@ -293,7 +307,7 @@ class MatMul8bitLt(torch.autograd.Function): gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) - return grad_A, grad_B, None, None, None, None, None + return grad_A, grad_B, None, None matmul = MatMul8bitLt.apply |