summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-26 12:12:38 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-26 12:12:38 -0700
commitcbb901ac51bd6c41e4243ffb936ef0e2f7ca8ada (patch)
treef02615b5588aa6ed94a51c1e66297595b802c0a1 /bitsandbytes
parentc771b3a75a6ebbfbfc398a028a477246b0799cf0 (diff)
Boilerplate and test for extract_outliers.
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 806c254..a9233e2 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1409,3 +1409,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
x *= SA[1]/127
x +=offset
return x.to(dtype)
+
+def extract_outliers(A, SA, idx):
+ shapeA = SA[0]
+ formatA = SA[1]
+ assert formatA in ['col_turing', 'col_ampere']
+ assert A.device.type == 'cuda'
+
+ out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
+
+ idx_size = ct.c_int32(idx.numel())
+ rows = ct.c_int32(shapeA[0])
+ cols = ct.c_int32(shapeA[1])
+ ptrA = get_ptr(A)
+ ptrIdx = get_ptr(idx)
+ ptrOut = get_ptr(out)
+
+ if formatA == 'col_turing':
+ lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
+ elif formatA == 'col_ampere':
+ lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
+
+ return out
+
+
+
+