summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd/_functions.py
diff options
context:
space:
mode:
authorTitus von Koeller <titus@vonkoeller.com>2022-08-01 03:31:48 -0700
committerTitus von Koeller <titus@vonkoeller.com>2022-08-01 03:31:48 -0700
commitbfa0e33294f2b1dc25e65a33be2397f989824298 (patch)
tree396b5d722fdd79da068882ca7376e3636fcb3bb8 /bitsandbytes/autograd/_functions.py
parent597a8521b29e90958c31e47421016494da998648 (diff)
ran black and isort for coherent code formatting
Diffstat (limited to 'bitsandbytes/autograd/_functions.py')
-rw-r--r--bitsandbytes/autograd/_functions.py134
1 files changed, 88 insertions, 46 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index e641583..a08b560 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,21 +1,24 @@
+from dataclasses import dataclass
+
import torch
+
import bitsandbytes as bnb
import bitsandbytes.functional as F
-from dataclasses import dataclass
-
tensor = torch.Tensor
-'''
+"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
-'''
+"""
+
+
class GlobalOutlierPooler(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.outliers = set()
@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
return cls._instance
def add_outliers(self, outlier_idx, feature_dim):
- if self.model_dim is None: self.model_dim = feature_dim
- if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
+ if self.model_dim is None:
+ self.model_dim = feature_dim
+ if feature_dim != self.model_dim:
+ return # we do not encode outliers for the 2nd FFN layer
self.outliers.update(outlier_idx.tolist())
def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
-class MatMul8bit(torch.autograd.Function):
+class MatMul8bit(torch.autograd.Function):
@staticmethod
- def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
+ def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)
else:
- if len(B.shape) == 2: dim = 0
- else: dim = 1
+ if len(B.shape) == 2:
+ dim = 0
+ else:
+ dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB)
@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
- if not grad_output.is_contiguous(): grad_output.contiguous()
- qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
- if not A.is_contiguous(): A = A.contiguous()
- qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
+ if not grad_output.is_contiguous():
+ grad_output.contiguous()
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output.view(-1, grad_output.shape[2]),
+ dim=0,
+ quant_type=quant_type,
+ )
+ if not A.is_contiguous():
+ A = A.contiguous()
+ qA, S2 = F.vectorwise_quant(
+ A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
+ )
igrad_B = F.igemm(qA.t(), qgrad_output)
- grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
+ grad_B = F.vectorwise_mm_dequant(
+ igrad_B, S2.t(), S1, grad_output.dtype, quant_type
+ )
else:
- qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output, dim=dims, quant_type=quant_type
+ )
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
- grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
+ grad_B = F.vectorwise_mm_dequant(
+ igrad_B,
+ S2.permute(permute_dim),
+ S1,
+ grad_output.dtype,
+ quant_type,
+ )
if A.requires_grad:
- if len(grad_output.shape) == 3: dims = [2]
- else: dims = [1]
+ if len(grad_output.shape) == 3:
+ dims = [2]
+ else:
+ dims = [1]
if len(B.shape) == 3:
# bio -> boi
@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
- qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output, dim=dims, quant_type=quant_type
+ )
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
- grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
+ grad_A = F.vectorwise_mm_dequant(
+ igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
+ )
return grad_A, grad_B, None, None, None
@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
+
@dataclass
class MatmulLtState:
CB = None
@@ -159,7 +191,6 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
-
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# 1. Quantize A
@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB = B.requires_grad
formatB = state.formatB
input_shape = A.shape
- if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
- assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
+ if state.outlier_pool is None:
+ state.outlier_pool = GlobalOutlierPooler.get_instance()
+ assert (
+ A.dtype == torch.float16
+ ), f"The input data type needs to be fp16 but {A.dtype} was found!"
# 1. Quantize A
- if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
+ if len(A.shape) == 3:
+ A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
- #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
- #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
+ # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
+ # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
- #if state.idx is not None:
+ # if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
- #else:
+ # else:
# subA = None
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
-
# 2. Quantize B
if state.has_fp16_weights:
- has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
+ has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
- if is_transposed: B = B.contiguous()
+ if is_transposed:
+ B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
- #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- #else:
+ # else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
- state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half()
+ state.subB = (
+ (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
+ )
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
- #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
+ # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone
return clone_func(output.view(output_shape))
@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
- assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
+ assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
- C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
+ C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
- C32grad, Sgrad = F.transform(Cgrad, 'col32')
+ C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
- state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
+ state.CxBt, state.SBt = F.transform(
+ state.CBt, to_order=formatB, transpose=True
+ )
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
- grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
+ ctx.grad_shape
+ )
return grad_A, grad_B, None, None, None, None, None
@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
matmul = MatMul8bitLt.apply
-def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
+def matmul(
+ A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
+):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, state)
-