summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
diff options
context:
space:
mode:
authorTitus von Koeller <titus@vonkoeller.com>2022-08-01 03:31:48 -0700
committerTitus von Koeller <titus@vonkoeller.com>2022-08-01 03:31:48 -0700
commitbfa0e33294f2b1dc25e65a33be2397f989824298 (patch)
tree396b5d722fdd79da068882ca7376e3636fcb3bb8 /tests/test_modules.py
parent597a8521b29e90958c31e47421016494da998648 (diff)
ran black and isort for coherent code formatting
Diffstat (limited to 'tests/test_modules.py')
-rw-r--r--tests/test_modules.py297
1 files changed, 175 insertions, 122 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index a2c950b..6b8d641 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -1,21 +1,27 @@
+from itertools import product
+
import pytest
import torch
-
-from itertools import product
from torch import nn
import bitsandbytes as bnb
+
class MockArgs(object):
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
+
class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
super(MLP8bit, self).__init__()
- self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
- self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)
+ self.fc1 = bnb.nn.Linear8bitLt(
+ dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
+ )
+ self.fc2 = bnb.nn.Linear8bitLt(
+ dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
+ )
def forward(self, x):
x = self.fc1(x)
@@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module):
def get_args():
args = MockArgs([])
- args.quant_type = 'vector'
- args.use_8bit_training = 'full'
+ args.quant_type = "vector"
+ args.use_8bit_training = "full"
args.clip_freq = 9999
return args
+
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
idx = torch.isclose(a, b, rtol, atol)
- sumval = (idx==0).sum().item()
+ sumval = (idx == 0).sum().item()
if sumval > count:
- print(f'Too many values not close: assert {sumval} < {count}')
+ print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol)
-class LinearFunction(torch.autograd.Function):
+class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
- norm = math.sqrt(math.pi)/math.sqrt(2.0)
- #std = torch.abs(x).mean()*norm
+ norm = math.sqrt(math.pi) / math.sqrt(2.0)
+ # std = torch.abs(x).mean()*norm
std = torch.std(x)
- max1 = std*trim_value
- x = x/max1*127
+ max1 = std * trim_value
+ x = x / max1 * 127
x = round_func(x)
x[x > 127] = 127
x[x < -127] = -127
- x = x/127*max1
+ x = x / 127 * max1
return x
def quant(x, quant_type, dim=1):
- if quant_type == 'linear':
+ if quant_type == "linear":
max1 = torch.abs(x).max().float()
- xq = torch.round(x/max1*127).to(torch.int8)
+ xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
- elif quant_type == 'vector':
+ elif quant_type == "vector":
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
- xq = torch.round(x/max1*127).to(torch.int8)
+ xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
- elif quant_type == 'min-max':
+ elif quant_type == "min-max":
maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float()
- scale = (maxA-minA)/2.0
- xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
+ scale = (maxA - minA) / 2.0
+ xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
return xq, (minA.float(), scale.float())
- else: return None
+ else:
+ return None
def dequant(xq, S1, S2, dtype, quant_type):
- if quant_type == 'linear':
- norm = S1*S2/(127*127)
+ if quant_type == "linear":
+ norm = S1 * S2 / (127 * 127)
# double cast needed to prevent overflows
- return (xq.float()*norm).to(dtype)
- elif quant_type == 'vector':
+ return (xq.float() * norm).to(dtype)
+ elif quant_type == "vector":
x = xq.float()
- if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
- if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
- #print(x.shape, S1.shape, S2.shape)
+ if len(xq.shape) == 2 and len(S1.shape) == 3:
+ S1 = S1.squeeze(0)
+ if len(xq.shape) == 2 and len(S2.shape) == 3:
+ S2 = S2.squeeze(0)
+ # print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2:
- x *= S1.t()/127
+ x *= S1.t() / 127
else:
- x *= S1/127
- x *= S2/127
+ x *= S1 / 127
+ x *= S2 / 127
return x.to(dtype)
- else: return None
+ else:
+ return None
def dequant_min_max(xq, A, B, SA, SB, dtype):
- offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
- if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
- if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
+ if len(xq.shape) == 2 and len(SB.shape) == 3:
+ SB = SB.squeeze(0)
+ if len(xq.shape) == 2 and len(SA.shape) == 3:
+ SA = SA.squeeze(0)
if len(SB.shape) == 2:
- x *= SB.t()/127
+ x *= SB.t() / 127
else:
- x *= SB/127
- x *= SA[1]/127
- x +=offset
+ x *= SB / 127
+ x *= SA[1] / 127
+ x += offset
return x.to(dtype)
-
def get_8bit_linear(x, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.abs(x).max()
- x = x/max1*127
- x = round_func(x)/127*max1
- #x = torch.round(x)/128*max1
+ x = x / max1 * 127
+ x = round_func(x) / 127 * max1
+ # x = torch.round(x)/128*max1
return x
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
- max1[max1==0] = 1.0
- x = (x*127)/max1
- x = round_func(x)/127*max1
+ max1[max1 == 0] = 1.0
+ x = (x * 127) / max1
+ x = round_func(x) / 127 * max1
return x
@staticmethod
def round_stoachastic(x):
sign = torch.sign(x)
absx = torch.abs(x)
- decimal = absx-torch.floor(absx)
+ decimal = absx - torch.floor(absx)
rdm = torch.rand_like(decimal)
- return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))
+ return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
@staticmethod
def fake_8bit_storage(w, exponent_bits):
@@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
- #C = bnb.functional.quantize_no_absmax(code, w)
- #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
- #print(out)
- #out = out.half()
+ # C = bnb.functional.quantize_no_absmax(code, w)
+ # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
+ # print(out)
+ # out = out.half()
code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
@@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def fake_8bit_storage_with_max(w, topk=8):
- blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
+ blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk]
max_val = max_val[:, :topk]
@@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function):
w.copy_(unblocked_w)
return unblocked_w
-
@staticmethod
def forward(ctx, x, weight, bias=None, args=None):
- if args.use_8bit_training != 'off':
+ if args.use_8bit_training != "off":
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
- #if torch.rand(1) < 0.01:
- #output32 = torch.matmul(x, weight.t())
- #err = torch.abs(output-output32).float()
- #relerr = err/(torch.abs(output32).float()+1e-8)
- #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
+ # if torch.rand(1) < 0.01:
+ # output32 = torch.matmul(x, weight.t())
+ # err = torch.abs(output-output32).float()
+ # relerr = err/(torch.abs(output32).float()+1e-8)
+ # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else:
- #output = torch.matmul(x, weight.t())
- output = torch.einsum('bsi,oi->bso', x, weight)
+ # output = torch.matmul(x, weight.t())
+ output = torch.einsum("bsi,oi->bso", x, weight)
ctx.save_for_backward(x, weight, bias)
ctx.args = args
@@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function):
args = ctx.args
stochastic = False
grad_input = grad_weight = grad_bias = None
- if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)
+ if bias is not None and ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum(0)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
- if args.use_8bit_training == 'forward+wgrad':
- grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ if args.use_8bit_training == "forward+wgrad":
+ grad_output8, S1 = LinearFunction.quant(
+ grad_output, args.quant_type, dim=[0, 1]
+ )
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
- grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+ grad_weight = LinearFunction.dequant(
+ grad_weight8, S1, S2, grad_output.dtype, args.quant_type
+ )
- #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
+ # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight)
- elif args.use_8bit_training == 'full':
- grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ elif args.use_8bit_training == "full":
+ grad_output8, S1 = LinearFunction.quant(
+ grad_output, args.quant_type, dim=[0, 1]
+ )
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
- grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+ grad_weight = LinearFunction.dequant(
+ grad_weight8, S1, S2, grad_output.dtype, args.quant_type
+ )
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
- grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
+ grad_input = LinearFunction.dequant(
+ grad_input8, S1, S3, grad_output.dtype, args.quant_type
+ )
else:
grad_input = grad_output.matmul(weight)
- grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)
+ grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
return grad_input, grad_weight, grad_bias, None
+
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__()
@@ -263,7 +286,7 @@ class Linear8bit(nn.Module):
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
- self.register_parameter('bias', None)
+ self.register_parameter("bias", None)
torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
@@ -275,12 +298,11 @@ class Linear8bit(nn.Module):
return LinearFunction.apply(x, self.weight, self.bias, self.args)
-
def test_linear8bit():
l0 = torch.nn.Linear(32, 64).cuda().half()
- l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
+ l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half()
l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
- l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()
+ l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half()
l0.weight.data = l2.weight.data.clone()
l0.bias.data = l2.bias.data.clone()
@@ -292,8 +314,8 @@ def test_linear8bit():
l3.bias.data = l2.bias.data.clone()
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
- t = torch.randn(16, 8, 64, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
+ t = torch.randn(16, 8, 64, device="cuda").half()
b2 = b1.clone()
b3 = b1.clone()
b0 = b1.clone()
@@ -318,16 +340,20 @@ def test_linear8bit():
assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
- assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
- assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
+ assert_all_approx_close(
+ l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
+ )
+ assert_all_approx_close(
+ l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
+ )
- err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
- err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
- err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()
+ err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item()
+ err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item()
+ err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item()
- assert err1*0.8 < err2
- assert err2*0.8 < err3
- assert err3*0.8 < err1
+ assert err1 * 0.8 < err2
+ assert err2 * 0.8 < err3
+ assert err3 * 0.8 < err1
l0.weight.grad = None
l1.weight.grad = None
@@ -341,23 +367,28 @@ def test_linear8bit():
threshold = [0.0, 3.0]
values = threshold
-names = ['threshold_{0}'.format(vals) for vals in values]
+names = ["threshold_{0}".format(vals) for vals in values]
+
+
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
- l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
- assert l1.weight.device.type == 'cuda'
+ l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
+ assert l1.weight.device.type == "cuda"
assert l1.weight.dtype == torch.float16
l1.eval()
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
if i == 1:
assert l1.state.CxB is not None
+
def test_linear8bitlt_accumulated_gradient():
- l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
- l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
+ l1 = torch.nn.Sequential(
+ *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
+ )
+ l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
@@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps = 10
-
for i in range(10):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
o2 = l2(b1)
loss1 = o1.mean()
@@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient():
opt1.zero_grad(True)
opt2.step()
opt2.zero_grad(True)
- assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
- assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
+ assert_all_approx_close(
+ l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
+ )
+ assert_all_approx_close(
+ l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
+ )
# we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
@@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0]
values = threshold
-names = ['threshold_{0}'.format(vals) for vals in values]
+names = ["threshold_{0}".format(vals) for vals in values]
+
+
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold):
- l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ l1 = (
+ bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
+ .cuda()
+ .half()
+ )
assert l1.weight.dtype == torch.int8
l1.eval()
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
assert o1.dtype == torch.float16
@@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
-
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda")
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
- assert mlp.fc1.weight.device.type == 'cuda'
- assert mlp.fc2.weight.device.type == 'cuda'
+ assert mlp.fc1.weight.device.type == "cuda"
+ assert mlp.fc2.weight.device.type == "cuda"
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .to(torch.float16)
+ .to("cuda")
+ )
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
- assert mlp.fc1.weight.device.type == 'cuda'
- assert mlp.fc2.weight.device.type == 'cuda'
+ assert mlp.fc1.weight.device.type == "cuda"
+ assert mlp.fc2.weight.device.type == "cuda"