summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r--bitsandbytes/autograd/_functions.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 14f2660..a5446b7 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,10 +1,14 @@
-from dataclasses import dataclass
-
+import operator
import torch
-import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
+from dataclasses import dataclass
+from functools import reduce # Required in Python 3
+
+def prod(iterable):
+ return reduce(operator.mul, iterable, 1)
+
tensor = torch.Tensor
"""
@@ -12,8 +16,6 @@ tensor = torch.Tensor
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
"""
-
-
class GlobalOutlierPooler(object):
_instance = None
@@ -201,7 +203,7 @@ class MatMul8bitLt(torch.autograd.Function):
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
- if math.prod(A.shape) == 0:
+ if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B