summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordbaranchuk <dmitrybaranchuk@gmail.com>2022-09-11 05:51:29 +0300
committerdbaranchuk <dmitrybaranchuk@gmail.com>2022-09-11 05:51:29 +0300
commit42b5fc9acc4b59a6d90c662eb26099ac25907c7f (patch)
treedf0f65f65e2f1aae25462da1be9c65ca3fe45580
parent843ad0631c65eabc7f64e80906ecf5482cc1a036 (diff)
add memory effcient backward option
-rw-r--r--bitsandbytes/autograd/_functions.py46
-rw-r--r--bitsandbytes/nn/modules.py16
2 files changed, 52 insertions, 10 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 226cbb5..271c690 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,5 +1,6 @@
import operator
import torch
+import bitsandbytes as bnb
import bitsandbytes.functional as F
from dataclasses import dataclass
@@ -187,6 +188,8 @@ class MatmulLtState:
use_pool = False
formatB = F.get_special_format_str()
+ memory_efficient_backward = False
+
def reset_grads(self):
self.CB = None
self.CxB = None
@@ -283,6 +286,12 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
+ # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # # do not use pool for 2nd FFN layer
+ # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
+ # else:
+ # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
@@ -332,13 +341,15 @@ class MatMul8bitLt(torch.autograd.Function):
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape))
+ @staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
-
req_gradA, req_gradB, req_gradBias = ctx.req_grads
- assert not req_gradB, "TODO: support weight updates as well"
+ CAt, subA = ctx.tensors
+ SCAt, idx = ctx.tensor_states
+ formatB = ctx.formatB
state = ctx.state
# Cast grad_output to fp16
@@ -352,11 +363,31 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A = grad_B = grad_bias = None
+ Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
+ if req_gradB:
+ CxAt, SAt = F.transform(CAt, formatB, transpose=True)
+ C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
+ gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
+ grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
+ if state.threshold > 0.0 and subA is not None:
+ grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
+
if req_gradA:
- CB = state.CB.half()
- SCB = (state.SCB.unsqueeze(1) / 127.0).half()
- CB *= SCB
- grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
+ if state.CBt:
+ C32grad, Sgrad = F.transform(Cgrad, "col32")
+ if state.CxBt is None:
+ state.CxBt, state.SBt = F.transform(
+ state.CBt, to_order=formatB, transpose=True
+ )
+ gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+ elif state.CB:
+ CB = state.CB.half()
+ SCB = (state.SCB.unsqueeze(1) / 127.0).half()
+ CB *= SCB
+ grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
+ else:
+ raise Exception('State must contain either CBt or CB matrix')
if req_gradBias:
grad_bias = grad_output.sum(0)
@@ -367,6 +398,9 @@ class MatMul8bitLt(torch.autograd.Function):
return grad_A, grad_B, None, grad_bias, None
+matmul = MatMul8bitLt.apply
+
+
def matmul(
A: tensor,
B: tensor,
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 3e32c8e..00d0c61 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear):
has_fp16_weights=True,
threshold=0.0,
index=None,
+ memory_efficient_backward=False
):
super(Linear8bitLt, self).__init__(
input_features, output_features, bias
@@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear):
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
+ self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
@@ -255,10 +257,16 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
- if not self.state.has_fp16_weights and self.state.CxB is not None:
- # In this version, we convert 8-bit row major to turing/ampere format at each inference pass
- # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place.
- del self.state.CxB
+ if not self.state.has_fp16_weights:
+ if not self.state.memory_efficient_backward and self.state.CB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+ elif self.state.memory_efficient_backward and self.state.CxB is not None:
+ # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
+ # Thus, we delete CxB from the state.
+ del self.state.CxB
return out