summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd
diff options
context:
space:
mode:
authorTitus von Koeller <titus@vonkoeller.com>2022-08-01 09:32:47 -0700
committerTitus von Koeller <titus@vonkoeller.com>2022-08-01 09:32:47 -0700
commitea7c14f8ef64924f2d0ff80df3cdabf2c7299848 (patch)
tree3b9ec443a259cf36d87627a8e2cc7d13513f6a21 /bitsandbytes/autograd
parent3fd06fb6206f46b6d18fbb8a512da63832dea98b (diff)
reran black with linelength 80 for greater readability
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r--bitsandbytes/autograd/_functions.py45
1 files changed, 36 insertions, 9 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index a08b560..b56b2ee 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function):
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
- qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
+ qA, S2 = F.vectorwise_quant(
+ A, dim=dims, quant_type=quant_type
+ )
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
@@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function):
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
- igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
+ igrad_A,
+ S1,
+ S3.permute(permute_dim),
+ grad_output.dtype,
+ quant_type,
)
return grad_A, grad_B, None, None, None
@@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function):
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
- CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
+ CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
+ A, threshold=state.threshold
+ )
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
@@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function):
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
- state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ state.CxB, state.SB = F.transform(
+ state.CB, to_order=formatB
+ )
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
@@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function):
if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
- CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
+ (
+ CB,
+ state.CBt,
+ state.SCB,
+ state.SCBt,
+ coo_tensorB,
+ ) = F.double_quant(B)
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False
@@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
- (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
+ (outliers * state.SCB.view(-1, 1) / 127.0)
+ .t()
+ .contiguous()
+ .half()
)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
@@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
- assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
+ assert (
+ state.has_fp16_weights
+ ), "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
- grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
+ grad_output = grad_output.view(
+ -1, grad_output.shape[-1]
+ ).contiguous()
grad_A = grad_B = None
@@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply
def matmul(
- A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
+ A: tensor,
+ B: tensor,
+ out: tensor = None,
+ state: MatmulLtState = None,
+ threshold=0.0,
):
state = state or MatmulLtState()
if threshold > 0.0: