summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
authorTim Dettmers <TimDettmers@users.noreply.github.com>2022-07-18 09:51:37 -0700
committerGitHub <noreply@github.com>2022-07-18 09:51:37 -0700
commit4cd7ea62b2f51c68aacde2f62e7141765e476111 (patch)
tree548b2e77c62acd152330e898a6e17ea949a156d1 /csrc
parent3418cd390e952a7752fb6b2544c25e25af7c0371 (diff)
parentfd750cd2370b3b12e216a9148b23aaae63a80989 (diff)
Merge pull request #3 from TimDettmers/cpuonly
Add a CPU-only build option
Diffstat (limited to 'csrc')
-rw-r--r--csrc/common.cpp39
-rw-r--r--csrc/common.h23
-rw-r--r--csrc/cpu_ops.cpp57
-rw-r--r--csrc/cpu_ops.h9
-rw-r--r--csrc/ops.cu122
-rw-r--r--csrc/ops.cuh10
-rw-r--r--csrc/pythonInterface.c11
7 files changed, 142 insertions, 129 deletions
diff --git a/csrc/common.cpp b/csrc/common.cpp
new file mode 100644
index 0000000..972602b
--- /dev/null
+++ b/csrc/common.cpp
@@ -0,0 +1,39 @@
+#include <common.h>
+#include <float.h>
+
+void *quantize_block(void *arguments) {
+ // 1. find absmax in block
+ // 2. divide input value by absmax to normalize into [-1.0, 1.0]
+ // 3. do binary search to find the closest value
+ // 4. check minimal distance
+ // 5. store index
+
+ struct quantize_block_args *args = (quantize_block_args *) arguments;
+
+ // 1. find absmax in block
+ float absmax_block = -FLT_MAX;
+ for (int i = args->block_idx; i < args->block_end; i++)
+ absmax_block = fmax(absmax_block, fabs(args->A[i]));
+
+ args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block;
+
+ for (int i = args->block_idx; i < args->block_end; i++) {
+ // 2. divide input value by absmax to normalize into [-1.0, 1.0]
+ // 3. do binary search to find the closest value
+ float normed_value = args->A[i] / absmax_block;
+ int idx = args->bin_searcher->scalar(normed_value);
+
+ // 4. check minimal distance
+ // The binary search returns always the value to the left, which might not be the closest value
+ if (idx < 255) {
+ float dist_left = fabs(normed_value - (args->code[idx]));
+ float dist_right = fabs(normed_value - (args->code[idx + 1]));
+ if (dist_right < dist_left) { idx += 1; }
+ }
+
+ // 5. store index
+ args->out[i] = (unsigned char) idx;
+ }
+
+ return NULL;
+}
diff --git a/csrc/common.h b/csrc/common.h
new file mode 100644
index 0000000..2f25a58
--- /dev/null
+++ b/csrc/common.h
@@ -0,0 +1,23 @@
+#include <BinSearch.h>
+
+#ifndef common
+#define common
+
+using namespace BinSearch;
+
+struct quantize_block_args {
+ BinAlgo<Scalar, float, Direct2> *bin_searcher;
+ float *code;
+ float *A;
+ float *absmax;
+ unsigned char *out;
+ int block_end;
+ int block_idx;
+ int threadidx;
+};
+
+#define BLOCK_SIZE 4096
+
+void *quantize_block(void *arguments);
+
+#endif
diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp
new file mode 100644
index 0000000..89de52d
--- /dev/null
+++ b/csrc/cpu_ops.cpp
@@ -0,0 +1,57 @@
+#include <BinSearch.h>
+#include <pthread.h>
+#include <common.h>
+
+using namespace BinSearch;
+
+void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) {
+ for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
+ int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
+ int block_end = block_idx + valid_items;
+ for (int i = block_idx; i < block_end; i++)
+ out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE];
+ }
+}
+
+void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) {
+
+ // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
+ code[0] = -1.0f;
+
+ int num_blocks = n / BLOCK_SIZE;
+ num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
+
+ pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks);
+ struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *));
+
+ for (int i = 0; i < num_blocks; i++)
+ args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
+
+ const uint32 elements_code = 256;
+ BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
+
+ for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
+ int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
+ int block_end = block_idx + valid_items;
+
+ struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE];
+ arg->bin_searcher = &bin_searcher;
+ arg->code = code;
+ arg->A = A;
+ arg->absmax = absmax;
+ arg->out = out;
+ arg->block_end = block_end;
+ arg->block_idx = block_idx;
+ arg->threadidx = block_idx / BLOCK_SIZE;
+
+ pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg);
+ }
+
+ for (int i = 0; i < num_blocks; i++)
+ int err = pthread_join(threads[i], NULL);
+
+ free(threads);
+ for (int i = 0; i < num_blocks; i++)
+ free(args[i]);
+ free(args);
+}
diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h
new file mode 100644
index 0000000..57145a9
--- /dev/null
+++ b/csrc/cpu_ops.h
@@ -0,0 +1,9 @@
+#ifndef BITSANDBYTES_CPU_OPS_H
+#define BITSANDBYTES_CPU_OPS_H
+
+
+void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
+
+void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
+
+#endif
diff --git a/csrc/ops.cu b/csrc/ops.cu
index 9691241..40c185c 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -8,123 +8,13 @@
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
+#include <common.h>
using namespace BinSearch;
using std::cout;
using std::endl;
-#define BLOCK_SIZE 4096
-
-struct quantize_block_args
-{
- BinAlgo<Scalar, float, Direct2> *bin_searcher;
- float *code;
- float *A;
- float *absmax;
- unsigned char *out;
- int block_end;
- int block_idx;
- int threadidx;
-};
-
-void *quantize_block(void *arguments)
-{
- // 1. find absmax in block
- // 2. divide input value by absmax to normalize into [-1.0, 1.0]
- // 3. do binary search to find the closest value
- // 4. check minimal distance
- // 5. store index
-
- struct quantize_block_args *args = (quantize_block_args*)arguments;
-
- // 1. find absmax in block
- float absmax_block = -FLT_MAX;
- for (int i = args->block_idx; i < args->block_end; i++)
- absmax_block = fmax(absmax_block, fabs(args->A[i]));
-
- args->absmax[args->block_idx/BLOCK_SIZE] = absmax_block;
-
- for (int i = args->block_idx; i < args->block_end; i++)
- {
- // 2. divide input value by absmax to normalize into [-1.0, 1.0]
- // 3. do binary search to find the closest value
- float normed_value = args->A[i]/absmax_block;
- int idx = args->bin_searcher->scalar(normed_value);
-
- // 4. check minimal distance
- // The binary search returns always the value to the left, which might not be the closest value
- if(idx < 255)
- {
- float dist_left = fabs(normed_value-(args->code[idx]));
- float dist_right = fabs(normed_value-(args->code[idx+1]));
- if(dist_right < dist_left){ idx+=1; }
- }
-
- // 5. store index
- args->out[i] = (unsigned char)idx;
- }
-
- return NULL;
-}
-
-void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n)
-{
-
- // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
- code[0] = -1.0f;
-
- int num_blocks = n/BLOCK_SIZE;
- num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
-
- pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t)*num_blocks);
- struct quantize_block_args **args = (quantize_block_args**)malloc(num_blocks*sizeof(quantize_block_args*));
-
- for(int i = 0; i < num_blocks; i++)
- args[i] = (quantize_block_args*)malloc(sizeof(quantize_block_args));
-
- const uint32 elements_code = 256;
- BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
-
- for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
- {
- int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
- int block_end = block_idx + valid_items;
-
- struct quantize_block_args *arg = args[block_idx/BLOCK_SIZE];
- arg->bin_searcher = &bin_searcher;
- arg->code = code;
- arg->A = A;
- arg->absmax = absmax;
- arg->out = out;
- arg->block_end = block_end;
- arg->block_idx = block_idx;
- arg->threadidx = block_idx/BLOCK_SIZE;
-
- pthread_create(&threads[block_idx/BLOCK_SIZE], NULL, &quantize_block, (void *)arg);
- }
-
- for(int i = 0; i < num_blocks; i++)
- int err = pthread_join(threads[i], NULL);
-
- free(threads);
- for(int i = 0; i < num_blocks; i++)
- free(args[i]);
- free(args);
-}
-
-
-void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n)
-{
- for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
- {
- int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
- int block_end = block_idx + valid_items;
- for (int i = block_idx; i < block_end; i++)
- out[i] = code[A[i]]*absmax[block_idx/BLOCK_SIZE];
- }
-}
-
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{
int threads = 512;
@@ -178,7 +68,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
-template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
+template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
@@ -189,7 +79,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
{
case ADAM:
if(max_unorm > 0.0f)
- {
+ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
@@ -202,7 +92,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case ADAGRAD:
if(max_unorm > 0.0f)
- {
+ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
@@ -218,7 +108,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm,
float beta1, float beta2,
- float eps, int step, float lr,
+ float eps, int step, float lr,
float* quantiles1, float* quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
@@ -261,7 +151,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
#define NUM_1STATE 8
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
- unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
+ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 1bc13fb..8fb4cec 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -68,16 +68,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
-void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
-void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
-
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
#endif
-
-
-
-
-
-
-
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index e0b0d59..c2fed6b 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -3,7 +3,10 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
+#if BUILD_CUDA
#include <ops.cuh>
+#endif
+#include <cpu_ops.h>
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
@@ -12,6 +15,7 @@
// UNMANGLED CALLS
//===================================================================================
+#if BUILD_CUDA
void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
@@ -78,9 +82,11 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
+#endif
extern "C"
{
+ #if BUILD_CUDA
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
@@ -147,11 +153,10 @@ extern "C"
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
+ void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
+ #endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
-
- void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
}
-