summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu12
1 files changed, 6 insertions, 6 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index d1fa253..d8dfee1 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -715,7 +715,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
switch(OPTIMIZER)
{
case ADAM:
- if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
@@ -868,7 +868,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
- if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
switch(OPTIMIZER)
{
@@ -1475,7 +1475,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
- if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
@@ -1518,7 +1518,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
- if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
if(weight_decay > 0.0f)
@@ -1635,7 +1635,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
- if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;
@@ -1677,7 +1677,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
- if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
switch(OPTIMIZER)
{