summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r--bitsandbytes/autograd/_functions.py40
1 files changed, 5 insertions, 35 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 7cf4999..52e56d0 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -196,7 +196,6 @@ class MatmulLtState:
self.CxBt = None
self.SBt = None
- self.CBt = None
class MatMul8bitLt(torch.autograd.Function):
@@ -327,15 +326,12 @@ class MatMul8bitLt(torch.autograd.Function):
#clone_func = torch.clone
return clone_func(output.view(output_shape))
- @staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, req_gradBias = ctx.req_grads
- CAt, subA = ctx.tensors
- SCAt, idx = ctx.tensor_states
- formatB = ctx.formatB
+ assert not req_gradB, "TODO: support weight updates as well"
state = ctx.state
if len(grad_output.shape) == 3:
@@ -345,37 +341,11 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A = grad_B = grad_bias = None
- Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
- if req_gradB:
- CxAt, SAt = F.transform(CAt, formatB, transpose=True)
- C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
- gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
- grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
- if state.threshold > 0.0 and subA is not None:
- grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
-
if req_gradA:
- C32grad, Sgrad = F.transform(Cgrad, "col32")
- if state.CxBt is None:
- if state.has_fp16_weights:
- CBt = state.CBt
- else:
- # Restore CBt from CB
- assert state.CBt is None, "CBt should not be stored in state"
- CB = state.CB.half()
- SCB = state.SCB.unsqueeze(1).half()
- SCBt = state.SCBt.unsqueeze(1).half()
- Bt = (CB * SCB).t().contiguous()
- CBt = (Bt / SCBt).t().to(torch.int8)
-
- # intentionally, do not store CxBt in state
- CxBt, SBt = F.transform(
- CBt, to_order=formatB, transpose=True
- )
- else:
- CxBt = state.CxBt
- gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
- grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+ CB = state.CB.half()
+ SCB = state.SCB.unsqueeze(1).half()
+ B = (CB * SCB) / 127.0
+ grad_A = torch.mm(grad_output, B).view(ctx.grad_shape)
if req_gradBias:
grad_bias = grad_output.sum(0)