diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 58 | ||||
-rw-r--r-- | bitsandbytes/functional.py | 26 |
2 files changed, 64 insertions, 20 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 370ca83..607d868 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -203,30 +203,30 @@ class MatMul8bitLt(torch.autograd.Function): # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: - # generate outlier index and subB - outlier_idx = torch.unique(coo_tensorA.colidx).long() - state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # do not use pool for 2nd FFN layer - state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - else: - state.idx = outlier_idx - state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() - - if state.idx is not None: - # extract outliers - CA[:, state.idx] = 0 - CAt[:, state.idx] = 0 - subA = A[:, state.idx] - else: - subA = None + #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() + #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: + # # generate outlier index and subB + # outlier_idx = torch.unique(coo_tensorA.colidx).long() + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx + # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() + + #if state.idx is not None: + # # extract outliers + # CA[:, state.idx] = 0 + # CAt[:, state.idx] = 0 + # subA = A[:, state.idx] + #else: + # subA = None else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None - C32A, SA = F.transform(CA, 'col32') # 2. Quantize B if state.has_fp16_weights: @@ -241,6 +241,23 @@ class MatMul8bitLt(torch.autograd.Function): else: has_grad = False + if coo_tensorA is not None and not state.has_fp16_weights: + # extract outliers + + outlier_idx = torch.unique(coo_tensorA.colidx) + state.idx = outlier_idx + #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + #else: + # state.idx = outlier_idx + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half() + CA[:, state.idx.long()] = 0 + CAt[:, state.idx.long()] = 0 + subA = A[:, state.idx.long()] + shapeB = state.SB[0] if len(input_shape) == 3: @@ -249,11 +266,12 @@ class MatMul8bitLt(torch.autograd.Function): output_shape = (input_shape[0], shapeB[0]) # 3. Matmul + C32A, SA = F.transform(CA, 'col32') out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) output = F.mm_dequant(out32, Sout32, SCA, state.SCB) # 4. Mixed-precision decomposition matmul - if state.threshold > 0.0 and coo_tensorA is not None and subA is not None: + if coo_tensorA is not None and subA is not None: output += torch.matmul(subA, state.subB) # 5. Save state diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e7261bc..08c108c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1435,3 +1435,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): x *= SA[1]/127 x +=offset return x.to(dtype) + +def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ['col_turing', 'col_ampere'] + assert A.device.type == 'cuda' + + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + if formatA == 'col_turing': + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == 'col_ampere': + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + + return out + + + + |