summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorTim Dettmers <TimDettmers@users.noreply.github.com>2022-09-19 21:09:25 -0700
committerGitHub <noreply@github.com>2022-09-19 21:09:25 -0700
commit439f2b0c10abd3e9aade386d92810b074c69e9ec (patch)
tree75454081c86ba1c96c07e83defc9fc5f4de840cf /bitsandbytes
parent9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff)
parent76ce9aa6da7d68d2463f0f3e99532ab5b6db58a8 (diff)
Merge pull request #33 from dbaranchuk/memory-efficient-backward
Memory efficient backward
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/autograd/_functions.py76
-rw-r--r--bitsandbytes/nn/modules.py21
2 files changed, 59 insertions, 38 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index be975f6..2ddb406 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,4 +1,6 @@
import operator
+import warnings
+
import torch
import bitsandbytes.functional as F
@@ -184,6 +186,7 @@ class MatmulLtState:
idx = None
is_training = True
has_fp16_weights = True
+ memory_efficient_backward = False
use_pool = False
formatB = F.get_special_format_str()
@@ -209,31 +212,29 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
- return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
+ return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
else:
- return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+ return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
- requires_gradA = A.requires_grad
- requires_gradB = B.requires_grad
- requires_gradBias = bias is not None and bias.requires_grad
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
- assert (
- A.dtype == torch.float16
- ), f"The input data type needs to be fp16 but {A.dtype} was found!"
+
+ # Cast A to fp16
+ if A.dtype != torch.float16:
+ warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
- A, threshold=state.threshold
+ A.to(torch.float16), threshold=state.threshold
)
if state.threshold > 0.0 and coo_tensorA is not None:
@@ -269,7 +270,7 @@ class MatMul8bitLt(torch.autograd.Function):
state.SCB,
state.SCBt,
coo_tensorB,
- ) = F.double_quant(B)
+ ) = F.double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False
@@ -290,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
- .half()
+ .to(A.dtype)
)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
@@ -307,7 +308,13 @@ class MatMul8bitLt(torch.autograd.Function):
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here
- output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
+
+ if bias is None or bias.dtype == torch.float16:
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
+ output = output.to(A.dtype)
+ else: # apply bias separately
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
+ output = output.to(A.dtype).add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
@@ -318,9 +325,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.formatB = formatB
ctx.grad_shape = input_shape
- ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
+ ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
- if requires_gradA or requires_gradB:
+ if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA)
ctx.tensor_states = (SCAt, state.idx)
else:
@@ -328,8 +335,8 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
+
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
- #clone_func = torch.clone
return clone_func(output.view(output_shape))
@staticmethod
@@ -337,23 +344,24 @@ class MatMul8bitLt(torch.autograd.Function):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
- req_gradA, req_gradB, req_gradBias = ctx.req_grads
+ req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
- assert (
- state.has_fp16_weights
- ), "Backprop only supported for fp16 weights."
+ grad_A = grad_B = grad_bias = None
+
+ if req_gradBias:
+ # compute grad_bias first before changing grad_output dtype
+ grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
+ # Cast grad_output to fp16
if len(grad_output.shape) == 3:
- grad_output = grad_output.view(
+ grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
- grad_A = grad_B = grad_bias = None
-
- Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
+ Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
@@ -363,16 +371,20 @@ class MatMul8bitLt(torch.autograd.Function):
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
- C32grad, Sgrad = F.transform(Cgrad, "col32")
- if state.CxBt is None:
- state.CxBt, state.SBt = F.transform(
- state.CBt, to_order=formatB, transpose=True
- )
- gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
- grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+ if state.CBt is not None:
+ C32grad, Sgrad = F.transform(Cgrad, "col32")
+ if state.CxBt is None:
+ state.CxBt, state.SBt = F.transform(
+ state.CBt, to_order=formatB, transpose=True
+ )
+ gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
- if req_gradBias:
- grad_bias = grad_output.sum(0)
+ elif state.CB is not None:
+ CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
+ grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
+ else:
+ raise Exception('State must contain either CBt or CB matrix for backward')
return grad_A, grad_B, None, grad_bias, None
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index b222f54..9250fec 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -221,6 +221,7 @@ class Linear8bitLt(nn.Linear):
output_features,
bias=True,
has_fp16_weights=True,
+ memory_efficient_backward=False,
threshold=0.0,
index=None,
):
@@ -232,10 +233,13 @@ class Linear8bitLt(nn.Linear):
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
+ self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
- self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+ self.weight = Int8Params(
+ self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
+ )
def init_8bit_state(self):
self.state.CB = self.weight.CB
@@ -255,11 +259,16 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
- if not self.state.has_fp16_weights and self.state.CB is not None:
- # we converted 8-bit row major to turing/ampere format in the first inference pass
- # we no longer need the row-major weight
- del self.state.CB
- self.weight.data = self.state.CxB
+ if not self.state.has_fp16_weights:
+ if not self.state.memory_efficient_backward and self.state.CB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+ elif self.state.memory_efficient_backward and self.state.CxB is not None:
+ # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
+ # Thus, we delete CxB from the state.
+ del self.state.CxB
return out