diff options
author | Titus von Koeller <titus@vonkoeller.com> | 2022-08-01 09:32:47 -0700 |
---|---|---|
committer | Titus von Koeller <titus@vonkoeller.com> | 2022-08-01 09:32:47 -0700 |
commit | ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 (patch) | |
tree | 3b9ec443a259cf36d87627a8e2cc7d13513f6a21 /bitsandbytes/autograd | |
parent | 3fd06fb6206f46b6d18fbb8a512da63832dea98b (diff) |
reran black with linelength 80 for greater readability
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 45 |
1 files changed, 36 insertions, 9 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index a08b560..b56b2ee 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function): qgrad_output, S1 = F.vectorwise_quant( grad_output, dim=dims, quant_type=quant_type ) - qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) + qA, S2 = F.vectorwise_quant( + A, dim=dims, quant_type=quant_type + ) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) grad_B = F.vectorwise_mm_dequant( igrad_B, @@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function): qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) grad_A = F.vectorwise_mm_dequant( - igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type + igrad_A, + S1, + S3.permute(permute_dim), + grad_output.dtype, + quant_type, ) return grad_A, grad_B, None, None, None @@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A, threshold=state.threshold + ) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function): if state.CxB is None: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + state.CxB, state.SB = F.transform( + state.CB, to_order=formatB + ) # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # # generate outlier index and subB @@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function): if (state.is_training and not has_grad) or state.CxB is None: state.reset_grads() - CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B) + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B) state.CxB, state.SB = F.transform(CB, to_order=formatB) else: has_grad = False @@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function): # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half() + (outliers * state.SCB.view(-1, 1) / 127.0) + .t() + .contiguous() + .half() ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 @@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert state.has_fp16_weights, "Backprop only supported for fp16 weights." + assert ( + state.has_fp16_weights + ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() + grad_output = grad_output.view( + -1, grad_output.shape[-1] + ).contiguous() grad_A = grad_B = None @@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply def matmul( - A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0 + A: tensor, + B: tensor, + out: tensor = None, + state: MatmulLtState = None, + threshold=0.0, ): state = state or MatmulLtState() if threshold > 0.0: |