From 451fd9506e215aa25643e9782cb7d8aed2a266cc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 3 Aug 2022 11:54:01 -0700 Subject: Added fixes for the case that matmullt dim A is zero, e.g. [0, 768]. --- bitsandbytes/functional.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'bitsandbytes/functional.py') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0a2d557..494de1b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. import ctypes as ct import random -from typing import Tuple - +import math import torch + +from typing import Tuple from torch import Tensor from .cextension import lib, COMPILED_WITH_CUDA @@ -919,15 +920,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0]*shapeA[1] - if dimsB == 2: - rows = n = shapeB[0] - elif dimsB == 3: - rows = n = shapeB[0]*shapeB[1] + rows = n = shapeB[0] + assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + print(shapeA, math.prod(shapeA), math.prod(list(shapeA))) + print('aaa') + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row') @@ -984,6 +992,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 1: + print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') raise Exception('cublasLt ran into an error!') torch.cuda.set_device(prev_device) -- cgit v1.2.3