summaryrefslogtreecommitdiff
path: root/bitsandbytes/functional.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-04 07:47:22 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-04 07:47:22 -0700
commitab72a1294fda03a0fd4ec297562fdab806349752 (patch)
treee40697e0db3c6969c0ce126f42802ea2876e444d /bitsandbytes/functional.py
parentcc5b323876392658b1d91655f30840d24be6d821 (diff)
Added pre/post device call for extract outliers.
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r--bitsandbytes/functional.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 08c108c..ad85f53 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1198,6 +1198,7 @@ def get_special_format_str():
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
+ prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order)
else: from_order = state[1]
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
@@ -1214,7 +1215,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
is_on_gpu([A, out])
- prev_device = pre_call(A.device)
if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
@@ -1237,8 +1237,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
- post_call(prev_device)
+ post_call(prev_device)
return out, new_state
@@ -1451,10 +1451,12 @@ def extract_outliers(A, SA, idx):
ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out)
+ prev_device = pre_call(A.device)
if formatA == 'col_turing':
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == 'col_ampere':
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
+ post_call(prev_device)
return out