summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/autograd/_functions.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 5503749..e641583 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -191,6 +191,7 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
@@ -214,7 +215,6 @@ class MatMul8bitLt(torch.autograd.Function):
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
- C32A, SA = F.transform(CA, 'col32')
# 2. Quantize B
if state.has_fp16_weights:
@@ -233,14 +233,15 @@ class MatMul8bitLt(torch.autograd.Function):
# extract outliers
outlier_idx = torch.unique(coo_tensorA.colidx)
- state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
- # do not use pool for 2nd FFN layer
- state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- else:
- state.idx = outlier_idx
- outliers = F.extract_outliers(state.CxB, state.SB, outlier_idx).half()
- state.subB = (outliers*state.SCB.view(-1, 1).half()/127.0).t().contiguous()
+ state.idx = outlier_idx
+ #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # # do not use pool for 2nd FFN layer
+ # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
+ #else:
+ # state.idx = outlier_idx
+ outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
+ state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half()
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
@@ -253,11 +254,12 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
+ C32A, SA = F.transform(CA, 'col32')
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
# 4. Mixed-precision decomposition matmul
- if state.threshold > 0.0 and coo_tensorA is not None and subA is not None:
+ if coo_tensorA is not None and subA is not None:
output += torch.matmul(subA, state.subB)
# 5. Save state