diff options
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 11 |
1 files changed, 2 insertions, 9 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 008655d..642e516 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -185,11 +185,10 @@ class MatmulLtState: idx = None is_training = True has_fp16_weights = True + memory_efficient_backward = False use_pool = False formatB = F.get_special_format_str() - memory_efficient_backward = False - def reset_grads(self): self.CB = None self.CxB = None @@ -198,6 +197,7 @@ class MatmulLtState: self.CxBt = None self.SBt = None + self.CBt = None class MatMul8bitLt(torch.autograd.Function): @@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function): A_dtype = A.dtype A = A.to(torch.float16) - assert ( - A.dtype == torch.float16 - ), f"The input data type needs to be fp16 but {A.dtype} was found!" - # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() @@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None -matmul = MatMul8bitLt.apply - - def matmul( A: tensor, B: tensor, |