summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/sgd.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
commit7439924891496025edf60c9da6a782f362a50c70 (patch)
tree90476984d2c267f89232577a2ea40eb172387475 /bitsandbytes/optim/sgd.py
Initial commit
Diffstat (limited to 'bitsandbytes/optim/sgd.py')
-rw-r--r--bitsandbytes/optim/sgd.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py
new file mode 100644
index 0000000..926d804
--- /dev/null
+++ b/bitsandbytes/optim/sgd.py
@@ -0,0 +1,32 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+class SGD(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class SGD8bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class SGD32bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)