summaryrefslogtreecommitdiff
path: root/tests/test_optim.py
diff options
context:
space:
mode:
authorTitus von Koeller <titus@vonkoeller.com>2022-08-01 03:31:48 -0700
committerTitus von Koeller <titus@vonkoeller.com>2022-08-01 03:31:48 -0700
commitbfa0e33294f2b1dc25e65a33be2397f989824298 (patch)
tree396b5d722fdd79da068882ca7376e3636fcb3bb8 /tests/test_optim.py
parent597a8521b29e90958c31e47421016494da998648 (diff)
ran black and isort for coherent code formatting
Diffstat (limited to 'tests/test_optim.py')
-rw-r--r--tests/test_optim.py397
1 files changed, 265 insertions, 132 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py
index b173eaa..b84425e 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -1,81 +1,132 @@
+import ctypes
import os
-import time
import shutil
+import time
import uuid
+from itertools import product
+from os.path import join
+
import pytest
-import ctypes
import torch
+
import bitsandbytes as bnb
import bitsandbytes.functional as F
-from os.path import join
-from itertools import product
-
-#import apex
+# import apex
k = 20
+
def get_temp_dir():
- path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
+ path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
os.makedirs(path, exist_ok=True)
return path
+
def rm_path(path):
shutil.rmtree(path)
+
str2optimizers = {}
-str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
-#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
-#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
-str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
-#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
-#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
-
-str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
-#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
-str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
-#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
-str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
-str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
-#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
-str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
-
-str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
-str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
-str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
+str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
+# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
+# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
+str2optimizers["momentum_pytorch"] = (
+ None,
+ lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
+ bnb.optim.Adam,
+)
+# str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
+# str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
+
+str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
+# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
+str2optimizers["momentum"] = (
+ lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
+)
+str2optimizers["lars"] = (
+ lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
+)
+# str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
+str2optimizers["rmsprop"] = (
+ lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
+)
+str2optimizers["adam8bit"] = (
+ torch.optim.Adam,
+ lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
+)
+str2optimizers["momentum8bit"] = (
+ lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
+)
+str2optimizers["rmsprop8bit"] = (
+ lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
+)
+# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
+str2optimizers["lars8bit"] = (
+ lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
+)
+
+str2optimizers["adam8bit_blockwise"] = (
+ torch.optim.Adam,
+ lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
+)
+str2optimizers["momentum8bit_blockwise"] = (
+ lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
+)
+str2optimizers["rmsprop8bit_blockwise"] = (
+ lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
+)
str2statenames = {}
-str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
-str2statenames['momentum'] = [('momentum_buffer', 'state1')]
-str2statenames['lars'] = [('momentum_buffer', 'state1')]
-str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
-str2statenames['rmsprop'] = [('square_avg', 'state1')]
-str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
-str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
-str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
-str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
-str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
-str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
-str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
-str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
+str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
+str2statenames["momentum"] = [("momentum_buffer", "state1")]
+str2statenames["lars"] = [("momentum_buffer", "state1")]
+str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
+str2statenames["rmsprop"] = [("square_avg", "state1")]
+str2statenames["adam8bit"] = [
+ ("exp_avg", "state1", "qmap1", "max1"),
+ ("exp_avg_sq", "state2", "qmap2", "max2"),
+]
+str2statenames["lamb8bit"] = [
+ ("exp_avg", "state1", "qmap1", "max1"),
+ ("exp_avg_sq", "state2", "qmap2", "max2"),
+]
+str2statenames["adam8bit_blockwise"] = [
+ ("exp_avg", "state1", "qmap1", "absmax1"),
+ ("exp_avg_sq", "state2", "qmap2", "absmax2"),
+]
+str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
+str2statenames["momentum8bit_blockwise"] = [
+ ("momentum_buffer", "state1", "qmap1", "absmax1")
+]
+str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
+str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
+str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
-values = list(product(dim1,dim2, gtype, optimizer_names))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
+values = list(product(dim1, dim2, gtype, optimizer_names))
+names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
-
torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
@@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
else:
atol, rtol = 1e-4, 1e-3
-
for i in range(k):
- g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
@@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]:
- torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
+ torch.testing.assert_allclose(
+ torch_optimizer.state[p1][name1],
+ bnb_optimizer.state[p2][name2],
+ atol=atol,
+ rtol=rtol,
+ )
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
- if i % (k//5) == 0 and i > 0:
+ if i % (k // 5) == 0 and i > 0:
path = get_temp_dir()
- torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
+ torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
- bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
+ bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]:
- torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
+ torch.testing.assert_allclose(
+ torch_optimizer.state[p1][name1],
+ bnb_optimizer.state[p2][name2],
+ atol=atol,
+ rtol=rtol,
+ )
if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit
@@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.half().float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2)
- if optim_name in ['lars', 'lamb']:
- assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0
+ if optim_name in ["lars", "lamb"]:
+ assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
+
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
-values = list(product(dim1,dim2, gtype))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, gtype))
+names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
- p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
- p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
+ p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
+ p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
mask = torch.rand_like(p2) < 0.1
beta1 = 0.9
beta2 = 0.999
@@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
- bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
+ bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
p1 = p1.cuda()
@@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype):
atol, rtol = 1e-4, 1e-3
for i in range(50):
- g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
- g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
- g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
+ g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
+ g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
+ g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
p1.grad = g1
p2.grad = g2
p3.grad = g3
adam2.step()
- assert adam2.state[p3]['state1'].dtype == torch.uint8
- assert adam2.state[p3]['state2'].dtype == torch.uint8
-
+ assert adam2.state[p3]["state1"].dtype == torch.uint8
+ assert adam2.state[p3]["state2"].dtype == torch.uint8
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
-values = list(product(dim1,dim2, gtype, optimizer_names))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+optimizer_names = [
+ "adam8bit",
+ "momentum8bit",
+ "rmsprop8bit",
+ "adam8bit_blockwise",
+ "lamb8bit",
+ "lars8bit",
+ "momentum8bit_blockwise",
+ "rmsprop8bit_blockwise",
+]
+values = list(product(dim1, dim2, gtype, optimizer_names))
+names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 2048
@@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors = []
for i in range(50):
- g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
@@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
- #print(bnb_optimizer.state[p2][max_val], name1)
- if 'blockwise' in optim_name:
- s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
+ # print(bnb_optimizer.state[p2][max_val], name1)
+ if "blockwise" in optim_name:
+ s1 = F.dequantize_blockwise(
+ code=bnb_optimizer.state[p2][qmap],
+ absmax=bnb_optimizer.state[p2][max_val],
+ A=bnb_optimizer.state[p2][name2],
+ blocksize=blocksize,
+ )
else:
- s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
- num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
+ s1 = F.dequantize(
+ code=bnb_optimizer.state[p2][qmap],
+ absmax=bnb_optimizer.state[p2][max_val],
+ A=bnb_optimizer.state[p2][name2],
+ )
+ num_not_close = (
+ torch.isclose(
+ torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
+ )
+ == 0
+ )
assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone())
- err = torch.abs(p1-p2)
- relerr = err/torch.abs(p1)
+ err = torch.abs(p1 - p2)
+ relerr = err / torch.abs(p1)
assert err.mean() < 0.0001
assert relerr.mean() < 0.001
@@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0:
- for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
+ for (name1, name2, qmap, max_val), s in zip(
+ str2statenames[optim_name], dequant_states
+ ):
s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone()
path = get_temp_dir()
- torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
+ torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
- bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
+ bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
- if 'blockwise' in optim_name:
- s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
+ if "blockwise" in optim_name:
+ s1 = F.dequantize_blockwise(
+ code=bnb_optimizer.state[p2][qmap],
+ absmax=bnb_optimizer.state[p2][max_val],
+ A=bnb_optimizer.state[p2][name2],
+ blocksize=blocksize,
+ )
else:
- s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
+ s1 = F.dequantize(
+ code=bnb_optimizer.state[p2][qmap],
+ absmax=bnb_optimizer.state[p2][max_val],
+ A=bnb_optimizer.state[p2][name2],
+ )
torch.testing.assert_allclose(s1cpy, s1)
- num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
+ num_not_close = (
+ torch.isclose(
+ torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
+ )
+ == 0
+ )
assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
@@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2)
- for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
+ for (name1, name2, qmap, max_val), s in zip(
+ str2statenames[optim_name], dequant_states
+ ):
torch_optimizer.state[p1][name1].copy_(s.data)
- #print(sum(errors)/len(errors))
- #print(sum(relerrors)/len(relerrors))
-
+ # print(sum(errors)/len(errors))
+ # print(sum(relerrors)/len(relerrors))
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
-values = list(product(dim1,dim2, gtype, optim_bits))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, gtype, optim_bits))
+names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
beta1 = 0.9
beta2 = 0.999
lr = 0.001
@@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p1 = p1.cuda()
p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
- adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
+ adam2 = bnb.optim.Adam(
+ [p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5
+ )
gnorm_vec = torch.zeros(100).cuda()
step = 0
for i in range(50):
step += 1
- g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i)
+ g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
g2 = g1.clone()
p2.grad = g2
- current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
- g1 = (g1.float()*gnorm_scale).to(gtype)
+ current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
+ g1, gnorm_vec, step, 5
+ )
+ g1 = (g1.float() * gnorm_scale).to(gtype)
p1.grad = g1
adam1.step()
@@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32:
torch.testing.assert_allclose(p1, p2)
- torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4)
- torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4)
+ torch.testing.assert_allclose(
+ adam1.state[p1]["state1"],
+ adam2.state[p2]["state1"],
+ atol=5e-5,
+ rtol=1e-4,
+ )
+ torch.testing.assert_allclose(
+ adam1.state[p1]["state2"],
+ adam2.state[p2]["state2"],
+ atol=5e-5,
+ rtol=1e-4,
+ )
elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
- torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3)
- torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3)
- adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1'])
- adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2'])
+ torch.testing.assert_allclose(
+ adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3
+ )
+ torch.testing.assert_allclose(
+ adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3
+ )
+ adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
+ adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
if i % 10 == 0 and i > 0:
path = get_temp_dir()
- torch.save(adam2.state_dict(),join(path, 'opt.pt'))
+ torch.save(adam2.state_dict(), join(path, "opt.pt"))
del adam2
adam2 = None
- adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
- adam2.load_state_dict(torch.load(join(path, 'opt.pt')))
-
-
+ adam2 = bnb.optim.Adam(
+ [p2],
+ lr,
+ (beta1, beta2),
+ eps,
+ optim_bits=optim_bits,
+ percentile_clipping=5,
+ )
+ adam2.load_state_dict(torch.load(join(path, "opt.pt")))
dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
-#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
-#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
-#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
-#optimizer_names = ['lamb_apex', 'lamb8bit']
-#optimizer_names = ['lars_apex', 'lars8bit']
-optimizer_names = ['adam8bit_blockwise']
-values = list(product(dim1,dim2, gtype, optimizer_names))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
+# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
+# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
+# optimizer_names = ['lamb_apex', 'lamb8bit']
+# optimizer_names = ['lars_apex', 'lars8bit']
+optimizer_names = ["adam8bit_blockwise"]
+values = list(product(dim1, dim2, gtype, optimizer_names))
+names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1])
- g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g
for i in range(k):
- if i == k//5:
+ if i == k // 5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
@@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
bnb_optimizer.step()
torch.cuda.synchronize()
- s = time.time()-t0
- print('')
- params = (k-k//5)*dim1*dim2
- print(optim_name, gtype, s/params)
- #assert s < 3.9
-
-
+ s = time.time() - t0
+ print("")
+ params = (k - k // 5) * dim1 * dim2
+ print(optim_name, gtype, s / params)
+ # assert s < 3.9