summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
commitc771b3a75a6ebbfbfc398a028a477246b0799cf0 (patch)
tree158353d531766ed133be34d3c5085da6e8a4d01e /tests/test_modules.py
parent4cd7ea62b2f51c68aacde2f62e7141765e476111 (diff)
Most tests passing.
Diffstat (limited to 'tests/test_modules.py')
-rw-r--r--tests/test_modules.py478
1 files changed, 453 insertions, 25 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index a0379cb..a2c950b 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -1,42 +1,470 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
import pytest
import torch
+
+from itertools import product
+from torch import nn
+
import bitsandbytes as bnb
+class MockArgs(object):
+ def __init__(self, initial_data):
+ for key in initial_data:
+ setattr(self, key, initial_data[key])
+
+class MLP8bit(torch.nn.Module):
+ def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
+ super(MLP8bit, self).__init__()
+ self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
+ self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.fc2(x)
+ return x
+
+
+def get_args():
+ args = MockArgs([])
+ args.quant_type = 'vector'
+ args.use_8bit_training = 'full'
+ args.clip_freq = 9999
+ return args
+
+def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
+ idx = torch.isclose(a, b, rtol, atol)
+ sumval = (idx==0).sum().item()
+ if sumval > count:
+ print(f'Too many values not close: assert {sumval} < {count}')
+ torch.testing.assert_allclose(a, b, rtol, atol)
+
+class LinearFunction(torch.autograd.Function):
+
+ @staticmethod
+ def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ norm = math.sqrt(math.pi)/math.sqrt(2.0)
+ #std = torch.abs(x).mean()*norm
+ std = torch.std(x)
+ max1 = std*trim_value
+ x = x/max1*127
+ x = round_func(x)
+ x[x > 127] = 127
+ x[x < -127] = -127
+ x = x/127*max1
+
+ return x
+
+ def quant(x, quant_type, dim=1):
+ if quant_type == 'linear':
+ max1 = torch.abs(x).max().float()
+ xq = torch.round(x/max1*127).to(torch.int8)
+ return xq, max1
+ elif quant_type == 'vector':
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ xq = torch.round(x/max1*127).to(torch.int8)
+ return xq, max1
+ elif quant_type == 'min-max':
+ maxA = torch.amax(x, dim=dim, keepdim=True).float()
+ minA = torch.amin(x, dim=dim, keepdim=True).float()
+ scale = (maxA-minA)/2.0
+ xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
+ return xq, (minA.float(), scale.float())
+ else: return None
+
+ def dequant(xq, S1, S2, dtype, quant_type):
+ if quant_type == 'linear':
+ norm = S1*S2/(127*127)
+ # double cast needed to prevent overflows
+ return (xq.float()*norm).to(dtype)
+ elif quant_type == 'vector':
+ x = xq.float()
+ if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
+ if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
+ #print(x.shape, S1.shape, S2.shape)
+ if len(S1.shape) == 2:
+ x *= S1.t()/127
+ else:
+ x *= S1/127
+ x *= S2/127
+ return x.to(dtype)
+ else: return None
+
+ def dequant_min_max(xq, A, B, SA, SB, dtype):
+ offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ x = xq.float()
+ if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
+ if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
+ if len(SB.shape) == 2:
+ x *= SB.t()/127
+ else:
+ x *= SB/127
+ x *= SA[1]/127
+ x +=offset
+ return x.to(dtype)
+
+
+ def get_8bit_linear(x, stochastic=False):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ max1 = torch.abs(x).max()
+ x = x/max1*127
+ x = round_func(x)/127*max1
+ #x = torch.round(x)/128*max1
+ return x
+
+ @staticmethod
+ def get_8bit_vector_wise(x, dim, stochastic=False):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ max1[max1==0] = 1.0
+ x = (x*127)/max1
+ x = round_func(x)/127*max1
+ return x
+
+ @staticmethod
+ def round_stoachastic(x):
+ sign = torch.sign(x)
+ absx = torch.abs(x)
+ decimal = absx-torch.floor(absx)
+ rdm = torch.rand_like(decimal)
+ return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))
+
+ @staticmethod
+ def fake_8bit_storage(w, exponent_bits):
+ code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
+ absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
+ out = bnb.functional.dequantize_blockwise(absmax, C, code)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_quantile(w, args):
+ code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
+ #C = bnb.functional.quantize_no_absmax(code, w)
+ #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
+ #print(out)
+ #out = out.half()
+ code /= torch.max(torch.abs(code))
+ absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
+ out = bnb.functional.dequantize_blockwise(absmax, C, code)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_stoachstic(w):
+ rand = torch.rand(1024, device=w.device)
+ absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
+ out = bnb.functional.dequantize_blockwise(absmax, C)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_with_max(w, topk=8):
+ blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
+ max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
+ idx = idx[:, :topk]
+ max_val = max_val[:, :topk]
+
+ mask = torch.zeros_like(blocked_w)
+ mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
+ mask = mask.bool()
+
+ # 1. zero out max values
+ # 2. quantize + dequantize
+ # 3. write back max values
+ # 4. copy matrix back to weight
+
+ values = blocked_w[mask]
+ blocked_w[mask] = 0
+
+ code = bnb.functional.create_dynamic_map()
+ code = code.to(w.device)
+ absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
+ bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)
+
+ blocked_w[mask] = values
+
+ unblocked_w = blocked_w.flatten().view(w.shape)
+
+ w.copy_(unblocked_w)
+ return unblocked_w
+
+
+ @staticmethod
+ def forward(ctx, x, weight, bias=None, args=None):
+ if args.use_8bit_training != 'off':
+ weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
+ outputq = bnb.functional.igemm(x8, weight8.t())
+ output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
+ #if torch.rand(1) < 0.01:
+ #output32 = torch.matmul(x, weight.t())
+ #err = torch.abs(output-output32).float()
+ #relerr = err/(torch.abs(output32).float()+1e-8)
+ #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
+ else:
+ #output = torch.matmul(x, weight.t())
+ output = torch.einsum('bsi,oi->bso', x, weight)
+
+ ctx.save_for_backward(x, weight, bias)
+ ctx.args = args
+
+ if bias is not None:
+ output += bias.unsqueeze(0).expand_as(output)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x, weight, bias = ctx.saved_tensors
+ args = ctx.args
+ stochastic = False
+ grad_input = grad_weight = grad_bias = None
+ if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)
+
+ # weight and x are already 8bit
+ # -> transform grad_output to 8-bit
+ if args.use_8bit_training == 'forward+wgrad':
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
+ grad_weight8 = bnb.functional.igemm(grad_output8, x8)
+ grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+
+ #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
+
+ grad_input = grad_output.matmul(weight)
+ elif args.use_8bit_training == 'full':
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
+ grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
+ bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
+ grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
+ weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
+ grad_input8 = bnb.functional.igemm(grad_output8, weight8)
+ grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
+
+ else:
+ grad_input = grad_output.matmul(weight)
+ grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)
-@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
-def test_embeddings(embcls):
- bnb.optim.GlobalOptimManager.get_instance().initialize()
- emb1 = torch.nn.Embedding(100, 512).cuda()
- emb2 = embcls(100, 512).cuda()
+ return grad_input, grad_weight, grad_bias, None
- adam1 = bnb.optim.Adam8bit(emb1.parameters())
- adam2 = bnb.optim.Adam8bit(emb2.parameters())
+class Linear8bit(nn.Module):
+ def __init__(self, input_features, output_features, bias=True, args=None):
+ super(Linear8bit, self).__init__()
+ self.input_features = input_features
+ self.output_features = output_features
+ self.args = args
- batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()
+ self.weight = nn.Parameter(torch.empty(output_features, input_features))
+ if bias:
+ self.bias = nn.Parameter(torch.empty(output_features))
+ else:
+ self.register_parameter('bias', None)
+ torch.nn.init.xavier_uniform_(self.weight)
+ if self.bias is not None:
+ torch.nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ self.args.training = self.training
+
+ return LinearFunction.apply(x, self.weight, self.bias, self.args)
+
+
+
+def test_linear8bit():
+ l0 = torch.nn.Linear(32, 64).cuda().half()
+ l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
+ l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
+ l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()
+
+ l0.weight.data = l2.weight.data.clone()
+ l0.bias.data = l2.bias.data.clone()
+
+ l1.weight.data = l2.weight.data.clone()
+ l1.bias.data = l2.bias.data.clone()
+
+ l3.weight.data = l2.weight.data.clone()
+ l3.bias.data = l2.bias.data.clone()
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ t = torch.randn(16, 8, 64, device='cuda').half()
+ b2 = b1.clone()
+ b3 = b1.clone()
+ b0 = b1.clone()
+
+ o0 = l0(b0)
+ o1 = l1(b1)
+ o2 = l2(b2)
+ o3 = l3(b3)
+
+ assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
+ assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)
+
+ loss0 = torch.nn.functional.mse_loss(o0, t)
+ loss1 = torch.nn.functional.mse_loss(o1, t)
+ loss2 = torch.nn.functional.mse_loss(o2, t)
+ loss3 = torch.nn.functional.mse_loss(o3, t)
+
+ loss0.backward()
+ loss1.backward()
+ loss2.backward()
+ loss3.backward()
+
+ assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
+ assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
+ assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
+ assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
+
+ err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
+ err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
+ err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()
+
+ assert err1*0.8 < err2
+ assert err2*0.8 < err3
+ assert err3*0.8 < err1
+
+ l0.weight.grad = None
+ l1.weight.grad = None
+ l2.weight.grad = None
+ l3.weight.grad = None
+ l0.bias.grad = None
+ l1.bias.grad = None
+ l2.bias.grad = None
+ l3.bias.grad = None
+
+
+threshold = [0.0, 3.0]
+values = threshold
+names = ['threshold_{0}'.format(vals) for vals in values]
+@pytest.mark.parametrize("threshold", values, ids=names)
+def test_linear8bitlt_inference(threshold):
+ l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
+ assert l1.weight.device.type == 'cuda'
+ assert l1.weight.dtype == torch.float16
+
+ l1.eval()
for i in range(100):
- batch = batches[i]
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ if i == 1:
+ assert l1.state.CxB is not None
+
+def test_linear8bitlt_accumulated_gradient():
+ l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
+ l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
+ l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
+ l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
+ l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
+ l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
+ opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
+ opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
+
+ acc_steps = 10
+
- embedded1 = emb1(batch)
- embedded2 = emb2(batch)
+ for i in range(10):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ o2 = l2(b1)
+ loss1 = o1.mean()
+ loss2 = o2.mean()
+ loss1.backward()
+ loss2.backward()
+ if i == 2:
+ assert l1[0].state.CxB is not None
+ assert l1[1].state.CxB is not None
- l1 = embedded1.mean()
- l2 = embedded2.mean()
+ if i > 0 and i % acc_steps == 0:
+ opt1.step()
+ opt1.zero_grad(True)
+ opt2.step()
+ opt2.zero_grad(True)
+ assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
+ assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
+ # we do this copy because otherwise we have small divergences over time that add up
+ l1[0].weight.data.copy_(l2[0].weight.data)
+ l1[1].weight.data.copy_(l2[1].weight.data)
+ else:
+ torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
+ torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
- l1.backward()
- l2.backward()
- adam1.step()
- adam2.step()
+threshold = [0.0, 2.0]
+values = threshold
+names = ['threshold_{0}'.format(vals) for vals in values]
+@pytest.mark.parametrize("threshold", values, ids=names)
+def test_linear8bitlt_no_fp16_weights(threshold):
+ l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ assert l1.weight.dtype == torch.int8
- adam1.zero_grad()
- adam2.zero_grad()
+ l1.eval()
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ assert o1.dtype == torch.float16
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
- assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
- assert adam2.state[emb2.weight]['state1'].dtype == torch.float32
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+ assert mlp.fc1.weight.device.type == 'cuda'
+ assert mlp.fc2.weight.device.type == 'cuda'
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+ assert mlp.fc1.weight.device.type == 'cuda'
+ assert mlp.fc2.weight.device.type == 'cuda'