summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
commit7439924891496025edf60c9da6a782f362a50c70 (patch)
tree90476984d2c267f89232577a2ea40eb172387475
Initial commit
-rw-r--r--.buckconfig0
-rw-r--r--.gitignore135
-rw-r--r--BUCK25
-rw-r--r--CHANGELOG.md23
-rw-r--r--CODE_OF_CONDUCT.md80
-rw-r--r--CONTRIBUTING.md31
-rw-r--r--LICENSE21
-rw-r--r--Makefile60
-rw-r--r--NOTICE.md3
-rw-r--r--README.md106
-rw-r--r--bitsandbytes/__init__.py10
-rw-r--r--bitsandbytes/functional.py531
-rw-r--r--bitsandbytes/nn/__init__.py5
-rw-r--r--bitsandbytes/nn/modules.py44
-rw-r--r--bitsandbytes/optim/__init__.py10
-rw-r--r--bitsandbytes/optim/adam.py28
-rw-r--r--bitsandbytes/optim/lamb.py29
-rw-r--r--bitsandbytes/optim/lars.py115
-rw-r--r--bitsandbytes/optim/optimizer.py460
-rw-r--r--bitsandbytes/optim/rmsprop.py37
-rw-r--r--bitsandbytes/optim/sgd.py32
-rw-r--r--csrc/kernels.cu1846
-rw-r--r--csrc/kernels.cuh111
-rw-r--r--csrc/ops.cu355
-rw-r--r--csrc/ops.cuh81
-rw-r--r--csrc/pythonInterface.c149
-rw-r--r--deploy.sh13
-rw-r--r--deploy_from_slurm.sh86
-rw-r--r--include/AAlloc.h86
-rw-r--r--include/Algo-Direct-Common.h341
-rw-r--r--include/Algo-Direct2.h305
-rw-r--r--include/AlgoXCodes.h23
-rw-r--r--include/BinAlgo.h77
-rw-r--r--include/BinSearch.h11
-rw-r--r--include/Portable.h151
-rw-r--r--include/SIMD.h562
-rw-r--r--include/Type.h221
-rw-r--r--pyproject.toml6
-rw-r--r--requirements.txt1
-rw-r--r--setup.py32
-rw-r--r--tests/test_functional.py213
-rw-r--r--tests/test_optim.py362
42 files changed, 6817 insertions, 0 deletions
diff --git a/.buckconfig b/.buckconfig
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/.buckconfig
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..f8ebf71
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,135 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# vim
+*.swp
+
+dependencies
+cuda_build
diff --git a/BUCK b/BUCK
new file mode 100644
index 0000000..8ba6ac3
--- /dev/null
+++ b/BUCK
@@ -0,0 +1,25 @@
+prebuilt_python_library(
+ name = 'bnb-cuda102',
+ binary_src = ':bnb-cuda102-wheel',
+)
+
+
+remote_file(
+ name = 'bnb-cuda102-wheel',
+ url = 'https://test-files.pythonhosted.org/packages/4e/69/025b08bf1b7e777ca3800dc79ebe9dfd7309931f0a5f3de132d1433076ff/bitsandbytes_cuda102-0.0.22-py3-none-any.whl',
+ sha1 = '8c89e640afab18cdc6b7c5924c70e25036811686',
+ )
+
+
+prebuilt_python_library(
+ name = 'bnb-cuda111',
+ binary_src = ':bnb-cuda111-wheel',
+)
+
+
+remote_file(
+ name = 'bnb-cuda111-wheel',
+ url = 'https://test-files.pythonhosted.org/packages/f9/38/2179701c80ae2aa9606bce7d498f397bd94e7bb2ff7e7c30ed032a3a39c2/bitsandbytes_cuda111-0.0.22-py3-none-any.whl',
+ sha1 = '433f534b225bc29391782c8a9d82635bc0eb9d33',
+ )
+
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..132e7ec
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,23 @@
+v0.0.21
+- Ampere, RTX 30 series GPUs now compatible with the library.
+
+v0.0.22:
+
+- Fixed an error where a `reset_parameters()` call on the `StableEmbedding` would lead to an error in older PyTorch versions (from 1.7.0).
+
+v0.0.23:
+
+Bugs:
+ - Unified quantization API: each quantization function now returns `Q, S` where `Q` is the quantized tensor and `S` the quantization state which may hold absolute max values, a quantization map or more. For dequantization all functions now accept the inputs `Q, S` so that `Q` is dequantized with the quantization state `S`.
+ - Fixed an issue where the CUDA 11.1 binary was not compiled with the right headers
+
+API changes:
+ - Block-wise quantization for optimizers now enabled by default
+
+Features:
+ - Block-wise quantization routines now support CPU Tensors.
+
+
+v0.0.24:
+
+- Fixed a bug where a float/half conversion led to a compilation error for CUDA 11.1 on Turning GPUs.
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000..08b500a
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,80 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+This Code of Conduct also applies outside the project spaces when there is a
+reasonable belief that an individual's behavior may have a negative impact on
+the project or its community.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at <opensource-conduct@fb.com>. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..7996343
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,31 @@
+# Contributing to bitsandbytes
+We want to make contributing to this project as easy and transparent as
+possible.
+
+## Pull Requests
+We actively welcome your pull requests.
+
+1. Fork the repo and create your branch from `main`.
+2. If you've added code that should be tested, add tests.
+3. If you've changed APIs, update the documentation.
+4. Ensure the test suite passes.
+5. Make sure your code lints.
+6. If you haven't already, complete the Contributor License Agreement ("CLA").
+
+## Contributor License Agreement ("CLA")
+In order to accept your pull request, we need you to submit a CLA. You only need
+to do this once to work on any of Facebook's open source projects.
+
+Complete your CLA here: <https://code.facebook.com/cla>
+
+## Issues
+We use GitHub issues to track public bugs. Please ensure your description is
+clear and has sufficient instructions to be able to reproduce the issue.
+
+Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
+disclosure of security bugs. In those cases, please go through the process
+outlined on that page and do not file a public issue.
+
+## License
+By contributing to bitsandbytes, you agree that your contributions will be licensed
+under the LICENSE file in the root directory of this source tree. \ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..b96dcb0
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) Facebook, Inc. and its affiliates.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..5f5efed
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,60 @@
+MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST)))
+ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH)))
+
+GPP:= /usr/bin/g++
+NVCC := $(CUDA_HOME)/bin/nvcc
+###########################################
+
+CSRC := $(ROOT_DIR)/csrc
+BUILD_DIR:= $(ROOT_DIR)/cuda_build
+
+FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
+FILES_CPP := $(CSRC)/pythonInterface.c
+
+INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
+LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
+
+# NVIDIA NVCC compilation flags
+COMPUTE_CAPABILITY := -gencode arch=compute_50,code=sm_50 # Maxwell
+COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
+COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
+COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
+COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
+
+all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
+ $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+
+cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
+ $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+
+cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -gencode arch=compute_75,code=sm_75 -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
+ $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+
+cuda110: $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -gencode arch=compute_80,code=sm_80 -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
+ $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+
+cuda11x: $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
+ $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+
+$(BUILD_DIR):
+ mkdir -p cuda_build
+ mkdir -p dependencies
+
+$(ROOT_DIR)/dependencies/cub:
+ git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub
+
+clean:
+ rm cuda_build/* ./bitsandbytes/libbitsandbytes.so
+
+cleaneggs:
+ rm -rf *.egg*
diff --git a/NOTICE.md b/NOTICE.md
new file mode 100644
index 0000000..660658b
--- /dev/null
+++ b/NOTICE.md
@@ -0,0 +1,3 @@
+The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license.
+
+We thank Fabio Cannizzo for this work on FastBinarySearch which is included in this project.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e3e9209
--- /dev/null
+++ b/README.md
@@ -0,0 +1,106 @@
+# bitsandbytes
+
+bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers and quantization functions.
+
+## Features
+- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB
+- Percentile clipping: A gradient clipping technique that adjusts dynamically for each weight-tensor during training
+- Stable Embedding Layer: Improved stability through better initialization, and normalization
+- Fast quantile estimation: Up to 100x faster than other algorithms
+- 8-bit quantization: Quantile, Linear, and Dynamic quantization
+
+#### Details
+- **8-bit Optimizers** use an 8-bit instead of 32-bit state and thus save 75% of memory.
+- **Percentile Clipping** is an adaptive gradient clipping technique that adapts the clipping threshold automatically during training for each weight-tensor. It tracks a history of the past 100 gradient norms, and the gradient is clipped at a certain percentile p. For most tasks, p=5 works well and provides improved stability and, in some cases, even better performance (ResNet-50 ImageNet).
+- The **Stable Embedding Layer** uses a less variable initialization coupled with layer norm for stability. Usually, dense optimizers are used in conjunction with sparse BPE/word embeddings, and these dense optimizers perform incorrect updates, leading to instability. The Stable Embedding Layer fixes this problem by performing sparse updates by default for any chosen bnb optimizer.
+- Fast quantile estimation via **SRAM-Quantiles** algorithm, which is up to 100x faster than previous algorithms to estimate quantiles.
+- Various **8-bit Quantization** schemes which are useful to compress data. For example, gradient communication or Mixture of Experts token routing can be improved by using 8-bit quantization before communication followed by decompression to 16/32-bit.
+
+## Requirements & Installation
+
+Requirements: anaconda, cudatoolkit, pytorch
+Hardware requirements: NVIDIA Maxwell GPU or newer (>=GTX 9XX)
+
+The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website.
+
+bitsandbytes is compatible with all major PyTorch releases and cudatoolkit versions, but for now, you need to select the right version manually. To do this run:
+
+```conda list | grep cudatoolkit```
+
+and take note of the Cuda version that you have installed. Then you can install bitsandbytes via:
+```bash
+# choices: {cuda92, cuda 100, cuda101, cuda102, cuda110, cuda111, cuda113}
+# replace XXX with the respective number
+pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX
+```
+
+To check if your installation was successful, you can execute the following command, which runs a single bnb Adam update.
+```
+wget https://gist.githubusercontent.com/TimDettmers/1f5188c6ee6ed69d211b7fe4e381e713/raw/4d17c3d09ccdb57e9ab7eca0171f2ace6e4d2858/check_bnb_install.py && python check_bnb_install.py
+```
+
+## Using bitsandbytes
+
+### Using the 8-bit Optimizers
+
+With bitsandbytes 8-bit optimizers can be used by changing a single line of code in your codebase. For NLP models we recommend also to use the StableEmbedding layers (see below) which improves results and helps with stable 8-bit optimization. To get started with 8-bit optimizers, it is sufficient to replace your old optimizer with the 8-bit optimizer in the following way:
+```python
+import bitsandbytes as bnb
+
+# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer
+adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer
+adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalent
+
+# use 32-bit Adam with 5th percentile clipping
+adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995),
+ optim_bits=32, percentile_clipping=5)
+```
+
+Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm).
+
+### Change Bits and other Hyperparameters for Individual Parameters
+
+If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, with can use the `GlobalOptimManager`. With this, we can also configure specific parameters for sparse optimization, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere).
+
+```python
+import torch
+import bitsandbytes as bnb
+
+mng = bnb.optim.GlobalOptimManager.get_instance()
+
+model = MyModel()
+mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU
+
+model = model.cuda()
+# use 8-bit optimizer states for all parameters
+adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
+
+# 2a. override: the parameter model.fc1.weight now uses 32-bit Adam
+mng.override_config(model.fc1.weight, 'optim_bits', 32)
+
+# 2b. override: the two special layers use
+# sparse optimization + different learning rate + different Adam betas
+mng.override_config([model.special.weight, model.also_special.weight],
+ key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
+```
+
+### Stable Embedding Layer
+
+To use the stable embedding layer, simply replace the PyTorch embedding layer with `bnb.nn.StableEmbedding`. By default, this layer is sparsely optimized.
+
+### Fairseq Users
+
+To use the Stable Embedding Layer, override the respective `build_embedding(...)` function of your model. Make sure to also use the `--no-scale-embedding` flag to disable scaling of the word embedding layer (nor replaced with layer norm). You can use the optimizers by replacing the optimizer in the respective file (`adam.py` etc.).
+
+## Release and Feature History
+
+Last release: v0.0.22:
+- Fixed an error where a `reset_parameters()` call on the `StableEmbedding` would lead to an error in older PyTorch versions (from 1.7.0).
+
+For upcoming features and changes and full history see [Patch Notes](PATCH_NOTES.md).
+
+## License
+
+The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license.
+
+We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization.
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
new file mode 100644
index 0000000..6e29322
--- /dev/null
+++ b/bitsandbytes/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .optim import adam
+from .nn import modules
+__pdoc__ = {'libBitsNBytes' : False,
+ 'optim.optimizer.Optimizer8bit': False,
+ 'optim.optimizer.MockArgs': False
+ }
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
new file mode 100644
index 0000000..65c697d
--- /dev/null
+++ b/bitsandbytes/functional.py
@@ -0,0 +1,531 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import random
+import math
+import ctypes as ct
+import torch
+from torch import Tensor
+from typing import Tuple
+
+lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
+name2qmap = {}
+
+''' C FUNCTIONS FOR OPTIMIZERS '''
+
+str2optimizer32bit = {}
+str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
+str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
+str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
+str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
+str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
+
+str2optimizer8bit = {}
+str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
+str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
+str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16)
+str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
+str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
+
+str2optimizer8bit_blockwise = {}
+str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16)
+str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16)
+str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
+
+optimal_normal = [-0.9939730167388916, -0.8727636337280273, -0.8097418546676636, -0.7660024166107178, -0.7318882346153259, -0.6793879270553589, -0.657649040222168, -0.6385974884033203, -0.6211113333702087, -0.5901028513908386, -0.5762918591499329, -0.5630806684494019, -0.5509274005889893, -0.5394591689109802, -0.5283197164535522, -0.517780065536499, -0.5074946284294128, -0.4980469048023224, -0.48867011070251465, -0.48003149032592773, -0.47125306725502014, -0.4629971981048584, -0.4547359049320221, -0.446626216173172, -0.43902668356895447, -0.43158355355262756, -0.4244747757911682, -0.4173796474933624, -0.41038978099823, -0.4055633544921875, -0.4035947024822235, -0.39701032638549805, -0.39057496190071106, -0.38439232110977173, -0.3782760500907898, -0.3721940815448761, -0.3661896586418152, -0.3604033589363098, -0.354605108499527, -0.34892538189888, -0.34320303797721863, -0.3376772701740265, -0.3323028087615967, -0.3269782066345215, -0.32166096568107605, -0.316457599401474, -0.3112771809101105, -0.3061025142669678, -0.30106794834136963, -0.2961243987083435, -0.2912728488445282, -0.28644347190856934, -0.28165507316589355, -0.2769731283187866, -0.2722635865211487, -0.26779335737228394, -0.26314786076545715, -0.2586647868156433, -0.2541804611682892, -0.2496625930070877, -0.24527113139629364, -0.24097171425819397, -0.23659978806972504, -0.23218469321727753, -0.22799566388130188, -0.22380566596984863, -0.21965542435646057, -0.2154538631439209, -0.2113603949546814, -0.20735277235507965, -0.20334717631340027, -0.19932441413402557, -0.19530178606510162, -0.19136647880077362, -0.18736697733402252, -0.18337111175060272, -0.17951400578022003, -0.1757056713104248, -0.17182783782482147, -0.1680615097284317, -0.16431649029254913, -0.16053077578544617, -0.15685945749282837, -0.15298527479171753, -0.1493264138698578, -0.14566898345947266, -0.14188314974308014, -0.13819937407970428, -0.1344561129808426, -0.1306886374950409, -0.1271020770072937, -0.12346585839986801, -0.11981867253780365, -0.11614970862865448, -0.11256207525730133, -0.10889036953449249, -0.10525048524141312, -0.1016591489315033, -0.09824034571647644, -0.09469068050384521, -0.0911419615149498, -0.08773849159479141, -0.08416644483804703, -0.08071305602788925, -0.07720902562141418, -0.07371306419372559, -0.07019119709730148, -0.06673648208379745, -0.06329209357500076, -0.059800852090120316, -0.0564190037548542, -0.05296570807695389, -0.049522045999765396, -0.04609023034572601, -0.04262964054942131, -0.039246633648872375, -0.03577171266078949, -0.03236335143446922, -0.028855687007308006, -0.02542758360505104, -0.022069433704018593, -0.018754752352833748, -0.015386369079351425, -0.01194947212934494, -0.008439815603196621, -0.004995611496269703, -0.0016682245768606663, 0.0, 0.0015510577941313386, 0.005062474869191647, 0.008417150937020779, 0.011741090565919876, 0.015184164978563786, 0.018582714721560478, 0.02204744517803192, 0.025471193715929985, 0.02889077737927437, 0.0323684960603714, 0.03579240292310715, 0.039281025528907776, 0.0427563451230526, 0.04619763046503067, 0.04968220740556717, 0.05326594039797783, 0.05679265409708023, 0.060245808213949203, 0.06372645497322083, 0.06721872836351395, 0.0706876739859581, 0.0742349922657013, 0.07774098962545395, 0.08123527467250824, 0.08468879014253616, 0.08810535818338394, 0.09155989438295364, 0.09498448669910431, 0.0985206812620163, 0.10206405073404312, 0.10563778132200241, 0.10921968519687653, 0.11284469068050385, 0.11653254181146622, 0.12008969485759735, 0.12368203699588776, 0.1272617131471634, 0.13089501857757568, 0.134552001953125, 0.1382799744606018, 0.14194637537002563, 0.14563234150409698, 0.14930322766304016, 0.15303383767604828, 0.1567956507205963, 0.16050070524215698, 0.16431072354316711, 0.16813558340072632, 0.17204202711582184, 0.1758781224489212, 0.17973239719867706, 0.1836014688014984, 0.18753431737422943, 0.19138391315937042, 0.19535475969314575, 0.19931404292583466, 0.20333819091320038, 0.20738255977630615, 0.21152682602405548, 0.21568812429904938, 0.21978361904621124, 0.22393859922885895, 0.22814159095287323, 0.23241068422794342, 0.23675410449504852, 0.24123944342136383, 0.24569889903068542, 0.2500703036785126, 0.25904011726379395, 0.26349544525146484, 0.2682226300239563, 0.272907555103302, 0.2774306833744049, 0.28220856189727783, 0.2869136929512024, 0.2916390895843506, 0.29649388790130615, 0.30142995715141296, 0.3065022826194763, 0.3114383816719055, 0.31648796796798706, 0.3216581642627716, 0.32700115442276, 0.3322487473487854, 0.33778008818626404, 0.3431521952152252, 0.3487405776977539, 0.3543166518211365, 0.3601346015930176, 0.36605337262153625, 0.37217751145362854, 0.378179669380188, 0.3843980133533478, 0.3906566798686981, 0.39714935421943665, 0.40357843041419983, 0.4104187488555908, 0.4171563684940338, 0.42418959736824036, 0.43136918544769287, 0.4389212429523468, 0.44673123955726624, 0.45457619428634644, 0.4627031683921814, 0.47130417823791504, 0.4798591434955597, 0.48897242546081543, 0.4979848861694336, 0.5, 0.5076631307601929, 0.5177803635597229, 0.5282770991325378, 0.5392990112304688, 0.5506287813186646, 0.5632893443107605, 0.5764452815055847, 0.5903191566467285, 0.6051878333091736, 0.6209936141967773, 0.6382884979248047, 0.6573970913887024, 0.6795773506164551, 0.7037051916122437, 0.7327037453651428, 0.7677436470985413, 0.8111193776130676, 0.875165581703186, 1.0]
+
+optimal_half_normal = [0.0025565922260284424, 0.005811259150505066, 0.00961565226316452, 0.010822802782058716, 0.013123787939548492, 0.014242202043533325, 0.0143156498670578, 0.016469404101371765, 0.017666727304458618, 0.01773911714553833, 0.0199756920337677, 0.0210941880941391, 0.021161124110221863, 0.02451971173286438, 0.024580076336860657, 0.02685210108757019, 0.028012827038764954, 0.030198264867067337, 0.0302925705909729, 0.03136435151100159, 0.03374280035495758, 0.03487399220466614, 0.035243816673755646, 0.037192340940237045, 0.03822284936904907, 0.04164902865886688, 0.04173608124256134, 0.04401407018303871, 0.04508155584335327, 0.047482021152973175, 0.04756556823849678, 0.050963032990694046, 0.05196474492549896, 0.055417388677597046, 0.05793146416544914, 0.05799369141459465, 0.05887940526008606, 0.05895659327507019, 0.062420234084129333, 0.06493274495005608, 0.06499008461833, 0.06935599446296692, 0.07197384163737297, 0.07201516255736351, 0.07276943325996399, 0.07283210754394531, 0.07550075277686119, 0.07975354790687561, 0.07980883121490479, 0.08257630094885826, 0.0867777168750763, 0.08682405948638916, 0.08967285975813866, 0.09323835000395775, 0.09386616945266724, 0.09735457599163055, 0.09739077091217041, 0.10092401504516602, 0.10444298386573792, 0.10447832942008972, 0.10770941898226738, 0.10803905129432678, 0.11161200702190399, 0.1151546835899353, 0.11520349979400635, 0.11875157058238983, 0.11879390478134155, 0.1222602017223835, 0.122351735830307, 0.12240418791770935, 0.12594850733876228, 0.12597402930259705, 0.12602100148797035, 0.12960633635520935, 0.1296597123146057, 0.12966342642903328, 0.13227657973766327, 0.13325360417366028, 0.1333133578300476, 0.13691483438014984, 0.1371927298605442, 0.14066261053085327, 0.14088113978505135, 0.1447291411459446, 0.14805573225021362, 0.148526418954134, 0.15170684456825256, 0.15178103744983673, 0.15225710347294807, 0.1554398238658905, 0.15609459951519966, 0.15618794038891792, 0.1592724472284317, 0.1629735231399536, 0.16382690146565437, 0.16676269471645355, 0.16873238794505596, 0.17066434025764465, 0.17068277299404144, 0.1717144437134266, 0.17558929696679115, 0.17827065289020538, 0.17835864424705505, 0.18222273886203766, 0.18353315070271492, 0.18604370951652527, 0.18611834943294525, 0.1876586265861988, 0.18996606767177582, 0.19170701876282692, 0.19398853182792664, 0.19786442816257477, 0.19795633852481842, 0.20195159316062927, 0.2058800607919693, 0.2099103182554245, 0.2122517265379429, 0.21410366892814636, 0.21819619834423065, 0.22221362590789795, 0.22233009338378906, 0.22500130906701088, 0.2251257635653019, 0.22638091444969177, 0.23067741096019745, 0.23368822410702705, 0.2348879873752594, 0.2382080741226673, 0.2390350103378296, 0.2391497790813446, 0.24253453686833382, 0.24265171959996223, 0.2470107562839985, 0.24764248728752136, 0.24777774512767792, 0.2516774423420429, 0.256104726344347, 0.2564055472612381, 0.2607169933617115, 0.265461727976799, 0.26985861361026764, 0.2701106257736683, 0.2702729292213917, 0.274574413895607, 0.2750340588390827, 0.27919672429561615, 0.283704474568367, 0.28386808931827545, 0.28953738883137703, 0.2896753139793873, 0.29320384562015533, 0.29451676085591316, 0.295327290892601, 0.29802779853343964, 0.29818175733089447, 0.29972871020436287, 0.30290623009204865, 0.30305664241313934, 0.30486901476979256, 0.31299956142902374, 0.31518544629216194, 0.31790371239185333, 0.3205283172428608, 0.3230419009923935, 0.32595496252179146, 0.32612212374806404, 0.3282426446676254, 0.3283906430006027, 0.33146094158291817, 0.3316439874470234, 0.33365286886692047, 0.33723779395222664, 0.3390095978975296, 0.3427443392574787, 0.34853987768292427, 0.34869300201535225, 0.35457711294293404, 0.35537679493427277, 0.3604113645851612, 0.36124424636363983, 0.3665340431034565, 0.36667295172810555, 0.3727492541074753, 0.3729033060371876, 0.37888188660144806, 0.37907837703824043, 0.3792510814964771, 0.38557394221425056, 0.38573457673192024, 0.39108292758464813, 0.39911722019314766, 0.40589402988553047, 0.40604450181126595, 0.410498782992363, 0.4106704741716385, 0.4129834659397602, 0.4131447561085224, 0.4172855168581009, 0.4202354736626148, 0.4204071946442127, 0.43538858368992805, 0.4355536885559559, 0.4432900734245777, 0.44603554904460907, 0.4461968094110489, 0.451409537345171, 0.4598204083740711, 0.46002377942204475, 0.46178819239139557, 0.46868549659848213, 0.46995367109775543, 0.4868385046720505, 0.48702501133084297, 0.4958047419786453, 0.4960057884454727, 0.5051481872797012, 0.506847757846117, 0.5148334950208664, 0.5150565356016159, 0.5174009390175343, 0.5249751061201096, 0.5283288545906544, 0.5355450958013535, 0.539984006434679, 0.5467876642942429, 0.5522958822548389, 0.5584012717008591, 0.5706631988286972, 0.5836620181798935, 0.5836880058050156, 0.5942088551819324, 0.5975865572690964, 0.6102624125778675, 0.6124880760908127, 0.6286389082670212, 0.646102175116539, 0.6471664495766163, 0.665437325835228, 0.6687244363129139, 0.687017485499382, 0.6932839937508106, 0.7115348428487778, 0.7218200154602528, 0.7219699807465076, 0.7747527211904526, 0.7749756425619125, 0.8192005604505539, 0.8194110840559006, 0.8830635994672775, 0.9217727445065975, 0.9245667457580566, 0.947742685675621, 0.9674464613199234, 0.9890814647078514, 0.9891453236341476, 0.9925699159502983]
+
+def create_linear_map(signed=True):
+ if signed:
+ return torch.linspace(-1.0, 1.0, 256)
+ else:
+ return torch.linspace(0.0, 1.0, 256)
+
+def create_dynamic_map(signed=True, n=7):
+ '''
+ Creates the dynamic quantiztion map.
+
+ The dynamic data type is made up of a dynamic exponent and
+ fraction. As the exponent increase from 0 to -7 the number
+ of bits available for the fraction shrinks.
+
+ This is a generalization of the dynamic type where a certain
+ number of the bits and be reserved for the linear quantization
+ region (the fraction). n determines the maximum number of
+ exponent bits.
+
+ For more details see
+ (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
+ '''
+
+ data = []
+ # these are additional items that come from the case
+ # where all the exponent bits are zero and no
+ # indicator bit is present
+ additional_items = 2**(7-n)-1
+ if not signed: additional_items = 2*additional_items
+ for i in range(n):
+ fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1
+ boundaries = torch.linspace(0.1, 1, fraction_items)
+ means = (boundaries[:-1]+boundaries[1:])/2.0
+ data += ((10**(-(n-1)+i))*means).tolist()
+ if signed:
+ data += (-(10**(-(n-1)+i))*means).tolist()
+
+ if additional_items > 0:
+ boundaries = torch.linspace(0.1, 1, additional_items+1)
+ means = (boundaries[:-1]+boundaries[1:])/2.0
+ data += ((10**(-(n-1)+i))*means).tolist()
+ if signed:
+ data += (-(10**(-(n-1)+i))*means).tolist()
+
+ data.append(0)
+ data.append(1.0)
+ data.sort()
+ return Tensor(data)
+
+def get_ptr(A: Tensor) -> ct.c_void_p:
+ '''
+ Get the ctypes pointer from a PyTorch Tensor.
+
+ Parameters
+ ----------
+ A : torch.tensor
+ The PyTorch tensor.
+
+ Returns
+ -------
+ ctypes.c_void_p
+ '''
+ if A is None: return None
+ else: return ct.c_void_p(A.data.storage().data_ptr())
+
+def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor:
+ '''
+ Estimates 256 equidistant quantiles on the input tensor eCDF.
+
+ Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
+ via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
+ and the extreme quantiles close to 0 and 1 have high variance / large estimation
+ errors. These large errors can be avoided by using the offset variable which trims
+ the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
+ trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
+ usually has a much lower error but is not a minimum entropy encoding. Given an offset
+ of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.
+
+ Parameters
+ ----------
+ A : torch.Tensor
+ The input tensor. Any shape.
+ out : torch.Tensor
+ Tensor with the 256 estimated quantiles.
+ offset : float
+ The offset for the first and last quantile from 0 and 1. Default: 1/512
+
+ Returns
+ -------
+ torch.Tensor:
+ The 256 quantiles in float32 datatype.
+ '''
+ if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+ if A.dtype == torch.float32:
+ lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
+ elif A.dtype == torch.float16:
+ lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
+ else:
+ raise NotImplementError(f'Not supported data type {A.dtype}')
+ return out
+
+def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor:
+ '''
+ Quantize tensor A in blocks of size 4096 values.
+
+ Quantizes tensor A by dividing it into blocks of 4096 values.
+ Then the absolute maximum value within these blocks is calculated
+ for the non-linear quantization.
+
+ Parameters
+ ----------
+ A : torch.Tensor
+ The input tensor.
+ code : torch.Tensor
+ The quantization map.
+ absmax : torch.Tensor
+ The absmax values.
+ rand : torch.Tensor
+ The tensor for stochastic rounding.
+ out : torch.Tensor
+ The output tensor (8-bit).
+
+ Returns
+ -------
+ torch.Tensor:
+ The 8-bit tensor.
+ tuple(torch.Tensor, torch.Tensor):
+ The quantization state to undo the quantization.
+ '''
+
+ if code is None:
+ if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
+ code = name2qmap['dynamic']
+ code = code.to(A.device)
+
+ if absmax is None:
+ n = A.numel()
+ num_blocks = 4096
+ blocks = n//num_blocks
+ blocks += 1 if n % num_blocks > 0 else 0
+ absmax = torch.zeros((blocks,), device=A.device)
+
+ if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+
+
+ if A.device.type != 'cpu':
+ if rand is not None:
+ assert rand.numel() >= 1024
+ rand_offset = random.randint(0, 1023)
+ if A.dtype == torch.float32:
+ lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
+ elif A.dtype == torch.float16:
+ lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
+ else:
+ raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ else:
+ if A.dtype == torch.float32:
+ lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+ elif A.dtype == torch.float16:
+ lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+ else:
+ raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ else:
+ # cpu
+ assert rand is None
+ lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+
+ return out, (absmax, code)
+
+def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
+ absmax: Tensor=None, code: Tensor=None, out: Tensor=None,
+ blocksize: int=4096) -> Tensor:
+ '''
+ Dequantizes blockwise quantized values.
+
+ Dequantizes the tensor A with maximum absolute values absmax in
+ blocks of size 4096.
+
+ Parameters
+ ----------
+ A : torch.Tensor
+ The input 8-bit tensor.
+ quant_state : tuple(torch.Tensor, torch.Tensor)
+ Tuple of code and absmax values.
+ absmax : torch.Tensor
+ The absmax values.
+ code : torch.Tensor
+ The quantization map.
+ out : torch.Tensor
+ Dequantized output tensor (default: float32)
+
+
+ Returns
+ -------
+ torch.Tensor:
+ Dequantized tensor (default: float32)
+ '''
+ assert quant_state is not None or absmax is not None
+ if code is None and quant_state is None:
+ if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
+ code = name2qmap['dynamic']
+ code = code.to(A.device)
+
+ if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+ if quant_state is None: quant_state = (absmax, code)
+
+ if blocksize not in [2048, 4096]:
+ raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]')
+
+ if A.device.type != 'cpu':
+ if out.dtype == torch.float32:
+ lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
+ elif out.dtype == torch.float16:
+ lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
+ else:
+ raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ else:
+ lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(A.numel()))
+
+
+ return out
+
+
+def quantize(A: Tensor, code: Tensor=None, out: Tensor=None) -> Tensor:
+ if code is None:
+ if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
+ code = name2qmap['dynamic']
+ code = code.to(A.device)
+
+ absmax = torch.abs(A).max()
+ inp = A/absmax
+ out = quantize_no_absmax(inp, code, out)
+ return out, (absmax, code)
+
+def dequantize(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, absmax: Tensor=None, code: Tensor=None, out: Tensor=None) -> Tensor:
+ assert quant_state is not None or absmax is not None
+ if code is None and quant_state is None:
+ if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
+ code = name2qmap['dynamic']
+ code = code.to(A.device)
+
+ if quant_state is None: quant_state = (absmax, code)
+ out = dequantize_no_absmax(A, quant_state[1], out)
+ return out*quant_state[0]
+
+def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
+ '''
+ Quantizes input tensor to 8-bit.
+
+ Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
+ `out` using the quantization map `code`.
+
+ Parameters
+ ----------
+ A : torch.Tensor
+ The input tensor.
+ code : torch.Tensor
+ The quantization map.
+ out : torch.Tensor, optional
+ The output tensor. Needs to be of type byte.
+
+ Returns
+ -------
+ torch.Tensor:
+ Quantized 8-bit tensor.
+ '''
+ if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+ lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
+ return out
+
+def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
+ '''
+ Dequantizes the 8-bit tensor to 32-bit.
+
+ Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
+ the quantization map `code`.
+
+ Parameters
+ ----------
+ A : torch.Tensor
+ The 8-bit input tensor.
+ code : torch.Tensor
+ The quantization map.
+ out : torch.Tensor
+ The 32-bit output tensor.
+
+ Returns
+ -------
+ torch.Tensor:
+ 32-bit output tensor.
+ '''
+ if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+ lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
+ return out
+
+def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Tensor,
+ beta1: float, eps: float, step: int, lr: float,
+ state2: Tensor=None, beta2: float=0.0,
+ weight_decay: float=0.0, gnorm_scale: float=1.0,
+ unorm_vec: Tensor=None, max_unorm: float=0.0) -> None:
+ '''
+ Performs an inplace optimizer update with one or two optimizer states.
+
+ Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
+
+ Parameters
+ ----------
+ optimizer_name : str
+ The name of the optimizer: {adam}.
+ g : torch.Tensor
+ Gradient tensor.
+ p : torch.Tensor
+ Parameter tensor.
+ state1 : torch.Tensor
+ Optimizer state 1.
+ beta1 : float
+ Optimizer beta1.
+ eps : float
+ Optimizer epsilon.
+ weight_decay : float
+ Weight decay.
+ step : int
+ Current optimizer step.
+ lr : float
+ The learning rate.
+ state2 : torch.Tensor
+ Optimizer state 2.
+ beta2 : float
+ Optimizer beta2.
+ gnorm_scale : float
+ The factor to rescale the gradient to the max clip value.
+ '''
+
+ param_norm = 0.0
+ if max_unorm > 0.0:
+ param_norm = torch.norm(p.data.float())
+
+ if optimizer_name not in str2optimizer32bit:
+ raise NotImplementError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')
+
+ if g.dtype == torch.float32 and state1.dtype == torch.float32:
+ str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
+ ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
+ ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ elif g.dtype == torch.float16 and state1.dtype == torch.float32:
+ str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
+ ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
+ ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ else:
+ raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
+
+def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
+ beta1: float, beta2: float, eps: float,
+ step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
+ max1: Tensor, max2: Tensor, new_max1: Tensor, new_max2: Tensor,
+ weight_decay: float=0.0, gnorm_scale: float=1.0,
+ unorm_vec: Tensor=None, max_unorm: float=0.0) -> None:
+ '''
+ Performs an inplace Adam update.
+
+ Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
+ Uses AdamW formulation if weight decay > 0.0.
+
+ Parameters
+ ----------
+ optimizer_name : str
+ The name of the optimizer. Choices {adam, momentum}
+ g : torch.Tensor
+ Gradient tensor.
+ p : torch.Tensor
+ Parameter tensor.
+ state1 : torch.Tensor
+ Adam state 1.
+ state2 : torch.Tensor
+ Adam state 2.
+ beta1 : float
+ Adam beta1.
+ beta2 : float
+ Adam beta2.
+ eps : float
+ Adam epsilon.
+ weight_decay : float
+ Weight decay.
+ step : int
+ Current optimizer step.
+ lr : float
+ The learning rate.
+ qmap1 : torch.Tensor
+ Quantization map for first Adam state.
+ qmap2 : torch.Tensor
+ Quantization map for second Adam state.
+ max1 : torch.Tensor
+ Max value for first Adam state update.
+ max2 : torch.Tensor
+ Max value for second Adam state update.
+ new_max1 : torch.Tensor
+ Max value for the next Adam update of the first state.
+ new_max2 : torch.Tensor
+ Max value for the next Adam update of the second state.
+ gnorm_scale : float
+ The factor to rescale the gradient to the max clip value.
+ '''
+
+ param_norm = 0.0
+ if max_unorm > 0.0:
+ param_norm = torch.norm(p.data.float())
+
+ if g.dtype == torch.float32 and state1.dtype == torch.uint8:
+ str2optimizer8bit[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
+ get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm),
+ ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
+ ct.c_int32(step), ct.c_float(lr),
+ get_ptr(qmap1), get_ptr(qmap2),
+ get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2),
+ ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
+ str2optimizer8bit[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
+ get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm),
+ ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
+ ct.c_int32(step), ct.c_float(lr),
+ get_ptr(qmap1), get_ptr(qmap2),
+ get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2),
+ ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ else:
+ raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
+
+
+def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
+ beta1: float, beta2: float, eps: float,
+ step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
+ absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0) -> None:
+
+
+ if g.dtype == torch.float32 and state1.dtype == torch.uint8:
+ str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
+ ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
+ ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
+ get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
+ str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
+ ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
+ ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
+ get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ else:
+ raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
+
+
+def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int=5):
+ """Applies percentile clipping
+
+ grad: torch.Tensor
+ The gradient tensor.
+ gnorm_vec: torch.Tensor
+ Vector of gradient norms. 100 elements expected.
+ step: int
+ The current optimiation steps (number of past gradient norms).
+
+ """
+ if grad.dtype == torch.float32:
+ lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
+ elif grad.dtype == torch.float16:
+ lib.cpercentile_clipping_g16(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
+ else:
+ raise ValueError(f'Gradient type {grad.dtype} not supported!')
+
+ current_gnorm = torch.sqrt(gnorm_vec[step % 100])
+ vals, idx = torch.sort(gnorm_vec)
+ clip_value = torch.sqrt(vals[percentile])
+ gnorm_scale = 1.0
+
+ if current_gnorm > clip_value:
+ gnorm_scale = clip_value/current_gnorm
+
+ return current_gnorm, clip_value, gnorm_scale
+
+
+def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
+ assert len(histogram.shape) == 2
+ assert histogram.dtype == torch.float32
+ assert source.dtype == torch.float32
+ assert index1.dtype == torch.int32
+ assert index2.dtype == torch.int32
+
+ assert histogram.device.type == 'cuda'
+ assert index1.device.type == 'cuda'
+ assert index2.device.type == 'cuda'
+ assert source.device.type == 'cuda'
+
+ maxdim1 = ct.c_int32(histogram.shape[0])
+ n = ct.c_int32(index1.numel())
+ lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
new file mode 100644
index 0000000..177540f
--- /dev/null
+++ b/bitsandbytes/nn/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .modules import StableEmbedding
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
new file mode 100644
index 0000000..bf0945c
--- /dev/null
+++ b/bitsandbytes/nn/modules.py
@@ -0,0 +1,44 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from typing import Optional
+
+from torch import Tensor
+from torch.nn.parameter import Parameter
+import torch.nn.functional as F
+
+from bitsandbytes.optim import GlobalOptimManager
+
+class StableEmbedding(torch.nn.Embedding):
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+ sparse: bool = True, _weight: Optional[Tensor] = None) -> None:
+ super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight)
+ self.norm = torch.nn.LayerNorm(embedding_dim)
+ GlobalOptimManager.get_instance().register_parameters(self.weight)
+ GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
+
+ def reset_parameters(self) -> None:
+ torch.nn.init.xavier_uniform_(self.weight)
+ self._fill_padding_idx_with_zero()
+
+ ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ to make the Layer compatible with Pytorch < 1.9.
+ This means that if this changes in future PyTorch releases this need to change too
+ which is cumbersome. However, with this we can ensure compatibility with previous
+ PyTorch releases.
+ '''
+ def _fill_padding_idx_with_zero(self) -> None:
+ if self.padding_idx is not None:
+ with torch.no_grad():
+ self.weight[self.padding_idx].fill_(0)
+
+ def forward(self, input: Tensor) -> Tensor:
+ emb = F.embedding(
+ input, self.weight, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
+
+ return self.norm(emb)
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
new file mode 100644
index 0000000..92c83b1
--- /dev/null
+++ b/bitsandbytes/optim/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .adam import Adam, Adam8bit, Adam32bit
+from .sgd import SGD, SGD8bit, SGD32bit
+from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
+from .lamb import LAMB, LAMB8bit, LAMB32bit
+from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
+from .optimizer import GlobalOptimManager
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
new file mode 100644
index 0000000..99a6d10
--- /dev/null
+++ b/bitsandbytes/optim/adam.py
@@ -0,0 +1,28 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from bitsandbytes.optim.optimizer import Optimizer2State
+
+class Adam(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(Adam, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class Adam8bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(Adam8bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class Adam32bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(Adam32bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+
+
diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py
new file mode 100644
index 0000000..b8d4b1e
--- /dev/null
+++ b/bitsandbytes/optim/lamb.py
@@ -0,0 +1,29 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import apex
+from bitsandbytes.optim.optimizer import Optimizer2State
+
+class LAMB(Optimizer2State):
+ def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
+ super(LAMB, self).__init__('lamb', params, lr, betas, eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+
+class LAMB8bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
+ super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+
+class LAMB32bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
+ super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+
+
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
new file mode 100644
index 0000000..40dede7
--- /dev/null
+++ b/bitsandbytes/optim/lars.py
@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from torch.optim import Optimizer
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+class LARS(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+class LARS8bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+class LARS32bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+
+class PytorchLARS(Optimizer):
+ def __init__(self, params, lr=0.01, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, max_unorm=0.02):
+ if lr < 0.0:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if momentum < 0.0:
+ raise ValueError("Invalid momentum value: {}".format(momentum))
+ if weight_decay < 0.0:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
+ weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
+ if nesterov and (momentum <= 0 or dampening != 0):
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ super(PytorchLARS, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(PytorchLARS, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('nesterov', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ d_p_list = []
+ momentum_buffer_list = []
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ nesterov = group['nesterov']
+ max_unorm = group['max_unorm']
+ lr = group['lr']
+
+ for p in group['params']:
+ if p.grad is None: continue
+
+ state = self.state[p]
+ d_p = p.grad
+ if weight_decay != 0:
+ d_p = d_p.add(param, alpha=weight_decay)
+
+ if momentum != 0:
+ buf = state.get('momentum_buffer', None)
+
+ if buf is None:
+ buf = torch.clone(d_p).detach()
+ state['momentum_buffer']= buf
+ else:
+ buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
+
+ if nesterov:
+ update = d_p + buf*momentum
+ else:
+ update = buf
+
+ update_scale = 1.0
+ if max_unorm > 0.0:
+ assert p.dtype == torch.float32
+ pnorm = torch.norm(p.detach())
+ unorm = torch.norm(update)
+ if unorm > max_unorm*pnorm:
+ update_scale = max_unorm*pnorm/unorm
+
+ p.add_(update, alpha=-lr*update_scale)
+
+ return loss
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
new file mode 100644
index 0000000..6743c15
--- /dev/null
+++ b/bitsandbytes/optim/optimizer.py
@@ -0,0 +1,460 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+import bitsandbytes.functional as F
+
+from copy import deepcopy
+from itertools import chain
+from collections import defaultdict, abc as container_abcs
+
+class MockArgs(object):
+ def __init__(self, initial_data):
+ for key in initial_data:
+ setattr(self, key, initial_data[key])
+
+
+class GlobalOptimManager(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.pid2config = {}
+ self.index2config = {}
+ self.optimizer = None
+ self.uses_config_override = False
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+ def register_parameters(self, params):
+ param_groups = list(params)
+ if not isinstance(param_groups[0], dict):
+ param_groups = [{'params': param_groups}]
+
+ for group_index, group in enumerate(param_groups):
+ for p_index, p in enumerate(group['params']):
+ if id(p) in self.pid2config:
+ self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
+
+ def override_config(self, parameters, key=None, value=None, key_value_dict=None):
+ '''
+ Overrides initial optimizer config for specific parameters.
+
+ The key-values of the optimizer config for the input parameters are overidden
+ This can be both, optimizer parameters like "betas", or "lr" or it can be
+ 8-bit specific paramters like "optim_bits", "percentile_clipping".
+
+ Parameters
+ ----------
+ parameters : torch.Tensor or list(torch.Tensors)
+ The input parameters.
+ key : str
+ The hyperparamter to override.
+ value : object
+ The value for the hyperparamters.
+ key_value_dict : dict
+ A dictionary with multiple key-values to override.
+ '''
+ self.uses_config_override = True
+ if isinstance(parameters, torch.nn.Parameter):
+ parameters = [parameters]
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ if key is not None and value is not None:
+ assert key_value_dict is None
+ key_value_dict = {key: value}
+
+ if key_value_dict is not None:
+ for p in parameters:
+ if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
+ else: self.pid2config[id(p)] = key_value_dict
+
+
+class Optimizer8bit(torch.optim.Optimizer):
+
+ def __init__(self, params, defaults, optim_bits=32):
+ super(Optimizer8bit, self).__init__(params, defaults)
+ self.checked_if_on_gpu = False
+ self.name2qmap = {}
+
+ self.mng = GlobalOptimManager.get_instance()
+ self.non_castable_tensor_keys = set(
+ ['qmap1', 'qmap2',
+ 'max1', 'max2',
+ 'new_max1', 'new_max2',
+ 'state1', 'state2',
+ 'gnorm_vec', 'absmax1', 'absmax2',
+ 'unorm_vec'])
+
+ if optim_bits == 8: self.fill_qmap()
+
+ def fill_qmap(self):
+ self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
+ self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)
+
+ def __setstate__(self, state):
+ super(Optimizer8bit, self).__setstate__(state)
+
+
+ def load_state_dict(self, state_dict):
+ r"""Loads the optimizer state.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = deepcopy(state_dict)
+ # Validate the state_dict
+ groups = self.param_groups
+ saved_groups = state_dict['param_groups']
+
+ if len(groups) != len(saved_groups):
+ raise ValueError("loaded state dict has a different number of "
+ "parameter groups")
+ param_lens = (len(g['params']) for g in groups)
+ saved_lens = (len(g['params']) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError("loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group")
+
+ # Update the state
+ id_map = {old_id: p for old_id, p in
+ zip(chain.from_iterable((g['params'] for g in saved_groups)),
+ chain.from_iterable((g['params'] for g in groups)))}
+
+ def cast(param, value):
+ r"""Make a deep copy of value, casting all tensors to device of param."""
+ if isinstance(value, torch.Tensor):
+ # Floating-point types are a bit special here. They are the only ones
+ # that are assumed to always match the type of params.
+ if param.is_floating_point() and value.dtype != torch.uint8:
+ value = value.to(param.dtype)
+ return value
+ elif isinstance(value, dict):
+ for k, v in value.items():
+ if k in self.non_castable_tensor_keys:
+ value[k] = v.to(param.device)
+ else:
+ value[k] = cast(param, v)
+
+ return value
+ elif isinstance(value, container_abcs.Iterable):
+ return type(value)(cast(param, v) for v in value)
+ else:
+ return value
+
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state = defaultdict(dict)
+ for k, v in state_dict['state'].items():
+ if k in id_map:
+ param = id_map[k]
+ state[param] = cast(param, v)
+ else:
+ state[k] = v
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group, new_group):
+ new_group['params'] = group['params']
+ return new_group
+ param_groups = [
+ update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({'state': state, 'param_groups': param_groups})
+
+ def to_gpu(self):
+ self.checked_if_on_gpu = True
+ for gindex, group in enumerate(self.param_groups):
+ for pindex, p in enumerate(group['params']):
+ if p in self.state:
+ values = self.state[p]
+ for k, v in values.items():
+ if isinstance(v, torch.Tensor):
+ self.state[p][k] = v.to(p.device)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ overflows = []
+
+ if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training
+ for gindex, group in enumerate(self.param_groups):
+ for pindex, p in enumerate(group['params']):
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self.init_state(group, p, gindex, pindex)
+
+ self.update_step(group, p, gindex, pindex)
+
+ return loss
+
+ def get_config(self, gindex, pindex, group):
+ config = {}
+ config['betas'] = group['betas']
+ config['eps'] = group['eps']
+ config['weight_decay'] = group['weight_decay']
+ config['lr'] = group['lr']
+ config['optim_bits'] = self.args.optim_bits
+ config['min_8bit_size'] = self.args.min_8bit_size
+ config['percentile_clipping'] = self.args.percentile_clipping
+ config['block_wise'] = self.args.block_wise
+ config['max_unorm'] = self.args.max_unorm
+
+ if (gindex, pindex) in self.mng.index2config:
+ config.update(self.mng.index2config[(gindex, pindex)])
+ return config
+
+ def init_state(self, group, p, gindex, pindex):
+ raise NotImplementedError(f'init_state method needs to be overidden')
+
+ def update_step(self, group, p, gindex, pindex):
+ raise NotImplementedError(f'The update_step method needs to be overidden')
+
+class Optimizer2State(Optimizer8bit):
+ def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0.0, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if isinstance(betas, str):
+ betas = eval(betas)
+ print(betas, 'parsed')
+ for i in range(len(betas)):
+ if not 0.0 <= betas[i] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay)
+ super(Optimizer2State, self).__init__(params, defaults, optim_bits)
+
+ if args is None:
+ args = {}
+ args['optim_bits'] = optim_bits
+ args['percentile_clipping'] = 100
+ args['min_8bit_size'] = min_8bit_size
+ args['percentile_clipping'] = percentile_clipping
+ args['block_wise'] = block_wise
+ args['max_unorm'] = max_unorm
+
+ self.args = MockArgs(args)
+ else:
+ self.args = args
+
+ self.optimizer_name = optimizer_name
+
+ @torch.no_grad()
+ def init_state(self, group, p, gindex, pindex):
+ config = self.get_config(gindex, pindex, group)
+
+ if config['optim_bits'] == 32:
+ dtype = torch.float32
+ elif config['optim_bits'] == 8:
+ dtype = torch.uint8
+ else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+
+ if p.numel() < config['min_8bit_size']: dtype = torch.float32
+
+ state = self.state[p]
+ state['step'] = 0
+
+ if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ elif dtype == torch.uint8:
+ if state['step'] == 0:
+ if 'dynamic' not in self.name2qmap: self.fill_qmap()
+ self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
+ self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
+
+ state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
+ state['qmap1'] = self.name2qmap['dynamic']
+
+ state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
+ state['qmap2'] = self.name2qmap['udynamic']
+
+ if config['block_wise']:
+ n = p.numel()
+ blocks = n//2048
+ blocks += 1 if n % 2048 > 0 else 0
+
+ state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ else:
+ state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+
+ if config['percentile_clipping'] < 100:
+ state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+
+ if config['max_unorm'] > 0.0:
+ state['unorm_vec'] = torch.zeros((1,), device=p.device)
+
+ @torch.no_grad()
+ def update_step(self, group, p, gindex, pindex):
+ state = self.state[p]
+ grad = p.grad
+
+ config = self.get_config(gindex, pindex, group)
+
+ state['step'] += 1
+ step = state['step']
+
+ if config['percentile_clipping'] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ else:
+ gnorm_scale = 1.0
+
+ if state['state1'].dtype == torch.float:
+ F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
+ state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+
+ elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
+ F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
+ config['eps'], step, config['lr'],
+ state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
+ config['weight_decay'], gnorm_scale=gnorm_scale,
+ unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+
+ # swap maxes
+ state['max1'], state['new_max1'] = state['new_max1'], state['max1']
+ state['max2'], state['new_max2'] = state['new_max2'], state['max2']
+ elif state['state1'].dtype == torch.uint8 and config['block_wise']:
+ F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
+ config['eps'], step, config['lr'],
+ state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
+ config['weight_decay'], gnorm_scale=gnorm_scale)
+
+
+class Optimizer1State(Optimizer8bit):
+ def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
+ weight_decay=0.0, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ for i in range(len(betas)):
+ if not 0.0 <= betas[i] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay)
+ super(Optimizer1State, self).__init__(params, defaults, optim_bits)
+
+ if args is None:
+ args = {}
+ args['optim_bits'] = optim_bits
+ args['percentile_clipping'] = 100
+ args['min_8bit_size'] = min_8bit_size
+ args['percentile_clipping'] = percentile_clipping
+ args['block_wise'] = block_wise
+ args['max_unorm'] = max_unorm
+
+ self.args = MockArgs(args)
+ else:
+ self.args = args
+
+ self.optimizer_name = optimizer_name
+
+ @torch.no_grad()
+ def init_state(self, group, p, gindex, pindex):
+ config = self.get_config(gindex, pindex, group)
+
+ if config['optim_bits'] == 32:
+ dtype = torch.float32
+ elif config['optim_bits'] == 8:
+ dtype = torch.uint8
+ else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+
+ if p.numel() < config['min_8bit_size']: dtype = torch.float32
+
+ state = self.state[p]
+ state['step'] = 0
+
+ if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ elif dtype == torch.uint8:
+ if state['step'] == 0:
+ if 'dynamic' not in self.name2qmap: self.fill_qmap()
+ self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
+
+ state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
+ state['qmap1'] = self.name2qmap['dynamic']
+
+ if config['block_wise']:
+ n = p.numel()
+ blocks = n//2048
+ blocks += 1 if n % 2048 > 0 else 0
+
+ state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ else:
+ state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+
+ if config['percentile_clipping'] < 100:
+ state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+
+ if config['max_unorm'] > 0.0:
+ state['unorm_vec'] = torch.zeros((1,), device=p.device)
+
+
+ @torch.no_grad()
+ def update_step(self, group, p, gindex, pindex):
+ state = self.state[p]
+ grad = p.grad
+
+ config = self.get_config(gindex, pindex, group)
+
+ state['step'] += 1
+ step = state['step']
+
+ if config['percentile_clipping'] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ else:
+ gnorm_scale = 1.0
+
+ if state['state1'].dtype == torch.float:
+ F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
+ None, 0.0, config['weight_decay'], gnorm_scale,
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+
+ elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
+ F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
+ config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
+ config['weight_decay'], gnorm_scale,
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+
+ state['max1'], state['new_max1'] = state['new_max1'], state['max1']
+ elif state['state1'].dtype == torch.uint8 and config['block_wise']:
+ F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
+ config['eps'], step, config['lr'],
+ state['qmap1'], None, state['absmax1'], None,
+ config['weight_decay'], gnorm_scale=gnorm_scale)
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
new file mode 100644
index 0000000..99b718e
--- /dev/null
+++ b/bitsandbytes/optim/rmsprop.py
@@ -0,0 +1,37 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+class RMSprop(Optimizer1State):
+ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if alpha == 0:
+ raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
+ if centered:
+ raise NotImplementError(f'Centered RMSprop is not supported!')
+ super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class RMSprop8bit(Optimizer1State):
+ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if alpha == 0:
+ raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
+ if centered:
+ raise NotImplementError(f'Centered RMSprop is not supported!')
+ super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class RMSprop32bit(Optimizer1State):
+ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+
+ if alpha == 0:
+ raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
+ if centered:
+ raise NotImplementError(f'Centered RMSprop is not supported!')
+ super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py
new file mode 100644
index 0000000..926d804
--- /dev/null
+++ b/bitsandbytes/optim/sgd.py
@@ -0,0 +1,32 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+class SGD(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class SGD8bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class SGD32bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
new file mode 100644
index 0000000..66a2c99
--- /dev/null
+++ b/csrc/kernels.cu
@@ -0,0 +1,1846 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <kernels.cuh>
+#include <cub/block/block_radix_sort.cuh>
+#include <cub/warp/warp_reduce.cuh>
+#include <cub/block/block_load.cuh>
+#include <cub/block/block_discontinuity.cuh>
+#include <cub/block/block_store.cuh>
+#include <cub/block/block_reduce.cuh>
+#include <cub/cub.cuh>
+#include <math_constants.h>
+
+#define HLF_MAX 65504
+#define TH 1024
+#define NUM 4
+#define NUM_BLOCK 4096
+
+// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
+__device__ float atomicMax(float* address, float val) {
+ int* address_as_i = reinterpret_cast<int*>(address);
+ int old = *address_as_i, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(
+ reinterpret_cast<int*>(address), assumed,
+ __float_as_int(fmaxf(val, __int_as_float(assumed))));
+ } while (assumed != old);
+ return __int_as_float(old);
+}
+
+__device__ float atomicMin(float* address, float val) {
+ int* address_as_i = reinterpret_cast<int*>(address);
+ int old = *address_as_i, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(
+ reinterpret_cast<int*>(address), assumed,
+ __float_as_int(fminf(val, __int_as_float(assumed))));
+ } while (assumed != old);
+ return __int_as_float(old);
+}
+
+template <int STOCHASTIC>
+__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
+{
+ int pivot = 127;
+ int upper_pivot = 255;
+ int lower_pivot = 0;
+
+ float lower = -1.0f;
+ float upper = 1.0f;
+
+ float val = smem_code[pivot];
+ // i>>=1 = {32, 16, 8, 4, 2, 1}
+ for(int i = 64; i > 0; i>>=1)
+ {
+ if(x > val)
+ {
+ lower_pivot = pivot;
+ lower = val;
+ pivot+=i;
+ }
+ else
+ {
+ upper_pivot = pivot;
+ upper = val;
+ pivot-=i;
+ }
+ val = smem_code[pivot];
+ }
+
+ if(upper_pivot == 255)
+ upper = smem_code[upper_pivot];
+ if(lower_pivot == 0)
+ lower = smem_code[lower_pivot];
+
+ if(!STOCHASTIC)
+ {
+ if(x > val)
+ {
+ float midpoint = (upper+val)*0.5f;
+ if(x > midpoint)
+ {
+ return upper_pivot;
+ }
+ else
+ return pivot;
+ }
+ else
+ {
+ float midpoint = (lower+val)*0.5f;
+ if(x < midpoint)
+ return lower_pivot;
+ else
+ return pivot;
+ }
+ }
+ else
+ {
+ if(x > val)
+ {
+ float dist_to_upper = fabsf(upper-x);
+ float dist_full = upper-val;
+ if(rand >= dist_to_upper/dist_full) return upper_pivot;
+ else return pivot;
+ }
+ else
+ {
+ float dist_to_lower = fabsf(lower-x);
+ float dist_full = val-lower;
+ if(rand >= dist_to_lower/dist_full) return lower_pivot;
+ else return pivot;
+ }
+ }
+}
+
+template <int SIGNED>
+__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x)
+{
+ int pivot = 127;
+ int upper_pivot = 255;
+ int lower_pivot = 0;
+
+ float lower = SIGNED ? -1.0f : 0.0f;
+ float upper = 1.0f;
+ float midpoint;
+ float val = quadrants[1];
+ int local_pivot = 1;
+ int offset = 1;
+
+ // i>>=1 = {32, 16, 8, 4, 2, 1}
+ for(int i = 64; i > 0; i>>=1)
+ {
+ if(x > val)
+ {
+ lower_pivot = pivot;
+ lower = val;
+ pivot+=i;
+ //val = i == 64 ? quadrants[2] : smem_code[pivot];
+ local_pivot += offset;
+ }
+ else
+ {
+ upper_pivot = pivot;
+ upper = val;
+ pivot-=i;
+ //val = i == 64 ? quadrants[0] : smem_code[pivot];
+ local_pivot -= offset;
+ }
+ val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot];
+ offset -= 1;
+ }
+
+ if(x > val)
+ {
+ midpoint = (upper+val)*0.5f;
+ if(x > midpoint)
+ return upper_pivot;
+ else
+ return pivot;
+ }
+ else
+ {
+ midpoint = (lower+val)*0.5f;
+ if(x < midpoint)
+ return lower_pivot;
+ else
+ return pivot;
+ }
+}
+
+template <int SIGNED>
+__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper)
+{
+ int lower_pivot = QUADRANT*16-1 - 0;
+ int pivot = QUADRANT*16-1 + 16;
+ int upper_pivot = QUADRANT*16-1 + 31;
+
+ float val = midpoint;
+
+ // i>>=1 = {32, 16, 8, 4, 2, 1}
+ for(int i = 16; i > 0; i>>=1)
+ {
+ if(x > val)
+ {
+ lower_pivot = pivot;
+ lower = val;
+ pivot+=i;
+ }
+ else
+ {
+ upper_pivot = pivot;
+ upper = val;
+ pivot-=i;
+ }
+ val = smem_code[pivot];
+ }
+
+ if(x > val)
+ {
+ midpoint = (upper+val)*0.5f;
+ if(x > midpoint)
+ return upper_pivot;
+ else
+ return pivot;
+ }
+ else
+ {
+ midpoint = (lower+val)*0.5f;
+ if(x < midpoint)
+ return lower_pivot;
+ else
+ return pivot;
+ }
+}
+
+__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
+{
+ const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
+ const int numThreads = blockDim.x*gridDim.x;
+
+ for(int i = tid; i < n; i+=numThreads)
+ {
+ int idx = (index1[i]*maxidx1) + index2[i];
+ atomicAdd(&histogram[idx], src[i]);
+ }
+}
+
+template<typename T, int BLOCK_SIZE, int NUM_MAX>
+__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n)
+{
+ typedef cub::WarpReduce<T> WarpReduce;
+ __shared__ typename WarpReduce::TempStorage temp_storage;
+ typedef cub::BlockLoad<T, BLOCK_SIZE/8 , 8, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+ __shared__ typename LoadT::TempStorage loadt;
+
+ const int warp_idx = threadIdx.x/32;
+ const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE);
+
+ // BLOCK_SIZE/32 == number of warps
+ __shared__ int smem_max_indices[8*BLOCK_SIZE/32];
+ __shared__ float smem_max_values[8*BLOCK_SIZE/32];
+
+ T values[8];
+ T max1 = -64000.0f;
+ T max2 = -64000.0f;
+ int max_idx1 = -1;
+ int max_idx2 = -1;
+ int sign1 = -1;
+ int sign2 = -1;
+
+ // 1. load 8 values per thread
+ // 2. compute 2-max in registers (64 max per warp)
+ // 3. do warp reduction + broadcast back
+ // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
+ // 5. Repeat (3) 8 times for top 8 values in 256
+ // 6. store with byte index
+
+ LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f);
+ #pragma unroll 8
+ for(int i = 0; i < 8; i++)
+ {
+ T absval = fabsf(values[i]);
+ if(absval > max1)
+ {
+ max1 = values[i];
+ sign1 = signbit(values[i]);
+ max_idx1 = 8*threadIdx.x + i;
+ }
+ else if(absval > max2)
+ {
+ max2 = values[i];
+ sign2 = signbit(values[i]);
+ max_idx2 = 8*threadIdx.x + i;
+ }
+ }
+
+ float warp_max;
+ for(int i = 0; i < 8; i++)
+ {
+ // 3. do warp reduction + broadcast back
+ warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max());
+ warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff);
+
+ // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
+ if(warp_max == max1)
+ {
+ smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1;
+ smem_max_indices[warp_idx*8 + i] = max_idx1;
+
+ sign1 = sign2;
+ max1 = max2;
+ max_idx1 = max_idx2;
+
+ max2 = -64000.0f;
+ }
+ __syncwarp();
+ }
+
+ if(threadIdx.x % 32 < 8)
+ {
+ // offset: 8 values per 256 input values
+ //
+ int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8;
+ }
+
+}
+
+#define THREADS_ESTIMATE 512
+#define NUM_ESTIMATE 8
+#define BLOCK_ESTIMATE 4096
+
+template<typename T>
+__launch_bounds__(THREADS_ESTIMATE, 1)
+__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n)
+{
+ const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE);
+ int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE;
+ const int base_idx = (blockIdx.x * BLOCK_ESTIMATE);
+ const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE));
+
+ T vals[NUM_ESTIMATE];
+
+ typedef cub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
+ typedef cub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
+
+ __shared__ union {
+ typename LoadFloat::TempStorage loadf;
+ typename BlockRadixSort::TempStorage sort;
+ int smem_qidx[BLOCK_ESTIMATE];
+ } temp_storage;
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE)
+ {
+ valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i;
+
+ // do not process half-blocks
+ if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; }
+
+ #pragma unroll 4
+ for(int j = 0; j < NUM_ESTIMATE; j++)
+ vals[j] = max_val;
+
+ __syncthreads();
+ LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items);
+
+ #pragma unroll 4
+ for(int j = 0; j < NUM_ESTIMATE; j++)
+ vals[j] = ((float)vals[j]) * reciprocal_num_blocks;
+
+
+ __syncthreads();
+ // sort into striped pattern to mitigate bank conflicts
+ // striped pattern index for thread 0 [0, 1024, 2048, 3096]
+ // striped pattern index for thread 1 [1, 1025, 2049, 3097]
+ BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals);
+
+ __syncthreads();
+ for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x)
+ temp_storage.smem_qidx[j] = -1;
+
+ if(threadIdx.x < 256)
+ {
+ float q_interval = (1.0f-(2.0f*offset))/255.0f;
+ int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1)));
+ temp_storage.smem_qidx[local_idx] = threadIdx.x;
+ }
+
+ __syncthreads();
+
+ for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x)
+ {
+ if(temp_storage.smem_qidx[i] != -1)
+ atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]);
+ }
+ }
+}
+
+
+__launch_bounds__(TH, 4)
+__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
+{
+ const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
+ int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK;
+ const int base_idx = (blockIdx.x * NUM_BLOCK);
+
+ float vals[NUM];
+ unsigned char qvals[NUM];
+ //const int lane_id = threadIdx.x % 2;
+
+ typedef cub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
+ typedef cub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
+
+ __shared__ typename LoadFloat::TempStorage loadf;
+ __shared__ typename StoreChar::TempStorage storec;
+ __shared__ float smem_code[256];
+ //__shared__ float smem_code[2][257];
+
+ if(threadIdx.x < 256)
+ {
+ smem_code[threadIdx.x] = code[threadIdx.x];
+ //smem_code[0][threadIdx.x] = code[threadIdx.x];
+ //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x];
+ }
+
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK)
+ {
+ // number of values already processed in blocks +
+ // number of values already processed in this block +
+ // rand_offset % mod value
+ valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;
+
+ __syncthreads();
+ LoadFloat(loadf).Load(&(A[i]), vals, valid_items);
+
+
+ #pragma unroll 4
+ for(int j = 0; j < NUM; j++)
+ qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]);
+
+ __syncthreads();
+ StoreChar(storec).Store(&(out[i]), qvals, valid_items);
+ }
+}
+
+template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
+__launch_bounds__(TH, 4)
+__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
+{
+ const int n_full = gridDim.x * BLOCK_SIZE;
+ int valid_items = 0;
+ const int base_idx = (blockIdx.x * BLOCK_SIZE);
+
+ T vals[NUM];
+ float rand_vals[NUM];
+ unsigned char qvals[NUM];
+ //float local_abs_max = -FLT_MAX;
+ float local_abs_max = 0.0f;
+ int local_rand_idx = 0;
+
+ typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+ typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
+ typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
+ typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
+
+ __shared__ typename LoadT::TempStorage loadt;
+ __shared__ typename LoadFloat::TempStorage loadf;
+ __shared__ typename StoreChar::TempStorage storec;
+ __shared__ typename BlockReduce::TempStorage reduce;
+ __shared__ float smem_code[256];
+ __shared__ float smem_absmax_value[1];
+
+ if(threadIdx.x < 256)
+ smem_code[threadIdx.x] = code[threadIdx.x];
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
+ {
+ valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
+ local_abs_max = -FLT_MAX;
+
+ __syncthreads();
+ LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f);
+
+ // 1. compute local max
+ // 2. broadcast local max
+ // 3. normalize inputs and quantize
+
+ #pragma unroll NUM_PER_TH
+ for(int j = 0; j < NUM_PER_TH; j++)
+ local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
+
+ local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
+
+ if(threadIdx.x == 0)
+ smem_absmax_value[0] = local_abs_max;
+
+ __syncthreads();
+
+ if(threadIdx.x == 0)
+ absmax[i/BLOCK_SIZE] = local_abs_max;
+ else
+ local_abs_max = smem_absmax_value[0];
+
+ __syncwarp();
+
+ local_abs_max = 1.0f/local_abs_max;
+
+ if(STOCHASTIC)
+ {
+ local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4);
+ LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
+ }
+
+ #pragma unroll NUM_PER_TH
+ for(int j = 0; j < NUM_PER_TH; j++)
+ {
+ if(!STOCHASTIC)
+ qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
+ else
+ qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
+ }
+
+ __syncthreads();
+ StoreChar(storec).Store(&(out[i]), qvals, valid_items);
+ }
+}
+
+template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
+__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n)
+{
+
+ const int n_full = gridDim.x * BLOCK_SIZE;
+ int valid_items = 0;
+ const int base_idx = (blockIdx.x * BLOCK_SIZE);
+
+ T vals[NUM];
+ unsigned char qvals[NUM];
+ float local_abs_max = -FLT_MAX;
+
+ typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
+ typedef cub::BlockStore<T, THREADS, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
+
+ __shared__ typename LoadChar::TempStorage loadchar;
+ __shared__ typename StoreT::TempStorage storet;
+ __shared__ float smem_code[256];
+
+ if(threadIdx.x < 256)
+ smem_code[threadIdx.x] = code[threadIdx.x];
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
+ {
+ valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
+ local_abs_max = absmax[i/BLOCK_SIZE];
+
+ __syncthreads();
+ LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
+
+ #pragma unroll NUM_PER_TH
+ for(int j = 0; j < NUM_PER_TH; j++)
+ vals[j] = smem_code[qvals[j]]*local_abs_max;
+
+ __syncthreads();
+ StoreT(storet).Store(&(out[i]), vals, valid_items);
+ }
+}
+
+
+__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
+{
+ const unsigned int numThreads = blockDim.x * gridDim.x;
+ const int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
+
+ __shared__ float smem_code[256];
+ if(threadIdx.x < 256)
+ {
+ smem_code[threadIdx.x] = code[threadIdx.x];
+ }
+
+ __syncthreads();
+
+ for (int i = idx;i < n; i += numThreads)
+ {
+ out[i] = smem_code[A[i]];
+ }
+}
+
+
+
+template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
+__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
+__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
+ float* state1, float* state2, float *unorm,
+ const float beta1, const float beta2, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n)
+{
+
+ const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
+ const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
+ int valid_items = 0;
+
+ T g_vals[NUM_VALS];
+
+ float s1_vals[NUM_VALS];
+ float s2_vals[NUM_VALS];
+
+ const float correction1 = 1.0f/(1.0f - powf(beta1, step));
+ const float correction2 = 1.0f/(1.0f - powf(beta2, step));
+
+ typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
+ typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
+ typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
+
+ __shared__ union {
+ typename Load::TempStorage load;
+ typename LoadFloat::TempStorage loadf;
+ typename BlockReduce::TempStorage reduce;
+ } temp_storage;
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
+ {
+ valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
+
+ __syncthreads();
+ Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
+ __syncthreads();
+ LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
+ __syncthreads();
+ LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f);
+
+ # pragma unroll NUM_VALS
+ for(unsigned int j = 0; j < NUM_VALS; j++)
+ g_vals[j] = gnorm_scale*((float)g_vals[j]);
+
+ # pragma unroll NUM_VALS
+ for(unsigned int j = 0; j < NUM_VALS; j++)
+ {
+ switch(OPTIMIZER)
+ {
+ case ADAM:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
+ s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
+ s1_vals[j] *= correction1;
+ s2_vals[j] *= correction2;
+ s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
+ s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update)
+ break;
+ }
+ }
+
+ # pragma unroll NUM_VALS-1
+ for(unsigned int j = 1; j < NUM_VALS; j++)
+ s1_vals[0] += s1_vals[j];
+
+ __syncthreads();
+ s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]);
+
+ if(threadIdx.x == 0)
+ atomicAdd(&unorm[0], s1_vals[0]);
+
+ __syncwarp();
+ }
+}
+
+
+
+#define NUM_PER_THREAD 4
+
+template<typename T, int OPTIMIZER>
+__launch_bounds__(TH, 1)
+__global__ void kOptimizer32bit2State(T* g, T* p,
+ float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+ const float beta1, const float beta2, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n)
+{
+
+ const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
+ const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
+ int valid_items = 0;
+ float update_scale = 0.0f;
+ T g_vals[NUM_PER_THREAD];
+ T p_vals[NUM_PER_THREAD];
+
+ float s1_vals[NUM_PER_THREAD];
+ float s2_vals[NUM_PER_THREAD];
+
+ const float correction1 = 1.0f - powf(beta1, step);
+ const float correction2 = sqrtf(1.0f - powf(beta2, step));
+ const float step_size = -lr*correction2/correction1;
+
+ if(max_unorm > 0.0f)
+ {
+ update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
+ if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
+ else{ update_scale = 1.0f; }
+ }
+ else{ update_scale = 1.0f; }
+
+ typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
+ typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
+
+ typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
+ typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
+
+ __shared__ union {
+ typename Load::TempStorage load;
+ typename Store::TempStorage store;
+ typename LoadFloat::TempStorage loadf;
+ typename StoreFloat::TempStorage storef;
+ } temp_storage;
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
+ {
+ valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
+
+ __syncthreads();
+ Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
+ __syncthreads();
+ LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
+ __syncthreads();
+ LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items);
+ __syncthreads();
+ Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
+
+ # pragma unroll 4
+ for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
+ g_vals[j] = gnorm_scale*((float)g_vals[j]);
+
+ # pragma unroll 4
+ for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
+ {
+ switch(OPTIMIZER)
+ {
+ case ADAM:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
+ s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
+ p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+ break;
+ }
+ }
+
+ __syncthreads();
+ Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
+ __syncthreads();
+ StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
+ __syncthreads();
+ StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items);
+ }
+}
+
+template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
+__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
+__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
+ float* state1, float *unorm,
+ const float beta1, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n)
+{
+
+ const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
+ const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
+ int valid_items = 0;
+
+ T g_vals[NUM_VALS];
+
+ float s1_vals[NUM_VALS];
+
+ typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
+ typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
+ typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
+
+ __shared__ union {
+ typename Load::TempStorage load;
+ typename LoadFloat::TempStorage loadf;
+ typename BlockReduce::TempStorage reduce;
+ } temp_storage;
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
+ {
+ valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
+
+ __syncthreads();
+ Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
+ __syncthreads();
+ LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
+
+ # pragma unroll NUM_VALS
+ for(unsigned int j = 0; j < NUM_VALS; j++)
+ g_vals[j] = gnorm_scale*((float)g_vals[j]);
+
+ # pragma unroll NUM_VALS
+ for(unsigned int j = 0; j < NUM_VALS; j++)
+ {
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = (float)g_vals[j]; // state update
+ else
+ s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
+ s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
+ s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
+ s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
+ break;
+ }
+ }
+
+ # pragma unroll
+ for(unsigned int j = 1; j < NUM_VALS; j++)
+ s1_vals[0] += s1_vals[j];
+
+ __syncthreads();
+ s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items);
+
+ if(threadIdx.x == 0)
+ atomicAdd(&unorm[0], s1_vals[0]);
+
+ __syncwarp();
+ }
+}
+
+template<typename T, int OPTIMIZER>
+__launch_bounds__(TH, 1)
+__global__ void kOptimizer32bit1State(T *g, T *p,
+ float *state1, float *unorm, const float max_unorm, const float param_norm,
+ const float beta1, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n)
+{
+
+ const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
+ const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
+ int valid_items = 0;
+ float update_scale = 0.0f;
+
+ if(max_unorm > 0.0f)
+ {
+ update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
+ if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; }
+ else{ update_scale = 1.0f; }
+ }
+ else{ update_scale = 1.0f; }
+
+ T g_vals[NUM_PER_THREAD];
+ T p_vals[NUM_PER_THREAD];
+
+ float s1_vals[NUM_PER_THREAD];
+
+ typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
+ typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
+
+ typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
+ typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
+
+ __shared__ union {
+ typename Load::TempStorage load;
+ typename Store::TempStorage store;
+ typename LoadFloat::TempStorage loadf;
+ typename StoreFloat::TempStorage storef;
+ } temp_storage;
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
+ {
+ valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
+
+ __syncthreads();
+ Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
+ __syncthreads();
+ LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
+ __syncthreads();
+ Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
+
+ # pragma unroll 4
+ for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
+ {
+ g_vals[j] = gnorm_scale*((float)g_vals[j]);
+ if(weight_decay > 0.0f)
+ g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
+ }
+
+ # pragma unroll 4
+ for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
+ {
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = (float)g_vals[j];
+ else
+ s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
+
+ p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
+ p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
+ break;
+ }
+ }
+
+ __syncthreads();
+ Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
+ __syncthreads();
+ StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
+ }
+}
+
+
+#define NUM8BIT 16
+#define NUM_THREADS 256
+#define NUM_PER_BLOCK 4096
+
+template<typename T, int OPTIMIZER>
+__global__ void
+__launch_bounds__(NUM_THREADS, 2)
+kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
+ float *unorm,
+ const float beta1, const float beta2,
+ const float eps, const int step,
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
+ float* max1, float* max2, float* new_max1, float* new_max2,
+ const float gnorm_scale, const int n)
+{
+ const int n_full = gridDim.x * NUM_PER_BLOCK;
+ const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
+ int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
+ float g_val = 0.0f;
+ float local_max_s1 = -FLT_MAX;
+ float local_max_s2 = -FLT_MAX;
+ float local_unorm = 0.0f;
+
+ float s2_vals[NUM8BIT];
+ float s1_vals[NUM8BIT];
+ T g_vals[NUM8BIT];
+ unsigned char m_c1[NUM8BIT];
+ unsigned char r_c2[NUM8BIT];
+
+ typedef cub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+ typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
+ typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
+
+
+ __shared__ union {
+ typename LoadT::TempStorage loadh;
+ typename LoadUInt8::TempStorage loadc;
+ typename BlockReduce::TempStorage reduce;
+ } temp_storage;
+
+ __shared__ float smem_quantiles1[256];
+ __shared__ float smem_quantiles2[256];
+
+ if(threadIdx.x < 256)
+ {
+ smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
+ smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x];
+ }
+
+ __syncthreads();
+
+ for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT)
+ {
+ valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
+
+ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
+ __syncthreads();
+ LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
+ __syncthreads();
+ LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128);
+ __syncthreads();
+
+ #pragma unroll 16
+ for(int j = 0; j < NUM8BIT; j++)
+ {
+ g_val = g_vals[j];
+ g_val *= gnorm_scale;
+ s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1;
+ s1_vals[j] += (1.0f-beta1)*g_val;
+ local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
+ }
+
+ #pragma unroll 16
+ for(int j = 0; j < NUM8BIT; j++)
+ {
+ g_val = g_vals[j];
+ g_val *= gnorm_scale;
+ s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2;
+ s2_vals[j] += (1.0f-beta2)*g_val*g_val;
+ local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j]));
+ }
+
+ if(unorm != NULL)
+ {
+ #pragma unroll 16
+ for(int j = 0; j < NUM8BIT; j++)
+ {
+ float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step));
+ float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step));
+ s1_vals[j] *= correction1;
+ s2_vals[j] *= correction2;
+ float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
+ local_unorm += update_val*update_val;
+ }
+ }
+ }
+
+ __syncthreads();
+ local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
+ __syncthreads();
+ local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items);
+ if(unorm != NULL)
+ {
+ __syncthreads();
+ local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
+ }
+
+ if(threadIdx.x == 0)
+ {
+ atomicMax(&new_max1[0], local_max_s1);
+ atomicMax(&new_max2[0], local_max_s2);
+ if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); }
+ }
+}
+
+#define NUM_PER_THREAD2 4
+#define NUM_THREADS2 1024
+#define NUM_PER_BLOCK2 4096
+
+template<typename T, int OPTIMIZER>
+__global__ void
+__launch_bounds__(NUM_THREADS2, 1)
+kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
+ const float *unorm, const float max_unorm, const float param_norm, \
+ const float beta1, const float beta2,
+ const float eps, const int step, const float lr,
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
+ float* max1, float* max2, float* new_max1, float* new_max2,
+ float weight_decay,
+ const float gnorm_scale, const int n)
+{
+
+ const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
+ const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
+ int valid_items = 0;
+ float g_val = 0.0f;
+ float s1_vals[NUM_PER_THREAD2];
+ float s2_vals[NUM_PER_THREAD2];
+ const float correction1 = 1.0f - powf(beta1, step);
+ const float correction2 = sqrtf(1.0f - powf(beta2, step));
+ const float step_size = -lr*correction2/correction1;
+ //const float step_size = -lr*correction2/correction1;
+ float new_max_val1 = 1.0f/new_max1[0];
+ float new_max_val2 = 1.0f/new_max2[0];
+ float update_scale = 1.0f;
+
+ if(max_unorm > 0.0f)
+ {
+ update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
+ if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
+ else{ update_scale = 1.0f; }
+ }
+ else{ update_scale = 1.0f; }
+
+ unsigned char c1s[NUM_PER_THREAD2];
+ unsigned char c2s[NUM_PER_THREAD2];
+ T p_vals[NUM_PER_THREAD2];
+ T g_vals[NUM_PER_THREAD2];
+ typedef cub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+ typedef cub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
+
+ typedef cub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
+ typedef cub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
+
+ __shared__ float smem_quantiles1[256];
+ __shared__ float smem_quantiles2[256];
+
+ __shared__ union {
+ typename LoadT::TempStorage loadh;
+ typename LoadChar::TempStorage loadc;
+ typename StoreChar::TempStorage storec;
+ typename StoreT::TempStorage storeh;
+ } temp_storage;
+
+ if(threadIdx.x < 512)
+ {
+ if(threadIdx.x < 256)
+ smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
+ else
+ smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256];
+ }
+
+ __syncthreads();
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
+ {
+ valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
+ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
+ __syncthreads();
+ LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
+ __syncthreads();
+ LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
+ __syncthreads();
+ LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
+
+ if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
+
+ # pragma unroll 4
+ for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
+ {
+ g_val = float(g_vals[j]);
+ g_val *= gnorm_scale;
+ s1_vals[j] = smem_quantiles1[c1s[j]];
+ s1_vals[j] = s1_vals[j]*max1[0];
+
+ s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
+
+ c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);
+
+ // make sure state1 term has still the same sign after quantization
+ // (not needed for state2 term which has only positive values)
+ if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
+ {
+ if(s1_vals[j] > 0.0f)
+ c1s[j] += 1;
+ else
+ c1s[j] -= 1;
+ }
+
+ s2_vals[j] = smem_quantiles2[c2s[j]];
+ s2_vals[j] = s2_vals[j]*max2[0];
+ s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
+ c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2);
+ }
+
+ # pragma unroll 4
+ for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
+ {
+ p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps))))));
+ if(weight_decay > 0.0f)
+ p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay));
+ }
+
+ StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
+ __syncthreads();
+ StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
+ __syncthreads();
+ StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
+ __syncthreads();
+ }
+}
+
+
+template<typename T, int OPTIMIZER>
+__global__ void
+__launch_bounds__(NUM_THREADS, 2)
+kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
+ float *unorm,
+ const float beta1,
+ const float eps, const int step,
+ float* __restrict__ const quantiles1,
+ float* max1, float* new_max1,
+ const float weight_decay,
+ const float gnorm_scale, const int n)
+{
+ const int n_full = gridDim.x * NUM_PER_BLOCK;
+ const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
+ int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
+ float g_val = 0.0f;
+ float local_max_s1 = -FLT_MAX;
+ float local_unorm = 0.0f;
+
+ float s1_vals[NUM8BIT];
+ T g_vals[NUM8BIT];
+ unsigned char m_c1[NUM8BIT];
+
+ typedef cub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+ typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
+ typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
+
+
+ __shared__ union {
+ typename LoadT::TempStorage loadh;
+ typename LoadUInt8::TempStorage loadc;
+ typename BlockReduce::TempStorage reduce;
+ } temp_storage;
+
+ __shared__ float smem_quantiles1[256];
+
+ if(threadIdx.x < 256)
+ smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
+
+ __syncthreads();
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT)
+ {
+ valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
+
+ __syncthreads();
+ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
+ __syncthreads();
+ LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
+
+ #pragma unroll 16
+ for(int j = 0; j < NUM8BIT; j++)
+ {
+ g_val = g_vals[j];
+ g_val *= gnorm_scale;
+ s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0];
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = (float)g_vals[j];
+ else
+ s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
+ if(unorm != NULL)
+ local_unorm += s1_vals[j]*s1_vals[j];
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
+ break;
+ }
+
+ local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
+ }
+ }
+
+ __syncthreads();
+ local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items);
+ if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); }
+ if(unorm != NULL)
+ {
+ __syncthreads();
+ local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items);
+ if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); }
+ }
+
+}
+
+template<typename T, int OPTIMIZER>
+__global__ void
+kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
+ const float *unorm, const float max_unorm, const float param_norm,
+ const float beta1,
+ const float eps, const int step, const float lr,
+ float* __restrict__ const quantiles1,
+ float* max1, float* new_max1,
+ float weight_decay,
+ const float gnorm_scale, const int n)
+{
+
+ const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
+ const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
+ int valid_items = 0;
+ float g_val = 0.0f;
+ float s1_vals[NUM_PER_THREAD2];
+ float new_max_val1 = 1.0f/new_max1[0];
+ float update_scale = 1.0f;
+
+ if(max_unorm > 0.0f)
+ {
+ update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
+ if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
+ else{ update_scale = 1.0f; }
+ }
+ else{ update_scale = 1.0f; }
+
+ unsigned char c1s[NUM_PER_THREAD2];
+ T p_vals[NUM_PER_THREAD2];
+ T g_vals[NUM_PER_THREAD2];
+ typedef cub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+ typedef cub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
+
+ typedef cub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
+ typedef cub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
+
+ __shared__ float smem_quantiles1[256];
+
+ __shared__ union {
+ typename LoadT::TempStorage loadh;
+ typename LoadChar::TempStorage loadc;
+ typename StoreChar::TempStorage storec;
+ typename StoreT::TempStorage storeh;
+ } temp_storage;
+
+ if(threadIdx.x < 256)
+ smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
+
+ __syncthreads();
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
+ {
+ valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
+ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
+ __syncthreads();
+ LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
+ __syncthreads();
+ LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
+
+ if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
+
+ # pragma unroll 4
+ for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
+ {
+ g_val = float(g_vals[j]);
+ g_val *= gnorm_scale;
+ if(weight_decay > 0.0f)
+ g_val += ((float)p_vals[j])*weight_decay;
+ s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
+
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = g_vals[j];
+ else
+ s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
+
+ p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
+ p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
+ break;
+ }
+
+ c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);
+
+ // make sure state1 term has still the same sign after quantization
+ if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
+ {
+ if(s1_vals[j] > 0.0f)
+ c1s[j] += 1;
+ else
+ c1s[j] -= 1;
+ }
+ }
+
+ StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
+ __syncthreads();
+ StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
+ __syncthreads();
+ }
+}
+
+
+template<typename T, int BLOCK_SIZE, int NUM_VALS>
+__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n)
+{
+ const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
+ int valid_items = 0;
+
+ typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
+ typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+
+ __shared__ typename BlockReduce::TempStorage reduce;
+
+ __shared__ typename LoadT::TempStorage loadT;
+ T vals[NUM_VALS];
+ float local_sum = 0.0f;
+
+ for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE)
+ {
+ valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
+ local_sum = 0.0f;
+
+ __syncthreads();
+ LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f);
+
+ #pragma unroll NUM_VALS
+ for(int j = 0; j < NUM_VALS; j++)
+ local_sum += ((float)vals[j])*((float)vals[j]);
+
+ local_sum = BlockReduce(reduce).Sum(local_sum, valid_items);
+ if(threadIdx.x == 0)
+ {
+ if(step == 1)
+ {
+ // initialize with the same norm for all positions
+ //#pragma unroll 10
+ for(int j = 0; j < 100; j++)
+ atomicAdd(&gnorm_vec[j], local_sum);
+ }
+ else
+ atomicAdd(&gnorm_vec[step % 100], local_sum);
+ }
+
+ }
+}
+
+
+#define LANES 2
+#define QUAD 3
+template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
+__launch_bounds__(256, 3)
+__global__ void
+kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
+ const float beta1, const float beta2,
+ const float eps, const int step, const float lr,
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
+ float* absmax1, float* absmax2,
+ float weight_decay,
+ const float gnorm_scale, const int n)
+{
+
+ //const int n_full = n + (n%BLOCK_SIZE);
+ const int n_full = gridDim.x * BLOCK_SIZE;
+ const int base_idx = (blockIdx.x * BLOCK_SIZE);
+ int valid_items = 0;
+ float g_val = 0.0f;
+ float s1_vals[N_PER_TH];
+ float s2_vals[N_PER_TH];
+ // 2-5%
+ const float correction1 = 1.0f - __powf(beta1, step);
+ const float correction2 = sqrtf(1.0f -__powf(beta2, step));
+ const float step_size = __fdividef(-lr*correction2,correction1);
+ const int lane_id = threadIdx.x % LANES;
+ float new_local_abs_max1 = -FLT_MAX;
+ float new_local_abs_max2 = -FLT_MAX;
+ float quadrants1[QUAD];
+ float quadrants2[QUAD];
+
+ unsigned char c1s[N_PER_TH];
+ unsigned char c2s[N_PER_TH];
+ T g_vals[N_PER_TH];
+ typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+ typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
+
+ typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
+ typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
+
+ __shared__ float smem_quantiles1[LANES][257];
+ __shared__ float smem_quantiles2[LANES][257];
+ typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
+ typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
+ __shared__ typename BlockReduce1::TempStorage reduce1;
+ __shared__ typename BlockReduce2::TempStorage reduce2;
+ __shared__ float smem_exchange1[1];
+ __shared__ float smem_exchange2[1];
+
+ __shared__ union {
+ typename LoadT::TempStorage loadh;
+ typename LoadChar::TempStorage loadc;
+ typename StoreChar::TempStorage storec;
+ typename StoreT::TempStorage storeh;
+ } temp_storage;
+ // init: 0.2 -> 0.23
+
+ // 0.23 -> 0.23
+ smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
+ smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x];
+ # pragma unroll
+ for(unsigned int j = 1; j < LANES; j++)
+ {
+ smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
+ smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x];
+ }
+
+ __syncthreads();
+
+ #pragma unroll
+ for(int k = 0; k < QUAD; k++)
+ {
+ quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
+ quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
+ }
+
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
+ {
+ // loads: 0.23 -> 0.85/1.44
+ valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
+ __syncthreads();
+ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
+ __syncthreads();
+ LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
+ __syncthreads();
+ LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
+
+ new_local_abs_max1 = -FLT_MAX;
+ new_local_abs_max2 = -FLT_MAX;
+
+ // update: 2.48/1.57 -> 2.51/1.60
+ # pragma unroll N_PER_TH
+ for(unsigned int j = 0; j < N_PER_TH; j++)
+ {
+ g_val = float(g_vals[j]);
+ g_val *= gnorm_scale;
+ s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
+ s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
+
+ s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
+ s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
+
+ new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
+ new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
+ }
+
+
+ // reduce: 2.51/1.60 -> 2.67/1.69
+ new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
+ new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max());
+
+ if(threadIdx.x == 0)
+ {
+ smem_exchange1[0] = new_local_abs_max1;
+ smem_exchange2[0] = new_local_abs_max2;
+ }
+
+ __syncthreads();
+
+ if(threadIdx.x == 0)
+ {
+ absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
+ absmax2[i/BLOCK_SIZE] = new_local_abs_max2;
+ }
+ else
+ {
+ new_local_abs_max1 = smem_exchange1[0];
+ new_local_abs_max2 = smem_exchange2[0];
+ }
+
+ __syncthreads();
+ LoadT(temp_storage.loadh).Load(&(p[i]), g_vals, valid_items, (T)0.0f);
+ // reduce: 2.67/1.69 -> 2.67/1.70
+ # pragma unroll N_PER_TH
+ for(unsigned int j = 0; j < N_PER_TH; j++)
+ {
+ g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
+ if(weight_decay > 0.0f)
+ g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
+ }
+
+ // store: 0.85/1.44 -> 2.48/1.57
+ __syncthreads();
+ StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items);
+
+ // quantizaztion: 2.67/1.70 -> 3.4/3.3
+ # pragma unroll N_PER_TH
+ for(unsigned int j = 0; j < N_PER_TH; j++)
+ {
+ c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
+ c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2));
+
+ // make sure state1 term has still the same sign after quantization
+ // (not needed for state2 term which has only positive values)
+ if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
+ {
+ if(s1_vals[j] > 0.0f)
+ c1s[j] += 1;
+ else
+ c1s[j] -= 1;
+ }
+ }
+
+ __syncthreads();
+ StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
+ __syncthreads();
+ StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
+ }
+}
+
+
+#define LANES 2
+#define QUAD 3
+template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
+__launch_bounds__(256, 3)
+__global__ void
+kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
+ const float beta1, const float beta2,
+ const float eps, const int step, const float lr,
+ float* __restrict__ const quantiles1,
+ float* absmax1,
+ float weight_decay,
+ const float gnorm_scale, const int n)
+{
+
+ //const int n_full = n + (n%BLOCK_SIZE);
+ const int n_full = gridDim.x * BLOCK_SIZE;
+ const int base_idx = (blockIdx.x * BLOCK_SIZE);
+ int valid_items = 0;
+ float g_val = 0.0f;
+ float s1_vals[N_PER_TH];
+ // 2-5%
+ const int lane_id = threadIdx.x % LANES;
+ float new_local_abs_max1 = -FLT_MAX;
+ float quadrants1[QUAD];
+
+ unsigned char c1s[N_PER_TH];
+ T g_vals[N_PER_TH];
+ T p_vals[N_PER_TH];
+
+ typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
+ typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
+
+ typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
+ typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
+
+ __shared__ float smem_quantiles1[LANES][257];
+ typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
+ __shared__ typename BlockReduce1::TempStorage reduce1;
+ __shared__ float smem_exchange1[1];
+
+ __shared__ union {
+ typename LoadT::TempStorage loadh;
+ typename LoadChar::TempStorage loadc;
+ typename StoreChar::TempStorage storec;
+ typename StoreT::TempStorage storeh;
+ } temp_storage;
+ // init: 0.2 -> 0.23
+
+ // 0.23 -> 0.23
+ smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
+ # pragma unroll
+ for(unsigned int j = 1; j < LANES; j++)
+ smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
+
+ __syncthreads();
+
+ #pragma unroll
+ for(int k = 0; k < QUAD; k++)
+ quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
+
+ for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
+ {
+ // loads: 0.23 -> 0.85/1.44
+ valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
+ __syncthreads();
+ LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
+ __syncthreads();
+ LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
+ __syncthreads();
+ LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
+
+ new_local_abs_max1 = -FLT_MAX;
+
+ // update: 2.48/1.57 -> 2.51/1.60
+ # pragma unroll N_PER_TH
+ for(unsigned int j = 0; j < N_PER_TH; j++)
+ {
+ g_val = float(g_vals[j]);
+ g_val *= gnorm_scale;
+ if(weight_decay > 0.0f)
+ g_val += ((float)p_vals[j])*weight_decay;
+
+ s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
+
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = g_val;
+ else
+ s1_vals[j] = (s1_vals[j]*beta1) + g_val;
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
+ break;
+ }
+
+ new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
+ }
+
+
+ // reduce: 2.51/1.60 -> 2.67/1.69
+ new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
+
+ if(threadIdx.x == 0)
+ smem_exchange1[0] = new_local_abs_max1;
+
+ __syncthreads();
+
+ if(threadIdx.x == 0)
+ absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
+ else
+ new_local_abs_max1 = smem_exchange1[0];
+
+ // reduce: 2.67/1.69 -> 2.67/1.70
+ # pragma unroll N_PER_TH
+ for(unsigned int j = 0; j < N_PER_TH; j++)
+ {
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
+ break;
+ case RMSPROP:
+ g_val = g_vals[j];
+ p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
+ break;
+ }
+ }
+
+ // store: 0.85/1.44 -> 2.48/1.57
+ __syncthreads();
+ StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
+
+ // quantizaztion: 2.67/1.70 -> 3.4/3.3
+ # pragma unroll N_PER_TH
+ for(unsigned int j = 0; j < N_PER_TH; j++)
+ {
+ c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
+
+ // make sure state1 term has still the same sign after quantization
+ // (not needed for state2 term which has only positive values)
+ if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
+ {
+ if(s1_vals[j] > 0.0f)
+ c1s[j] += 1;
+ else
+ c1s[j] -= 1;
+ }
+ }
+
+ __syncthreads();
+ StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
+ }
+}
+
+//==============================================================
+// TEMPLATE DEFINITIONS
+//==============================================================
+
+template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
+template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
+
+template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n);
+template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n);
+
+#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
+template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
+ float* state1, float *unorm, \
+ const float beta1, const float eps, const float weight_decay, \
+ const int step, const float lr, const float gnorm_scale, const int n); \
+
+MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
+MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
+MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
+MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
+
+#define MAKE_Optimizer32bit1State(oname, gtype) \
+template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
+ const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); \
+
+MAKE_Optimizer32bit1State(MOMENTUM, half)
+MAKE_Optimizer32bit1State(MOMENTUM, float)
+MAKE_Optimizer32bit1State(RMSPROP, half)
+MAKE_Optimizer32bit1State(RMSPROP, float)
+
+#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
+template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
+ float* state1, float* state2, float *unorm, \
+ const float beta1, const float beta2, const float eps, const float weight_decay, \
+ const int step, const float lr, const float gnorm_scale, const int n); \
+
+MAKE_PreconditionOptimizer32bit2State(ADAM, half)
+MAKE_PreconditionOptimizer32bit2State(ADAM, float)
+
+template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n);
+template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n);
+
+#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
+template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
+ float *unorm, \
+ const float beta1, \
+ const float eps, const int step, \
+ float* __restrict__ const quantiles1, \
+ float* max1, float* new_max1, \
+ const float weight_decay, \
+ const float gnorm_scale, \
+ const int n); \
+
+MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
+MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
+MAKE_PreconditionStatic8bit1State(RMSPROP, half)
+MAKE_PreconditionStatic8bit1State(RMSPROP, float)
+
+#define MAKE_optimizerStatic8bit1State(oname, gtype) \
+template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
+ const float *unorm, const float max_unorm, const float param_norm, \
+ const float beta1, \
+ const float eps, const int step, const float lr, \
+ float* __restrict__ const quantiles1, \
+ float* max1, float* new_max1, \
+ float weight_decay, \
+ const float gnorm_scale, \
+ const int n); \
+
+MAKE_optimizerStatic8bit1State(MOMENTUM, half)
+MAKE_optimizerStatic8bit1State(MOMENTUM, float)
+MAKE_optimizerStatic8bit1State(RMSPROP, half)
+MAKE_optimizerStatic8bit1State(RMSPROP, float)
+
+#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
+template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
+ float *unorm, \
+ const float beta1, const float beta2, \
+ const float eps, const int step, \
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
+ float* max1, float* max2, float* new_max1, float* new_max2, \
+ const float gnorm_scale, \
+ const int n); \
+
+MAKE_PreconditionStatic8bit2State(ADAM, half)
+MAKE_PreconditionStatic8bit2State(ADAM, float)
+
+#define MAKE_optimizerStatic8bit2State(oname, gtype) \
+template __global__ void kOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \
+ const float *unorm, const float max_unorm, const float param_norm, \
+ const float beta1, const float beta2, \
+ const float eps, const int step, const float lr, \
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
+ float* max1, float* max2, float* new_max1, float* new_max2, \
+ float weight_decay, \
+ const float gnorm_scale, \
+ const int n); \
+
+MAKE_optimizerStatic8bit2State(ADAM, half)
+MAKE_optimizerStatic8bit2State(ADAM, float)
+
+template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
+template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
+
+template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+
+template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
+template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
+template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
+template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
+
+
+
+#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
+template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
+ const float beta1, const float beta2, \
+ const float eps, const int step, const float lr, \
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
+ float* absmax1, float* absmax2, \
+ float weight_decay, \
+ const float gnorm_scale, const int n); \
+
+MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
+MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
+
+#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
+template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
+ gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
+ const float beta1, const float beta2, \
+ const float eps, const int step, const float lr, \
+ float* __restrict__ const quantiles1, \
+ float* absmax1, \
+ float weight_decay, \
+ const float gnorm_scale, const int n); \
+
+MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
+MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
+MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
+MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
new file mode 100644
index 0000000..06ae1e4
--- /dev/null
+++ b/csrc/kernels.cuh
@@ -0,0 +1,111 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <float.h>
+#include <ops.cuh>
+
+#ifndef kernels
+#define kernels
+
+template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
+
+__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
+__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
+
+template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
+
+template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
+__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
+ float* state1, float* state2, float *unorm,
+ const float beta1, const float beta2, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n);
+
+template<typename T, int OPTIMIZER>
+__global__ void kOptimizer32bit2State(T* g, T* p,
+ float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
+ const float beta1, const float beta2, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n);
+
+template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
+__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
+ float* state1, float *unorm,
+ const float beta1, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n);
+
+template<typename T, int OPTIMIZER>
+__global__ void kOptimizer32bit1State(T* g, T* p,
+ float* state1, float *unorm, const float max_unorm, const float param_norm,
+ const float beta1, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n);
+
+template<typename T, int OPTIMIZER>
+__global__ void
+kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
+ float *unorm,
+ const float beta1,
+ const float eps, const int step,
+ float* __restrict__ const quantiles1,
+ float* max1, float* new_max1,
+ const float weight_decay,
+ const float gnorm_scale, const int n);
+
+
+template<typename T, int OPTIMIZER>
+__global__ void
+kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
+ const float *unorm, const float max_unorm, const float param_norm,
+ const float beta1,
+ const float eps, const int step, const float lr,
+ float* __restrict__ const quantiles1,
+ float* max1, float* new_max1,
+ float weight_decay, const float gnorm_scale, const int n);
+
+
+
+template<typename T, int OPTIMIZER>
+__global__ void
+kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
+ float *unorm,
+ const float beta1, const float beta2,
+ const float eps, const int step,
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
+ float* max1, float* max2, float* new_max1, float* new_max2,
+ const float gnorm_scale, const int n);
+
+
+template<typename T, int OPTIMIZER>
+__global__ void
+kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
+ const float *unorm, const float max_unorm, const float param_norm,
+ const float beta1, const float beta2,
+ const float eps, const int step, const float lr,
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
+ float* max1, float* max2, float* new_max1, float* new_max2,
+ float weight_decay, const float gnorm_scale, const int n);
+
+template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
+ T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
+ const float beta1, const float beta2, const float eps, const int step, const float lr,
+ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
+ float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const int n);
+
+template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
+ T* p, T* __restrict__ const g, unsigned char* state1,
+ const float beta1, const float beta2,
+ const float eps, const int step, const float lr,
+ float* __restrict__ const quantiles1,
+ float* absmax1,
+ float weight_decay,
+ const float gnorm_scale, const int n);
+
+
+template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
+
+__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
+
+#endif
+
+
diff --git a/csrc/ops.cu b/csrc/ops.cu
new file mode 100644
index 0000000..d460ab1
--- /dev/null
+++ b/csrc/ops.cu
@@ -0,0 +1,355 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <ops.cuh>
+#include <kernels.cuh>
+#include <cub/device/device_scan.cuh>
+#include <limits>
+#include <BinSearch.h>
+
+
+using namespace BinSearch;
+using std::cout;
+using std::endl;
+
+#define BLOCK_SIZE 4096
+
+struct quantize_block_args
+{
+ BinAlgo<Scalar, float, Direct2> *bin_searcher;
+ float *code;
+ float *A;
+ float *absmax;
+ unsigned char *out;
+ int block_end;
+ int block_idx;
+ int threadidx;
+};
+
+void *quantize_block(void *arguments)
+{
+ // 1. find absmax in block
+ // 2. divide input value by absmax to normalize into [-1.0, 1.0]
+ // 3. do binary search to find the closest value
+ // 4. check minimal distance
+ // 5. store index
+
+ struct quantize_block_args *args = (quantize_block_args*)arguments;
+
+ // 1. find absmax in block
+ float absmax_block = -FLT_MAX;
+ for (int i = args->block_idx; i < args->block_end; i++)
+ absmax_block = fmax(absmax_block, fabs(args->A[i]));
+
+ args->absmax[args->block_idx/BLOCK_SIZE] = absmax_block;
+
+ for (int i = args->block_idx; i < args->block_end; i++)
+ {
+ // 2. divide input value by absmax to normalize into [-1.0, 1.0]
+ // 3. do binary search to find the closest value
+ float normed_value = args->A[i]/absmax_block;
+ int idx = args->bin_searcher->scalar(normed_value);
+
+ // 4. check minimal distance
+ // The binary search returns always the value to the left, which might not be the closest value
+ if(idx < 255)
+ {
+ float dist_left = fabs(normed_value-(args->code[idx]));
+ float dist_right = fabs(normed_value-(args->code[idx+1]));
+ if(dist_right < dist_left){ idx+=1; }
+ }
+
+ // 5. store index
+ args->out[i] = (unsigned char)idx;
+ }
+
+ return NULL;
+}
+
+void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n)
+{
+
+ // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
+ code[0] = -1.0f;
+
+ int num_blocks = n/BLOCK_SIZE;
+ num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
+
+ pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t)*num_blocks);
+ struct quantize_block_args **args = (quantize_block_args**)malloc(num_blocks*sizeof(quantize_block_args*));
+
+ for(int i = 0; i < num_blocks; i++)
+ args[i] = (quantize_block_args*)malloc(sizeof(quantize_block_args));
+
+ const uint32 elements_code = 256;
+ BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
+
+ for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
+ {
+ int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
+ int block_end = block_idx + valid_items;
+
+ struct quantize_block_args *arg = args[block_idx/BLOCK_SIZE];
+ arg->bin_searcher = &bin_searcher;
+ arg->code = code;
+ arg->A = A;
+ arg->absmax = absmax;
+ arg->out = out;
+ arg->block_end = block_end;
+ arg->block_idx = block_idx;
+ arg->threadidx = block_idx/BLOCK_SIZE;
+
+ pthread_create(&threads[block_idx/BLOCK_SIZE], NULL, &quantize_block, (void *)arg);
+ }
+
+ for(int i = 0; i < num_blocks; i++)
+ int err = pthread_join(threads[i], NULL);
+
+ free(threads);
+ for(int i = 0; i < num_blocks; i++)
+ free(args[i]);
+ free(args);
+}
+
+
+void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n)
+{
+ for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
+ {
+ int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
+ int block_end = block_idx + valid_items;
+ for (int i = block_idx; i < block_end; i++)
+ out[i] = code[A[i]]*absmax[block_idx/BLOCK_SIZE];
+ }
+}
+
+void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
+{
+ int threads = 512;
+ int blocks = n/threads;
+ blocks = n % threads == 0 ? blocks : blocks + 1;
+ kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
+{
+ int blocks = n/4096;
+ blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
+ kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+void quantize(float *code, float *A, unsigned char *out, int n)
+{
+ int blocks = n/1024;
+ blocks = n % 1024 == 0 ? blocks : blocks + 1;
+ kQuantize<<<blocks, 1024>>>(code, A, out, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+void dequantize(float *code, unsigned char *A, float *out, int n)
+{
+ int blocks = n/1024;
+ blocks = n % 1024 == 0 ? blocks : blocks + 1;
+ kDequantize<<<blocks, 1024>>>(code, A, out, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
+{
+ int blocks = n/4096;
+ blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
+{
+ int blocks = n/blocksize;
+ blocks = n % blocksize == 0 ? blocks : blocks + 1;
+ if(blocksize == 4096)
+ kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n);
+ else if(blocksize == 2048)
+ kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
+ float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
+ const float beta1, const float beta2, const float eps, const float weight_decay,
+ const int step, const float lr, const float gnorm_scale, const int n)
+{
+ int blocks = n/4096;
+ blocks = n % 4096 == 0 ? blocks : blocks + 1;
+ switch(OPTIMIZER)
+ {
+ case ADAM:
+ if(max_unorm > 0.0f)
+ {
+ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
+ kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ }
+ kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ break;
+ case MOMENTUM:
+ case RMSPROP:
+ if(max_unorm > 0.0f)
+ {
+ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
+ kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ }
+
+ kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ break;
+ }
+}
+
+template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
+ unsigned char* state1, unsigned char* state2,
+ float *unorm, float max_unorm, float param_norm,
+ float beta1, float beta2,
+ float eps, int step, float lr,
+ float* quantiles1, float* quantiles2,
+ float* max1, float* max2, float* new_max1, float* new_max2,
+ float weight_decay,
+ const float gnorm_scale, int n)
+{
+ int blocks = n/4096;
+ blocks = n % 4096 == 0 ? blocks : blocks + 1;
+
+ if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
+
+ switch(OPTIMIZER)
+ {
+ case ADAM:
+ CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
+ CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
+ kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
+ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ break;
+ case MOMENTUM:
+ case RMSPROP:
+ CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
+ kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
+ quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ break;
+ default:
+ break;
+ }
+}
+
+#define BLOCKSIZE_2STATE 2048
+#define NUM_2STATE 8
+#define BLOCKSIZE_1STATE 2048
+#define NUM_1STATE 8
+
+template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
+ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)
+{
+
+ int blocks = 0;
+ switch(OPTIMIZER)
+ {
+ case ADAM:
+ blocks = n/BLOCKSIZE_2STATE;
+ blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
+ kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
+ quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ break;
+ case MOMENTUM:
+ case RMSPROP:
+ blocks = n/BLOCKSIZE_1STATE;
+ blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
+ kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
+ quantiles1, absmax1, weight_decay, gnorm_scale, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+ break;
+ }
+}
+
+
+
+template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
+{
+ int blocks = n/2048;
+ blocks = n % 2048 == 0 ? blocks : blocks + 1;
+ CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
+ kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+
+//==============================================================
+// TEMPLATE DEFINITIONS
+//==============================================================
+
+template void estimateQuantiles(half *A, float *code, float offset, int n);
+template void estimateQuantiles(float *A, float *code, float offset, int n);
+
+template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
+template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
+template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
+template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
+template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
+template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
+
+#define MAKE_optimizer32bit(name, gtype) \
+template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
+ float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
+ const float beta1, const float beta2, const float eps, const float weight_decay, \
+ const int step, const float lr, const float gnorm_scale, const int n);
+
+MAKE_optimizer32bit(ADAM, half)
+MAKE_optimizer32bit(ADAM, float)
+MAKE_optimizer32bit(MOMENTUM, half)
+MAKE_optimizer32bit(MOMENTUM, float)
+MAKE_optimizer32bit(RMSPROP, half)
+MAKE_optimizer32bit(RMSPROP, float)
+
+#define MAKE_optimizerStatic8bit(name, gtype) \
+template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
+ float *unorm, float max_unorm, float param_norm, \
+ float beta1, float beta2, \
+ float eps, int step, float lr, \
+ float* quantiles1, float* quantiles2, \
+ float* max1, float* max2, float* new_max1, float* new_max2, \
+ float weight_decay, \
+ const float gnorm_scale, int n); \
+
+MAKE_optimizerStatic8bit(ADAM, half)
+MAKE_optimizerStatic8bit(ADAM, float)
+MAKE_optimizerStatic8bit(MOMENTUM, half)
+MAKE_optimizerStatic8bit(MOMENTUM, float)
+MAKE_optimizerStatic8bit(RMSPROP, half)
+MAKE_optimizerStatic8bit(RMSPROP, float)
+
+#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
+template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
+ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); \
+
+MAKE_optimizerStatic8bitBlockwise(half, ADAM);
+MAKE_optimizerStatic8bitBlockwise(float, ADAM);
+MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
+MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
+MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
+MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
+
+template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
+template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
new file mode 100644
index 0000000..e6033cb
--- /dev/null
+++ b/csrc/ops.cuh
@@ -0,0 +1,81 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+
+#ifndef ops_H
+#define ops_H
+
+#include <stdio.h>
+#include <iostream>
+#include <unistd.h>
+#include <assert.h>
+
+#include <cuda_runtime_api.h>
+#include <cuda_fp16.h>
+
+#define CUDA_CHECK_RETURN(value) { \
+ cudaError_t _m_cudaStat = value; \
+ if (_m_cudaStat != cudaSuccess) { \
+ fprintf(stderr, "Error %s at line %d in file %s\n", \
+ cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
+ exit(1); \
+ } }
+
+#define THREADS_PER_BLOCKS (512)
+
+typedef enum Operations_t
+{
+ ksmul = 0,
+} Operations_t;
+
+typedef enum Optimizer_t
+{
+ ADAM = 0,
+ MOMENTUM = 1,
+ RMSPROP = 2,
+ LARS = 3,
+} Optimizer_t;
+
+
+template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
+
+void quantize(float *code, float *A, unsigned char *out, int n);
+void dequantize(float *code, unsigned char *A, float *out, int n);
+template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
+template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
+
+template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
+ float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
+ float beta1, float beta2, float eps, float weight_decay,
+ int step, float lr, const float gnorm_scale, int n);
+
+template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
+ float *unorm, float max_unorm, float param_norm,
+ float beta1, float beta2,
+ float eps, int step, float lr,
+ float* quantiles1, float* quantiles2,
+ float* max1, float* max2, float* new_max1, float* new_max2,
+ float weight_decay,
+ const float gnorm_scale, int n);
+
+template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
+ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n);
+
+template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
+
+void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
+void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
+
+void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
+
+#endif
+
+
+
+
+
+
+
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
new file mode 100644
index 0000000..eacb849
--- /dev/null
+++ b/csrc/pythonInterface.c
@@ -0,0 +1,149 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <ops.cuh>
+
+// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
+// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
+// maintain all that boilerplate
+//===================================================================================
+// UNMANGLED CALLS
+//===================================================================================
+
+void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
+void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
+
+
+#define MAKE_FUNC32(fname, oname, gtype, gbits) \
+void fname##32bit_g##gbits(gtype *g, gtype *p, \
+ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
+ const float beta1, const float beta2, const float eps, const float weight_decay, \
+ const int step, const float lr, float gnorm_scale, const int n) \
+{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
+
+MAKE_FUNC32(momentum, MOMENTUM, float, 32)
+MAKE_FUNC32(momentum, MOMENTUM, half, 16)
+MAKE_FUNC32(adam, ADAM, float, 32)
+MAKE_FUNC32(adam, ADAM, half, 16)
+MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
+MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
+
+#define MAKE_FUNC8(fname, oname, gtype, gbits) \
+void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
+ float *unorm, float max_unorm, float param_norm, \
+ float beta1, float beta2, \
+ float eps, int step, float lr, \
+ float* quantiles1, float* quantiles2, \
+ float* max1, float* max2, float* new_max1, float* new_max2, \
+ float weight_decay, float gnorm_scale, int n) \
+{ \
+ optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
+ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
+} \
+
+MAKE_FUNC8(adam, ADAM, float, 32)
+MAKE_FUNC8(adam, ADAM, half, 16)
+MAKE_FUNC8(momentum, MOMENTUM, float, 32)
+MAKE_FUNC8(momentum, MOMENTUM, half, 16)
+MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
+MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
+
+#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
+void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
+ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)\
+{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); }\
+
+MAKE_BLOCKWISE8(adam, ADAM, half, 16)
+MAKE_BLOCKWISE8(adam, ADAM, float, 32)
+MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
+MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
+MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
+MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
+
+
+void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
+void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
+
+void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, n); }
+void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, n); }
+void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, n); }
+void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, n); }
+
+void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
+void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
+
+extern "C"
+{
+ void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
+ void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
+ void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
+ void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
+ void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
+ void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
+ void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
+ void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
+
+ void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
+ void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
+
+ #define MAKE_CFUNC32(name, gtype, gbits) \
+ void c##name##32bit_g##gbits(gtype *g, gtype *p, \
+ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
+ const float beta1, const float beta2, const float eps, const float weight_decay, \
+ const int step, const float lr, const float gnorm_scale, const int n) \
+ { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
+
+ MAKE_CFUNC32(adam, float, 32)
+ MAKE_CFUNC32(adam, half, 16)
+ MAKE_CFUNC32(momentum, float, 32)
+ MAKE_CFUNC32(momentum, half, 16)
+ MAKE_CFUNC32(rmsprop, float, 32)
+ MAKE_CFUNC32(rmsprop, half, 16)
+
+ #define MAKE_CFUNC8(name, gtype, gbits) \
+ void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
+ float *unorm, float max_unorm, float param_norm, \
+ float beta1, float beta2, \
+ float eps, int step, float lr, \
+ float* quantiles1, float* quantiles2, \
+ float* max1, float* max2, float* new_max1, float* new_max2, \
+ float weight_decay, float gnorm_scale, int n) \
+ { \
+ name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
+ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
+ } \
+
+ MAKE_CFUNC8(adam, float, 32)
+ MAKE_CFUNC8(adam, half, 16)
+ MAKE_CFUNC8(momentum, float, 32)
+ MAKE_CFUNC8(momentum, half, 16)
+ MAKE_CFUNC8(rmsprop, float, 32)
+ MAKE_CFUNC8(rmsprop, half, 16)
+
+ #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
+ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
+ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) \
+ { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); } \
+
+ MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
+ MAKE_CBLOCKWISE8(adam, ADAM, float, 32)
+ MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16)
+ MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
+ MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
+ MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
+
+
+ void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
+ void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
+
+ void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
+ void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
+
+ void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
+}
+
+
diff --git a/deploy.sh b/deploy.sh
new file mode 100644
index 0000000..a08351e
--- /dev/null
+++ b/deploy.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+rm -rf dist build
+make clean
+CUDA_HOME=/usr/local/cuda-10.2 make
+CUDA_VERSION=102 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+
+rm -rf dist build
+make clean
+CUDA_HOME=/usr/local/cuda-11.1 make
+CUDA_VERSION=111 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
diff --git a/deploy_from_slurm.sh b/deploy_from_slurm.sh
new file mode 100644
index 0000000..e21f2e0
--- /dev/null
+++ b/deploy_from_slurm.sh
@@ -0,0 +1,86 @@
+#!/bin/bash
+module unload cuda
+module unload gcc
+
+rm -rf dist build
+make clean
+make cleaneggs
+module load cuda/9.2
+module load gcc/7.3.0
+CUDA_HOME=/public/apps/cuda/9.2
+make
+CUDA_VERSION=92 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+module unload cuda
+
+
+rm -rf dist build
+make clean
+make cleaneggs
+module load cuda/10.0
+CUDA_HOME=/public/apps/cuda/10.0
+make cuda10x
+CUDA_VERSION=100 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+module unload cuda
+module unload gcc
+module load gcc/8.4
+
+rm -rf dist build
+make clean
+make cleaneggs
+module load cuda/10.1
+CUDA_HOME=/public/apps/cuda/10.1
+make cuda10x
+CUDA_VERSION=101 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+module unload cuda
+
+rm -rf dist build
+make clean
+make cleaneggs
+module load cuda/10.2
+CUDA_HOME=/public/apps/cuda/10.2/
+make cuda10x
+CUDA_VERSION=102 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+module unload cuda
+
+
+rm -rf dist build
+make clean
+make cleaneggs
+module load cuda/11.0
+CUDA_HOME=/public/apps/cuda/11.0
+make cuda110
+CUDA_VERSION=110 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+module unload cuda
+
+rm -rf dist build
+make clean
+make cleaneggs
+module load cuda/11.1
+CUDA_HOME=/public/apps/cuda/11.1
+make cuda11x
+CUDA_VERSION=111 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+module unload cuda
+
+rm -rf dist build
+make clean
+make cleaneggs
+module load cuda/11.2
+CUDA_HOME=/public/apps/cuda/11.2
+make cuda11x
+CUDA_VERSION=112 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+module unload cuda
+
+rm -rf dist build
+make clean
+make cleaneggs
+CUDA_HOME=/private/home/timdettmers/git/autoswap/local/cuda-11.3 make cuda11x
+CUDA_VERSION=113 python -m build
+python -m twine upload --repository testpypi dist/* --verbose
+module unload cuda
diff --git a/include/AAlloc.h b/include/AAlloc.h
new file mode 100644
index 0000000..6c2ae41
--- /dev/null
+++ b/include/AAlloc.h
@@ -0,0 +1,86 @@
+#pragma once
+
+#include "Portable.h"
+
+namespace BinSearch {
+namespace Details {
+
+template <typename T>
+bool isAligned(const T *p, size_t A)
+{
+ return (reinterpret_cast<size_t>(p) % A) == 0;
+}
+
+template <class T, size_t A=64>
+struct AlignedVec
+{
+ AlignedVec()
+ : m_storage(0)
+ , m_data(0)
+ , m_sz(0)
+ {
+ }
+
+ static size_t nBytes(size_t sz)
+ {
+ return sz * sizeof(T) + A;
+ }
+
+ static size_t shiftAmt(char *p)
+ {
+ return A>1? (A - (reinterpret_cast<size_t>(p) % A)) % A: 0;
+ }
+
+ void setPtr(char *p, size_t sz)
+ {
+ m_sz = sz;
+ m_data = reinterpret_cast<T *>(p + shiftAmt(p));
+ }
+
+ //void setPtr(T *p, size_t sz)
+ //{
+ // m_sz = sz;
+ // if (A>1)
+ // myassert(((reinterpret_cast<size_t>(p) % A) == 0), "bad alignment");
+ // m_data = p;
+ //}
+
+ // internal allocation
+ void resize(size_t sz)
+ {
+ m_storage = new char[nBytes(sz)];
+ setPtr(m_storage, sz);
+ }
+
+ // external allocation
+ void set(char *storage, size_t sz)
+ {
+ setPtr(storage, sz);
+ }
+
+ ~AlignedVec()
+ {
+ if (m_storage)
+ delete [] m_storage;
+ }
+
+ size_t size() const { return m_sz; }
+ T& operator[](size_t i) { return m_data[i]; }
+ const T& operator[](size_t i) const { return m_data[i]; }
+ T* begin() { return m_data; }
+ T* end() { return m_data+m_sz; }
+ const T* begin() const { return m_data; }
+ const T* end() const { return m_data+m_sz; }
+ T& front() { return m_data[0]; }
+ T& back() { return m_data[m_sz-1]; }
+ const T& front() const { return m_data[0]; }
+ const T& back() const { return m_data[m_sz - 1]; }
+
+private:
+ char *m_storage;
+ T *m_data;
+ size_t m_sz;
+};
+
+} // namespace Details
+} // namespace BinSearch
diff --git a/include/Algo-Direct-Common.h b/include/Algo-Direct-Common.h
new file mode 100644
index 0000000..cf5f0c9
--- /dev/null
+++ b/include/Algo-Direct-Common.h
@@ -0,0 +1,341 @@
+#pragma once
+
+#include <algorithm>
+#include <limits>
+#include <type_traits>
+#include "AAlloc.h"
+
+namespace BinSearch {
+namespace Details {
+
+namespace DirectAux {
+
+#define SAFETY_MULTI_PASS true
+
+template <typename T>
+struct HResults
+{
+ HResults(T h, double ratio, size_t n) : H(h), hRatio(ratio), nInc(n) {}
+ T H;
+ double hRatio;
+ size_t nInc;
+};
+
+
+#ifdef USE_FMA
+template <Algos A> struct IsDirect { static const bool value = (A == Direct) || (A == DirectFMA); };
+template <Algos A> struct IsDirect2 { static const bool value = (A == Direct2) || (A == Direct2FMA); };
+template <Algos A> struct IsDirectCache { static const bool value = (A == DirectCache) || (A == DirectCacheFMA); };
+#else
+template <Algos A> struct IsDirect { static const bool value = (A == Direct); };
+template <Algos A> struct IsDirect2 { static const bool value = (A == Direct2); };
+template <Algos A> struct IsDirectCache { static const bool value = (A == DirectCache); };
+#endif
+
+// general definition
+template <Algos A, typename T, typename Enable = void>
+struct BucketElem
+{
+ FORCE_INLINE void set( uint32 b, const T *)
+ {
+ m_b = b;
+ }
+
+ FORCE_INLINE uint32 index() const { return m_b; }
+
+private:
+ uint32 m_b;
+};
+
+// specialization for DirectCache methods
+
+template <typename T> struct MatchingIntType;
+template <> struct MatchingIntType<double> { typedef uint64 type; };
+template <> struct MatchingIntType<float> { typedef uint32 type; };
+
+template <Algos A, typename T>
+struct BucketElem<A, T, typename std::enable_if< IsDirectCache<A>::value >::type >
+{
+ typedef typename MatchingIntType<T>::type I;
+
+ void set(uint32 b, const T *xi)
+ {
+ u.u.x = xi[b];
+ u.u.b = b;
+ }
+
+ FORCE_INLINE I index() const { return u.u.b; }
+ FORCE_INLINE T x() const { return u.u.x; }
+
+private:
+ union {
+ double dummy;
+ struct
+ {
+ T x;
+ I b;
+ } u;
+ } u;
+};
+
+
+template <bool UseFMA, unsigned char Gap, typename T>
+struct DirectTraits
+{
+ static void checkH(T scaler, T x0, T xN)
+ {
+ T Dn = xN - x0;
+ T ifmax = Dn * scaler;
+ myassert((ifmax < std::numeric_limits<uint32>::max() - (Gap - 1)),
+ "Problem unfeasible: index size exceeds uint32 capacity:"
+ << " D[N] =" << Dn
+ << ", H =" << scaler
+ << ", H D[n] =" << ifmax << "\n"
+ );
+ }
+
+ FORCE_INLINE static uint32 f(T scaler, T x0, T z)
+ {
+ T tmp = scaler * (z - x0);
+#ifdef USE_SSE2
+ return ftoi(FVec1<SSE,T>(tmp));
+#else
+ return static_cast<uint32>(tmp);
+#endif
+ }
+
+ template <InstrSet I>
+ FORCE_INLINE static typename FTOITraits<I, T>::vec_t f(const FVec<I, T>& scaler, const FVec<I, T>& x0, const FVec<I, T>& z)
+ {
+ return ftoi(scaler*(z-x0));
+ }
+
+ static T cst0(T scaler, T x0)
+ {
+ return x0;
+ }
+};
+
+#ifdef USE_FMA
+template <unsigned char Gap, typename T>
+struct DirectTraits<true,Gap,T>
+{
+ typedef FVec1<SSE, T> fVec1;
+
+ static void checkH(T scaler, T H_Times_x0, T xN)
+ {
+ union {
+ typename FVec1<SSE, T>::vec_t v;
+ T s;
+ } ifmax;
+ ifmax.v = mulSub(fVec1(scaler), fVec1(xN), fVec1(H_Times_x0));
+ myassert((ifmax.s < std::numeric_limits<uint32>::max() - (Gap - 1)),
+ "Problem unfeasible: index size exceeds uint32 capacity:"
+ << " H X[0] =" << H_Times_x0
+ << ", H =" << scaler
+ << ", X[N] =" << xN
+ << ", H X[N] - H X[0] =" << ifmax.s << "\n"
+ );
+ }
+
+ FORCE_INLINE static uint32 f(T scaler, T Hx0, T xi)
+ {
+ return ftoi(mulSub(fVec1(scaler), fVec1(xi), fVec1(Hx0)));
+ }
+
+ template <InstrSet I>
+ FORCE_INLINE static typename FTOITraits<I,T>::vec_t f(const FVec<I,T>& scaler, const FVec<I, T>& H_Times_X0, const FVec<I, T>& z)
+ {
+ return ftoi(mulSub(scaler, z, H_Times_X0));
+ }
+
+ static T cst0(T scaler, T x0)
+ {
+ return scaler*x0;
+ }
+};
+#endif
+
+template <unsigned char Gap, typename T, Algos A>
+struct DirectInfo
+{
+ static const bool UseFMA = (A == DirectFMA) || (A == Direct2FMA) || (A == DirectCacheFMA);
+ typedef DirectTraits<UseFMA, Gap, T> fun_t;
+ typedef BucketElem<A,T> bucket_t;
+ typedef AlignedVec<bucket_t> bucketvec_t;
+
+ struct Data {
+ Data() : buckets(0), xi(0), scaler(0), cst0(0) {}
+ Data( const T *x // for Direct must persist if xws=NULL
+ , uint32 n
+ , T H
+ , bucket_t *bws // assumed to gave size nb, as computed below
+ , T *xws = NULL // assumed to have size (n+Gap-1). Optional for Direct, unused for DirectCache, required for DirectGap
+ )
+ : buckets(bws)
+ , scaler(H)
+ , cst0(fun_t::cst0(H, x[0]))
+ {
+ myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned");
+
+ uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]);
+
+ const uint32 npad = Gap-1;
+ const uint32 n_sz = n + npad; // size of padded vector
+
+ if (xws) {
+ myassert(isAligned(xws,8), "x pointer not allocated or incorrectly aligned");
+ std::fill_n(xws, npad, x[0]); // pad in front with x[0]
+ std::copy(x, x+n, xws + npad);
+ xi = xws;
+ }
+ else {
+ myassert(Gap==1, "if Gap>1 then X workspace must be provided");
+ xi = x;
+ }
+
+ populateIndex(bws, nb, xi, n_sz, scaler, cst0);
+ }
+
+ const bucket_t *buckets;
+ const T *xi;
+ T scaler;
+ T cst0; // could be x0 or (scaler*x0), depending if we are using FMA or not
+ } data;
+
+ static T growStep(T H)
+ {
+ T step;
+ T P = next(H);
+ while ((step = P - H) == 0)
+ P = next(P);
+ return step;
+ }
+
+ static HResults<T> computeH(const T *px, uint32 nx)
+ {
+ myassert((nx > Gap), "Array X too small");
+ myassert(((Gap == 1) || (Gap == 2)), "Only tested for these values of Gap");
+
+ const T x0 = px[0];
+ const T xN = px[nx-1];
+
+ const T range = xN - x0;
+ myassert((range < std::numeric_limits<T>::max()), "range too large");
+
+ // check that D_i are strictly increasing and compute minimum value D_{i+Offset}-D_i
+ T deltaDMin = range;
+ for (uint32 i = Gap; i < nx; ++i) {
+ T Dnew = px[i] - x0;
+ T Dold = px[i - Gap] - x0;
+ myassert((Dnew > Dold),
+ "Problem unfeasible: D_i sequence not strictly increasing"
+ << " X[" << 0 << "]=" << x0
+ << " X[" << i - Gap << "]=" << px[i - Gap]
+ << " X[" << i << "]=" << px[i]
+ << "\n"
+ );
+ T deltaD = Dnew - Dold;
+ if (deltaD < deltaDMin)
+ deltaDMin = deltaD;
+ }
+
+ // initial guess for H
+ const T H0 = T(1.0) / deltaDMin;
+ T H = H0;
+
+ T cst0 = fun_t::cst0(H, x0);
+ fun_t::checkH(H, cst0, xN);
+
+ // adjust H by trial and error until succeed
+ size_t nInc = 0;
+ bool modified = false;
+ size_t npasses = 0;
+ T step = growStep(H);
+ uint32 seg_already_checked_from = nx;
+ do {
+ myassert((npasses++ < 2), "verification failed\n");
+ // if there has been an increase, then check only up to that point
+ uint32 last_seg_to_be_checked = seg_already_checked_from - 1;
+ modified = false;
+ uint32 inew = 0;
+ for (uint32 i = Gap; i <= last_seg_to_be_checked; ++i) {
+ uint32 iold = fun_t::f(H, cst0, px[i-Gap]);
+ uint32 inew = fun_t::f(H, cst0, px[i]);
+ while (inew == iold) {
+ seg_already_checked_from = i;
+ last_seg_to_be_checked = nx-1; // everything needs to be checked
+ modified = true;
+ H = H + step;
+ step *= 2;
+ // recalculate all constants and indices
+ cst0 = fun_t::cst0(H, x0);
+ fun_t::checkH(H, cst0, xN);
+ iold = fun_t::f(H, cst0, px[i - Gap]);
+ inew = fun_t::f(H, cst0, px[i]);
+ }
+ }
+ } while (SAFETY_MULTI_PASS && modified);
+
+ return HResults<T>(H, (((double)H) / H0) - 1.0, nInc);
+ }
+
+ static void populateIndex(BucketElem<A, T> *buckets, uint32 index_size, const T *px, uint32 x_size, T scaler, T cst0)
+ {
+ for (uint32 i = x_size-1, b = index_size-1, j=0; ; --i) {
+ uint32 idx = fun_t::f(scaler, cst0, px[i]);
+ while (b > idx) { // in the 1st iteration it is j=0 but this condition is always false
+ buckets[b].set( j, px );
+ --b;
+ }
+ if (Gap==1 || b == idx) { // if Gap==1, which is known at compile time, the check b==idx is redundant
+ j = i - (Gap-1); // subtracting (Gap-1) points to the index of the first X-element to check
+ buckets[b].set(j, px);
+ if (b-- == 0)
+ break;
+ }
+ }
+ }
+
+ DirectInfo(const Data& d)
+ : data(d)
+ {
+ }
+
+ DirectInfo(const T* px, const uint32 n)
+ {
+ HResults<T> res = computeH(px, n);
+
+#ifdef PAPER_TEST
+ nInc = res.nInc;
+ hRatio = res.hRatio;
+#endif
+ const uint32 npad = Gap-1;
+ const uint32 n_sz = n + npad; // size of padded vector
+
+ if (npad)
+ xi.resize(n_sz);
+
+ T H = res.H;
+ T cst0 = fun_t::cst0(H, px[0]);
+ const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]);
+ buckets.resize(maxIndex + 1);
+
+ data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL));
+ }
+
+private:
+ bucketvec_t buckets;
+ AlignedVec<T,8> xi;
+
+#ifdef PAPER_TEST
+public:
+ double hRatio;
+ size_t nInc;
+#endif
+};
+
+
+} // namespace DirectAux
+} // namespace Details
+} // namespace BinSearch
diff --git a/include/Algo-Direct2.h b/include/Algo-Direct2.h
new file mode 100644
index 0000000..d5fa58d
--- /dev/null
+++ b/include/Algo-Direct2.h
@@ -0,0 +1,305 @@
+#pragma once
+
+#include "Algo-Direct-Common.h"
+
+namespace BinSearch {
+namespace Details {
+
+template <typename T, Algos A>
+struct AlgoScalarBase<T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : DirectAux::DirectInfo<2, T, A>
+{
+private:
+ typedef DirectAux::DirectInfo<2, T, A> base_t;
+ static const size_t Offset=2;
+
+public:
+ AlgoScalarBase(const T* x, const uint32 n)
+ : base_t(x, n)
+ {
+ }
+
+ FORCE_INLINE uint32 scalar(T z) const
+ {
+ const T* px = base_t::data.xi;
+ const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
+ uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z);
+ uint32 iidx = buckets[bidx];
+ px += iidx;
+ if (z < *px)
+ --iidx;
+ if (z < *(px+1))
+ --iidx;
+ return iidx;
+ }
+};
+
+
+template <InstrSet I, typename T, Algos A>
+struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : AlgoScalarBase<T, A>
+{
+ static const uint32 nElem = sizeof(typename InstrFloatTraits<I, T>::vec_t) / sizeof(T);
+
+ typedef FVec<I, T> fVec;
+ typedef IVec<SSE, T> i128;
+
+ struct Constants
+ {
+ fVec vscaler;
+ fVec vcst0;
+ IVec<I, T> one;
+ };
+
+private:
+ typedef AlgoScalarBase<T, A> base_t;
+
+ FORCE_INLINE
+ //NO_INLINE
+ void resolve(const FVec<SSE, float>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
+ {
+ union U {
+ __m128i vec;
+ uint32 ui32[4];
+ } u;
+
+ const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
+ const float *xi = base_t::data.xi;
+
+ // read indices t
+ const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
+ const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
+ const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
+ const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
+
+#if 0
+ // read pairs ( X(t-1), X(t) )
+ __m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3));
+ __m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2));
+ __m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1));
+ __m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0));
+
+ // build:
+ // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
+ // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
+ __m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6));
+ __m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6));
+ __m128 u01 = _mm_unpacklo_ps(h02, h13);
+ __m128 u23 = _mm_unpackhi_ps(h02, h13);
+ __m128 vxm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6));
+ __m128 vxp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6));
+#else
+ __m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2));
+ __m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0));
+ __m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6));
+ __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
+#endif
+ IVec<SSE, float> i(u.vec);
+ IVec<SSE, float> vlem = vz < vxm;
+ IVec<SSE, float> vlep = vz < vxp;
+ i = i + vlem + vlep;
+ i.store(pr);
+ }
+
+ FORCE_INLINE
+ //NO_INLINE
+ void resolve(const FVec<SSE, double>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
+ {
+ const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
+ const double *xi = base_t::data.xi;
+
+ uint32 b1 = buckets[bidx.get1()];
+ uint32 b0 = buckets[bidx.get0()];
+
+ const double *p1 = &xi[b1];
+ const double *p0 = &xi[b0];
+
+ // read pairs ( X(t-1), X(t) )
+ __m128d vx1 = _mm_loadu_pd(p1);
+ __m128d vx0 = _mm_loadu_pd(p0);
+
+ // build:
+ // { X(t(0)-1), X(t(1)-1) }
+ // { X(t(0)), X(t(1)) }
+ __m128d vxm = _mm_shuffle_pd(vx0, vx1, 0);
+ __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
+
+ IVec<SSE, double> i(b1, b0);
+ IVec<SSE, double> vlem = (vz < vxm);
+ IVec<SSE, double> vlep = (vz < vxp);
+ i = i + vlem + vlep;
+
+ union {
+ __m128i vec;
+ uint32 ui32[4];
+ } u;
+ u.vec = i;
+ pr[0] = u.ui32[0];
+ pr[1] = u.ui32[2];
+ }
+
+#ifdef USE_AVX
+
+ FORCE_INLINE
+ //NO_INLINE
+ void resolve(const FVec<AVX, float>& vz, const IVec<AVX, float>& bidx, uint32 *pr) const
+ {
+ const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
+ const float *xi = base_t::data.xi;
+
+#if 0 // use gather instructions
+
+ IVec<AVX,float> idxm;
+ idxm.setidx(buckets, bidx);
+ __m256i z = _mm256_setzero_si256();
+ IVec<AVX,float> minusone = _mm256_cmpeq_epi32(z,z);
+ IVec<AVX,float> idxp = idxm - minusone;
+
+ FVec<AVX, float> vxm = _mm256_i32gather_ps(xi, idxm, sizeof(float));
+ FVec<AVX, float> vxp = _mm256_i32gather_ps(xi, idxp, sizeof(float));
+ IVec<AVX, float> ip = idxm;
+
+#else // do not use gather instrucions
+
+ union U {
+ __m256i vec;
+ uint32 ui32[8];
+ } u;
+
+ // read indices t
+
+ const double *p7 = reinterpret_cast<const double *>(&xi[(u.ui32[7] = buckets[bidx.get7()])]);
+ const double *p6 = reinterpret_cast<const double *>(&xi[(u.ui32[6] = buckets[bidx.get6()])]);
+ const double *p5 = reinterpret_cast<const double *>(&xi[(u.ui32[5] = buckets[bidx.get5()])]);
+ const double *p4 = reinterpret_cast<const double *>(&xi[(u.ui32[4] = buckets[bidx.get4()])]);
+ const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
+ const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
+ const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
+ const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
+
+#if 0 // perform 8 loads in double precision
+
+ // read pairs ( X(t-1), X(t) )
+ __m128 xp7 = _mm_castpd_ps(_mm_load_sd(p7));
+ __m128 xp6 = _mm_castpd_ps(_mm_load_sd(p6));
+ __m128 xp5 = _mm_castpd_ps(_mm_load_sd(p5));
+ __m128 xp4 = _mm_castpd_ps(_mm_load_sd(p4));
+ __m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3));
+ __m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2));
+ __m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1));
+ __m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0));
+
+ // build:
+ // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
+ // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
+ __m128 h57 = _mm_shuffle_ps(xp5, xp7, (1 << 2) + (1 << 6)); // F- F+ H- H+
+ __m128 h46 = _mm_shuffle_ps(xp4, xp6, (1 << 2) + (1 << 6)); // E- E+ G- G+
+ __m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6)); // B- B+ D- D+
+ __m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6)); // A- A+ C- C+
+
+ __m128 u01 = _mm_unpacklo_ps(h02, h13); // A- B- A+ B+
+ __m128 u23 = _mm_unpackhi_ps(h02, h13); // C- D- C+ D+
+ __m128 u45 = _mm_unpacklo_ps(h46, h57); // E- F- E+ F+
+ __m128 u67 = _mm_unpackhi_ps(h46, h57); // G- H- G+ H+
+
+ __m128 abcdm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // A- B- C- D-
+ __m128 abcdp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // A+ B+ C+ D+
+ __m128 efghm = _mm_shuffle_ps(u45, u67, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // E- F- G- H-
+ __m128 efghp = _mm_shuffle_ps(u45, u67, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // E+ F+ G+ H+
+
+ FVec<AVX, float> vxp = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdm), efghm, 1);
+ FVec<AVX, float> vxm = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdp), efghp, 1);
+
+ IVec<AVX, float> ip(u.vec);
+
+#else // use __mm256_set_pd
+
+ // read pairs ( X(t-1), X(t) )
+ __m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) }
+ __m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) }
+
+ // { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) }
+ FVec<AVX, float> vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6) );
+ // { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) }
+ FVec<AVX, float> vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6) );
+
+ IVec<AVX, float> ip(u.vec);
+
+#endif
+
+#endif
+
+ IVec<AVX, float> vlem = vz < vxm;
+ IVec<AVX, float> vlep = vz < vxp;
+ ip = ip + vlem + vlep;
+
+ ip.store(pr);
+ }
+
+
+
+ FORCE_INLINE
+ //NO_INLINE
+ void resolve(const FVec<AVX, double>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
+ {
+ union {
+ __m256i vec;
+ uint64 ui64[4];
+ } u;
+
+ const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
+ const double *xi = base_t::data.xi;
+
+ // read indices t
+ const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])];
+ const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])];
+ const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])];
+ const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])];
+
+ // read pairs ( X(t-1), X(t) )
+ __m128d xp3 = _mm_loadu_pd(p3);
+ __m128d xp2 = _mm_loadu_pd(p2);
+ __m128d xp1 = _mm_loadu_pd(p1);
+ __m128d xp0 = _mm_loadu_pd(p0);
+
+ // build:
+ // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
+ // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
+ __m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1);
+ __m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1);
+ FVec<AVX, double> vxm = _mm256_unpacklo_pd(x02,x13);
+ FVec<AVX, double> vxp = _mm256_unpackhi_pd(x02,x13);
+
+
+// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0);
+// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0);
+// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3);
+// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3);
+// FVec<AVX, double> vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1);
+// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
+
+ IVec<AVX, double> i(u.vec);
+ IVec<AVX, double> vlem = vz < vxm;
+ IVec<AVX, double> vlep = vz < vxp;
+ i = i + vlem + vlep;
+ i.extractLo32s().store(pr);
+ }
+#endif
+
+public:
+
+ AlgoVecBase(const T* x, const uint32 n) : base_t(x, n) {}
+
+ void initConstants(Constants& cst) const
+ {
+ cst.vscaler.setN(base_t::data.scaler);
+ cst.vcst0.setN(base_t::data.cst0);
+ cst.one.setN(uint32(1));
+ }
+
+ void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
+ {
+ fVec vz(pz);
+ resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr);
+ }
+};
+} // namespace Details
+} // namespace BinSearch
diff --git a/include/AlgoXCodes.h b/include/AlgoXCodes.h
new file mode 100644
index 0000000..bdc9b00
--- /dev/null
+++ b/include/AlgoXCodes.h
@@ -0,0 +1,23 @@
+ALGOENUM(DirectCacheFMA, 5)
+ALGOENUM(DirectFMA, 15)
+ALGOENUM(Direct2FMA, 25)
+ALGOENUM(DirectCache, 10)
+ALGOENUM(Direct, 20)
+ALGOENUM(Direct2, 30)
+ALGOENUM(Nonary, 40)
+ALGOENUM(Pentary, 50)
+ALGOENUM(Ternary, 60)
+ALGOENUM(Eytzinger, 70)
+ALGOENUM(BitSet, 80)
+ALGOENUM(ClassicOffset, 90)
+#ifdef PAPER_TEST
+ALGOENUM(MorinOffset, 100)
+ALGOENUM(BitSetNoPad, 110)
+ALGOENUM(ClassicMod, 120)
+ALGOENUM(MorinBranchy, 130)
+ALGOENUM(Classic, 140)
+ALGOENUM(LowerBound, 145)
+#ifdef USE_MKL
+ALGOENUM(MKL, 150)
+#endif
+#endif
diff --git a/include/BinAlgo.h b/include/BinAlgo.h
new file mode 100644
index 0000000..aac67a0
--- /dev/null
+++ b/include/BinAlgo.h
@@ -0,0 +1,77 @@
+#pragma once
+
+#include "Type.h"
+#include <algorithm>
+
+namespace BinSearch {
+
+template <InstrSet I, typename T, Algos A, bool L=false, bool R=false>
+struct BinAlgo : Details::BinAlgoBase<I,T,A>
+{
+ typedef Details::BinAlgoBase<I,T,A> base_t;
+
+ BinAlgo(const T* px, const uint32 n) : base_t(px, n), x0(px[0]), xN(px[n-1]), N(n) {}
+ BinAlgo(const T* px, const uint32 n, const typename base_t::Data& d) : base_t(d), x0(px[0]), xN(px[n-1]), N(n) {}
+
+ FORCE_INLINE
+ uint32 scalar(T z) const
+ {
+ if (!L || z >= x0)
+ if (!R || z < xN)
+ return base_t::scalar(z);
+ else
+ return N;
+ else
+ return std::numeric_limits<uint32>::max();
+ }
+
+
+ FORCE_INLINE
+ void vectorial(uint32 *pr, const T *pz, uint32 n) const
+ {
+ if (!L && !R) {
+ Details::Loop<T,base_t>::loop(*this, pr, pz, n);
+ }
+ else {
+ const uint32 nElem = base_t::nElem;
+ const uint32 idealbufsize = 256;
+ const uint32 bufsize = nElem * (idealbufsize / nElem + ((idealbufsize % nElem) ? 1 : 0));
+ T databuf[bufsize];
+ uint32 resbuf[bufsize];
+ uint32 indexbuf[bufsize];
+
+ uint32 *prend = pr + n;
+ while(pr != prend) {
+ uint32 cnt = 0;
+ uint32 niter = std::min(bufsize, (uint32)std::distance(pr,prend));
+ for (uint32 j = 0; j < niter; ++j) {
+ T z = pz[j];
+ // FIXME: use SSE2?
+ if (!L || z >= x0)
+ if (!R || z < xN) {
+ databuf[cnt] = z;
+ indexbuf[cnt] = j;
+ ++cnt;
+ }
+ else
+ pr[j] = N;
+ else
+ pr[j] = std::numeric_limits<uint32>::max();
+ }
+ // FIXME: merge these two loops
+ Details::Loop<T,base_t>::loop(*this, resbuf, databuf, cnt);
+ for (uint32 j = 0; j < cnt; ++j)
+ pr[indexbuf[j]] = resbuf[j];
+ pr += niter;
+ pz += niter;
+ }
+ }
+ }
+
+ Details::CondData<T,L> x0;
+ Details::CondData<T,R> xN;
+ Details::CondData<uint32,R> N;
+};
+
+
+} // namespace BinSearch
diff --git a/include/BinSearch.h b/include/BinSearch.h
new file mode 100644
index 0000000..336f529
--- /dev/null
+++ b/include/BinSearch.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "AAlloc.h"
+#include "BinAlgo.h"
+#include "SIMD.h"
+
+#include <algorithm>
+#include <limits>
+
+
+#include "Algo-Direct2.h"
diff --git a/include/Portable.h b/include/Portable.h
new file mode 100644
index 0000000..1710b05
--- /dev/null
+++ b/include/Portable.h
@@ -0,0 +1,151 @@
+#pragma once
+#include <limits>
+#include <cmath>
+#include <stdexcept>
+#include <sstream>
+
+#ifdef __FMA__
+#define USE_FMA
+#endif
+
+#ifdef __AVX2__
+#define USE_AVX2
+#endif
+
+#ifdef __AVX__
+#define USE_AVX
+#endif
+
+
+#ifdef __SSE4_1__
+#define USE_SSE41
+#endif
+
+#ifdef __SSE4_2__
+#define USE_SSE42
+#endif
+
+
+#ifndef _MSC_VER
+#include <stdint.h>
+#endif
+
+namespace BinSearch {
+
+#ifndef _MSC_VER
+typedef int8_t int8;
+typedef uint8_t uint8;
+typedef int32_t int32;
+typedef uint32_t uint32;
+typedef int64_t int64;
+typedef uint64_t uint64;
+#else
+typedef __int8 int8;
+typedef unsigned __int8 uint8;
+typedef __int32 int32;
+typedef unsigned __int32 uint32;
+typedef __int64 int64;
+typedef unsigned __int64 uint64;
+#endif
+
+namespace Details {
+
+#define myassert(cond, msg) if (!cond){ std::ostringstream os; os << "\nassertion failed: " << #cond << ", " << msg << "\n"; throw std::invalid_argument(os.str()); }
+
+// log2 is not defined in VS2008
+#if defined(_MSC_VER)
+inline uint32 log2 (uint32 val) {
+ if (val == 1) return 0;
+ uint32 ret = 0;
+ do {
+ ret++;
+ val >>= 1;
+ } while (val > 1);
+ return ret;
+}
+#endif
+
+#ifdef _DEBUG
+#define DEBUG
+#endif
+
+#ifdef _MSC_VER
+# define FORCE_INLINE __forceinline
+# define NO_INLINE __declspec(noinline)
+#else
+# define NO_INLINE __attribute__((noinline))
+# ifdef DEBUG
+# define FORCE_INLINE NO_INLINE
+# else
+# define FORCE_INLINE __attribute__((always_inline)) inline
+# endif
+#endif
+
+#ifdef USE_AVX
+#define COMISS "vcomiss"
+#define COMISD "vcomisd"
+#else
+#define COMISS "comiss"
+#define COMISD "comisd"
+#endif
+
+// nextafter is not defined in VS2008
+#if defined(_MSC_VER) && (_MSC_VER <= 1500)
+#include <float.h>
+inline float mynext(float x)
+{
+ return _nextafterf(x, std::numeric_limits<float>::max());
+}
+
+inline double mynext(double x)
+{
+ return _nextafter(x, std::numeric_limits<double>::max());
+}
+inline float myprev(float x)
+{
+ return _nextafterf(x, -std::numeric_limits<float>::max());
+}
+
+inline double myprev(double x)
+{
+ return _nextafter(x, -std::numeric_limits<double>::max());
+}
+#else
+inline float mynext(float x)
+{
+ return std::nextafterf(x, std::numeric_limits<float>::max());
+}
+
+inline double mynext(double x)
+{
+ return std::nextafter(x, std::numeric_limits<double>::max());
+}
+inline float myprev(float x)
+{
+ return std::nextafterf(x, -std::numeric_limits<float>::max());
+}
+
+inline double myprev(double x)
+{
+ return std::nextafter(x, -std::numeric_limits<double>::max());
+}
+#endif
+
+template <typename T>
+inline T next(T x)
+{
+ for (int i = 0; i < 4; ++i)
+ x = mynext(x);
+ return x;
+}
+
+template <typename T>
+inline T prev(T x)
+{
+ for (int i = 0; i < 4; ++i)
+ x = myprev(x);
+ return x;
+}
+
+} // namepsace Details
+} // namespace BinSearch
diff --git a/include/SIMD.h b/include/SIMD.h
new file mode 100644
index 0000000..642b80a
--- /dev/null
+++ b/include/SIMD.h
@@ -0,0 +1,562 @@
+#pragma once
+
+#include "Portable.h"
+
+#ifdef USE_SSE42
+#ifndef _MSC_VER
+#include <popcntintrin.h>
+#define popcnt32 _mm_popcnt_u32
+#else
+#include <intrin.h>
+#define popcnt32 __popcnt
+#endif
+#else // USE_SSE42
+namespace BinSearch {
+FORCE_INLINE int popcnt32(int x32)
+{
+ // strictly speaking this is not correct, as it ignores higher order bits
+ // however, this is only used on the resuot of movemask on a 128-bit register, which is 8 at most, so it is ok
+ // with 256-bit registers, SSE42 is defined, and we do not use this function
+ uint8 x = static_cast<uint8>(x32);
+ x = (x & 0x55) + (x >> 1 & 0x55);
+ x = (x & 0x33) + (x >> 2 & 0x33);
+ x = (x & 0x0f) + (x >> 4 & 0x0f);
+ return x;
+}
+} // namespace
+#endif
+
+#if defined(USE_AVX) || defined(USE_AVX2)
+#include <immintrin.h>
+#else
+#include <emmintrin.h>
+#ifdef USE_SSE41
+#include <smmintrin.h>
+#endif
+#endif
+
+#include "Type.h"
+
+namespace BinSearch {
+namespace Details {
+
+template <InstrSet I, class T>
+struct FVec;
+
+template <InstrSet I, class T>
+struct IVec;
+
+template <InstrSet I, class T>
+struct FVec1;
+
+template <> struct InstrIntTraits<SSE>
+{
+ typedef __m128i vec_t;
+};
+
+template <> struct InstrFloatTraits<SSE, float>
+{
+ typedef __m128 vec_t;
+};
+
+template <> struct InstrFloatTraits<SSE, double>
+{
+ typedef __m128d vec_t;
+};
+
+template <InstrSet I, typename T>
+struct FTOITraits
+{
+ typedef IVec<SSE, float> vec_t;
+};
+
+#ifdef USE_AVX
+
+template <>
+struct FTOITraits<AVX, float>
+{
+ typedef IVec<AVX, float> vec_t;
+};
+
+template <> struct InstrIntTraits<AVX>
+{
+ typedef __m256i vec_t;
+};
+
+template <> struct InstrFloatTraits<AVX, float>
+{
+ typedef __m256 vec_t;
+};
+
+template <> struct InstrFloatTraits<AVX, double>
+{
+ typedef __m256d vec_t;
+};
+
+#endif
+
+
+template <typename TR>
+struct VecStorage
+{
+ typedef typename TR::vec_t vec_t;
+
+ FORCE_INLINE operator vec_t&() { return vec; }
+ FORCE_INLINE operator const vec_t&() const { return vec; }
+
+protected:
+ FORCE_INLINE VecStorage() {}
+ FORCE_INLINE VecStorage(const vec_t& v) : vec( v ) {}
+
+ vec_t vec;
+};
+
+template <InstrSet>
+struct IVecBase;
+
+template <>
+struct IVecBase<SSE> : VecStorage<InstrIntTraits<SSE>>
+{
+protected:
+ FORCE_INLINE IVecBase() {}
+ FORCE_INLINE IVecBase( const vec_t& v) : VecStorage<InstrIntTraits<SSE>>( v ) {}
+public:
+ FORCE_INLINE static vec_t zero() { return _mm_setzero_si128(); }
+
+ FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32( vec ); }
+
+ FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask )
+ {
+#ifdef USE_SSE41
+ vec = _mm_blendv_epi8(vec, val, mask);
+#else
+ vec = _mm_or_si128(_mm_andnot_si128(mask,vec), _mm_and_si128(mask,val));
+#endif
+ }
+ FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask)
+ {
+ vec = _mm_or_si128(vec, _mm_and_si128(val,mask));
+ }
+};
+
+template <>
+struct IVec<SSE, float> : IVecBase<SSE>
+{
+ FORCE_INLINE IVec() {}
+ FORCE_INLINE IVec( int32 i ) : IVecBase<SSE>( _mm_set1_epi32( i ) ) {}
+ FORCE_INLINE IVec( const vec_t& v) : IVecBase<SSE>( v ) {}
+ FORCE_INLINE IVec( uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase<SSE>( _mm_set_epi32( u3, u2, u1, u0 ) ) {}
+
+ void setN( int32 i ) { vec = _mm_set1_epi32( i ); }
+
+#ifdef USE_SSE41
+ FORCE_INLINE int32 get1() const { return _mm_extract_epi32(vec, 1); }
+ FORCE_INLINE int32 get2() const { return _mm_extract_epi32(vec, 2); }
+ FORCE_INLINE int32 get3() const { return _mm_extract_epi32(vec, 3); }
+#else
+ FORCE_INLINE int32 get1() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 1 ) ); }
+ FORCE_INLINE int32 get2() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) ); }
+ FORCE_INLINE int32 get3() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 3 ) ); }
+#endif
+
+ FORCE_INLINE void store( uint32 *pi ) const { _mm_storeu_si128( reinterpret_cast<vec_t*>(pi), vec ); }
+
+ FORCE_INLINE int countbit()
+ {
+ return popcnt32(_mm_movemask_ps(_mm_castsi128_ps(vec)));
+ }
+};
+
+template <>
+struct IVec<SSE, double> : IVecBase<SSE>
+{
+ FORCE_INLINE IVec() {}
+ FORCE_INLINE IVec( int32 i ) : IVecBase<SSE>( _mm_set1_epi64x( i ) ) {}
+ FORCE_INLINE IVec( const vec_t& v) : IVecBase<SSE>( v ) {}
+ FORCE_INLINE IVec( uint64 u1, uint64 u0 ) : IVecBase<SSE>( _mm_set_epi64x(u1, u0) ) {}
+
+ void setN( int32 i ) { vec = _mm_set1_epi64x( i ); }
+
+ FORCE_INLINE int32 get1() const
+ {
+#ifdef USE_SSE41
+ return _mm_extract_epi32(vec, 2);
+#else
+ return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) );
+#endif
+ }
+
+ // extract the 2 32 bits integers no. 0, 2 and store them in a __m128i
+ FORCE_INLINE IVec<SSE,float> extractLo32s() const
+ {
+ return _mm_shuffle_epi32(vec, ((2 << 2) | 0));
+ }
+
+ FORCE_INLINE void store( uint32 *pi ) const
+ {
+ pi[0] = get0();
+ pi[1] = get1();
+ }
+
+ FORCE_INLINE int countbit()
+ {
+#if 1
+ // takes 4 cycles
+ __m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle
+ __m128i s = _mm_add_epi32(vec, hi);
+ int32 x = _mm_cvtsi128_si32(s);
+ return -x;
+#else
+ // takes 6 cycles
+ return popcnt32(_mm_movemask_pd(_mm_castsi128_pd(vec)));
+#endif
+ }
+};
+
+template <typename T>
+FORCE_INLINE IVec<SSE,T> operator>> (const IVec<SSE,T>& a, unsigned n) { return _mm_srli_epi32(a, n); }
+template <typename T>
+FORCE_INLINE IVec<SSE,T> operator<< (const IVec<SSE,T>& a, unsigned n) { return _mm_slli_epi32(a, n); }
+template <typename T>
+FORCE_INLINE IVec<SSE,T> operator& (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_and_si128( a, b ); }
+template <typename T>
+FORCE_INLINE IVec<SSE,T> operator| (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_or_si128( a, b ); }
+template <typename T>
+FORCE_INLINE IVec<SSE,T> operator^ (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_xor_si128( a, b ); }
+template <typename T>
+FORCE_INLINE IVec<SSE,T> operator+ (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_add_epi32( a, b ); }
+template <typename T>
+FORCE_INLINE IVec<SSE,T> operator- (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_sub_epi32( a, b ); }
+#ifdef USE_SSE41
+template <typename T>
+FORCE_INLINE IVec<SSE,T> min (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_min_epi32( a, b ); }
+#endif
+
+typedef VecStorage<InstrFloatTraits<SSE,float>> FVec128Float;
+
+template <>
+struct FVec1<SSE, float> : FVec128Float
+{
+ FORCE_INLINE FVec1() {}
+ FORCE_INLINE FVec1( float f ) : FVec128Float( _mm_load_ss( &f ) ) {}
+ FORCE_INLINE FVec1( const vec_t& v ): FVec128Float( v ) {}
+
+ FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); }
+};
+
+template <>
+struct FVec<SSE, float> : FVec128Float
+{
+ FORCE_INLINE FVec() {}
+ FORCE_INLINE FVec( float f ) : FVec128Float( _mm_set1_ps( f ) ) {}
+ FORCE_INLINE FVec( const float *v ) : FVec128Float( _mm_loadu_ps( v ) ) {}
+ FORCE_INLINE FVec( const vec_t& v) : FVec128Float(v) {}
+ FORCE_INLINE FVec( float f3, float f2, float f1, float f0 ) : FVec128Float( _mm_set_ps(f3, f2, f1, f0) ) {}
+
+ void set0( float f ) { vec = _mm_load_ss( &f ); }
+ void setN( float f ) { vec = _mm_set1_ps( f ); }
+
+ FORCE_INLINE void setidx( const float *xi, const IVec<SSE,float>& idx )
+ {
+ uint32 i0 = idx.get0();
+ uint32 i1 = idx.get1();
+ uint32 i2 = idx.get2();
+ uint32 i3 = idx.get3();
+ vec = _mm_set_ps( xi[i3], xi[i2], xi[i1], xi[i0] );
+ }
+
+ FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); }
+ FORCE_INLINE float get1() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 1 ) ); }
+ FORCE_INLINE float get2() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 2 ) ); }
+ FORCE_INLINE float get3() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 3 ) ); }
+};
+
+FORCE_INLINE FVec1<SSE,float> operator+ (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_add_ss( a, b ); }
+FORCE_INLINE FVec1<SSE,float> operator- (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_sub_ss( a, b ); }
+FORCE_INLINE FVec1<SSE,float> operator* (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_mul_ss( a, b ); }
+FORCE_INLINE FVec1<SSE,float> operator/ (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_div_ss( a, b ); }
+FORCE_INLINE int ftoi (const FVec1<SSE,float>& a) { return _mm_cvttss_si32(a); }
+FORCE_INLINE IVec<SSE,float> operator> (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_castps_si128( _mm_cmpgt_ss( a, b ) ); }
+#ifdef USE_FMA
+FORCE_INLINE FVec1<SSE, float> mulSub(const FVec1<SSE, float>& a, const FVec1<SSE, float>& b, const FVec1<SSE, float>& c) { return _mm_fmsub_ss(a, b, c); }
+#endif
+
+FORCE_INLINE FVec<SSE,float> operator- (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_sub_ps( a, b ); }
+FORCE_INLINE FVec<SSE,float> operator* (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_mul_ps( a, b ); }
+FORCE_INLINE FVec<SSE,float> operator/ (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_div_ps( a, b ); }
+FORCE_INLINE IVec<SSE,float> ftoi (const FVec<SSE,float>& a) { return _mm_cvttps_epi32(a); }
+FORCE_INLINE IVec<SSE,float> operator<= (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128( _mm_cmple_ps( a, b ) ); }
+FORCE_INLINE IVec<SSE,float> operator>= (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128( _mm_cmpge_ps( a, b ) ); }
+FORCE_INLINE IVec<SSE,float> operator< (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128(_mm_cmplt_ps(a, b)); }
+#ifdef USE_FMA
+FORCE_INLINE FVec<SSE, float> mulSub(const FVec<SSE, float>& a, const FVec<SSE, float>& b, const FVec<SSE, float>& c) { return _mm_fmsub_ps(a, b, c); }
+#endif
+
+typedef VecStorage<InstrFloatTraits<SSE,double>> FVec128Double;
+
+template <>
+struct FVec1<SSE, double> : FVec128Double
+{
+ FORCE_INLINE FVec1() {}
+ FORCE_INLINE FVec1( double f ) : FVec128Double( _mm_load_sd( &f ) ) {}
+ FORCE_INLINE FVec1( const vec_t& v ) : FVec128Double( v ) {}
+
+ FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); }
+};
+
+template <>
+struct FVec<SSE, double> : FVec128Double
+{
+ FORCE_INLINE FVec() {}
+ FORCE_INLINE FVec( double d ) : FVec128Double( _mm_set1_pd( d ) ) {}
+ FORCE_INLINE FVec( const double *v ) : FVec128Double( _mm_loadu_pd( v ) ) {}
+ FORCE_INLINE FVec( const vec_t& v) : FVec128Double( v ) {}
+ FORCE_INLINE FVec( double f1, double f0 ) : FVec128Double( _mm_set_pd(f1, f0) ) {}
+
+ void set0( double f ) { vec = _mm_load_sd( &f ); }
+ void setN( double f ) { vec = _mm_set1_pd( f ); }
+
+ FORCE_INLINE void setidx( const double *xi, const IVec<SSE,double>& idx )
+ {
+ vec = _mm_set_pd( xi[idx.get1()], xi[idx.get0()] );
+ }
+
+ FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); }
+ FORCE_INLINE double get1() const { return _mm_cvtsd_f64( _mm_shuffle_pd( vec, vec, 1 ) ); };
+};
+
+FORCE_INLINE FVec1<SSE,double> operator+ (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_add_sd( a, b ); }
+FORCE_INLINE FVec1<SSE,double> operator- (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_sub_sd( a, b ); }
+FORCE_INLINE FVec1<SSE,double> operator* (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_mul_sd( a, b ); }
+FORCE_INLINE FVec1<SSE,double> operator/ (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_div_sd( a, b ); }
+FORCE_INLINE int ftoi (const FVec1<SSE,double>& a) { return _mm_cvttsd_si32(a); }
+FORCE_INLINE IVec<SSE,double> operator> (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_castpd_si128( _mm_cmpgt_sd( a, b ) ); }
+#ifdef USE_FMA
+FORCE_INLINE FVec1<SSE, double> mulSub(const FVec1<SSE, double>& a, const FVec1<SSE, double>& b, const FVec1<SSE, double>& c) { return _mm_fmsub_sd(a, b, c); }
+#endif
+
+FORCE_INLINE FVec<SSE,double> operator- (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_sub_pd( a, b ); }
+FORCE_INLINE FVec<SSE,double> operator* (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_mul_pd( a, b ); }
+FORCE_INLINE FVec<SSE,double> operator/ (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_div_pd( a, b ); }
+FORCE_INLINE IVec<SSE,float> ftoi (const FVec<SSE,double>& a) { return _mm_cvttpd_epi32(a); }
+FORCE_INLINE IVec<SSE,double> operator<= (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128( _mm_cmple_pd( a, b ) ); }
+FORCE_INLINE IVec<SSE,double> operator< (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128(_mm_cmplt_pd(a, b)); }
+FORCE_INLINE IVec<SSE,double> operator>= (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128( _mm_cmpge_pd( a, b ) ); }
+#ifdef USE_FMA
+FORCE_INLINE FVec<SSE, double> mulSub(const FVec<SSE, double>& a, const FVec<SSE, double>& b, const FVec<SSE, double>& c ) { return _mm_fmsub_pd(a, b, c); }
+#endif
+
+#ifdef USE_AVX
+
+template <>
+struct IVecBase<AVX> : VecStorage<InstrIntTraits<AVX>>
+{
+protected:
+ FORCE_INLINE IVecBase() {}
+ FORCE_INLINE IVecBase( const vec_t& v) : VecStorage<InstrIntTraits<AVX>>( v ) {}
+public:
+ FORCE_INLINE static vec_t zero() { return _mm256_setzero_si256(); }
+
+ FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32(_mm256_castsi256_si128(vec)); }
+
+ FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask ) { vec = _mm256_blendv_epi8(vec, val, mask); }
+ FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask)
+ {
+ vec = _mm256_blendv_epi8(vec, val, mask);
+ //vec = _mm256_or_si256(vec, _mm256_and_si256(val,mask));
+ }
+
+ FORCE_INLINE __m128i lo128() const { return _mm256_castsi256_si128(vec); }
+ FORCE_INLINE __m128i hi128() const { return _mm256_extractf128_si256(vec, 1); }
+};
+
+template <>
+struct IVec<AVX, float> : IVecBase<AVX>
+{
+ FORCE_INLINE IVec() {}
+ FORCE_INLINE IVec( int32 i ) : IVecBase<AVX>( _mm256_set1_epi32( i ) ) {}
+ FORCE_INLINE IVec( const vec_t& v) : IVecBase<AVX>( v ) {}
+ FORCE_INLINE IVec(uint32 u7, uint32 u6, uint32 u5, uint32 u4, uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase<AVX>(_mm256_set_epi32(u7, u6, u5, u4, u3, u2, u1, u0)) {}
+
+ void setN( int32 i ) { vec = _mm256_set1_epi32( i ); }
+
+ FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 1); }
+ FORCE_INLINE int32 get2() const { return _mm256_extract_epi32(vec, 2); }
+ FORCE_INLINE int32 get3() const { return _mm256_extract_epi32(vec, 3); }
+ FORCE_INLINE int32 get4() const { return _mm256_extract_epi32(vec, 4); }
+ FORCE_INLINE int32 get5() const { return _mm256_extract_epi32(vec, 5); }
+ FORCE_INLINE int32 get6() const { return _mm256_extract_epi32(vec, 6); }
+ FORCE_INLINE int32 get7() const { return _mm256_extract_epi32(vec, 7); }
+
+ FORCE_INLINE void setidx( const uint32 *bi, const IVec<AVX,float>& idx )
+ {
+ vec = _mm256_i32gather_epi32(reinterpret_cast<const int32 *>(bi), idx, sizeof(uint32));
+ }
+
+ FORCE_INLINE void store( uint32 *pi ) const { _mm256_storeu_si256( reinterpret_cast<vec_t*>(pi), vec ); }
+
+ FORCE_INLINE int countbit()
+ {
+ return popcnt32(_mm256_movemask_ps(_mm256_castsi256_ps(vec)));
+ }
+};
+
+template <>
+struct IVec<AVX, double> : IVecBase<AVX>
+{
+ FORCE_INLINE IVec() {}
+ FORCE_INLINE IVec( int32 i ) : IVecBase<AVX>( _mm256_set1_epi64x( i ) ) {}
+ FORCE_INLINE IVec( const vec_t& v) : IVecBase<AVX>( v ) {}
+ FORCE_INLINE IVec(uint64 u3, uint64 u2, uint64 u1, uint64 u0) : IVecBase<AVX>(_mm256_set_epi64x(u3, u2, u1, u0)) {}
+
+ void setN( int32 i ) { vec = _mm256_set1_epi64x( i ); }
+
+ // extract the 4 32 bits integers no. 0, 2, 4, 6 and store them in a __m128i
+ FORCE_INLINE IVec<SSE,float> extractLo32s() const
+ {
+ union {
+ uint32 u32[4];
+ __m128i u;
+ } mask = {0,2,4,6};
+ //__m256 ps256 = _mm256_castsi256_ps(vec);
+ //__m128 lo128 = _mm256_castps256_ps128(ps256);
+ //__m128 hi128 = _mm256_extractf128_ps(ps256, 1);
+ //__m128 blend = _mm_shuffle_ps(lo128, hi128, 0 + (2<<2) + (0<<4) + (2<<6));
+ __m256i blend = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(mask.u));
+ return _mm256_castsi256_si128(blend);
+ }
+
+ //int32 get1() const { return _mm256_cvtsi256_si32( _mm256_shuffle_epi32( vec, 2 ) ); };
+ FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 2); }
+
+ FORCE_INLINE void store( uint32 *pi ) const
+ {
+ extractLo32s().store(pi);
+ }
+
+ FORCE_INLINE int countbit()
+ {
+ return popcnt32(_mm256_movemask_pd(_mm256_castsi256_pd(vec)));
+ }
+};
+
+template <typename T>
+FORCE_INLINE IVec<AVX,T> operator>> (const IVec<AVX,T>& a, unsigned n) { return _mm256_srli_epi32(a, n); }
+template <typename T>
+FORCE_INLINE IVec<AVX,T> operator<< (const IVec<AVX,T>& a, unsigned n) { return _mm256_slli_epi32(a, n); }
+template <typename T>
+FORCE_INLINE IVec<AVX,T> operator& (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_and_si256( a, b ); }
+template <typename T>
+FORCE_INLINE IVec<AVX,T> operator| (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_or_si256( a, b ); }
+template <typename T>
+FORCE_INLINE IVec<AVX,T> operator^ (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_xor_si256( a, b ); }
+template <typename T>
+FORCE_INLINE IVec<AVX,T> min (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_min_epi32( a, b ); }
+
+FORCE_INLINE IVec<AVX,float> operator+ (const IVec<AVX,float>& a, const IVec<AVX,float>& b ) { return _mm256_add_epi32( a, b ); }
+FORCE_INLINE IVec<AVX,float> operator- (const IVec<AVX,float>& a, const IVec<AVX,float>& b ) { return _mm256_sub_epi32( a, b ); }
+FORCE_INLINE IVec<AVX,double> operator+ (const IVec<AVX,double>& a, const IVec<AVX,double>& b ) { return _mm256_add_epi64( a, b ); }
+FORCE_INLINE IVec<AVX,double> operator- (const IVec<AVX,double>& a, const IVec<AVX,double>& b ) { return _mm256_sub_epi64( a, b ); }
+
+
+typedef VecStorage<InstrFloatTraits<AVX,float>> FVec256Float;
+
+template <>
+struct FVec<AVX, float> : FVec256Float
+{
+ FORCE_INLINE FVec() {}
+ FORCE_INLINE FVec( float f ) : FVec256Float( _mm256_set1_ps( f ) ) {}
+ FORCE_INLINE FVec( const float *v ) : FVec256Float( _mm256_loadu_ps( v ) ) {}
+ FORCE_INLINE FVec( const vec_t& v) : FVec256Float(v) {}
+ FORCE_INLINE FVec(float f7, float f6, float f5, float f4, float f3, float f2, float f1, float f0) : FVec256Float(_mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0)) {}
+
+ //void set0( float f ) { vec = _mm256_load_ss( &f ); }
+ void setN( float f ) { vec = _mm256_set1_ps( f ); }
+
+ FORCE_INLINE void setidx( const float *xi, const IVec<AVX,float>& idx )
+ {
+#if 1 // use gather primitives
+ vec = _mm256_i32gather_ps (xi, idx, 4);
+#elif 0
+ uint32 i0 = idx.get0();
+ uint32 i1 = idx.get1();
+ uint32 i2 = idx.get2();
+ uint32 i3 = idx.get3();
+ uint32 i4 = idx.get4();
+ uint32 i5 = idx.get5();
+ uint32 i6 = idx.get6();
+ uint32 i7 = idx.get7();
+ vec = _mm256_set_ps( xi[i7], xi[i6], xi[i5], xi[i4], xi[i3], xi[i2], xi[i1], xi[i0] );
+#else
+ union {
+ __m256i vec;
+ uint32 ui32[8];
+ } i;
+ i.vec = static_cast<const __m256i&>(idx);
+ vec = _mm256_set_ps(xi[i.ui32[7]], xi[i.ui32[6]], xi[i.ui32[5]], xi[i.ui32[4]], xi[i.ui32[3]], xi[i.ui32[2]], xi[i.ui32[1]], xi[i.ui32[0]]);
+#endif
+ }
+
+ FORCE_INLINE FVec<SSE, float> lo128() const { return _mm256_castps256_ps128(vec); }
+ FORCE_INLINE FVec<SSE, float> hi128() const { return _mm256_extractf128_ps(vec, 1); }
+
+ //FORCE_INLINE float get0() const { return _mm256_cvtss_f32( vec ); }
+ //FORCE_INLINE float get1() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 1 ) ); }
+ //FORCE_INLINE float get2() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 2 ) ); }
+ //FORCE_INLINE float get3() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 3 ) ); }
+};
+
+FORCE_INLINE FVec<AVX,float> operator- (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_sub_ps( a, b ); }
+FORCE_INLINE FVec<AVX,float> operator* (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_mul_ps( a, b ); }
+FORCE_INLINE FVec<AVX,float> operator/ (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_div_ps( a, b ); }
+FORCE_INLINE IVec<AVX,float> ftoi (const FVec<AVX,float>& a) { return _mm256_cvttps_epi32(a); }
+FORCE_INLINE IVec<AVX,float> operator<= (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_LE_OS) ); }
+FORCE_INLINE IVec<AVX,float> operator>= (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_GE_OS ) ); }
+FORCE_INLINE IVec<AVX,float> operator< (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256(_mm256_cmp_ps(a, b, _CMP_LT_OS )); }
+#ifdef USE_FMA
+FORCE_INLINE FVec<AVX, float> mulSub(const FVec<AVX, float>& a, const FVec<AVX, float>& b, const FVec<AVX, float>& c) { return _mm256_fmsub_ps(a, b, c); }
+#endif
+
+typedef VecStorage<InstrFloatTraits<AVX,double>> FVec256Double;
+
+template <>
+struct FVec<AVX, double> : FVec256Double
+{
+ FORCE_INLINE FVec() {}
+ FORCE_INLINE FVec( double d ) : FVec256Double( _mm256_set1_pd( d ) ) {}
+ FORCE_INLINE FVec( const double *v ) : FVec256Double( _mm256_loadu_pd( v ) ) {}
+ FORCE_INLINE FVec( const vec_t& v) : FVec256Double( v ) {}
+ FORCE_INLINE FVec(double d3, double d2, double d1, double d0) : FVec256Double(_mm256_set_pd(d3, d2, d1, d0)) {}
+
+ //void set0( double f ) { vec = _mm256_load_sd( &f ); }
+ void setN( double f ) { vec = _mm256_set1_pd( f ); }
+
+ FORCE_INLINE void setidx( const double *xi, const IVec<SSE,float>& idx )
+ {
+ vec = _mm256_i32gather_pd(xi, idx, 8);
+ }
+
+ FORCE_INLINE void setidx( const double *xi, const IVec<AVX,double>& idx )
+ {
+ vec = _mm256_i64gather_pd(xi, idx, 8);
+ }
+
+// FORCE_INLINE double get0() const { return _mm256_cvtsd_f64( vec ); }
+// FORCE_INLINE double get1() const { return _mm256_cvtsd_f64( _mm256_shuffle_pd( vec, vec, 1 ) ); };
+};
+
+FORCE_INLINE FVec<AVX,double> operator- (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_sub_pd( a, b ); }
+FORCE_INLINE FVec<AVX,double> operator* (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_mul_pd( a, b ); }
+FORCE_INLINE FVec<AVX,double> operator/ (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_div_pd( a, b ); }
+FORCE_INLINE IVec<SSE,float> ftoi (const FVec<AVX,double>& a) { return _mm256_cvttpd_epi32(a); }
+FORCE_INLINE IVec<AVX,double> operator<= (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_LE_OS ) ); }
+FORCE_INLINE IVec<AVX,double> operator< (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd(a, b, _CMP_LT_OS)); }
+FORCE_INLINE IVec<AVX,double> operator>= (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_GE_OS ) ); }
+#ifdef USE_FMA
+FORCE_INLINE FVec<AVX, double> mulSub(const FVec<AVX, double>& a, const FVec<AVX, double>& b, const FVec<AVX, double>& c) { return _mm256_fmsub_pd(a, b, c); }
+#endif
+
+#endif
+
+} // namepsace Details
+} // namespace BinSearch
diff --git a/include/Type.h b/include/Type.h
new file mode 100644
index 0000000..720bfb8
--- /dev/null
+++ b/include/Type.h
@@ -0,0 +1,221 @@
+ #pragma once
+
+#include <stddef.h>
+#include <vector>
+#include <limits>
+
+#include "Portable.h"
+
+using std::size_t;
+
+namespace BinSearch {
+
+enum InstrSet { Scalar, SSE, AVX };
+
+#define ALGOENUM(x, b) x,
+enum Algos
+ {
+#include "AlgoXCodes.h"
+ };
+#undef ALGOENUM
+
+namespace Details {
+
+ template <InstrSet I>
+ struct InstrIntTraits;
+
+ template <InstrSet I, typename T>
+ struct InstrFloatTraits;
+
+ // base class for algorithm supporting the method:
+ // uint32 scalar(T z) const
+ template <typename T, Algos A, typename Enable=void>
+ struct AlgoScalarBase;
+
+ // base class for algorithm supporting the following methods, constants and definitions:
+ // static const uint32 nElem
+ // struct Constants;
+ // void initConstants(Constants& cst) const
+ // void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
+ // The function vectorial processes nElem items
+ template <InstrSet I, typename T, Algos A, typename Enable=void>
+ struct AlgoVecBase;
+
+ template <typename T> struct IntTraits;
+
+ template <> struct IntTraits<float>
+ {
+ typedef uint32 itype;
+ };
+ template <> struct IntTraits<double>
+ {
+ typedef uint64 itype;
+ };
+
+ template <int N>
+ struct Body
+ {
+ template <uint32 D, typename T, typename Expr>
+ FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const typename Expr::Constants& cst)
+ {
+ e.vectorial(ri, zi, cst);
+ Body<N - 1>::template iteration<D>(e, ri + D, zi + D, cst);
+ }
+
+ };
+
+ template <>
+ struct Body<0>
+ {
+ template <uint32 D, typename T, typename Expr, typename H>
+ FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const H&)
+ {
+ }
+ };
+
+ template <typename T, typename Algo>
+ struct Loop
+ {
+ typedef Algo algo_type;
+ static const uint32 M = 4;
+ static const uint32 D = algo_type::nElem;
+
+ FORCE_INLINE static void loop(const algo_type& e, uint32 *ri, const T* zi, uint32 n)
+ {
+ typename algo_type::Constants cst;
+ e.initConstants(cst);
+
+ uint32 j = 0;
+ while (j + (D*M) <= n) {
+ Details::Body<M>::template iteration<D>(e, ri + j, zi + j, cst);
+ j += (D*M);
+ }
+ while (j + D <= n) {
+ e.vectorial(ri + j, zi + j, cst);
+ j += D;
+ }
+ while (D > 1 && j < n) {
+ ri[j] = e.scalar(zi[j]);
+ j += 1;
+ }
+ }
+ };
+
+ template <uint32 nIterTot, uint32 nIterLeft>
+ struct _Pipeliner
+ {
+ template <typename Expr, typename Data>
+ FORCE_INLINE static void go(const Expr& e, Data* d)
+ {
+ e.template run<nIterTot - nIterLeft>(d);
+ _Pipeliner<nIterTot, nIterLeft - 1>::go(e, d);
+ }
+ };
+
+ template <uint32 nIterTot>
+ struct _Pipeliner<nIterTot, 0>
+ {
+ template <typename Expr, typename Data>
+ FORCE_INLINE static void go(const Expr& e, Data* d)
+ {
+ }
+ };
+
+ template <uint32 nIter>
+ struct Pipeliner
+ {
+ template <typename Expr, typename Data>
+ FORCE_INLINE static void go(const Expr& e, Data* d)
+ {
+ _Pipeliner<nIter, nIter>::go(e, d);
+ }
+ };
+
+
+#if 1
+ template <class T>
+ char is_complete_impl(char (*)[sizeof(T)]);
+
+ template <class>
+ long is_complete_impl(...);
+
+ template <class T>
+ struct IsComplete
+ {
+ static const bool value = sizeof(is_complete_impl<T>(0)) == sizeof(char);
+ };
+#else
+ template <class T, std::size_t = sizeof(T)>
+ std::true_type is_complete_impl(T *);
+
+ std::false_type is_complete_impl(...);
+
+ template <class T>
+ struct IsComplete : decltype(is_complete_impl(std::declval<T*>())) {};
+#endif
+
+template <typename T, Algos A>
+struct AlgoScalarToVec : AlgoScalarBase<T,A>
+{
+ typedef AlgoScalarBase<T, A> base_t;
+
+ AlgoScalarToVec(const typename base_t::Data& d) : base_t(d) {}
+ AlgoScalarToVec(const T* px, const uint32 n) : base_t(px, n) {}
+
+ static const uint32 nElem = 1;
+
+ struct Constants
+ {
+ };
+
+ void initConstants(Constants& cst) const
+ {
+ }
+
+ FORCE_INLINE
+ void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
+ {
+ *pr = base_t::scalar(*pz);
+ }
+};
+
+template<bool B, class T, class F>
+struct conditional { typedef T type; };
+
+template<class T, class F>
+struct conditional<false, T, F> { typedef F type; };
+
+template <typename T, bool C>
+struct CondData
+{
+ FORCE_INLINE CondData(T x) : v(x) {}
+ FORCE_INLINE operator const T&() const { return v;}
+private:
+ T v;
+};
+
+template <typename T>
+struct CondData<T,false>
+{
+ FORCE_INLINE CondData(T) {}
+ FORCE_INLINE operator const T() const { return 0;}
+};
+
+template <InstrSet I, typename T, Algos A, bool L=false>
+struct BinAlgoBase : Details::conditional< Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value
+ , Details::AlgoVecBase<I, T, A>
+ , Details::AlgoScalarToVec<T,A>
+ >::type
+{
+ typedef typename Details::conditional< Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value
+ , Details::AlgoVecBase<I, T, A>
+ , Details::AlgoScalarToVec<T,A>
+ >::type base_t;
+
+ BinAlgoBase(const T* px, const uint32 n) : base_t(px, n) {}
+ BinAlgoBase(const typename base_t::Data& d) : base_t(d) {}
+};
+
+} // namespace Details
+
+} // namespace BinSearch
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..374b58c
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,6 @@
+[build-system]
+requires = [
+ "setuptools>=42",
+ "wheel"
+]
+build-backend = "setuptools.build_meta"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e079f8a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1 @@
+pytest
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..dc1eb60
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,32 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+from setuptools import setup, find_packages
+
+
+
+def read(fname):
+ return open(os.path.join(os.path.dirname(__file__), fname)).read()
+
+
+setup(
+ name = f"bitsandbytes-cuda{os.environ['CUDA_VERSION']}",
+ version = "0.0.23",
+ author = "Tim Dettmers",
+ author_email = "tim.dettmers@gmail.com",
+ description = ("Numpy-like library for GPUs."),
+ license = "MIT",
+ keywords = "gpu",
+ url = "http://packages.python.org/bitsandbytes",
+ packages=find_packages(),
+ package_data={'': ['libbitsandbytes.so']},
+ long_description=read('README.md'),
+ long_description_content_type = 'text/markdown',
+ classifiers=[
+ "Development Status :: 1 - Planning",
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence'
+ ],
+)
+
diff --git a/tests/test_functional.py b/tests/test_functional.py
new file mode 100644
index 0000000..2a7d308
--- /dev/null
+++ b/tests/test_functional.py
@@ -0,0 +1,213 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import pytest
+import torch
+import bitsandbytes as bnb
+
+from itertools import product
+
+from bitsandbytes import functional as F
+
+def setup():
+ pass
+
+def teardown():
+ pass
+
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half'])
+def test_estimate_quantiles(dtype):
+ A = torch.rand(1024, 1024, device='cuda')
+ A = A.to(dtype)
+ code = F.estimate_quantiles(A)
+
+ percs = torch.linspace(1/512, 511/512, 256, device=A.device)
+ torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
+
+ A = torch.randn(1024, 1024, device='cuda')
+ A = A.to(dtype)
+ code = F.estimate_quantiles(A)
+
+ quantiles = torch.quantile(A.float(), percs)
+ diff = torch.abs(code-quantiles)
+ assert (diff > 5e-02).sum().item() == 0
+
+
+def test_quantile_quantization():
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device='cuda')
+ code = F.estimate_quantiles(A1)
+ C = F.quantize_no_absmax(A1, code)
+ A2 = F.dequantize_no_absmax(C, code)
+ diff = torch.abs(A1-A2).mean().item()
+ assert diff < 0.0075
+
+ A1 = torch.rand(1024, 1024, device='cuda')
+ code = F.estimate_quantiles(A1)
+ C = F.quantize_no_absmax(A1, code)
+ A2 = F.dequantize_no_absmax(C, code)
+ diff = torch.abs(A1-A2).mean().item()
+ torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
+ assert diff < 0.001
+
+
+def test_dynamic_quantization():
+ diffs = []
+ reldiffs = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device='cuda')
+ C, S = F.quantize(A1)
+ A2 = F.dequantize(C, S)
+ diff = torch.abs(A1-A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ assert diff.mean().item() < 0.0135
+ print(sum(diffs)/len(diffs))
+ print(sum(reldiffs)/len(reldiffs))
+
+ for i in range(100):
+ A1 = torch.rand(1024, 1024, device='cuda')
+ C, S = F.quantize(A1)
+ A2 = F.dequantize(C, S)
+ diff = torch.abs(A1-A2).mean().item()
+ torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
+ assert diff < 0.004
+
+
+def test_dynamic_blockwise_quantization():
+ diffs = []
+ reldiffs = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device='cuda')
+ C, S = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, S)
+ diff = torch.abs(A1-A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ assert diffs[-1] < 0.011
+ print(sum(diffs)/len(diffs))
+ print(sum(reldiffs)/len(reldiffs))
+
+ diffs = []
+ for i in range(100):
+ A1 = torch.rand(1024, 1024, device='cuda')
+ C, S = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, S)
+ diff = torch.abs(A1-A2).mean().item()
+ assert diff < 0.0033
+ diffs.append(diff)
+ torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
+ #print(sum(diffs)/len(diffs))
+
+def test_dynamic_blockwise_stochastic_quantization():
+ diffs = []
+ reldiffs = []
+ rand = torch.rand(1024).cuda()
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device='cuda')
+ C1, S1 = F.quantize_blockwise(A1, rand=rand)
+ C2, S2 = F.quantize_blockwise(A1)
+ # a maximunm distance of quantized values of 1
+ torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
+ fraction_smaller = (C1<C2).float().sum()/C1.numel()
+ fraction_larger = (C1>C2).float().sum()/C1.numel()
+ torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0)
+
+
+
+@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half'])
+def test_percentile_clipping(gtype):
+ gnorm_vec1 = torch.zeros(100, device='cuda')
+ gnorm_vec2 = torch.zeros(100, device='cuda')
+ n = 4
+ step = 0
+ percentile=5
+ for i in range(1000):
+ step += 1
+ g = torch.randn(n, n, dtype=gtype, device='cuda')
+ gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
+ assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1
+
+ gnorm2 = torch.norm(g.float())
+ if step == 1:
+ gnorm_vec1[:] = gnorm2
+ else:
+ gnorm_vec1[step % 100] = gnorm2
+
+ vals, idx = torch.sort(gnorm_vec1)
+ clip1 = vals[percentile]
+
+ torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
+ torch.testing.assert_allclose(clip1, clip2)
+ torch.testing.assert_allclose(gnorm1, gnorm2)
+
+
+def test_stable_embedding():
+ layer = bnb.nn.StableEmbedding(1024, 1024)
+ layer.reset_parameters()
+
+
+def test_dynamic_blockwise_quantization_cpu():
+ #A1 = torch.randn(1024, 1024, device='cpu')
+ #code = F.create_dynamic_map()
+ #for i in range(1000):
+ # C, S = F.quantize_blockwise(A1, code=code)
+ # A2 = F.dequantize_blockwise(C, S)
+
+ for i in range(10):
+ # equivalence with GPU blockwise quantization
+ A1 = torch.randn(1024, 1024, device='cpu')
+ C1, S1 = F.quantize_blockwise(A1)
+ C2, S2 = F.quantize_blockwise(A1.cuda())
+ torch.testing.assert_allclose(S1[0], S2[0].cpu())
+ # there seems to be some issues with precision in CUDA vs CPU
+ # not all elements are usually close, with couple off elements in a million
+ idx = torch.isclose(C1, C2.cpu())
+ assert (idx==0).sum().item() < 15
+
+
+ diffs = []
+ reldiffs = []
+ for i in range(10):
+ A1 = torch.randn(1024, 1024, device='cpu')
+ C, S = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, S)
+ diff = torch.abs(A1-A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ assert diffs[-1] < 0.011
+ #print(sum(diffs)/len(diffs))
+ #print(sum(reldiffs)/len(reldiffs))
+
+ diffs = []
+ for i in range(10):
+ A1 = torch.rand(1024, 1024, device='cpu')
+ C, S = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, S)
+ diff = torch.abs(A1-A2).mean().item()
+ assert diff < 0.0033
+ diffs.append(diff)
+ torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
+ #print(sum(diffs)/len(diffs))
+
+
+def test_histogram():
+ dim1, dim2 = 32, 32
+ source = torch.rand(dim1, dim2, device='cuda')
+ idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
+ idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
+ histogram1 = torch.zeros((256, 256)).cuda()
+ histogram2 = torch.zeros((256, 256)).cuda()
+
+ F.histogram_scatter_add_2d(histogram2, idx1, idx2, source)
+
+ for i in range(dim1):
+ for j in range(dim2):
+ histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j]
+
+ torch.testing.assert_allclose(histogram1, histogram2)
+ torch.testing.assert_allclose(histogram1.sum(), source.sum())
diff --git a/tests/test_optim.py b/tests/test_optim.py
new file mode 100644
index 0000000..4d67b08
--- /dev/null
+++ b/tests/test_optim.py
@@ -0,0 +1,362 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import time
+import shutil
+import uuid
+import pytest
+import ctypes
+import torch
+import bitsandbytes as bnb
+import bitsandbytes.functional as F
+
+from os.path import join
+from itertools import product
+
+import apex
+
+def get_temp_dir():
+ path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
+ os.makedirs(path, exist_ok=True)
+ return path
+
+def rm_path(path):
+ shutil.rmtree(path)
+
+str2optimizers = {}
+str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
+str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
+str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
+str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
+str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
+str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
+
+str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
+str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
+str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
+str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
+str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
+str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
+str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
+str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
+str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
+str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
+str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
+
+str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
+str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
+str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
+
+str2statenames = {}
+str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
+str2statenames['momentum'] = [('momentum_buffer', 'state1')]
+str2statenames['lars'] = [('momentum_buffer', 'state1')]
+str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
+str2statenames['rmsprop'] = [('square_avg', 'state1')]
+str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
+str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
+str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
+str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
+str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
+str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
+str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
+str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
+
+dim1 = [1024]
+dim2 = [32, 1024, 4097, 1]
+gtype = [torch.float32, torch.float16]
+optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
+values = list(product(dim1,dim2, gtype, optimizer_names))
+names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
+def test_optimizer32bit(dim1, dim2, gtype, optim_name):
+ if dim1 == 1 and dim2 == 1: return
+ p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+ p2 = p1.clone()
+ p1 = p1.float()
+
+
+ torch_optimizer = str2optimizers[optim_name][0]([p1])
+ bnb_optimizer = str2optimizers[optim_name][1]([p2])
+
+ if gtype == torch.float32:
+ atol, rtol = 1e-6, 1e-5
+ else:
+ atol, rtol = 1e-4, 1e-3
+
+
+ for i in range(50):
+ g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ p1.grad = g.clone().float()
+ p2.grad = g.clone()
+
+ bnb_optimizer.step()
+ torch_optimizer.step()
+
+ for name1, name2 in str2statenames[optim_name]:
+ torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
+
+ torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
+
+ if i % 10 == 0 and i > 0:
+ path = get_temp_dir()
+ torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
+ del bnb_optimizer
+ bnb_optimizer = None
+ bnb_optimizer = str2optimizers[optim_name][1]([p2])
+ bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
+ rm_path(path)
+ torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
+ for name1, name2 in str2statenames[optim_name]:
+ torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
+
+ if gtype == torch.float16:
+ # the adam buffers should also be close because they are 32-bit
+ # but the paramters can diverge because they are 16-bit
+ # the difference grow larger and larger with each update
+ # --> copy the state to keep weights close
+ p1.data = p1.data.half().float()
+ p2.copy_(p1.data)
+ torch.testing.assert_allclose(p1.half(), p2)
+ if optim_name in ['lars', 'lamb']:
+ assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0
+
+dim1 = [1024]
+dim2 = [32, 1024, 4097]
+gtype = [torch.float32, torch.float16]
+values = list(product(dim1,dim2, gtype))
+names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
+def test_global_config(dim1, dim2, gtype):
+ if dim1 == 1 and dim2 == 1: return
+ p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
+ p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
+ p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
+ mask = torch.rand_like(p2) < 0.1
+ beta1 = 0.9
+ beta2 = 0.999
+ lr = 0.001
+ eps = 1e-8
+
+ bnb.optim.GlobalOptimManager.get_instance().initialize()
+ bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
+
+ bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
+ p1 = p1.cuda()
+ p2 = p2.cuda()
+ p3 = p3.cuda()
+
+ adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)
+
+ if gtype == torch.float32:
+ atol, rtol = 1e-6, 1e-5
+ else:
+ atol, rtol = 1e-4, 1e-3
+
+ for i in range(50):
+ g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
+ g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
+ g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
+ p1.grad = g1
+ p2.grad = g2
+ p3.grad = g3
+
+ adam2.step()
+
+ assert adam2.state[p3]['state1'].dtype == torch.uint8
+ assert adam2.state[p3]['state2'].dtype == torch.uint8
+
+
+
+dim1 = [1024]
+dim2 = [32, 1024, 4097]
+gtype = [torch.float32, torch.float16]
+optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
+values = list(product(dim1,dim2, gtype, optimizer_names))
+names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
+def test_optimizer8bit(dim1, dim2, gtype, optim_name):
+ if dim1 == 1 and dim2 == 1: return
+ p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+ p2 = p1.clone()
+ p1 = p1.float()
+ blocksize = 2048
+
+ torch_optimizer = str2optimizers[optim_name][0]([p1])
+ bnb_optimizer = str2optimizers[optim_name][1]([p2])
+
+ if gtype == torch.float32:
+ atol, rtol = 3e-3, 1e-3
+ patol, prtol = 1e-5, 1e-3
+
+ else:
+ atol, rtol = 3e-3, 1e-3
+ patol, prtol = 1e-5, 1e-3
+
+ errors = []
+ relerrors = []
+
+ for i in range(50):
+ g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ p1.grad = g.clone().float()
+ p2.grad = g.clone()
+
+ bnb_optimizer.step()
+ torch_optimizer.step()
+
+ torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
+
+ dequant_states = []
+ for name1, name2, qmap, max_val in str2statenames[optim_name]:
+ #print(bnb_optimizer.state[p2][max_val], name1)
+ if 'blockwise' in optim_name:
+ s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
+ else:
+ s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
+ num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
+ assert num_not_close.sum().item() < 20
+ dequant_states.append(s1.clone())
+
+ err = torch.abs(p1-p2)
+ relerr = err/torch.abs(p1)
+ assert err.mean() < 0.0001
+ assert relerr.mean() < 0.001
+
+ errors.append(err.mean().item())
+ relerrors.append(relerr.mean().item())
+
+ if i % 10 == 0 and i > 0:
+ for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
+ s1cpy = s.clone()
+ raws1cpy = bnb_optimizer.state[p2][name2].clone()
+ qmap1 = bnb_optimizer.state[p2][qmap].clone()
+
+ path = get_temp_dir()
+ torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
+ del bnb_optimizer
+ bnb_optimizer = None
+ bnb_optimizer = str2optimizers[optim_name][1]([p2])
+ bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
+ rm_path(path)
+ torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
+ torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
+
+ if 'blockwise' in optim_name:
+ s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
+ else:
+ s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
+ torch.testing.assert_allclose(s1cpy, s1)
+
+ num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
+ assert num_not_close.sum().item() < 20
+ torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
+
+ # the parameters diverge quickly. Here we keep them close
+ # together so we can test against the Adam error
+ p1.data = p1.data.to(gtype).float()
+ p2.copy_(p1.data)
+ torch.testing.assert_allclose(p1.to(gtype), p2)
+ for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
+ torch_optimizer.state[p1][name1].copy_(s.data)
+
+ #print(sum(errors)/len(errors))
+ #print(sum(relerrors)/len(relerrors))
+
+
+
+dim1 = [1024]
+dim2 = [32, 1024, 4097]
+gtype = [torch.float32]
+optim_bits = [32, 8]
+values = list(product(dim1,dim2, gtype, optim_bits))
+names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
+def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
+ if dim1 == 1 and dim2 == 1: return
+ p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
+ beta1 = 0.9
+ beta2 = 0.999
+ lr = 0.001
+ eps = 1e-8
+ p1 = p1.cuda()
+ p2 = p1.clone()
+ adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
+ adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
+
+ gnorm_vec = torch.zeros(100).cuda()
+ step = 0
+
+ for i in range(50):
+ step += 1
+ g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i)
+ g2 = g1.clone()
+ p2.grad = g2
+
+ current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
+ g1 = (g1.float()*gnorm_scale).to(gtype)
+ p1.grad = g1
+
+ adam1.step()
+ adam2.step()
+
+ # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
+ if optim_bits == 32:
+ torch.testing.assert_allclose(p1, p2)
+ torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4)
+ torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4)
+ elif optim_bits == 8:
+ torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
+ torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3)
+ torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3)
+ adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1'])
+ adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2'])
+ if i % 10 == 0 and i > 0:
+ path = get_temp_dir()
+ torch.save(adam2.state_dict(),join(path, 'opt.pt'))
+ del adam2
+ adam2 = None
+ adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
+ adam2.load_state_dict(torch.load(join(path, 'opt.pt')))
+
+
+
+
+dim1 = [4096]
+dim2 = [4096]
+gtype = [torch.float32, torch.float16]
+#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
+#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
+#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
+#optimizer_names = ['lamb_apex', 'lamb8bit']
+#optimizer_names = ['lars_apex', 'lars8bit']
+optimizer_names = ['adam8bit_blockwise']
+values = list(product(dim1,dim2, gtype, optimizer_names))
+names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
+def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
+ if dim1 == 1 and dim2 == 1: return
+ p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+
+
+ bnb_optimizer = str2optimizers[optim_name][1]([p1])
+
+ g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ p1.grad = g
+ for i in range(5000):
+ if i == 500:
+ # 100 iterations for burn-in
+ torch.cuda.synchronize()
+ t0 = time.time()
+
+ bnb_optimizer.step()
+
+ torch.cuda.synchronize()
+ s = time.time()-t0
+ print('')
+ params = 4500*4096*4096
+ print(optim_name, gtype, s/params)
+ #assert s < 3.9
+
+