summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd/_functions.py
diff options
context:
space:
mode:
authordbaranchuk <dmitrybaranchuk@gmail.com>2022-09-11 06:26:15 +0300
committerdbaranchuk <dmitrybaranchuk@gmail.com>2022-09-11 06:26:15 +0300
commitd358999e9e2d98a834aaa38ffec1bef983d73fe6 (patch)
treee81da8dd7032209c9f16be83124697ba3fd2c6b3 /bitsandbytes/autograd/_functions.py
parentee325f02157cd23b37059e3dce5fb17cb1c1b137 (diff)
refactoring
Diffstat (limited to 'bitsandbytes/autograd/_functions.py')
-rw-r--r--bitsandbytes/autograd/_functions.py11
1 files changed, 2 insertions, 9 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 008655d..642e516 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -185,11 +185,10 @@ class MatmulLtState:
idx = None
is_training = True
has_fp16_weights = True
+ memory_efficient_backward = False
use_pool = False
formatB = F.get_special_format_str()
- memory_efficient_backward = False
-
def reset_grads(self):
self.CB = None
self.CxB = None
@@ -198,6 +197,7 @@ class MatmulLtState:
self.CxBt = None
self.SBt = None
+ self.CBt = None
class MatMul8bitLt(torch.autograd.Function):
@@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function):
A_dtype = A.dtype
A = A.to(torch.float16)
- assert (
- A.dtype == torch.float16
- ), f"The input data type needs to be fp16 but {A.dtype} was found!"
-
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
@@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function):
return grad_A, grad_B, None, grad_bias, None
-matmul = MatMul8bitLt.apply
-
-
def matmul(
A: tensor,
B: tensor,