From 6101a8fb9f76c2cc4018452b4420dd52e946d52b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Thu, 4 Aug 2022 07:28:12 -0700 Subject: Added pre and post device call to transform. --- bitsandbytes/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 334bdd9..e7261bc 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1214,6 +1214,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ptrA = get_ptr(A) ptrOut = get_ptr(out) is_on_gpu([A, out]) + prev_device = pre_call(A.device) if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) @@ -1236,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - - + post_call(prev_device) return out, new_state -- cgit v1.2.3