summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-16 10:57:10 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-16 10:57:10 -0700
commit111b8764492fd1f9921caae64ce7d7d3ac7ef183 (patch)
tree5e2f62b52708cb17e30acd26e74743d840afdbd7 /bitsandbytes/autograd
parent1ed2fa2f218d8dac401f3315420ffec92014c124 (diff)
parent1ced47c5043ed88b78c288f55f43ec3e66a0f765 (diff)
Merge branch 'cuda-bin-switch-and-cli' of github.com:TimDettmers/bitsandbytes into cuda-bin-switch-and-cli
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r--bitsandbytes/autograd/_functions.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 14f2660..01e7073 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,10 +1,15 @@
-from dataclasses import dataclass
-
+import operator
import torch
-import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
+from dataclasses import dataclass
+from functools import reduce # Required in Python 3
+
+# math.prod not compatible with python < 3.8
+def prod(iterable):
+ return reduce(operator.mul, iterable, 1)
+
tensor = torch.Tensor
"""
@@ -12,8 +17,6 @@ tensor = torch.Tensor
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
"""
-
-
class GlobalOutlierPooler(object):
_instance = None
@@ -201,7 +204,7 @@ class MatMul8bitLt(torch.autograd.Function):
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
- if math.prod(A.shape) == 0:
+ if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B