summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authordbaranchuk <dmitrybaranchuk@gmail.com>2022-08-23 23:39:54 +0300
committerdbaranchuk <dmitrybaranchuk@gmail.com>2022-08-23 23:39:54 +0300
commit8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b (patch)
treeb0b17700aad3ac18a1265e078c0ea6b1ada8b87f /bitsandbytes
parent9d60b3c5279641ba936facd710c722ebe52fcf40 (diff)
add memory efficient backward
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/autograd/_functions.py39
-rw-r--r--bitsandbytes/nn/modules.py13
2 files changed, 28 insertions, 24 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 4dbf129..63e8ad5 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function):
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
- else:
- if state.CxB is None:
- # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
- # we also need to convert it to the turing/ampere format
- state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ elif state.CxB is None:
+ # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
+ # we also need to convert it to the turing/ampere format
+ state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
- # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
- # # do not use pool for 2nd FFN layer
- # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- # else:
- # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
@@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
- assert (
- state.has_fp16_weights
- ), "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
- grad_output = grad_output.view(
+ grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
@@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32")
- if state.CxBt is None:
- state.CxBt, state.SBt = F.transform(
- state.CBt, to_order=formatB, transpose=True
- )
- gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
+ if state.CxBt is None and state.has_fp16_weights:
+ CBt = state.CBt
+ elif state.CxBt is None:
+ assert state.CBt is None
+ CB = state.CB.half()
+ SCB = state.SCB.unsquezee(1).half()
+ SCBt = state.SCBt.unsquezee(1).half()
+ Bt = (CB * SCB).t().contiguous()
+ CBt = (Bt / SCBt).t().to(torch.int8)
+
+ CxBt, SBt = F.transform(
+ CBt, to_order=formatB, transpose=True
+ )
+ gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
if req_gradBias:
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index b222f54..ef7fefc 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter):
has_fp16_weights=False,
CB=None,
SCB=None,
+ SCBt=None,
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
+ cls.SCBt = None
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
@@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter):
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
- del SCBt
self.data = CB
setattr(self, "CB", CB)
setattr(self, "SCB", SCB)
+ setattr(self, "SCBt", SCBt)
return self
@@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter):
)
new_param.CB = self.CB
new_param.SCB = self.SCB
+ new_param.SCB = self.SCBt
return new_param
@@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear):
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
+ self.state.SCBt = self.weight.SCBt
self.weight.CB = None
self.weight.SCB = None
+ self.weight.SCBt = None
def forward(self, x):
self.state.is_training = self.training
@@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
- if not self.state.has_fp16_weights and self.state.CB is not None:
+ # if not self.state.has_fp16_weights and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
- del self.state.CB
- self.weight.data = self.state.CxB
+ # del self.state.CB
+ # self.weight.data = self.state.CxB
return out