summaryrefslogtreecommitdiff
path: root/bitsandbytes/functional.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-04 07:40:48 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-04 07:40:48 -0700
commitcc5b323876392658b1d91655f30840d24be6d821 (patch)
tree8e23e961709a3cc082a707ebc8ea0f52baee6923 /bitsandbytes/functional.py
parent6101a8fb9f76c2cc4018452b4420dd52e946d52b (diff)
parentbd515328d70f344f935075f359c5aefc616878d5 (diff)
Merge branch 'extract_outliers' into debug
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r--bitsandbytes/functional.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index e7261bc..08c108c 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1435,3 +1435,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
x *= SA[1]/127
x +=offset
return x.to(dtype)
+
+def extract_outliers(A, SA, idx):
+ shapeA = SA[0]
+ formatA = SA[1]
+ assert formatA in ['col_turing', 'col_ampere']
+ assert A.device.type == 'cuda'
+
+ out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
+
+ idx_size = ct.c_int32(idx.numel())
+ rows = ct.c_int32(shapeA[0])
+ cols = ct.c_int32(shapeA[1])
+ ptrA = get_ptr(A)
+ ptrIdx = get_ptr(idx)
+ ptrOut = get_ptr(out)
+
+ if formatA == 'col_turing':
+ lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
+ elif formatA == 'col_ampere':
+ lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
+
+ return out
+
+
+
+