summaryrefslogtreecommitdiff
path: root/bitsandbytes/functional.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-08 09:13:22 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-08 09:13:22 -0700
commitf9cbe2fe99c805dcca934c66677951f428d3b3e2 (patch)
treefaa06c4f066517312221e3fad8694fdae890daba /bitsandbytes/functional.py
parent62441815bc733c9e75d32dd65305a16aaebd317a (diff)
Fixed prod Python < 3.7 compatibility in function.py.
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r--bitsandbytes/functional.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index b4409e4..1bddb52 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
+import operator
import random
import math
import torch
@@ -11,6 +12,11 @@ from typing import Tuple
from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib
+from functools import reduce # Required in Python 3
+
+# math.prod not compatible with python < 3.8
+def prod(iterable):
+ return reduce(operator.mul, iterable, 1)
name2qmap = {}
@@ -326,8 +332,8 @@ def nvidia_transform(
dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1])
elif ld is not None:
- n = math.prod(shape)
- dim1 = math.prod([shape[i] for i in ld])
+ n = prod(shape)
+ dim1 = prod([shape[i] for i in ld])
dim2 = ct.c_int32(n // dim1)
dim1 = ct.c_int32(dim1)
else:
@@ -1314,7 +1320,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
m = shapeA[0] * shapeA[1]
rows = n = shapeB[0]
- assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
+ assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
# if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2: