summaryrefslogtreecommitdiff
path: root/Biz/Bild/Deps/exllama.nix
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2023-08-10 21:11:23 -0400
committerBen Sima <ben@bsima.me>2023-08-16 14:29:43 -0400
commit247678afc7c74c98f64e8d19f67355d128946974 (patch)
tree6bde2696aab9029f67ff6eb136f26b81bcd5a4c4 /Biz/Bild/Deps/exllama.nix
parent4e67ef22a7508150798413081bf8a5bb4adab6e5 (diff)
Add llama-cpp and exllama
Diffstat (limited to 'Biz/Bild/Deps/exllama.nix')
-rw-r--r--Biz/Bild/Deps/exllama.nix64
1 files changed, 64 insertions, 0 deletions
diff --git a/Biz/Bild/Deps/exllama.nix b/Biz/Bild/Deps/exllama.nix
new file mode 100644
index 0000000..54d6df1
--- /dev/null
+++ b/Biz/Bild/Deps/exllama.nix
@@ -0,0 +1,64 @@
+{ lib
+, sources
+, buildPythonPackage
+, pythonOlder
+, fetchFromGitHub
+, torch # tested on 2.0.1 and 2.1.0 (nightly) with cu118
+, safetensors
+, sentencepiece
+, ninja
+, cudaPackages
+, addOpenGLRunpath
+, which
+, gcc11 # cuda 11.7 requires g++ <12
+}:
+
+buildPythonPackage rec {
+ pname = "exllama";
+ version = sources.exllama.rev;
+ format = "setuptools";
+ disabled = pythonOlder "3.9";
+
+ src = sources.exllama;
+
+ # I only care about compiling for the Ampere architecture, which is what my
+ # RTX 3090 TI is, and for some reason (nix sandbox?) the torch extension
+ # builder
+ # cannot autodetect the arch
+ TORCH_CUDA_ARCH_LIST = "8.0;8.6+PTX";
+
+ CUDA_HOME = "${cudaPackages.cuda_nvcc}";
+
+ nativeBuildInputs = [
+ gcc11
+ which
+ addOpenGLRunpath
+ cudaPackages.cuda_nvcc
+ cudaPackages.cuda_cudart
+ ];
+
+ propagatedBuildInputs = [
+ torch safetensors sentencepiece ninja
+ cudaPackages.cudatoolkit
+ ];
+
+ doCheck = false; # no tests currently
+ pythonImportsCheck = [
+ "exllama"
+ "exllama.cuda_ext"
+ "exllama.generator"
+ "exllama.lora"
+ "exllama.model"
+ "exllama.tokenizer"
+ ];
+
+ meta = with lib; {
+ description = ''
+ A more memory-efficient rewrite of the HF transformers implementation of
+ Llama for use with quantized weights.
+ '';
+ homepage = "https://github.com/jllllll/exllama";
+ license = licenses.mit;
+ maintainers = with maintainers; [ bsima ];
+ };
+}