From ab72a1294fda03a0fd4ec297562fdab806349752 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Thu, 4 Aug 2022 07:47:22 -0700 Subject: Added pre/post device call for extract outliers. --- bitsandbytes/functional.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'bitsandbytes/functional.py') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 08c108c..ad85f53 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1198,6 +1198,7 @@ def get_special_format_str(): def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + prev_device = pre_call(A.device) if state is None: state = (A.shape, from_order) else: from_order = state[1] if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) @@ -1214,7 +1215,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ptrA = get_ptr(A) ptrOut = get_ptr(out) is_on_gpu([A, out]) - prev_device = pre_call(A.device) if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) @@ -1237,8 +1237,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - post_call(prev_device) + post_call(prev_device) return out, new_state @@ -1451,10 +1451,12 @@ def extract_outliers(A, SA, idx): ptrIdx = get_ptr(idx) ptrOut = get_ptr(out) + prev_device = pre_call(A.device) if formatA == 'col_turing': lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == 'col_ampere': lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + post_call(prev_device) return out -- cgit v1.2.3