summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-16 11:12:09 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-16 11:12:09 -0700
commitdede343033991c32735f01a94019e13fb4968b3c (patch)
tree6e9cc9263a0a52a6991904083f642d053f8b6e79 /csrc
parent111b8764492fd1f9921caae64ce7d7d3ac7ef183 (diff)
Added fused bias in dequant_mm.
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu3
1 files changed, 2 insertions, 1 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 0a1bf79..f01b4e1 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1951,6 +1951,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
float colStat = col >= numCols ? 0.0f : colStats[col];
+ float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
// no block loads for rows for now -- keep it simple
for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
{
@@ -1989,7 +1990,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
- local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat);
+ local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue);
//absmax_col = fmax(fabsf(local_output[j]), absmax_col);
// we store data in row major