summaryrefslogtreecommitdiff
path: root/howto_config_override.md
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-11-29 09:32:13 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2021-11-29 09:32:13 -0800
commit20e1677dfdc4495038fd780807c8cbc253adf921 (patch)
tree42011169e55eab3f4226ff171d84edac84ec6f8f /howto_config_override.md
parent3cff6795fb70dd99b4802593f3c70d291e0cd1dc (diff)
Added module override, bnb.nn.Embedding #13 #15 #19
Diffstat (limited to 'howto_config_override.md')
-rw-r--r--howto_config_override.md14
1 files changed, 14 insertions, 0 deletions
diff --git a/howto_config_override.md b/howto_config_override.md
index 11e9d49..4680776 100644
--- a/howto_config_override.md
+++ b/howto_config_override.md
@@ -2,6 +2,7 @@
If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details
+For global overrides in many different places in your code you can do:
```python
import torch
import bitsandbytes as bnb
@@ -24,3 +25,16 @@ mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
```
Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm`
+
+For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager:
+```python
+class MyModule(torch.nn.Module):
+ def __init__(din, dout):
+ super(MyModule, self).__init__()
+ self.linear = torch.nn.Linear(din, dout)
+ # optimization will happen in 32-bit and
+ # learning rate will be set to 0.0001 independent of the main learning rate
+ config = {'optim_bits': 32, 'lr' : 0.0001}
+ GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)
+
+```