From dede343033991c32735f01a94019e13fb4968b3c Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 11:12:09 -0700 Subject: Added fused bias in dequant_mm. --- csrc/kernels.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'csrc/kernels.cu') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 0a1bf79..f01b4e1 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1951,6 +1951,7 @@ template __global__ void kd // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. float colStat = col >= numCols ? 0.0f : colStats[col]; + float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); // no block loads for rows for now -- keep it simple for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) { @@ -1989,7 +1990,7 @@ template __global__ void kd #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat); + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); //absmax_col = fmax(fabsf(local_output[j]), absmax_col); // we store data in row major -- cgit v1.2.3