summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile48
-rw-r--r--bitsandbytes/functional.py4
-rw-r--r--bitsandbytes/optim/optimizer.py8
-rw-r--r--csrc/kernels.cu122
-rw-r--r--csrc/pythonInterface.c2
5 files changed, 102 insertions, 82 deletions
diff --git a/Makefile b/Makefile
index 1a5f17f..4fbe918 100644
--- a/Makefile
+++ b/Makefile
@@ -15,29 +15,31 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
# NVIDIA NVCC compilation flags
-COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
-COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
-COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
-COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
-COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
-COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
-COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
-COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
-COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
-
-# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
-CC_CUDA92 := -gencode arch=compute_30,code=sm_30
-
-# Later versions of CUDA support the new architectures
-CC_CUDA10x := -gencode arch=compute_30,code=sm_30
-CC_CUDA10x += -gencode arch=compute_75,code=sm_75
-
-CC_CUDA110 := -gencode arch=compute_75,code=sm_75
-CC_CUDA110 += -gencode arch=compute_80,code=sm_80
-
-CC_CUDA11x := -gencode arch=compute_75,code=sm_75
-CC_CUDA11x += -gencode arch=compute_80,code=sm_80
-CC_CUDA11x += -gencode arch=compute_86,code=sm_86
+#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
+#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
+#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
+#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
+#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
+#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
+#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
+#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
+#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
+#
+## CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
+#CC_CUDA92 := -gencode arch=compute_30,code=sm_30
+#
+## Later versions of CUDA support the new architectures
+#CC_CUDA10x := -gencode arch=compute_30,code=sm_30
+#CC_CUDA10x += -gencode arch=compute_75,code=sm_75
+#
+#CC_CUDA110 := -gencode arch=compute_75,code=sm_75
+#CC_CUDA110 += -gencode arch=compute_80,code=sm_80
+#
+#CC_CUDA11x := -gencode arch=compute_75,code=sm_75
+#CC_CUDA11x += -gencode arch=compute_80,code=sm_80
+#CC_CUDA11x += -gencode arch=compute_86,code=sm_86
+
+COMPUTE_CAPABILITY := -gencode arch=compute_70,code=sm_70 # Volta
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 48ab40c..9fe1345 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -486,13 +486,13 @@ def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, s
str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
+ get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
+ get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
else:
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 25512b1..4b70b5c 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -336,7 +336,7 @@ class Optimizer2State(Optimizer8bit):
if state['state1'].dtype == torch.float:
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
@@ -352,7 +352,7 @@ class Optimizer2State(Optimizer8bit):
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'],
state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
- config['weight_decay'], gnorm_scale=gnorm_scale)
+ config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
class Optimizer1State(Optimizer8bit):
@@ -450,7 +450,7 @@ class Optimizer1State(Optimizer8bit):
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
None, 0.0, config['weight_decay'], gnorm_scale,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
- skip_zeros=False)
+ skip_zeros=config['skip_zeros'])
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
@@ -463,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'],
state['qmap1'], None, state['absmax1'], None,
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False)
+ config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index f8f7b62..d1fa253 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -715,9 +715,12 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
switch(OPTIMIZER)
{
case ADAM:
- s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
- s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
- p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
+ s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
+ p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+ }
break;
}
}
@@ -865,21 +868,24 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
- switch(OPTIMIZER)
- {
- case MOMENTUM:
- if(step == 1)
- s1_vals[j] = (float)g_vals[j];
- else
- s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
-
- p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
- break;
- case RMSPROP:
- s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
- p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
- break;
- }
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = (float)g_vals[j];
+ else
+ s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
+
+ p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
+ p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
+ break;
+ }
+ }
}
__syncthreads();
@@ -1469,11 +1475,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
- s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
- s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
+ s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
- s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
- s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
+ s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
+ s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
+ }
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
@@ -1509,9 +1518,12 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
- g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
- if(weight_decay > 0.0f)
- g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
+ if(weight_decay > 0.0f)
+ g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
+ }
}
// store: 0.85/1.44 -> 2.48/1.57
@@ -1623,23 +1635,26 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
- if(weight_decay > 0.0f)
- g_val += ((float)p_vals[j])*weight_decay;
-
- s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
-
- switch(OPTIMIZER)
- {
- case MOMENTUM:
- if(step == 1)
- s1_vals[j] = g_val;
- else
- s1_vals[j] = (s1_vals[j]*beta1) + g_val;
- break;
- case RMSPROP:
- s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
- break;
- }
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ if(weight_decay > 0.0f)
+ g_val += ((float)p_vals[j])*weight_decay;
+
+ s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
+
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = g_val;
+ else
+ s1_vals[j] = (s1_vals[j]*beta1) + g_val;
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
+ break;
+ }
+ }
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
}
@@ -1662,16 +1677,19 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
- switch(OPTIMIZER)
- {
- case MOMENTUM:
- p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
- break;
- case RMSPROP:
- g_val = g_vals[j];
- p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
- break;
- }
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
+ break;
+ case RMSPROP:
+ g_val = g_vals[j];
+ p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
+ break;
+ }
+ }
}
// store: 0.85/1.44 -> 2.48/1.57
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index 67bf2e5..7d5e654 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -110,7 +110,7 @@ extern "C"
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
- float weight_decay, float gnorm_scale, bool skip_zeros, int n) \
+ float weight_decay, float gnorm_scale, int n) \
{ \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \