summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 334bdd9..e7261bc 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1214,6 +1214,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
is_on_gpu([A, out])
+ prev_device = pre_call(A.device)
if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
@@ -1236,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
-
-
+ post_call(prev_device)
return out, new_state