summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd/_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/autograd/_functions.py')
-rw-r--r--bitsandbytes/autograd/_functions.py55
1 files changed, 20 insertions, 35 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 01e7073..4dbf129 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -201,13 +201,14 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
- def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
+ ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
else:
@@ -220,6 +221,7 @@ class MatMul8bitLt(torch.autograd.Function):
# 5. Save state
requires_gradA = A.requires_grad
requires_gradB = B.requires_grad
+ requires_gradBias = bias is not None and bias.requires_grad
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None:
@@ -247,28 +249,7 @@ class MatMul8bitLt(torch.autograd.Function):
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
- state.CxB, state.SB = F.transform(
- state.CB, to_order=formatB
- )
- # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
- # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
- # # generate outlier index and subB
- # outlier_idx = torch.unique(coo_tensorA.colidx).long()
- # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
- # # do not use pool for 2nd FFN layer
- # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- # else:
- # state.idx = outlier_idx
- # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
-
- # if state.idx is not None:
- # # extract outliers
- # CA[:, state.idx] = 0
- # CAt[:, state.idx] = 0
- # subA = A[:, state.idx]
- # else:
- # subA = None
+ state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@@ -326,7 +307,8 @@ class MatMul8bitLt(torch.autograd.Function):
# 3. Matmul
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
- output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
+ # we apply the fused bias here
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
@@ -337,7 +319,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.formatB = formatB
ctx.grad_shape = input_shape
- ctx.req_grads = [requires_gradA, requires_gradB]
+ ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
if requires_gradA or requires_gradB:
ctx.tensors = (CAt, subA)
@@ -347,15 +329,16 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
- # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
- clone_func = torch.clone
+ clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
+ #clone_func = torch.clone
return clone_func(output.view(output_shape))
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
- return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
- req_gradA, req_gradB = ctx.req_grads
+ bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
+ return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
+ req_gradA, req_gradB, req_gradBias = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
@@ -369,7 +352,7 @@ class MatMul8bitLt(torch.autograd.Function):
-1, grad_output.shape[-1]
).contiguous()
- grad_A = grad_B = None
+ grad_A = grad_B = grad_bias = None
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
@@ -387,11 +370,12 @@ class MatMul8bitLt(torch.autograd.Function):
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
- grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
- ctx.grad_shape
- )
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+
+ if req_gradBias:
+ grad_bias = grad_output.sum(0)
- return grad_A, grad_B, None, None
+ return grad_A, grad_B, None, grad_bias, None
matmul = MatMul8bitLt.apply
@@ -403,8 +387,9 @@ def matmul(
out: tensor = None,
state: MatmulLtState = None,
threshold=0.0,
+ bias=None
):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
- return MatMul8bitLt.apply(A, B, out, state)
+ return MatMul8bitLt.apply(A, B, out, bias, state)