summaryrefslogtreecommitdiff
path: root/tests/test_optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optim.py')
-rw-r--r--tests/test_optim.py87
1 files changed, 19 insertions, 68 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py
index c80fe51..b173eaa 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -1,12 +1,9 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
import os
import time
import shutil
import uuid
import pytest
+import ctypes
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -14,7 +11,9 @@ import bitsandbytes.functional as F
from os.path import join
from itertools import product
-import apex
+#import apex
+
+k = 20
def get_temp_dir():
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
@@ -26,55 +25,47 @@ def rm_path(path):
str2optimizers = {}
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
-str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
-str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
+#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
+#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
-str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
-str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
+#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
+#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
-str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW)
-str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
+#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
-str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
+#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False))
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
+#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
-str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True))
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
-str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True))
str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
-str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
str2statenames['lars'] = [('momentum_buffer', 'state1')]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['rmsprop'] = [('square_avg', 'state1')]
-str2statenames['adagrad'] = [('sum', 'state1')]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
-str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
-str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
+optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
- atol, rtol = 2e-6, 1e-5
+ atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 1e-4, 1e-3
- for i in range(50):
+ for i in range(k):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
@@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
- if i % 10 == 0 and i > 0:
+ if i % (k//5) == 0 and i > 0:
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
del bnb_optimizer
@@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
- bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True)
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
@@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype):
else:
atol, rtol = 1e-4, 1e-3
- original_p2 = p2[mask].clone()
-
for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
@@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype):
p2.grad = g2
p3.grad = g3
- if i > 30 and i % 10 == 0:
- g1.data[mask] = 0.0
- g2.data[mask] = 0.0
- p1.grad = g1
- p2.grad = g2
- original_p1 = p1[mask].clone()
- original_p2 = p2[mask].clone()
- og_s1 = adam2.state[p2]['state1'][mask].clone()
- og_s2 = adam2.state[p2]['state2'][mask].clone()
- og_s11 = adam2.state[p1]['state1'][mask].clone()
- og_s21 = adam2.state[p1]['state2'][mask].clone()
-
adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8
- if i > 30 and i % 10 == 0:
- torch.testing.assert_allclose(original_p2, p2[mask])
- torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1)
- torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2)
- assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel()
- assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0
- assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0
-
-
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
+optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
-
bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g
- for i in range(5000):
- if i == 500:
+ for i in range(k):
+ if i == k//5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
@@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch.cuda.synchronize()
s = time.time()-t0
print('')
- params = 4500*4096*4096
+ params = (k-k//5)*dim1*dim2
print(optim_name, gtype, s/params)
#assert s < 3.9
-
-def test_str_betas():
- betas = (0.80, 0.95)
- strbetas = '(0.80, 0.95)'
-
- layer = torch.nn.Linear(10, 10)
-
- base = bnb.optim.Adam(layer.parameters(), betas=betas)
- strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
- assert base.defaults['betas'][0] == 0.8
- assert base.defaults['betas'][1] == 0.95
- assert strbase.defaults['betas'][0] == 0.8
- assert strbase.defaults['betas'][1] == 0.95
-
-