summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-23 16:00:26 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-23 16:00:26 -0700
commitee5b947e63c2340405f25e4e83066f39292bc0ed (patch)
tree70fa3b2041e1cedf257feeff4e5b071703c43c2a
parent7e0fb655e1e040221054886fbee9d5682aa6e4e2 (diff)
Fixed issue where Pascal was not displaying proper error.
-rw-r--r--bitsandbytes/functional.py23
-rw-r--r--csrc/ops.cu6
2 files changed, 7 insertions, 22 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 745e7e4..75d083b 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -185,14 +185,9 @@ def create_dynamic_map(signed=True, n=7):
def get_special_format_str():
+ if not torch.cuda.is_available(): return 'col_turing'
major, minor = torch.cuda.get_device_capability()
- if major < 7:
- print(
- f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
- )
- assert major >= 7
-
- if major == 7:
+ if major <= 7:
return "col_turing"
elif major == 8:
return "col_ampere"
@@ -1685,20 +1680,6 @@ def double_quant(
return out_row, out_col, row_stats, col_stats, coo_tensor
-def get_special_format_str():
- if not torch.cuda.is_available(): return 'col_turning'
- major, minor = torch.cuda.get_device_capability()
- if major < 7:
- print(f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!")
- assert major >= 7
-
- if major == 7: return 'col_turing'
- elif major == 8: return 'col_ampere'
- else: return 'col_turing'
-
-
-
-
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order)
diff --git a/csrc/ops.cu b/csrc/ops.cu
index c0ec3cb..e49c94b 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -371,7 +371,11 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
- printf("ERROR: Your GPU does not support Int8 Matmul!");
+ cout << "" << endl;
+ cout << "=============================================" << endl;
+ cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
+ cout << "=============================================" << endl;
+ cout << "" << endl;
assert(false);
return 0;