From 8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Tue, 23 Aug 2022 23:39:54 +0300 Subject: add memory efficient backward --- bitsandbytes/nn/modules.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b222f54..ef7fefc 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, + SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None + cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt - del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) + setattr(self, "SCBt", SCBt) return self @@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB + new_param.SCB = self.SCBt return new_param @@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB + self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None + self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training @@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: + # if not self.state.has_fp16_weights and self.state.CB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + # del self.state.CB + # self.weight.data = self.state.CxB return out -- cgit v1.2.3 From 656de8ed110fce4e94b4f9d48494ecc5f8e04970 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Tue, 23 Aug 2022 23:53:43 +0300 Subject: minor fixes --- bitsandbytes/nn/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ef7fefc..360a182 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -212,7 +212,7 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB - new_param.SCB = self.SCBt + new_param.SCBt = self.SCBt return new_param -- cgit v1.2.3 From ef2936a90d903d0f9a27e16ecb7f839f2c4d9ba1 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Wed, 24 Aug 2022 01:33:04 +0300 Subject: delete CxB from state --- bitsandbytes/nn/modules.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 360a182..03ffd3b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -260,11 +260,10 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - # if not self.state.has_fp16_weights and self.state.CB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - # del self.state.CB - # self.weight.data = self.state.CxB + if not self.state.has_fp16_weights and self.state.CxB is not None: + # In this version, we convert 8-bit row major to turing/ampere format at each inference pass + # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. + del self.state.CxB return out -- cgit v1.2.3 From 4d6174bc6336fb6fba712f1d2c903de1de677747 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Thu, 25 Aug 2022 19:09:23 +0300 Subject: memory efficient fp16 backward --- bitsandbytes/nn/modules.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 03ffd3b..3e32c8e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, - SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None - cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt + del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) - setattr(self, "SCBt", SCBt) return self @@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB - new_param.SCBt = self.SCBt return new_param @@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB - self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None - self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training -- cgit v1.2.3 From 42b5fc9acc4b59a6d90c662eb26099ac25907c7f Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 05:51:29 +0300 Subject: add memory effcient backward option --- bitsandbytes/nn/modules.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3e32c8e..00d0c61 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear): has_fp16_weights=True, threshold=0.0, index=None, + memory_efficient_backward=False ): super(Linear8bitLt, self).__init__( input_features, output_features, bias @@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear): self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True @@ -255,10 +257,16 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CxB is not None: - # In this version, we convert 8-bit row major to turing/ampere format at each inference pass - # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. - del self.state.CxB + if not self.state.has_fp16_weights: + if not self.state.memory_efficient_backward and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + elif self.state.memory_efficient_backward and self.state.CxB is not None: + # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. + # Thus, we delete CxB from the state. + del self.state.CxB return out -- cgit v1.2.3 From 4dd475ced4adcbb31f6e1c42225f6d9b1e3be9f2 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 06:28:17 +0300 Subject: refactoring --- bitsandbytes/nn/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 00d0c61..e7e759d 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -221,9 +221,9 @@ class Linear8bitLt(nn.Linear): output_features, bias=True, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None, - memory_efficient_backward=False ): super(Linear8bitLt, self).__init__( input_features, output_features, bias -- cgit v1.2.3 From 5d658171017473b54825dfeac21718f4e4be4eca Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 01:09:24 +0300 Subject: debug --- bitsandbytes/nn/modules.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index e7e759d..9250fec 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -237,7 +237,9 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights + ) def init_8bit_state(self): self.state.CB = self.weight.CB -- cgit v1.2.3