summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/autograd/_functions.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 52e56d0..e266d69 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -213,6 +213,10 @@ class MatMul8bitLt(torch.autograd.Function):
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+ # Cast A to fp16
+ A_dtype = A.dtype
+ A = A.to(torch.float16)
+
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@@ -322,14 +326,21 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
+ # Cast fp16 output back to A.dtype
+ output = output.to(A_dtype)
+
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
- #clone_func = torch.clone
return clone_func(output.view(output_shape))
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
+
+ # Cast grad_output to fp16
+ grad_output_dtype = grad_output.dtype
+ grad_output.to(torch.float16)
+
req_gradA, req_gradB, req_gradBias = ctx.req_grads
assert not req_gradB, "TODO: support weight updates as well"
state = ctx.state
@@ -350,6 +361,9 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradBias:
grad_bias = grad_output.sum(0)
+ # Cast grad_A back to grad_output_dtype
+ grad_output.to(grad_output_dtype)
+
return grad_A, grad_B, None, grad_bias, None