diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 07:28:12 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 07:28:12 -0700 |
commit | 6101a8fb9f76c2cc4018452b4420dd52e946d52b (patch) | |
tree | 706df0bd4f6f9d156304c99294b20b41b3858b29 /bitsandbytes | |
parent | 320eacb4c23adeaaf4a54166f19eac950aa631f1 (diff) |
Added pre and post device call to transform.
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 334bdd9..e7261bc 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1214,6 +1214,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ptrA = get_ptr(A) ptrOut = get_ptr(out) is_on_gpu([A, out]) + prev_device = pre_call(A.device) if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) @@ -1236,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - - + post_call(prev_device) return out, new_state |