summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd/_functions.py
diff options
context:
space:
mode:
authordbaranchuk <dmitrybaranchuk@gmail.com>2022-08-23 23:39:54 +0300
committerdbaranchuk <dmitrybaranchuk@gmail.com>2022-08-23 23:39:54 +0300
commit8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b (patch)
treeb0b17700aad3ac18a1265e078c0ea6b1ada8b87f /bitsandbytes/autograd/_functions.py
parent9d60b3c5279641ba936facd710c722ebe52fcf40 (diff)
add memory efficient backward
Diffstat (limited to 'bitsandbytes/autograd/_functions.py')
-rw-r--r--bitsandbytes/autograd/_functions.py39
1 files changed, 19 insertions, 20 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 4dbf129..63e8ad5 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function):
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
- else:
- if state.CxB is None:
- # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
- # we also need to convert it to the turing/ampere format
- state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ elif state.CxB is None:
+ # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
+ # we also need to convert it to the turing/ampere format
+ state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
- # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
- # # do not use pool for 2nd FFN layer
- # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- # else:
- # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
@@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
- assert (
- state.has_fp16_weights
- ), "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
- grad_output = grad_output.view(
+ grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
@@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32")
- if state.CxBt is None:
- state.CxBt, state.SBt = F.transform(
- state.CBt, to_order=formatB, transpose=True
- )
- gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
+ if state.CxBt is None and state.has_fp16_weights:
+ CBt = state.CBt
+ elif state.CxBt is None:
+ assert state.CBt is None
+ CB = state.CB.half()
+ SCB = state.SCB.unsquezee(1).half()
+ SCBt = state.SCBt.unsquezee(1).half()
+ Bt = (CB * SCB).t().contiguous()
+ CBt = (Bt / SCBt).t().to(torch.int8)
+
+ CxBt, SBt = F.transform(
+ CBt, to_order=formatB, transpose=True
+ )
+ gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
if req_gradBias: