summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/__init__.py20
-rw-r--r--bitsandbytes/autograd/_functions.py134
-rw-r--r--bitsandbytes/cextension.py23
-rw-r--r--bitsandbytes/cuda_setup.py78
-rw-r--r--bitsandbytes/debug_cli.py1
-rw-r--r--bitsandbytes/functional.py1395
-rw-r--r--bitsandbytes/nn/__init__.py8
-rw-r--r--bitsandbytes/nn/modules.py197
-rw-r--r--bitsandbytes/optim/__init__.py6
-rw-r--r--bitsandbytes/optim/adagrad.py114
-rw-r--r--bitsandbytes/optim/adam.py179
-rw-r--r--bitsandbytes/optim/adamw.py104
-rw-r--r--bitsandbytes/optim/lamb.py117
-rw-r--r--bitsandbytes/optim/lars.py167
-rw-r--r--bitsandbytes/optim/optimizer.py565
-rw-r--r--bitsandbytes/optim/rmsprop.py115
-rw-r--r--bitsandbytes/optim/sgd.py109
-rw-r--r--bitsandbytes/utils.py4
18 files changed, 2389 insertions, 947 deletions
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index 3c3affa..7ca017d 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -1,16 +1,18 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .nn import modules
-from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
+from .autograd._functions import (MatmulLtState, bmm_cublas, matmul,
+ matmul_cublas, mm_cublas)
from .cextension import COMPILED_WITH_CUDA
+from .nn import modules
if COMPILED_WITH_CUDA:
from .optim import adam
-__pdoc__ = {'libbitsandbytes': False,
- 'optim.optimizer.Optimizer8bit': False,
- 'optim.optimizer.MockArgs': False
- }
+__pdoc__ = {
+ "libbitsandbytes": False,
+ "optim.optimizer.Optimizer8bit": False,
+ "optim.optimizer.MockArgs": False,
+}
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index e641583..a08b560 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,21 +1,24 @@
+from dataclasses import dataclass
+
import torch
+
import bitsandbytes as bnb
import bitsandbytes.functional as F
-from dataclasses import dataclass
-
tensor = torch.Tensor
-'''
+"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
-'''
+"""
+
+
class GlobalOutlierPooler(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.outliers = set()
@@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
return cls._instance
def add_outliers(self, outlier_idx, feature_dim):
- if self.model_dim is None: self.model_dim = feature_dim
- if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
+ if self.model_dim is None:
+ self.model_dim = feature_dim
+ if feature_dim != self.model_dim:
+ return # we do not encode outliers for the 2nd FFN layer
self.outliers.update(outlier_idx.tolist())
def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
-class MatMul8bit(torch.autograd.Function):
+class MatMul8bit(torch.autograd.Function):
@staticmethod
- def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
+ def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)
else:
- if len(B.shape) == 2: dim = 0
- else: dim = 1
+ if len(B.shape) == 2:
+ dim = 0
+ else:
+ dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB)
@@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
- if not grad_output.is_contiguous(): grad_output.contiguous()
- qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
- if not A.is_contiguous(): A = A.contiguous()
- qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
+ if not grad_output.is_contiguous():
+ grad_output.contiguous()
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output.view(-1, grad_output.shape[2]),
+ dim=0,
+ quant_type=quant_type,
+ )
+ if not A.is_contiguous():
+ A = A.contiguous()
+ qA, S2 = F.vectorwise_quant(
+ A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
+ )
igrad_B = F.igemm(qA.t(), qgrad_output)
- grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
+ grad_B = F.vectorwise_mm_dequant(
+ igrad_B, S2.t(), S1, grad_output.dtype, quant_type
+ )
else:
- qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output, dim=dims, quant_type=quant_type
+ )
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
- grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
+ grad_B = F.vectorwise_mm_dequant(
+ igrad_B,
+ S2.permute(permute_dim),
+ S1,
+ grad_output.dtype,
+ quant_type,
+ )
if A.requires_grad:
- if len(grad_output.shape) == 3: dims = [2]
- else: dims = [1]
+ if len(grad_output.shape) == 3:
+ dims = [2]
+ else:
+ dims = [1]
if len(B.shape) == 3:
# bio -> boi
@@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
- qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output, dim=dims, quant_type=quant_type
+ )
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
- grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
+ grad_A = F.vectorwise_mm_dequant(
+ igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
+ )
return grad_A, grad_B, None, None, None
@@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
+
@dataclass
class MatmulLtState:
CB = None
@@ -159,7 +191,6 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
-
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# 1. Quantize A
@@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB = B.requires_grad
formatB = state.formatB
input_shape = A.shape
- if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
- assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
+ if state.outlier_pool is None:
+ state.outlier_pool = GlobalOutlierPooler.get_instance()
+ assert (
+ A.dtype == torch.float16
+ ), f"The input data type needs to be fp16 but {A.dtype} was found!"
# 1. Quantize A
- if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
+ if len(A.shape) == 3:
+ A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
@@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
- #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
- #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
+ # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
+ # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
@@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
- #if state.idx is not None:
+ # if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
- #else:
+ # else:
# subA = None
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
-
# 2. Quantize B
if state.has_fp16_weights:
- has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
+ has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
- if is_transposed: B = B.contiguous()
+ if is_transposed:
+ B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
@@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
- #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- #else:
+ # else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
- state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half()
+ state.subB = (
+ (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
+ )
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
@@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
@@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
- #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
+ # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone
return clone_func(output.view(output_shape))
@@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
- assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
+ assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
@@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
- C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
+ C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
- C32grad, Sgrad = F.transform(Cgrad, 'col32')
+ C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
- state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
+ state.CxBt, state.SBt = F.transform(
+ state.CBt, to_order=formatB, transpose=True
+ )
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
- grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
+ ctx.grad_shape
+ )
return grad_A, grad_B, None, None, None, None, None
@@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
matmul = MatMul8bitLt.apply
-def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
+def matmul(
+ A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
+):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, state)
-
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index 4bc7bf7..bc11474 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -1,6 +1,7 @@
import ctypes as ct
import os
from warnings import warn
+
from bitsandbytes.cuda_setup import evaluate_cuda_setup
@@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = {}
binary_name = evaluate_cuda_setup()
- if not os.path.exists(os.path.dirname(__file__) + f'/{binary_name}'):
- print(f'TODO: compile library for specific version: {binary_name}')
- print('defaulting to libbitsandbytes.so')
- self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
+ if not os.path.exists(os.path.dirname(__file__) + f"/{binary_name}"):
+ print(f"TODO: compile library for specific version: {binary_name}")
+ print("defaulting to libbitsandbytes.so")
+ self.lib = ct.cdll.LoadLibrary(
+ os.path.dirname(__file__) + "/libbitsandbytes.so"
+ )
else:
- self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + f'/{binary_name}')
+ self.lib = ct.cdll.LoadLibrary(
+ os.path.dirname(__file__) + f"/{binary_name}"
+ )
@classmethod
def get_instance(cls):
@@ -35,6 +40,8 @@ try:
lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
- warn("The installed version of bitsandbytes was compiled without GPU support. "
- "8-bit optimizers and GPU quantization are unavailable.")
+ warn(
+ "The installed version of bitsandbytes was compiled without GPU support. "
+ "8-bit optimizers and GPU quantization are unavailable."
+ )
COMPILED_WITH_CUDA = False
diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py
index 5ed0c89..0dd53c5 100644
--- a/bitsandbytes/cuda_setup.py
+++ b/bitsandbytes/cuda_setup.py
@@ -18,31 +18,36 @@ evaluation:
- based on that set the default path
"""
+import ctypes
+import shlex
+import subprocess
from os import environ as env
from pathlib import Path
from typing import Set, Union
-from .utils import warn_of_missing_prerequisite, print_err
-import ctypes
-import shlex
-import subprocess
+from .utils import print_err, warn_of_missing_prerequisite
+
def execute_and_return(strCMD):
- proc = subprocess.Popen(shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ proc = subprocess.Popen(
+ shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
out, err = proc.communicate()
out, err = out.decode("UTF-8").strip(), err.decode("UTF-8").strip()
return out, err
+
def check_cuda_result(cuda, result_val):
if result_val != 0:
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
print(f"Count not initialize CUDA - failure!")
- raise Exception('CUDA exception!')
+ raise Exception("CUDA exception!")
return result_val
+
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
def get_compute_capability():
- libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
+ libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll")
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
@@ -51,8 +56,7 @@ def get_compute_capability():
else:
break
else:
- raise OSError("could not load any of: " + ' '.join(libnames))
-
+ raise OSError("could not load any of: " + " ".join(libnames))
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
@@ -69,39 +73,43 @@ def get_compute_capability():
ccs = []
for i in range(nGpus.value):
result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
- result = check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device))
- ccs.append(f'{cc_major.value}.{cc_minor.value}')
+ result = check_cuda_result(
+ cuda,
+ cuda.cuDeviceComputeCapability(
+ ctypes.byref(cc_major), ctypes.byref(cc_minor), device
+ ),
+ )
+ ccs.append(f"{cc_major.value}.{cc_minor.value}")
- #TODO: handle different compute capabilities; for now, take the max
+ # TODO: handle different compute capabilities; for now, take the max
ccs.sort()
- return ccs[-1]
+ # return ccs[-1]
+ return ccs
+
CUDA_RUNTIME_LIB: str = "libcudart.so"
+
def tokenize_paths(paths: str) -> Set[Path]:
- return {
- Path(ld_path) for ld_path in paths.split(':')
- if ld_path
- }
+ return {Path(ld_path) for ld_path in paths.split(":") if ld_path}
+
def get_cuda_runtime_lib_path(
# TODO: replace this with logic for all paths in env vars
LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH")
) -> Union[Path, None]:
- """ # TODO: add doc-string
- """
+ """# TODO: add doc-string"""
if not LD_LIBRARY_PATH:
warn_of_missing_prerequisite(
- 'LD_LIBRARY_PATH is completely missing from environment!'
+ "LD_LIBRARY_PATH is completely missing from environment!"
)
return None
ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH)
- non_existent_directories: Set[Path] = {
- path for path in ld_library_paths
- if not path.exists()
+ non_existent_directories: Set[Path] = {
+ path for path in ld_library_paths if not path.exists()
}
if non_existent_directories:
@@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
)
cuda_runtime_libs: Set[Path] = {
- path / CUDA_RUNTIME_LIB for path in ld_library_paths
+ path / CUDA_RUNTIME_LIB
+ for path in ld_library_paths
if (path / CUDA_RUNTIME_LIB).is_file()
} - non_existent_directories
@@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
return single_cuda_runtime_lib_dir
+
def evaluate_cuda_setup():
cuda_path = get_cuda_runtime_lib_path()
cc = get_compute_capability()
- binary_name = 'libbitsandbytes_cpu.so'
+ binary_name = "libbitsandbytes_cpu.so"
if not (has_gpu := bool(cc)):
- print('WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library...')
+ print(
+ "WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library..."
+ )
return binary_name
- has_cublaslt = cc in ['7.5', '8.0', '8.6']
+ has_cublaslt = cc in ["7.5", "8.0", "8.6"]
- # TODO:
+ # TODO:
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
cuda_home = str(Path(cuda_path).parent.parent)
- ls_output, err = execute_and_return(f'{cuda_home}/bin/nvcc --version')
- cuda_version = ls_output.split('\n')[3].split(',')[-1].strip().lower().replace('v', '')
- major, minor, revision = cuda_version.split('.')
- cuda_version_string = f'{major}{minor}'
+ ls_output, err = execute_and_return(f"{cuda_home}/bin/nvcc --version")
+ cuda_version = (
+ ls_output.split("\n")[3].split(",")[-1].strip().lower().replace("v", "")
+ )
+ major, minor, revision = cuda_version.split(".")
+ cuda_version_string = f"{major}{minor}"
binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so'
diff --git a/bitsandbytes/debug_cli.py b/bitsandbytes/debug_cli.py
index 88307a6..4306bc0 100644
--- a/bitsandbytes/debug_cli.py
+++ b/bitsandbytes/debug_cli.py
@@ -1,6 +1,5 @@
import typer
-
cli = typer.Typer()
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index ac85f88..2e86958 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import random
@@ -9,47 +9,68 @@ from typing import Tuple
import torch
from torch import Tensor
-from .cextension import lib, COMPILED_WITH_CUDA
+from .cextension import COMPILED_WITH_CUDA, lib
name2qmap = {}
if COMPILED_WITH_CUDA:
- ''' C FUNCTIONS FOR OPTIMIZERS '''
+ """C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {}
- str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
- str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
- str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
- str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
- str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
- str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
+ str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
+ str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
+ str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
+ str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
+ str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
+ str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer8bit = {}
- str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
- str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
- str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16)
- str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
- str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
+ str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
+ str2optimizer8bit["momentum"] = (
+ lib.cmomentum_static_8bit_g32,
+ lib.cmomentum_static_8bit_g16,
+ )
+ str2optimizer8bit["rmsprop"] = (
+ lib.crmsprop_static_8bit_g32,
+ lib.crmsprop_static_8bit_g16,
+ )
+ str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
+ str2optimizer8bit["lars"] = (
+ lib.cmomentum_static_8bit_g32,
+ lib.cmomentum_static_8bit_g16,
+ )
str2optimizer8bit_blockwise = {}
- str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16)
- str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16)
- str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
- str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16)
+ str2optimizer8bit_blockwise["adam"] = (
+ lib.cadam_8bit_blockwise_fp32,
+ lib.cadam_8bit_blockwise_fp16,
+ )
+ str2optimizer8bit_blockwise["momentum"] = (
+ lib.cmomentum_8bit_blockwise_fp32,
+ lib.cmomentum_8bit_blockwise_fp16,
+ )
+ str2optimizer8bit_blockwise["rmsprop"] = (
+ lib.crmsprop_8bit_blockwise_fp32,
+ lib.crmsprop_8bit_blockwise_fp16,
+ )
+ str2optimizer8bit_blockwise["adagrad"] = (
+ lib.cadagrad_8bit_blockwise_fp32,
+ lib.cadagrad_8bit_blockwise_fp16,
+ )
class CUBLAS_Context(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = {}
- #prev_device = torch.cuda.current_device()
- #for i in range(torch.cuda.device_count()):
+ # prev_device = torch.cuda.current_device()
+ # for i in range(torch.cuda.device_count()):
# torch.cuda.set_device(torch.device('cuda', i))
# self.context.append(ct.c_void_p(lib.get_context()))
- #torch.cuda.set_device(prev_device)
+ # torch.cuda.set_device(prev_device)
@classmethod
def get_instance(cls):
@@ -66,11 +87,12 @@ class CUBLAS_Context(object):
torch.cuda.set_device(prev_device)
return self.context[device.index]
+
class Cusparse_Context(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = ct.c_void_p(lib.get_cusparse())
@@ -82,14 +104,16 @@ class Cusparse_Context(object):
cls._instance.initialize()
return cls._instance
+
def create_linear_map(signed=True):
if signed:
return torch.linspace(-1.0, 1.0, 256)
else:
return torch.linspace(0.0, 1.0, 256)
+
def create_dynamic_map(signed=True, n=7):
- '''
+ """
Creates the dynamic quantiztion map.
The dynamic data type is made up of a dynamic exponent and
@@ -103,46 +127,54 @@ def create_dynamic_map(signed=True, n=7):
For more details see
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
- '''
+ """
data = []
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
- additional_items = 2**(7-n)-1
- if not signed: additional_items = 2*additional_items
+ additional_items = 2 ** (7 - n) - 1
+ if not signed:
+ additional_items = 2 * additional_items
for i in range(n):
- fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1
+ fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
boundaries = torch.linspace(0.1, 1, fraction_items)
- means = (boundaries[:-1]+boundaries[1:])/2.0
- data += ((10**(-(n-1)+i))*means).tolist()
+ means = (boundaries[:-1] + boundaries[1:]) / 2.0
+ data += ((10 ** (-(n - 1) + i)) * means).tolist()
if signed:
- data += (-(10**(-(n-1)+i))*means).tolist()
+ data += (-(10 ** (-(n - 1) + i)) * means).tolist()
if additional_items > 0:
- boundaries = torch.linspace(0.1, 1, additional_items+1)
- means = (boundaries[:-1]+boundaries[1:])/2.0
- data += ((10**(-(n-1)+i))*means).tolist()
+ boundaries = torch.linspace(0.1, 1, additional_items + 1)
+ means = (boundaries[:-1] + boundaries[1:]) / 2.0
+ data += ((10 ** (-(n - 1) + i)) * means).tolist()
if signed:
- data += (-(10**(-(n-1)+i))*means).tolist()
+ data += (-(10 ** (-(n - 1) + i)) * means).tolist()
data.append(0)
data.append(1.0)
data.sort()
return Tensor(data)
+
def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
- print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
+ print(
+ f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
+ )
assert major >= 7
- if major == 7: return 'col_turing'
- elif major == 8: return 'col_ampere'
- else: return 'col_turing'
+ if major == 7:
+ return "col_turing"
+ elif major == 8:
+ return "col_ampere"
+ else:
+ return "col_turing"
+
def get_ptr(A: Tensor) -> ct.c_void_p:
- '''
+ """
Get the ctypes pointer from a PyTorch Tensor.
Parameters
@@ -153,31 +185,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
Returns
-------
ctypes.c_void_p
- '''
- if A is None: return None
- else: return ct.c_void_p(A.data.storage().data_ptr())
+ """
+ if A is None:
+ return None
+ else:
+ return ct.c_void_p(A.data.storage().data_ptr())
+
def pre_call(device):
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
return prev_device
+
def post_call(prev_device):
torch.cuda.set_device(prev_device)
+
def get_transform_func(dtype, orderA, orderOut, transpose=False):
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
if not hasattr(lib, name):
print(name)
- raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}')
+ raise ValueError(
+ f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
+ )
else:
return getattr(lib, name)
+
class GlobalData(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.data = {}
@@ -190,15 +230,17 @@ class GlobalData(object):
return cls._instance
-def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False):
- #init_func = torch.empty
+def get_transform_buffer(
+ shape, dtype, device, to_order, from_order="row", transpose=False
+):
+ # init_func = torch.empty
init_func = torch.zeros
dims = len(shape)
if dims == 2:
rows = shape[0]
elif dims == 3:
- rows = shape[0]*shape[1]
+ rows = shape[0] * shape[1]
cols = shape[-1]
state = (shape, to_order)
@@ -209,30 +251,39 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans
cols = tmp
state = (shape[::-1], to_order)
- if to_order == 'row' or to_order == 'col':
+ if to_order == "row" or to_order == "col":
return init_func(shape, dtype=dtype, device=device), state
- elif to_order == 'col32':
+ elif to_order == "col32":
# blocks of 32 columns (padded)
- cols = 32*((cols+31)//32)
+ cols = 32 * ((cols + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
- elif to_order == 'col_turing':
+ elif to_order == "col_turing":
# blocks of 32 columns and 8 rows
- cols = 32*((cols+31)//32)
- rows = 8*((rows+7)//8)
+ cols = 32 * ((cols + 31) // 32)
+ rows = 8 * ((rows + 7) // 8)
return init_func((rows, cols), dtype=dtype, device=device), state
- elif to_order == 'col_ampere':
+ elif to_order == "col_ampere":
# blocks of 32 columns and 32 rows
- cols = 32*((cols+31)//32)
- rows = 32*((rows+31)//32)
+ cols = 32 * ((cols + 31) // 32)
+ rows = 32 * ((rows + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
else:
- raise NotImplementedError(f'To_order not supported: {to_order}')
+ raise NotImplementedError(f"To_order not supported: {to_order}")
+
-def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
- if state is None: state = (A.shape, from_order)
- else: from_order = state[1]
- if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
- else: new_state = (state[1], to_order)
+def nvidia_transform(
+ A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
+):
+ if state is None:
+ state = (A.shape, from_order)
+ else:
+ from_order = state[1]
+ if out is None:
+ out, new_state = get_transform_buffer(
+ state[0], A.dtype, A.device, to_order, state[1]
+ )
+ else:
+ new_state = (state[1], to_order)
func = get_transform_func(A.dtype, from_order, to_order, transpose)
shape = state[0]
@@ -242,10 +293,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
elif ld is not None:
n = math.prod(shape)
dim1 = math.prod([shape[i] for i in ld])
- dim2 = ct.c_int32(n//dim1)
+ dim2 = ct.c_int32(n // dim1)
dim1 = ct.c_int32(dim1)
else:
- dim1 = ct.c_int32(shape[0]*shape[1])
+ dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
ptr = CUBLAS_Context.get_instance().get_context(A.device)
@@ -253,11 +304,13 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
ptrOut = get_ptr(out)
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
-
return out, new_state
-def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor:
- '''
+
+def estimate_quantiles(
+ A: Tensor, out: Tensor = None, offset: float = 1 / 512
+) -> Tensor:
+ """
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
@@ -282,18 +335,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
- '''
- if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
+ """
+ if out is None:
+ out = torch.zeros((256,), dtype=torch.float32, device=A.device)
if A.dtype == torch.float32:
- lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
+ lib.cestimate_quantiles_fp32(
+ get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
+ )
elif A.dtype == torch.float16:
- lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
+ lib.cestimate_quantiles_fp16(
+ get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
+ )
else:
- raise NotImplementedError(f'Not supported data type {A.dtype}')
+ raise NotImplementedError(f"Not supported data type {A.dtype}")
return out
-def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor:
- '''
+
+def quantize_blockwise(
+ A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None
+) -> Tensor:
+ """
Quantize tensor A in blocks of size 4096 values.
Quantizes tensor A by dividing it into blocks of 4096 values.
@@ -319,51 +380,96 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
The 8-bit tensor.
tuple(torch.Tensor, torch.Tensor):
The quantization state to undo the quantization.
- '''
+ """
if code is None:
- if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
- code = name2qmap['dynamic']
+ if "dynamic" not in name2qmap:
+ name2qmap["dynamic"] = create_dynamic_map().to(A.device)
+ code = name2qmap["dynamic"]
code = code.to(A.device)
if absmax is None:
n = A.numel()
num_blocks = 4096
- blocks = n//num_blocks
+ blocks = n // num_blocks
blocks += 1 if n % num_blocks > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
- if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+ if out is None:
+ out = torch.zeros_like(A, dtype=torch.uint8)
-
- if A.device.type != 'cpu':
+ if A.device.type != "cpu":
if rand is not None:
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32:
- lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_stochastic_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ get_ptr(rand),
+ ct.c_int32(rand_offset),
+ ct.c_int(A.numel()),
+ )
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_stochastic_fp16(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ get_ptr(rand),
+ ct.c_int32(rand_offset),
+ ct.c_int(A.numel()),
+ )
else:
- raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ raise ValueError(
+ f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
+ )
else:
if A.dtype == torch.float32:
- lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(A.numel()),
+ )
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_fp16(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(A.numel()),
+ )
else:
- raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ raise ValueError(
+ f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
+ )
else:
# cpu
assert rand is None
- lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_cpu_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(A.numel()),
+ )
return out, (absmax, code)
-def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
- absmax: Tensor=None, code: Tensor=None, out: Tensor=None,
- blocksize: int=4096) -> Tensor:
- '''
+
+def dequantize_blockwise(
+ A: Tensor,
+ quant_state: Tuple[Tensor, Tensor] = None,
+ absmax: Tensor = None,
+ code: Tensor = None,
+ out: Tensor = None,
+ blocksize: int = 4096,
+) -> Tensor:
+ """
Dequantizes blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in
@@ -374,7 +480,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
A : torch.Tensor
The input 8-bit tensor.
quant_state : tuple(torch.Tensor, torch.Tensor)
- Tuple of code and absmax values.
+ Tuple of code and absmax values.
absmax : torch.Tensor
The absmax values.
code : torch.Tensor
@@ -387,57 +493,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
-------
torch.Tensor:
Dequantized tensor (default: float32)
- '''
+ """
assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
- if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
- code = name2qmap['dynamic']
+ if "dynamic" not in name2qmap:
+ name2qmap["dynamic"] = create_dynamic_map().to(A.device)
+ code = name2qmap["dynamic"]
code = code.to(A.device)
- if out is None: out = torch.zeros_like(A, dtype=torch.float32)
- if quant_state is None: quant_state = (absmax, code)
+ if out is None:
+ out = torch.zeros_like(A, dtype=torch.float32)
+ if quant_state is None:
+ quant_state = (absmax, code)
if blocksize not in [2048, 4096]:
- raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]')
+ raise ValueError(
+ f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
+ )
- if A.device.type != 'cpu':
+ if A.device.type != "cpu":
if out.dtype == torch.float32:
- lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
+ lib.cdequantize_blockwise_fp32(
+ get_ptr(quant_state[1]),
+ get_ptr(A),
+ get_ptr(quant_state[0]),
+ get_ptr(out),
+ ct.c_int(blocksize),
+ ct.c_int(A.numel()),
+ )
elif out.dtype == torch.float16:
- lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
+ lib.cdequantize_blockwise_fp16(
+ get_ptr(quant_state[1]),
+ get_ptr(A),
+ get_ptr(quant_state[0]),
+ get_ptr(out),
+ ct.c_int(blocksize),
+ ct.c_int(A.numel()),
+ )
else:
- raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ raise ValueError(
+ f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
+ )
else:
- lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(A.numel()))
-
+ lib.cdequantize_blockwise_cpu_fp32(
+ get_ptr(quant_state[1]),
+ get_ptr(A),
+ get_ptr(quant_state[0]),
+ get_ptr(out),
+ ct.c_int(A.numel()),
+ )
return out
-def quantize(A: Tensor, code: Tensor=None, out: Tensor=None) -> Tensor:
+def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
if code is None:
- if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
- code = name2qmap['dynamic']
+ if "dynamic" not in name2qmap:
+ name2qmap["dynamic"] = create_dynamic_map().to(A.device)
+ code = name2qmap["dynamic"]
code = code.to(A.device)
absmax = torch.abs(A).max()
- inp = A/absmax
+ inp = A / absmax
out = quantize_no_absmax(inp, code, out)
return out, (absmax, code)
-def dequantize(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, absmax: Tensor=None, code: Tensor=None, out: Tensor=None) -> Tensor:
+
+def dequantize(
+ A: Tensor,
+ quant_state: Tuple[Tensor, Tensor] = None,
+ absmax: Tensor = None,
+ code: Tensor = None,
+ out: Tensor = None,
+) -> Tensor:
assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
- if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
- code = name2qmap['dynamic']
+ if "dynamic" not in name2qmap:
+ name2qmap["dynamic"] = create_dynamic_map().to(A.device)
+ code = name2qmap["dynamic"]
code = code.to(A.device)
- if quant_state is None: quant_state = (absmax, code)
+ if quant_state is None:
+ quant_state = (absmax, code)
out = dequantize_no_absmax(A, quant_state[1], out)
- return out*quant_state[0]
+ return out * quant_state[0]
-def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
- '''
+
+def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
+ """
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
@@ -456,13 +599,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
-------
torch.Tensor:
Quantized 8-bit tensor.
- '''
- if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
+ """
+ if out is None:
+ out = torch.zeros_like(A, dtype=torch.uint8)
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
-def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
- '''
+
+def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
+ """
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
@@ -481,17 +626,31 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
-------
torch.Tensor:
32-bit output tensor.
- '''
- if out is None: out = torch.zeros_like(A, dtype=torch.float32)
+ """
+ if out is None:
+ out = torch.zeros_like(A, dtype=torch.float32)
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
-def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Tensor,
- beta1: float, eps: float, step: int, lr: float,
- state2: Tensor=None, beta2: float=0.0,
- weight_decay: float=0.0, gnorm_scale: float=1.0,
- unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None:
- '''
+
+def optimizer_update_32bit(
+ optimizer_name: str,
+ g: Tensor,
+ p: Tensor,
+ state1: Tensor,
+ beta1: float,
+ eps: float,
+ step: int,
+ lr: float,
+ state2: Tensor = None,
+ beta2: float = 0.0,
+ weight_decay: float = 0.0,
+ gnorm_scale: float = 1.0,
+ unorm_vec: Tensor = None,
+ max_unorm: float = 0.0,
+ skip_zeros=False,
+) -> None:
+ """
Performs an inplace optimizer update with one or two optimizer states.
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
@@ -528,33 +687,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
The maximum update norm relative to the weight norm.
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
- '''
+ """
param_norm = 0.0
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
if optimizer_name not in str2optimizer32bit:
- raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')
+ raise NotImplementedError(
+ f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
+ )
if g.dtype == torch.float32 and state1.dtype == torch.float32:
- str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
- ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
- ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
+ str2optimizer32bit[optimizer_name][0](
+ get_ptr(g),
+ get_ptr(p),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_float(weight_decay),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
elif g.dtype == torch.float16 and state1.dtype == torch.float32:
- str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
- ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
- ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
+ str2optimizer32bit[optimizer_name][1](
+ get_ptr(g),
+ get_ptr(p),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_float(weight_decay),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
else:
- raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
-
-def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
- beta1: float, beta2: float, eps: float,
- step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
- max1: Tensor, max2: Tensor, new_max1: Tensor, new_max2: Tensor,
- weight_decay: float=0.0, gnorm_scale: float=1.0,
- unorm_vec: Tensor=None, max_unorm: float=0.0) -> None:
- '''
+ raise ValueError(
+ f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
+ )
+
+
+def optimizer_update_8bit(
+ optimizer_name: str,
+ g: Tensor,
+ p: Tensor,
+ state1: Tensor,
+ state2: Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: Tensor,
+ qmap2: Tensor,
+ max1: Tensor,
+ max2: Tensor,
+ new_max1: Tensor,
+ new_max2: Tensor,
+ weight_decay: float = 0.0,
+ gnorm_scale: float = 1.0,
+ unorm_vec: Tensor = None,
+ max_unorm: float = 0.0,
+) -> None:
+ """
Performs an inplace Adam update.
Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
@@ -602,56 +812,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
- '''
+ """
param_norm = 0.0
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
- str2optimizer8bit[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
- get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm),
- ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
- ct.c_int32(step), ct.c_float(lr),
- get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2),
- ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ str2optimizer8bit[optimizer_name][0](
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(max1),
+ get_ptr(max2),
+ get_ptr(new_max1),
+ get_ptr(new_max2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_int32(g.numel()),
+ )
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
- str2optimizer8bit[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
- get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm),
- ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
- ct.c_int32(step), ct.c_float(lr),
- get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2),
- ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ str2optimizer8bit[optimizer_name][1](
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(max1),
+ get_ptr(max2),
+ get_ptr(new_max1),
+ get_ptr(new_max2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_int32(g.numel()),
+ )
else:
- raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
-
-
-def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
- beta1: float, beta2: float, eps: float,
- step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
- absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0,
- skip_zeros=False) -> None:
-
+ raise ValueError(
+ f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
+ )
+
+
+def optimizer_update_8bit_blockwise(
+ optimizer_name: str,
+ g: Tensor,
+ p: Tensor,
+ state1: Tensor,
+ state2: Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: Tensor,
+ qmap2: Tensor,
+ absmax1: Tensor,
+ absmax2: Tensor,
+ weight_decay: float = 0.0,
+ gnorm_scale: float = 1.0,
+ skip_zeros=False,
+) -> None:
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
- str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
- ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
- ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
- ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
+ str2optimizer8bit_blockwise[optimizer_name][0](
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(absmax1),
+ get_ptr(absmax2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
- str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
- ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
- ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
- ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
+ str2optimizer8bit_blockwise[optimizer_name][1](
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(absmax1),
+ get_ptr(absmax2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
else:
- raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
+ raise ValueError(
+ f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
+ )
-def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int=5):
+def percentile_clipping(
+ grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
+):
"""Applies percentile clipping
grad: torch.Tensor
@@ -663,11 +952,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
"""
if grad.dtype == torch.float32:
- lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
+ lib.cpercentile_clipping_g32(
+ get_ptr(grad),
+ get_ptr(gnorm_vec),
+ ct.c_int32(step),
+ ct.c_int32(grad.numel()),
+ )
elif grad.dtype == torch.float16:
- lib.cpercentile_clipping_g16(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
+ lib.cpercentile_clipping_g16(
+ get_ptr(grad),
+ get_ptr(gnorm_vec),
+ ct.c_int32(step),
+ ct.c_int32(grad.numel()),
+ )
else:
- raise ValueError(f'Gradient type {grad.dtype} not supported!')
+ raise ValueError(f"Gradient type {grad.dtype} not supported!")
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec)
@@ -675,31 +974,44 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
gnorm_scale = 1.0
if current_gnorm > clip_value:
- gnorm_scale = clip_value/current_gnorm
+ gnorm_scale = clip_value / current_gnorm
return current_gnorm, clip_value, gnorm_scale
-def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
+def histogram_scatter_add_2d(
+ histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
+):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
assert source.dtype == torch.float32
assert index1.dtype == torch.int32
assert index2.dtype == torch.int32
- assert histogram.device.type == 'cuda'
- assert index1.device.type == 'cuda'
- assert index2.device.type == 'cuda'
- assert source.device.type == 'cuda'
+ assert histogram.device.type == "cuda"
+ assert index1.device.type == "cuda"
+ assert index2.device.type == "cuda"
+ assert source.device.type == "cuda"
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
- lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
+ lib.chistogram_scatter_add_2d(
+ get_ptr(histogram),
+ get_ptr(index1),
+ get_ptr(index2),
+ get_ptr(source),
+ maxdim1,
+ n,
+ )
+
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
- if not torch.cuda.is_initialized(): torch.cuda.init()
+ if not torch.cuda.is_initialized():
+ torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
- raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}')
+ raise TypeError(
+ f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
+ )
sA = A.shape
sB = B.shape
@@ -709,64 +1021,101 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
correct = True
if len(sA) == 2 and len(sB) == 2:
- if not tA and not tB and A.shape[1] != B.shape[0]: correct = False
- elif tA and not tB and A.shape[0] != B.shape[0]: correct = False
- elif tA and tB and A.shape[0] != B.shape[1]: correct = False
- elif not tA and tB and A.shape[1] != B.shape[1]: correct = False
+ if not tA and not tB and A.shape[1] != B.shape[0]:
+ correct = False
+ elif tA and not tB and A.shape[0] != B.shape[0]:
+ correct = False
+ elif tA and tB and A.shape[0] != B.shape[1]:
+ correct = False
+ elif not tA and tB and A.shape[1] != B.shape[1]:
+ correct = False
elif len(sA) == 3 and len(sB) == 2:
- if not tA and not tB and A.shape[2] != B.shape[0]: correct = False
- elif tA and not tB and A.shape[1] != B.shape[0]: correct = False
- elif tA and tB and A.shape[1] != B.shape[1]: correct = False
- elif not tA and tB and A.shape[2] != B.shape[1]: correct = False
+ if not tA and not tB and A.shape[2] != B.shape[0]:
+ correct = False
+ elif tA and not tB and A.shape[1] != B.shape[0]:
+ correct = False
+ elif tA and tB and A.shape[1] != B.shape[1]:
+ correct = False
+ elif not tA and tB and A.shape[2] != B.shape[1]:
+ correct = False
elif len(sA) == 3 and len(sB) == 3:
- if not tA and not tB and A.shape[2] != B.shape[1]: correct = False
- elif tA and not tB and A.shape[1] != B.shape[1]: correct = False
- elif tA and tB and A.shape[1] != B.shape[2]: correct = False
- elif not tA and tB and A.shape[2] != B.shape[2]: correct = False
+ if not tA and not tB and A.shape[2] != B.shape[1]:
+ correct = False
+ elif tA and not tB and A.shape[1] != B.shape[1]:
+ correct = False
+ elif tA and tB and A.shape[1] != B.shape[2]:
+ correct = False
+ elif not tA and tB and A.shape[2] != B.shape[2]:
+ correct = False
if out is not None:
sout = out.shape
# special case common in backprop
if not correct and len(sA) == 3 and len(sB) == 3:
- if (sout[0] == sA[2] and sout[1] == sB[2] and
- sA[0] == sB[0] and sA[1] == sB[1]):
+ if (
+ sout[0] == sA[2]
+ and sout[1] == sB[2]
+ and sA[0] == sB[0]
+ and sA[1] == sB[1]
+ ):
correct = True
else:
if len(sA) == 2 and len(sB) == 2:
- if not tA and not tB: sout = (sA[0], sB[1])
- elif tA and tB: sout = (sA[1], sB[0])
- elif tA and not tB: sout = (sA[1], sB[1])
- elif not tA and tB: sout = (sA[0], sB[0])
+ if not tA and not tB:
+ sout = (sA[0], sB[1])
+ elif tA and tB:
+ sout = (sA[1], sB[0])
+ elif tA and not tB:
+ sout = (sA[1], sB[1])
+ elif not tA and tB:
+ sout = (sA[0], sB[0])
elif len(sA) == 3 and len(sB) == 2:
- if not tA and not tB: sout = (sA[0], sA[1], sB[1])
- elif tA and tB: sout = (sA[0], sA[2], sB[0])
- elif tA and not tB: sout = (sA[0], sA[2], sB[1])
- elif not tA and tB: sout = (sA[0], sA[1], sB[0])
+ if not tA and not tB:
+ sout = (sA[0], sA[1], sB[1])
+ elif tA and tB:
+ sout = (sA[0], sA[2], sB[0])
+ elif tA and not tB:
+ sout = (sA[0], sA[2], sB[1])
+ elif not tA and tB:
+ sout = (sA[0], sA[1], sB[0])
elif len(sA) == 3 and len(sB) == 3:
- if not tA and not tB: sout = (sA[0], sA[1], sB[2])
- elif tA and tB: sout = (sA[0], sA[2], sB[1])
- elif tA and not tB: sout = (sA[0], sA[2], sB[2])
- elif not tA and tB: sout = (sA[0], sA[1], sB[1])
-
+ if not tA and not tB:
+ sout = (sA[0], sA[1], sB[2])
+ elif tA and tB:
+ sout = (sA[0], sA[2], sB[1])
+ elif tA and not tB:
+ sout = (sA[0], sA[2], sB[2])
+ elif not tA and tB:
+ sout = (sA[0], sA[1], sB[1])
if not correct:
- raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.')
+ raise ValueError(
+ f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
+ )
return sout
-def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
+
+def igemm(
+ A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
+):
sout = check_matmul(A, B, out, transposed_A, transposed_B)
- if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
+ if out is None:
+ out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if len(A.shape) == 3 and len(B.shape) == 3:
if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
return batched_igemm(A, B, out)
sA = A.shape
sB = B.shape
- if transposed_A and len(sA) == 2: sA = (sA[1], sA[0])
- elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0])
- if transposed_B and len(sB) == 2: sB = (sB[1], sB[0])
- elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0])
+ if transposed_A and len(sA) == 2:
+ sA = (sA[1], sA[0])
+ elif transposed_A and len(sA) == 3:
+ sA = (sA[0], sA[2], sA[0])
+ if transposed_B and len(sB) == 2:
+ sB = (sB[1], sB[0])
+ elif transposed_B and len(sB) == 3:
+ sB = (sB[0], sB[2], sB[0])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
@@ -777,23 +1126,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if len(sB) == 2:
- if B.stride()[0] == B.shape[1]: transposed_B = False
- elif B.stride()[1] == B.shape[0]: transposed_B = True
+ if B.stride()[0] == B.shape[1]:
+ transposed_B = False
+ elif B.stride()[1] == B.shape[0]:
+ transposed_B = True
if len(A.shape) == 2:
- if A.stride()[0] == A.shape[1]: transposed_A = False
- elif A.stride()[1] == A.shape[0]: transposed_A = True
+ if A.stride()[0] == A.shape[1]:
+ transposed_A = False
+ elif A.stride()[1] == A.shape[0]:
+ transposed_A = True
else:
- if A.stride()[1] == A.shape[2]: transposed_A = False
- elif A.stride()[2] == A.shape[1]: transposed_A = True
+ if A.stride()[1] == A.shape[2]:
+ transposed_A = False
+ elif A.stride()[2] == A.shape[1]:
+ transposed_A = True
if len(sA) == 2:
n = sA[0]
ldb = A.stride()[1 if transposed_A else 0]
elif len(sA) == 3 and len(sB) == 2:
- n = sA[0]*sA[1]
+ n = sA[0] * sA[1]
ldb = sA[2]
-
m = sB[1]
k = sB[0]
lda = B.stride()[(1 if transposed_B else 0)]
@@ -802,34 +1156,52 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# special case
assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]):
- raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}')
+ raise ValueError(
+ f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
+ )
transposed_A = True
transposed_B = False
m = sB[2]
n = sA[2]
- k = sB[0]*sB[1]
+ k = sB[0] * sB[1]
lda = m
ldb = sA[2]
ldc = m
-
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
- # [km, nk -> mn]
- lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
- get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
+ # [km, nk -> mn]
+ lib.cigemm(
+ ptr,
+ ct.c_bool(transposed_B),
+ ct.c_bool(transposed_A),
+ ct.c_int32(m),
+ ct.c_int32(n),
+ ct.c_int32(k),
+ get_ptr(B),
+ get_ptr(A),
+ get_ptr(out),
+ ct.c_int32(lda),
+ ct.c_int32(ldb),
+ ct.c_int32(ldc),
+ )
return out
-def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
+def batched_igemm(
+ A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
+):
if not len(A.shape) == 3 or not len(B.shape) == 3:
- raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}')
+ raise ValueError(
+ f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
+ )
sout = check_matmul(A, B, out, transposed_A, transposed_B)
- if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
+ if out is None:
+ out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if B.is_contiguous():
lda = B.stride()[1]
@@ -886,17 +1258,33 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ldc = m
- strideA = B.shape[1]*B.shape[2]
- strideB = A.shape[1]*A.shape[2]
- strideC = A.shape[1]*B.shape[2]
+ strideA = B.shape[1] * B.shape[2]
+ strideB = A.shape[1] * A.shape[2]
+ strideC = A.shape[1] * B.shape[2]
ptr = CUBLAS_Context.get_instance().get_context(A.device)
- lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
- get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
- ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
+ lib.cbatched_igemm(
+ ptr,
+ ct.c_bool(transposed_B),
+ ct.c_bool(transposed_A),
+ ct.c_int32(m),
+ ct.c_int32(n),
+ ct.c_int32(k),
+ get_ptr(B),
+ get_ptr(A),
+ get_ptr(out),
+ ct.c_int32(lda),
+ ct.c_int32(ldb),
+ ct.c_int32(ldc),
+ ct.c_long(strideA),
+ ct.c_long(strideB),
+ ct.c_long(strideC),
+ ct.c_uint32(num_batch),
+ )
return out
+
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeA = SA[0]
shapeB = SB[0]
@@ -905,28 +1293,34 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
- m = shapeA[0]*shapeA[1]
+ m = shapeA[0] * shapeA[1]
if dimsB == 2:
rows = n = shapeB[0]
elif dimsB == 3:
- rows = n = shapeB[0]*shapeB[1]
+ rows = n = shapeB[0] * shapeB[1]
if dimsA == 2 and out is None:
- out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
+ out, Sout = get_transform_buffer(
+ (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
+ )
elif dimsA == 3 and out is None:
- out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
+ out, Sout = get_transform_buffer(
+ (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
+ )
- assert dimsB != 3, 'len(B.shape)==3 not supported'
- assert A.device.type == 'cuda'
- assert B.device.type == 'cuda'
+ assert dimsB != 3, "len(B.shape)==3 not supported"
+ assert A.device.type == "cuda"
+ assert B.device.type == "cuda"
assert A.dtype == torch.int8
assert B.dtype == torch.int8
assert out.dtype == dtype
- assert SA[1] == 'col32'
- assert SB[1] in ['col_turing', 'col_ampere']
- assert Sout[1] == 'col32'
- assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}'
+ assert SA[1] == "col32"
+ assert SB[1] in ["col_turing", "col_ampere"]
+ assert Sout[1] == "col32"
+ assert (
+ shapeA[-1] == shapeB[-1]
+ ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
formatB = SB[1]
prev_device = A.device
torch.cuda.set_device(A.device)
@@ -937,53 +1331,76 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
ptrC = get_ptr(out)
k = shapeA[-1]
- lda = ct.c_int32(m*32)
- if formatB == 'col_turing':
+ lda = ct.c_int32(m * 32)
+ if formatB == "col_turing":
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
- ldb = ct.c_int32(((rows+7)//8)*8*32)
+ ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
else:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
- ldb = ct.c_int32(((rows+31)//32)*32*32)
+ ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
- ldc = ct.c_int32(m*32)
+ ldc = ct.c_int32(m * 32)
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
has_error = 0
ptrRowScale = get_ptr(None)
- if formatB == 'col_turing':
+ if formatB == "col_turing":
if dtype == torch.int32:
- has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_turing_32(
+ ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
+ )
else:
- has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
- elif formatB == 'col_ampere':
+ has_error = lib.cigemmlt_turing_8(
+ ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
+ )
+ elif formatB == "col_ampere":
if dtype == torch.int32:
- has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_ampere_32(
+ ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
+ )
else:
- has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_ampere_8(
+ ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
+ )
if has_error == 1:
- raise Exception('cublasLt ran into an error!')
+ raise Exception("cublasLt ran into an error!")
torch.cuda.set_device(prev_device)
-
return out, Sout
-def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None):
+def mm_dequant(
+ A,
+ quant_state,
+ row_stats,
+ col_stats,
+ out=None,
+ new_row_stats=None,
+ new_col_stats=None,
+):
assert A.dtype == torch.int32
out_shape = quant_state[0]
- if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2])
-
- if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
- if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
- if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
- assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
- assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
+ if len(out_shape) == 3:
+ out_shape = (out_shape[0] * out_shape[1], out_shape[2])
+
+ if out is None:
+ out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
+ if new_row_stats is None:
+ new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
+ if new_col_stats is None:
+ new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
+ assert (
+ new_row_stats.shape[0] == row_stats.shape[0]
+ ), f"{new_row_stats.shape} vs {row_stats.shape}"
+ assert (
+ new_col_stats.shape[0] == col_stats.shape[0]
+ ), f"{new_col_stats.shape} vs {col_stats.shape}"
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
@@ -994,27 +1411,47 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
- lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
+ lib.cdequant_mm_int32_fp16(
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrOut,
+ ptrNewRowStats,
+ ptrNewColStats,
+ numRows,
+ numCols,
+ )
return out
-def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
+def get_colrow_absmax(
+ A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
+):
assert A.dtype == torch.float16
device = A.device
cols = A.shape[-1]
if len(A.shape) == 3:
- rows = A.shape[0]*A.shape[1]
+ rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]
- col_tiles = (cols+255)//256
- tiled_rows = ((rows+15)//16)*16
- if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
- if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)
-
- if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device)
+ col_tiles = (cols + 255) // 256
+ tiled_rows = ((rows + 15) // 16) * 16
+ if row_stats is None:
+ row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(
+ -50000.0
+ )
+ if col_stats is None:
+ col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(
+ -50000.0
+ )
+
+ if nnz_block_ptr is None and threshold > 0.0:
+ nnz_block_ptr = torch.zeros(
+ ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
+ )
ptrA = get_ptr(A)
ptrRowStats = get_ptr(row_stats)
@@ -1024,16 +1461,17 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
cols = ct.c_int32(cols)
prev_device = pre_call(A.device)
- lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
+ lib.cget_col_row_stats(
+ ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols
+ )
post_call(prev_device)
-
if threshold > 0.0:
nnz_block_ptr.cumsum_(0)
-
return row_stats, col_stats, nnz_block_ptr
+
class COOSparseTensor(object):
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
@@ -1050,6 +1488,7 @@ class COOSparseTensor(object):
self.colidx = colidx
self.values = values
+
class CSRSparseTensor(object):
def __init__(self, rows, cols, nnz, rowptr, colidx, values):
assert rowptr.dtype == torch.int32
@@ -1057,7 +1496,7 @@ class CSRSparseTensor(object):
assert values.dtype == torch.float16
assert values.numel() == nnz
assert colidx.numel() == nnz
- assert rowptr.numel() == rows+1
+ assert rowptr.numel() == rows + 1
self.rows = rows
self.cols = cols
@@ -1066,6 +1505,7 @@ class CSRSparseTensor(object):
self.colidx = colidx
self.values = values
+
class CSCSparseTensor(object):
def __init__(self, rows, cols, nnz, colptr, rowidx, values):
assert colptr.dtype == torch.int32
@@ -1073,7 +1513,7 @@ class CSCSparseTensor(object):
assert values.dtype == torch.float16
assert values.numel() == nnz
assert rowidx.numel() == nnz
- assert colptr.numel() == cols+1
+ assert colptr.numel() == cols + 1
self.rows = rows
self.cols = cols
@@ -1082,13 +1522,17 @@ class CSCSparseTensor(object):
self.rowidx = rowidx
self.values = values
+
def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1)
- rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device)
+ rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0)
- return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
+ return CSRSparseTensor(
+ cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
+ )
+
def coo2csc(cooA):
val, col2rowidx = torch.sort(cooA.colidx)
@@ -1096,11 +1540,12 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1)
- colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device)
+ colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0)
return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
+
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
@@ -1108,23 +1553,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
-def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
+def double_quant(
+ A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
+):
device = A.device
assert A.dtype == torch.half
- assert device.type == 'cuda'
+ assert device.type == "cuda"
prev_device = pre_call(A.device)
cols = A.shape[-1]
if len(A.shape) == 3:
- rows = A.shape[0]*A.shape[1]
+ rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]
if row_stats is None or col_stats is None:
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
- if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
- if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
+ if out_col is None:
+ out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
+ if out_row is None:
+ out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
coo_tensor = None
ptrA = get_ptr(A)
@@ -1136,21 +1585,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
- coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
+ coo_tensor = coo_zeros(
+ A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
+ )
ptrRowIdx = get_ptr(coo_tensor.rowidx)
ptrColIdx = get_ptr(coo_tensor.colidx)
ptrVal = get_ptr(coo_tensor.values)
ptrRowPtr = get_ptr(nnz_row_ptr)
- lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
+ lib.cdouble_rowcol_quant(
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrOutCol,
+ ptrOutRow,
+ ptrRowIdx,
+ ptrColIdx,
+ ptrVal,
+ ptrRowPtr,
+ ct.c_float(threshold),
+ ct.c_int32(rows),
+ ct.c_int32(cols),
+ )
val, idx = torch.sort(coo_tensor.rowidx)
coo_tensor.rowidx = val
coo_tensor.colidx = coo_tensor.colidx[idx]
coo_tensor.values = coo_tensor.values[idx]
else:
- lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols))
+ lib.cdouble_rowcol_quant(
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrOutCol,
+ ptrOutRow,
+ None,
+ None,
+ None,
+ None,
+ ct.c_float(0.0),
+ ct.c_int32(rows),
+ ct.c_int32(cols),
+ )
else:
- lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
+ lib.cdouble_rowcol_quant(
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrOutCol,
+ ptrOutRow,
+ None,
+ None,
+ None,
+ None,
+ ct.c_float(threshold),
+ ct.c_int32(rows),
+ ct.c_int32(cols),
+ )
post_call(prev_device)
return out_row, out_col, row_stats, col_stats, coo_tensor
@@ -1159,69 +1649,81 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
- print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
+ print(
+ f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
+ )
assert major >= 7
- if major == 7: return 'col_turing'
- elif major == 8: return 'col_ampere'
- else: return 'col_turing'
-
-
+ if major == 7:
+ return "col_turing"
+ elif major == 8:
+ return "col_ampere"
+ else:
+ return "col_turing"
-def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
- if state is None: state = (A.shape, from_order)
- else: from_order = state[1]
- if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
- else: new_state = (state[0], to_order) # (shape, order)
+def transform(
+ A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
+):
+ if state is None:
+ state = (A.shape, from_order)
+ else:
+ from_order = state[1]
+ if out is None:
+ out, new_state = get_transform_buffer(
+ state[0], A.dtype, A.device, to_order, state[1], transpose
+ )
+ else:
+ new_state = (state[0], to_order) # (shape, order)
shape = state[0]
if len(shape) == 2:
dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1])
else:
- dim1 = ct.c_int32(shape[0]*shape[1])
+ dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
- if to_order == 'col32':
+ if to_order == "col32":
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == 'col_turing':
+ elif to_order == "col_turing":
if transpose:
lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == 'col_ampere':
+ elif to_order == "col_ampere":
if transpose:
lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == 'row':
- if from_order == 'col_turing':
+ elif to_order == "row":
+ if from_order == "col_turing":
lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
- elif from_order == 'col_ampere':
+ elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
- raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
-
-
-
+ raise NotImplementedError(
+ f"Transform function not implemented: From {from_order} to {to_order}"
+ )
return out, new_state
+
def spmm_coo(cooA, B, out=None):
- if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
+ if out is None:
+ out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
assert cooA.cols == B.shape[0]
- transposed_B = (False if B.is_contiguous() else True)
+ transposed_B = False if B.is_contiguous() else True
ldb = B.stride()[(1 if transposed_B else 0)]
ldc = B.shape[1]
@@ -1240,19 +1742,37 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
- lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
+ lib.cspmm_coo(
+ ptr,
+ ptrRowidx,
+ ptrColidx,
+ ptrValues,
+ cnnz,
+ crowsA,
+ ccolsA,
+ ccolsB,
+ cldb,
+ ptrB,
+ cldc,
+ ptrC,
+ ct.c_bool(transposed_B),
+ )
return out
+
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
- if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
+ if out is None:
+ out = torch.zeros(
+ (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
+ )
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
- assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}'
+ assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
- transposed_B = (False if B.is_contiguous() else True)
+ transposed_B = False if B.is_contiguous() else True
ldb = B.stride()[(1 if transposed_B else 0)]
ldc = B.shape[1]
@@ -1262,7 +1782,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
max_count, max_idx = torch.sort(counts, descending=True)
max_idx = max_idx.int()
max_count = max_count.int()
- assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.'
+ assert (
+ max_count[0] <= 32
+ ), f"Current max count per row is 8 but found {max_count[0]}."
assert B.dtype in [torch.float16, torch.int8]
ptrOffset = get_ptr(offset)
ptrMaxCount = get_ptr(max_count)
@@ -1282,134 +1804,183 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB = ct.c_int32(B.shape[1])
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
- #print(cooA.rowidx[:64])
- #print(cooA.colidx[:64].sort()[0])
+ # print(cooA.rowidx[:64])
+ # print(cooA.colidx[:64].sort()[0])
if B.dtype == torch.float16:
- lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
+ lib.cspmm_coo_very_sparse_naive_fp16(
+ ptrMaxCount,
+ ptrMaxIdx,
+ ptrOffset,
+ ptrRowidx,
+ ptrColidx,
+ ptrValues,
+ ptrB,
+ ptrC,
+ ptrDequantStats,
+ cnnz_rows,
+ cnnz,
+ crowsA,
+ crowsB,
+ ccolsB,
+ )
elif B.dtype == torch.int8:
- lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
- #else: assertion error
+ lib.cspmm_coo_very_sparse_naive_int8(
+ ptrMaxCount,
+ ptrMaxIdx,
+ ptrOffset,
+ ptrRowidx,
+ ptrColidx,
+ ptrValues,
+ ptrB,
+ ptrC,
+ ptrDequantStats,
+ cnnz_rows,
+ cnnz,
+ crowsA,
+ crowsB,
+ ccolsB,
+ )
+ # else: assertion error
return out
C = 127.0
-def vectorwise_quant(x, dim=1, quant_type='vector'):
- if quant_type == 'linear':
+
+def vectorwise_quant(x, dim=1, quant_type="vector"):
+ if quant_type == "linear":
max1 = torch.abs(x).max().float()
- xq = torch.round(x/max1*127).to(torch.int8)
+ xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
- elif quant_type in ['vector', 'row']:
+ elif quant_type in ["vector", "row"]:
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
- xq = torch.round(x*(C/max1)).to(torch.int8)
+ xq = torch.round(x * (C / max1)).to(torch.int8)
return xq, max1
- elif quant_type == 'zeropoint':
+ elif quant_type == "zeropoint":
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
- if dyna == 0: dyna = 1
- qx = 255./dyna
+ if dyna == 0:
+ dyna = 1
+ qx = 255.0 / dyna
minx = x.min()
- zpx = torch.round(minx* qx)
- x = torch.round(qx*x - zpx) + zpx
+ zpx = torch.round(minx * qx)
+ x = torch.round(qx * x - zpx) + zpx
return x, qx
- elif quant_type in ['vector-zeropoint', 'row-zeropoint']:
+ elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
dtype = x.dtype
x = x.float()
- dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True))
- dyna[dyna==0] = 1
- qx = 255./dyna
+ dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
+ x, dim=dim, keepdim=True
+ )
+ dyna[dyna == 0] = 1
+ qx = 255.0 / dyna
minx = torch.amin(x, dim=dim, keepdim=True)
- zpx = torch.round(minx* qx)
- x = torch.round(qx*x - zpx) + zpx
+ zpx = torch.round(minx * qx)
+ x = torch.round(qx * x - zpx) + zpx
return x, qx
- elif quant_type == 'truncated-vector':
+ elif quant_type == "truncated-vector":
with torch.no_grad():
absx = torch.abs(x)
max1 = torch.amax(absx, dim=dim, keepdim=True)
- max1 = max1*0.7
- idx = (absx > max1.expand_as(absx))
+ max1 = max1 * 0.7
+ idx = absx > max1.expand_as(absx)
sign = torch.sign(x[idx])
- x[idx] = max1.expand_as(absx)[idx]*sign
- xq = torch.round(x/max1*C).to(torch.int8)
+ x[idx] = max1.expand_as(absx)[idx] * sign
+ xq = torch.round(x / max1 * C).to(torch.int8)
return xq, max1
- else: return None
+ else:
+ return None
+
-def vectorwise_dequant(xq, max1, quant_type='vector'):
- if quant_type == 'vector':
- x = (xq/C*max1).to(torch.float32)
+def vectorwise_dequant(xq, max1, quant_type="vector"):
+ if quant_type == "vector":
+ x = (xq / C * max1).to(torch.float32)
return x
- else: return None
+ else:
+ return None
-def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'):
- if quant_type == 'linear':
- norm = S1*S2/(C*C)
+
+def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
+ if quant_type == "linear":
+ norm = S1 * S2 / (C * C)
# double cast needed to prevent overflows
- return (xq.float()*norm).to(dtype)
- elif quant_type == 'zeropoint':
- norm = 1.0/(S1*S2)
- return (xq.float()*norm).to(dtype)
- elif quant_type == 'row-zeropoint':
- norm = 1.0/(S1*S2)
+ return (xq.float() * norm).to(dtype)
+ elif quant_type == "zeropoint":
+ norm = 1.0 / (S1 * S2)
+ return (xq.float() * norm).to(dtype)
+ elif quant_type == "row-zeropoint":
+ norm = 1.0 / (S1 * S2)
x = xq.float()
- if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
- if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 3 and len(x.shape) == 2:
+ S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2:
+ S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= norm
else:
x *= norm
return x.to(dtype)
- elif quant_type == 'vector-zeropoint':
+ elif quant_type == "vector-zeropoint":
x = xq.float()
- if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
- if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 3 and len(x.shape) == 2:
+ S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2:
+ S2 = S2.squeeze(0)
if len(S1.shape) == 2:
- x *= 1.0/S1
+ x *= 1.0 / S1
else:
- x *= 1.0/S1
- x *= 1.0/S2.t()
+ x *= 1.0 / S1
+ x *= 1.0 / S2.t()
return x.to(dtype)
- elif quant_type == 'row':
+ elif quant_type == "row":
x = xq.float()
- if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
- if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 3 and len(x.shape) == 2:
+ S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2:
+ S2 = S2.squeeze(0)
if len(S1.shape) == 2:
- x *= S1*S2/(C*C)
+ x *= S1 * S2 / (C * C)
else:
- x *= S1*S2/(C*C)
+ x *= S1 * S2 / (C * C)
return x.to(dtype)
- elif quant_type in ['truncated-vector', 'vector']:
+ elif quant_type in ["truncated-vector", "vector"]:
x = xq.float()
- if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
- if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 3 and len(x.shape) == 2:
+ S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2:
+ S2 = S2.squeeze(0)
if len(S1.shape) == 2:
- x *= S1/C
+ x *= S1 / C
else:
- x *= S1/C
- x *= S2/C
+ x *= S1 / C
+ x *= S2 / C
return x.to(dtype)
- else: return None
+ else:
+ return None
def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
- offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
- if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
+ if len(xq.shape) == 2 and len(SB.shape) == 3:
+ SB = SB.squeeze(0)
if len(SB.shape) == 2:
- x *= SB.t()/127
+ x *= SB.t() / 127
else:
- x *= SB/127
- x *= SA[1]/127
- x +=offset
+ x *= SB / 127
+ x *= SA[1] / 127
+ x += offset
return x.to(dtype)
+
def extract_outliers(A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
- assert formatA in ['col_turing', 'col_ampere']
- assert A.device.type == 'cuda'
+ assert formatA in ["col_turing", "col_ampere"]
+ assert A.device.type == "cuda"
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
@@ -1420,13 +1991,9 @@ def extract_outliers(A, SA, idx):
ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out)
- if formatA == 'col_turing':
+ if formatA == "col_turing":
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
- elif formatA == 'col_ampere':
+ elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
return out
-
-
-
-
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
index 03b4655..98d4aa0 100644
--- a/bitsandbytes/nn/__init__.py
+++ b/bitsandbytes/nn/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params
+from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 5013d0b..9ce3ac8 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -1,39 +1,59 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
-import bitsandbytes as bnb
+from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set,
+ Tuple, TypeVar, Union, overload)
-from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
-
-from torch import Tensor, device, dtype
-from torch import nn
-from torch.nn.parameter import Parameter
+import torch
import torch.nn.functional as F
+from torch import Tensor, device, dtype, nn
+from torch.nn.parameter import Parameter
+import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager
-T = TypeVar('T', bound='torch.nn.Module')
+T = TypeVar("T", bound="torch.nn.Module")
+
class StableEmbedding(torch.nn.Embedding):
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
- super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ _weight: Optional[Tensor] = None,
+ ) -> None:
+ super(StableEmbedding, self).__init__(
+ num_embeddings,
+ embedding_dim,
+ padding_idx,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ sparse,
+ _weight,
+ )
self.norm = torch.nn.LayerNorm(embedding_dim)
- GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
+ GlobalOptimManager.get_instance().register_module_override(
+ self, "weight", {"optim_bits": 32}
+ )
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
- ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
- '''
+ """
+
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
- input, self.weight, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
+ input,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
return self.norm(emb)
class Embedding(torch.nn.Embedding):
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
- super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
- GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ _weight: Optional[Tensor] = None,
+ ) -> None:
+ super(Embedding, self).__init__(
+ num_embeddings,
+ embedding_dim,
+ padding_idx,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ sparse,
+ _weight,
+ )
+ GlobalOptimManager.get_instance().register_module_override(
+ self, "weight", {"optim_bits": 32}
+ )
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
- ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
- '''
+ """
+
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
- input, self.weight, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
+ input,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
return emb
+
class Int8Params(torch.nn.Parameter):
- def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
+ def __new__(
+ cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None
+ ):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
del CBt
del SCBt
self.data = CB
- setattr(self, 'CB', CB)
- setattr(self, 'SCB', SCB)
+ setattr(self, "CB", CB)
+ setattr(self, "SCB", SCB)
return self
@overload
- def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
- non_blocking: bool = ...) -> T:
+ def to(
+ self: T,
+ device: Optional[Union[int, device]] = ...,
+ dtype: Optional[Union[dtype, str]] = ...,
+ non_blocking: bool = ...,
+ ) -> T:
...
@overload
@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
def to(self, *args, **kwargs):
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
-
- if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
+ *args, **kwargs
+ )
+
+ if (
+ device is not None
+ and device.type == "cuda"
+ and self.data.device.type == "cpu"
+ ):
+ return self.cuda(device)
else:
- new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
+ new_param = Int8Params(
+ super().to(device=device, dtype=dtype, non_blocking=non_blocking),
+ requires_grad=self.requires_grad,
+ has_fp16_weights=self.has_fp16_weights,
+ )
new_param.CB = self.CB
new_param.SCB = self.SCB
return new_param
-
class Linear8bitLt(nn.Linear):
- def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
+ def __init__(
+ self,
+ input_features,
+ output_features,
+ bias=True,
+ has_fp16_weights=True,
+ threshold=0.0,
+ index=None,
+ ):
super(Linear8bitLt, self).__init__(input_features, output_features, bias)
self.state = bnb.MatmulLtState()
- self.index=index
+ self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
def forward(self, x):
self.state.is_training = self.training
- if self.weight.CB is not None: self.init_8bit_state()
- #assert not self.state.has_fp16_weights
- #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
+ if self.weight.CB is not None:
+ self.init_8bit_state()
+ # assert not self.state.has_fp16_weights
+ # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out = bnb.matmul(x, self.weight, state=self.state)
@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
return out
+
class Linear8bit(nn.Linear):
- def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
+ def __init__(
+ self,
+ input_features,
+ output_features,
+ bias=True,
+ quant_type="vector",
+ index=None,
+ args=None,
+ sparse_decomp=False,
+ ):
super(Linear8bit, self).__init__(input_features, output_features, bias)
self.quant_type = quant_type
self.index = index
@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
self.iter += 1
if self.iter % self.args.clip_freq == 0:
with torch.no_grad():
- maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
+ maxval, maxidx = torch.topk(
+ torch.abs(self.weight.flatten()), k=self.args.clip_idx
+ )
if not dist.is_initialized() or dist.get_rank() == 0:
- print('clip', maxval[-1].item())
+ print("clip", maxval[-1].item())
self.weight.clip_(-maxval[-1], maxval[-1])
-
if self.args is not None:
- out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
+ out = bnb.nn.functional.sparse_decomposed_linear8bit(
+ x,
+ self.weight,
+ self.bias,
+ qval=self.args.sparse_decomp_val,
+ quant_type=self.args.quant_type,
+ )
else:
- out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
+ out = bnb.nn.functional.linear8bit(
+ x, self.weight, self.bias, quant_type=self.args.quant_type
+ )
return out
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index 42b5bc0..a76d717 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.cextension import COMPILED_WITH_CUDA
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py
index 4f51250..43e3973 100644
--- a/bitsandbytes/optim/adagrad.py
+++ b/bitsandbytes/optim/adagrad.py
@@ -1,12 +1,25 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class Adagrad(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
@@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
- super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise ValueError("Lr Decay != 0.0 not supported!")
+ super(Adagrad, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adagrad8bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=8,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
@@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
+ raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise
- super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ super(Adagrad8bit, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adagrad32bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
@@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
- super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise ValueError("Lr Decay != 0.0 not supported!")
+ super(Adagrad32bit, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index ed1b9f0..5cfaa28 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
@@ -8,29 +8,97 @@ import os
import torch
import torch.distributed as dist
-from bitsandbytes.optim.optimizer import Optimizer2State
+
import bitsandbytes.functional as F
+from bitsandbytes.optim.optimizer import Optimizer2State
+
class Adam(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam, self).__init__('adam', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adam8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam8bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam8bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adam32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam32bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam32bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
class AnalysisAdam(torch.optim.Optimizer):
@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
eps=1e-8,
weight_decay=0,
amsgrad=False,
- bnb_analysis='dynamic-blockwise',
- savedir=None
+ bnb_analysis="dynamic-blockwise",
+ savedir=None,
):
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
- state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
- state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
- state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device)
+ state["abserrors"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
+ state["relerrors"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
+ state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
- e = state['abserrors']
- rele = state['relerrors']
- counts = state['counts']
+ e = state["abserrors"]
+ rele = state["relerrors"]
+ counts = state["counts"]
if group["weight_decay"] != 0:
p_data_fp32.add_(
@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
-
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
- update_fp32 = exp_avg/denom
+ update_fp32 = exp_avg / denom
- if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000:
+ if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
# embedding layer or too small
- p_data_fp32 += -step_size*update_fp32
+ p_data_fp32 += -step_size * update_fp32
else:
- if self.analysis == 'dynamic-blockwise':
+ if self.analysis == "dynamic-blockwise":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize_blockwise(exp_avg, code=code1)
state1 = F.dequantize_blockwise(C1, S1)
C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2)
state2 = F.dequantize_blockwise(C2, S2)
- elif self.analysis == 'dynamic':
+ elif self.analysis == "dynamic":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
- elif self.analysis == 'linear':
+ elif self.analysis == "linear":
code1 = F.create_linear_map(signed=True).to(p.device)
code2 = F.create_linear_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
- elif self.analysis == 'quantile':
+ elif self.analysis == "quantile":
code1 = F.estimate_quantiles(exp_avg)
code2 = F.estimate_quantiles(exp_avg_sq)
C1 = F.quantize_no_absmax(exp_avg, code=code1)
state1 = F.dequantize_no_absmax(C1, code1)
C2 = F.quantize_no_absmax(exp_avg_sq, code=code2)
state2 = F.dequantize_no_absmax(C2, code2)
- elif self.analysis == 'my-quantization-routine':
+ elif self.analysis == "my-quantization-routine":
pass
# 1. get code
# 2. quantize
# 3. dequantize
# Error will be calculated automatically!
else:
- raise ValueError(f'Invalid analysis value: {self.analysis}!')
+ raise ValueError(f"Invalid analysis value: {self.analysis}!")
denom = state2.sqrt().add_(group["eps"])
- update_8bit = state1/denom
+ update_8bit = state1 / denom
- abserr = torch.abs(update_8bit-update_fp32)
- relerr = abserr/torch.abs(update_fp32+1e-6)
+ abserr = torch.abs(update_8bit - update_fp32)
+ relerr = abserr / torch.abs(update_fp32 + 1e-6)
C1, C2 = C1.int(), C2.int()
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
- F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
-
- p_data_fp32 += -step_size*update_fp32
+ F.histogram_scatter_add_2d(
+ counts, C1.int(), C2.int(), torch.ones_like(abserr)
+ )
+ p_data_fp32 += -step_size * update_fp32
if not dist.is_initialized() or dist.get_rank() == 0:
- if self.savedir != '' and state['step'] % 100 == 0:
- if not os.path.exists(self.savedir): os.makedirs(self.savedir)
- shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
- pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
- pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
- pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
+ if self.savedir != "" and state["step"] % 100 == 0:
+ if not os.path.exists(self.savedir):
+ os.makedirs(self.savedir)
+ shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
+ pathe = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
+ )
+ pathrele = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
+ )
+ pathcounts = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_counts.pkl"
+ )
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)
@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
-
-
return loss
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
index c4f0355..d0b3bde 100644
--- a/bitsandbytes/optim/adamw.py
+++ b/bitsandbytes/optim/adamw.py
@@ -1,27 +1,93 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
+
class AdamW(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW, self).__init__('adam', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class AdamW8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW8bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW8bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
-class AdamW32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW32bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+class AdamW32bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW32bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py
index 58cc13d..8f365f7 100644
--- a/bitsandbytes/optim/lamb.py
+++ b/bitsandbytes/optim/lamb.py
@@ -1,28 +1,105 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
+
class LAMB(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
-class LAMB8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
-class LAMB32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+class LAMB8bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB8bit, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
+class LAMB32bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB32bit, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index 912520d..c6cf5c6 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -1,43 +1,121 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
-
from torch.optim import Optimizer
+
from bitsandbytes.optim.optimizer import Optimizer1State
+
class LARS(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(f"LARS without momentum is not supported!")
+ super(LARS, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
+
class LARS8bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(f"LARS without momentum is not supported!")
+ super(LARS8bit, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
+
class LARS32bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(f"LARS without momentum is not supported!")
+ super(LARS32bit, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
class PytorchLARS(Optimizer):
- def __init__(self, params, lr=0.01, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr=0.01,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ max_unorm=0.02,
+ ):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
- weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
+ defaults = dict(
+ lr=lr,
+ momentum=momentum,
+ dampening=dampening,
+ weight_decay=weight_decay,
+ nesterov=nesterov,
+ max_unorm=max_unorm,
+ )
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(PytorchLARS, self).__init__(params, defaults)
@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state)
for group in self.param_groups:
- group.setdefault('nesterov', False)
+ group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
- weight_decay = group['weight_decay']
- momentum = group['momentum']
- dampening = group['dampening']
- nesterov = group['nesterov']
- max_unorm = group['max_unorm']
- lr = group['lr']
+ weight_decay = group["weight_decay"]
+ momentum = group["momentum"]
+ dampening = group["dampening"]
+ nesterov = group["nesterov"]
+ max_unorm = group["max_unorm"]
+ lr = group["lr"]
- for p in group['params']:
- if p.grad is None: continue
+ for p in group["params"]:
+ if p.grad is None:
+ continue
state = self.state[p]
d_p = p.grad
@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0:
- buf = state.get('momentum_buffer', None)
+ buf = state.get("momentum_buffer", None)
if buf is None:
buf = torch.clone(d_p).detach()
- state['momentum_buffer']= buf
+ state["momentum_buffer"] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
- update = d_p + buf*momentum
+ update = d_p + buf * momentum
else:
update = buf
@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
assert p.dtype == torch.float32
pnorm = torch.norm(p.detach())
unorm = torch.norm(update)
- if unorm > max_unorm*pnorm:
- update_scale = max_unorm*pnorm/unorm
+ if unorm > max_unorm * pnorm:
+ update_scale = max_unorm * pnorm / unorm
- p.add_(update, alpha=-lr*update_scale)
+ p.add_(update, alpha=-lr * update_scale)
return loss
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 5a5bb1e..b942e34 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -1,13 +1,16 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from collections import abc as container_abcs
+from collections import defaultdict
+from copy import deepcopy
+from itertools import chain
+
import torch
+
import bitsandbytes.functional as F
-from copy import deepcopy
-from itertools import chain
-from collections import defaultdict, abc as container_abcs
class MockArgs(object):
def __init__(self, initial_data):
@@ -19,7 +22,7 @@ class GlobalOptimManager(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.pid2config = {}
@@ -38,15 +41,15 @@ class GlobalOptimManager(object):
def register_parameters(self, params):
param_groups = list(params)
if not isinstance(param_groups[0], dict):
- param_groups = [{'params': param_groups}]
+ param_groups = [{"params": param_groups}]
for group_index, group in enumerate(param_groups):
- for p_index, p in enumerate(group['params']):
+ for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
- '''
+ """
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
@@ -63,7 +66,7 @@ class GlobalOptimManager(object):
The value for the hyperparamters.
key_value_dict : dict
A dictionary with multiple key-values to override.
- '''
+ """
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters]
@@ -75,16 +78,16 @@ class GlobalOptimManager(object):
if key_value_dict is not None:
for p in parameters:
- if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
- else: self.pid2config[id(p)] = key_value_dict
+ if id(p) in self.pid2config:
+ self.pid2config[id(p)].update(key_value_dict)
+ else:
+ self.pid2config[id(p)] = key_value_dict
def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))
-
class Optimizer8bit(torch.optim.Optimizer):
-
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.initialized = False
@@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer):
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
- ['qmap1', 'qmap2',
- 'max1', 'max2',
- 'new_max1', 'new_max2',
- 'state1', 'state2',
- 'gnorm_vec', 'absmax1', 'absmax2',
- 'unorm_vec'])
-
- if optim_bits == 8: self.fill_qmap()
+ [
+ "qmap1",
+ "qmap2",
+ "max1",
+ "max2",
+ "new_max1",
+ "new_max2",
+ "state1",
+ "state2",
+ "gnorm_vec",
+ "absmax1",
+ "absmax2",
+ "unorm_vec",
+ ]
+ )
+
+ if optim_bits == 8:
+ self.fill_qmap()
def fill_qmap(self):
- self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
- self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)
+ self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
+ self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
-
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
@@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer):
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
- saved_groups = state_dict['param_groups']
+ saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups):
- raise ValueError("loaded state dict has a different number of "
- "parameter groups")
- param_lens = (len(g['params']) for g in groups)
- saved_lens = (len(g['params']) for g in saved_groups)
+ raise ValueError(
+ "loaded state dict has a different number of " "parameter groups"
+ )
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
- raise ValueError("loaded state dict contains a parameter group "
- "that doesn't match the size of optimizer's group")
+ raise ValueError(
+ "loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group"
+ )
# Update the state
- id_map = {old_id: p for old_id, p in
- zip(chain.from_iterable((g['params'] for g in saved_groups)),
- chain.from_iterable((g['params'] for g in groups)))}
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in groups)),
+ )
+ }
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
@@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
- for k, v in state_dict['state'].items():
+ for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
@@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer):
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
- new_group['params'] = group['params']
+ new_group["params"] = group["params"]
return new_group
- param_groups = [
- update_group(g, ng) for g, ng in zip(groups, saved_groups)]
- self.__setstate__({'state': state, 'param_groups': param_groups})
+
+ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
for gindex, group in enumerate(self.param_groups):
- for pindex, p in enumerate(group['params']):
+ for pindex, p in enumerate(group["params"]):
if p in self.state:
values = self.state[p]
for k, v in values.items():
@@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
- assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
+ assert isinstance(pmodule, torch.Tensor) or isinstance(
+ pmodule, torch.Parameter
+ )
found = False
for gindex, group in enumerate(self.param_groups):
- if found: break
- for pindex, p in enumerate(group['params']):
- if found: break
+ if found:
+ break
+ for pindex, p in enumerate(group["params"]):
+ if found:
+ break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
- self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
+ self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
+ id(p)
+ ]
found = True
@torch.no_grad()
@@ -219,11 +244,11 @@ class Optimizer8bit(torch.optim.Optimizer):
if not self.initialized:
self.check_overrides()
- self.to_gpu() # needed for fairseq pure fp16 training
+ self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
for gindex, group in enumerate(self.param_groups):
- for pindex, p in enumerate(group['params']):
+ for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
@@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer):
def get_config(self, gindex, pindex, group):
config = {}
- config['betas'] = group['betas']
- config['eps'] = group['eps']
- config['weight_decay'] = group['weight_decay']
- config['lr'] = group['lr']
- config['optim_bits'] = self.args.optim_bits
- config['min_8bit_size'] = self.args.min_8bit_size
- config['percentile_clipping'] = self.args.percentile_clipping
- config['block_wise'] = self.args.block_wise
- config['max_unorm'] = self.args.max_unorm
- config['skip_zeros'] = self.args.skip_zeros
+ config["betas"] = group["betas"]
+ config["eps"] = group["eps"]
+ config["weight_decay"] = group["weight_decay"]
+ config["lr"] = group["lr"]
+ config["optim_bits"] = self.args.optim_bits
+ config["min_8bit_size"] = self.args.min_8bit_size
+ config["percentile_clipping"] = self.args.percentile_clipping
+ config["block_wise"] = self.args.block_wise
+ config["max_unorm"] = self.args.max_unorm
+ config["skip_zeros"] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])
return config
def init_state(self, group, p, gindex, pindex):
- raise NotImplementedError(f'init_state method needs to be overidden')
+ raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
- raise NotImplementedError(f'The update_step method needs to be overidden')
+ raise NotImplementedError(f"The update_step method needs to be overidden")
+
class Optimizer2State(Optimizer8bit):
- def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
- skip_zeros=False):
+ def __init__(
+ self,
+ optimizer_name,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0.0,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ max_unorm=0.0,
+ skip_zeros=False,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
# format: '(beta1, beta2)'
- betas = betas.replace('(', '').replace(')', '').strip().split(',')
+ betas = betas.replace("(", "").replace(")", "").strip().split(",")
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
- args['optim_bits'] = optim_bits
- args['percentile_clipping'] = 100
- args['min_8bit_size'] = min_8bit_size
- args['percentile_clipping'] = percentile_clipping
- args['block_wise'] = block_wise
- args['max_unorm'] = max_unorm
- args['skip_zeros'] = skip_zeros
+ args["optim_bits"] = optim_bits
+ args["percentile_clipping"] = 100
+ args["min_8bit_size"] = min_8bit_size
+ args["percentile_clipping"] = percentile_clipping
+ args["block_wise"] = block_wise
+ args["max_unorm"] = max_unorm
+ args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
@@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
- if config['optim_bits'] == 32:
+ if config["optim_bits"] == 32:
dtype = torch.float32
- elif config['optim_bits'] == 8:
+ elif config["optim_bits"] == 8:
dtype = torch.uint8
- else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+ else:
+ raise NotImplementedError(
+ f'Amount of optimizer bits not supported: {config["optim_bits"]}'
+ )
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
+ if p.numel() < config["min_8bit_size"]:
+ dtype = torch.float32
state = self.state[p]
- state['step'] = 0
+ state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
+ state["state2"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
- self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
-
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
-
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap2'] = self.name2qmap['udynamic']
-
- if config['block_wise']:
+ if state["step"] == 0:
+ if "dynamic" not in self.name2qmap:
+ self.fill_qmap()
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
+ self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
+
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap1"] = self.name2qmap["dynamic"]
+
+ state["state2"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap2"] = self.name2qmap["udynamic"]
+
+ if config["block_wise"]:
n = p.numel()
- blocks = n//2048
+ blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
- state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state["absmax1"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
+ state["absmax2"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
else:
- state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["new_max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["new_max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
- if config['percentile_clipping'] < 100:
- state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+ if config["percentile_clipping"] < 100:
+ state["gnorm_vec"] = torch.zeros((100,), device=p.device)
- if config['max_unorm'] > 0.0:
- state['unorm_vec'] = torch.zeros((1,), device=p.device)
+ if config["max_unorm"] > 0.0:
+ state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
@@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
- state['step'] += 1
- step = state['step']
+ state["step"] += 1
+ step = state["step"]
- if config['percentile_clipping'] < 100:
- current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ if config["percentile_clipping"] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
+ grad, state["gnorm_vec"], step, config["percentile_clipping"]
+ )
else:
gnorm_scale = 1.0
- if state['state1'].dtype == torch.float:
- F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
- state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
-
- elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
- F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
- config['weight_decay'], gnorm_scale=gnorm_scale,
- unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+ if state["state1"].dtype == torch.float:
+ F.optimizer_update_32bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ config["betas"][0],
+ config["eps"],
+ step,
+ config["lr"],
+ state["state2"],
+ config["betas"][1],
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ skip_zeros=config["skip_zeros"],
+ )
+
+ elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
+ F.optimizer_update_8bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ state["state2"],
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ state["qmap2"],
+ state["max1"],
+ state["max2"],
+ state["new_max1"],
+ state["new_max2"],
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ )
# swap maxes
- state['max1'], state['new_max1'] = state['new_max1'], state['max1']
- state['max2'], state['new_max2'] = state['new_max2'], state['max2']
- elif state['state1'].dtype == torch.uint8 and config['block_wise']:
- F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
+ state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
+ state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
+ elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
+ F.optimizer_update_8bit_blockwise(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ state["state2"],
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ state["qmap2"],
+ state["absmax1"],
+ state["absmax2"],
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ skip_zeros=config["skip_zeros"],
+ )
class Optimizer1State(Optimizer8bit):
- def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
- weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
- skip_zeros=False):
+ def __init__(
+ self,
+ optimizer_name,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.0),
+ eps=1e-8,
+ weight_decay=0.0,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ max_unorm=0.0,
+ skip_zeros=False,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit):
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
- args['optim_bits'] = optim_bits
- args['percentile_clipping'] = 100
- args['min_8bit_size'] = min_8bit_size
- args['percentile_clipping'] = percentile_clipping
- args['block_wise'] = block_wise
- args['max_unorm'] = max_unorm
- args['skip_zeros'] = skip_zeros
+ args["optim_bits"] = optim_bits
+ args["percentile_clipping"] = 100
+ args["min_8bit_size"] = min_8bit_size
+ args["percentile_clipping"] = percentile_clipping
+ args["block_wise"] = block_wise
+ args["max_unorm"] = max_unorm
+ args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
@@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
- if config['optim_bits'] == 32:
+ if config["optim_bits"] == 32:
dtype = torch.float32
- elif config['optim_bits'] == 8:
+ elif config["optim_bits"] == 8:
dtype = torch.uint8
- else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+ else:
+ raise NotImplementedError(
+ f'Amount of optimizer bits not supported: {config["optim_bits"]}'
+ )
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
+ if p.numel() < config["min_8bit_size"]:
+ dtype = torch.float32
state = self.state[p]
- state['step'] = 0
+ state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
-
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
-
- if config['block_wise']:
+ if state["step"] == 0:
+ if "dynamic" not in self.name2qmap:
+ self.fill_qmap()
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
+
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap1"] = self.name2qmap["dynamic"]
+
+ if config["block_wise"]:
n = p.numel()
- blocks = n//2048
+ blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state["absmax1"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
else:
- state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
-
- if config['percentile_clipping'] < 100:
- state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+ state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["new_max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
- if config['max_unorm'] > 0.0:
- state['unorm_vec'] = torch.zeros((1,), device=p.device)
+ if config["percentile_clipping"] < 100:
+ state["gnorm_vec"] = torch.zeros((100,), device=p.device)
+ if config["max_unorm"] > 0.0:
+ state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
@@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
- state['step'] += 1
- step = state['step']
+ state["step"] += 1
+ step = state["step"]
- if config['percentile_clipping'] < 100:
- current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ if config["percentile_clipping"] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
+ grad, state["gnorm_vec"], step, config["percentile_clipping"]
+ )
else:
gnorm_scale = 1.0
- if state['state1'].dtype == torch.float:
- F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
- None, 0.0, config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
- skip_zeros=config['skip_zeros'])
-
- elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
- F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
- config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
-
- state['max1'], state['new_max1'] = state['new_max1'], state['max1']
- elif state['state1'].dtype == torch.uint8 and config['block_wise']:
- F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], None, state['absmax1'], None,
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
+ if state["state1"].dtype == torch.float:
+ F.optimizer_update_32bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ config["betas"][0],
+ config["eps"],
+ step,
+ config["lr"],
+ None,
+ 0.0,
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ skip_zeros=config["skip_zeros"],
+ )
+
+ elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
+ F.optimizer_update_8bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ None,
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ None,
+ state["max1"],
+ None,
+ state["new_max1"],
+ None,
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ )
+
+ state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
+ elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
+ F.optimizer_update_8bit_blockwise(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ None,
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ None,
+ state["absmax1"],
+ None,
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ skip_zeros=config["skip_zeros"],
+ )
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
index 0f1ffaa..679f783 100644
--- a/bitsandbytes/optim/rmsprop.py
+++ b/bitsandbytes/optim/rmsprop.py
@@ -1,36 +1,109 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class RMSprop(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class RMSprop8bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop8bit, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class RMSprop32bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop32bit, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py
index 0529879..f7b8934 100644
--- a/bitsandbytes/optim/sgd.py
+++ b/bitsandbytes/optim/sgd.py
@@ -1,32 +1,99 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class SGD(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class SGD8bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD8bit, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class SGD32bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD32bit, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py
index a9eddf9..29b9c90 100644
--- a/bitsandbytes/utils.py
+++ b/bitsandbytes/utils.py
@@ -1,7 +1,9 @@
import sys
+
def print_err(s: str) -> None:
print(s, file=sys.stderr)
+
def warn_of_missing_prerequisite(s: str) -> None:
- print_err('WARNING, missing pre-requisite: ' + s)
+ print_err("WARNING, missing pre-requisite: " + s)