diff options
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 52e56d0..e266d69 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -213,6 +213,10 @@ class MatMul8bitLt(torch.autograd.Function): else: return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + # Cast A to fp16 + A_dtype = A.dtype + A = A.to(torch.float16) + # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -322,14 +326,21 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + # Cast fp16 output back to A.dtype + output = output.to(A_dtype) + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - #clone_func = torch.clone return clone_func(output.view(output_shape)) def backward(ctx, grad_output): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + + # Cast grad_output to fp16 + grad_output_dtype = grad_output.dtype + grad_output.to(torch.float16) + req_gradA, req_gradB, req_gradBias = ctx.req_grads assert not req_gradB, "TODO: support weight updates as well" state = ctx.state @@ -350,6 +361,9 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradBias: grad_bias = grad_output.sum(0) + # Cast grad_A back to grad_output_dtype + grad_output.to(grad_output_dtype) + return grad_A, grad_B, None, grad_bias, None |