diff options
author | Ben Sima <ben@bsima.me> | 2023-08-28 21:05:25 -0400 |
---|---|---|
committer | Ben Sima <ben@bsima.me> | 2023-09-20 17:56:12 -0400 |
commit | 6e4a65579c3ade76feea0890072099f0d0caf416 (patch) | |
tree | 95671321c951134753323978854cece5f7d5435b | |
parent | 13added53bbf996ec25a19b734326a6834918279 (diff) |
Prototype Mynion
This implements a prototype Mynion, my chatbot which will eventually
help me write code here. In fact he's already helping me, and works
pretty well over xmpp.
The prompt is currently not checked in because I'm experimenting with it
a lot, and it should probably be a runtime parameter anyways.
In the course of writing this I added some helper libraries to get me
going, configured black (didn't even know that was possible), and added
'outlines' and its dependencies even though I didn't end up using it.
I'll keep outlines around for now, but I'm not sure how useful it really
is because afaict its just pre-defining some stop conditions. But it
took a while to get it working so I'll just keep it in for now.
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | Biz/Bild.hs | 94 | ||||
-rw-r--r-- | Biz/Bild/Builder.nix | 12 | ||||
-rw-r--r-- | Biz/Bild/Deps.nix | 3 | ||||
-rw-r--r-- | Biz/Bild/Deps/interegular.nix | 26 | ||||
-rw-r--r-- | Biz/Bild/Deps/outlines.nix | 49 | ||||
-rw-r--r-- | Biz/Bild/Deps/perscache.nix | 39 | ||||
-rw-r--r-- | Biz/Bild/Sources.json | 36 | ||||
-rwxr-xr-x | Biz/Dragons/main.py | 16 | ||||
-rwxr-xr-x | Biz/Ide/repl | 6 | ||||
-rwxr-xr-x | Biz/Ide/tidy | 2 | ||||
-rwxr-xr-x | Biz/Ide/tips | 1 | ||||
-rw-r--r-- | Biz/Lint.hs | 2 | ||||
-rw-r--r-- | Biz/Log.py | 32 | ||||
-rw-r--r-- | Biz/Mynion.py | 246 | ||||
-rw-r--r-- | Biz/Namespace.hs | 4 | ||||
-rwxr-xr-x | Biz/Que/Client.py | 4 | ||||
-rw-r--r-- | Biz/Repl.py | 36 | ||||
-rw-r--r-- | pyproject.toml | 6 |
19 files changed, 572 insertions, 44 deletions
@@ -15,3 +15,5 @@ tags dist* .envrc.local .direnv/ +__pycache__ +Biz/Mynion/Prompt.md diff --git a/Biz/Bild.hs b/Biz/Bild.hs index 9c4f035..22d3882 100644 --- a/Biz/Bild.hs +++ b/Biz/Bild.hs @@ -312,10 +312,10 @@ data Target = Target deriving (Show, Generic, Aeson.ToJSON) -- | Use this to just get a target to play with at the repl. -dev_getTarget :: IO Target -dev_getTarget = do +dev_getTarget :: FilePath -> IO Target +dev_getTarget fp = do root <- Env.getEnv "BIZ_ROOT" - path <- Dir.makeAbsolute "Biz/Bild.hs" + path <- Dir.makeAbsolute fp Namespace.fromPath root path |> \case Nothing -> panic "Could not get namespace from path" @@ -448,34 +448,35 @@ analyze hmap ns = case Map.lookup ns hmap of Namespace.Md -> pure Nothing Namespace.None -> pure Nothing Namespace.Py -> - Meta.detectAll "#" contentLines |> \Meta.Parsed {..} -> - Target - { builder = "python", - wrapper = Nothing, - compiler = CPython, - compilerFlags = - -- This doesn't really make sense for python, but I'll leave - -- it here for eventual --dev builds - [ "-c", - "\"import py_compile;import os;" - <> "py_compile.compile(file='" - <> str quapath - <> "', cfile=os.getenv('BIZ_ROOT')+'/_/int/" - <> str quapath - <> "', doraise=True)\"" - ], - sysdeps = psys, - langdeps = pdep, - outPath = outToPath pout, - out = pout, - -- implement detectPythonImports, then I can fill this out - srcs = Set.empty, - packageSet = "python.packages", - mainModule = Namespace.toModule namespace, - .. - } - |> Just - |> pure + contentLines + |> Meta.detectAll "#" + |> \Meta.Parsed {..} -> + detectPythonImports contentLines +> \srcs -> + Target + { builder = "python", + wrapper = Nothing, + compiler = CPython, + compilerFlags = + -- This doesn't really make sense for python, but I'll leave + -- it here for eventual --dev builds + [ "-c", + "\"import py_compile;import os;" + <> "py_compile.compile(file='" + <> str quapath + <> "', cfile=os.getenv('BIZ_ROOT')+'/_/int/" + <> str quapath + <> "', doraise=True)\"" + ], + sysdeps = psys, + langdeps = pdep, + outPath = outToPath pout, + out = pout, + packageSet = "python.packages", + mainModule = Namespace.toModule namespace, + .. + } + |> Just + |> pure Namespace.Sh -> pure Nothing Namespace.C -> Meta.detectAll "//" contentLines |> \Meta.Parsed {..} -> do @@ -713,6 +714,27 @@ detectLispImports contentLines = |> Set.fromList |> pure +-- | Finds local imports. Does not recurse to find transitive imports like +-- 'detectHaskellImports' does. Someday I will refactor these detection +-- functions and have a common, well-performing, complete solution. +detectPythonImports :: [Text] -> IO (Set FilePath) +detectPythonImports contentLines = + contentLines + /> Text.unpack + /> Regex.match pythonImport + |> catMaybes + /> Namespace.fromPythonModule + /> Namespace.toPath + |> filterM Dir.doesPathExist + /> Set.fromList + where + -- only detects 'import x' because I don't like 'from' + pythonImport :: Regex.RE Char String + pythonImport = + Regex.string "import" + *> Regex.some (Regex.psym Char.isSpace) + *> Regex.many (Regex.psym isModuleChar) + ghcPkgFindModule :: Set String -> String -> IO (Set String) ghcPkgFindModule acc m = Env.getEnv "GHC_PACKAGE_PATH" +> \packageDb -> @@ -755,9 +777,13 @@ build andTest loud analysis = Env.getEnv "BIZ_ROOT" +> \root -> forM (Map.elems analysis) <| \target@Target {..} -> fst </ case compiler of - CPython -> - Log.info ["bild", "nix", "python", nschunk namespace] - >> nixBuild loud target + CPython -> case out of + Meta.Bin _ -> + Log.info ["bild", "nix", "python", nschunk namespace] + >> nixBuild loud target + _ -> + Log.info ["bild", "nix", "python", nschunk namespace, "cannot build library"] + >> pure (Exit.ExitSuccess, mempty) Gcc -> Log.info ["bild", label, "gcc", nschunk namespace] >> nixBuild loud target diff --git a/Biz/Bild/Builder.nix b/Biz/Bild/Builder.nix index 8f42733..f9eb31d 100644 --- a/Biz/Bild/Builder.nix +++ b/Biz/Bild/Builder.nix @@ -92,16 +92,22 @@ let python = bild.python.buildPythonApplication rec { inherit name src BIZ_ROOT; - propagatedBuildInputs = [ (bild.python.pythonWith (_: langdeps_)) ] ++ sysdeps_; + propagatedBuildInputs = langdeps_ ++ sysdeps_; buildInputs = sysdeps_; - checkInputs = [(bild.python.pythonWith (p: with p; [black mypy])) ruff]; + nativeCheckInputs = [ black mypy ruff ]; checkPhase = '' check() { $@ || { echo "fail: $name: $3"; exit 1; } } + cp ${../../pyproject.toml} ./pyproject.toml check python -m black --quiet --exclude 'setup\.py$' --check . check ${ruff}/bin/ruff check . - check python -m mypy --strict --no-error-summary --exclude 'setup\.py$' . + touch ./py.typed + check python -m mypy \ + --explicit-package-bases \ + --no-error-summary \ + --exclude 'setup\.py$' \ + . check python -m ${mainModule} test ''; preBuild = '' diff --git a/Biz/Bild/Deps.nix b/Biz/Bild/Deps.nix index da18d89..dcb7d50 100644 --- a/Biz/Bild/Deps.nix +++ b/Biz/Bild/Deps.nix @@ -36,6 +36,9 @@ in rec exllama = callPackage ./Deps/exllama.nix { cudaPackages = super.pkgs.cudaPackages_11_7; }; + interegular = callPackage ./Deps/interegular.nix {}; + outlines = callPackage ./Deps/outlines.nix {}; + perscache = callPackage ./Deps/perscache.nix {}; }; }; diff --git a/Biz/Bild/Deps/interegular.nix b/Biz/Bild/Deps/interegular.nix new file mode 100644 index 0000000..8b0bc86 --- /dev/null +++ b/Biz/Bild/Deps/interegular.nix @@ -0,0 +1,26 @@ +{ lib +, sources +, buildPythonPackage +}: + +buildPythonPackage rec { + pname = "interegular"; + version = sources.interegular.rev; + format = "setuptools"; + + src = sources.interegular; + + propagatedBuildInputs = [ ]; + + doCheck = false; # no tests currently + pythonImportsCheck = [ + "interegular" + ]; + + meta = with lib; { + description = "Allows to check regexes for overlaps."; + homepage = "https://github.com/MegaIng/interegular"; + license = licenses.mit; + maintainers = with maintainers; [ bsima ]; + }; +} diff --git a/Biz/Bild/Deps/outlines.nix b/Biz/Bild/Deps/outlines.nix new file mode 100644 index 0000000..013581b --- /dev/null +++ b/Biz/Bild/Deps/outlines.nix @@ -0,0 +1,49 @@ +{ lib +, sources +, buildPythonPackage +, interegular +, jinja2 +, lark +, numpy +, perscache +, pillow +, pydantic +, regex +, scipy +, tenacity +, torch +}: + +buildPythonPackage rec { + pname = "outlines"; + version = sources.outlines.rev; + format = "pyproject"; + + src = sources.outlines; + + propagatedBuildInputs = [ + interegular + jinja2 + lark + numpy + perscache + pillow + pydantic + regex + scipy + tenacity + torch + ]; + + doCheck = false; # no tests currently + pythonImportsCheck = [ + "outlines" + ]; + + meta = with lib; { + description = "Probabilistic Generative Model Programming"; + homepage = "https://github.com/normal-computing/outlines"; + license = licenses.asl20; + maintainers = with maintainers; [ bsima ]; + }; +} diff --git a/Biz/Bild/Deps/perscache.nix b/Biz/Bild/Deps/perscache.nix new file mode 100644 index 0000000..d757e1a --- /dev/null +++ b/Biz/Bild/Deps/perscache.nix @@ -0,0 +1,39 @@ +{ lib +, sources +, buildPythonPackage +, beartype +, cloudpickle +, icontract +, pbr +}: + +buildPythonPackage rec { + pname = "perscache"; + version = sources.perscache.rev; + + src = sources.perscache; + + propagatedBuildInputs = [ + beartype + cloudpickle + icontract + pbr + ]; + PBR_VERSION = version; + + doCheck = false; # no tests currently + pythonImportsCheck = [ + "perscache" + ]; + + meta = with lib; { + description = '' + An easy to use decorator for persistent memoization: like + `functools.lrucache`, but results can be saved in any format to any + storage. + ''; + homepage = "https://github.com/leshchenko1979/perscache"; + license = licenses.mit; + maintainers = with maintainers; [ bsima ]; + }; +} diff --git a/Biz/Bild/Sources.json b/Biz/Bild/Sources.json index e4fcfd4..6cc4d48 100644 --- a/Biz/Bild/Sources.json +++ b/Biz/Bild/Sources.json @@ -64,6 +64,18 @@ "url_template": "https://gitlab.com/kavalogic-inc/inspekt3d/-/archive/<version>/inspekt3d-<version>.tar.gz", "version": "703f52ccbfedad2bf5240bf8183d1b573c9d54ef" }, + "interegular": { + "branch": "master", + "description": "Allows to check regexes for overlaps. Based on greenery by @qntm.", + "homepage": null, + "owner": "MegaIng", + "repo": "interegular", + "rev": "v0.2.1", + "sha256": "14f3jvnczq6qay2qp4rxchbdhkj00qs8kpacl0nrxgr0785km36k", + "type": "tarball", + "url": "https://github.com/MegaIng/interegular/archive/v0.2.1.tar.gz", + "url_template": "https://github.com/<owner>/<repo>/archive/<rev>.tar.gz" + }, "llama-cpp": { "branch": "master", "description": "Port of Facebook's LLaMA model in C/C++", @@ -110,6 +122,30 @@ "url": "https://github.com/nixos/nixpkgs/archive/61676e4dcfeeb058f255294bcb08ea7f3bc3ce56.tar.gz", "url_template": "https://github.com/<owner>/<repo>/archive/<rev>.tar.gz" }, + "outlines": { + "branch": "main", + "description": "Generative Model Programming", + "homepage": "https://normal-computing.github.io/outlines/", + "owner": "normal-computing", + "repo": "outlines", + "rev": "0.0.8", + "sha256": "1yvx5c5kplmr56nffqcb6ssjnmlikkaw32hxl6i4b607v3s0s6jv", + "type": "tarball", + "url": "https://github.com/normal-computing/outlines/archive/0.0.8.tar.gz", + "url_template": "https://github.com/<owner>/<repo>/archive/<rev>.tar.gz" + }, + "perscache": { + "branch": "master", + "description": "An easy to use decorator for persistent memoization: like `functools.lrucache`, but results can be saved in any format to any storage.", + "homepage": null, + "owner": "leshchenko1979", + "repo": "perscache", + "rev": "0.6.1", + "sha256": "0j2775pjll4vw1wmxkjhnb5z6z83x5lhg89abj2d8ivd17n4rhjf", + "type": "tarball", + "url": "https://github.com/leshchenko1979/perscache/archive/0.6.1.tar.gz", + "url_template": "https://github.com/<owner>/<repo>/archive/<rev>.tar.gz" + }, "regex-applicative": { "branch": "master", "description": "Regex-based parsing with applicative interface", diff --git a/Biz/Dragons/main.py b/Biz/Dragons/main.py index 7ec80bb..7a94f99 100755 --- a/Biz/Dragons/main.py +++ b/Biz/Dragons/main.py @@ -20,7 +20,9 @@ def find_user(line: str) -> typing.Any: return re.findall(r"^[^<]*", line)[0].strip() -def authors_for(path: str, active_users: typing.List[str]) -> typing.Dict[str, str]: +def authors_for( + path: str, active_users: typing.List[str] +) -> typing.Dict[str, str]: """Return a dictionary of {author: commits} for given path. Usernames not in the 'active_users' list will be filtered out.""" raw = subprocess.check_output( @@ -71,7 +73,9 @@ def get_args() -> typing.Any: "Parse CLI arguments." cli = argparse.ArgumentParser(description=__doc__) cli.add_argument("test", action="store_true", help="run the test suite") - cli.add_argument("repo", default=".", help="the git repo to run on", metavar="REPO") + cli.add_argument( + "repo", default=".", help="the git repo to run on", metavar="REPO" + ) cli.add_argument( "-b", "--blackholes", @@ -138,7 +142,9 @@ class Repo: ) -> None: self.paths = [ p - for p in subprocess.check_output(["git", "ls-files", "--no-deleted"]) + for p in subprocess.check_output( + ["git", "ls-files", "--no-deleted"] + ) .decode("utf-8") .split() if not any(i in p for i in ignored_paths) @@ -147,7 +153,9 @@ class Repo: self.stats = {} for path in self.paths: self.stats[path] = authors_for(path, active_users) - self.blackholes = [path for path, authors in self.stats.items() if not authors] + self.blackholes = [ + path for path, authors in self.stats.items() if not authors + ] self.liabilities = { path: list(authors) for path, authors in self.stats.items() diff --git a/Biz/Ide/repl b/Biz/Ide/repl index 1d94e47..1401218 100755 --- a/Biz/Ide/repl +++ b/Biz/Ide/repl @@ -30,6 +30,7 @@ fi sysdeps=$(jq --raw-output '.[].sysdeps | join(" ")' <<< $json) exts=$(jq --raw-output '.[].namespace.ext' <<< $json | sort | uniq) packageSet=$(jq --raw-output '.[].packageSet' <<< $json) + module=$(jq --raw-output '.[].mainModule' <<< $json) BILD="(import ${BIZ_ROOT:?}/Biz/Bild.nix {})" for lib in ${sysdeps[@]}; do flags+=(--packages "$BILD.pkgs.${lib}") @@ -64,8 +65,11 @@ fi ;; Py) langdeps="$langdeps mypy" + flags+=(--packages ruff) flags+=(--packages "$BILD.bild.python.pythonWith (p: with p; [$langdeps])") - command=${CMD:-"python -i $targets"} + PYTHONPATH=$BIZ_ROOT:$PYTHONPATH + pycommand="python -i $BIZ_ROOT/Biz/Repl.py $module ${targets[@]}" + command=${CMD:-"$pycommand"} ;; *) echo "unsupported targets: ${targets[@]}" diff --git a/Biz/Ide/tidy b/Biz/Ide/tidy new file mode 100755 index 0000000..edea828 --- /dev/null +++ b/Biz/Ide/tidy @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +rm -f $BIZ_ROOT/_/bin/* diff --git a/Biz/Ide/tips b/Biz/Ide/tips index 1b998e6..21808eb 100755 --- a/Biz/Ide/tips +++ b/Biz/Ide/tips @@ -9,4 +9,5 @@ echo " tips show this message" echo " lint auto-lint all changed files" echo " push send a namespace to the cloud" echo " ship lint, bild, and push one (or all) namespace(s)" +echo " tidy cleanup common working files" echo "" diff --git a/Biz/Lint.hs b/Biz/Lint.hs index 5c3bef3..bc91f34 100644 --- a/Biz/Lint.hs +++ b/Biz/Lint.hs @@ -115,7 +115,7 @@ printResult r = case r of changedFiles :: IO [FilePath] changedFiles = - git ["merge-base", "HEAD", "origin/master"] + git ["merge-base", "HEAD", "origin/live"] /> filter (/= '\n') +> (\mb -> git ["diff", "--name-only", "--diff-filter=d", mb]) /> String.lines diff --git a/Biz/Log.py b/Biz/Log.py new file mode 100644 index 0000000..af28c41 --- /dev/null +++ b/Biz/Log.py @@ -0,0 +1,32 @@ +""" +Setup logging like Biz/Log.hs. +""" + +import logging +import typing + + +class LowerFormatter(logging.Formatter): + def format(self, record: typing.Any) -> typing.Any: + record.levelname = record.levelname.lower() + return super(logging.Formatter, self).format(record) # type: ignore + + +def setup() -> None: + "Run this in your __main__ function" + logging.basicConfig( + level=logging.DEBUG, format="%(levelname)s: %(name)s: %(message)s" + ) + logging.addLevelName(logging.DEBUG, "dbug") + logging.addLevelName(logging.ERROR, "fail") + logging.addLevelName(logging.INFO, "info") + logger = logging.getLogger(__name__) + formatter = LowerFormatter() + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + + +if __name__ == "__main__": + setup() + logging.debug("i am doing testing") diff --git a/Biz/Mynion.py b/Biz/Mynion.py new file mode 100644 index 0000000..6bb55e1 --- /dev/null +++ b/Biz/Mynion.py @@ -0,0 +1,246 @@ +# : out mynion +# +# : dep exllama +# : dep slixmpp +import argparse +import exllama # type: ignore +import Biz.Log +import glob +import logging +import os +import slixmpp +import slixmpp.exceptions +import sys +import torch +import typing + + +def smoosh(s: str) -> str: + return s.replace("\n", " ") + + +class Mynion(slixmpp.ClientXMPP): + def __init__( + self, + jid: str, + password: str, + model: exllama.model.ExLlama, + tokenizer: exllama.tokenizer.ExLlamaTokenizer, + generator: exllama.generator.ExLlamaGenerator, + ) -> None: + slixmpp.ClientXMPP.__init__(self, jid, password) + self.plugin.enable("xep_0085") # type: ignore + self.plugin.enable("xep_0184") # type: ignore + + self.name = "mynion" + self.user = "ben" + self.first_round = True + self.min_response_tokens = 4 + self.max_response_tokens = 256 + self.extra_prune = 256 + # todo: handle summary rollup when max_seq_len is reached + self.max_seq_len = 8000 + + self.model = model + self.tokenizer = tokenizer + self.generator = generator + + root = os.getenv("BIZ_ROOT", "") + # this should be parameterized somehow + with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f: + txt = f.read() + txt = txt.format(user=self.user, name=self.name) + + # this is the "system prompt", ideally i would load this in/out of a + # database with all of the past history. if the history gets too long, i + # can roll it up by asking llama to summarize it + self.past = smoosh(txt) + + ids = tokenizer.encode(self.past) + self.generator.gen_begin(ids) + + self.add_event_handler("session_start", self.session_start) + self.add_event_handler("message", self.message) + + def session_start(self) -> None: + self.send_presence() + try: + self.get_roster() # type: ignore + except slixmpp.exceptions.IqError as err: + logging.error("There was an error getting the roster") + logging.error(err.iq["error"]["condition"]) + self.disconnect() + except slixmpp.exceptions.IqTimeout: + logging.error("Server is taking too long to respond") + self.disconnect() + + def message(self, msg: slixmpp.Message) -> None: + if msg["type"] in ("chat", "normal"): + res_line = f"{self.name}: " + res_tokens = self.tokenizer.encode(res_line) + num_res_tokens = res_tokens.shape[-1] + + if self.first_round: + in_tokens = res_tokens + else: + # read and format input + in_line = f"{self.user}: " + msg["body"].strip() + "\n" + in_tokens = self.tokenizer.encode(in_line) + in_tokens = torch.cat((in_tokens, res_tokens), dim=1) + + # If we're approaching the context limit, prune some whole lines + # from the start of the context. Also prune a little extra so we + # don't end up rebuilding the cache on every line when up against + # the limit. + expect_tokens = in_tokens.shape[-1] + self.max_response_tokens + max_tokens = self.max_seq_len - expect_tokens + if self.generator.gen_num_tokens() >= max_tokens: + generator.gen_prune_to( + self.max_seq_len - expect_tokens - self.extra_prune, + self.tokenizer.newline_token_id, + ) + + # feed in the user input and "{self.name}:", tokenized + self.generator.gen_feed_tokens(in_tokens) + + # start beam search? + self.generator.begin_beam_search() + + # generate tokens, with streaming + # TODO: drop the streaming! + for i in range(self.max_response_tokens): + # disallowing the end condition tokens seems like a clean way to + # force longer replies + if i < self.min_response_tokens: + self.generator.disallow_tokens( + [ + self.tokenizer.newline_token_id, + self.tokenizer.eos_token_id, + ] + ) + else: + self.generator.disallow_tokens(None) + + # get a token + gen_token = self.generator.beam_search() + + # if token is EOS, replace it with a newline before continuing + if gen_token.item() == self.tokenizer.eos_token_id: + self.generator.replace_last_token( + self.tokenizer.newline_token_id + ) + + # decode the current line + num_res_tokens += 1 + text = self.tokenizer.decode( + self.generator.sequence_actual[:, -num_res_tokens:][0] + ) + + # append to res_line + res_line += text[len(res_line) :] + + # end conditions + breakers = [ + self.tokenizer.eos_token_id, + # self.tokenizer.newline_token_id, + ] + if gen_token.item() in breakers: + break + + # try to drop the "ben:" at the end + if res_line.endswith(f"{self.user}:"): + logging.info("rewinding!") + plen = self.tokenizer.encode(f"{self.user}:").shape[-1] + self.generator.gen_rewind(plen) + break + + # end generation and send the reply + self.generator.end_beam_search() + res_line = res_line.removeprefix(f"{self.name}:") + res_line = res_line.removesuffix(f"{self.user}:") + self.first_round = False + msg.reply(res_line).send() # type: ignore + + +MY_MODELS = [ + "Llama-2-13B-GPTQ", + "Nous-Hermes-13B-GPTQ", + "Nous-Hermes-Llama2-13b-GPTQ", + "Wizard-Vicuna-13B-Uncensored-GPTQ", + "Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ", + "Wizard-Vicuna-30B-Uncensored-GPTQ", + "CodeLlama-13B-Python-GPTQ", + "CodeLlama-13B-Instruct-GPTQ", + "CodeLlama-34B-Instruct-GPTQ", +] + + +def load_model(model_name: str) -> typing.Any: + assert model_name in MY_MODELS + if not torch.cuda.is_available(): + raise ValueError("no cuda") + sys.exit(1) + + torch.set_grad_enabled(False) + torch.cuda._lazy_init() # type: ignore + + ml_models = "/mnt/campbell/ben/ml-models" + + model_dir = os.path.join(ml_models, model_name) + + tokenizer_path = os.path.join(model_dir, "tokenizer.model") + config_path = os.path.join(model_dir, "config.json") + st_pattern = os.path.join(model_dir, "*.safetensors") + st = glob.glob(st_pattern) + if len(st) != 1: + print("found multiple safetensors!") + sys.exit() + model_path = st[0] + + config = exllama.model.ExLlamaConfig(config_path) + config.model_path = model_path + + # gpu split + config.set_auto_map("23") + + model = exllama.model.ExLlama(config) + cache = exllama.model.ExLlamaCache(model) + tokenizer = exllama.tokenizer.ExLlamaTokenizer(tokenizer_path) + + generator = exllama.generator.ExLlamaGenerator(model, tokenizer, cache) + generator.settings = exllama.generator.ExLlamaGenerator.Settings() + + return (model, tokenizer, generator) + + +def main( + model: exllama.model.ExLlama, + tokenizer: exllama.tokenizer.ExLlamaTokenizer, + generator: exllama.generator.ExLlamaGenerator, + user: str, + password: str, +) -> None: + """ + Start the chatbot. + + This purposefully does not call 'load_model()' so that you can load the + model in the repl and then restart the chatbot without unloading it. + """ + Biz.Log.setup() + xmpp = Mynion(user, password, model, tokenizer, generator) + xmpp.connect() + xmpp.process(forever=True) # type: ignore + + +if __name__ == "__main__": + if "test" in sys.argv: + print("pass: test: Biz/Mynion.py") + sys.exit(0) + else: + cli = argparse.ArgumentParser() + cli.add_argument("-u", "--user") + cli.add_argument("-p", "--password") + cli.add_argument("-m", "--model", choices=MY_MODELS) + args = cli.parse_args() + model, tokenizer, generator = load_model(args.model) + main(model, tokenizer, generator, args.user, args.password) diff --git a/Biz/Namespace.hs b/Biz/Namespace.hs index 9621186..48f6277 100644 --- a/Biz/Namespace.hs +++ b/Biz/Namespace.hs @@ -14,6 +14,7 @@ module Biz.Namespace fromHaskellModule, toHaskellModule, toSchemeModule, + fromPythonModule, isCab, groupByExt, ) @@ -105,6 +106,9 @@ fromHaskellModule s = Namespace (List.splitOn "." s) Hs toSchemeModule :: Namespace -> String toSchemeModule = toModule +fromPythonModule :: String -> Namespace +fromPythonModule s = Namespace (List.splitOn "." s) Py + dot :: Regex.RE Char String dot = Regex.some <| Regex.sym '.' diff --git a/Biz/Que/Client.py b/Biz/Que/Client.py index ef6d6d2..20349b6 100755 --- a/Biz/Que/Client.py +++ b/Biz/Que/Client.py @@ -191,7 +191,9 @@ def get_args() -> argparse.Namespace: be used for serving a webpage or other file continuously""" ), ) - cli.add_argument("target", help="namespace and path of the que, like 'ns/path'") + cli.add_argument( + "target", help="namespace and path of the que, like 'ns/path'" + ) cli.add_argument( "infile", nargs="?", diff --git a/Biz/Repl.py b/Biz/Repl.py new file mode 100644 index 0000000..0732fae --- /dev/null +++ b/Biz/Repl.py @@ -0,0 +1,36 @@ +""" +This module attempts to emulate the workflow of ghci or lisp repls. It uses +importlib to load a namespace from the given path. It then binds 'r()' to a +function that reloads the same namespace. +""" + +import importlib +import sys + + +def use(ns: str, path: str) -> None: + """ + Load or reload the module named 'ns' from 'path'. Like `use` in the Guile + Scheme repl. + """ + spec = importlib.util.spec_from_file_location(ns, path) + module = importlib.util.module_from_spec(spec) + # delete module and its imported names if its already loaded + if ns in sys.modules: + del sys.modules[ns] + for name in module.__dict__: + if name in globals(): + del globals()[name] + sys.modules[ns] = module + spec.loader.exec_module(module) + names = [x for x in module.__dict__] + globals().update({k: getattr(module, k) for k in names}) + + +if __name__ == "__main__": + NS = sys.argv[1] + PATH = sys.argv[2] + use(NS, PATH) + + def r(): + return use(NS, PATH) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7f7d0d3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[tool.black] +line-length = 80 + +[tool.mypy] +strict = true +implicit_reexport = true |