diff options
author | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-09-11 06:26:15 +0300 |
---|---|---|
committer | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-09-11 06:26:15 +0300 |
commit | d358999e9e2d98a834aaa38ffec1bef983d73fe6 (patch) | |
tree | e81da8dd7032209c9f16be83124697ba3fd2c6b3 /bitsandbytes/autograd | |
parent | ee325f02157cd23b37059e3dce5fb17cb1c1b137 (diff) |
refactoring
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 11 |
1 files changed, 2 insertions, 9 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 008655d..642e516 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -185,11 +185,10 @@ class MatmulLtState: idx = None is_training = True has_fp16_weights = True + memory_efficient_backward = False use_pool = False formatB = F.get_special_format_str() - memory_efficient_backward = False - def reset_grads(self): self.CB = None self.CxB = None @@ -198,6 +197,7 @@ class MatmulLtState: self.CxBt = None self.SBt = None + self.CBt = None class MatMul8bitLt(torch.autograd.Function): @@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function): A_dtype = A.dtype A = A.to(torch.float16) - assert ( - A.dtype == torch.float16 - ), f"The input data type needs to be fp16 but {A.dtype} was found!" - # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() @@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None -matmul = MatMul8bitLt.apply - - def matmul( A: tensor, B: tensor, |