diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 12:12:38 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 12:12:38 -0700 |
commit | cbb901ac51bd6c41e4243ffb936ef0e2f7ca8ada (patch) | |
tree | f02615b5588aa6ed94a51c1e66297595b802c0a1 /bitsandbytes | |
parent | c771b3a75a6ebbfbfc398a028a477246b0799cf0 (diff) |
Boilerplate and test for extract_outliers.
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 806c254..a9233e2 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1409,3 +1409,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): x *= SA[1]/127 x +=offset return x.to(dtype) + +def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ['col_turing', 'col_ampere'] + assert A.device.type == 'cuda' + + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + if formatA == 'col_turing': + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == 'col_ampere': + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + + return out + + + + |