summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/autograd/_functions.py54
-rw-r--r--bitsandbytes/nn/modules.py9
2 files changed, 26 insertions, 37 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index be975f6..226cbb5 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -195,7 +195,6 @@ class MatmulLtState:
self.CxBt = None
self.SBt = None
- self.CBt = None
class MatMul8bitLt(torch.autograd.Function):
@@ -225,6 +224,11 @@ class MatMul8bitLt(torch.autograd.Function):
input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
+
+ # Cast A to fp16
+ A_dtype = A.dtype
+ A = A.to(torch.float16)
+
assert (
A.dtype == torch.float16
), f"The input data type needs to be fp16 but {A.dtype} was found!"
@@ -279,12 +283,6 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
- # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
- # # do not use pool for 2nd FFN layer
- # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- # else:
- # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
@@ -328,52 +326,44 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
+ # Cast fp16 output back to A.dtype
+ output = output.to(A_dtype)
+
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
- #clone_func = torch.clone
return clone_func(output.view(output_shape))
- @staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
+
req_gradA, req_gradB, req_gradBias = ctx.req_grads
- CAt, subA = ctx.tensors
- SCAt, idx = ctx.tensor_states
- formatB = ctx.formatB
+ assert not req_gradB, "TODO: support weight updates as well"
state = ctx.state
- assert (
- state.has_fp16_weights
- ), "Backprop only supported for fp16 weights."
+
+ # Cast grad_output to fp16
+ grad_output_dtype = grad_output.dtype
+ grad_output = grad_output.to(torch.float16)
if len(grad_output.shape) == 3:
- grad_output = grad_output.view(
+ grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
grad_A = grad_B = grad_bias = None
- Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
- if req_gradB:
- CxAt, SAt = F.transform(CAt, formatB, transpose=True)
- C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
- gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
- grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
- if state.threshold > 0.0 and subA is not None:
- grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
-
if req_gradA:
- C32grad, Sgrad = F.transform(Cgrad, "col32")
- if state.CxBt is None:
- state.CxBt, state.SBt = F.transform(
- state.CBt, to_order=formatB, transpose=True
- )
- gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
- grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+ CB = state.CB.half()
+ SCB = (state.SCB.unsqueeze(1) / 127.0).half()
+ CB *= SCB
+ grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
if req_gradBias:
grad_bias = grad_output.sum(0)
+ # Cast grad_A back to grad_output_dtype
+ grad_output.to(grad_output_dtype)
+
return grad_A, grad_B, None, grad_bias, None
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index b222f54..3e32c8e 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -255,11 +255,10 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
- if not self.state.has_fp16_weights and self.state.CB is not None:
- # we converted 8-bit row major to turing/ampere format in the first inference pass
- # we no longer need the row-major weight
- del self.state.CB
- self.weight.data = self.state.CxB
+ if not self.state.has_fp16_weights and self.state.CxB is not None:
+ # In this version, we convert 8-bit row major to turing/ampere format at each inference pass
+ # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place.
+ del self.state.CxB
return out