summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r--bitsandbytes/autograd/_functions.py25
1 files changed, 10 insertions, 15 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 6868b75..0e594a5 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -321,7 +321,6 @@ class MatMul8bitLt(torch.autograd.Function):
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
- assert subA.dtype == state.subB.dtype == output.dtype, (subA.dtype, state.subB.dtype, output.dtype)
output.addmm_(subA, state.subB)
# 5. Save state
@@ -330,6 +329,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.formatB = formatB
ctx.grad_shape = input_shape
ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
+ ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if requires_gradA or requires_gradB:
ctx.tensors = (CAt, subA)
@@ -348,7 +348,7 @@ class MatMul8bitLt(torch.autograd.Function):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
- req_gradA, req_gradB, req_gradBias = ctx.req_grads
+ req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
@@ -357,25 +357,22 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
- grad_bias = grad_output.sum(0)
+ grad_bias = grad_output.sum(0).to(ctx.bias_dtype)
# Cast grad_output to fp16
- grad_output_dtype = grad_output.dtype
- grad_output = grad_output.to(torch.float16)
-
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
- Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
+ Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
- grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
+ grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.B_dtype)
if state.threshold > 0.0 and subA is not None:
- grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
+ grad_B[:, idx].addmm_(grad_output.t(), subA)
if req_gradA:
if state.CBt is not None:
@@ -385,18 +382,16 @@ class MatMul8bitLt(torch.autograd.Function):
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
- grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.A_dtype)
+
elif state.CB is not None:
- CB = state.CB.half()
+ CB = state.CB.to(ctx.B_dtype)
SCB = (state.SCB.unsqueeze(1) / 127.0).half()
CB *= SCB
- grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
+ grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape).to(ctx.A_dtype)
else:
raise Exception('State must contain either CBt or CB matrix for backward')
- # Cast grad_A back to grad_output_dtype
- grad_output = grad_output.to(grad_output_dtype)
-
return grad_A, grad_B, None, grad_bias, None