summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile30
-rw-r--r--bitsandbytes/__init__.py28
-rw-r--r--bitsandbytes/__main__.py96
-rw-r--r--bitsandbytes/autograd/_functions.py170
-rw-r--r--bitsandbytes/cextension.py39
-rw-r--r--bitsandbytes/cuda_setup/__init__.py0
-rw-r--r--bitsandbytes/cuda_setup/compute_capability.py79
-rw-r--r--bitsandbytes/cuda_setup/env_vars.py51
-rw-r--r--bitsandbytes/cuda_setup/main.py127
-rw-r--r--bitsandbytes/cuda_setup/paths.py126
-rw-r--r--bitsandbytes/debug_cli.py26
-rw-r--r--bitsandbytes/functional.py1303
-rw-r--r--bitsandbytes/nn/__init__.py8
-rw-r--r--bitsandbytes/nn/modules.py223
-rw-r--r--bitsandbytes/optim/__init__.py6
-rw-r--r--bitsandbytes/optim/adagrad.py126
-rw-r--r--bitsandbytes/optim/adam.py198
-rw-r--r--bitsandbytes/optim/adamw.py104
-rw-r--r--bitsandbytes/optim/lamb.py117
-rw-r--r--bitsandbytes/optim/lars.py183
-rw-r--r--bitsandbytes/optim/optimizer.py622
-rw-r--r--bitsandbytes/optim/rmsprop.py121
-rw-r--r--bitsandbytes/optim/sgd.py109
-rw-r--r--bitsandbytes/utils.py32
-rw-r--r--deploy_from_slurm.sh148
-rw-r--r--environment.yml14
-rw-r--r--install_cuda.sh5
-rw-r--r--quicktest.py90
-rw-r--r--setup.py32
-rw-r--r--tests/test_autograd.py273
-rw-r--r--tests/test_cuda_setup_evaluator.py150
-rw-r--r--tests/test_functional.py1574
-rw-r--r--tests/test_modules.py349
-rw-r--r--tests/test_optim.py446
-rw-r--r--to_be_fixed__complaints_by_linter.log149
35 files changed, 5019 insertions, 2135 deletions
diff --git a/Makefile b/Makefile
index 3e95b35..6194fe3 100644
--- a/Makefile
+++ b/Makefile
@@ -5,6 +5,13 @@ GPP:= /usr/bin/g++
ifeq ($(CUDA_HOME),)
CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev)
endif
+
+ifndef CUDA_VERSION
+$(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU)
+CUDA_VERSION:=
+endif
+
+
NVCC := $(CUDA_HOME)/bin/nvcc
###########################################
@@ -53,44 +60,46 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
- $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
- $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
- $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda110_nomatmul: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
- $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda11x_nomatmul: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
- $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda110: $(BUILD_DIR) env
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
- $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda11x: $(BUILD_DIR) env
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
- $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cpuonly: $(BUILD_DIR) env
- $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so
+ $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so
env:
@echo "ENVIRONMENT"
@echo "============================"
+ @echo "CUDA_VERSION: $(CUDA_VERSION)"
+ @echo "============================"
@echo "NVCC path: $(NVCC)"
@echo "GPP path: $(GPP) VERSION: `$(GPP) --version | head -n 1`"
@echo "CUDA_HOME: $(CUDA_HOME)"
@@ -108,7 +117,10 @@ $(ROOT_DIR)/dependencies/cub:
cd dependencies/cub; git checkout 1.11.0
clean:
- rm cuda_build/* ./bitsandbytes/libbitsandbytes.so
+ rm build/*
cleaneggs:
rm -rf *.egg*
+
+cleanlibs:
+ rm ./bitsandbytes/libbitsandbytes*.so
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index 3c3affa..7901f96 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -1,16 +1,26 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .nn import modules
-from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
+from .autograd._functions import (
+ MatmulLtState,
+ bmm_cublas,
+ matmul,
+ matmul_cublas,
+ mm_cublas,
+)
from .cextension import COMPILED_WITH_CUDA
+from .nn import modules
+from . import cuda_setup
if COMPILED_WITH_CUDA:
from .optim import adam
-__pdoc__ = {'libbitsandbytes': False,
- 'optim.optimizer.Optimizer8bit': False,
- 'optim.optimizer.MockArgs': False
- }
+__pdoc__ = {
+ "libbitsandbytes": False,
+ "optim.optimizer.Optimizer8bit": False,
+ "optim.optimizer.MockArgs": False,
+}
+
+PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"
diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py
new file mode 100644
index 0000000..7f3d24c
--- /dev/null
+++ b/bitsandbytes/__main__.py
@@ -0,0 +1,96 @@
+# from bitsandbytes.debug_cli import cli
+
+# cli()
+import os
+import sys
+import torch
+
+
+HEADER_WIDTH = 60
+
+
+def print_header(
+ txt: str, width: int = HEADER_WIDTH, filler: str = "+"
+) -> None:
+ txt = f" {txt} " if txt else ""
+ print(txt.center(width, filler))
+
+
+def print_debug_info() -> None:
+ print(
+ "\nAbove we output some debug information. Please provide this info when "
+ f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n"
+ )
+
+
+print_header("")
+print_header("DEBUG INFORMATION")
+print_header("")
+print()
+
+
+from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
+from .cuda_setup.main import get_compute_capabilities
+from .cuda_setup.env_vars import to_be_ignored
+from .utils import print_stderr
+
+
+print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
+for k, v in os.environ.items():
+ if "/" in v and not to_be_ignored(k, v):
+ print(f"'{k}': '{v}'")
+print_header("")
+
+print(
+ "\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n"
+)
+
+print_header("OTHER")
+print(f"{COMPILED_WITH_CUDA = }")
+print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}")
+print_header("")
+print_header("DEBUG INFO END")
+print_header("")
+print(
+ """
+Running a quick check that:
+ + library is importable
+ + CUDA function is callable
+"""
+)
+
+try:
+ from bitsandbytes.optim import Adam
+
+ p = torch.nn.Parameter(torch.rand(10, 10).cuda())
+ a = torch.rand(10, 10).cuda()
+
+ p1 = p.data.sum().item()
+
+ adam = Adam([p])
+
+ out = a * p
+ loss = out.sum()
+ loss.backward()
+ adam.step()
+
+ p2 = p.data.sum().item()
+
+ assert p1 != p2
+ print("SUCCESS!")
+ print("Installation was successful!")
+ sys.exit(0)
+
+except ImportError:
+ print()
+ print_stderr(
+ f"WARNING: {__package__} is currently running as CPU-only!\n"
+ "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
+ f"If you think that this is so erroneously,\nplease report an issue!"
+ )
+ print_debug_info()
+ sys.exit(0)
+except Exception as e:
+ print(e)
+ print_debug_info()
+ sys.exit(1)
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 607d868..14f2660 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,22 +1,24 @@
+from dataclasses import dataclass
+
import torch
import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
-from dataclasses import dataclass
-
tensor = torch.Tensor
-'''
+"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
-'''
+"""
+
+
class GlobalOutlierPooler(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.outliers = set()
@@ -30,25 +32,29 @@ class GlobalOutlierPooler(object):
return cls._instance
def add_outliers(self, outlier_idx, feature_dim):
- if self.model_dim is None: self.model_dim = feature_dim
- if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
+ if self.model_dim is None:
+ self.model_dim = feature_dim
+ if feature_dim != self.model_dim:
+ return # we do not encode outliers for the 2nd FFN layer
self.outliers.update(outlier_idx.tolist())
def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
-class MatMul8bit(torch.autograd.Function):
+class MatMul8bit(torch.autograd.Function):
@staticmethod
- def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
+ def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)
else:
- if len(B.shape) == 2: dim = 0
- else: dim = 1
+ if len(B.shape) == 2:
+ dim = 0
+ else:
+ dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB)
@@ -85,21 +91,43 @@ class MatMul8bit(torch.autograd.Function):
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
- if not grad_output.is_contiguous(): grad_output.contiguous()
- qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
- if not A.is_contiguous(): A = A.contiguous()
- qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
+ if not grad_output.is_contiguous():
+ grad_output.contiguous()
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output.view(-1, grad_output.shape[2]),
+ dim=0,
+ quant_type=quant_type,
+ )
+ if not A.is_contiguous():
+ A = A.contiguous()
+ qA, S2 = F.vectorwise_quant(
+ A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
+ )
igrad_B = F.igemm(qA.t(), qgrad_output)
- grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
+ grad_B = F.vectorwise_mm_dequant(
+ igrad_B, S2.t(), S1, grad_output.dtype, quant_type
+ )
else:
- qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
- qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output, dim=dims, quant_type=quant_type
+ )
+ qA, S2 = F.vectorwise_quant(
+ A, dim=dims, quant_type=quant_type
+ )
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
- grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
+ grad_B = F.vectorwise_mm_dequant(
+ igrad_B,
+ S2.permute(permute_dim),
+ S1,
+ grad_output.dtype,
+ quant_type,
+ )
if A.requires_grad:
- if len(grad_output.shape) == 3: dims = [2]
- else: dims = [1]
+ if len(grad_output.shape) == 3:
+ dims = [2]
+ else:
+ dims = [1]
if len(B.shape) == 3:
# bio -> boi
@@ -114,10 +142,18 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
- qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qgrad_output, S1 = F.vectorwise_quant(
+ grad_output, dim=dims, quant_type=quant_type
+ )
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
- grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
+ grad_A = F.vectorwise_mm_dequant(
+ igrad_A,
+ S1,
+ S3.permute(permute_dim),
+ grad_output.dtype,
+ quant_type,
+ )
return grad_A, grad_B, None, None, None
@@ -126,6 +162,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
+
@dataclass
class MatmulLtState:
CB = None
@@ -160,7 +197,6 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
-
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
@@ -183,12 +219,18 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB = B.requires_grad
formatB = state.formatB
input_shape = A.shape
- if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
- assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
+ if state.outlier_pool is None:
+ state.outlier_pool = GlobalOutlierPooler.get_instance()
+ assert (
+ A.dtype == torch.float16
+ ), f"The input data type needs to be fp16 but {A.dtype} was found!"
# 1. Quantize A
- if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
- CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
+ if len(A.shape) == 3:
+ A = A.view(-1, A.shape[-1]).contiguous()
+ CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
+ A, threshold=state.threshold
+ )
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
@@ -202,9 +244,11 @@ class MatMul8bitLt(torch.autograd.Function):
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
- state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
- #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
- #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
+ state.CxB, state.SB = F.transform(
+ state.CB, to_order=formatB
+ )
+ # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
+ # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
@@ -215,28 +259,34 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
- #if state.idx is not None:
+ # if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
- #else:
+ # else:
# subA = None
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
-
# 2. Quantize B
if state.has_fp16_weights:
- has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
+ has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
- if is_transposed: B = B.contiguous()
+ if is_transposed:
+ B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
- CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
+ (
+ CB,
+ state.CBt,
+ state.SCB,
+ state.SCBt,
+ coo_tensorB,
+ ) = F.double_quant(B)
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False
@@ -246,14 +296,19 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
- #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
- #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
- #else:
+ # else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
- state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half()
+ state.subB = (
+ (outliers * state.SCB.view(-1, 1) / 127.0)
+ .t()
+ .contiguous()
+ .half()
+ )
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
@@ -266,7 +321,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
@@ -289,7 +344,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
- #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
+ # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone
return clone_func(output.view(output_shape))
@@ -302,28 +357,36 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
- assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
+ assert (
+ state.has_fp16_weights
+ ), "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
- grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
+ grad_output = grad_output.view(
+ -1, grad_output.shape[-1]
+ ).contiguous()
grad_A = grad_B = None
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
- C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
+ C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
- C32grad, Sgrad = F.transform(Cgrad, 'col32')
+ C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
- state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
+ state.CxBt, state.SBt = F.transform(
+ state.CBt, to_order=formatB, transpose=True
+ )
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
- grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
+ ctx.grad_shape
+ )
return grad_A, grad_B, None, None
@@ -331,9 +394,14 @@ class MatMul8bitLt(torch.autograd.Function):
matmul = MatMul8bitLt.apply
-def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
+def matmul(
+ A: tensor,
+ B: tensor,
+ out: tensor = None,
+ state: MatmulLtState = None,
+ threshold=0.0,
+):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, state)
-
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index 2374c35..66c79d8 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -1,15 +1,46 @@
import ctypes as ct
-import os
+from pathlib import Path
from warnings import warn
-lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
+from .cuda_setup.main import evaluate_cuda_setup
+
+class CUDALibrary_Singleton(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError("Call get_instance() instead")
+
+ def initialize(self):
+ binary_name = evaluate_cuda_setup()
+ package_dir = Path(__file__).parent
+ binary_path = package_dir / binary_name
+
+ if not binary_path.exists():
+ print(f"TODO: compile library for specific version: {binary_name}")
+ legacy_binary_name = "libbitsandbytes.so"
+ print(f"Defaulting to {legacy_binary_name}...")
+ self.lib = ct.cdll.LoadLibrary(package_dir / legacy_binary_name)
+ else:
+ self.lib = ct.cdll.LoadLibrary(package_dir / binary_name)
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+
+lib = CUDALibrary_Singleton.get_instance().lib
try:
lib.cadam32bit_g32
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
- warn("The installed version of bitsandbytes was compiled without GPU support. "
- "8-bit optimizers and GPU quantization are unavailable.")
+ warn(
+ "The installed version of bitsandbytes was compiled without GPU support. "
+ "8-bit optimizers and GPU quantization are unavailable."
+ )
COMPILED_WITH_CUDA = False
diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/bitsandbytes/cuda_setup/__init__.py
diff --git a/bitsandbytes/cuda_setup/compute_capability.py b/bitsandbytes/cuda_setup/compute_capability.py
new file mode 100644
index 0000000..7a3f463
--- /dev/null
+++ b/bitsandbytes/cuda_setup/compute_capability.py
@@ -0,0 +1,79 @@
+import ctypes
+from dataclasses import dataclass, field
+
+
+@dataclass
+class CudaLibVals:
+ # code bits taken from
+ # https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
+
+ nGpus: ctypes.c_int = field(default=ctypes.c_int())
+ cc_major: ctypes.c_int = field(default=ctypes.c_int())
+ cc_minor: ctypes.c_int = field(default=ctypes.c_int())
+ device: ctypes.c_int = field(default=ctypes.c_int())
+ error_str: ctypes.c_char_p = field(default=ctypes.c_char_p())
+ cuda: ctypes.CDLL = field(init=False, repr=False)
+ ccs: List[str, ...] = field(init=False)
+
+ def _initialize_driver_API(self):
+ self.check_cuda_result(self.cuda.cuInit(0))
+
+ def _load_cuda_lib(self):
+ """
+ 1. find libcuda.so library (GPU driver) (/usr/lib)
+ init_device -> init variables -> call function by reference
+ """
+ libnames = "libcuda.so"
+ for libname in libnames:
+ try:
+ self.cuda = ctypes.CDLL(libname)
+ except OSError:
+ continue
+ else:
+ break
+ else:
+ raise OSError("could not load any of: " + " ".join(libnames))
+
+ def call_cuda_func(self, function_obj, **kwargs):
+ CUDA_SUCCESS = 0 # constant taken from cuda.h
+ pass
+ # if (CUDA_SUCCESS := function_obj(
+
+ def _error_handle(cuda_lib_call_return_value):
+ """
+ 2. call extern C function to determine CC
+ (see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
+ """
+ CUDA_SUCCESS = 0 # constant taken from cuda.h
+
+ if cuda_lib_call_return_value != CUDA_SUCCESS:
+ self.cuda.cuGetErrorString(
+ cuda_lib_call_return_value,
+ ctypes.byref(self.error_str),
+ )
+ print("Count not initialize CUDA - failure!")
+ raise Exception("CUDA exception!")
+ return cuda_lib_call_return_value
+
+ def __post_init__(self):
+ self._load_cuda_lib()
+ self._initialize_driver_API()
+ self.check_cuda_result(
+ self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus))
+ )
+ tmp_ccs = []
+ for gpu_index in range(self.nGpus.value):
+ check_cuda_result(
+ self.cuda,
+ self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index),
+ )
+ check_cuda_result(
+ self.cuda,
+ self.cuda.cuDeviceComputeCapability(
+ ctypes.byref(self.cc_major),
+ ctypes.byref(self.cc_minor),
+ self.device,
+ ),
+ )
+ tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}")
+ self.ccs = sorted(tmp_ccs, reverse=True)
diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py
new file mode 100644
index 0000000..536a7d8
--- /dev/null
+++ b/bitsandbytes/cuda_setup/env_vars.py
@@ -0,0 +1,51 @@
+import os
+from typing import Dict
+
+
+def to_be_ignored(env_var: str, value: str) -> bool:
+ ignorable = {
+ "PWD", # PWD: this is how the shell keeps track of the current working dir
+ "OLDPWD",
+ "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated
+ "SSH_TTY",
+ "HOME", # Linux shell default
+ "TMUX", # Terminal Multiplexer
+ "XDG_DATA_DIRS", # XDG: Desktop environment stuff
+ "XDG_RUNTIME_DIR",
+ "MAIL", # something related to emails
+ "SHELL", # binary for currently invoked shell
+ "DBUS_SESSION_BUS_ADDRESS", # hardware related
+ "PATH", # this is for finding binaries, not libraries
+ "LESSOPEN", # related to the `less` command
+ "LESSCLOSE",
+ "_", # current Python interpreter
+ }
+ return env_var in ignorable
+
+
+def might_contain_a_path(candidate: str) -> bool:
+ return "/" in candidate
+
+
+def is_active_conda_env(env_var: str) -> bool:
+ return "CONDA_PREFIX" == env_var
+
+
+def is_other_conda_env_var(env_var: str) -> bool:
+ return "CONDA" in env_var
+
+
+def is_relevant_candidate_env_var(env_var: str, value: str) -> bool:
+ return is_active_conda_env(env_var) or (
+ might_contain_a_path(value) and not
+ is_other_conda_env_var(env_var) and not
+ to_be_ignored(env_var, value)
+ )
+
+
+def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]:
+ return {
+ env_var: value
+ for env_var, value in os.environ.items()
+ if is_relevant_candidate_env_var(env_var, value)
+ }
diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py
new file mode 100644
index 0000000..e96ac70
--- /dev/null
+++ b/bitsandbytes/cuda_setup/main.py
@@ -0,0 +1,127 @@
+"""
+extract factors the build is dependent on:
+[X] compute capability
+ [ ] TODO: Q - What if we have multiple GPUs of different makes?
+- CUDA version
+- Software:
+ - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
+ - CuBLAS-LT: full-build 8-bit optimizer
+ - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
+
+evaluation:
+ - if paths faulty, return meaningful error
+ - else:
+ - determine CUDA version
+ - determine capabilities
+ - based on that set the default path
+"""
+
+import ctypes
+from pathlib import Path
+
+from ..utils import execute_and_return
+from .paths import determine_cuda_runtime_lib_path
+
+
+def check_cuda_result(cuda, result_val):
+ # 3. Check for CUDA errors
+ if result_val != 0:
+ error_str = ctypes.c_char_p()
+ cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
+ raise Exception(f"CUDA exception! ERROR: {error_str}")
+
+
+def get_compute_capabilities():
+ """
+ 1. find libcuda.so library (GPU driver) (/usr/lib)
+ init_device -> init variables -> call function by reference
+ 2. call extern C function to determine CC
+ (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
+ 3. Check for CUDA errors
+ https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
+ # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
+ """
+
+ # 1. find libcuda.so library (GPU driver) (/usr/lib)
+ try:
+ cuda = ctypes.CDLL("libcuda.so")
+ except OSError:
+ # TODO: shouldn't we error or at least warn here?
+ return None
+
+ nGpus = ctypes.c_int()
+ cc_major = ctypes.c_int()
+ cc_minor = ctypes.c_int()
+
+ result = ctypes.c_int()
+ device = ctypes.c_int()
+
+ check_cuda_result(cuda, cuda.cuInit(0))
+
+ check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
+ ccs = []
+ for i in range(nGpus.value):
+ check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
+ ref_major = ctypes.byref(cc_major)
+ ref_minor = ctypes.byref(cc_minor)
+ # 2. call extern C function to determine CC
+ check_cuda_result(
+ cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
+ )
+ ccs.append(f"{cc_major.value}.{cc_minor.value}")
+
+ return ccs.sort()
+
+
+# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
+def get_compute_capability():
+ """
+ Extracts the highest compute capbility from all available GPUs, as compute
+ capabilities are downwards compatible. If no GPUs are detected, it returns
+ None.
+ """
+ if ccs := get_compute_capabilities() is not None:
+ # TODO: handle different compute capabilities; for now, take the max
+ return ccs[-1]
+ return None
+
+
+def evaluate_cuda_setup():
+ cuda_path = determine_cuda_runtime_lib_path()
+ print(f"CUDA SETUP: CUDA path found: {cuda_path}")
+ cc = get_compute_capability()
+ binary_name = "libbitsandbytes_cpu.so"
+
+ # FIXME: has_gpu is still unused
+ if not (has_gpu := bool(cc)):
+ print(
+ "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
+ )
+ return binary_name
+
+ # 7.5 is the minimum CC vor cublaslt
+ has_cublaslt = cc in ["7.5", "8.0", "8.6"]
+
+ # TODO:
+ # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
+ # (2) Multiple CUDA versions installed
+
+ # FIXME: cuda_home is still unused
+ cuda_home = str(Path(cuda_path).parent.parent)
+ # we use ls -l instead of nvcc to determine the cuda version
+ # since most installations will have the libcudart.so installed, but not the compiler
+ ls_output, err = execute_and_return(f"ls -l {cuda_path}")
+ major, minor, revision = (
+ ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")
+ )
+ cuda_version_string = f"{major}{minor}"
+
+ def get_binary_name():
+ "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
+ bin_base_name = "libbitsandbytes_cuda"
+ if has_cublaslt:
+ return f"{bin_base_name}{cuda_version_string}.so"
+ else:
+ return f"{bin_base_name}_nocublaslt.so"
+
+ return binary_name
diff --git a/bitsandbytes/cuda_setup/paths.py b/bitsandbytes/cuda_setup/paths.py
new file mode 100644
index 0000000..c4a7465
--- /dev/null
+++ b/bitsandbytes/cuda_setup/paths.py
@@ -0,0 +1,126 @@
+from pathlib import Path
+from typing import Set, Union
+from warnings import warn
+
+from ..utils import print_stderr
+from .env_vars import get_potentially_lib_path_containing_env_vars
+
+
+CUDA_RUNTIME_LIB: str = "libcudart.so"
+
+
+def purge_unwanted_semicolon(tentative_path: Path) -> Path:
+ """
+ Special function to handle the following exception:
+ __LMOD_REF_COUNT_PATH=/sw/cuda/11.6.2/bin:2;/mmfs1/home/dettmers/git/sched/bin:1;/mmfs1/home/dettmers/data/anaconda3/bin:1;/mmfs1/home/dettmers/data/anaconda3/condabin:1;/mmfs1/home/dettmers/.local/bin:1;/mmfs1/home/dettmers/bin:1;/usr/local/bin:1;/usr/bin:1;/usr/local/sbin:1;/usr/sbin:1;/mmfs1/home/dettmers/.fzf/bin:1;/mmfs1/home/dettmers/data/local/cuda-11.4/bin:1
+ """
+ # if ';' in str(tentative_path):
+ # path_as_str, _ = str(tentative_path).split(';')
+ pass
+
+
+def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
+ return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
+
+
+def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
+ non_existent_directories: Set[Path] = {
+ path for path in candidate_paths if not path.exists()
+ }
+
+ if non_existent_directories:
+ print_stderr(
+ "WARNING: The following directories listed in your path were found to "
+ f"be non-existent: {non_existent_directories}"
+ )
+
+ return candidate_paths - non_existent_directories
+
+
+def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
+ return {
+ path / CUDA_RUNTIME_LIB
+ for path in candidate_paths
+ if (path / CUDA_RUNTIME_LIB).is_file()
+ }
+
+
+def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
+ """
+ Searches a given environmental var for the CUDA runtime library,
+ i.e. `libcudart.so`.
+ """
+ return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate))
+
+
+def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
+ return get_cuda_runtime_lib_paths(
+ resolve_paths_list(paths_list_candidate)
+ )
+
+
+def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
+ if len(results_paths) > 1:
+ warning_msg = (
+ f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. "
+ "We'll flip a coin and try one of these, in order to fail forward.\n"
+ "Either way, this might cause trouble in the future:\n"
+ "If you get `CUDA error: invalid device function` errors, the above "
+ "might be the cause and the solution is to make sure only one "
+ f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env."
+ )
+ warn(warning_msg)
+
+
+def determine_cuda_runtime_lib_path() -> Union[Path, None]:
+ """
+ Searches for a cuda installations, in the following order of priority:
+ 1. active conda env
+ 2. LD_LIBRARY_PATH
+ 3. any other env vars, while ignoring those that
+ - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`)
+ - don't contain the path separator `/`
+
+ If multiple libraries are found in part 3, we optimistically try one,
+ while giving a warning message.
+ """
+ candidate_env_vars = get_potentially_lib_path_containing_env_vars()
+
+ if "CONDA_PREFIX" in candidate_env_vars:
+ conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib"
+
+ conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path))
+ warn_in_case_of_duplicates(conda_cuda_libs)
+
+ if conda_cuda_libs:
+ return next(iter(conda_cuda_libs))
+
+ warn(
+ f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
+ f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
+ )
+
+ if "LD_LIBRARY_PATH" in candidate_env_vars:
+ lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
+
+ if lib_ld_cuda_libs:
+ return next(iter(lib_ld_cuda_libs))
+ warn_in_case_of_duplicates(lib_ld_cuda_libs)
+
+ warn(
+ f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
+ f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
+ )
+
+ remaining_candidate_env_vars = {
+ env_var: value for env_var, value in candidate_env_vars.items()
+ if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
+ }
+
+ cuda_runtime_libs = set()
+ for env_var, value in remaining_candidate_env_vars:
+ cuda_runtime_libs.update(find_cuda_lib_in(value))
+
+ warn_in_case_of_duplicates(cuda_runtime_libs)
+
+ return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else set()
diff --git a/bitsandbytes/debug_cli.py b/bitsandbytes/debug_cli.py
new file mode 100644
index 0000000..4306bc0
--- /dev/null
+++ b/bitsandbytes/debug_cli.py
@@ -0,0 +1,26 @@
+import typer
+
+cli = typer.Typer()
+
+
+@cli.callback()
+def callback():
+ """
+ Awesome Portal Gun
+ """
+
+
+@cli.command()
+def shoot():
+ """
+ Shoot the portal gun
+ """
+ typer.echo("Shooting portal gun")
+
+
+@cli.command()
+def load():
+ """
+ Load the portal gun
+ """
+ typer.echo("Loading portal gun")
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index ad85f53..b4409e4 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import random
@@ -10,47 +10,86 @@ import torch
from typing import Tuple
from torch import Tensor
-from .cextension import lib, COMPILED_WITH_CUDA
+from .cextension import COMPILED_WITH_CUDA, lib
name2qmap = {}
if COMPILED_WITH_CUDA:
- ''' C FUNCTIONS FOR OPTIMIZERS '''
+ """C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {}
- str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
- str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
- str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
- str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
- str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
- str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
+ str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
+ str2optimizer32bit["momentum"] = (
+ lib.cmomentum32bit_g32,
+ lib.cmomentum32bit_g16,
+ )
+ str2optimizer32bit["rmsprop"] = (
+ lib.crmsprop32bit_g32,
+ lib.crmsprop32bit_g16,
+ )
+ str2optimizer32bit["adagrad"] = (
+ lib.cadagrad32bit_g32,
+ lib.cadagrad32bit_g16,
+ )
+ str2optimizer32bit["lars"] = (
+ lib.cmomentum32bit_g32,
+ lib.cmomentum32bit_g16,
+ )
+ str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer8bit = {}
- str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
- str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
- str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16)
- str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
- str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
+ str2optimizer8bit["adam"] = (
+ lib.cadam_static_8bit_g32,
+ lib.cadam_static_8bit_g16,
+ )
+ str2optimizer8bit["momentum"] = (
+ lib.cmomentum_static_8bit_g32,
+ lib.cmomentum_static_8bit_g16,
+ )
+ str2optimizer8bit["rmsprop"] = (
+ lib.crmsprop_static_8bit_g32,
+ lib.crmsprop_static_8bit_g16,
+ )
+ str2optimizer8bit["lamb"] = (
+ lib.cadam_static_8bit_g32,
+ lib.cadam_static_8bit_g16,
+ )
+ str2optimizer8bit["lars"] = (
+ lib.cmomentum_static_8bit_g32,
+ lib.cmomentum_static_8bit_g16,
+ )
str2optimizer8bit_blockwise = {}
- str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16)
- str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16)
- str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
- str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16)
+ str2optimizer8bit_blockwise["adam"] = (
+ lib.cadam_8bit_blockwise_fp32,
+ lib.cadam_8bit_blockwise_fp16,
+ )
+ str2optimizer8bit_blockwise["momentum"] = (
+ lib.cmomentum_8bit_blockwise_fp32,
+ lib.cmomentum_8bit_blockwise_fp16,
+ )
+ str2optimizer8bit_blockwise["rmsprop"] = (
+ lib.crmsprop_8bit_blockwise_fp32,
+ lib.crmsprop_8bit_blockwise_fp16,
+ )
+ str2optimizer8bit_blockwise["adagrad"] = (
+ lib.cadagrad_8bit_blockwise_fp32,
+ lib.cadagrad_8bit_blockwise_fp16,
+ )
class CUBLAS_Context(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = {}
- #prev_device = torch.cuda.current_device()
- #for i in range(torch.cuda.device_count()):
+ # prev_device = torch.cuda.current_device()
+ # for i in range(torch.cuda.device_count()):
# torch.cuda.set_device(torch.device('cuda', i))
# self.context.append(ct.c_void_p(lib.get_context()))
- #torch.cuda.set_device(prev_device)
+ # torch.cuda.set_device(prev_device)
@classmethod
def get_instance(cls):
@@ -67,11 +106,12 @@ class CUBLAS_Context(object):
torch.cuda.set_device(prev_device)
return self.context[device.index]
+
class Cusparse_Context(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.context = ct.c_void_p(lib.get_cusparse())
@@ -83,14 +123,16 @@ class Cusparse_Context(object):
cls._instance.initialize()
return cls._instance
+
def create_linear_map(signed=True):
if signed:
return torch.linspace(-1.0, 1.0, 256)
else:
return torch.linspace(0.0, 1.0, 256)
+
def create_dynamic_map(signed=True, n=7):
- '''
+ """
Creates the dynamic quantiztion map.
The dynamic data type is made up of a dynamic exponent and
@@ -104,43 +146,53 @@ def create_dynamic_map(signed=True, n=7):
For more details see
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
- '''
+ """
data = []
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
- additional_items = 2**(7-n)-1
- if not signed: additional_items = 2*additional_items
+ additional_items = 2 ** (7 - n) - 1
+ if not signed:
+ additional_items = 2 * additional_items
for i in range(n):
- fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1
+ fraction_items = (
+ 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
+ )
boundaries = torch.linspace(0.1, 1, fraction_items)
- means = (boundaries[:-1]+boundaries[1:])/2.0
- data += ((10**(-(n-1)+i))*means).tolist()
+ means = (boundaries[:-1] + boundaries[1:]) / 2.0
+ data += ((10 ** (-(n - 1) + i)) * means).tolist()
if signed:
- data += (-(10**(-(n-1)+i))*means).tolist()
+ data += (-(10 ** (-(n - 1) + i)) * means).tolist()
if additional_items > 0:
- boundaries = torch.linspace(0.1, 1, additional_items+1)
- means = (boundaries[:-1]+boundaries[1:])/2.0
- data += ((10**(-(n-1)+i))*means).tolist()
+ boundaries = torch.linspace(0.1, 1, additional_items + 1)
+ means = (boundaries[:-1] + boundaries[1:]) / 2.0
+ data += ((10 ** (-(n - 1) + i)) * means).tolist()
if signed:
- data += (-(10**(-(n-1)+i))*means).tolist()
+ data += (-(10 ** (-(n - 1) + i)) * means).tolist()
data.append(0)
data.append(1.0)
data.sort()
return Tensor(data)
+
def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
- print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
+ print(
+ f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
+ )
assert major >= 7
- if major == 7: return 'col_turing'
- elif major == 8: return 'col_ampere'
- else: return 'col_turing'
+ if major == 7:
+ return "col_turing"
+ elif major == 8:
+ return "col_ampere"
+ else:
+ return "col_turing"
+
def is_on_gpu(tensors):
@@ -151,7 +203,7 @@ def is_on_gpu(tensors):
return on_gpu
def get_ptr(A: Tensor) -> ct.c_void_p:
- '''
+ """
Get the ctypes pointer from a PyTorch Tensor.
Parameters
@@ -162,31 +214,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
Returns
-------
ctypes.c_void_p
- '''
- if A is None: return None
- else: return ct.c_void_p(A.data.storage().data_ptr())
+ """
+ if A is None:
+ return None
+ else:
+ return ct.c_void_p(A.data.storage().data_ptr())
+
def pre_call(device):
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
return prev_device
+
def post_call(prev_device):
torch.cuda.set_device(prev_device)
+
def get_transform_func(dtype, orderA, orderOut, transpose=False):
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
if not hasattr(lib, name):
print(name)
- raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}')
+ raise ValueError(
+ f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
+ )
else:
return getattr(lib, name)
+
class GlobalData(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.data = {}
@@ -199,15 +259,17 @@ class GlobalData(object):
return cls._instance
-def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False):
- #init_func = torch.empty
+def get_transform_buffer(
+ shape, dtype, device, to_order, from_order="row", transpose=False
+):
+ # init_func = torch.empty
init_func = torch.zeros
dims = len(shape)
if dims == 2:
rows = shape[0]
elif dims == 3:
- rows = shape[0]*shape[1]
+ rows = shape[0] * shape[1]
cols = shape[-1]
state = (shape, to_order)
@@ -218,30 +280,45 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans
cols = tmp
state = (shape[::-1], to_order)
- if to_order == 'row' or to_order == 'col':
+ if to_order == "row" or to_order == "col":
return init_func(shape, dtype=dtype, device=device), state
- elif to_order == 'col32':
+ elif to_order == "col32":
# blocks of 32 columns (padded)
- cols = 32*((cols+31)//32)
+ cols = 32 * ((cols + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
- elif to_order == 'col_turing':
+ elif to_order == "col_turing":
# blocks of 32 columns and 8 rows
- cols = 32*((cols+31)//32)
- rows = 8*((rows+7)//8)
+ cols = 32 * ((cols + 31) // 32)
+ rows = 8 * ((rows + 7) // 8)
return init_func((rows, cols), dtype=dtype, device=device), state
- elif to_order == 'col_ampere':
+ elif to_order == "col_ampere":
# blocks of 32 columns and 32 rows
- cols = 32*((cols+31)//32)
- rows = 32*((rows+31)//32)
+ cols = 32 * ((cols + 31) // 32)
+ rows = 32 * ((rows + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
else:
- raise NotImplementedError(f'To_order not supported: {to_order}')
-
-def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
- if state is None: state = (A.shape, from_order)
- else: from_order = state[1]
- if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
- else: new_state = (state[1], to_order)
+ raise NotImplementedError(f"To_order not supported: {to_order}")
+
+
+def nvidia_transform(
+ A,
+ to_order,
+ from_order="row",
+ out=None,
+ transpose=False,
+ state=None,
+ ld=None,
+):
+ if state is None:
+ state = (A.shape, from_order)
+ else:
+ from_order = state[1]
+ if out is None:
+ out, new_state = get_transform_buffer(
+ state[0], A.dtype, A.device, to_order, state[1]
+ )
+ else:
+ new_state = (state[1], to_order)
func = get_transform_func(A.dtype, from_order, to_order, transpose)
shape = state[0]
@@ -251,10 +328,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
elif ld is not None:
n = math.prod(shape)
dim1 = math.prod([shape[i] for i in ld])
- dim2 = ct.c_int32(n//dim1)
+ dim2 = ct.c_int32(n // dim1)
dim1 = ct.c_int32(dim1)
else:
- dim1 = ct.c_int32(shape[0]*shape[1])
+ dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
ptr = CUBLAS_Context.get_instance().get_context(A.device)
@@ -262,10 +339,12 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s
ptrOut = get_ptr(out)
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
-
return out, new_state
-def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor:
+
+def estimate_quantiles(
+ A: Tensor, out: Tensor = None, offset: float = 1 / 512
+) -> Tensor:
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
@@ -295,15 +374,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out])
if A.dtype == torch.float32:
- lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
+ lib.cestimate_quantiles_fp32(
+ get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
+ )
elif A.dtype == torch.float16:
- lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
+ lib.cestimate_quantiles_fp16(
+ get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
+ )
else:
- raise NotImplementedError(f'Not supported data type {A.dtype}')
+ raise NotImplementedError(f"Not supported data type {A.dtype}")
return out
-def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor:
- '''
+
+def quantize_blockwise(
+ A: Tensor,
+ code: Tensor = None,
+ absmax: Tensor = None,
+ rand=None,
+ out: Tensor = None,
+) -> Tensor:
+ """
Quantize tensor A in blocks of size 4096 values.
Quantizes tensor A by dividing it into blocks of 4096 values.
@@ -329,22 +419,23 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
The 8-bit tensor.
tuple(torch.Tensor, torch.Tensor):
The quantization state to undo the quantization.
- '''
+ """
if code is None:
- if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
- code = name2qmap['dynamic']
+ if "dynamic" not in name2qmap:
+ name2qmap["dynamic"] = create_dynamic_map().to(A.device)
+ code = name2qmap["dynamic"]
code = code.to(A.device)
if absmax is None:
n = A.numel()
num_blocks = 4096
- blocks = n//num_blocks
+ blocks = n // num_blocks
blocks += 1 if n % num_blocks > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
- if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
-
+ if out is None:
+ out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
is_on_gpu([code, A, absmax, out, rand])
@@ -352,29 +443,73 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32:
- lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_stochastic_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ get_ptr(rand),
+ ct.c_int32(rand_offset),
+ ct.c_int(A.numel()),
+ )
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_stochastic_fp16(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ get_ptr(rand),
+ ct.c_int32(rand_offset),
+ ct.c_int(A.numel()),
+ )
else:
- raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ raise ValueError(
+ f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
+ )
else:
if A.dtype == torch.float32:
- lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(A.numel()),
+ )
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_fp16(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(A.numel()),
+ )
else:
- raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ raise ValueError(
+ f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
+ )
else:
# cpu
assert rand is None
- lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel()))
+ lib.cquantize_blockwise_cpu_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(A.numel()),
+ )
return out, (absmax, code)
-def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
- absmax: Tensor=None, code: Tensor=None, out: Tensor=None,
- blocksize: int=4096) -> Tensor:
- '''
+
+def dequantize_blockwise(
+ A: Tensor,
+ quant_state: Tuple[Tensor, Tensor] = None,
+ absmax: Tensor = None,
+ code: Tensor = None,
+ out: Tensor = None,
+ blocksize: int = 4096,
+) -> Tensor:
+ """
Dequantizes blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in
@@ -385,7 +520,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
A : torch.Tensor
The input 8-bit tensor.
quant_state : tuple(torch.Tensor, torch.Tensor)
- Tuple of code and absmax values.
+ Tuple of code and absmax values.
absmax : torch.Tensor
The absmax values.
code : torch.Tensor
@@ -398,57 +533,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None,
-------
torch.Tensor:
Dequantized tensor (default: float32)
- '''
+ """
assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
- if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
- code = name2qmap['dynamic']
+ if "dynamic" not in name2qmap:
+ name2qmap["dynamic"] = create_dynamic_map().to(A.device)
+ code = name2qmap["dynamic"]
code = code.to(A.device)
- if out is None: out = torch.zeros_like(A, dtype=torch.float32)
- if quant_state is None: quant_state = (absmax, code)
+ if out is None:
+ out = torch.zeros_like(A, dtype=torch.float32)
+ if quant_state is None:
+ quant_state = (absmax, code)
if blocksize not in [2048, 4096]:
- raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]')
+ raise ValueError(
+ f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
+ )
if A.device.type != 'cpu':
is_on_gpu([A, out])
if out.dtype == torch.float32:
- lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
+ lib.cdequantize_blockwise_fp32(
+ get_ptr(quant_state[1]),
+ get_ptr(A),
+ get_ptr(quant_state[0]),
+ get_ptr(out),
+ ct.c_int(blocksize),
+ ct.c_int(A.numel()),
+ )
elif out.dtype == torch.float16:
- lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
+ lib.cdequantize_blockwise_fp16(
+ get_ptr(quant_state[1]),
+ get_ptr(A),
+ get_ptr(quant_state[0]),
+ get_ptr(out),
+ ct.c_int(blocksize),
+ ct.c_int(A.numel()),
+ )
else:
- raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}')
+ raise ValueError(
+ f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
+ )
else:
- lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(A.numel()))
-
+ lib.cdequantize_blockwise_cpu_fp32(
+ get_ptr(quant_state[1]),
+ get_ptr(A),
+ get_ptr(quant_state[0]),
+ get_ptr(out),
+ ct.c_int(A.numel()),
+ )
return out
-def quantize(A: Tensor, code: Tensor=None, out: Tensor=None) -> Tensor:
+def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
if code is None:
- if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
- code = name2qmap['dynamic']
+ if "dynamic" not in name2qmap:
+ name2qmap["dynamic"] = create_dynamic_map().to(A.device)
+ code = name2qmap["dynamic"]
code = code.to(A.device)
absmax = torch.abs(A).max()
- inp = A/absmax
+ inp = A / absmax
out = quantize_no_absmax(inp, code, out)
return out, (absmax, code)
-def dequantize(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, absmax: Tensor=None, code: Tensor=None, out: Tensor=None) -> Tensor:
+
+def dequantize(
+ A: Tensor,
+ quant_state: Tuple[Tensor, Tensor] = None,
+ absmax: Tensor = None,
+ code: Tensor = None,
+ out: Tensor = None,
+) -> Tensor:
assert quant_state is not None or absmax is not None
if code is None and quant_state is None:
- if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device)
- code = name2qmap['dynamic']
+ if "dynamic" not in name2qmap:
+ name2qmap["dynamic"] = create_dynamic_map().to(A.device)
+ code = name2qmap["dynamic"]
code = code.to(A.device)
- if quant_state is None: quant_state = (absmax, code)
+ if quant_state is None:
+ quant_state = (absmax, code)
out = dequantize_no_absmax(A, quant_state[1], out)
- return out*quant_state[0]
+ return out * quant_state[0]
+
-def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
+def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
'''
Quantizes input tensor to 8-bit.
@@ -474,7 +646,8 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
-def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
+
+def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
'''
Dequantizes the 8-bit tensor to 32-bit.
@@ -500,12 +673,25 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor:
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
-def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Tensor,
- beta1: float, eps: float, step: int, lr: float,
- state2: Tensor=None, beta2: float=0.0,
- weight_decay: float=0.0, gnorm_scale: float=1.0,
- unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None:
- '''
+
+def optimizer_update_32bit(
+ optimizer_name: str,
+ g: Tensor,
+ p: Tensor,
+ state1: Tensor,
+ beta1: float,
+ eps: float,
+ step: int,
+ lr: float,
+ state2: Tensor = None,
+ beta2: float = 0.0,
+ weight_decay: float = 0.0,
+ gnorm_scale: float = 1.0,
+ unorm_vec: Tensor = None,
+ max_unorm: float = 0.0,
+ skip_zeros=False,
+) -> None:
+ """
Performs an inplace optimizer update with one or two optimizer states.
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
@@ -542,33 +728,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
The maximum update norm relative to the weight norm.
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
- '''
+ """
param_norm = 0.0
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
if optimizer_name not in str2optimizer32bit:
- raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')
+ raise NotImplementedError(
+ f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
+ )
if g.dtype == torch.float32 and state1.dtype == torch.float32:
- str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
- ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
- ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
+ str2optimizer32bit[optimizer_name][0](
+ get_ptr(g),
+ get_ptr(p),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_float(weight_decay),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
elif g.dtype == torch.float16 and state1.dtype == torch.float32:
- str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
- ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
- ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
+ str2optimizer32bit[optimizer_name][1](
+ get_ptr(g),
+ get_ptr(p),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_float(weight_decay),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
else:
- raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
-
-def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
- beta1: float, beta2: float, eps: float,
- step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
- max1: Tensor, max2: Tensor, new_max1: Tensor, new_max2: Tensor,
- weight_decay: float=0.0, gnorm_scale: float=1.0,
- unorm_vec: Tensor=None, max_unorm: float=0.0) -> None:
- '''
+ raise ValueError(
+ f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
+ )
+
+
+def optimizer_update_8bit(
+ optimizer_name: str,
+ g: Tensor,
+ p: Tensor,
+ state1: Tensor,
+ state2: Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: Tensor,
+ qmap2: Tensor,
+ max1: Tensor,
+ max2: Tensor,
+ new_max1: Tensor,
+ new_max2: Tensor,
+ weight_decay: float = 0.0,
+ gnorm_scale: float = 1.0,
+ unorm_vec: Tensor = None,
+ max_unorm: float = 0.0,
+) -> None:
+ """
Performs an inplace Adam update.
Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
@@ -616,56 +853,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
- '''
+ """
param_norm = 0.0
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
- str2optimizer8bit[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
- get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm),
- ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
- ct.c_int32(step), ct.c_float(lr),
- get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2),
- ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ str2optimizer8bit[optimizer_name][0](
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(max1),
+ get_ptr(max2),
+ get_ptr(new_max1),
+ get_ptr(new_max2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_int32(g.numel()),
+ )
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
- str2optimizer8bit[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
- get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm),
- ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
- ct.c_int32(step), ct.c_float(lr),
- get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2),
- ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
+ str2optimizer8bit[optimizer_name][1](
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(max1),
+ get_ptr(max2),
+ get_ptr(new_max1),
+ get_ptr(new_max2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_int32(g.numel()),
+ )
else:
- raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
-
-
-def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
- beta1: float, beta2: float, eps: float,
- step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
- absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0,
- skip_zeros=False) -> None:
-
+ raise ValueError(
+ f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
+ )
+
+
+def optimizer_update_8bit_blockwise(
+ optimizer_name: str,
+ g: Tensor,
+ p: Tensor,
+ state1: Tensor,
+ state2: Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: Tensor,
+ qmap2: Tensor,
+ absmax1: Tensor,
+ absmax2: Tensor,
+ weight_decay: float = 0.0,
+ gnorm_scale: float = 1.0,
+ skip_zeros=False,
+) -> None:
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
- str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
- ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
- ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
- ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
+ str2optimizer8bit_blockwise[optimizer_name][0](
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(absmax1),
+ get_ptr(absmax2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
- str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
- ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
- ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
- ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
+ str2optimizer8bit_blockwise[optimizer_name][1](
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(absmax1),
+ get_ptr(absmax2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
else:
- raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
+ raise ValueError(
+ f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
+ )
-def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int=5):
+def percentile_clipping(
+ grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
+):
"""Applies percentile clipping
grad: torch.Tensor
@@ -678,11 +994,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
"""
is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
- lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
+ lib.cpercentile_clipping_g32(
+ get_ptr(grad),
+ get_ptr(gnorm_vec),
+ ct.c_int32(step),
+ ct.c_int32(grad.numel()),
+ )
elif grad.dtype == torch.float16:
- lib.cpercentile_clipping_g16(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel()))
+ lib.cpercentile_clipping_g16(
+ get_ptr(grad),
+ get_ptr(gnorm_vec),
+ ct.c_int32(step),
+ ct.c_int32(grad.numel()),
+ )
else:
- raise ValueError(f'Gradient type {grad.dtype} not supported!')
+ raise ValueError(f"Gradient type {grad.dtype} not supported!")
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec)
@@ -690,22 +1016,24 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
gnorm_scale = 1.0
if current_gnorm > clip_value:
- gnorm_scale = clip_value/current_gnorm
+ gnorm_scale = clip_value / current_gnorm
return current_gnorm, clip_value, gnorm_scale
-def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
+def histogram_scatter_add_2d(
+ histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
+):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
assert source.dtype == torch.float32
assert index1.dtype == torch.int32
assert index2.dtype == torch.int32
- assert histogram.device.type == 'cuda'
- assert index1.device.type == 'cuda'
- assert index2.device.type == 'cuda'
- assert source.device.type == 'cuda'
+ assert histogram.device.type == "cuda"
+ assert index1.device.type == "cuda"
+ assert index2.device.type == "cuda"
+ assert source.device.type == "cuda"
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
@@ -715,7 +1043,9 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized(): torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
- raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}')
+ raise TypeError(
+ f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
+ )
sA = A.shape
sB = B.shape
@@ -725,64 +1055,105 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
correct = True
if len(sA) == 2 and len(sB) == 2:
- if not tA and not tB and A.shape[1] != B.shape[0]: correct = False
- elif tA and not tB and A.shape[0] != B.shape[0]: correct = False
- elif tA and tB and A.shape[0] != B.shape[1]: correct = False
- elif not tA and tB and A.shape[1] != B.shape[1]: correct = False
+ if not tA and not tB and A.shape[1] != B.shape[0]:
+ correct = False
+ elif tA and not tB and A.shape[0] != B.shape[0]:
+ correct = False
+ elif tA and tB and A.shape[0] != B.shape[1]:
+ correct = False
+ elif not tA and tB and A.shape[1] != B.shape[1]:
+ correct = False
elif len(sA) == 3 and len(sB) == 2:
- if not tA and not tB and A.shape[2] != B.shape[0]: correct = False
- elif tA and not tB and A.shape[1] != B.shape[0]: correct = False
- elif tA and tB and A.shape[1] != B.shape[1]: correct = False
- elif not tA and tB and A.shape[2] != B.shape[1]: correct = False
+ if not tA and not tB and A.shape[2] != B.shape[0]:
+ correct = False
+ elif tA and not tB and A.shape[1] != B.shape[0]:
+ correct = False
+ elif tA and tB and A.shape[1] != B.shape[1]:
+ correct = False
+ elif not tA and tB and A.shape[2] != B.shape[1]:
+ correct = False
elif len(sA) == 3 and len(sB) == 3:
- if not tA and not tB and A.shape[2] != B.shape[1]: correct = False
- elif tA and not tB and A.shape[1] != B.shape[1]: correct = False
- elif tA and tB and A.shape[1] != B.shape[2]: correct = False
- elif not tA and tB and A.shape[2] != B.shape[2]: correct = False
+ if not tA and not tB and A.shape[2] != B.shape[1]:
+ correct = False
+ elif tA and not tB and A.shape[1] != B.shape[1]:
+ correct = False
+ elif tA and tB and A.shape[1] != B.shape[2]:
+ correct = False
+ elif not tA and tB and A.shape[2] != B.shape[2]:
+ correct = False
if out is not None:
sout = out.shape
# special case common in backprop
if not correct and len(sA) == 3 and len(sB) == 3:
- if (sout[0] == sA[2] and sout[1] == sB[2] and
- sA[0] == sB[0] and sA[1] == sB[1]):
+ if (
+ sout[0] == sA[2]
+ and sout[1] == sB[2]
+ and sA[0] == sB[0]
+ and sA[1] == sB[1]
+ ):
correct = True
else:
if len(sA) == 2 and len(sB) == 2:
- if not tA and not tB: sout = (sA[0], sB[1])
- elif tA and tB: sout = (sA[1], sB[0])
- elif tA and not tB: sout = (sA[1], sB[1])
- elif not tA and tB: sout = (sA[0], sB[0])
+ if not tA and not tB:
+ sout = (sA[0], sB[1])
+ elif tA and tB:
+ sout = (sA[1], sB[0])
+ elif tA and not tB:
+ sout = (sA[1], sB[1])
+ elif not tA and tB:
+ sout = (sA[0], sB[0])
elif len(sA) == 3 and len(sB) == 2:
- if not tA and not tB: sout = (sA[0], sA[1], sB[1])
- elif tA and tB: sout = (sA[0], sA[2], sB[0])
- elif tA and not tB: sout = (sA[0], sA[2], sB[1])
- elif not tA and tB: sout = (sA[0], sA[1], sB[0])
+ if not tA and not tB:
+ sout = (sA[0], sA[1], sB[1])
+ elif tA and tB:
+ sout = (sA[0], sA[2], sB[0])
+ elif tA and not tB:
+ sout = (sA[0], sA[2], sB[1])
+ elif not tA and tB:
+ sout = (sA[0], sA[1], sB[0])
elif len(sA) == 3 and len(sB) == 3:
- if not tA and not tB: sout = (sA[0], sA[1], sB[2])
- elif tA and tB: sout = (sA[0], sA[2], sB[1])
- elif tA and not tB: sout = (sA[0], sA[2], sB[2])
- elif not tA and tB: sout = (sA[0], sA[1], sB[1])
-
+ if not tA and not tB:
+ sout = (sA[0], sA[1], sB[2])
+ elif tA and tB:
+ sout = (sA[0], sA[2], sB[1])
+ elif tA and not tB:
+ sout = (sA[0], sA[2], sB[2])
+ elif not tA and tB:
+ sout = (sA[0], sA[1], sB[1])
if not correct:
- raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.')
+ raise ValueError(
+ f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
+ )
return sout
-def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
+
+def igemm(
+ A: Tensor,
+ B: Tensor,
+ out: Tensor = None,
+ transposed_A=False,
+ transposed_B=False,
+):
sout = check_matmul(A, B, out, transposed_A, transposed_B)
- if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
+ if out is None:
+ out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if len(A.shape) == 3 and len(B.shape) == 3:
if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
return batched_igemm(A, B, out)
sA = A.shape
sB = B.shape
- if transposed_A and len(sA) == 2: sA = (sA[1], sA[0])
- elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0])
- if transposed_B and len(sB) == 2: sB = (sB[1], sB[0])
- elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0])
+ if transposed_A and len(sA) == 2:
+ sA = (sA[1], sA[0])
+ elif transposed_A and len(sA) == 3:
+ sA = (sA[0], sA[2], sA[0])
+ if transposed_B and len(sB) == 2:
+ sB = (sB[1], sB[0])
+ elif transposed_B and len(sB) == 3:
+ sB = (sB[0], sB[2], sB[0])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
@@ -793,23 +1164,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if len(sB) == 2:
- if B.stride()[0] == B.shape[1]: transposed_B = False
- elif B.stride()[1] == B.shape[0]: transposed_B = True
+ if B.stride()[0] == B.shape[1]:
+ transposed_B = False
+ elif B.stride()[1] == B.shape[0]:
+ transposed_B = True
if len(A.shape) == 2:
- if A.stride()[0] == A.shape[1]: transposed_A = False
- elif A.stride()[1] == A.shape[0]: transposed_A = True
+ if A.stride()[0] == A.shape[1]:
+ transposed_A = False
+ elif A.stride()[1] == A.shape[0]:
+ transposed_A = True
else:
- if A.stride()[1] == A.shape[2]: transposed_A = False
- elif A.stride()[2] == A.shape[1]: transposed_A = True
+ if A.stride()[1] == A.shape[2]:
+ transposed_A = False
+ elif A.stride()[2] == A.shape[1]:
+ transposed_A = True
if len(sA) == 2:
n = sA[0]
ldb = A.stride()[1 if transposed_A else 0]
elif len(sA) == 3 and len(sB) == 2:
- n = sA[0]*sA[1]
+ n = sA[0] * sA[1]
ldb = sA[2]
-
m = sB[1]
k = sB[0]
lda = B.stride()[(1 if transposed_B else 0)]
@@ -818,20 +1194,21 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
# special case
assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]):
- raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}')
+ raise ValueError(
+ f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
+ )
transposed_A = True
transposed_B = False
m = sB[2]
n = sA[2]
- k = sB[0]*sB[1]
+ k = sB[0] * sB[1]
lda = m
ldb = sA[2]
ldc = m
-
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
@@ -842,11 +1219,20 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed
return out
-def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
+def batched_igemm(
+ A: Tensor,
+ B: Tensor,
+ out: Tensor = None,
+ transposed_A=False,
+ transposed_B=False,
+):
if not len(A.shape) == 3 or not len(B.shape) == 3:
- raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}')
+ raise ValueError(
+ f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
+ )
sout = check_matmul(A, B, out, transposed_A, transposed_B)
- if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
+ if out is None:
+ out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
if B.is_contiguous():
lda = B.stride()[1]
@@ -903,9 +1289,9 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ldc = m
- strideA = B.shape[1]*B.shape[2]
- strideB = A.shape[1]*A.shape[2]
- strideC = A.shape[1]*B.shape[2]
+ strideA = B.shape[1] * B.shape[2]
+ strideB = A.shape[1] * A.shape[2]
+ strideC = A.shape[1] * B.shape[2]
ptr = CUBLAS_Context.get_instance().get_context(A.device)
@@ -915,6 +1301,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
return out
+
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeA = SA[0]
shapeB = SB[0]
@@ -924,7 +1311,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
- m = shapeA[0]*shapeA[1]
+ m = shapeA[0] * shapeA[1]
rows = n = shapeB[0]
assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
@@ -936,20 +1323,26 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None:
- out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
+ out, Sout = get_transform_buffer(
+ (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
+ )
elif dimsA == 3 and out is None:
- out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
+ out, Sout = get_transform_buffer(
+ (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
+ )
- assert dimsB != 3, 'len(B.shape)==3 not supported'
- assert A.device.type == 'cuda'
- assert B.device.type == 'cuda'
+ assert dimsB != 3, "len(B.shape)==3 not supported"
+ assert A.device.type == "cuda"
+ assert B.device.type == "cuda"
assert A.dtype == torch.int8
assert B.dtype == torch.int8
assert out.dtype == dtype
- assert SA[1] == 'col32'
- assert SB[1] in ['col_turing', 'col_ampere']
- assert Sout[1] == 'col32'
- assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}'
+ assert SA[1] == "col32"
+ assert SB[1] in ["col_turing", "col_ampere"]
+ assert Sout[1] == "col32"
+ assert (
+ shapeA[-1] == shapeB[-1]
+ ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
formatB = SB[1]
prev_device = A.device
torch.cuda.set_device(A.device)
@@ -960,17 +1353,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
ptrC = get_ptr(out)
k = shapeA[-1]
- lda = ct.c_int32(m*32)
- if formatB == 'col_turing':
+ lda = ct.c_int32(m * 32)
+ if formatB == "col_turing":
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
- ldb = ct.c_int32(((rows+7)//8)*8*32)
+ ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
else:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
- ldb = ct.c_int32(((rows+31)//32)*32*32)
+ ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
- ldc = ct.c_int32(m*32)
+ ldc = ct.c_int32(m * 32)
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
@@ -980,14 +1373,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32:
- has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_turing_32(
+ ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
+ )
else:
- has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
- elif formatB == 'col_ampere':
+ has_error = lib.cigemmlt_turing_8(
+ ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
+ )
+ elif formatB == "col_ampere":
if dtype == torch.int32:
- has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_ampere_32(
+ ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
+ )
else:
- has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_ampere_8(
+ ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
+ )
if has_error == 1:
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
@@ -995,20 +1396,39 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
torch.cuda.set_device(prev_device)
-
return out, Sout
-def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None):
+def mm_dequant(
+ A,
+ quant_state,
+ row_stats,
+ col_stats,
+ out=None,
+ new_row_stats=None,
+ new_col_stats=None,
+):
assert A.dtype == torch.int32
out_shape = quant_state[0]
- if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2])
-
- if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
- if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
- if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
- assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
- assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
+ if len(out_shape) == 3:
+ out_shape = (out_shape[0] * out_shape[1], out_shape[2])
+
+ if out is None:
+ out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
+ if new_row_stats is None:
+ new_row_stats = torch.empty(
+ out_shape[0], dtype=torch.float32, device=A.device
+ )
+ if new_col_stats is None:
+ new_col_stats = torch.empty(
+ out_shape[1], dtype=torch.float32, device=A.device
+ )
+ assert (
+ new_row_stats.shape[0] == row_stats.shape[0]
+ ), f"{new_row_stats.shape} vs {row_stats.shape}"
+ assert (
+ new_col_stats.shape[0] == col_stats.shape[0]
+ ), f"{new_col_stats.shape} vs {col_stats.shape}"
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
@@ -1025,22 +1445,33 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
return out
-def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
+def get_colrow_absmax(
+ A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
+):
assert A.dtype == torch.float16
device = A.device
cols = A.shape[-1]
if len(A.shape) == 3:
- rows = A.shape[0]*A.shape[1]
+ rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]
- col_tiles = (cols+255)//256
- tiled_rows = ((rows+15)//16)*16
- if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
- if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)
-
- if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device)
+ col_tiles = (cols + 255) // 256
+ tiled_rows = ((rows + 15) // 16) * 16
+ if row_stats is None:
+ row_stats = torch.empty(
+ (rows,), dtype=torch.float32, device=device
+ ).fill_(-50000.0)
+ if col_stats is None:
+ col_stats = torch.empty(
+ (cols,), dtype=torch.float32, device=device
+ ).fill_(-50000.0)
+
+ if nnz_block_ptr is None and threshold > 0.0:
+ nnz_block_ptr = torch.zeros(
+ ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
+ )
ptrA = get_ptr(A)
ptrRowStats = get_ptr(row_stats)
@@ -1054,13 +1485,12 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
-
if threshold > 0.0:
nnz_block_ptr.cumsum_(0)
-
return row_stats, col_stats, nnz_block_ptr
+
class COOSparseTensor(object):
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
@@ -1077,6 +1507,7 @@ class COOSparseTensor(object):
self.colidx = colidx
self.values = values
+
class CSRSparseTensor(object):
def __init__(self, rows, cols, nnz, rowptr, colidx, values):
assert rowptr.dtype == torch.int32
@@ -1084,7 +1515,7 @@ class CSRSparseTensor(object):
assert values.dtype == torch.float16
assert values.numel() == nnz
assert colidx.numel() == nnz
- assert rowptr.numel() == rows+1
+ assert rowptr.numel() == rows + 1
self.rows = rows
self.cols = cols
@@ -1093,6 +1524,7 @@ class CSRSparseTensor(object):
self.colidx = colidx
self.values = values
+
class CSCSparseTensor(object):
def __init__(self, rows, cols, nnz, colptr, rowidx, values):
assert colptr.dtype == torch.int32
@@ -1100,7 +1532,7 @@ class CSCSparseTensor(object):
assert values.dtype == torch.float16
assert values.numel() == nnz
assert rowidx.numel() == nnz
- assert colptr.numel() == cols+1
+ assert colptr.numel() == cols + 1
self.rows = rows
self.cols = cols
@@ -1109,13 +1541,19 @@ class CSCSparseTensor(object):
self.rowidx = rowidx
self.values = values
+
def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1)
- rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device)
+ rowptr = torch.zeros(
+ (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
+ )
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0)
- return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
+ return CSRSparseTensor(
+ cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
+ )
+
def coo2csc(cooA):
val, col2rowidx = torch.sort(cooA.colidx)
@@ -1123,10 +1561,15 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1)
- colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device)
+ colptr = torch.zeros(
+ (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
+ )
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0)
- return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
+ return CSCSparseTensor(
+ cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
+ )
+
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
@@ -1135,23 +1578,29 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
-def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
+def double_quant(
+ A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
+):
device = A.device
assert A.dtype == torch.half
- assert device.type == 'cuda'
+ assert device.type == "cuda"
prev_device = pre_call(A.device)
cols = A.shape[-1]
if len(A.shape) == 3:
- rows = A.shape[0]*A.shape[1]
+ rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]
if row_stats is None or col_stats is None:
- row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
+ row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
+ A, threshold=threshold
+ )
- if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
- if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
+ if out_col is None:
+ out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
+ if out_row is None:
+ out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
coo_tensor = None
ptrA = get_ptr(A)
@@ -1164,21 +1613,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
- coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
+ coo_tensor = coo_zeros(
+ A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
+ )
ptrRowIdx = get_ptr(coo_tensor.rowidx)
ptrColIdx = get_ptr(coo_tensor.colidx)
ptrVal = get_ptr(coo_tensor.values)
ptrRowPtr = get_ptr(nnz_row_ptr)
- lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
+ lib.cdouble_rowcol_quant(
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrOutCol,
+ ptrOutRow,
+ ptrRowIdx,
+ ptrColIdx,
+ ptrVal,
+ ptrRowPtr,
+ ct.c_float(threshold),
+ ct.c_int32(rows),
+ ct.c_int32(cols),
+ )
val, idx = torch.sort(coo_tensor.rowidx)
coo_tensor.rowidx = val
coo_tensor.colidx = coo_tensor.colidx[idx]
coo_tensor.values = coo_tensor.values[idx]
else:
- lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols))
+ lib.cdouble_rowcol_quant(
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrOutCol,
+ ptrOutRow,
+ None,
+ None,
+ None,
+ None,
+ ct.c_float(0.0),
+ ct.c_int32(rows),
+ ct.c_int32(cols),
+ )
else:
- lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
+ lib.cdouble_rowcol_quant(
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrOutCol,
+ ptrOutRow,
+ None,
+ None,
+ None,
+ None,
+ ct.c_float(threshold),
+ ct.c_int32(rows),
+ ct.c_int32(cols),
+ )
post_call(prev_device)
return out_row, out_col, row_stats, col_stats, coo_tensor
@@ -1187,7 +1677,9 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None,
def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
- print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
+ print(
+ f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
+ )
assert major >= 7
if major == 7: return 'col_turing'
@@ -1209,7 +1701,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1])
else:
- dim1 = ct.c_int32(shape[0]*shape[1])
+ dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
ptrA = get_ptr(A)
@@ -1220,20 +1712,20 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == 'col_turing':
+ elif to_order == "col_turing":
if transpose:
lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == 'col_ampere':
+ elif to_order == "col_ampere":
if transpose:
lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == 'row':
- if from_order == 'col_turing':
+ elif to_order == "row":
+ if from_order == "col_turing":
lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
- elif from_order == 'col_ampere':
+ elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
@@ -1242,15 +1734,19 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
return out, new_state
+
def spmm_coo(cooA, B, out=None):
- if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
+ if out is None:
+ out = torch.empty(
+ (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
+ )
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
assert cooA.cols == B.shape[0]
- transposed_B = (False if B.is_contiguous() else True)
+ transposed_B = False if B.is_contiguous() else True
ldb = B.stride()[(1 if transposed_B else 0)]
ldc = B.shape[1]
@@ -1274,15 +1770,19 @@ def spmm_coo(cooA, B, out=None):
return out
+
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
- if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
+ if out is None:
+ out = torch.zeros(
+ (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
+ )
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
- assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}'
+ assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
- transposed_B = (False if B.is_contiguous() else True)
+ transposed_B = False if B.is_contiguous() else True
ldb = B.stride()[(1 if transposed_B else 0)]
ldc = B.shape[1]
@@ -1292,7 +1792,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
max_count, max_idx = torch.sort(counts, descending=True)
max_idx = max_idx.int()
max_count = max_count.int()
- assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.'
+ assert (
+ max_count[0] <= 32
+ ), f"Current max count per row is 8 but found {max_count[0]}."
assert B.dtype in [torch.float16, torch.int8]
ptrOffset = get_ptr(offset)
ptrMaxCount = get_ptr(max_count)
@@ -1312,137 +1814,188 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB = ct.c_int32(B.shape[1])
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
- #print(cooA.rowidx[:64])
- #print(cooA.colidx[:64].sort()[0])
+ # print(cooA.rowidx[:64])
+ # print(cooA.colidx[:64].sort()[0])
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16:
- lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
+ lib.cspmm_coo_very_sparse_naive_fp16(
+ ptrMaxCount,
+ ptrMaxIdx,
+ ptrOffset,
+ ptrRowidx,
+ ptrColidx,
+ ptrValues,
+ ptrB,
+ ptrC,
+ ptrDequantStats,
+ cnnz_rows,
+ cnnz,
+ crowsA,
+ crowsB,
+ ccolsB,
+ )
elif B.dtype == torch.int8:
- lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
- #else: assertion error
+ lib.cspmm_coo_very_sparse_naive_int8(
+ ptrMaxCount,
+ ptrMaxIdx,
+ ptrOffset,
+ ptrRowidx,
+ ptrColidx,
+ ptrValues,
+ ptrB,
+ ptrC,
+ ptrDequantStats,
+ cnnz_rows,
+ cnnz,
+ crowsA,
+ crowsB,
+ ccolsB,
+ )
+ # else: assertion error
return out
C = 127.0
-def vectorwise_quant(x, dim=1, quant_type='vector'):
- if quant_type == 'linear':
+
+def vectorwise_quant(x, dim=1, quant_type="vector"):
+ if quant_type == "linear":
max1 = torch.abs(x).max().float()
- xq = torch.round(x/max1*127).to(torch.int8)
+ xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
- elif quant_type in ['vector', 'row']:
+ elif quant_type in ["vector", "row"]:
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
- xq = torch.round(x*(C/max1)).to(torch.int8)
+ xq = torch.round(x * (C / max1)).to(torch.int8)
return xq, max1
- elif quant_type == 'zeropoint':
+ elif quant_type == "zeropoint":
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
- if dyna == 0: dyna = 1
- qx = 255./dyna
+ if dyna == 0:
+ dyna = 1
+ qx = 255.0 / dyna
minx = x.min()
- zpx = torch.round(minx* qx)
- x = torch.round(qx*x - zpx) + zpx
+ zpx = torch.round(minx * qx)
+ x = torch.round(qx * x - zpx) + zpx
return x, qx
- elif quant_type in ['vector-zeropoint', 'row-zeropoint']:
+ elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
dtype = x.dtype
x = x.float()
- dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True))
- dyna[dyna==0] = 1
- qx = 255./dyna
+ dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
+ x, dim=dim, keepdim=True
+ )
+ dyna[dyna == 0] = 1
+ qx = 255.0 / dyna
minx = torch.amin(x, dim=dim, keepdim=True)
- zpx = torch.round(minx* qx)
- x = torch.round(qx*x - zpx) + zpx
+ zpx = torch.round(minx * qx)
+ x = torch.round(qx * x - zpx) + zpx
return x, qx
- elif quant_type == 'truncated-vector':
+ elif quant_type == "truncated-vector":
with torch.no_grad():
absx = torch.abs(x)
max1 = torch.amax(absx, dim=dim, keepdim=True)
- max1 = max1*0.7
- idx = (absx > max1.expand_as(absx))
+ max1 = max1 * 0.7
+ idx = absx > max1.expand_as(absx)
sign = torch.sign(x[idx])
- x[idx] = max1.expand_as(absx)[idx]*sign
- xq = torch.round(x/max1*C).to(torch.int8)
+ x[idx] = max1.expand_as(absx)[idx] * sign
+ xq = torch.round(x / max1 * C).to(torch.int8)
return xq, max1
- else: return None
+ else:
+ return None
+
-def vectorwise_dequant(xq, max1, quant_type='vector'):
- if quant_type == 'vector':
- x = (xq/C*max1).to(torch.float32)
+def vectorwise_dequant(xq, max1, quant_type="vector"):
+ if quant_type == "vector":
+ x = (xq / C * max1).to(torch.float32)
return x
- else: return None
+ else:
+ return None
+
-def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'):
- if quant_type == 'linear':
- norm = S1*S2/(C*C)
+def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
+ if quant_type == "linear":
+ norm = S1 * S2 / (C * C)
# double cast needed to prevent overflows
- return (xq.float()*norm).to(dtype)
- elif quant_type == 'zeropoint':
- norm = 1.0/(S1*S2)
- return (xq.float()*norm).to(dtype)
- elif quant_type == 'row-zeropoint':
- norm = 1.0/(S1*S2)
+ return (xq.float() * norm).to(dtype)
+ elif quant_type == "zeropoint":
+ norm = 1.0 / (S1 * S2)
+ return (xq.float() * norm).to(dtype)
+ elif quant_type == "row-zeropoint":
+ norm = 1.0 / (S1 * S2)
x = xq.float()
- if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
- if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 3 and len(x.shape) == 2:
+ S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2:
+ S2 = S2.squeeze(0)
if len(S1.shape) == 2:
x *= norm
else:
x *= norm
return x.to(dtype)
- elif quant_type == 'vector-zeropoint':
+ elif quant_type == "vector-zeropoint":
x = xq.float()
- if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
- if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 3 and len(x.shape) == 2:
+ S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2:
+ S2 = S2.squeeze(0)
if len(S1.shape) == 2:
- x *= 1.0/S1
+ x *= 1.0 / S1
else:
- x *= 1.0/S1
- x *= 1.0/S2.t()
+ x *= 1.0 / S1
+ x *= 1.0 / S2.t()
return x.to(dtype)
- elif quant_type == 'row':
+ elif quant_type == "row":
x = xq.float()
- if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
- if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 3 and len(x.shape) == 2:
+ S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2:
+ S2 = S2.squeeze(0)
if len(S1.shape) == 2:
- x *= S1*S2/(C*C)
+ x *= S1 * S2 / (C * C)
else:
- x *= S1*S2/(C*C)
+ x *= S1 * S2 / (C * C)
return x.to(dtype)
- elif quant_type in ['truncated-vector', 'vector']:
+ elif quant_type in ["truncated-vector", "vector"]:
x = xq.float()
- if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
- if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 3 and len(x.shape) == 2:
+ S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2:
+ S2 = S2.squeeze(0)
if len(S1.shape) == 2:
- x *= S1/C
+ x *= S1 / C
else:
- x *= S1/C
- x *= S2/C
+ x *= S1 / C
+ x *= S2 / C
return x.to(dtype)
- else: return None
+ else:
+ return None
def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
- offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
- if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
+ if len(xq.shape) == 2 and len(SB.shape) == 3:
+ SB = SB.squeeze(0)
if len(SB.shape) == 2:
- x *= SB.t()/127
+ x *= SB.t() / 127
else:
- x *= SB/127
- x *= SA[1]/127
- x +=offset
+ x *= SB / 127
+ x *= SA[1] / 127
+ x += offset
return x.to(dtype)
+
def extract_outliers(A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
- assert formatA in ['col_turing', 'col_ampere']
- assert A.device.type == 'cuda'
+ assert formatA in ["col_turing", "col_ampere"]
+ assert A.device.type == "cuda"
- out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
+ out = torch.zeros(
+ (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
+ )
idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0])
@@ -1454,12 +2007,8 @@ def extract_outliers(A, SA, idx):
prev_device = pre_call(A.device)
if formatA == 'col_turing':
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
- elif formatA == 'col_ampere':
+ elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
post_call(prev_device)
return out
-
-
-
-
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
index 03b4655..98d4aa0 100644
--- a/bitsandbytes/nn/__init__.py
+++ b/bitsandbytes/nn/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params
+from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 5013d0b..454dba5 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -1,39 +1,70 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
-import bitsandbytes as bnb
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
-from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
-
-from torch import Tensor, device, dtype
-from torch import nn
-from torch.nn.parameter import Parameter
+import torch
import torch.nn.functional as F
+from torch import Tensor, device, dtype, nn
+from torch.nn.parameter import Parameter
+import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager
-T = TypeVar('T', bound='torch.nn.Module')
+T = TypeVar("T", bound="torch.nn.Module")
+
class StableEmbedding(torch.nn.Embedding):
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
- super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ _weight: Optional[Tensor] = None,
+ ) -> None:
+ super(StableEmbedding, self).__init__(
+ num_embeddings,
+ embedding_dim,
+ padding_idx,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ sparse,
+ _weight,
+ )
self.norm = torch.nn.LayerNorm(embedding_dim)
- GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
+ GlobalOptimManager.get_instance().register_module_override(
+ self, "weight", {"optim_bits": 32}
+ )
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
- ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
- '''
+ """
+
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
@@ -41,29 +72,55 @@ class StableEmbedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
- input, self.weight, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
+ input,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
return self.norm(emb)
class Embedding(torch.nn.Embedding):
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
- super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
- GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ _weight: Optional[Tensor] = None,
+ ) -> None:
+ super(Embedding, self).__init__(
+ num_embeddings,
+ embedding_dim,
+ padding_idx,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ sparse,
+ _weight,
+ )
+ GlobalOptimManager.get_instance().register_module_override(
+ self, "weight", {"optim_bits": 32}
+ )
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
- ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
- '''
+ """
+
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
@@ -71,13 +128,27 @@ class Embedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
- input, self.weight, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
+ input,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
return emb
+
class Int8Params(torch.nn.Parameter):
- def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
+ def __new__(
+ cls,
+ data=None,
+ requires_grad=True,
+ has_fp16_weights=False,
+ CB=None,
+ SCB=None,
+ ):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
@@ -96,14 +167,18 @@ class Int8Params(torch.nn.Parameter):
del CBt
del SCBt
self.data = CB
- setattr(self, 'CB', CB)
- setattr(self, 'SCB', SCB)
+ setattr(self, "CB", CB)
+ setattr(self, "SCB", SCB)
return self
@overload
- def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
- non_blocking: bool = ...) -> T:
+ def to(
+ self: T,
+ device: Optional[Union[int, device]] = ...,
+ dtype: Optional[Union[dtype, str]] = ...,
+ non_blocking: bool = ...,
+ ) -> T:
...
@overload
@@ -115,30 +190,54 @@ class Int8Params(torch.nn.Parameter):
...
def to(self, *args, **kwargs):
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
-
- if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
+ *args, **kwargs
+ )
+
+ if (
+ device is not None
+ and device.type == "cuda"
+ and self.data.device.type == "cpu"
+ ):
+ return self.cuda(device)
else:
- new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
+ new_param = Int8Params(
+ super().to(
+ device=device, dtype=dtype, non_blocking=non_blocking
+ ),
+ requires_grad=self.requires_grad,
+ has_fp16_weights=self.has_fp16_weights,
+ )
new_param.CB = self.CB
new_param.SCB = self.SCB
return new_param
-
class Linear8bitLt(nn.Linear):
- def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
- super(Linear8bitLt, self).__init__(input_features, output_features, bias)
+ def __init__(
+ self,
+ input_features,
+ output_features,
+ bias=True,
+ has_fp16_weights=True,
+ threshold=0.0,
+ index=None,
+ ):
+ super(Linear8bitLt, self).__init__(
+ input_features, output_features, bias
+ )
self.state = bnb.MatmulLtState()
- self.index=index
+ self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
- self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+ self.weight = Int8Params(
+ self.weight.data, has_fp16_weights=has_fp16_weights
+ )
def init_8bit_state(self):
self.state.CB = self.weight.CB
@@ -149,9 +248,10 @@ class Linear8bitLt(nn.Linear):
def forward(self, x):
self.state.is_training = self.training
- if self.weight.CB is not None: self.init_8bit_state()
- #assert not self.state.has_fp16_weights
- #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
+ if self.weight.CB is not None:
+ self.init_8bit_state()
+ # assert not self.state.has_fp16_weights
+ # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out = bnb.matmul(x, self.weight, state=self.state)
@@ -166,8 +266,18 @@ class Linear8bitLt(nn.Linear):
return out
+
class Linear8bit(nn.Linear):
- def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
+ def __init__(
+ self,
+ input_features,
+ output_features,
+ bias=True,
+ quant_type="vector",
+ index=None,
+ args=None,
+ sparse_decomp=False,
+ ):
super(Linear8bit, self).__init__(input_features, output_features, bias)
self.quant_type = quant_type
self.index = index
@@ -178,15 +288,24 @@ class Linear8bit(nn.Linear):
self.iter += 1
if self.iter % self.args.clip_freq == 0:
with torch.no_grad():
- maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
+ maxval, maxidx = torch.topk(
+ torch.abs(self.weight.flatten()), k=self.args.clip_idx
+ )
if not dist.is_initialized() or dist.get_rank() == 0:
- print('clip', maxval[-1].item())
+ print("clip", maxval[-1].item())
self.weight.clip_(-maxval[-1], maxval[-1])
-
if self.args is not None:
- out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
+ out = bnb.nn.functional.sparse_decomposed_linear8bit(
+ x,
+ self.weight,
+ self.bias,
+ qval=self.args.sparse_decomp_val,
+ quant_type=self.args.quant_type,
+ )
else:
- out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
+ out = bnb.nn.functional.linear8bit(
+ x, self.weight, self.bias, quant_type=self.args.quant_type
+ )
return out
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index 42b5bc0..a76d717 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.cextension import COMPILED_WITH_CUDA
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py
index 4f51250..7e2f566 100644
--- a/bitsandbytes/optim/adagrad.py
+++ b/bitsandbytes/optim/adagrad.py
@@ -1,54 +1,132 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class Adagrad(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
- super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise ValueError("Lr Decay != 0.0 not supported!")
+ super(Adagrad, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adagrad8bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=8,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
+ raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise
- super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ super(Adagrad8bit, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adagrad32bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
- super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise ValueError("Lr Decay != 0.0 not supported!")
+ super(Adagrad32bit, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index ed1b9f0..3634971 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
@@ -8,29 +8,97 @@ import os
import torch
import torch.distributed as dist
-from bitsandbytes.optim.optimizer import Optimizer2State
+
import bitsandbytes.functional as F
+from bitsandbytes.optim.optimizer import Optimizer2State
+
class Adam(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam, self).__init__('adam', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adam8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam8bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam8bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adam32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam32bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam32bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
class AnalysisAdam(torch.optim.Optimizer):
@@ -68,11 +136,15 @@ class AnalysisAdam(torch.optim.Optimizer):
eps=1e-8,
weight_decay=0,
amsgrad=False,
- bnb_analysis='dynamic-blockwise',
- savedir=None
+ bnb_analysis="dynamic-blockwise",
+ savedir=None,
):
defaults = dict(
- lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ amsgrad=amsgrad,
)
super(AnalysisAdam, self).__init__(params, defaults)
self.analysis = bnb_analysis
@@ -124,9 +196,15 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
- state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
- state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
- state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device)
+ state["abserrors"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
+ state["relerrors"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
+ state["counts"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@@ -142,10 +220,12 @@ class AnalysisAdam(torch.optim.Optimizer):
beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
- step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
- e = state['abserrors']
- rele = state['relerrors']
- counts = state['counts']
+ step_size = (
+ group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+ )
+ e = state["abserrors"]
+ rele = state["relerrors"]
+ counts = state["counts"]
if group["weight_decay"] != 0:
p_data_fp32.add_(
@@ -156,77 +236,91 @@ class AnalysisAdam(torch.optim.Optimizer):
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
-
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
- update_fp32 = exp_avg/denom
+ update_fp32 = exp_avg / denom
- if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000:
+ if (
+ p_data_fp32.numel() <= 8192
+ or p_data_fp32.numel() > 50000 * 1000
+ ):
# embedding layer or too small
- p_data_fp32 += -step_size*update_fp32
+ p_data_fp32 += -step_size * update_fp32
else:
- if self.analysis == 'dynamic-blockwise':
+ if self.analysis == "dynamic-blockwise":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize_blockwise(exp_avg, code=code1)
state1 = F.dequantize_blockwise(C1, S1)
C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2)
state2 = F.dequantize_blockwise(C2, S2)
- elif self.analysis == 'dynamic':
+ elif self.analysis == "dynamic":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
- elif self.analysis == 'linear':
+ elif self.analysis == "linear":
code1 = F.create_linear_map(signed=True).to(p.device)
code2 = F.create_linear_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
- elif self.analysis == 'quantile':
+ elif self.analysis == "quantile":
code1 = F.estimate_quantiles(exp_avg)
code2 = F.estimate_quantiles(exp_avg_sq)
C1 = F.quantize_no_absmax(exp_avg, code=code1)
state1 = F.dequantize_no_absmax(C1, code1)
C2 = F.quantize_no_absmax(exp_avg_sq, code=code2)
state2 = F.dequantize_no_absmax(C2, code2)
- elif self.analysis == 'my-quantization-routine':
+ elif self.analysis == "my-quantization-routine":
pass
# 1. get code
# 2. quantize
# 3. dequantize
# Error will be calculated automatically!
else:
- raise ValueError(f'Invalid analysis value: {self.analysis}!')
+ raise ValueError(
+ f"Invalid analysis value: {self.analysis}!"
+ )
denom = state2.sqrt().add_(group["eps"])
- update_8bit = state1/denom
+ update_8bit = state1 / denom
- abserr = torch.abs(update_8bit-update_fp32)
- relerr = abserr/torch.abs(update_fp32+1e-6)
+ abserr = torch.abs(update_8bit - update_fp32)
+ relerr = abserr / torch.abs(update_fp32 + 1e-6)
C1, C2 = C1.int(), C2.int()
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
- F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
-
- p_data_fp32 += -step_size*update_fp32
+ F.histogram_scatter_add_2d(
+ counts, C1.int(), C2.int(), torch.ones_like(abserr)
+ )
+ p_data_fp32 += -step_size * update_fp32
if not dist.is_initialized() or dist.get_rank() == 0:
- if self.savedir != '' and state['step'] % 100 == 0:
- if not os.path.exists(self.savedir): os.makedirs(self.savedir)
- shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
- pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
- pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
- pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
+ if self.savedir != "" and state["step"] % 100 == 0:
+ if not os.path.exists(self.savedir):
+ os.makedirs(self.savedir)
+ shapestr = "_".join(
+ [str(dim) for dim in p_data_fp32.shape]
+ )
+ pathe = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
+ )
+ pathrele = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
+ )
+ pathcounts = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_counts.pkl"
+ )
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)
@@ -234,6 +328,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
-
-
return loss
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
index c4f0355..d0b3bde 100644
--- a/bitsandbytes/optim/adamw.py
+++ b/bitsandbytes/optim/adamw.py
@@ -1,27 +1,93 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
+
class AdamW(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW, self).__init__('adam', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class AdamW8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW8bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW8bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
-class AdamW32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW32bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+class AdamW32bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW32bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py
index 58cc13d..8f365f7 100644
--- a/bitsandbytes/optim/lamb.py
+++ b/bitsandbytes/optim/lamb.py
@@ -1,28 +1,105 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
+
class LAMB(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
-class LAMB8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
-class LAMB32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+class LAMB8bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB8bit, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
+class LAMB32bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB32bit, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index 912520d..8a89fb0 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -1,60 +1,154 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
-
from torch.optim import Optimizer
+
from bitsandbytes.optim.optimizer import Optimizer1State
+
class LARS(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
+ super(LARS, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
+
class LARS8bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
+ super(LARS8bit, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
+
class LARS32bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
+ super(LARS32bit, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
class PytorchLARS(Optimizer):
- def __init__(self, params, lr=0.01, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr=0.01,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ max_unorm=0.02,
+ ):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
-
- defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
- weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
+
+ defaults = dict(
+ lr=lr,
+ momentum=momentum,
+ dampening=dampening,
+ weight_decay=weight_decay,
+ nesterov=nesterov,
+ max_unorm=max_unorm,
+ )
if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ raise ValueError(
+ "Nesterov momentum requires a momentum and zero dampening"
+ )
super(PytorchLARS, self).__init__(params, defaults)
def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state)
for group in self.param_groups:
- group.setdefault('nesterov', False)
+ group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
@@ -73,15 +167,16 @@ class PytorchLARS(Optimizer):
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
- weight_decay = group['weight_decay']
- momentum = group['momentum']
- dampening = group['dampening']
- nesterov = group['nesterov']
- max_unorm = group['max_unorm']
- lr = group['lr']
+ weight_decay = group["weight_decay"]
+ momentum = group["momentum"]
+ dampening = group["dampening"]
+ nesterov = group["nesterov"]
+ max_unorm = group["max_unorm"]
+ lr = group["lr"]
- for p in group['params']:
- if p.grad is None: continue
+ for p in group["params"]:
+ if p.grad is None:
+ continue
state = self.state[p]
d_p = p.grad
@@ -89,16 +184,16 @@ class PytorchLARS(Optimizer):
d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0:
- buf = state.get('momentum_buffer', None)
+ buf = state.get("momentum_buffer", None)
if buf is None:
buf = torch.clone(d_p).detach()
- state['momentum_buffer']= buf
+ state["momentum_buffer"] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
- update = d_p + buf*momentum
+ update = d_p + buf * momentum
else:
update = buf
@@ -107,9 +202,9 @@ class PytorchLARS(Optimizer):
assert p.dtype == torch.float32
pnorm = torch.norm(p.detach())
unorm = torch.norm(update)
- if unorm > max_unorm*pnorm:
- update_scale = max_unorm*pnorm/unorm
+ if unorm > max_unorm * pnorm:
+ update_scale = max_unorm * pnorm / unorm
- p.add_(update, alpha=-lr*update_scale)
+ p.add_(update, alpha=-lr * update_scale)
return loss
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 5a5bb1e..4fb30cd 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -1,13 +1,16 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from collections import abc as container_abcs
+from collections import defaultdict
+from copy import deepcopy
+from itertools import chain
+
import torch
+
import bitsandbytes.functional as F
-from copy import deepcopy
-from itertools import chain
-from collections import defaultdict, abc as container_abcs
class MockArgs(object):
def __init__(self, initial_data):
@@ -19,7 +22,7 @@ class GlobalOptimManager(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.pid2config = {}
@@ -38,15 +41,19 @@ class GlobalOptimManager(object):
def register_parameters(self, params):
param_groups = list(params)
if not isinstance(param_groups[0], dict):
- param_groups = [{'params': param_groups}]
+ param_groups = [{"params": param_groups}]
for group_index, group in enumerate(param_groups):
- for p_index, p in enumerate(group['params']):
+ for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
- self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
+ self.index2config[(group_index, p_index)] = self.pid2config[
+ id(p)
+ ]
- def override_config(self, parameters, key=None, value=None, key_value_dict=None):
- '''
+ def override_config(
+ self, parameters, key=None, value=None, key_value_dict=None
+ ):
+ """
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
@@ -63,7 +70,7 @@ class GlobalOptimManager(object):
The value for the hyperparamters.
key_value_dict : dict
A dictionary with multiple key-values to override.
- '''
+ """
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters]
@@ -75,16 +82,16 @@ class GlobalOptimManager(object):
if key_value_dict is not None:
for p in parameters:
- if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
- else: self.pid2config[id(p)] = key_value_dict
+ if id(p) in self.pid2config:
+ self.pid2config[id(p)].update(key_value_dict)
+ else:
+ self.pid2config[id(p)] = key_value_dict
def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))
-
class Optimizer8bit(torch.optim.Optimizer):
-
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.initialized = False
@@ -92,23 +99,32 @@ class Optimizer8bit(torch.optim.Optimizer):
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
- ['qmap1', 'qmap2',
- 'max1', 'max2',
- 'new_max1', 'new_max2',
- 'state1', 'state2',
- 'gnorm_vec', 'absmax1', 'absmax2',
- 'unorm_vec'])
-
- if optim_bits == 8: self.fill_qmap()
+ [
+ "qmap1",
+ "qmap2",
+ "max1",
+ "max2",
+ "new_max1",
+ "new_max2",
+ "state1",
+ "state2",
+ "gnorm_vec",
+ "absmax1",
+ "absmax2",
+ "unorm_vec",
+ ]
+ )
+
+ if optim_bits == 8:
+ self.fill_qmap()
def fill_qmap(self):
- self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
- self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)
+ self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
+ self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
-
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
@@ -120,21 +136,29 @@ class Optimizer8bit(torch.optim.Optimizer):
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
- saved_groups = state_dict['param_groups']
+ saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups):
- raise ValueError("loaded state dict has a different number of "
- "parameter groups")
- param_lens = (len(g['params']) for g in groups)
- saved_lens = (len(g['params']) for g in saved_groups)
+ raise ValueError(
+ "loaded state dict has a different number of "
+ "parameter groups"
+ )
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
- raise ValueError("loaded state dict contains a parameter group "
- "that doesn't match the size of optimizer's group")
+ raise ValueError(
+ "loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group"
+ )
# Update the state
- id_map = {old_id: p for old_id, p in
- zip(chain.from_iterable((g['params'] for g in saved_groups)),
- chain.from_iterable((g['params'] for g in groups)))}
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in groups)),
+ )
+ }
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
@@ -161,7 +185,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
- for k, v in state_dict['state'].items():
+ for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
@@ -170,15 +194,17 @@ class Optimizer8bit(torch.optim.Optimizer):
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
- new_group['params'] = group['params']
+ new_group["params"] = group["params"]
return new_group
+
param_groups = [
- update_group(g, ng) for g, ng in zip(groups, saved_groups)]
- self.__setstate__({'state': state, 'param_groups': param_groups})
+ update_group(g, ng) for g, ng in zip(groups, saved_groups)
+ ]
+ self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
for gindex, group in enumerate(self.param_groups):
- for pindex, p in enumerate(group['params']):
+ for pindex, p in enumerate(group["params"]):
if p in self.state:
values = self.state[p]
for k, v in values.items():
@@ -189,17 +215,23 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
- assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
+ assert isinstance(pmodule, torch.Tensor) or isinstance(
+ pmodule, torch.Parameter
+ )
found = False
for gindex, group in enumerate(self.param_groups):
- if found: break
- for pindex, p in enumerate(group['params']):
- if found: break
+ if found:
+ break
+ for pindex, p in enumerate(group["params"]):
+ if found:
+ break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
- self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
+ self.mng.index2config[
+ (gindex, pindex)
+ ] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
@@ -219,11 +251,11 @@ class Optimizer8bit(torch.optim.Optimizer):
if not self.initialized:
self.check_overrides()
- self.to_gpu() # needed for fairseq pure fp16 training
+ self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
for gindex, group in enumerate(self.param_groups):
- for pindex, p in enumerate(group['params']):
+ for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
@@ -236,58 +268,76 @@ class Optimizer8bit(torch.optim.Optimizer):
def get_config(self, gindex, pindex, group):
config = {}
- config['betas'] = group['betas']
- config['eps'] = group['eps']
- config['weight_decay'] = group['weight_decay']
- config['lr'] = group['lr']
- config['optim_bits'] = self.args.optim_bits
- config['min_8bit_size'] = self.args.min_8bit_size
- config['percentile_clipping'] = self.args.percentile_clipping
- config['block_wise'] = self.args.block_wise
- config['max_unorm'] = self.args.max_unorm
- config['skip_zeros'] = self.args.skip_zeros
+ config["betas"] = group["betas"]
+ config["eps"] = group["eps"]
+ config["weight_decay"] = group["weight_decay"]
+ config["lr"] = group["lr"]
+ config["optim_bits"] = self.args.optim_bits
+ config["min_8bit_size"] = self.args.min_8bit_size
+ config["percentile_clipping"] = self.args.percentile_clipping
+ config["block_wise"] = self.args.block_wise
+ config["max_unorm"] = self.args.max_unorm
+ config["skip_zeros"] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])
return config
def init_state(self, group, p, gindex, pindex):
- raise NotImplementedError(f'init_state method needs to be overidden')
+ raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
- raise NotImplementedError(f'The update_step method needs to be overidden')
+ raise NotImplementedError(
+ f"The update_step method needs to be overidden"
+ )
+
class Optimizer2State(Optimizer8bit):
- def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
- skip_zeros=False):
+ def __init__(
+ self,
+ optimizer_name,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0.0,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ max_unorm=0.0,
+ skip_zeros=False,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
# format: '(beta1, beta2)'
- betas = betas.replace('(', '').replace(')', '').strip().split(',')
+ betas = betas.replace("(", "").replace(")", "").strip().split(",")
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
- args['optim_bits'] = optim_bits
- args['percentile_clipping'] = 100
- args['min_8bit_size'] = min_8bit_size
- args['percentile_clipping'] = percentile_clipping
- args['block_wise'] = block_wise
- args['max_unorm'] = max_unorm
- args['skip_zeros'] = skip_zeros
+ args["optim_bits"] = optim_bits
+ args["percentile_clipping"] = 100
+ args["min_8bit_size"] = min_8bit_size
+ args["percentile_clipping"] = percentile_clipping
+ args["block_wise"] = block_wise
+ args["max_unorm"] = max_unorm
+ args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
@@ -299,50 +349,93 @@ class Optimizer2State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
- if config['optim_bits'] == 32:
+ if config["optim_bits"] == 32:
dtype = torch.float32
- elif config['optim_bits'] == 8:
+ elif config["optim_bits"] == 8:
dtype = torch.uint8
- else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+ else:
+ raise NotImplementedError(
+ f'Amount of optimizer bits not supported: {config["optim_bits"]}'
+ )
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
+ if p.numel() < config["min_8bit_size"]:
+ dtype = torch.float32
state = self.state[p]
- state['step'] = 0
-
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state["step"] = 0
+
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
+ state["state2"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
- self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
-
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
-
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap2'] = self.name2qmap['udynamic']
-
- if config['block_wise']:
+ if state["step"] == 0:
+ if "dynamic" not in self.name2qmap:
+ self.fill_qmap()
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
+ self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
+ p.device
+ )
+
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap1"] = self.name2qmap["dynamic"]
+
+ state["state2"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap2"] = self.name2qmap["udynamic"]
+
+ if config["block_wise"]:
n = p.numel()
- blocks = n//2048
+ blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
- state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state["absmax1"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
+ state["absmax2"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
else:
- state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
-
- if config['percentile_clipping'] < 100:
- state['gnorm_vec'] = torch.zeros((100,), device=p.device)
-
- if config['max_unorm'] > 0.0:
- state['unorm_vec'] = torch.zeros((1,), device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["new_max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["new_max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+
+ if config["percentile_clipping"] < 100:
+ state["gnorm_vec"] = torch.zeros((100,), device=p.device)
+
+ if config["max_unorm"] > 0.0:
+ state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
@@ -351,63 +444,128 @@ class Optimizer2State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
- state['step'] += 1
- step = state['step']
+ state["step"] += 1
+ step = state["step"]
- if config['percentile_clipping'] < 100:
- current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ if config["percentile_clipping"] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
+ grad, state["gnorm_vec"], step, config["percentile_clipping"]
+ )
else:
gnorm_scale = 1.0
- if state['state1'].dtype == torch.float:
- F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
- state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
-
- elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
- F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
- config['weight_decay'], gnorm_scale=gnorm_scale,
- unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+ if state["state1"].dtype == torch.float:
+ F.optimizer_update_32bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ config["betas"][0],
+ config["eps"],
+ step,
+ config["lr"],
+ state["state2"],
+ config["betas"][1],
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ skip_zeros=config["skip_zeros"],
+ )
+
+ elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
+ F.optimizer_update_8bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ state["state2"],
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ state["qmap2"],
+ state["max1"],
+ state["max2"],
+ state["new_max1"],
+ state["new_max2"],
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ unorm_vec=state["unorm_vec"]
+ if config["max_unorm"] > 0.0
+ else None,
+ max_unorm=config["max_unorm"],
+ )
# swap maxes
- state['max1'], state['new_max1'] = state['new_max1'], state['max1']
- state['max2'], state['new_max2'] = state['new_max2'], state['max2']
- elif state['state1'].dtype == torch.uint8 and config['block_wise']:
- F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
+ state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
+ state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
+ elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
+ F.optimizer_update_8bit_blockwise(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ state["state2"],
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ state["qmap2"],
+ state["absmax1"],
+ state["absmax2"],
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ skip_zeros=config["skip_zeros"],
+ )
class Optimizer1State(Optimizer8bit):
- def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
- weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
- skip_zeros=False):
+ def __init__(
+ self,
+ optimizer_name,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.0),
+ eps=1e-8,
+ weight_decay=0.0,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ max_unorm=0.0,
+ skip_zeros=False,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
- args['optim_bits'] = optim_bits
- args['percentile_clipping'] = 100
- args['min_8bit_size'] = min_8bit_size
- args['percentile_clipping'] = percentile_clipping
- args['block_wise'] = block_wise
- args['max_unorm'] = max_unorm
- args['skip_zeros'] = skip_zeros
+ args["optim_bits"] = optim_bits
+ args["percentile_clipping"] = 100
+ args["min_8bit_size"] = min_8bit_size
+ args["percentile_clipping"] = percentile_clipping
+ args["block_wise"] = block_wise
+ args["max_unorm"] = max_unorm
+ args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
@@ -419,43 +577,67 @@ class Optimizer1State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
- if config['optim_bits'] == 32:
+ if config["optim_bits"] == 32:
dtype = torch.float32
- elif config['optim_bits'] == 8:
+ elif config["optim_bits"] == 8:
dtype = torch.uint8
- else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+ else:
+ raise NotImplementedError(
+ f'Amount of optimizer bits not supported: {config["optim_bits"]}'
+ )
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
+ if p.numel() < config["min_8bit_size"]:
+ dtype = torch.float32
state = self.state[p]
- state['step'] = 0
-
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state["step"] = 0
+
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
-
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
-
- if config['block_wise']:
+ if state["step"] == 0:
+ if "dynamic" not in self.name2qmap:
+ self.fill_qmap()
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
+
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap1"] = self.name2qmap["dynamic"]
+
+ if config["block_wise"]:
n = p.numel()
- blocks = n//2048
+ blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state["absmax1"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
else:
- state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
-
- if config['percentile_clipping'] < 100:
- state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["new_max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
- if config['max_unorm'] > 0.0:
- state['unorm_vec'] = torch.zeros((1,), device=p.device)
+ if config["percentile_clipping"] < 100:
+ state["gnorm_vec"] = torch.zeros((100,), device=p.device)
+ if config["max_unorm"] > 0.0:
+ state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
@@ -464,29 +646,77 @@ class Optimizer1State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
- state['step'] += 1
- step = state['step']
+ state["step"] += 1
+ step = state["step"]
- if config['percentile_clipping'] < 100:
- current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ if config["percentile_clipping"] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
+ grad, state["gnorm_vec"], step, config["percentile_clipping"]
+ )
else:
gnorm_scale = 1.0
- if state['state1'].dtype == torch.float:
- F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
- None, 0.0, config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
- skip_zeros=config['skip_zeros'])
-
- elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
- F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
- config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
-
- state['max1'], state['new_max1'] = state['new_max1'], state['max1']
- elif state['state1'].dtype == torch.uint8 and config['block_wise']:
- F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], None, state['absmax1'], None,
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
+ if state["state1"].dtype == torch.float:
+ F.optimizer_update_32bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ config["betas"][0],
+ config["eps"],
+ step,
+ config["lr"],
+ None,
+ 0.0,
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ skip_zeros=config["skip_zeros"],
+ )
+
+ elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
+ F.optimizer_update_8bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ None,
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ None,
+ state["max1"],
+ None,
+ state["new_max1"],
+ None,
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ )
+
+ state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
+ elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
+ F.optimizer_update_8bit_blockwise(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ None,
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ None,
+ state["absmax1"],
+ None,
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ skip_zeros=config["skip_zeros"],
+ )
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
index 0f1ffaa..7ddb12c 100644
--- a/bitsandbytes/optim/rmsprop.py
+++ b/bitsandbytes/optim/rmsprop.py
@@ -1,36 +1,115 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class RMSprop(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class RMSprop8bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop8bit, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class RMSprop32bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop32bit, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py
index 0529879..f7b8934 100644
--- a/bitsandbytes/optim/sgd.py
+++ b/bitsandbytes/optim/sgd.py
@@ -1,32 +1,99 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class SGD(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class SGD8bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD8bit, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class SGD32bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD32bit, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py
new file mode 100644
index 0000000..4256a87
--- /dev/null
+++ b/bitsandbytes/utils.py
@@ -0,0 +1,32 @@
+import shlex
+import subprocess
+import sys
+from typing import Tuple
+
+
+def execute_and_return(command_string: str) -> Tuple[str, str]:
+ def _decode(subprocess_err_out_tuple):
+ return tuple(
+ to_decode.decode("UTF-8").strip()
+ for to_decode in subprocess_err_out_tuple
+ )
+
+ def execute_and_return_decoded_std_streams(command_string):
+ return _decode(
+ subprocess.Popen(
+ shlex.split(command_string),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ ).communicate()
+ )
+
+ std_out, std_err = execute_and_return_decoded_std_streams(command_string)
+ return std_out, std_err
+
+
+def print_stderr(s: str) -> None:
+ print(s, file=sys.stderr)
+
+
+def warn_of_missing_prerequisite(s: str) -> None:
+ print_stderr("WARNING, missing pre-requisite: " + s)
diff --git a/deploy_from_slurm.sh b/deploy_from_slurm.sh
index 37311bc..c6ee84d 100644
--- a/deploy_from_slurm.sh
+++ b/deploy_from_slurm.sh
@@ -14,256 +14,192 @@ module unload cuda
module unload gcc
rm -rf dist build
-make clean
make cleaneggs
+make cleanlibs
+
+make clean
export CUDA_HOME=
-make cpuonly
+export CUDA_VERSION=
+make cpuonly CUDA_VERSION="CPU"
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cpu.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=cpu python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.0
-make cuda110
+make cuda110 CUDA_VERSION=110
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=110 python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.1
-make cuda11x
+make cuda11x CUDA_VERSION=111
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=111 python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.2
-make cuda11x
+make cuda11x CUDA_VERSION=112
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda112.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=112 python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.3
-make cuda11x
+make cuda11x CUDA_VERSION=113
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda113.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=113 python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.4
-make cuda11x
+make cuda11x CUDA_VERSION=114
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=114 python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.5
-make cuda11x
+make cuda11x CUDA_VERSION=115
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=115 python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.6
-make cuda11x
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+make cuda11x CUDA_VERSION=116
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda116.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=116 python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.7
-make cuda11x
+make cuda11x CUDA_VERSION=117
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=117 python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-10.2
-make cuda10x_nomatmul
+make cuda10x_nomatmul CUDA_VERSION=102
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=102-nomatmul python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.0
-make cuda110_nomatmul
+make cuda110_nomatmul CUDA_VERSION=110
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=110-nomatmul python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.1
-make cuda11x_nomatmul
+make cuda11x_nomatmul CUDA_VERSION=111
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=111-nomatmul python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.2
-make cuda11x_nomatmul
+make cuda11x_nomatmul CUDA_VERSION=112
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda112_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=112-nomatmul python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.3
-make cuda11x_nomatmul
+make cuda11x_nomatmul CUDA_VERSION=113
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda113_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=113-nomatmul python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.4
-make cuda11x_nomatmul
+make cuda11x_nomatmul CUDA_VERSION=114
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=114-nomatmul python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.5
-make cuda11x_nomatmul
+make cuda11x_nomatmul CUDA_VERSION=115
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=115-nomatmul python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.6
-make cuda11x_nomatmul
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+make cuda11x_nomatmul CUDA_VERSION=116
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda116_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=116-nomatmul python -m build
-python -m twine upload dist/* --verbose --repository testpypi
-rm -rf dist build
make clean
-make cleaneggs
export CUDA_HOME=$BASE_PATH/cuda-11.7
-make cuda11x_nomatmul
+make cuda11x_nomatmul CUDA_VERSION=117
-if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
-CUDA_VERSION=117-nomatmul python -m build
+
+python -m build
python -m twine upload dist/* --verbose --repository testpypi
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000..6bc6f9a
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,14 @@
+name: 8-bit
+channels:
+ - conda-forge
+dependencies:
+ - python=3.9
+ - pytest
+ - pytorch
+ - torchaudio
+ - torchvision
+ - cudatoolkit=11.1
+ - typer
+ - ca-certificates
+ - certifi
+ - openssl
diff --git a/install_cuda.sh b/install_cuda.sh
deleted file mode 100644
index 6a4ff0c..0000000
--- a/install_cuda.sh
+++ /dev/null
@@ -1,5 +0,0 @@
-wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
-bash cuda_11.1.1_455.32.00_linux.run --no-drm --no-man-page --override --installpath=~/local --librarypath=~/local/lib --toolkitpath=~/local/cuda-11.1/ --toolkit --silent
-echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/local/cuda-11.1/lib64/" >> ~/.bashrc
-echo "export PATH=$PATH:~/local/cuda-11.1/bin/" >> ~/.bashrc
-source ~/.bashrc
diff --git a/quicktest.py b/quicktest.py
deleted file mode 100644
index 2db6afa..0000000
--- a/quicktest.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import torch
-import bitsandbytes as bnb
-import bitsandbytes.functional as F
-
-from itertools import product
-
-def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
- k = 25
- for i in range(k):
- if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
- elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
- C1 = torch.matmul(A.float(), B.t().float())
-
- A2, SA = F.transform(A, 'col32')
- B2, SB = F.transform(B, 'colx')
- if dims == 2:
- C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
- else:
- C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
- F.igemmlt(A2, B2, C2, SA, SB, SC)
- C3, S = F.transform(C2, 'row', state=SC)
- #torch.testing.assert_allclose(C1, C3.float())
- #print(C1)
- #print(C2)
- #print(C3)
- allclose = torch.allclose(C1, C3.float())
- if allclose:
- print(C1)
- print(C2)
- print(C3)
-
- ## transposed
- #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
- #if dims == 2:
- # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
- # C1 = torch.matmul(A.float(), B.float().t())
- #elif dims == 3:
- # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
- # C1 = torch.matmul(B.float(), A.t().float())
- # C1 = C1.permute([2, 0, 1])
-
- #A2, SA = F.transform(A, 'col32')
- #B2, SB = F.transform(B, 'colx')
- #if dims == 2:
- # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
- #else:
- # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
- # state = (C2.shape, 'row', A.shape[0])
- # C2, SC = F.transform(C2, 'col32', state=state)
- #F.igemmlt(A2, B2, C2, SA, SB, SC)
- #C3, S = F.transform(C2, 'row', state=SC, ld=[0])
- #torch.testing.assert_allclose(C1, C3.float())
-
- ## weight update
- #if dims == 3:
- # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
- # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
- # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
-
- # A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
- # B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
- # C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
- # C2, SC = F.transform(C2, 'col32')
- # F.igemmlt(B2, A2, C2, SB, SA, SC)
- # C3, S = F.transform(C2, 'row', state=SC)
- # torch.testing.assert_allclose(C1, C3.float())
-
-
-dims = (2, 3)
-ldb = [0]
-
-n = 2
-dim1 = torch.randint(1,256, size=(n,)).tolist()
-dim2 = torch.randint(32,512, size=(n,)).tolist()
-dim3 = torch.randint(32,1024, size=(n,)).tolist()
-dim4 = torch.randint(32,1024, size=(n,)).tolist()
-values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
-
-for ldb in range(32, 4096, 32):
-#for ldb in [None]:
- val = test_igemmlt(2, 2, 2, 2, 2, ldb)
- if val:
- print(val, ldb)
- else:
- print('nope', ldb)
-#for val in values:
- #test_igemmlt(*val)
diff --git a/setup.py b/setup.py
index 6275ddd..cec4982 100644
--- a/setup.py
+++ b/setup.py
@@ -1,21 +1,24 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+import glob
import os
-from setuptools import setup, find_packages
+
+from setuptools import find_packages, setup
+
+libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so"))
+libs = [os.path.basename(p) for p in libs]
+print("libs:", libs)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
-version = os.getenv("CUDA_VERSION", "cpu")
-prefix = '' if version == 'cpu' else 'cuda'
-
setup(
- name=f"bitsandbytes-{prefix}{version}",
- version=f"0.30.0",
+ name=f"bitsandbytes",
+ version=f"0.31.1",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.",
@@ -23,11 +26,14 @@ setup(
keywords="gpu optimizers optimization 8-bit quantization compression",
url="http://packages.python.org/bitsandbytes",
packages=find_packages(),
- package_data={'': ['libbitsandbytes.so']},
- long_description=read('README.md'),
- long_description_content_type='text/markdown',
+ entry_points={
+ "console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
+ },
+ package_data={"": libs},
+ long_description=read("README.md"),
+ long_description_content_type="text/markdown",
classifiers=[
"Development Status :: 4 - Beta",
- 'Topic :: Scientific/Engineering :: Artificial Intelligence'
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index 1b6c2ab..8ebe8c8 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -1,27 +1,44 @@
-import pytest
+from itertools import product
+import pytest
import torch
-import bitsandbytes as bnb
-from itertools import product
+import bitsandbytes as bnb
n = 1
k = 25
-dim1 = torch.randint(16,64, size=(n,)).tolist()
-dim2 = torch.randint(32,96, size=(n,)).tolist()
-dim3 = torch.randint(32,96, size=(n,)).tolist()
-dim4 = torch.randint(32,96, size=(n,)).tolist()
+dim1 = torch.randint(16, 64, size=(n,)).tolist()
+dim2 = torch.randint(32, 96, size=(n,)).tolist()
+dim3 = torch.randint(32, 96, size=(n,)).tolist()
+dim4 = torch.randint(32, 96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
-str_funcs = ['bmm', 'matmul']
+str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
-req_grad_str = ['FF', 'TF', 'TT', 'FT']
+req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
-str_transpose = ['FF', 'FT', 'TT', 'TF']
+str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
-values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose))
-str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
-@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
+values = list(
+ product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
+)
+str_values = list(
+ product(
+ dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
+ )
+)
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
+ *vals
+ )
+ for vals in str_values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
+ values,
+ ids=names,
+)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
@@ -33,9 +50,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
- A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0])
- B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
- target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
+ B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
+ target = torch.randn(
+ size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
+ )
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
@@ -53,9 +72,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.0175
+ assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() < n*0.001
+ assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
@@ -67,7 +86,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -75,20 +96,36 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
- torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+ assert (idx == 0).sum().item() < n * 0.02
+ torch.testing.assert_allclose(
+ gradB1, gradB2, atol=0.18, rtol=0.3
+ )
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
- A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
- B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1])
- target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ A = torch.randn(
+ size=(dim1, dim2, dim3),
+ device="cuda",
+ requires_grad=req_grad[0],
+ )
+ B = torch.randn(
+ size=(dim1, dim3, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ )
+ target = torch.randn(
+ size=(dim1, dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ )
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
@@ -96,8 +133,10 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.01
- torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
+ assert (idx == 0).sum().item() < n * 0.01
+ torch.testing.assert_allclose(
+ out_bnb, out_torch, atol=0.027, rtol=0.2
+ )
if any(req_grad):
out_bnb.data.copy_(out_torch)
@@ -109,7 +148,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -117,20 +158,30 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
+ assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
- A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
+ A = torch.randn(
+ size=(dim1, dim2, dim3),
+ device="cuda",
+ requires_grad=req_grad[0],
+ )
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
- B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
- target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
+ target = torch.randn(
+ size=(dim1, dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ )
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
@@ -142,9 +193,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.0175
+ assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() < n*0.001
+ assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
@@ -156,7 +207,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -164,56 +217,111 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
+ assert (idx == 0).sum().item() < n * 0.02
n = 1
k = 3
-dim1 = torch.randint(16,64, size=(n,)).tolist()
-dim2 = torch.randint(32,96, size=(n,)).tolist()
-dim3 = torch.randint(32,96, size=(n,)).tolist()
-dim4 = torch.randint(32,96, size=(n,)).tolist()
+dim1 = torch.randint(16, 64, size=(n,)).tolist()
+dim2 = torch.randint(32, 96, size=(n,)).tolist()
+dim3 = torch.randint(32, 96, size=(n,)).tolist()
+dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0)
-#dim1 = (17,)
-#dim2 = (7,)
-#dim3 = (37,)
-#dim4 = (23,)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
-str_funcs = ['matmul']
+str_funcs = ["matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
-req_grad_str = ['FF', 'TF', 'TT', 'FT']
+req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, True), (False, False)]
-str_transpose = ['NT', 'NN']
+str_transpose = ["NT", "NN"]
dtype = [torch.float16]
has_fp16_weights = [True, False]
-values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights))
-str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values]
-@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names)
-def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights):
+values = list(
+ product(
+ dim1,
+ dim2,
+ dim3,
+ dim4,
+ funcs,
+ dtype,
+ req_grad,
+ transpose,
+ decomp,
+ has_fp16_weights,
+ )
+)
+str_values = list(
+ product(
+ dim1,
+ dim2,
+ dim3,
+ dim4,
+ str_funcs,
+ dtype,
+ req_grad_str,
+ str_transpose,
+ decomp,
+ has_fp16_weights,
+ )
+)
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
+ *vals
+ )
+ for vals in str_values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
+ values,
+ ids=names,
+)
+def test_matmullt(
+ dim1,
+ dim2,
+ dim3,
+ dim4,
+ funcs,
+ dtype,
+ req_grad,
+ transpose,
+ decomp,
+ has_fp16_weights,
+):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
- outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda')
+ outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
- A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype)
+ A = torch.randn(
+ size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
+ )
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
- B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype)
- target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype)
+ B = torch.randn(
+ size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
+ )
+ target = torch.randn(
+ size=(dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ dtype=dtype,
+ )
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
@@ -221,8 +329,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
state.threshold = decomp
state.has_fp16_weights = has_fp16_weights
if not has_fp16_weights:
- if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous()
- state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2)
+ if not transpose[0] and not transpose[1]:
+ B2 = B2.t().contiguous()
+ (
+ state.CB,
+ CBt,
+ state.SCB,
+ SCBt,
+ coo_tensorB,
+ ) = bnb.functional.double_quant(B2)
B2 = state.CB
if not transpose[0] and transpose[1]:
@@ -233,25 +348,29 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
out_bnb = funcs[1](A, B2.t(), state=state)
n = out_bnb.numel()
- err = torch.abs(out_bnb-out_torch).mean().item()
- #print(f'abs error {err:.4f}')
+ err = torch.abs(out_bnb - out_torch).mean().item()
+ # print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() <= n*0.0175
+ assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() <= n*0.001
+ assert (idx == 0).sum().item() < n * 0.001
if has_fp16_weights:
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
- loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb = torch.nn.functional.mse_loss(
+ out_bnb, target
+ ).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -259,7 +378,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
@@ -269,8 +390,10 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() <= n*0.1
- idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() <= n*0.02
- torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+ assert (idx == 0).sum().item() < n * 0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx == 0).sum().item() < n * 0.02
+ torch.testing.assert_allclose(
+ gradB1, gradB2, atol=0.18, rtol=0.3
+ )
diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py
new file mode 100644
index 0000000..3d34c29
--- /dev/null
+++ b/tests/test_cuda_setup_evaluator.py
@@ -0,0 +1,150 @@
+import os
+import pytest
+import bitsandbytes as bnb
+
+from typing import List, NamedTuple
+
+from bitsandbytes.cuda_setup import (
+ CUDA_RUNTIME_LIB,
+ evaluate_cuda_setup,
+ determine_cuda_runtime_lib_path,
+ extract_candidate_paths,
+)
+
+"""
+'LD_LIBRARY_PATH': ':/mnt/D/titus/local/cuda-11.1/lib64/'
+'CONDA_EXE': '/mnt/D/titus/miniconda/bin/conda'
+'LESSCLOSE': '/usr/bin/lesspipe %s %s'
+'OLDPWD': '/mnt/D/titus/src'
+'CONDA_PREFIX': '/mnt/D/titus/miniconda/envs/8-bit'
+'SSH_AUTH_SOCK': '/mnt/D/titus/.ssh/ssh-agent.tim-uw.sock'
+'CONDA_PREFIX_1': '/mnt/D/titus/miniconda'
+'PWD': '/mnt/D/titus/src/8-bit'
+'HOME': '/mnt/D/titus'
+'CONDA_PYTHON_EXE': '/mnt/D/titus/miniconda/bin/python'
+'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
+'TMUX': '/tmp/tmux-1007/default,59286,1'
+'XDG_DATA_DIRS': '/usr/local/share:/usr/share:/var/lib/snapd/desktop'
+'SSH_TTY': '/dev/pts/0'
+'MAIL': '/var/mail/titus'
+'SHELL': '/bin/bash'
+'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/1007/bus'
+'XDG_RUNTIME_DIR': '/run/user/1007'
+'PATH': '/mnt/D/titus/miniconda/envs/8-bit/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/mnt/D/titus/local/cuda-11.1/bin'
+'LESSOPEN': '| /usr/bin/lesspipe %s'
+'_': '/mnt/D/titus/miniconda/envs/8-bit/bin/python'
+# any that include 'CONDA' that are not 'CONDA_PREFIX'
+
+# we search for
+'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
+"""
+
+
+class InputAndExpectedOutput(NamedTuple):
+ input: str
+ output: str
+
+
+HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
+ (
+ f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+]
+
+
+@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
+def happy_path_path_string(tmpdir, request):
+ for path in extract_candidate_paths(request.param):
+ test_dir.mkdir()
+ if CUDA_RUNTIME_LIB in path:
+ (test_input / CUDA_RUNTIME_LIB).touch()
+
+
+@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS)
+def test_determine_cuda_runtime_lib_path__happy_path(
+ tmp_path, test_input: str, expected: str
+):
+ for path in extract_candidate_paths(test_input):
+ path.mkdir()
+ (path / CUDA_RUNTIME_LIB).touch()
+ assert determine_cuda_runtime_lib_path(test_input) == expected
+
+
+UNHAPPY_PATH__LD_LIB_TEST_PATHS = [
+ f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}",
+ f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}",
+]
+
+
+@pytest.mark.parametrize("test_input", UNHAPPY_PATH__LD_LIB_TEST_PATHS)
+def test_determine_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
+ test_input = tmp_path / test_input
+ (test_input / CUDA_RUNTIME_LIB).touch()
+ with pytest.raises(FileNotFoundError) as err_info:
+ determine_cuda_runtime_lib_path(test_input)
+ assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB})
+
+
+def test_determine_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
+ existent_dir = tmp_path / "a/b"
+ existent_dir.mkdir()
+ non_existent_dir = tmp_path / "c/d" # non-existent dir
+ test_input = ":".join([str(existent_dir), str(non_existent_dir)])
+
+ determine_cuda_runtime_lib_path(test_input)
+ std_err = capsys.readouterr().err
+
+ assert all(match in std_err for match in {"WARNING", "non-existent"})
+
+
+def test_full_system():
+ ## this only tests the cuda version and not compute capability
+
+ # if CONDA_PREFIX exists, it has priority before all other env variables
+ # but it does not contain the library directly, so we need to look at the a sub-folder
+ version = ""
+ if "CONDA_PREFIX" in os.environ:
+ ls_output, err = bnb.utils.execute_and_return(
+ f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so'
+ )
+ major, minor, revision = (
+ ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")
+ )
+ version = float(f"{major}.{minor}")
+
+ if version == "" and "LD_LIBRARY_PATH":
+ ld_path = os.environ["LD_LIBRARY_PATH"]
+ paths = ld_path.split(":")
+ version = ""
+ for p in paths:
+ if "cuda" in p:
+ idx = p.rfind("cuda-")
+ version = p[idx + 5 : idx + 5 + 4].replace("/", "")
+ version = float(version)
+ break
+
+ assert version > 0
+ binary_name = evaluate_cuda_setup()
+ binary_name = binary_name.replace("libbitsandbytes_cuda", "")
+ assert binary_name.startswith(str(version).replace(".", ""))
diff --git a/tests/test_functional.py b/tests/test_functional.py
index bfc3e28..ab7d672 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,25 +1,29 @@
-import pytest
import math
import random
import time
-import torch
-import bitsandbytes as bnb
-import einops
-
from itertools import product
+import einops
+import pytest
+import torch
+
+import bitsandbytes as bnb
from bitsandbytes import functional as F
-torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
+torch.set_printoptions(
+ precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
+)
k = 20
+
def assert_all_approx_close(a, b, rtol, atol, count):
idx = torch.isclose(a, b, rtol, atol)
- sumval = (idx==0).sum().item()
+ sumval = (idx == 0).sum().item()
if sumval > count:
- print(f'Too many values not close: assert {sumval} < {count}')
+ print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol)
+
class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__()
@@ -35,13 +39,14 @@ class FFN(torch.nn.Module):
x = self.fc2(x)
return x
+
class Timer(object):
def __init__(self):
self.starts = {}
self.ends = {}
self.agg = {}
- def tick(self, name='default'):
+ def tick(self, name="default"):
if name not in self.starts:
self.starts[name] = torch.cuda.Event(enable_timing=True)
self.ends[name] = torch.cuda.Event(enable_timing=True)
@@ -49,66 +54,72 @@ class Timer(object):
else:
ms = self.tock(name, evict=True, print_ms=False)
- def tock(self, name='default', evict=True, print_ms=True):
+ def tock(self, name="default", evict=True, print_ms=True):
if name in self.ends:
self.ends[name].record()
torch.cuda.synchronize()
ms = self.starts[name].elapsed_time(self.ends[name])
- if name not in self.agg: self.agg[name] = 0.0
+ if name not in self.agg:
+ self.agg[name] = 0.0
self.agg[name] += ms
if evict:
self.starts.pop(name)
self.ends.pop(name)
if print_ms and name in self.agg:
- print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0))
+ print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
return self.agg[name]
def reset(self):
- self.starts = {}
+ self.starts = {}
self.ends = {}
self.agg = {}
- print('Resetting benchmark data')
+ print("Resetting benchmark data")
+
def setup():
pass
+
def teardown():
pass
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half'])
+
+@pytest.mark.parametrize(
+ "dtype", [torch.float32, torch.float16], ids=["float", "half"]
+)
def test_estimate_quantiles(dtype):
- A = torch.rand(1024, 1024, device='cuda')
+ A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
- percs = torch.linspace(1/512, 511/512, 256, device=A.device)
+ percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
- A = torch.randn(1024, 1024, device='cuda')
+ A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
- diff = torch.abs(code-quantiles)
+ diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
def test_quantile_quantization():
for i in range(100):
- A1 = torch.randn(1024, 1024, device='cuda')
+ A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
- diff = torch.abs(A1-A2).mean().item()
+ diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
- A1 = torch.rand(1024, 1024, device='cuda')
+ A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
- diff = torch.abs(A1-A2).mean().item()
+ diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
@@ -117,22 +128,22 @@ def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
- A1 = torch.randn(1024, 1024, device='cuda')
+ A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
- diff = torch.abs(A1-A2)
- reldiff = diff/torch.abs(A1+1e-8)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
- #print(sum(diffs)/len(diffs))
- #print(sum(reldiffs)/len(reldiffs))
+ # print(sum(diffs)/len(diffs))
+ # print(sum(reldiffs)/len(reldiffs))
for i in range(100):
- A1 = torch.rand(1024, 1024, device='cuda')
+ A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
- diff = torch.abs(A1-A2).mean().item()
+ diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
@@ -141,56 +152,62 @@ def test_dynamic_blockwise_quantization():
diffs = []
reldiffs = []
for i in range(100):
- A1 = torch.randn(1024, 1024, device='cuda')
+ A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1-A2)
- reldiff = diff/torch.abs(A1+1e-8)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
- #print(sum(diffs)/len(diffs))
- #print(sum(reldiffs)/len(reldiffs))
+ # print(sum(diffs)/len(diffs))
+ # print(sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
- A1 = torch.rand(1024, 1024, device='cuda')
+ A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1-A2).mean().item()
+ diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0033
diffs.append(diff)
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
- #print(sum(diffs)/len(diffs))
+ # print(sum(diffs)/len(diffs))
+
def test_dynamic_blockwise_stochastic_quantization():
diffs = []
reldiffs = []
rand = torch.rand(1024).cuda()
for i in range(100):
- A1 = torch.randn(1024, 1024, device='cuda')
+ A1 = torch.randn(1024, 1024, device="cuda")
C1, S1 = F.quantize_blockwise(A1, rand=rand)
C2, S2 = F.quantize_blockwise(A1)
# a maximunm distance of quantized values of 1
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
- fraction_smaller = (C1<C2).float().sum()/C1.numel()
- fraction_larger = (C1>C2).float().sum()/C1.numel()
- torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0)
+ fraction_smaller = (C1 < C2).float().sum() / C1.numel()
+ fraction_larger = (C1 > C2).float().sum() / C1.numel()
+ torch.testing.assert_allclose(
+ fraction_larger, fraction_smaller, atol=0.01, rtol=0
+ )
-
-@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half'])
+@pytest.mark.parametrize(
+ "gtype", [torch.float32, torch.float16], ids=["float", "half"]
+)
def test_percentile_clipping(gtype):
- gnorm_vec1 = torch.zeros(100, device='cuda')
- gnorm_vec2 = torch.zeros(100, device='cuda')
+ gnorm_vec1 = torch.zeros(100, device="cuda")
+ gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
- percentile=5
+ percentile = 5
for i in range(k):
step += 1
- g = torch.randn(n, n, dtype=gtype, device='cuda')
- gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
- assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1
+ g = torch.randn(n, n, dtype=gtype, device="cuda")
+ gnorm1, clip2, gnorm_scale = F.percentile_clipping(
+ g, gnorm_vec2, step, percentile=percentile
+ )
+ assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
@@ -208,74 +225,98 @@ def test_percentile_clipping(gtype):
def quant(x):
max1 = torch.abs(x).max()
- x = torch.round(x/max1*127)
+ x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
+
def dequant(c, maxC):
- return c.float()*(maxC/127)
+ return c.float() * (maxC / 127)
+
def mm_dequant(maxA, maxB, C):
- return C.float()*(maxA/127)*(maxB/127)
+ return C.float() * (maxA / 127) * (maxB / 127)
+
def quant_multi(x, dim):
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
- max1[max1==0] = 1.0
- x = torch.round(x/max1*127)
+ max1[max1 == 0] = 1.0
+ x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
+
def quant_multi_chunk(x, dim, chunk_size=32):
- if dim==1:
- x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size)
- max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True)
+ if dim == 1:
+ x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
+ max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
max1 = torch.tile(max1, (1, 1, x.shape[1]))
max1 = max1.view(x.shape)
- elif dim==0:
- x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size)
+ elif dim == 0:
+ x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
max1 = torch.tile(max1, (x.shape[0], 1, 1))
max1 = max1.view(x.shape)
- max1[max1==0] = 1.0
- x = torch.round(x/max1*127)
+ max1[max1 == 0] = 1.0
+ x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
+
def quant_minmax(A):
minA = A.min()
maxA = A.max()
-def mean(xx):
- return sum(xx)/float(len(xx))
-#dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
-#dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
-dim1 = [1024*2]
-dim2 = [1024*16]
-methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)]
+def mean(xx):
+ return sum(xx) / float(len(xx))
+
+
+# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
+# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
+dim1 = [1024 * 2]
+dim2 = [1024 * 16]
+methods = [
+ (
+ lambda x, dim: quant(x),
+ lambda x, dim: quant(x),
+ dequant,
+ dequant,
+ mm_dequant,
+ )
+]
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
-#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
-method_names = ['linear', 'vectorwise']
+# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
+method_names = ["linear", "vectorwise"]
batched = [False, True]
-values = list(product(dim1,dim2, methods, batched))
-values_names = list(product(dim1,dim2, method_names, batched))
-names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names]
-@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
+values = list(product(dim1, dim2, methods, batched))
+values_names = list(product(dim1, dim2, method_names, batched))
+names = [
+ "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
+ for vals in values_names
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, quant_methods, batched", values, ids=names
+)
def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32)
errors = []
relerrors = []
- print('')
+ print("")
for i in range(5):
if batched:
- A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda')
- B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda')
+ A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
+ B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 2)
maxB, Bc = quant_methods[1](B, 1)
else:
- A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda')
- B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda')
+ A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
+ B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0)
- torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
+ torch.testing.assert_allclose(
+ quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
+ )
if batched:
out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float())
@@ -284,43 +325,53 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
C = F.igemm(Ac, Bc)
out = quant_methods[4](maxA, maxB, C)
std = out2.std()
- out/= std
- out2/= std
- err = torch.abs(out-out2)
- relerr = err/torch.abs(out2)
+ out /= std
+ out2 /= std
+ err = torch.abs(out - out2)
+ relerr = err / torch.abs(out2)
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
print(mean(errors))
print(mean(relerrors))
-
-
-
-
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
-
n = 2
-hidden_dim = torch.randint(32,256, size=(n,)).tolist()
-batch_dim = torch.randint(16,256, size=(n,)).tolist()
-seq_dim = torch.randint(16,256, size=(n,)).tolist()
+hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
+batch_dim = torch.randint(16, 256, size=(n,)).tolist()
+seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)]
-values = list(product(hidden_dim,batch_dim, transpose, seq_dim))
-names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
+values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
+names = [
+ "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
+ for vals in values
+]
+
+
+@pytest.mark.parametrize(
+ "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
+)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
- shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
- shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ shapeA = (
+ (batch_dim, hidden_dim)
+ if not transpose[0]
+ else (hidden_dim, batch_dim)
+ )
+ shapeB = (
+ (32 * random.randint(1, 4), hidden_dim)
+ if transpose[1]
+ else (hidden_dim, 32 * random.randint(1, 4))
+ )
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
@@ -338,9 +389,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim)
- shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ shapeB = (
+ (32 * random.randint(1, 4), hidden_dim)
+ if transpose[1]
+ else (hidden_dim, 32 * random.randint(1, 4))
+ )
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
@@ -352,40 +407,57 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
n = 3
-seq_dim = torch.randint(32,512, size=(n,)).tolist()
-hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
-batch_dim = torch.randint(2,16, size=(n,)).tolist()
-values = list(product(seq_dim,hidden_dim,batch_dim))
-names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values]
+seq_dim = torch.randint(32, 512, size=(n,)).tolist()
+hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
+batch_dim = torch.randint(2, 16, size=(n,)).tolist()
+values = list(product(seq_dim, hidden_dim, batch_dim))
+names = [
+ "seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2)
for i in range(25):
- A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8)
- out2 = torch.einsum('bsi, bso->io', A.float(), B.float())
- iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
+ A = torch.randint(
+ -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
+ ).to(torch.int8)
+ B = torch.randint(
+ -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
+ ).to(torch.int8)
+ out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
+ iout = torch.empty(
+ A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
+ )
out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2)
+
n = 2
-seq_dim = torch.randint(32,512, size=(n,)).tolist()
-hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
-batch_dim = torch.randint(2,16, size=(n,)).tolist()
+seq_dim = torch.randint(32, 512, size=(n,)).tolist()
+hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
+batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True]
-values = list(product(seq_dim,hidden_dim,batch_dim, transpose))
-names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
-def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
+values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
+names = [
+ "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
+ for vals in values
+]
+
+@pytest.mark.parametrize(
+ "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
+)
+def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
minA = torch.amin(x, dim=2, keepdim=True)
- scale = (maxA-minA)/2.0
- return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale
+ scale = (maxA - minA) / 2.0
+ return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
seq_dim = seq_dim - (seq_dim % 16)
hidden_dim = hidden_dim - (hidden_dim % 16)
@@ -395,30 +467,32 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2 = []
relerrs2 = []
for i in range(k):
- A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda')
+ A = torch.normal(
+ 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
+ )
if transpose:
- B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda')
+ B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else:
- B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda')
+ B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Ac, minA, scale = min_max(A)
if transpose:
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
out = F.igemm(Ac, Bc.t())
- out2 = torch.matmul(A,B.t())
- offset = B.t().sum(0)*(minA+scale)
+ out2 = torch.matmul(A, B.t())
+ offset = B.t().sum(0) * (minA + scale)
out = out.float()
- out = (out*maxB.t()*scale/(127*127))+offset
+ out = (out * maxB.t() * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc.t())
out3 = mm_dequant(maxA, maxB.t(), out3)
else:
maxB, Bc = quant_multi(B, dim=0)
- offset = B.sum(0)*(minA+scale)
+ offset = B.sum(0) * (minA + scale)
out = F.igemm(Ac, Bc)
- out2 = torch.matmul(A,B)
+ out2 = torch.matmul(A, B)
out = out.float()
- out = (out*maxB*scale/(127*127))+offset
+ out = (out * maxB * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc)
@@ -429,31 +503,37 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
out /= std
out3 /= std
- err = torch.abs(out-out2)
- relerr = err/(torch.abs(out2)+1e-7)
+ err = torch.abs(out - out2)
+ relerr = err / (torch.abs(out2) + 1e-7)
- err2 = torch.abs(out3-out2)
- relerr2 = err2/(torch.abs(out2)+1e-7)
+ err2 = torch.abs(out3 - out2)
+ relerr2 = err2 / (torch.abs(out2) + 1e-7)
errs.append(err.mean().item())
relerrs.append(relerr.mean().item())
errs2.append(err2.mean().item())
relerrs2.append(relerr2.mean().item())
- #print(mean(errs))
- #print(mean(relerrs))
- #print(mean(errs2))
- #print(mean(relerrs2))
+ # print(mean(errs))
+ # print(mean(relerrs))
+ # print(mean(errs2))
+ # print(mean(relerrs2))
assert mean(errs) < 0.015
assert mean(relerrs) < 0.3
+
n = 2
-dim1 = torch.randint(1,64, size=(n,)).tolist()
-dim2 = torch.randint(32,128, size=(n,)).tolist()
-dim3 = torch.randint(32,256, size=(n,)).tolist()
-dim4 = torch.randint(32,256, size=(n,)).tolist()
+dim1 = torch.randint(1, 64, size=(n,)).tolist()
+dim2 = torch.randint(32, 128, size=(n,)).tolist()
+dim3 = torch.randint(32, 256, size=(n,)).tolist()
+dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)]
-values = list(product(dim1,dim2,dim3,dim4,transpose))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, dim3, dim4, transpose))
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
+ for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16)
@@ -462,8 +542,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
for i in range(k):
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.bmm(A.float(), B.float())
@@ -475,150 +555,203 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]:
- out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
+ out2 = torch.bmm(
+ A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
+ )
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float())
+
n = 1
-dim1 = torch.randint(1,64, size=(n,)).tolist()
-dim2 = torch.randint(32,128, size=(n,)).tolist()
-dim3 = torch.randint(32,256, size=(n,)).tolist()
-values = list(product(dim1,dim2,dim3))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values]
+dim1 = torch.randint(1, 64, size=(n,)).tolist()
+dim2 = torch.randint(32, 128, size=(n,)).tolist()
+dim3 = torch.randint(32, 256, size=(n,)).tolist()
+values = list(product(dim1, dim2, dim3))
+names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
for i in range(k):
- A = torch.randn(size=(dim2, dim3), device='cuda')
+ A = torch.randn(size=(dim2, dim3), device="cuda")
qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA)
torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)
-
n = 2
-dim1 = torch.randint(2,256, size=(n,)).tolist()
-dim2 = torch.randint(2,256, size=(n,)).tolist()
-dim3 = torch.randint(2,256, size=(n,)).tolist()
-#dim1, dim2 = (256,), (256,)
+dim1 = torch.randint(2, 256, size=(n,)).tolist()
+dim2 = torch.randint(2, 256, size=(n,)).tolist()
+dim3 = torch.randint(2, 256, size=(n,)).tolist()
+# dim1, dim2 = (256,), (256,)
dtype = [torch.int8, torch.int32]
-a_order = ['row']
-out_order = ['col', 'row', 'col32']
+a_order = ["row"]
+out_order = ["col", "row", "col32"]
transpose = [False]
dims = [2, 3]
-values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
-
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
-def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
- if dims == 3 and out_order != 'col32': return
- if dtype == torch.int32 and out_order != 'col32': return
+values = list(
+ product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
+)
+
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
+ *vals
+ )
+ for vals in values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
+ values,
+ ids=names,
+)
+def test_nvidia_transform(
+ dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose
+):
+ if dims == 3 and out_order != "col32":
+ return
+ if dtype == torch.int32 and out_order != "col32":
+ return
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype)
+ A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype)
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
+ dtype
+ )
out, S = F.nvidia_transform(A, to_order=orderOut)
- if orderOut == 'row':
+ if orderOut == "row":
torch.testing.assert_allclose(A.flatten(), out.flatten())
- elif orderOut == 'col':
+ elif orderOut == "col":
torch.testing.assert_allclose(A.t().flatten(), out.flatten())
- elif orderOut == 'col32':
+ elif orderOut == "col32":
if dims == 2:
- n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32)))
+ n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
- n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32)))
+ n = (
+ A.shape[0]
+ * A.shape[1]
+ * (A.shape[2] + (32 - (A.shape[2] % 32)))
+ )
assert out.numel() == n
- elif orderOut == 'col_turing':
+ elif orderOut == "col_turing":
# 32 col 8 row tiles
- n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32)))
+ n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
+ A.shape[1] + (32 - (A.shape[1] % 32))
+ )
assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]):
for col in range(A.shape[1]):
- i = row*A.shape[1]
+ i = row * A.shape[1]
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
- rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile
- offset = 32*8*(rowtile+coltile)
+ rowtile = (
+ (row // 8) + (1 if row % 8 != 0 else 0)
+ ) * total_coltile
+ offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
- row2 = (row%8)*32
+ row2 = (row % 8) * 32
+ assert A.flatten()[i + j] == A[row, col]
+ # assert A.flatten()[i+j] == out.flatten()[row2+col2]
+ # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
+ # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
- assert A.flatten()[i+j] == A[row, col]
- #assert A.flatten()[i+j] == out.flatten()[row2+col2]
- #torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
- #torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
-
- if orderOut == 'col32':
- out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S)
+ if orderOut == "col32":
+ out2, S = F.nvidia_transform(
+ out, from_order=orderOut, to_order="row", state=S
+ )
torch.testing.assert_allclose(A, out2)
n = 1
-dim1 = torch.randint(1,256, size=(n,)).tolist()
-dim2 = torch.randint(32,512, size=(n,)).tolist()
-dim3 = torch.randint(32,1024, size=(n,)).tolist()
-dim4 = torch.randint(32,1024, size=(n,)).tolist()
+dim1 = torch.randint(1, 256, size=(n,)).tolist()
+dim2 = torch.randint(32, 512, size=(n,)).tolist()
+dim3 = torch.randint(32, 1024, size=(n,)).tolist()
+dim4 = torch.randint(32, 1024, size=(n,)).tolist()
-#dim1 = [2]
-#dim2 = [2]
-#dim3 = [2]
-#dim4 = [2]
+# dim1 = [2]
+# dim2 = [2]
+# dim3 = [2]
+# dim4 = [2]
-dims = (2,3)
+dims = (2, 3)
ldb = [0]
-#ldb = list(range(256, 1*1024, 256))
-values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values]
+# ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
+ for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k):
if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
+ torch.int8
+ )
elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ A = torch.randint(
+ -128, 127, size=(dim1, dim2, dim3), device="cuda"
+ ).to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
+ torch.int8
+ )
C1 = torch.matmul(A.float(), B.t().float())
- A2, SA = F.transform(A, 'col32')
- B2, SB = F.transform(B, 'col_turing')
+ A2, SA = F.transform(A, "col32")
+ B2, SB = F.transform(B, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
- C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float())
# transpose
- B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
+ torch.int8
+ )
C1 = torch.matmul(A.float(), B.float())
- B2t, SBt = F.transform(B, 'col_turing', transpose=True)
+ B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
- C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float())
+
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]
dims = (2,)
-#ldb = list(range(256, 1*1024, 256))
-values = list(product(dim1,dim2,dim3,dim4,dims))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values]
+# ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1, dim2, dim3, dim4, dims))
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
+ for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
formatB = F.get_special_format_str()
for i in range(k):
if dims == 2:
- A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half()
+ A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3:
- A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half()
- B = torch.randn((dim4, dim3), device='cuda').half()
+ A = torch.normal(
+ 0, 0.5, size=(dim1, dim2, dim3), device="cuda"
+ ).half()
+ B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
C2 = bnb.matmul(A, B.t())
@@ -627,50 +760,58 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
CxB, SB = F.transform(CB, to_order=formatB)
out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
- #print('')
- #print(output.flatten()[:10])
- #print(C1.flatten()[:10])
- #print(C2.flatten()[:10])
+ # print('')
+ # print(output.flatten()[:10])
+ # print(C1.flatten()[:10])
+ # print(C2.flatten()[:10])
-
- #torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
+ # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
- #B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
- #C1 = torch.matmul(A.float(), B.float())
+ # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
+ # C1 = torch.matmul(A.float(), B.float())
+
+ # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
+ # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
+ # C3, S = F.transform(C2, 'row', state=SC)
+ # torch.testing.assert_allclose(C1, C3.float())
- #B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
- #C2, SC = F.igemmlt(A2, B2t, SA, SBt)
- #C3, S = F.transform(C2, 'row', state=SC)
- #torch.testing.assert_allclose(C1, C3.float())
batch_size = 2
seqdim = 512
-#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
-values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
+# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
+values = [
+ (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
+ (batch_size, seqdim, 5120, 3 * 5120),
+ (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
+]
+
+
+# values = list(product(batch, seq, model, hidden))
+names = [
+ "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
+]
-#values = list(product(batch, seq, model, hidden))
-names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str()
- A = torch.randn(batch, seq, model, device='cuda').half()
- grad = torch.randn(batch, seq, model, device='cuda').half()
- w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half()
- w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half()
- print('')
+ A = torch.randn(batch, seq, model, device="cuda").half()
+ grad = torch.randn(batch, seq, model, device="cuda").half()
+ w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
+ w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
+ print("")
- #torch.cuda.synchronize()
+ # torch.cuda.synchronize()
## warmup
- #for i in range(100):
+ # for i in range(100):
# torch.matmul(A, w1.t())
- #torch.cuda.synchronize()
+ # torch.cuda.synchronize()
dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous()
@@ -679,77 +820,77 @@ def test_bench_8bit_training(batch, seq, model, hidden):
t0 = time.time()
for i in range(k):
- out1 = torch.matmul(A, w1.t()) # fc1
- #out2 = torch.matmul(out1, w2.t())# fc2
+ out1 = torch.matmul(A, w1.t()) # fc1
+ # out2 = torch.matmul(out1, w2.t())# fc2
- #d1 = torch.matmul(grad, w2) # delta1
- #d2 = torch.matmul(d1, w1) # delta2
+ # d1 = torch.matmul(grad, w2) # delta1
+ # d2 = torch.matmul(d1, w1) # delta2
- #grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
- #grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
+ # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
+ # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch.cuda.synchronize()
t16 = time.time() - t0
print(t16)
- #torch.cuda.empty_cache()
+ # torch.cuda.empty_cache()
- #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
- #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+ # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
- #CTw1, Sw1 = F.transform2(Cw1, formatB)
- #CTw2, Sw2 = F.transform2(Cw2, formatB)
- #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
- #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+ # CTw1, Sw1 = F.transform2(Cw1, formatB)
+ # CTw2, Sw2 = F.transform2(Cw2, formatB)
+ # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+ # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
- #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
- #C32A, SA = F.transform2(CA, 'col32')
+ # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ # C32A, SA = F.transform2(CA, 'col32')
## fc1
- #out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
+ # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
- #Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
- #C32out1, Sout1 = F.transform2(Cout1, 'col32')
- #out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
+ # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
+ # C32out1, Sout1 = F.transform2(Cout1, 'col32')
+ # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
- #Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
- #C32grad, Sgrad = F.transform2(Cgrad, 'col32')
+ # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
+ # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
- #Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
- #C32d1, Sd1 = F.transform2(Cd1, 'col32')
+ # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
+ # C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
- #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
- #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
+ # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
+ # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
- #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
- #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
+ # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
+ # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
- #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+ # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
- #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
- #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+ # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
- #CTw1, Sw1 = F.transform2(Cw1, formatB)
- #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
- #CTw2, Sw2 = F.transform2(Cw2, formatB)
- #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
- #torch.cuda.synchronize()
- #t0 = time.time()
- #for i in range(k):
+ # CTw1, Sw1 = F.transform2(Cw1, formatB)
+ # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+ # CTw2, Sw2 = F.transform2(Cw2, formatB)
+ # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
@@ -802,74 +943,78 @@ def test_bench_8bit_training(batch, seq, model, hidden):
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
- #torch.cuda.synchronize()
- #t8 = time.time() - t0
- #print(t8)
-
-
-
+ # torch.cuda.synchronize()
+ # t8 = time.time() - t0
+ # print(t8)
n = 2
-dim1 = torch.randint(64,256, size=(n,)).tolist()
-dim4 = torch.randint(64,1024, size=(n,)).tolist()
+dim1 = torch.randint(64, 256, size=(n,)).tolist()
+dim4 = torch.randint(64, 1024, size=(n,)).tolist()
-#dim1 = [2*1024]
-#dim4 = [2*1024]
+# dim1 = [2*1024]
+# dim4 = [2*1024]
-#dim1 = [4]
-#dim4 = [4]
+# dim1 = [4]
+# dim4 = [4]
dims = (2,)
-#ldb = list(range(256, 1*1024, 256))
-formatB = ['col_turing', 'col_ampere']
-values = list(product(dim1,dim4,dims, formatB))
-names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values]
+# ldb = list(range(256, 1*1024, 256))
+formatB = ["col_turing", "col_ampere"]
+values = list(product(dim1, dim4, dims, formatB))
+names = [
+ "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB):
inner = torch.randint(1, 128, size=(1,)).item()
formatB = F.get_special_format_str()
for i in range(k):
- A = torch.randn(dim1, inner, device='cuda')
- B = torch.randn(dim4, inner, device='cuda')
+ A = torch.randn(dim1, inner, device="cuda")
+ B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half())
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
- A2, SA = F.nvidia_transform(A1, 'col32')
+ A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, formatB)
C2, SC = F.igemmlt(A2, B2, SA, SB)
- C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
n = C1.numel()
p = 0.06
- assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}'
+ assert (
+ count / n < p
+ ), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
torch.testing.assert_allclose(C5, C4)
- #print(C2)
-
+ # print(C2)
n = 2
-dim1 = [1*1024]
-dim2 = [1*1024]
-#dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1 * 1024]
+dim2 = [1 * 1024]
+# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims = (2,)
-#ldb = list(range(256, 1*1024, 256))
-values = list(product(dim1,dim2,dims))
-names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values]
+# ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1, dim2, dims))
+names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
for i in range(k):
threshold = 3.0
- A = torch.randn(dim1, dim2, device='cuda').half()
+ A = torch.randn(dim1, dim2, device="cuda").half()
A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
if dims == 2:
@@ -880,37 +1025,51 @@ def test_colrow_absmax(dim1, dim2, dims):
else:
assert False
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
-
- A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4)
- nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten()
- nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device)
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
+ A, threshold=threshold
+ )
+
+ A_blocked = einops.rearrange(
+ torch.abs(A),
+ "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
+ row_tiles=16,
+ block_size=64 * 4,
+ )
+ nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
+ nnz_block_ptr1 = torch.zeros(
+ nnz_rows1_counts.shape[0] + 1,
+ dtype=nnz_rows1_counts.dtype,
+ device=nnz_rows1_counts.device,
+ )
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
+ A, threshold=0.0
+ )
torch.testing.assert_allclose(col_stats1, col_stats2)
torch.testing.assert_allclose(row_stats1, row_stats2)
assert nnz_block_ptr2 is None
-
n = 2
-#dim1 = [8*1024]
-#dim2 = [4*1024]
-dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+# dim1 = [8*1024]
+# dim2 = [4*1024]
+dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+
+values = list(product(dim1, dim2))
+names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
+
-values = list(product(dim1,dim2))
-names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
for i in range(k):
- A = torch.randn(dim1, dim2, device='cuda').half()
+ A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1)
@@ -920,18 +1079,25 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
-
n = CAt.numel()
- num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item()
- num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item()
+ num_not_close_rows = (
+ (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
+ )
+ num_not_close_cols = (
+ (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
+ )
# allow for 1:500 error due to rounding differences
- min_error = 1/500
- if num_not_close_cols > (min_error*n):
- print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}')
+ min_error = 1 / 500
+ if num_not_close_cols > (min_error * n):
+ print(
+ f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
+ )
assert False
- if num_not_close_rows > (min_error*n):
- print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}')
+ if num_not_close_rows > (min_error * n):
+ print(
+ f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
+ )
assert False
torch.testing.assert_allclose(Srow.flatten(), statsA)
@@ -939,21 +1105,23 @@ def test_double_quant(dim1, dim2):
n = 4
-dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
-inner = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim1 = [6]
dim4 = [4]
inner = [8]
values = list(zip(dim1, dim4, inner))
-names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k):
- A = torch.randn(dim1, inner, device='cuda').half()
- B = torch.randn(dim4, inner, device='cuda').half()
+ A = torch.randn(dim1, inner, device="cuda").half()
+ B = torch.randn(dim4, inner, device="cuda").half()
out1 = torch.matmul(A.half(), B.t().half())
@@ -967,30 +1135,32 @@ def test_integrated_igemmlt(dim1, dim4, inner):
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
- A2, SA = F.nvidia_transform(C1a, 'col32')
- B2, SB = F.nvidia_transform(C2a, 'col_turing')
+ A2, SA = F.nvidia_transform(C1a, "col32")
+ B2, SB = F.nvidia_transform(C2a, "col_turing")
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
- A2, SA = F.nvidia_transform(A1, 'col32')
- B2, SB = F.nvidia_transform(B1, 'col_turing')
+ A2, SA = F.nvidia_transform(A1, "col32")
+ B2, SB = F.nvidia_transform(B1, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
- C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C3, S = F.nvidia_transform(C2, "row", state=SC)
out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
- err1 = torch.abs(out1-out2).mean().item()
- err2 = torch.abs(out1-out3).mean().item()
- assert err2 <= err1*1.01
+ err1 = torch.abs(out1 - out2).mean().item()
+ err2 = torch.abs(out1 - out3).mean().item()
+ assert err2 <= err1 * 1.01
n = 6
-dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
-inner = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
-names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
@@ -999,79 +1169,81 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
relerr1, relerr2 = [], []
scale = 1
for i in range(k):
- A = torch.randn(dim1, inner, device='cuda').half()
- B = torch.randn(dim4, inner, device='cuda').half()
+ A = torch.randn(dim1, inner, device="cuda").half()
+ B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
out1 = torch.matmul(A.half(), B.t().half())
-
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
- CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
- A2, SA = F.nvidia_transform(C1a, 'col32')
+ CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
+ A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
- c = 10.0*inner*scale
- row_scale = torch.ones_like(maxA)/c
- outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
- C3, S = F.nvidia_transform(outC32, 'row', state=SC)
+ c = 10.0 * inner * scale
+ row_scale = torch.ones_like(maxA) / c
+ outC32, SC = F.igemmlt(
+ A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
+ )
+ C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max()
if maxval == 127:
scale = 1.5
else:
- scale = maxval/120
- out3 = C3*maxA*absmaxB*c/(127*127)
+ scale = maxval / 120
+ out3 = C3 * maxA * absmaxB * c / (127 * 127)
C4 = torch.matmul(C1a.float(), CB.float().t())
-
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
- CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector')
- CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear')
+ CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
+ CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
C = torch.matmul(CA.float(), CB.t().float())
- out4 = C*SA*SB/(127*127)
- #out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
+ out4 = C * SA * SB / (127 * 127)
+ # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
- #print('='*80)
- #print(out1)
- #print(out2)
- #print(out3)
+ # print('='*80)
+ # print(out1)
+ # print(out2)
+ # print(out3)
- #print(out1)
- #print(out2)
- #print(out3)
- err1.append(torch.abs(out1-out2).mean().item())
- err2.append(torch.abs(out1-out3).mean().item())
- err3.append(torch.abs(out1-out4).mean().item())
+ # print(out1)
+ # print(out2)
+ # print(out3)
+ err1.append(torch.abs(out1 - out2).mean().item())
+ err2.append(torch.abs(out1 - out3).mean().item())
+ err3.append(torch.abs(out1 - out4).mean().item())
- #assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
- print('')
- print(sum(err1)/len(err1))
- print(sum(err2)/len(err2))
- print(sum(err3)/len(err3))
+ # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
+ print("")
+ print(sum(err1) / len(err1))
+ print(sum(err2) / len(err2))
+ print(sum(err3) / len(err3))
dim1 = [1024, 2048]
-inner = [12288*4, 4096*4]
+inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
-names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
- A = torch.randn(dim1, inner, device='cuda').half()
- B = torch.randn(dim4, inner, device='cuda').half()
+ A = torch.randn(dim1, inner, device="cuda").half()
+ B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
# warmpup
for i in range(k):
@@ -1082,23 +1254,24 @@ def test_row_scale_bench(dim1, dim4, inner):
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
- print('16', time.time()-t0)
+ print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
- CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
- A2, SA = F.nvidia_transform(C1a, 'col32')
+ CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
+ A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
- c = 10.0*inner*scale
- row_scale = maxA/c
+ c = 10.0 * inner * scale
+ row_scale = maxA / c
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
- outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
+ outC32, SC = F.igemmlt(
+ A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
+ )
torch.cuda.synchronize()
- print('row-wise', time.time()-t0)
-
+ print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
@@ -1107,32 +1280,47 @@ def test_row_scale_bench(dim1, dim4, inner):
for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB)
torch.cuda.synchronize()
- print('vector-wise', time.time()-t0)
-
-
+ print("vector-wise", time.time() - t0)
n = 2
-dim1 = torch.randint(2,1024, size=(n,)).tolist()
-dim2 = torch.randint(2,1024, size=(n,)).tolist()
-#dim1 = [8*1024]
-#dim2 = [4*1024]
+dim1 = torch.randint(2, 1024, size=(n,)).tolist()
+dim2 = torch.randint(2, 1024, size=(n,)).tolist()
+# dim1 = [8*1024]
+# dim2 = [4*1024]
dim3 = [0]
dtype = [torch.int8]
-a_order = ['row']
-out_order = ['col32', 'col_turing', 'col_ampere']
+a_order = ["row"]
+out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True]
dims = [2]
-values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
+values = list(
+ product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
+)
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
+ *vals
+ )
+ for vals in values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
+ values,
+ ids=names,
+)
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
- A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype)
+ A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
+ dtype
+ )
elif dims == 3:
- A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype)
+ A = torch.randint(
+ 10, 99, size=(dim1, dim2, dim3), device="cuda"
+ ).to(dtype)
A.view(-1)[-1] = -1
if transpose:
@@ -1144,53 +1332,57 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
assert S1[0][0] == S2[0][0]
assert S1[0][1] == S2[0][1]
- #print(out1)
- #print(out2)
+ # print(out1)
+ # print(out2)
torch.testing.assert_allclose(out1, out2)
+
n = 2
-#dim1 = torch.randint(2,1024, size=(n,)).tolist()
-#dim2 = torch.randint(2,1024, size=(n,)).tolist()
+# dim1 = torch.randint(2,1024, size=(n,)).tolist()
+# dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1 = [1]
dim2 = [33]
dtype = [torch.int8]
-#a_order = ['col_turing', 'col_ampere']
-a_order = ['col_turing']
-out_order = ['row']
-values = list(product(dim1,dim2,dtype, a_order, out_order))
-names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
+# a_order = ['col_turing', 'col_ampere']
+a_order = ["col_turing"]
+out_order = ["row"]
+values = list(product(dim1, dim2, dtype, a_order, out_order))
+names = [
+ "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
+ for vals in values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dtype, orderA, orderOut", values, ids=names
+)
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
for i in range(1):
- A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype)
+ A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
out2, S2 = F.transform(A, to_order=orderA)
- A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2)
+ A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2)
assert A2.shape[0] == A.shape[0]
assert A2.shape[1] == A.shape[1]
-
- print('')
+ print("")
print(A)
print(out2)
print(A2)
-
- #torch.testing.assert_allclose(A, A2)
-
-
+ # torch.testing.assert_allclose(A, A2)
def test_overflow():
formatB = F.get_special_format_str()
print(formatB)
for i in range(2):
- a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
- b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
+ a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
+ b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
- Ca, Sa = F.nvidia_transform(a, 'col32')
+ Ca, Sa = F.nvidia_transform(a, "col32")
Cb, Sb = F.nvidia_transform(b, formatB)
c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
@@ -1198,46 +1390,57 @@ def test_overflow():
n = 2
-dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
-#dim1 = [4]
-#dim2 = [5]
+dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+# dim1 = [4]
+# dim2 = [5]
+
+values = list(product(dim1, dim2))
+names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
+
-values = list(product(dim1,dim2))
-names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
threshold = 3.00
for i in range(k):
- A = torch.randn(dim1, dim2, device='cuda').half()
+ A = torch.randn(dim1, dim2, device="cuda").half()
- idx = (torch.abs(A) >= threshold)
+ idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
- CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
+ A, threshold=threshold
+ )
if coo_tensor is not None:
- A1 = A*idx
+ A1 = A * idx
A2 = torch.zeros_like(A)
- A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
+ A2[
+ coo_tensor.rowidx.long(), coo_tensor.colidx.long()
+ ] = coo_tensor.values
torch.testing.assert_allclose(A1, A2)
- A1 = A*(idx==0)
- A2 = (CA.float()*statsA.unsqueeze(1)/127).half()
- torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2)
+ A1 = A * (idx == 0)
+ A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
+ torch.testing.assert_allclose(
+ A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
+ )
+
n = 2
-dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
-dim2 = torch.randint(1,1*1024, size=(n,)).tolist()
-#dim1 = [7]
-#dim2 = [11]
+dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
+dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
+# dim1 = [7]
+# dim2 = [11]
transposed_B = [False, True]
-values = list(product(dim1,dim2, transposed_B))
-names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, transposed_B))
+names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item()
- #dim3 = 17
+ # dim3 = 17
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
if transposed_B:
@@ -1249,8 +1452,10 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
if transposed_B:
out2 = F.spmm_coo(cooA, B.t())
@@ -1262,18 +1467,17 @@ def test_spmm_coo(dim1, dim2, transposed_B):
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
-
def test_spmm_bench():
batch = 2
- model = 1024*1
- hidden = model*4
+ model = 1024 * 1
+ hidden = model * 4
seq = 1024
- dim1 = batch*seq
+ dim1 = batch * seq
dim2 = model
dim3 = hidden
threshold = 4
- A = torch.randn(dim1, dim2, device='cuda').half()
- B = torch.randn(dim2, dim3, device='cuda').half()
+ A = torch.randn(dim1, dim2, device="cuda").half()
+ B = torch.randn(dim2, dim3, device="cuda").half()
for i in range(10):
C1 = bnb.matmul(A, B)
@@ -1282,14 +1486,16 @@ def test_spmm_bench():
for i in range(k):
C1 = bnb.matmul(A, B)
torch.cuda.synchronize()
- t8 = time.time()-t0
+ t8 = time.time() - t0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
- print(nnz/idx.numel())
+ print(nnz / idx.numel())
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
for i in range(10):
out2 = F.spmm_coo(cooA, B)
@@ -1299,20 +1505,22 @@ def test_spmm_bench():
for i in range(k):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
- tsp = time.time()-t0
+ tsp = time.time() - t0
print(tsp, t8)
- print(tsp/t8)
+ print(tsp / t8)
n = 2
-dim1 = torch.randint(256,1*1024, size=(n,)).tolist()
-dim2 = torch.randint(256,1*1024, size=(n,)).tolist()
-values = list(product(dim1,dim2))
-names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
+dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
+dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
+values = list(product(dim1, dim2))
+names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0
- formatB = 'col_turing'
+ formatB = "col_turing"
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
@@ -1322,13 +1530,15 @@ def test_integrated_sparse_decomp(dim1, dim2):
CTw1, Sw1 = F.transform(Cw1, formatB)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
- CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
- C32A, SA = F.transform(CA, 'col32')
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
+ A, threshold=threshold
+ )
+ C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
@@ -1338,8 +1548,8 @@ def test_integrated_sparse_decomp(dim1, dim2):
out4 = F.spmm_coo(coo_tensor, w1.t())
out5 = out3 + out4
- err1 = torch.abs(out1-out2).mean().item()
- err2 = torch.abs(out1-out5).mean().item()
+ err1 = torch.abs(out1 - out2).mean().item()
+ err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
@@ -1350,91 +1560,99 @@ def test_matmuls():
c2 = bnb.matmul(a, b)
c3 = bnb.matmul(a, b)
- err1 = torch.abs(c1-c2).mean().item()
- err2 = torch.abs(c1-c3).mean().item()
+ err1 = torch.abs(c1 - c2).mean().item()
+ err2 = torch.abs(c1 - c3).mean().item()
assert err1 < 0.2
assert err2 < 0.2
-
n = 2
-#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
-#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim1 = [1*2048]
+# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1 * 2048]
dim2 = [12288]
-#dim1 = [32]
-#dim2 = [32]
-#dtype = [torch.float16, torch.int8]
+# dim1 = [32]
+# dim2 = [32]
+# dtype = [torch.float16, torch.int8]
dtype = [torch.float16]
-out_function = ['zeros', 'ones']
-values = list(product(dim1,dim2, dtype, out_function))
-names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values]
+out_function = ["zeros", "ones"]
+values = list(product(dim1, dim2, dtype, out_function))
+names = [
+ "dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func)
threshold = 3.3
- #threshold = 2.8
- #threshold = 0.0
- A = torch.randn(dim1, dim2, device='cuda').half()
+ # threshold = 2.8
+ # threshold = 0.0
+ A = torch.randn(dim1, dim2, device="cuda").half()
if dtype == torch.float16:
- B = torch.randn(dim2, dim2*4, device='cuda').half()
+ B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
else:
- B = torch.randn(dim2, dim2*4, device='cuda').half()
+ B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
- B, SB = F.vectorwise_quant(B, quant_type='linear')
- #B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
+ B, SB = F.vectorwise_quant(B, quant_type="linear")
+ # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
- print('')
+ print("")
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
out1 += out.clone()
out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
- #print(B)
- #print(out1)
- #print(out2)
- p = 200/(2048*12288*4)
+ # print(B)
+ # print(out1)
+ # print(out2)
+ p = 200 / (2048 * 12288 * 4)
n = out1.numel()
- count = math.ceil(p*n)
+ count = math.ceil(p * n)
std = out1.std()
out1 /= std
out2 /= std
- assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
- #assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
+ assert_all_approx_close(
+ out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
+ )
+ # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
- #torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
+ # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
- #Bt = torch.randn(dim2*4, dim2, device='cuda').half()
- #torch.cuda.synchronize()
- #t0 = time.time()
- #print(A2.shape, B.shape)
- #for i in range(100):
+ # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # print(A2.shape, B.shape)
+ # for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
- #torch.cuda.synchronize()
- #print(time.time() - t0)
+ # torch.cuda.synchronize()
+ # print(time.time() - t0)
+
def test_layout():
- a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16)
- a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte()
- a2, s2 = F.transform(a1, 'col_turing')
+ a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
+ a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
+ a2, s2 = F.transform(a1, "col_turing")
print(a2.shape)
- print(a1.flatten()[8*64:8*64+32])
+ print(a1.flatten()[8 * 64 : 8 * 64 + 32])
for i in range(4):
- print(a2.flatten()[i*8*32:i*8*32+32], 0)
+ print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
def test_coo2csr():
@@ -1444,14 +1662,16 @@ def test_coo2csr():
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
- torch.testing.assert_allclose(counts, (A2!=0).sum(1))
- idx = (A2!=0)
+ torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
+ idx = A2 != 0
torch.testing.assert_allclose(A2[idx], csrA.values)
@@ -1462,41 +1682,43 @@ def test_coo2csc():
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
- torch.testing.assert_allclose(counts, (A2!=0).sum(0))
+ torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major
- idx = (A2.t()!=0)
+ idx = A2.t() != 0
torch.testing.assert_allclose(A2.t()[idx], cscA.values)
-
n = 2
-#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
-#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim1 = [1*2048]
-#dim2 = [12288]
+# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1 * 2048]
+# dim2 = [12288]
dim2 = [2048]
-#dim1 = [2]
-#dim2 = [2]
+# dim1 = [2]
+# dim2 = [2]
dtype = [torch.int8]
-values = list(product(dim1,dim2, dtype))
-names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, dtype))
+names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
threshold = 6.0
- #threshold = 2.8
- #threshold = 0.0
- A = torch.randn(dim1, dim2, device='cuda').half()
- B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16)
+ # threshold = 2.8
+ # threshold = 0.0
+ A = torch.randn(dim1, dim2, device="cuda").half()
+ B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
torch.nn.init.xavier_uniform_(B)
Bt = B.t().contiguous()
-
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
rowidx = torch.randint(0, A.shape[-1], size=(15,))
@@ -1507,12 +1729,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half())
out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
- out3 = out3*statsBt.half()/127
+ out3 = out3 * statsBt.half() / 127
values, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int()
@@ -1521,56 +1745,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
- p = 200/(2048*12288*4)
+ p = 200 / (2048 * 12288 * 4)
n = out1.numel()
- count = math.ceil(p*n)
+ count = math.ceil(p * n)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
-
-
- #torch.cuda.synchronize()
- #t0 = time.time()
- #for i in range(100):
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
- #torch.cuda.synchronize()
- #print('fp16', time.time() - t0)
+ # torch.cuda.synchronize()
+ # print('fp16', time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
- out2 = F.spmm_coo(cooA, B)
+ out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
- print('cusparse fp16', time.time() - t0)
+ print("cusparse fp16", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
- out2 = F.spmm_coo_very_sparse(cooA, CBt)
+ out2 = F.spmm_coo_very_sparse(cooA, CBt)
torch.cuda.synchronize()
- print('int8', time.time() - t0)
+ print("int8", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
- out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
torch.cuda.synchronize()
- print('int8+dequant', time.time() - t0)
+ print("int8+dequant", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
- out2 = torch.matmul(A, B)
+ out2 = torch.matmul(A, B)
torch.cuda.synchronize()
- print('matmul', time.time() - t0)
+ print("matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
- out = out1+out2
+ out = out1 + out2
torch.cuda.synchronize()
- print('sparse+ matmul', time.time() - t0)
+ print("sparse+ matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
@@ -1578,33 +1800,38 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
out1 = bnb.matmul(A, Bt)
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
torch.cuda.synchronize()
- print('partial matmul', time.time() - t0)
+ print("partial matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.cuda.synchronize()
- print('partial matmul', time.time() - t0)
+ print("partial matmul", time.time() - t0)
+
batch_size = 1
seqdim = 2048
values = []
-values.append((batch_size, seqdim, 768, 4*768))
-#values.append((batch_size, seqdim, 1024, 4*1024))
-#values.append((batch_size, seqdim, 1536, 4*1536))
-#values.append((batch_size, seqdim, 2048, 4*2048))
-#values.append((batch_size, seqdim, 2560, 4*2560))
-#values.append((batch_size, seqdim, 4096, 4*4096))
-#values.append((batch_size, seqdim, 5140, 4*5140))
-#values.append((batch_size, seqdim, 12288, 4*12288))
-names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
+values.append((batch_size, seqdim, 768, 4 * 768))
+# values.append((batch_size, seqdim, 1024, 4*1024))
+# values.append((batch_size, seqdim, 1536, 4*1536))
+# values.append((batch_size, seqdim, 2048, 4*2048))
+# values.append((batch_size, seqdim, 2560, 4*2560))
+# values.append((batch_size, seqdim, 4096, 4*4096))
+# values.append((batch_size, seqdim, 5140, 4*5140))
+# values.append((batch_size, seqdim, 12288, 4*12288))
+names = [
+ "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
formatB = F.get_special_format_str()
- A = torch.randn(batch, seq, model, device='cuda').half()
- B = torch.empty(hidden, model, dtype=torch.float16, device='cuda')
+ A = torch.randn(batch, seq, model, device="cuda").half()
+ B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
@@ -1613,31 +1840,37 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
- linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
+ linearMixedBit = (
+ bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
+ )
linearMixedBit.eval()
# warmup
for i in range(100):
torch.matmul(A, B.t())
torch.cuda.synchronize()
- print('')
+ print("")
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
torch.matmul(A, B.t())
torch.cuda.synchronize()
- print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
bnb.matmul(A, B)
torch.cuda.synchronize()
- print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize()
@@ -1645,7 +1878,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize()
- print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
@@ -1654,26 +1889,30 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1)
- C32A, SA = F.nvidia_transform(CA, 'col32')
+ C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
- Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
+ Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize()
- print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
- BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear')
+ BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
A2 = A.view(-1, A.shape[-1]).contiguous()
- CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear')
- C32A, SA = F.nvidia_transform(CA, 'col32')
+ CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
+ C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
- Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
- out = Cout*statsB*statsA*(1.0/(127*127))
+ Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
+ out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize()
- print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
linear8bit(A)
torch.cuda.synchronize()
@@ -1681,8 +1920,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
linear8bit(A)
torch.cuda.synchronize()
- print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
-
+ print(
+ f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
linearMixedBit(A)
torch.cuda.synchronize()
@@ -1690,65 +1930,66 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
linearMixedBit(A)
torch.cuda.synchronize()
- print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
def test_zeropoint():
def min_max(x):
maxA = torch.amax(x, dim=1, keepdim=True)
minA = torch.amin(x, dim=1, keepdim=True)
- midpoint = (maxA-minA)/2.0
- dyna = 252/(maxA-minA)
- #dyna *= 0.98
- x = dyna*x
- x = x - torch.round((dyna*(minA+midpoint)))
+ midpoint = (maxA - minA) / 2.0
+ dyna = 252 / (maxA - minA)
+ # dyna *= 0.98
+ x = dyna * x
+ x = x - torch.round((dyna * (minA + midpoint)))
return x.to(torch.int8), minA, midpoint, dyna
+
batch = 2
seq = 2
model = 4
- hidden = 2*model
- #batch = 4
- #seq = 2048
- #model = 1024
- #hidden = 8*model
- A = torch.randn(batch*seq, model, device='cuda').half()-0.4
- B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half())
-
- #A[0] = 0
- #B[:, 0] = 0
- #A = A*(A>0)
- #A[0, 0] = 0
- #A[0, 0] = 6.0
+ hidden = 2 * model
+ # batch = 4
+ # seq = 2048
+ # model = 1024
+ # hidden = 8*model
+ A = torch.randn(batch * seq, model, device="cuda").half() - 0.4
+ B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half())
+
+ # A[0] = 0
+ # B[:, 0] = 0
+ # A = A*(A>0)
+ # A[0, 0] = 0
+ # A[0, 0] = 6.0
Ac, minA, midpoint, dyna = min_max(A)
- #print(Ac[0, 0], 'zero')
- #print(Ac, Ac.min(), Ac.max())
- Bc, maxB = F.vectorwise_quant(B, quant_type='linear')
+ # print(Ac[0, 0], 'zero')
+ # print(Ac, Ac.min(), Ac.max())
+ Bc, maxB = F.vectorwise_quant(B, quant_type="linear")
out = F.igemm(Ac, Bc)
- out2 = torch.matmul(A,B)
- offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna
+ out2 = torch.matmul(A, B)
+ offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna
out = out.float()
- #print(out.shape, maxB.shape, scale.shape, offset.shape)
- norm1 = maxB/127
- C4 = (out/dyna)*norm1+offset
-
+ # print(out.shape, maxB.shape, scale.shape, offset.shape)
+ norm1 = maxB / 127
+ C4 = (out / dyna) * norm1 + offset
B1 = torch.nn.Parameter(B.clone())
B2 = torch.nn.Parameter(B.clone())
B3 = torch.nn.Parameter(B.clone())
B4 = torch.nn.Parameter(B.clone())
-
C1 = torch.matmul(A, B1)
- C2 = bnb.matmul_cublas(A, B2, None, 'linear')
- C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint')
- C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint')
+ C2 = bnb.matmul_cublas(A, B2, None, "linear")
+ C3 = bnb.matmul_cublas(A, B3, None, "zeropoint")
+ C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint")
- err1 = torch.abs(C1-C2).mean().item()
- err2 = torch.abs(C1-C3).mean().item()
- err3 = torch.abs(C1-C4).mean().item()
+ err1 = torch.abs(C1 - C2).mean().item()
+ err2 = torch.abs(C1 - C3).mean().item()
+ err3 = torch.abs(C1 - C4).mean().item()
print(err1, err2, err3)
- #assert err1 > err2
+ # assert err1 > err2
loss1 = C1.mean()
loss2 = C2.mean()
@@ -1765,40 +2006,38 @@ def test_zeropoint():
print(B2.grad)
print(B3.grad)
print(B4.grad)
- err1 = torch.abs(B1.grad-B2.grad).mean().item()
- err2 = torch.abs(B1.grad-B3.grad).mean().item()
- err3 = torch.abs(B1.grad-B4.grad).mean().item()
+ err1 = torch.abs(B1.grad - B2.grad).mean().item()
+ err2 = torch.abs(B1.grad - B3.grad).mean().item()
+ err3 = torch.abs(B1.grad - B4.grad).mean().item()
print(err1, err2, err3)
-
-
def test_zp():
def quant_zp(x):
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
- if dyna == 0: dyna = 1
- qx = 254./dyna
+ if dyna == 0:
+ dyna = 1
+ qx = 254.0 / dyna
minx = x.min()
- #zpx = torch.round(minx* qx)
- #zpx = 127 - torch.round(x.max()* qx)
- zpx = torch.round(x.min()* qx) - 127
- x = (qx*x) + zpx
+ # zpx = torch.round(minx* qx)
+ # zpx = 127 - torch.round(x.max()* qx)
+ zpx = torch.round(x.min() * qx) - 127
+ x = (qx * x) + zpx
return x, qx, zpx
+
batch = 2
seq = 512
model = 1024
- hidden = 4*model
- A = torch.randn(batch*seq, model, device='cuda').half()*0.1
- B = torch.randn(model, hidden, device='cuda').half()*0.1
-
+ hidden = 4 * model
+ A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
+ B = torch.randn(model, hidden, device="cuda").half() * 0.1
C0 = torch.matmul(A, B)
-
- #A, SA = F.vectorwise_quant(A, quant_type='linear')
- #B, SB = F.vectorwise_quant(B, quant_type='linear')
+ # A, SA = F.vectorwise_quant(A, quant_type='linear')
+ # B, SB = F.vectorwise_quant(B, quant_type='linear')
A = A.float()
B = B.float()
@@ -1806,69 +2045,68 @@ def test_zp():
C3 = bnb.matmul(A.half(), B.t().contiguous().half())
zp = 1
- #C2 = torch.matmul(A-zp, B)
- #C2 += B.sum(0).view(1, -1)*zp
- C2 = torch.matmul(A, B-zp)
- C2 -= A.sum(1).view(-1, 1)*zp
+ # C2 = torch.matmul(A-zp, B)
+ # C2 += B.sum(0).view(1, -1)*zp
+ C2 = torch.matmul(A, B - zp)
+ C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A)
print(ca.min(), ca.max())
- print((ca-cza).min(), (ca-cza).max())
+ print((ca - cza).min(), (ca - cza).max())
zp = 1
scale = 2.0
- C5 = torch.matmul((A*scale)-zp, B)
- C5 += B.sum(0)*zp
+ C5 = torch.matmul((A * scale) - zp, B)
+ C5 += B.sum(0) * zp
C5 /= scale
CA, qa, zpa = quant_zp(A)
C4 = torch.matmul(CA, B)
- C4 -= B.sum(0)*zpa
+ C4 -= B.sum(0) * zpa
C4 /= qa
zpb = 1
zpa = 1
qa = 2
qb = 2
- C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb)
- C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
- C6 -= zpa*zpb*A.shape[1]
- C6 /= qa*qb
+ C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
+ C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
+ C6 -= zpa * zpb * A.shape[1]
+ C6 /= qa * qb
CA, qa, zpa = quant_zp(A)
CB, qb, zpb = quant_zp(B)
C7 = torch.matmul(CA, CB)
- C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
- C7 -= zpa*zpb*A.shape[1]
- C7 /= qa*qb
+ C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
+ C7 -= zpa * zpb * A.shape[1]
+ C7 /= qa * qb
- print('')
- #print(C0.flatten()[:10])
+ print("")
+ # print(C0.flatten()[:10])
print(C1.flatten()[:10])
print(C2.flatten()[:10])
print(C3.flatten()[:10])
print(C5.flatten()[:10])
print(C6.flatten()[:10])
print(C7.flatten()[:10])
- err1 = torch.abs(C1-C2).mean().item()
- err2 = torch.abs(C1-C3).mean().item()
- err3 = torch.abs(C1-C4).mean().item()
- err4 = torch.abs(C1-C5).mean().item()
- err5 = torch.abs(C1-C6).mean().item()
- err6 = torch.abs(C1-C7).mean().item()
+ err1 = torch.abs(C1 - C2).mean().item()
+ err2 = torch.abs(C1 - C3).mean().item()
+ err3 = torch.abs(C1 - C4).mean().item()
+ err4 = torch.abs(C1 - C5).mean().item()
+ err5 = torch.abs(C1 - C6).mean().item()
+ err6 = torch.abs(C1 - C7).mean().item()
print(err1, err2, err3, err4, err5, err6)
-
def test_extract_outliers():
for i in range(k):
- shapeA = (4096, 4096*4)
+ shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
- #idx = torch.Tensor([0]).int().cuda()
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ # idx = torch.Tensor([0]).int().cuda()
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
outliers1 = A[:, idx.long()]
- CA, SA = F.transform(A, 'col_turing')
+ CA, SA = F.transform(A, "col_turing")
outliers2 = F.extract_outliers(CA, SA, idx)
@@ -1877,7 +2115,7 @@ def test_extract_outliers():
torch.testing.assert_allclose(outliers1, outliers2)
- CA, SA = F.transform(A, 'col_ampere')
+ CA, SA = F.transform(A, "col_ampere")
outliers2 = F.extract_outliers(CA, SA, idx)
diff --git a/tests/test_modules.py b/tests/test_modules.py
index a2c950b..7faadb8 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -1,21 +1,27 @@
+from itertools import product
+
import pytest
import torch
-
-from itertools import product
from torch import nn
import bitsandbytes as bnb
+
class MockArgs(object):
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
+
class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
super(MLP8bit, self).__init__()
- self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
- self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)
+ self.fc1 = bnb.nn.Linear8bitLt(
+ dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
+ )
+ self.fc2 = bnb.nn.Linear8bitLt(
+ dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
+ )
def forward(self, x):
x = self.fc1(x)
@@ -25,108 +31,120 @@ class MLP8bit(torch.nn.Module):
def get_args():
args = MockArgs([])
- args.quant_type = 'vector'
- args.use_8bit_training = 'full'
+ args.quant_type = "vector"
+ args.use_8bit_training = "full"
args.clip_freq = 9999
return args
+
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
idx = torch.isclose(a, b, rtol, atol)
- sumval = (idx==0).sum().item()
+ sumval = (idx == 0).sum().item()
if sumval > count:
- print(f'Too many values not close: assert {sumval} < {count}')
+ print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol)
-class LinearFunction(torch.autograd.Function):
+class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
- norm = math.sqrt(math.pi)/math.sqrt(2.0)
- #std = torch.abs(x).mean()*norm
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
+ norm = math.sqrt(math.pi) / math.sqrt(2.0)
+ # std = torch.abs(x).mean()*norm
std = torch.std(x)
- max1 = std*trim_value
- x = x/max1*127
+ max1 = std * trim_value
+ x = x / max1 * 127
x = round_func(x)
x[x > 127] = 127
x[x < -127] = -127
- x = x/127*max1
+ x = x / 127 * max1
return x
def quant(x, quant_type, dim=1):
- if quant_type == 'linear':
+ if quant_type == "linear":
max1 = torch.abs(x).max().float()
- xq = torch.round(x/max1*127).to(torch.int8)
+ xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
- elif quant_type == 'vector':
+ elif quant_type == "vector":
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
- xq = torch.round(x/max1*127).to(torch.int8)
+ xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1
- elif quant_type == 'min-max':
+ elif quant_type == "min-max":
maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float()
- scale = (maxA-minA)/2.0
- xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
+ scale = (maxA - minA) / 2.0
+ xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
return xq, (minA.float(), scale.float())
- else: return None
+ else:
+ return None
def dequant(xq, S1, S2, dtype, quant_type):
- if quant_type == 'linear':
- norm = S1*S2/(127*127)
+ if quant_type == "linear":
+ norm = S1 * S2 / (127 * 127)
# double cast needed to prevent overflows
- return (xq.float()*norm).to(dtype)
- elif quant_type == 'vector':
+ return (xq.float() * norm).to(dtype)
+ elif quant_type == "vector":
x = xq.float()
- if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
- if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
- #print(x.shape, S1.shape, S2.shape)
+ if len(xq.shape) == 2 and len(S1.shape) == 3:
+ S1 = S1.squeeze(0)
+ if len(xq.shape) == 2 and len(S2.shape) == 3:
+ S2 = S2.squeeze(0)
+ # print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2:
- x *= S1.t()/127
+ x *= S1.t() / 127
else:
- x *= S1/127
- x *= S2/127
+ x *= S1 / 127
+ x *= S2 / 127
return x.to(dtype)
- else: return None
+ else:
+ return None
def dequant_min_max(xq, A, B, SA, SB, dtype):
- offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
- if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
- if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
+ if len(xq.shape) == 2 and len(SB.shape) == 3:
+ SB = SB.squeeze(0)
+ if len(xq.shape) == 2 and len(SA.shape) == 3:
+ SA = SA.squeeze(0)
if len(SB.shape) == 2:
- x *= SB.t()/127
+ x *= SB.t() / 127
else:
- x *= SB/127
- x *= SA[1]/127
- x +=offset
+ x *= SB / 127
+ x *= SA[1] / 127
+ x += offset
return x.to(dtype)
-
def get_8bit_linear(x, stochastic=False):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
max1 = torch.abs(x).max()
- x = x/max1*127
- x = round_func(x)/127*max1
- #x = torch.round(x)/128*max1
+ x = x / max1 * 127
+ x = round_func(x) / 127 * max1
+ # x = torch.round(x)/128*max1
return x
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
- max1[max1==0] = 1.0
- x = (x*127)/max1
- x = round_func(x)/127*max1
+ max1[max1 == 0] = 1.0
+ x = (x * 127) / max1
+ x = round_func(x) / 127 * max1
return x
@staticmethod
def round_stoachastic(x):
sign = torch.sign(x)
absx = torch.abs(x)
- decimal = absx-torch.floor(absx)
+ decimal = absx - torch.floor(absx)
rdm = torch.rand_like(decimal)
- return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))
+ return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
@staticmethod
def fake_8bit_storage(w, exponent_bits):
@@ -140,10 +158,10 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
- #C = bnb.functional.quantize_no_absmax(code, w)
- #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
- #print(out)
- #out = out.half()
+ # C = bnb.functional.quantize_no_absmax(code, w)
+ # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
+ # print(out)
+ # out = out.half()
code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
@@ -162,7 +180,7 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def fake_8bit_storage_with_max(w, topk=8):
- blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
+ blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk]
max_val = max_val[:, :topk]
@@ -191,22 +209,23 @@ class LinearFunction(torch.autograd.Function):
w.copy_(unblocked_w)
return unblocked_w
-
@staticmethod
def forward(ctx, x, weight, bias=None, args=None):
- if args.use_8bit_training != 'off':
+ if args.use_8bit_training != "off":
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
- output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
- #if torch.rand(1) < 0.01:
- #output32 = torch.matmul(x, weight.t())
- #err = torch.abs(output-output32).float()
- #relerr = err/(torch.abs(output32).float()+1e-8)
- #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
+ output = LinearFunction.dequant(
+ outputq, S1, S2, x.dtype, args.quant_type
+ )
+ # if torch.rand(1) < 0.01:
+ # output32 = torch.matmul(x, weight.t())
+ # err = torch.abs(output-output32).float()
+ # relerr = err/(torch.abs(output32).float()+1e-8)
+ # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else:
- #output = torch.matmul(x, weight.t())
- output = torch.einsum('bsi,oi->bso', x, weight)
+ # output = torch.matmul(x, weight.t())
+ output = torch.einsum("bsi,oi->bso", x, weight)
ctx.save_for_backward(x, weight, bias)
ctx.args = args
@@ -221,37 +240,51 @@ class LinearFunction(torch.autograd.Function):
args = ctx.args
stochastic = False
grad_input = grad_weight = grad_bias = None
- if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)
+ if bias is not None and ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum(0)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
- if args.use_8bit_training == 'forward+wgrad':
- grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ if args.use_8bit_training == "forward+wgrad":
+ grad_output8, S1 = LinearFunction.quant(
+ grad_output, args.quant_type, dim=[0, 1]
+ )
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
- grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+ grad_weight = LinearFunction.dequant(
+ grad_weight8, S1, S2, grad_output.dtype, args.quant_type
+ )
- #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
+ # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight)
- elif args.use_8bit_training == 'full':
- grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ elif args.use_8bit_training == "full":
+ grad_output8, S1 = LinearFunction.quant(
+ grad_output, args.quant_type, dim=[0, 1]
+ )
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
- grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+ grad_weight = LinearFunction.dequant(
+ grad_weight8, S1, S2, grad_output.dtype, args.quant_type
+ )
- grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
+ grad_output8, S1 = LinearFunction.quant(
+ grad_output, args.quant_type, dim=2
+ )
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
- grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
+ grad_input = LinearFunction.dequant(
+ grad_input8, S1, S3, grad_output.dtype, args.quant_type
+ )
else:
grad_input = grad_output.matmul(weight)
- grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)
+ grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
return grad_input, grad_weight, grad_bias, None
+
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__()
@@ -263,7 +296,7 @@ class Linear8bit(nn.Module):
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
- self.register_parameter('bias', None)
+ self.register_parameter("bias", None)
torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
@@ -275,12 +308,11 @@ class Linear8bit(nn.Module):
return LinearFunction.apply(x, self.weight, self.bias, self.args)
-
def test_linear8bit():
l0 = torch.nn.Linear(32, 64).cuda().half()
- l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
+ l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half()
l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
- l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()
+ l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half()
l0.weight.data = l2.weight.data.clone()
l0.bias.data = l2.bias.data.clone()
@@ -292,8 +324,8 @@ def test_linear8bit():
l3.bias.data = l2.bias.data.clone()
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
- t = torch.randn(16, 8, 64, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
+ t = torch.randn(16, 8, 64, device="cuda").half()
b2 = b1.clone()
b3 = b1.clone()
b0 = b1.clone()
@@ -316,18 +348,26 @@ def test_linear8bit():
loss2.backward()
loss3.backward()
- assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
- assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
- assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
- assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
-
- err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
- err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
- err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()
-
- assert err1*0.8 < err2
- assert err2*0.8 < err3
- assert err3*0.8 < err1
+ assert_all_approx_close(
+ l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
+ )
+ assert_all_approx_close(
+ l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
+ )
+ assert_all_approx_close(
+ l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
+ )
+ assert_all_approx_close(
+ l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
+ )
+
+ err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item()
+ err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item()
+ err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item()
+
+ assert err1 * 0.8 < err2
+ assert err2 * 0.8 < err3
+ assert err3 * 0.8 < err1
l0.weight.grad = None
l1.weight.grad = None
@@ -341,23 +381,30 @@ def test_linear8bit():
threshold = [0.0, 3.0]
values = threshold
-names = ['threshold_{0}'.format(vals) for vals in values]
+names = ["threshold_{0}".format(vals) for vals in values]
+
+
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
- l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
- assert l1.weight.device.type == 'cuda'
+ l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
+ assert l1.weight.device.type == "cuda"
assert l1.weight.dtype == torch.float16
l1.eval()
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
if i == 1:
assert l1.state.CxB is not None
+
def test_linear8bitlt_accumulated_gradient():
- l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
- l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
+ l1 = torch.nn.Sequential(
+ *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
+ )
+ l2 = torch.nn.Sequential(
+ *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
+ )
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
@@ -367,9 +414,8 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps = 10
-
for i in range(10):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
o2 = l2(b1)
loss1 = o1.mean()
@@ -385,8 +431,12 @@ def test_linear8bitlt_accumulated_gradient():
opt1.zero_grad(True)
opt2.step()
opt2.zero_grad(True)
- assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
- assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
+ assert_all_approx_close(
+ l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
+ )
+ assert_all_approx_close(
+ l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
+ )
# we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
@@ -397,15 +447,21 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0]
values = threshold
-names = ['threshold_{0}'.format(vals) for vals in values]
+names = ["threshold_{0}".format(vals) for vals in values]
+
+
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold):
- l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ l1 = (
+ bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
+ .cuda()
+ .half()
+ )
assert l1.weight.dtype == torch.int8
l1.eval()
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
assert o1.dtype == torch.float16
@@ -414,57 +470,82 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
-
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
+
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .cuda()
+ .half()
+ )
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .half()
+ .cuda()
+ )
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
-
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .half()
+ .to("cuda")
+ )
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
- assert mlp.fc1.weight.device.type == 'cuda'
- assert mlp.fc2.weight.device.type == 'cuda'
+ assert mlp.fc1.weight.device.type == "cuda"
+ assert mlp.fc2.weight.device.type == "cuda"
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .to(torch.float16)
+ .to("cuda")
+ )
for i in range(100):
- b1 = torch.randn(16, 8, 32, device='cuda').half()
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0: assert mlp.fc1.state.idx is not None
- if threshold > 0: assert mlp.fc2.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc1.state.idx is not None
+ if threshold > 0:
+ assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
- assert mlp.fc1.weight.device.type == 'cuda'
- assert mlp.fc2.weight.device.type == 'cuda'
+ assert mlp.fc1.weight.device.type == "cuda"
+ assert mlp.fc2.weight.device.type == "cuda"
diff --git a/tests/test_optim.py b/tests/test_optim.py
index b173eaa..8e12761 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -1,81 +1,138 @@
+import ctypes
import os
-import time
import shutil
+import time
import uuid
+from itertools import product
+from os.path import join
+
import pytest
-import ctypes
import torch
+
import bitsandbytes as bnb
import bitsandbytes.functional as F
-from os.path import join
-from itertools import product
-
-#import apex
+# import apex
k = 20
+
def get_temp_dir():
- path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
+ path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
os.makedirs(path, exist_ok=True)
return path
+
def rm_path(path):
shutil.rmtree(path)
+
str2optimizers = {}
-str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
-#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
-#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
-str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
-#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
-#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
-
-str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
-#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
-str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
-#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
-str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
-str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
-#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
-str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
-
-str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
-str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
-str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
+str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
+# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
+# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
+str2optimizers["momentum_pytorch"] = (
+ None,
+ lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
+ bnb.optim.Adam,
+)
+# str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
+# str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
+
+str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
+# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
+str2optimizers["momentum"] = (
+ lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
+)
+str2optimizers["lars"] = (
+ lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
+)
+# str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
+str2optimizers["rmsprop"] = (
+ lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
+)
+str2optimizers["adam8bit"] = (
+ torch.optim.Adam,
+ lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
+)
+str2optimizers["momentum8bit"] = (
+ lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
+)
+str2optimizers["rmsprop8bit"] = (
+ lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
+)
+# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
+str2optimizers["lars8bit"] = (
+ lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
+)
+
+str2optimizers["adam8bit_blockwise"] = (
+ torch.optim.Adam,
+ lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
+)
+str2optimizers["momentum8bit_blockwise"] = (
+ lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
+)
+str2optimizers["rmsprop8bit_blockwise"] = (
+ lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
+ lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
+)
str2statenames = {}
-str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
-str2statenames['momentum'] = [('momentum_buffer', 'state1')]
-str2statenames['lars'] = [('momentum_buffer', 'state1')]
-str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
-str2statenames['rmsprop'] = [('square_avg', 'state1')]
-str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
-str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
-str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
-str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
-str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
-str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
-str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
-str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
+str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
+str2statenames["momentum"] = [("momentum_buffer", "state1")]
+str2statenames["lars"] = [("momentum_buffer", "state1")]
+str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
+str2statenames["rmsprop"] = [("square_avg", "state1")]
+str2statenames["adam8bit"] = [
+ ("exp_avg", "state1", "qmap1", "max1"),
+ ("exp_avg_sq", "state2", "qmap2", "max2"),
+]
+str2statenames["lamb8bit"] = [
+ ("exp_avg", "state1", "qmap1", "max1"),
+ ("exp_avg_sq", "state2", "qmap2", "max2"),
+]
+str2statenames["adam8bit_blockwise"] = [
+ ("exp_avg", "state1", "qmap1", "absmax1"),
+ ("exp_avg_sq", "state2", "qmap2", "absmax2"),
+]
+str2statenames["momentum8bit"] = [
+ ("momentum_buffer", "state1", "qmap1", "max1")
+]
+str2statenames["momentum8bit_blockwise"] = [
+ ("momentum_buffer", "state1", "qmap1", "absmax1")
+]
+str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
+str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
+str2statenames["rmsprop8bit_blockwise"] = [
+ ("square_avg", "state1", "qmap1", "absmax1")
+]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
-values = list(product(dim1,dim2, gtype, optimizer_names))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
+values = list(product(dim1, dim2, gtype, optimizer_names))
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
-
torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
@@ -84,9 +141,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
else:
atol, rtol = 1e-4, 1e-3
-
for i in range(k):
- g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
@@ -94,21 +150,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]:
- torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
+ torch.testing.assert_allclose(
+ torch_optimizer.state[p1][name1],
+ bnb_optimizer.state[p2][name2],
+ atol=atol,
+ rtol=rtol,
+ )
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
- if i % (k//5) == 0 and i > 0:
+ if i % (k // 5) == 0 and i > 0:
path = get_temp_dir()
- torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
+ torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
- bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
+ bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]:
- torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
+ torch.testing.assert_allclose(
+ torch_optimizer.state[p1][name1],
+ bnb_optimizer.state[p2][name2],
+ atol=atol,
+ rtol=rtol,
+ )
if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit
@@ -118,20 +184,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.half().float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2)
- if optim_name in ['lars', 'lamb']:
- assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0
+ if optim_name in ["lars", "lamb"]:
+ assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
+
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
-values = list(product(dim1,dim2, gtype))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, gtype))
+names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
- p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
- p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
+ p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
+ p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
mask = torch.rand_like(p2) < 0.1
beta1 = 0.9
beta2 = 0.999
@@ -139,9 +209,13 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
- bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
+ bnb.optim.GlobalOptimManager.get_instance().override_config(
+ p3, "optim_bits", 8
+ )
- bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
+ bnb.optim.GlobalOptimManager.get_instance().register_parameters(
+ [p1, p2, p3]
+ )
p1 = p1.cuda()
p2 = p2.cuda()
p3 = p3.cuda()
@@ -154,30 +228,43 @@ def test_global_config(dim1, dim2, gtype):
atol, rtol = 1e-4, 1e-3
for i in range(50):
- g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
- g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
- g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
+ g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
+ g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
+ g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
p1.grad = g1
p2.grad = g2
p3.grad = g3
adam2.step()
- assert adam2.state[p3]['state1'].dtype == torch.uint8
- assert adam2.state[p3]['state2'].dtype == torch.uint8
-
+ assert adam2.state[p3]["state1"].dtype == torch.uint8
+ assert adam2.state[p3]["state2"].dtype == torch.uint8
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
-values = list(product(dim1,dim2, gtype, optimizer_names))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+optimizer_names = [
+ "adam8bit",
+ "momentum8bit",
+ "rmsprop8bit",
+ "adam8bit_blockwise",
+ "lamb8bit",
+ "lars8bit",
+ "momentum8bit_blockwise",
+ "rmsprop8bit_blockwise",
+]
+values = list(product(dim1, dim2, gtype, optimizer_names))
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 2048
@@ -197,7 +284,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors = []
for i in range(50):
- g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
@@ -208,17 +295,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
- #print(bnb_optimizer.state[p2][max_val], name1)
- if 'blockwise' in optim_name:
- s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
+ # print(bnb_optimizer.state[p2][max_val], name1)
+ if "blockwise" in optim_name:
+ s1 = F.dequantize_blockwise(
+ code=bnb_optimizer.state[p2][qmap],
+ absmax=bnb_optimizer.state[p2][max_val],
+ A=bnb_optimizer.state[p2][name2],
+ blocksize=blocksize,
+ )
else:
- s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
- num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
+ s1 = F.dequantize(
+ code=bnb_optimizer.state[p2][qmap],
+ absmax=bnb_optimizer.state[p2][max_val],
+ A=bnb_optimizer.state[p2][name2],
+ )
+ num_not_close = (
+ torch.isclose(
+ torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
+ )
+ == 0
+ )
assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone())
- err = torch.abs(p1-p2)
- relerr = err/torch.abs(p1)
+ err = torch.abs(p1 - p2)
+ relerr = err / torch.abs(p1)
assert err.mean() < 0.0001
assert relerr.mean() < 0.001
@@ -226,54 +327,86 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0:
- for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
+ for (name1, name2, qmap, max_val), s in zip(
+ str2statenames[optim_name], dequant_states
+ ):
s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone()
path = get_temp_dir()
- torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
+ torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer
bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2])
- bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
+ bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
- torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
- torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
-
- if 'blockwise' in optim_name:
- s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
+ torch.testing.assert_allclose(
+ raws1cpy, bnb_optimizer.state[p2][name2]
+ )
+ torch.testing.assert_allclose(
+ qmap1, bnb_optimizer.state[p2][qmap]
+ )
+
+ if "blockwise" in optim_name:
+ s1 = F.dequantize_blockwise(
+ code=bnb_optimizer.state[p2][qmap],
+ absmax=bnb_optimizer.state[p2][max_val],
+ A=bnb_optimizer.state[p2][name2],
+ blocksize=blocksize,
+ )
else:
- s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
+ s1 = F.dequantize(
+ code=bnb_optimizer.state[p2][qmap],
+ absmax=bnb_optimizer.state[p2][max_val],
+ A=bnb_optimizer.state[p2][name2],
+ )
torch.testing.assert_allclose(s1cpy, s1)
- num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
+ num_not_close = (
+ torch.isclose(
+ torch_optimizer.state[p1][name1],
+ s1,
+ atol=atol,
+ rtol=rtol,
+ )
+ == 0
+ )
assert num_not_close.sum().item() < 20
- torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
+ torch.testing.assert_allclose(
+ p1, p2.float(), atol=patol, rtol=prtol
+ )
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2)
- for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
+ for (name1, name2, qmap, max_val), s in zip(
+ str2statenames[optim_name], dequant_states
+ ):
torch_optimizer.state[p1][name1].copy_(s.data)
- #print(sum(errors)/len(errors))
- #print(sum(relerrors)/len(relerrors))
-
+ # print(sum(errors)/len(errors))
+ # print(sum(relerrors)/len(relerrors))
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
-values = list(product(dim1,dim2, gtype, optim_bits))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, gtype, optim_bits))
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
+ for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
beta1 = 0.9
beta2 = 0.999
lr = 0.001
@@ -281,19 +414,30 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p1 = p1.cuda()
p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
- adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
+ adam2 = bnb.optim.Adam(
+ [p2],
+ lr,
+ (beta1, beta2),
+ eps,
+ optim_bits=optim_bits,
+ percentile_clipping=5,
+ )
gnorm_vec = torch.zeros(100).cuda()
step = 0
for i in range(50):
step += 1
- g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i)
+ g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
+ 0.01 * i
+ )
g2 = g1.clone()
p2.grad = g2
- current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
- g1 = (g1.float()*gnorm_scale).to(gtype)
+ current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
+ g1, gnorm_vec, step, 5
+ )
+ g1 = (g1.float() * gnorm_scale).to(gtype)
p1.grad = g1
adam1.step()
@@ -302,47 +446,77 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32:
torch.testing.assert_allclose(p1, p2)
- torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4)
- torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4)
+ torch.testing.assert_allclose(
+ adam1.state[p1]["state1"],
+ adam2.state[p2]["state1"],
+ atol=5e-5,
+ rtol=1e-4,
+ )
+ torch.testing.assert_allclose(
+ adam1.state[p1]["state2"],
+ adam2.state[p2]["state2"],
+ atol=5e-5,
+ rtol=1e-4,
+ )
elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
- torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3)
- torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3)
- adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1'])
- adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2'])
+ torch.testing.assert_allclose(
+ adam1.state[p1]["state1"],
+ adam2.state[p2]["state1"],
+ atol=2,
+ rtol=1e-3,
+ )
+ torch.testing.assert_allclose(
+ adam1.state[p1]["state2"],
+ adam2.state[p2]["state2"],
+ atol=2,
+ rtol=1e-3,
+ )
+ adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
+ adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
if i % 10 == 0 and i > 0:
path = get_temp_dir()
- torch.save(adam2.state_dict(),join(path, 'opt.pt'))
+ torch.save(adam2.state_dict(), join(path, "opt.pt"))
del adam2
adam2 = None
- adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
- adam2.load_state_dict(torch.load(join(path, 'opt.pt')))
-
-
+ adam2 = bnb.optim.Adam(
+ [p2],
+ lr,
+ (beta1, beta2),
+ eps,
+ optim_bits=optim_bits,
+ percentile_clipping=5,
+ )
+ adam2.load_state_dict(torch.load(join(path, "opt.pt")))
dim1 = [4096]
dim2 = [4096]
gtype = [torch.float32, torch.float16]
-#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
-#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
-#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
-#optimizer_names = ['lamb_apex', 'lamb8bit']
-#optimizer_names = ['lars_apex', 'lars8bit']
-optimizer_names = ['adam8bit_blockwise']
-values = list(product(dim1,dim2, gtype, optimizer_names))
-names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
+# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
+# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
+# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
+# optimizer_names = ['lamb_apex', 'lamb8bit']
+# optimizer_names = ['lars_apex', 'lars8bit']
+optimizer_names = ["adam8bit_blockwise"]
+values = list(product(dim1, dim2, gtype, optimizer_names))
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
- if dim1 == 1 and dim2 == 1: return
- p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
+ if dim1 == 1 and dim2 == 1:
+ return
+ p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1])
- g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
+ g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g
for i in range(k):
- if i == k//5:
+ if i == k // 5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
@@ -350,10 +524,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
bnb_optimizer.step()
torch.cuda.synchronize()
- s = time.time()-t0
- print('')
- params = (k-k//5)*dim1*dim2
- print(optim_name, gtype, s/params)
- #assert s < 3.9
-
-
+ s = time.time() - t0
+ print("")
+ params = (k - k // 5) * dim1 * dim2
+ print(optim_name, gtype, s / params)
+ # assert s < 3.9
diff --git a/to_be_fixed__complaints_by_linter.log b/to_be_fixed__complaints_by_linter.log
new file mode 100644
index 0000000..d696729
--- /dev/null
+++ b/to_be_fixed__complaints_by_linter.log
@@ -0,0 +1,149 @@
+./setup.py:20:10: F541 f-string is missing placeholders
+./setup.py:21:13: F541 f-string is missing placeholders
+./quicktest.py:5:1: F401 'bitsandbytes as bnb' imported but unused
+./bitsandbytes/cuda_setup.py:42:56: F821 undefined name 'error_str'
+./bitsandbytes/cuda_setup.py:43:15: F541 f-string is missing placeholders
+./bitsandbytes/cuda_setup.py:67:5: F841 local variable 'context' is assigned to but never used
+./bitsandbytes/cuda_setup.py:68:5: F841 local variable 'error_str' is assigned to but never used
+./bitsandbytes/cuda_setup.py:76:9: F841 local variable 'result' is assigned to but never used
+./bitsandbytes/cuda_setup.py:144:13: F841 local variable 'has_gpu' is assigned to but never used
+./bitsandbytes/functional.py:294:13: F821 undefined name 'math'
+./bitsandbytes/functional.py:295:16: F821 undefined name 'math'
+./bitsandbytes/functional.py:303:5: F841 local variable 'ptrA' is assigned to but never used
+./bitsandbytes/functional.py:304:5: F841 local variable 'ptrOut' is assigned to but never used
+./bitsandbytes/functional.py:1057:17: W503 line break before binary operator
+./bitsandbytes/functional.py:1058:17: W503 line break before binary operator
+./bitsandbytes/functional.py:1059:17: W503 line break before binary operator
+./bitsandbytes/functional.py:1649:1: F811 redefinition of unused 'get_special_format_str' from line 160
+./bitsandbytes/functional.py:1687:5: F841 local variable 'ptrA' is assigned to but never used
+./bitsandbytes/functional.py:1688:5: F841 local variable 'ptrOut' is assigned to but never used
+./bitsandbytes/functional.py:1802:5: F841 local variable 'ccolsA' is assigned to but never used
+./bitsandbytes/functional.py:1805:5: F841 local variable 'cldb' is assigned to but never used
+./bitsandbytes/functional.py:1806:5: F841 local variable 'cldc' is assigned to but never used
+./bitsandbytes/functional.py:1873:9: F841 local variable 'dtype' is assigned to but never used
+./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.MatmulLtState' imported but unused
+./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.bmm_cublas' imported but unused
+./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul' imported but unused
+./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul_cublas' imported but unused
+./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.mm_cublas' imported but unused
+./bitsandbytes/__init__.py:9:1: F401 '.nn.modules' imported but unused
+./bitsandbytes/__init__.py:12:5: F401 '.optim.adam' imported but unused
+./bitsandbytes/autograd/_functions.py:5:1: F401 'bitsandbytes as bnb' imported but unused
+./bitsandbytes/autograd/_functions.py:12:75: W291 trailing whitespace
+./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Int8Params' imported but unused
+./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bit' imported but unused
+./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bitLt' imported but unused
+./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.StableEmbedding' imported but unused
+./bitsandbytes/nn/modules.py:5:1: F401 'typing.Any' imported but unused
+./bitsandbytes/nn/modules.py:5:1: F401 'typing.Callable' imported but unused
+./bitsandbytes/nn/modules.py:5:1: F401 'typing.Dict' imported but unused
+./bitsandbytes/nn/modules.py:5:1: F401 'typing.Iterator' imported but unused
+./bitsandbytes/nn/modules.py:5:1: F401 'typing.Mapping' imported but unused
+./bitsandbytes/nn/modules.py:5:1: F401 'typing.Set' imported but unused
+./bitsandbytes/nn/modules.py:5:1: F401 'typing.Tuple' imported but unused
+./bitsandbytes/nn/modules.py:11:1: F401 'torch.nn.parameter.Parameter' imported but unused
+./bitsandbytes/nn/modules.py:183:13: W503 line break before binary operator
+./bitsandbytes/nn/modules.py:184:13: W503 line break before binary operator
+./bitsandbytes/nn/modules.py:272:24: F821 undefined name 'dist'
+./bitsandbytes/nn/modules.py:272:49: F821 undefined name 'dist'
+./bitsandbytes/optim/optimizer.py:243:9: F841 local variable 'overflows' is assigned to but never used
+./bitsandbytes/optim/optimizer.py:280:35: F541 f-string is missing placeholders
+./bitsandbytes/optim/optimizer.py:283:35: F541 f-string is missing placeholders
+./bitsandbytes/optim/lars.py:27:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/lars.py:59:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/lars.py:91:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/lars.py:157:13: F841 local variable 'params_with_grad' is assigned to but never used
+./bitsandbytes/optim/lars.py:158:13: F841 local variable 'd_p_list' is assigned to but never used
+./bitsandbytes/optim/lars.py:159:13: F841 local variable 'momentum_buffer_list' is assigned to but never used
+./bitsandbytes/optim/lars.py:174:35: F821 undefined name 'param'
+./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam' imported but unused
+./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam8bit' imported but unused
+./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam32bit' imported but unused
+./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW' imported but unused
+./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW8bit' imported but unused
+./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW32bit' imported but unused
+./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD' imported but unused
+./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD8bit' imported but unused
+./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD32bit' imported but unused
+./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS' imported but unused
+./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS8bit' imported but unused
+./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS32bit' imported but unused
+./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.PytorchLARS' imported but unused
+./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB' imported but unused
+./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB8bit' imported but unused
+./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB32bit' imported but unused
+./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop' imported but unused
+./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop8bit' imported but unused
+./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop32bit' imported but unused
+./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad' imported but unused
+./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad8bit' imported but unused
+./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad32bit' imported but unused
+./bitsandbytes/optim/__init__.py:17:1: F401 '.optimizer.GlobalOptimManager' imported but unused
+./bitsandbytes/optim/adam.py:229:21: F841 local variable 'max_exp_avg_sq' is assigned to but never used
+./bitsandbytes/optim/rmsprop.py:25:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/rmsprop.py:27:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/rmsprop.py:59:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/rmsprop.py:61:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/rmsprop.py:94:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/rmsprop.py:96:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/sgd.py:24:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/sgd.py:55:39: F541 f-string is missing placeholders
+./bitsandbytes/optim/sgd.py:86:39: F541 f-string is missing placeholders
+./tests/test_optim.py:1:1: F401 'ctypes' imported but unused
+./tests/test_optim.py:199:5: F841 local variable 'mask' is assigned to but never used
+./tests/test_optim.py:218:9: F841 local variable 'atol' is assigned to but never used
+./tests/test_optim.py:218:15: F841 local variable 'rtol' is assigned to but never used
+./tests/test_optim.py:304:17: W503 line break before binary operator
+./tests/test_optim.py:354:21: W503 line break before binary operator
+./tests/test_autograd.py:309:13: F841 local variable 'err' is assigned to but never used
+./tests/test_cuda_setup_evaluator.py:31:9: F821 undefined name 'test_dir'
+./tests/test_cuda_setup_evaluator.py:33:14: F821 undefined name 'test_input'
+./tests/test_cuda_setup_evaluator.py:81:32: E203 whitespace before ':'
+./tests/test_functional.py:55:13: F841 local variable 'ms' is assigned to but never used
+./tests/test_functional.py:177:5: F841 local variable 'diffs' is assigned to but never used
+./tests/test_functional.py:178:5: F841 local variable 'reldiffs' is assigned to but never used
+./tests/test_functional.py:260:5: F841 local variable 'minA' is assigned to but never used
+./tests/test_functional.py:261:5: F841 local variable 'maxA' is assigned to but never used
+./tests/test_functional.py:584:5: F841 local variable 'func' is assigned to but never used
+./tests/test_functional.py:617:17: F841 local variable 'offset' is assigned to but never used
+./tests/test_functional.py:618:17: F841 local variable 'col2' is assigned to but never used
+./tests/test_functional.py:619:17: F841 local variable 'row2' is assigned to but never used
+./tests/test_functional.py:705:9: F841 local variable 'C1' is assigned to but never used
+./tests/test_functional.py:706:9: F841 local variable 'C2' is assigned to but never used
+./tests/test_functional.py:715:9: F841 local variable 'output' is assigned to but never used
+./tests/test_functional.py:750:5: F841 local variable 'formatB' is assigned to but never used
+./tests/test_functional.py:754:5: F841 local variable 'w2' is assigned to but never used
+./tests/test_functional.py:763:5: F841 local variable 'dtype' is assigned to but never used
+./tests/test_functional.py:770:9: F841 local variable 'out1' is assigned to but never used
+./tests/test_functional.py:1108:5: F841 local variable 'relerr1' is assigned to but never used
+./tests/test_functional.py:1108:14: F841 local variable 'relerr2' is assigned to but never used
+./tests/test_functional.py:1114:9: F841 local variable 'C1' is assigned to but never used
+./tests/test_functional.py:1135:9: F841 local variable 'C4' is assigned to but never used
+./tests/test_functional.py:1179:5: F841 local variable 'err1' is assigned to but never used
+./tests/test_functional.py:1179:11: F841 local variable 'err2' is assigned to but never used
+./tests/test_functional.py:1179:17: F841 local variable 'err3' is assigned to but never used
+./tests/test_functional.py:1180:5: F841 local variable 'relerr1' is assigned to but never used
+./tests/test_functional.py:1180:14: F841 local variable 'relerr2' is assigned to but never used
+./tests/test_functional.py:1192:9: F841 local variable 'C1' is assigned to but never used
+./tests/test_functional.py:1313:9: F841 local variable 'c' is assigned to but never used
+./tests/test_functional.py:1314:9: F841 local variable 'c2' is assigned to but never used
+./tests/test_functional.py:1406:9: F841 local variable 'C1' is assigned to but never used
+./tests/test_functional.py:1425:9: F841 local variable 'out2' is assigned to but never used
+./tests/test_functional.py:1542:5: F841 local variable 'idx_col' is assigned to but never used
+./tests/test_functional.py:1566:30: E203 whitespace before ':'
+./tests/test_functional.py:1568:38: E203 whitespace before ':'
+./tests/test_functional.py:1655:5: F841 local variable 'offset' is assigned to but never used
+./tests/test_functional.py:1706:9: F841 local variable 'out' is assigned to but never used
+./tests/test_functional.py:1822:9: F841 local variable 'out' is assigned to but never used
+./tests/test_functional.py:1882:5: F841 local variable 'out2' is assigned to but never used
+./tests/test_functional.py:1928:9: F841 local variable 'dtype' is assigned to but never used
+./tests/test_functional.py:1934:9: F841 local variable 'minx' is assigned to but never used
+./tests/test_functional.py:1948:5: F841 local variable 'C0' is assigned to but never used
+./tests/test_modules.py:1:1: F401 'itertools.product' imported but unused
+./tests/test_modules.py:52:9: F841 local variable 'norm' is assigned to but never used
+./tests/test_modules.py:52:16: F821 undefined name 'math'
+./tests/test_modules.py:52:26: F821 undefined name 'math'
+./tests/test_modules.py:52:37: F821 undefined name 'math'
+./tests/test_modules.py:177:21: F821 undefined name 'einops'
+./tests/test_modules.py:233:9: F841 local variable 'stochastic' is assigned to but never used
+./tests/test_modules.py:382:9: F841 local variable 'o1' is assigned to but never used