diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 07:47:22 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 07:47:22 -0700 |
commit | ab72a1294fda03a0fd4ec297562fdab806349752 (patch) | |
tree | e40697e0db3c6969c0ce126f42802ea2876e444d /bitsandbytes | |
parent | cc5b323876392658b1d91655f30840d24be6d821 (diff) |
Added pre/post device call for extract outliers.
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 08c108c..ad85f53 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1198,6 +1198,7 @@ def get_special_format_str(): def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + prev_device = pre_call(A.device) if state is None: state = (A.shape, from_order) else: from_order = state[1] if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) @@ -1214,7 +1215,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ptrA = get_ptr(A) ptrOut = get_ptr(out) is_on_gpu([A, out]) - prev_device = pre_call(A.device) if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) @@ -1237,8 +1237,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - post_call(prev_device) + post_call(prev_device) return out, new_state @@ -1451,10 +1451,12 @@ def extract_outliers(A, SA, idx): ptrIdx = get_ptr(idx) ptrOut = get_ptr(out) + prev_device = pre_call(A.device) if formatA == 'col_turing': lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == 'col_ampere': lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + post_call(prev_device) return out |