From 8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Tue, 23 Aug 2022 23:39:54 +0300 Subject: add memory efficient backward --- bitsandbytes/autograd/_functions.py | 39 ++++++++++++++++++------------------- bitsandbytes/nn/modules.py | 13 +++++++++---- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 4dbf129..63e8ad5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - else: - if state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + elif state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() @@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradA: C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + if state.CxBt is None and state.has_fp16_weights: + CBt = state.CBt + elif state.CxBt is None: + assert state.CBt is None + CB = state.CB.half() + SCB = state.SCB.unsquezee(1).half() + SCBt = state.SCBt.unsquezee(1).half() + Bt = (CB * SCB).t().contiguous() + CBt = (Bt / SCBt).t().to(torch.int8) + + CxBt, SBt = F.transform( + CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) if req_gradBias: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b222f54..ef7fefc 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, + SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None + cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt - del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) + setattr(self, "SCBt", SCBt) return self @@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB + new_param.SCB = self.SCBt return new_param @@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB + self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None + self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training @@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: + # if not self.state.has_fp16_weights and self.state.CB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + # del self.state.CB + # self.weight.data = self.state.CxB return out -- cgit v1.2.3 From 1753aa04185b10a3bb52f7289ed4af15cf2502a7 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Tue, 23 Aug 2022 23:51:00 +0300 Subject: refactoring --- bitsandbytes/autograd/_functions.py | 40 +++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 63e8ad5..8ce1e60 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - elif state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + else: + if state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradA: C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None and state.has_fp16_weights: - CBt = state.CBt - elif state.CxBt is None: - assert state.CBt is None - CB = state.CB.half() - SCB = state.SCB.unsquezee(1).half() - SCBt = state.SCBt.unsquezee(1).half() - Bt = (CB * SCB).t().contiguous() - CBt = (Bt / SCBt).t().to(torch.int8) - - CxBt, SBt = F.transform( - CBt, to_order=formatB, transpose=True - ) + if state.CxBt is None: + if state.has_fp16_weights: + CBt = state.CBt + else: + # Restore CBt from CB + assert state.CBt is None, "CBt should not be stored in state" + CB = state.CB.half() + SCB = state.SCB.unsquezee(1).half() + SCBt = state.SCBt.unsquezee(1).half() + Bt = (CB * SCB).t().contiguous() + CBt = (Bt / SCBt).t().to(torch.int8) + + # intentionally, do not store CxBt into state + CxBt, SBt = F.transform( + CBt, to_order=formatB, transpose=True + ) + else: + CxBt = state.CxBt gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) -- cgit v1.2.3 From 656de8ed110fce4e94b4f9d48494ecc5f8e04970 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Tue, 23 Aug 2022 23:53:43 +0300 Subject: minor fixes --- bitsandbytes/autograd/_functions.py | 2 +- bitsandbytes/nn/modules.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 8ce1e60..641a779 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -368,7 +368,7 @@ class MatMul8bitLt(torch.autograd.Function): Bt = (CB * SCB).t().contiguous() CBt = (Bt / SCBt).t().to(torch.int8) - # intentionally, do not store CxBt into state + # intentionally, do not store CxBt in state CxBt, SBt = F.transform( CBt, to_order=formatB, transpose=True ) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ef7fefc..360a182 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -212,7 +212,7 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB - new_param.SCB = self.SCBt + new_param.SCBt = self.SCBt return new_param -- cgit v1.2.3 From 876387dc0c1c71ad9cd827d4aecc31190313c7ab Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Wed, 24 Aug 2022 01:12:48 +0300 Subject: minor fixes --- bitsandbytes/autograd/_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 641a779..7cf4999 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -363,8 +363,8 @@ class MatMul8bitLt(torch.autograd.Function): # Restore CBt from CB assert state.CBt is None, "CBt should not be stored in state" CB = state.CB.half() - SCB = state.SCB.unsquezee(1).half() - SCBt = state.SCBt.unsquezee(1).half() + SCB = state.SCB.unsqueeze(1).half() + SCBt = state.SCBt.unsqueeze(1).half() Bt = (CB * SCB).t().contiguous() CBt = (Bt / SCBt).t().to(torch.int8) -- cgit v1.2.3 From ef2936a90d903d0f9a27e16ecb7f839f2c4d9ba1 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Wed, 24 Aug 2022 01:33:04 +0300 Subject: delete CxB from state --- bitsandbytes/nn/modules.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 360a182..03ffd3b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -260,11 +260,10 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - # if not self.state.has_fp16_weights and self.state.CB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - # del self.state.CB - # self.weight.data = self.state.CxB + if not self.state.has_fp16_weights and self.state.CxB is not None: + # In this version, we convert 8-bit row major to turing/ampere format at each inference pass + # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. + del self.state.CxB return out -- cgit v1.2.3 From 4d6174bc6336fb6fba712f1d2c903de1de677747 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Thu, 25 Aug 2022 19:09:23 +0300 Subject: memory efficient fp16 backward --- bitsandbytes/autograd/_functions.py | 40 +++++-------------------------------- bitsandbytes/nn/modules.py | 7 +------ 2 files changed, 6 insertions(+), 41 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7cf4999..52e56d0 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -196,7 +196,6 @@ class MatmulLtState: self.CxBt = None self.SBt = None - self.CBt = None class MatMul8bitLt(torch.autograd.Function): @@ -327,15 +326,12 @@ class MatMul8bitLt(torch.autograd.Function): #clone_func = torch.clone return clone_func(output.view(output_shape)) - @staticmethod def backward(ctx, grad_output): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, req_gradBias = ctx.req_grads - CAt, subA = ctx.tensors - SCAt, idx = ctx.tensor_states - formatB = ctx.formatB + assert not req_gradB, "TODO: support weight updates as well" state = ctx.state if len(grad_output.shape) == 3: @@ -345,37 +341,11 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = grad_B = grad_bias = None - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) - if req_gradB: - CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) - gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) - if state.threshold > 0.0 and subA is not None: - grad_B[:, idx] += torch.matmul(grad_output.t(), subA) - if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - if state.has_fp16_weights: - CBt = state.CBt - else: - # Restore CBt from CB - assert state.CBt is None, "CBt should not be stored in state" - CB = state.CB.half() - SCB = state.SCB.unsqueeze(1).half() - SCBt = state.SCBt.unsqueeze(1).half() - Bt = (CB * SCB).t().contiguous() - CBt = (Bt / SCBt).t().to(torch.int8) - - # intentionally, do not store CxBt in state - CxBt, SBt = F.transform( - CBt, to_order=formatB, transpose=True - ) - else: - CxBt = state.CxBt - gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + CB = state.CB.half() + SCB = state.SCB.unsqueeze(1).half() + B = (CB * SCB) / 127.0 + grad_A = torch.mm(grad_output, B).view(ctx.grad_shape) if req_gradBias: grad_bias = grad_output.sum(0) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 03ffd3b..3e32c8e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, - SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None - cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt + del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) - setattr(self, "SCBt", SCBt) return self @@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB - new_param.SCBt = self.SCBt return new_param @@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB - self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None - self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training -- cgit v1.2.3 From b3fee1ed6a357f70eb72c5c75c93349bb8c6fcdd Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Fri, 26 Aug 2022 04:11:40 +0300 Subject: add dtype <-> fp16 cast --- bitsandbytes/autograd/_functions.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 52e56d0..e266d69 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -213,6 +213,10 @@ class MatMul8bitLt(torch.autograd.Function): else: return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + # Cast A to fp16 + A_dtype = A.dtype + A = A.to(torch.float16) + # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -322,14 +326,21 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + # Cast fp16 output back to A.dtype + output = output.to(A_dtype) + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - #clone_func = torch.clone return clone_func(output.view(output_shape)) def backward(ctx, grad_output): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + + # Cast grad_output to fp16 + grad_output_dtype = grad_output.dtype + grad_output.to(torch.float16) + req_gradA, req_gradB, req_gradBias = ctx.req_grads assert not req_gradB, "TODO: support weight updates as well" state = ctx.state @@ -350,6 +361,9 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradBias: grad_bias = grad_output.sum(0) + # Cast grad_A back to grad_output_dtype + grad_output.to(grad_output_dtype) + return grad_A, grad_B, None, grad_bias, None -- cgit v1.2.3 From 8d34d36f150b0fd4914cdb56d4e3bda34c029ccc Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Mon, 29 Aug 2022 00:56:08 +0300 Subject: req_gradA for casted & more efficient and accurate fp16 backward --- bitsandbytes/autograd/_functions.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e266d69..3bd39a9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -213,10 +213,6 @@ class MatMul8bitLt(torch.autograd.Function): else: return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) - # Cast A to fp16 - A_dtype = A.dtype - A = A.to(torch.float16) - # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -229,6 +225,11 @@ class MatMul8bitLt(torch.autograd.Function): input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() + + # Cast A to fp16 + A_dtype = A.dtype + A = A.to(torch.float16) + assert ( A.dtype == torch.float16 ), f"The input data type needs to be fp16 but {A.dtype} was found!" @@ -337,14 +338,14 @@ class MatMul8bitLt(torch.autograd.Function): bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - # Cast grad_output to fp16 - grad_output_dtype = grad_output.dtype - grad_output.to(torch.float16) - req_gradA, req_gradB, req_gradBias = ctx.req_grads assert not req_gradB, "TODO: support weight updates as well" state = ctx.state + # Cast grad_output to fp16 + grad_output_dtype = grad_output.dtype + grad_output = grad_output.to(torch.float16) + if len(grad_output.shape) == 3: grad_output = grad_output.reshape( -1, grad_output.shape[-1] @@ -354,9 +355,9 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradA: CB = state.CB.half() - SCB = state.SCB.unsqueeze(1).half() - B = (CB * SCB) / 127.0 - grad_A = torch.mm(grad_output, B).view(ctx.grad_shape) + SCB = (state.SCB.unsqueeze(1) / 127.0).half() + CB *= SCB + grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) if req_gradBias: grad_bias = grad_output.sum(0) -- cgit v1.2.3 From 42b5fc9acc4b59a6d90c662eb26099ac25907c7f Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 05:51:29 +0300 Subject: add memory effcient backward option --- bitsandbytes/autograd/_functions.py | 46 ++++++++++++++++++++++++++++++++----- bitsandbytes/nn/modules.py | 16 +++++++++---- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 226cbb5..271c690 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,5 +1,6 @@ import operator import torch +import bitsandbytes as bnb import bitsandbytes.functional as F from dataclasses import dataclass @@ -187,6 +188,8 @@ class MatmulLtState: use_pool = False formatB = F.get_special_format_str() + memory_efficient_backward = False + def reset_grads(self): self.CB = None self.CxB = None @@ -283,6 +286,12 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -332,13 +341,15 @@ class MatMul8bitLt(torch.autograd.Function): clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) + @staticmethod def backward(ctx, grad_output): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, req_gradBias = ctx.req_grads - assert not req_gradB, "TODO: support weight updates as well" + CAt, subA = ctx.tensors + SCAt, idx = ctx.tensor_states + formatB = ctx.formatB state = ctx.state # Cast grad_output to fp16 @@ -352,11 +363,31 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = grad_B = grad_bias = None + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) + if req_gradB: + CxAt, SAt = F.transform(CAt, formatB, transpose=True) + C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) + gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + if state.threshold > 0.0 and subA is not None: + grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + if req_gradA: - CB = state.CB.half() - SCB = (state.SCB.unsqueeze(1) / 127.0).half() - CB *= SCB - grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) + if state.CBt: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + elif state.CB: + CB = state.CB.half() + SCB = (state.SCB.unsqueeze(1) / 127.0).half() + CB *= SCB + grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) + else: + raise Exception('State must contain either CBt or CB matrix') if req_gradBias: grad_bias = grad_output.sum(0) @@ -367,6 +398,9 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None +matmul = MatMul8bitLt.apply + + def matmul( A: tensor, B: tensor, diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3e32c8e..00d0c61 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear): has_fp16_weights=True, threshold=0.0, index=None, + memory_efficient_backward=False ): super(Linear8bitLt, self).__init__( input_features, output_features, bias @@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear): self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True @@ -255,10 +257,16 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CxB is not None: - # In this version, we convert 8-bit row major to turing/ampere format at each inference pass - # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. - del self.state.CxB + if not self.state.has_fp16_weights: + if not self.state.memory_efficient_backward and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + elif self.state.memory_efficient_backward and self.state.CxB is not None: + # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. + # Thus, we delete CxB from the state. + del self.state.CxB return out -- cgit v1.2.3 From ee325f02157cd23b37059e3dce5fb17cb1c1b137 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 06:18:44 +0300 Subject: clarified an exception message --- bitsandbytes/autograd/_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 271c690..008655d 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -373,7 +373,7 @@ class MatMul8bitLt(torch.autograd.Function): grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - if state.CBt: + if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: state.CxBt, state.SBt = F.transform( @@ -381,13 +381,13 @@ class MatMul8bitLt(torch.autograd.Function): ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) - elif state.CB: + elif state.CB is not None: CB = state.CB.half() SCB = (state.SCB.unsqueeze(1) / 127.0).half() CB *= SCB grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) else: - raise Exception('State must contain either CBt or CB matrix') + raise Exception('State must contain either CBt or CB matrix for backward') if req_gradBias: grad_bias = grad_output.sum(0) -- cgit v1.2.3 From d358999e9e2d98a834aaa38ffec1bef983d73fe6 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 06:26:15 +0300 Subject: refactoring --- bitsandbytes/autograd/_functions.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 008655d..642e516 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -185,11 +185,10 @@ class MatmulLtState: idx = None is_training = True has_fp16_weights = True + memory_efficient_backward = False use_pool = False formatB = F.get_special_format_str() - memory_efficient_backward = False - def reset_grads(self): self.CB = None self.CxB = None @@ -198,6 +197,7 @@ class MatmulLtState: self.CxBt = None self.SBt = None + self.CBt = None class MatMul8bitLt(torch.autograd.Function): @@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function): A_dtype = A.dtype A = A.to(torch.float16) - assert ( - A.dtype == torch.float16 - ), f"The input data type needs to be fp16 but {A.dtype} was found!" - # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() @@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None -matmul = MatMul8bitLt.apply - - def matmul( A: tensor, B: tensor, -- cgit v1.2.3 From 4dd475ced4adcbb31f6e1c42225f6d9b1e3be9f2 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 06:28:17 +0300 Subject: refactoring --- bitsandbytes/autograd/_functions.py | 1 - bitsandbytes/nn/modules.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 642e516..48f867f 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,6 +1,5 @@ import operator import torch -import bitsandbytes as bnb import bitsandbytes.functional as F from dataclasses import dataclass diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 00d0c61..e7e759d 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -221,9 +221,9 @@ class Linear8bitLt(nn.Linear): output_features, bias=True, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None, - memory_efficient_backward=False ): super(Linear8bitLt, self).__init__( input_features, output_features, bias -- cgit v1.2.3 From e2a75769f22bdc5465240c3f6701a1b002e8ab59 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 21:41:46 +0300 Subject: bug fix --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 48f867f..bdcbec5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -388,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function): grad_bias = grad_output.sum(0) # Cast grad_A back to grad_output_dtype - grad_output.to(grad_output_dtype) + grad_output = grad_output.to(grad_output_dtype) return grad_A, grad_B, None, grad_bias, None -- cgit v1.2.3 From cc4858c2fd48ef17a888b9d45bb35bb00e373eb8 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 20:46:04 +0300 Subject: some kind of warning or something when this is first executed to make people aware that a cast happens and the operation quantization is performed in fp16. --- bitsandbytes/autograd/_functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index bdcbec5..6d473e9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,4 +1,6 @@ import operator +import warnings + import torch import bitsandbytes.functional as F @@ -229,6 +231,8 @@ class MatMul8bitLt(torch.autograd.Function): # Cast A to fp16 A_dtype = A.dtype + if A_dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: temporarily casting input matrix from {A_dtype} to float16") A = A.to(torch.float16) # 1. Quantize A -- cgit v1.2.3 From 469d5a631d77d135f055d3aa012ac852a0ef0856 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:06:57 +0300 Subject: test_bf16 --- tests/test_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index bae26de..05da6ed 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -253,7 +253,7 @@ for c in req_grad: transpose = [(False, True), (False, False)] str_transpose = ["NT", "NN"] -dtype = [torch.float16] +dtype = [torch.float16, torch.bfloat16] has_fp16_weights = [True, False] has_bias = [True, False] values = list( -- cgit v1.2.3 From a9c7953e0a68a934a18a9495b20deeed9665b2a6 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:10:21 +0300 Subject: cast to half before double_quant --- tests/test_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 05da6ed..636fe86 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -354,7 +354,7 @@ def test_matmullt( state.SCB, SCBt, coo_tensorB, - ) = bnb.functional.double_quant(B2) + ) = bnb.functional.double_quant(B2.half()) B2 = state.CB if not transpose[0] and transpose[1]: -- cgit v1.2.3 From 140cdbe8767247bb9b8ea510755cceaa304b6859 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:12:58 +0300 Subject: check dtypes first --- tests/test_autograd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 636fe86..083d465 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -354,7 +354,7 @@ def test_matmullt( state.SCB, SCBt, coo_tensorB, - ) = bnb.functional.double_quant(B2.half()) + ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: @@ -367,6 +367,8 @@ def test_matmullt( if has_bias: out_torch += bias + assert out_bnb.dtype == torch.dtype + n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') -- cgit v1.2.3 From 9379df85d223dff18f0fa4adbaf60770700b262a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:13:23 +0300 Subject: check dtypes first --- tests/test_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 083d465..c47754b 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -367,7 +367,7 @@ def test_matmullt( if has_bias: out_torch += bias - assert out_bnb.dtype == torch.dtype + assert out_bnb.dtype == out_torch.dtype n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() -- cgit v1.2.3 From e29c5f5c41627668c650a2849e29599cd4f0bf3a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:22:04 +0300 Subject: clearer assertions --- tests/test_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index c47754b..5171c4f 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -367,7 +367,7 @@ def test_matmullt( if has_bias: out_torch += bias - assert out_bnb.dtype == out_torch.dtype + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() -- cgit v1.2.3 From fc4a135ed1604d1f6190af725bea912e19e8a88a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:24:26 +0300 Subject: clearer assertions --- bitsandbytes/autograd/_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6d473e9..f4a6d57 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -232,8 +232,8 @@ class MatMul8bitLt(torch.autograd.Function): # Cast A to fp16 A_dtype = A.dtype if A_dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: temporarily casting input matrix from {A_dtype} to float16") - A = A.to(torch.float16) + warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16") + A = A.to(torch.float16) # 1. Quantize A if len(A.shape) == 3: -- cgit v1.2.3 From a9fe0ff98c3293d972eb7a638b9887df0bc0d30d Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:34:22 +0300 Subject: recast to fp16 --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index f4a6d57..dc79bb1 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -275,7 +275,7 @@ class MatMul8bitLt(torch.autograd.Function): state.SCB, state.SCBt, coo_tensorB, - ) = F.double_quant(B) + ) = F.double_quant(B.to(torch.float16)) state.CxB, state.SB = F.transform(CB, to_order=formatB) else: has_grad = False -- cgit v1.2.3 From eac9aca460ee7afb6d0cbc61ae43a95120d34f29 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:38:09 +0300 Subject: cast bias too --- bitsandbytes/autograd/_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index dc79bb1..6d9229b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -234,6 +234,8 @@ class MatMul8bitLt(torch.autograd.Function): if A_dtype != torch.float16: warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16") A = A.to(torch.float16) + if bias is not None: + bias = bias.to(torch.float16) # 1. Quantize A if len(A.shape) == 3: -- cgit v1.2.3 From 7facedda38da928843e9ed0de1810d45ce1b9224 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:41:40 +0300 Subject: copypaste tolerances --- tests/test_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 5171c4f..28d9259 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -427,4 +427,4 @@ def test_matmullt( ) if req_grad[2]: - torch.testing.assert_allclose(gradBias1, gradBias2) + torch.testing.assert_allclose(gradBias1, gradBias2, atol=0.18, rtol=0.3) -- cgit v1.2.3 From d9ca0ed9051a21295e9be80ec08a6589ebd98222 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:44:28 +0300 Subject: un-fuse bias --- bitsandbytes/autograd/_functions.py | 8 +++++--- tests/test_autograd.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6d9229b..540d1ec 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -234,8 +234,6 @@ class MatMul8bitLt(torch.autograd.Function): if A_dtype != torch.float16: warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16") A = A.to(torch.float16) - if bias is not None: - bias = bias.to(torch.float16) # 1. Quantize A if len(A.shape) == 3: @@ -315,7 +313,11 @@ class MatMul8bitLt(torch.autograd.Function): C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) # we apply the fused bias here - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + + fused_bias = bias if bias.dtype == torch.float16 else None + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias) + if fused_bias is None and bias is not None: + output.add_(bias.to(output.dtype)) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 28d9259..5171c4f 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -427,4 +427,4 @@ def test_matmullt( ) if req_grad[2]: - torch.testing.assert_allclose(gradBias1, gradBias2, atol=0.18, rtol=0.3) + torch.testing.assert_allclose(gradBias1, gradBias2) -- cgit v1.2.3 From 56a074f6dc50ae923e7a810b7c2ca53cd2f6129e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:46:37 +0300 Subject: un-fuse bias --- bitsandbytes/autograd/_functions.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 540d1ec..7293637 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function): out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) # we apply the fused bias here - fused_bias = bias if bias.dtype == torch.float16 else None - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias) - if fused_bias is None and bias is not None: - output.add_(bias.to(output.dtype)) + if bias is None or bias.dtype == torch.float16: + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A_dtype) + + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A_dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - # Cast fp16 output back to A.dtype - output = output.to(A_dtype) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) -- cgit v1.2.3 From e9b87112eeaabe3dfb51bdf553abbb94d9093870 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:51:28 +0300 Subject: un-fuse bias --- bitsandbytes/autograd/_functions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7293637..538267b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -316,15 +316,14 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A_dtype) - + delayed_bias = None else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A_dtype).add_(bias) + delayed_bias = bias # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - output += torch.matmul(subA, state.subB) + output.addmm_(subA, state.subB) # 5. Save state ctx.state = state @@ -341,6 +340,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + output = output.to(A_dtype) + if delayed_bias is not None: + output.add_(delayed_bias) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) -- cgit v1.2.3 From 0de1a4494bd9246e5b1b3f2c7a0e4d4181fc644a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:53:49 +0300 Subject: change order --- bitsandbytes/autograd/_functions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 538267b..34b27d9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -357,6 +357,11 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0) # Cast grad_output to fp16 grad_output_dtype = grad_output.dtype @@ -367,8 +372,6 @@ class MatMul8bitLt(torch.autograd.Function): -1, grad_output.shape[-1] ).contiguous() - grad_A = grad_B = grad_bias = None - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) @@ -395,9 +398,6 @@ class MatMul8bitLt(torch.autograd.Function): else: raise Exception('State must contain either CBt or CB matrix for backward') - if req_gradBias: - grad_bias = grad_output.sum(0) - # Cast grad_A back to grad_output_dtype grad_output = grad_output.to(grad_output_dtype) -- cgit v1.2.3 From 647c976a74249d284b31e8403dfcbcbfa3e203a3 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:59:36 +0300 Subject: change order --- bitsandbytes/autograd/_functions.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 34b27d9..25ff1a5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -316,10 +316,10 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - delayed_bias = None + output = output.to(A_dtype) else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - delayed_bias = bias + output = output.to(A_dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -340,9 +340,6 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - output = output.to(A_dtype) - if delayed_bias is not None: - output.add_(delayed_bias) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) -- cgit v1.2.3 From 210b9ed9cef6a053d783c2d3926171b30ce6a969 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:00:45 +0300 Subject: debug assert --- bitsandbytes/autograd/_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 25ff1a5..7441a22 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -323,6 +323,7 @@ class MatMul8bitLt(torch.autograd.Function): # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: + assert subA.dtype == state.subB.dtype == output.dtype output.addmm_(subA, state.subB) # 5. Save state -- cgit v1.2.3 From 85bf5294a60ceba84b85f0634b349bc486cec635 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:01:25 +0300 Subject: debug assert --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7441a22..2aada07 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -323,7 +323,7 @@ class MatMul8bitLt(torch.autograd.Function): # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - assert subA.dtype == state.subB.dtype == output.dtype + assert subA.dtype == state.subB.dtype == output.dtype, (subA.dtype, state.subB.dtype, output.dtype) output.addmm_(subA, state.subB) # 5. Save state -- cgit v1.2.3 From e2b523d071c1dfe70c274a7ff945e859bc8f9e02 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:07:05 +0300 Subject: change typecast behavior --- bitsandbytes/autograd/_functions.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 2aada07..6868b75 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -230,16 +230,14 @@ class MatMul8bitLt(torch.autograd.Function): state.outlier_pool = GlobalOutlierPooler.get_instance() # Cast A to fp16 - A_dtype = A.dtype - if A_dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16") - A = A.to(torch.float16) + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16") # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A, threshold=state.threshold + A.to(torch.float16), threshold=state.threshold ) if state.threshold > 0.0 and coo_tensorA is not None: @@ -316,10 +314,10 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A_dtype) + output = output.to(A.dtype) else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A_dtype).add_(bias) + output = output.to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: -- cgit v1.2.3 From d6e25b5f5ea36c1565145da773fbf0f842b1c235 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:15:18 +0300 Subject: change typecast behavior --- bitsandbytes/autograd/_functions.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6868b75..0e594a5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -321,7 +321,6 @@ class MatMul8bitLt(torch.autograd.Function): # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - assert subA.dtype == state.subB.dtype == output.dtype, (subA.dtype, state.subB.dtype, output.dtype) output.addmm_(subA, state.subB) # 5. Save state @@ -330,6 +329,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.formatB = formatB ctx.grad_shape = input_shape ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias] + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype if requires_gradA or requires_gradB: ctx.tensors = (CAt, subA) @@ -348,7 +348,7 @@ class MatMul8bitLt(torch.autograd.Function): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, req_gradBias = ctx.req_grads + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states formatB = ctx.formatB @@ -357,25 +357,22 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradBias: # compute grad_bias first before changing grad_output dtype - grad_bias = grad_output.sum(0) + grad_bias = grad_output.sum(0).to(ctx.bias_dtype) # Cast grad_output to fp16 - grad_output_dtype = grad_output.dtype - grad_output = grad_output.to(torch.float16) - if len(grad_output.shape) == 3: grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.B_dtype) if state.threshold > 0.0 and subA is not None: - grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + grad_B[:, idx].addmm_(grad_output.t(), subA) if req_gradA: if state.CBt is not None: @@ -385,18 +382,16 @@ class MatMul8bitLt(torch.autograd.Function): state.CBt, to_order=formatB, transpose=True ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.A_dtype) + elif state.CB is not None: - CB = state.CB.half() + CB = state.CB.to(ctx.B_dtype) SCB = (state.SCB.unsqueeze(1) / 127.0).half() CB *= SCB - grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) + grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape).to(ctx.A_dtype) else: raise Exception('State must contain either CBt or CB matrix for backward') - # Cast grad_A back to grad_output_dtype - grad_output = grad_output.to(grad_output_dtype) - return grad_A, grad_B, None, grad_bias, None -- cgit v1.2.3 From 1145589f84d2ba4eb3b4a18fa33423298f5747c0 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:15:57 +0300 Subject: change typecast behavior --- bitsandbytes/autograd/_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 0e594a5..b54ac24 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -328,7 +328,6 @@ class MatMul8bitLt(torch.autograd.Function): ctx.formatB = formatB ctx.grad_shape = input_shape - ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias] ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype if requires_gradA or requires_gradB: @@ -357,7 +356,7 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradBias: # compute grad_bias first before changing grad_output dtype - grad_bias = grad_output.sum(0).to(ctx.bias_dtype) + grad_bias = grad_output.sum(0).to(ctx.dtype_bias) # Cast grad_output to fp16 if len(grad_output.shape) == 3: -- cgit v1.2.3 From 1da4880262ab5febbc55aa690e72e446e6b1eb42 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:19:22 +0300 Subject: change typecast behavior --- bitsandbytes/autograd/_functions.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index b54ac24..5499db9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -356,7 +356,7 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradBias: # compute grad_bias first before changing grad_output dtype - grad_bias = grad_output.sum(0).to(ctx.dtype_bias) + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) # Cast grad_output to fp16 if len(grad_output.shape) == 3: @@ -385,9 +385,8 @@ class MatMul8bitLt(torch.autograd.Function): elif state.CB is not None: CB = state.CB.to(ctx.B_dtype) - SCB = (state.SCB.unsqueeze(1) / 127.0).half() - CB *= SCB - grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape).to(ctx.A_dtype) + CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(ctx.B_dtype)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.A_dtype) else: raise Exception('State must contain either CBt or CB matrix for backward') -- cgit v1.2.3 From 5b169f18e4894a82b8681139727d45a4dd61c4b1 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:21:15 +0300 Subject: change typecast behavior --- bitsandbytes/autograd/_functions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 5499db9..93304f9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -369,7 +369,7 @@ class MatMul8bitLt(torch.autograd.Function): CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.B_dtype) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.dtype_B) if state.threshold > 0.0 and subA is not None: grad_B[:, idx].addmm_(grad_output.t(), subA) @@ -381,12 +381,12 @@ class MatMul8bitLt(torch.autograd.Function): state.CBt, to_order=formatB, transpose=True ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.A_dtype) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.B_dtype) - CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(ctx.B_dtype)) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.A_dtype) + CB = state.CB.to(ctx.dtype_B) + CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: raise Exception('State must contain either CBt or CB matrix for backward') -- cgit v1.2.3 From 14048a3c16bf3e60754b4218ec40a01e6a7f213c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:24:20 +0300 Subject: safer cast --- bitsandbytes/autograd/_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 93304f9..6a225b7 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -371,7 +371,7 @@ class MatMul8bitLt(torch.autograd.Function): gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.dtype_B) if state.threshold > 0.0 and subA is not None: - grad_B[:, idx].addmm_(grad_output.t(), subA) + grad_B[:, idx] += torch.mm(grad_output.t(), subA) if req_gradA: if state.CBt is not None: @@ -384,6 +384,7 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: + raise NotImplementedError("WIP") CB = state.CB.to(ctx.dtype_B) CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) -- cgit v1.2.3 From a214824f930e26580304192c5cb8c4242c7889c5 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:24:59 +0300 Subject: matmul -1- addmm --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6a225b7..91eec4a 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -321,7 +321,7 @@ class MatMul8bitLt(torch.autograd.Function): # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - output.addmm_(subA, state.subB) + output += torch.matmul(subA, state.subB) # 5. Save state ctx.state = state -- cgit v1.2.3 From 702cc72018eaa177b94a276043a6c069ff0da32b Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:26:46 +0300 Subject: debug asset --- bitsandbytes/autograd/_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 91eec4a..c3c2bf8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -321,6 +321,7 @@ class MatMul8bitLt(torch.autograd.Function): # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: + assert subA.dtype == state.subB.dtype, (subA.dtype, state.subB.dtype) output += torch.matmul(subA, state.subB) # 5. Save state -- cgit v1.2.3 From 45dc1983e92f9c3125948f416aafc6b96b3a6c15 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:28:03 +0300 Subject: cast properly --- bitsandbytes/autograd/_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c3c2bf8..03949de 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -294,7 +294,7 @@ class MatMul8bitLt(torch.autograd.Function): (outliers * state.SCB.view(-1, 1) / 127.0) .t() .contiguous() - .half() + .to(B.dtype) ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 @@ -321,7 +321,6 @@ class MatMul8bitLt(torch.autograd.Function): # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - assert subA.dtype == state.subB.dtype, (subA.dtype, state.subB.dtype) output += torch.matmul(subA, state.subB) # 5. Save state -- cgit v1.2.3 From 577275bd8c1b4191284c4fb34799d252ae8667a1 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:30:57 +0300 Subject: cast properly --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 03949de..5a83dfd 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -294,7 +294,7 @@ class MatMul8bitLt(torch.autograd.Function): (outliers * state.SCB.view(-1, 1) / 127.0) .t() .contiguous() - .to(B.dtype) + .to(A.dtype) ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 -- cgit v1.2.3 From e35e2c665a69647d829c48e22fba0230180c11e7 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:35:03 +0300 Subject: cast properly --- bitsandbytes/autograd/_functions.py | 2 +- tests/test_autograd.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 5a83dfd..36c392b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -231,7 +231,7 @@ class MatMul8bitLt(torch.autograd.Function): # Cast A to fp16 if A.dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16") + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") # 1. Quantize A if len(A.shape) == 3: diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 5171c4f..4e4282a 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -372,8 +372,10 @@ def test_matmullt( n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') + out_error_rate = 0.0175 if dtype == torch.float16 else 0.02 + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() <= n * 0.0175 + assert (idx == 0).sum().item() <= n * out_error_rate idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) assert (idx == 0).sum().item() <= n * 0.001 -- cgit v1.2.3 From cbfdf0b5efe4923ba4533c274ce83072b7e502b5 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:35:42 +0300 Subject: cast edge case --- bitsandbytes/autograd/_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 36c392b..d0e48b7 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -212,9 +212,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B -- cgit v1.2.3 From ab9dee062d791ef343ff5f9e8c2c85dc094219ed Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:36:46 +0300 Subject: cast edge case --- bitsandbytes/autograd/_functions.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d0e48b7..1d0002c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -221,9 +221,6 @@ class MatMul8bitLt(torch.autograd.Function): # 3. Matmul # 4. Mixed-precision decomposition matmul # 5. Save state - requires_gradA = A.requires_grad - requires_gradB = B.requires_grad - requires_gradBias = bias is not None and bias.requires_grad formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: @@ -330,7 +327,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.grad_shape = input_shape ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype - if requires_gradA or requires_gradB: + if any(ctx.needs_input_grad[:2]): ctx.tensors = (CAt, subA) ctx.tensor_states = (SCAt, state.idx) else: -- cgit v1.2.3 From fa8e07c7c5186e18d9e2d45042814fe6e8d76d5a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:38:02 +0300 Subject: more lenient threshold --- tests/test_autograd.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 4e4282a..0150fbb 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -372,10 +372,9 @@ def test_matmullt( n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') - out_error_rate = 0.0175 if dtype == torch.float16 else 0.02 idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() <= n * out_error_rate + assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.02) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) assert (idx == 0).sum().item() <= n * 0.001 -- cgit v1.2.3 From f6670329fb9b26dc5547cfef6da73bea75c548ca Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:42:23 +0300 Subject: bump threshold to 0.21 --- tests/test_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 0150fbb..cb82898 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -374,7 +374,7 @@ def test_matmullt( # print(f'abs error {err:.4f}') idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.02) + assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) assert (idx == 0).sum().item() <= n * 0.001 -- cgit v1.2.3 From 18f142e268603849d1756df87d35e2f94b5e4853 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:43:02 +0300 Subject: addmm_ --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 1d0002c..55bedee 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -318,7 +318,7 @@ class MatMul8bitLt(torch.autograd.Function): # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - output += torch.matmul(subA, state.subB) + output.addmm_(output, subA, state.subB) # 5. Save state ctx.state = state -- cgit v1.2.3 From 76ece2c126b5255fe973615adf986c4331f521ff Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:43:56 +0300 Subject: rollback --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 55bedee..1d0002c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -318,7 +318,7 @@ class MatMul8bitLt(torch.autograd.Function): # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - output.addmm_(output, subA, state.subB) + output += torch.matmul(subA, state.subB) # 5. Save state ctx.state = state -- cgit v1.2.3 From 579b8c782f5240d589ca65ef950054734db97ae1 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:47:58 +0300 Subject: reduce diff --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 1d0002c..6674a82 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -368,7 +368,7 @@ class MatMul8bitLt(torch.autograd.Function): gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.dtype_B) if state.threshold > 0.0 and subA is not None: - grad_B[:, idx] += torch.mm(grad_output.t(), subA) + grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: if state.CBt is not None: -- cgit v1.2.3 From 591f60395a1e9c62f291e23c91af45cc699f072c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:52:53 +0300 Subject: add memory efficient backward --- bitsandbytes/autograd/_functions.py | 1 - tests/test_modules.py | 24 +++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6674a82..daf9ba0 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -381,7 +381,6 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - raise NotImplementedError("WIP") CB = state.CB.to(ctx.dtype_B) CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) diff --git a/tests/test_modules.py b/tests/test_modules.py index c0b3311..53a675f 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -14,13 +14,15 @@ class MockArgs(object): class MLP8bit(torch.nn.Module): - def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): + def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): super(MLP8bit, self).__init__() self.fc1 = bnb.nn.Linear8bitLt( - dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold + dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) self.fc2 = bnb.nn.Linear8bitLt( - dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold + dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) def forward(self, x): @@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values] @pytest.mark.parametrize("threshold", values, ids=names) -def test_linear8bitlt_no_fp16_weights(threshold): +@pytest.mark.parametrize("memory_efficient_backward", [True, False]) +def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): l1 = ( - bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False) + bnb.nn.Linear8bitLt( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .cuda() .half() ) @@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.dtype == torch.int8 mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .half() .to("cuda") ) @@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.device.type == "cuda" mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .to(torch.float16) .to("cuda") ) @@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.device.type == "cuda" + def test_linear8bitlt_fp32_bias(): # casts model to fp16 -> int8 automatically l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() -- cgit v1.2.3 From 2cd047e35da3a421c4b491ff1a137e19b9c6c919 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:55:53 +0300 Subject: run backward --- tests/test_modules.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_modules.py b/tests/test_modules.py index 53a675f..d3992a9 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -554,11 +554,22 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.state.idx is not None if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" + if memory_efficient_backward: + b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) + o1 = mlp(b1) + assert o1.dtype == torch.float16 + assert o1.requires_grad + grad_proj = torch.randn_like(o1) + + (o1 * grad_proj).sum().backward() + + def test_linear8bitlt_fp32_bias(): -- cgit v1.2.3 From 7906dc4c9a6d741699a561e03956f1f7ee4b8abc Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:57:26 +0300 Subject: debugpritn --- bitsandbytes/autograd/_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index daf9ba0..88a33a5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -368,6 +368,7 @@ class MatMul8bitLt(torch.autograd.Function): gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.dtype_B) if state.threshold > 0.0 and subA is not None: + assert False, idx grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: -- cgit v1.2.3 From 4b4a9effd1fa88bc30bb1bc1e732d74e034c9d66 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:02:13 +0300 Subject: debugprint --- bitsandbytes/autograd/_functions.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 88a33a5..9928fbd 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -366,9 +366,8 @@ class MatMul8bitLt(torch.autograd.Function): CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.dtype_B) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: - assert False, idx grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: @@ -382,8 +381,7 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.dtype_B) - CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype)) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).div(127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: raise Exception('State must contain either CBt or CB matrix for backward') -- cgit v1.2.3 From 4da2227fcbc3803d680dff113403aecac1827bc3 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:03:21 +0300 Subject: debug --- bitsandbytes/autograd/_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 9928fbd..407f14b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -370,6 +370,8 @@ class MatMul8bitLt(torch.autograd.Function): if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + raise NotImplementedError("!!") + if req_gradA: if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") -- cgit v1.2.3 From 5d658171017473b54825dfeac21718f4e4be4eca Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:09:24 +0300 Subject: debug --- bitsandbytes/autograd/_functions.py | 2 -- bitsandbytes/nn/modules.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 407f14b..9928fbd 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -370,8 +370,6 @@ class MatMul8bitLt(torch.autograd.Function): if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) - raise NotImplementedError("!!") - if req_gradA: if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index e7e759d..9250fec 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -237,7 +237,9 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights + ) def init_8bit_state(self): self.state.CB = self.weight.CB -- cgit v1.2.3 From d9b8789818191f9992733394d7ccfa00a63d4dba Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:13:58 +0300 Subject: debug --- tests/test_modules.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_modules.py b/tests/test_modules.py index d3992a9..c6e7f85 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -545,6 +545,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): .to(torch.float16) .to("cuda") ) + w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone() for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -567,8 +568,15 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert o1.requires_grad grad_proj = torch.randn_like(o1) + mlp.zero_grad() (o1 * grad_proj).sum().backward() + grad_ref = grad_proj.flatten(2) @ w2 @ w1 + assert torch.allclose(b1.grad, grad_ref) + + + + -- cgit v1.2.3 From 6a826c41a6e4b9d8e6d2b8c768d769587cc85672 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:20:34 +0300 Subject: pre-cast --- tests/test_modules.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index c6e7f85..01c9389 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -538,14 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" - mlp = ( - MLP8bit( + mlp = MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward ) - .to(torch.float16) - .to("cuda") - ) w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone() + mlp = mlp.cuda().half() for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() -- cgit v1.2.3 From 37f805bb44cd577422b792ae5bd1110f3eec69f6 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:21:12 +0300 Subject: debug --- tests/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 01c9389..8108b35 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -567,7 +567,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): mlp.zero_grad() (o1 * grad_proj).sum().backward() - + assert False, (w1, w2) grad_ref = grad_proj.flatten(2) @ w2 @ w1 assert torch.allclose(b1.grad, grad_ref) -- cgit v1.2.3 From 95dafc6475bc36490e213269d1028adfd4f75363 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:22:31 +0300 Subject: cast before allclose --- tests/test_modules.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 8108b35..dbadea9 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -541,8 +541,8 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): mlp = MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward ) - w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone() - mlp = mlp.cuda().half() + w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone() # note: we grad original weights before quantization, + mlp = mlp.cuda().half() # and this line triggers quantization for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -567,8 +567,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): mlp.zero_grad() (o1 * grad_proj).sum().backward() - assert False, (w1, w2) - grad_ref = grad_proj.flatten(2) @ w2 @ w1 + grad_ref = grad_proj.flatten(2) @ w2.to(grad_proj.device) @ w1.to(grad_proj.device) assert torch.allclose(b1.grad, grad_ref) -- cgit v1.2.3 From 28a9313ddcf09c40d6cea75b3fd932ef09b4c715 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:24:27 +0300 Subject: cast before allclose --- tests/test_modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index dbadea9..bb65edb 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -541,7 +541,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): mlp = MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward ) - w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone() # note: we grad original weights before quantization, + w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, mlp = mlp.cuda().half() # and this line triggers quantization for i in range(100): @@ -567,7 +567,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): mlp.zero_grad() (o1 * grad_proj).sum().backward() - grad_ref = grad_proj.flatten(2) @ w2.to(grad_proj.device) @ w1.to(grad_proj.device) + grad_ref = grad_proj.flatten(2) @ w2.to() @ w1.to(grad_proj.device) assert torch.allclose(b1.grad, grad_ref) -- cgit v1.2.3 From 725cc729931e21fd57377caba702da1ebecaa2ff Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:24:44 +0300 Subject: cast device --- tests/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index bb65edb..8e009b4 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -567,7 +567,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): mlp.zero_grad() (o1 * grad_proj).sum().backward() - grad_ref = grad_proj.flatten(2) @ w2.to() @ w1.to(grad_proj.device) + grad_ref = grad_proj.flatten(2) @ w2 @ w1 assert torch.allclose(b1.grad, grad_ref) -- cgit v1.2.3 From e4086a2758c171993f47b46cf0980030afe6db4a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:24:57 +0300 Subject: cast device --- tests/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 8e009b4..049858c 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -567,7 +567,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): mlp.zero_grad() (o1 * grad_proj).sum().backward() - grad_ref = grad_proj.flatten(2) @ w2 @ w1 + grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() assert torch.allclose(b1.grad, grad_ref) -- cgit v1.2.3 From 01b4c6a048abad182fc7c40038c232ce1493c54f Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:25:56 +0300 Subject: cast device --- tests/test_modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 049858c..d2ef856 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -568,7 +568,8 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): mlp.zero_grad() (o1 * grad_proj).sum().backward() grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() - assert torch.allclose(b1.grad, grad_ref) + scale = grad_ref.abs().mean() + assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.1 * scale) -- cgit v1.2.3 From 32a9a88f987e26c5b891ce1f881f008307b4548c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:26:12 +0300 Subject: cast device --- tests/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index d2ef856..163edf6 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -569,7 +569,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): (o1 * grad_proj).sum().backward() grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() scale = grad_ref.abs().mean() - assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.1 * scale) + assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.01 * scale) -- cgit v1.2.3 From cff3a7159943369841675dbc1076e555ffb2260b Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:26:25 +0300 Subject: cast device --- tests/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 163edf6..faf91b8 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -569,7 +569,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): (o1 * grad_proj).sum().backward() grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() scale = grad_ref.abs().mean() - assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.01 * scale) + assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) -- cgit v1.2.3 From 9b7d307b8cc9d88310fe0c0548e4a0fb094f45d3 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 20 Sep 2022 06:36:32 +0300 Subject: review --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 9928fbd..2ddb406 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -381,7 +381,7 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).div(127.0)) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: raise Exception('State must contain either CBt or CB matrix for backward') -- cgit v1.2.3 From a07825ac31eb5585bd75f9788880536d5fc77f3a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 20 Sep 2022 06:40:36 +0300 Subject: review --- tests/test_modules.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index faf91b8..235acde 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -569,12 +569,10 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): (o1 * grad_proj).sum().backward() grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() scale = grad_ref.abs().mean() - assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) - - - - + torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) + idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) + assert (idx == 0).sum().item() <= b1.numel() * 0.0 def test_linear8bitlt_fp32_bias(): -- cgit v1.2.3 From 292a47871603cc1ebe620221358d571a8f5c6d8f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 20 Sep 2022 06:42:05 +0300 Subject: set threshold --- tests/test_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 235acde..2879846 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -572,7 +572,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) - assert (idx == 0).sum().item() <= b1.numel() * 0.0 + assert (idx == 0).sum().item() <= b1.numel() * 0.005 def test_linear8bitlt_fp32_bias(): -- cgit v1.2.3 From 76ce9aa6da7d68d2463f0f3e99532ab5b6db58a8 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 20 Sep 2022 06:51:25 +0300 Subject: try fp32 --- tests/test_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index cb82898..40bb441 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -253,7 +253,7 @@ for c in req_grad: transpose = [(False, True), (False, False)] str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.bfloat16] +dtype = [torch.float16, torch.bfloat16, torch.float32] has_fp16_weights = [True, False] has_bias = [True, False] values = list( -- cgit v1.2.3