summaryrefslogtreecommitdiff
path: root/csrc/ops.cu
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-11-20 14:18:15 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2022-11-20 14:18:15 -0800
commitc059bd284832d09bc51cf82c377642b26a48ef28 (patch)
treede2790613f5b14d5b6ad9a615cece8045b781b01 /csrc/ops.cu
parenteb028e6ebcddc78c7921c2524d361b23b1a1007b (diff)
Added additional blocksizes: {64, 128, 256}.
Diffstat (limited to 'csrc/ops.cu')
-rw-r--r--csrc/ops.cu12
1 files changed, 12 insertions, 0 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu
index b121fc2..30079e6 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -65,6 +65,12 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
+ else if(blocksize == 256)
+ kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
+ else if(blocksize == 128)
+ kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
+ else if(blocksize == 64)
+ kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
@@ -82,6 +88,12 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
else if(blocksize == 512)
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
+ else if(blocksize == 256)
+ kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
+ else if(blocksize == 128)
+ kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
+ else if(blocksize == 64)
+ kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}