summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--Biz/Bild.hs94
-rw-r--r--Biz/Bild/Builder.nix12
-rw-r--r--Biz/Bild/Deps.nix3
-rw-r--r--Biz/Bild/Deps/interegular.nix26
-rw-r--r--Biz/Bild/Deps/outlines.nix49
-rw-r--r--Biz/Bild/Deps/perscache.nix39
-rw-r--r--Biz/Bild/Sources.json36
-rwxr-xr-xBiz/Dragons/main.py16
-rwxr-xr-xBiz/Ide/repl6
-rwxr-xr-xBiz/Ide/tidy2
-rwxr-xr-xBiz/Ide/tips1
-rw-r--r--Biz/Lint.hs2
-rw-r--r--Biz/Log.py32
-rw-r--r--Biz/Mynion.py246
-rw-r--r--Biz/Namespace.hs4
-rwxr-xr-xBiz/Que/Client.py4
-rw-r--r--Biz/Repl.py36
-rw-r--r--pyproject.toml6
19 files changed, 572 insertions, 44 deletions
diff --git a/.gitignore b/.gitignore
index e9beede..ad99c04 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,3 +15,5 @@ tags
dist*
.envrc.local
.direnv/
+__pycache__
+Biz/Mynion/Prompt.md
diff --git a/Biz/Bild.hs b/Biz/Bild.hs
index 9c4f035..22d3882 100644
--- a/Biz/Bild.hs
+++ b/Biz/Bild.hs
@@ -312,10 +312,10 @@ data Target = Target
deriving (Show, Generic, Aeson.ToJSON)
-- | Use this to just get a target to play with at the repl.
-dev_getTarget :: IO Target
-dev_getTarget = do
+dev_getTarget :: FilePath -> IO Target
+dev_getTarget fp = do
root <- Env.getEnv "BIZ_ROOT"
- path <- Dir.makeAbsolute "Biz/Bild.hs"
+ path <- Dir.makeAbsolute fp
Namespace.fromPath root path
|> \case
Nothing -> panic "Could not get namespace from path"
@@ -448,34 +448,35 @@ analyze hmap ns = case Map.lookup ns hmap of
Namespace.Md -> pure Nothing
Namespace.None -> pure Nothing
Namespace.Py ->
- Meta.detectAll "#" contentLines |> \Meta.Parsed {..} ->
- Target
- { builder = "python",
- wrapper = Nothing,
- compiler = CPython,
- compilerFlags =
- -- This doesn't really make sense for python, but I'll leave
- -- it here for eventual --dev builds
- [ "-c",
- "\"import py_compile;import os;"
- <> "py_compile.compile(file='"
- <> str quapath
- <> "', cfile=os.getenv('BIZ_ROOT')+'/_/int/"
- <> str quapath
- <> "', doraise=True)\""
- ],
- sysdeps = psys,
- langdeps = pdep,
- outPath = outToPath pout,
- out = pout,
- -- implement detectPythonImports, then I can fill this out
- srcs = Set.empty,
- packageSet = "python.packages",
- mainModule = Namespace.toModule namespace,
- ..
- }
- |> Just
- |> pure
+ contentLines
+ |> Meta.detectAll "#"
+ |> \Meta.Parsed {..} ->
+ detectPythonImports contentLines +> \srcs ->
+ Target
+ { builder = "python",
+ wrapper = Nothing,
+ compiler = CPython,
+ compilerFlags =
+ -- This doesn't really make sense for python, but I'll leave
+ -- it here for eventual --dev builds
+ [ "-c",
+ "\"import py_compile;import os;"
+ <> "py_compile.compile(file='"
+ <> str quapath
+ <> "', cfile=os.getenv('BIZ_ROOT')+'/_/int/"
+ <> str quapath
+ <> "', doraise=True)\""
+ ],
+ sysdeps = psys,
+ langdeps = pdep,
+ outPath = outToPath pout,
+ out = pout,
+ packageSet = "python.packages",
+ mainModule = Namespace.toModule namespace,
+ ..
+ }
+ |> Just
+ |> pure
Namespace.Sh -> pure Nothing
Namespace.C ->
Meta.detectAll "//" contentLines |> \Meta.Parsed {..} -> do
@@ -713,6 +714,27 @@ detectLispImports contentLines =
|> Set.fromList
|> pure
+-- | Finds local imports. Does not recurse to find transitive imports like
+-- 'detectHaskellImports' does. Someday I will refactor these detection
+-- functions and have a common, well-performing, complete solution.
+detectPythonImports :: [Text] -> IO (Set FilePath)
+detectPythonImports contentLines =
+ contentLines
+ /> Text.unpack
+ /> Regex.match pythonImport
+ |> catMaybes
+ /> Namespace.fromPythonModule
+ /> Namespace.toPath
+ |> filterM Dir.doesPathExist
+ /> Set.fromList
+ where
+ -- only detects 'import x' because I don't like 'from'
+ pythonImport :: Regex.RE Char String
+ pythonImport =
+ Regex.string "import"
+ *> Regex.some (Regex.psym Char.isSpace)
+ *> Regex.many (Regex.psym isModuleChar)
+
ghcPkgFindModule :: Set String -> String -> IO (Set String)
ghcPkgFindModule acc m =
Env.getEnv "GHC_PACKAGE_PATH" +> \packageDb ->
@@ -755,9 +777,13 @@ build andTest loud analysis =
Env.getEnv "BIZ_ROOT" +> \root ->
forM (Map.elems analysis) <| \target@Target {..} ->
fst </ case compiler of
- CPython ->
- Log.info ["bild", "nix", "python", nschunk namespace]
- >> nixBuild loud target
+ CPython -> case out of
+ Meta.Bin _ ->
+ Log.info ["bild", "nix", "python", nschunk namespace]
+ >> nixBuild loud target
+ _ ->
+ Log.info ["bild", "nix", "python", nschunk namespace, "cannot build library"]
+ >> pure (Exit.ExitSuccess, mempty)
Gcc ->
Log.info ["bild", label, "gcc", nschunk namespace]
>> nixBuild loud target
diff --git a/Biz/Bild/Builder.nix b/Biz/Bild/Builder.nix
index 8f42733..f9eb31d 100644
--- a/Biz/Bild/Builder.nix
+++ b/Biz/Bild/Builder.nix
@@ -92,16 +92,22 @@ let
python = bild.python.buildPythonApplication rec {
inherit name src BIZ_ROOT;
- propagatedBuildInputs = [ (bild.python.pythonWith (_: langdeps_)) ] ++ sysdeps_;
+ propagatedBuildInputs = langdeps_ ++ sysdeps_;
buildInputs = sysdeps_;
- checkInputs = [(bild.python.pythonWith (p: with p; [black mypy])) ruff];
+ nativeCheckInputs = [ black mypy ruff ];
checkPhase = ''
check() {
$@ || { echo "fail: $name: $3"; exit 1; }
}
+ cp ${../../pyproject.toml} ./pyproject.toml
check python -m black --quiet --exclude 'setup\.py$' --check .
check ${ruff}/bin/ruff check .
- check python -m mypy --strict --no-error-summary --exclude 'setup\.py$' .
+ touch ./py.typed
+ check python -m mypy \
+ --explicit-package-bases \
+ --no-error-summary \
+ --exclude 'setup\.py$' \
+ .
check python -m ${mainModule} test
'';
preBuild = ''
diff --git a/Biz/Bild/Deps.nix b/Biz/Bild/Deps.nix
index da18d89..dcb7d50 100644
--- a/Biz/Bild/Deps.nix
+++ b/Biz/Bild/Deps.nix
@@ -36,6 +36,9 @@ in rec
exllama = callPackage ./Deps/exllama.nix {
cudaPackages = super.pkgs.cudaPackages_11_7;
};
+ interegular = callPackage ./Deps/interegular.nix {};
+ outlines = callPackage ./Deps/outlines.nix {};
+ perscache = callPackage ./Deps/perscache.nix {};
};
};
diff --git a/Biz/Bild/Deps/interegular.nix b/Biz/Bild/Deps/interegular.nix
new file mode 100644
index 0000000..8b0bc86
--- /dev/null
+++ b/Biz/Bild/Deps/interegular.nix
@@ -0,0 +1,26 @@
+{ lib
+, sources
+, buildPythonPackage
+}:
+
+buildPythonPackage rec {
+ pname = "interegular";
+ version = sources.interegular.rev;
+ format = "setuptools";
+
+ src = sources.interegular;
+
+ propagatedBuildInputs = [ ];
+
+ doCheck = false; # no tests currently
+ pythonImportsCheck = [
+ "interegular"
+ ];
+
+ meta = with lib; {
+ description = "Allows to check regexes for overlaps.";
+ homepage = "https://github.com/MegaIng/interegular";
+ license = licenses.mit;
+ maintainers = with maintainers; [ bsima ];
+ };
+}
diff --git a/Biz/Bild/Deps/outlines.nix b/Biz/Bild/Deps/outlines.nix
new file mode 100644
index 0000000..013581b
--- /dev/null
+++ b/Biz/Bild/Deps/outlines.nix
@@ -0,0 +1,49 @@
+{ lib
+, sources
+, buildPythonPackage
+, interegular
+, jinja2
+, lark
+, numpy
+, perscache
+, pillow
+, pydantic
+, regex
+, scipy
+, tenacity
+, torch
+}:
+
+buildPythonPackage rec {
+ pname = "outlines";
+ version = sources.outlines.rev;
+ format = "pyproject";
+
+ src = sources.outlines;
+
+ propagatedBuildInputs = [
+ interegular
+ jinja2
+ lark
+ numpy
+ perscache
+ pillow
+ pydantic
+ regex
+ scipy
+ tenacity
+ torch
+ ];
+
+ doCheck = false; # no tests currently
+ pythonImportsCheck = [
+ "outlines"
+ ];
+
+ meta = with lib; {
+ description = "Probabilistic Generative Model Programming";
+ homepage = "https://github.com/normal-computing/outlines";
+ license = licenses.asl20;
+ maintainers = with maintainers; [ bsima ];
+ };
+}
diff --git a/Biz/Bild/Deps/perscache.nix b/Biz/Bild/Deps/perscache.nix
new file mode 100644
index 0000000..d757e1a
--- /dev/null
+++ b/Biz/Bild/Deps/perscache.nix
@@ -0,0 +1,39 @@
+{ lib
+, sources
+, buildPythonPackage
+, beartype
+, cloudpickle
+, icontract
+, pbr
+}:
+
+buildPythonPackage rec {
+ pname = "perscache";
+ version = sources.perscache.rev;
+
+ src = sources.perscache;
+
+ propagatedBuildInputs = [
+ beartype
+ cloudpickle
+ icontract
+ pbr
+ ];
+ PBR_VERSION = version;
+
+ doCheck = false; # no tests currently
+ pythonImportsCheck = [
+ "perscache"
+ ];
+
+ meta = with lib; {
+ description = ''
+ An easy to use decorator for persistent memoization: like
+ `functools.lrucache`, but results can be saved in any format to any
+ storage.
+ '';
+ homepage = "https://github.com/leshchenko1979/perscache";
+ license = licenses.mit;
+ maintainers = with maintainers; [ bsima ];
+ };
+}
diff --git a/Biz/Bild/Sources.json b/Biz/Bild/Sources.json
index e4fcfd4..6cc4d48 100644
--- a/Biz/Bild/Sources.json
+++ b/Biz/Bild/Sources.json
@@ -64,6 +64,18 @@
"url_template": "https://gitlab.com/kavalogic-inc/inspekt3d/-/archive/<version>/inspekt3d-<version>.tar.gz",
"version": "703f52ccbfedad2bf5240bf8183d1b573c9d54ef"
},
+ "interegular": {
+ "branch": "master",
+ "description": "Allows to check regexes for overlaps. Based on greenery by @qntm.",
+ "homepage": null,
+ "owner": "MegaIng",
+ "repo": "interegular",
+ "rev": "v0.2.1",
+ "sha256": "14f3jvnczq6qay2qp4rxchbdhkj00qs8kpacl0nrxgr0785km36k",
+ "type": "tarball",
+ "url": "https://github.com/MegaIng/interegular/archive/v0.2.1.tar.gz",
+ "url_template": "https://github.com/<owner>/<repo>/archive/<rev>.tar.gz"
+ },
"llama-cpp": {
"branch": "master",
"description": "Port of Facebook's LLaMA model in C/C++",
@@ -110,6 +122,30 @@
"url": "https://github.com/nixos/nixpkgs/archive/61676e4dcfeeb058f255294bcb08ea7f3bc3ce56.tar.gz",
"url_template": "https://github.com/<owner>/<repo>/archive/<rev>.tar.gz"
},
+ "outlines": {
+ "branch": "main",
+ "description": "Generative Model Programming",
+ "homepage": "https://normal-computing.github.io/outlines/",
+ "owner": "normal-computing",
+ "repo": "outlines",
+ "rev": "0.0.8",
+ "sha256": "1yvx5c5kplmr56nffqcb6ssjnmlikkaw32hxl6i4b607v3s0s6jv",
+ "type": "tarball",
+ "url": "https://github.com/normal-computing/outlines/archive/0.0.8.tar.gz",
+ "url_template": "https://github.com/<owner>/<repo>/archive/<rev>.tar.gz"
+ },
+ "perscache": {
+ "branch": "master",
+ "description": "An easy to use decorator for persistent memoization: like `functools.lrucache`, but results can be saved in any format to any storage.",
+ "homepage": null,
+ "owner": "leshchenko1979",
+ "repo": "perscache",
+ "rev": "0.6.1",
+ "sha256": "0j2775pjll4vw1wmxkjhnb5z6z83x5lhg89abj2d8ivd17n4rhjf",
+ "type": "tarball",
+ "url": "https://github.com/leshchenko1979/perscache/archive/0.6.1.tar.gz",
+ "url_template": "https://github.com/<owner>/<repo>/archive/<rev>.tar.gz"
+ },
"regex-applicative": {
"branch": "master",
"description": "Regex-based parsing with applicative interface",
diff --git a/Biz/Dragons/main.py b/Biz/Dragons/main.py
index 7ec80bb..7a94f99 100755
--- a/Biz/Dragons/main.py
+++ b/Biz/Dragons/main.py
@@ -20,7 +20,9 @@ def find_user(line: str) -> typing.Any:
return re.findall(r"^[^<]*", line)[0].strip()
-def authors_for(path: str, active_users: typing.List[str]) -> typing.Dict[str, str]:
+def authors_for(
+ path: str, active_users: typing.List[str]
+) -> typing.Dict[str, str]:
"""Return a dictionary of {author: commits} for given path. Usernames not in
the 'active_users' list will be filtered out."""
raw = subprocess.check_output(
@@ -71,7 +73,9 @@ def get_args() -> typing.Any:
"Parse CLI arguments."
cli = argparse.ArgumentParser(description=__doc__)
cli.add_argument("test", action="store_true", help="run the test suite")
- cli.add_argument("repo", default=".", help="the git repo to run on", metavar="REPO")
+ cli.add_argument(
+ "repo", default=".", help="the git repo to run on", metavar="REPO"
+ )
cli.add_argument(
"-b",
"--blackholes",
@@ -138,7 +142,9 @@ class Repo:
) -> None:
self.paths = [
p
- for p in subprocess.check_output(["git", "ls-files", "--no-deleted"])
+ for p in subprocess.check_output(
+ ["git", "ls-files", "--no-deleted"]
+ )
.decode("utf-8")
.split()
if not any(i in p for i in ignored_paths)
@@ -147,7 +153,9 @@ class Repo:
self.stats = {}
for path in self.paths:
self.stats[path] = authors_for(path, active_users)
- self.blackholes = [path for path, authors in self.stats.items() if not authors]
+ self.blackholes = [
+ path for path, authors in self.stats.items() if not authors
+ ]
self.liabilities = {
path: list(authors)
for path, authors in self.stats.items()
diff --git a/Biz/Ide/repl b/Biz/Ide/repl
index 1d94e47..1401218 100755
--- a/Biz/Ide/repl
+++ b/Biz/Ide/repl
@@ -30,6 +30,7 @@ fi
sysdeps=$(jq --raw-output '.[].sysdeps | join(" ")' <<< $json)
exts=$(jq --raw-output '.[].namespace.ext' <<< $json | sort | uniq)
packageSet=$(jq --raw-output '.[].packageSet' <<< $json)
+ module=$(jq --raw-output '.[].mainModule' <<< $json)
BILD="(import ${BIZ_ROOT:?}/Biz/Bild.nix {})"
for lib in ${sysdeps[@]}; do
flags+=(--packages "$BILD.pkgs.${lib}")
@@ -64,8 +65,11 @@ fi
;;
Py)
langdeps="$langdeps mypy"
+ flags+=(--packages ruff)
flags+=(--packages "$BILD.bild.python.pythonWith (p: with p; [$langdeps])")
- command=${CMD:-"python -i $targets"}
+ PYTHONPATH=$BIZ_ROOT:$PYTHONPATH
+ pycommand="python -i $BIZ_ROOT/Biz/Repl.py $module ${targets[@]}"
+ command=${CMD:-"$pycommand"}
;;
*)
echo "unsupported targets: ${targets[@]}"
diff --git a/Biz/Ide/tidy b/Biz/Ide/tidy
new file mode 100755
index 0000000..edea828
--- /dev/null
+++ b/Biz/Ide/tidy
@@ -0,0 +1,2 @@
+#!/usr/bin/env bash
+rm -f $BIZ_ROOT/_/bin/*
diff --git a/Biz/Ide/tips b/Biz/Ide/tips
index 1b998e6..21808eb 100755
--- a/Biz/Ide/tips
+++ b/Biz/Ide/tips
@@ -9,4 +9,5 @@ echo " tips show this message"
echo " lint auto-lint all changed files"
echo " push send a namespace to the cloud"
echo " ship lint, bild, and push one (or all) namespace(s)"
+echo " tidy cleanup common working files"
echo ""
diff --git a/Biz/Lint.hs b/Biz/Lint.hs
index 5c3bef3..bc91f34 100644
--- a/Biz/Lint.hs
+++ b/Biz/Lint.hs
@@ -115,7 +115,7 @@ printResult r = case r of
changedFiles :: IO [FilePath]
changedFiles =
- git ["merge-base", "HEAD", "origin/master"]
+ git ["merge-base", "HEAD", "origin/live"]
/> filter (/= '\n')
+> (\mb -> git ["diff", "--name-only", "--diff-filter=d", mb])
/> String.lines
diff --git a/Biz/Log.py b/Biz/Log.py
new file mode 100644
index 0000000..af28c41
--- /dev/null
+++ b/Biz/Log.py
@@ -0,0 +1,32 @@
+"""
+Setup logging like Biz/Log.hs.
+"""
+
+import logging
+import typing
+
+
+class LowerFormatter(logging.Formatter):
+ def format(self, record: typing.Any) -> typing.Any:
+ record.levelname = record.levelname.lower()
+ return super(logging.Formatter, self).format(record) # type: ignore
+
+
+def setup() -> None:
+ "Run this in your __main__ function"
+ logging.basicConfig(
+ level=logging.DEBUG, format="%(levelname)s: %(name)s: %(message)s"
+ )
+ logging.addLevelName(logging.DEBUG, "dbug")
+ logging.addLevelName(logging.ERROR, "fail")
+ logging.addLevelName(logging.INFO, "info")
+ logger = logging.getLogger(__name__)
+ formatter = LowerFormatter()
+ handler = logging.StreamHandler()
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+
+
+if __name__ == "__main__":
+ setup()
+ logging.debug("i am doing testing")
diff --git a/Biz/Mynion.py b/Biz/Mynion.py
new file mode 100644
index 0000000..6bb55e1
--- /dev/null
+++ b/Biz/Mynion.py
@@ -0,0 +1,246 @@
+# : out mynion
+#
+# : dep exllama
+# : dep slixmpp
+import argparse
+import exllama # type: ignore
+import Biz.Log
+import glob
+import logging
+import os
+import slixmpp
+import slixmpp.exceptions
+import sys
+import torch
+import typing
+
+
+def smoosh(s: str) -> str:
+ return s.replace("\n", " ")
+
+
+class Mynion(slixmpp.ClientXMPP):
+ def __init__(
+ self,
+ jid: str,
+ password: str,
+ model: exllama.model.ExLlama,
+ tokenizer: exllama.tokenizer.ExLlamaTokenizer,
+ generator: exllama.generator.ExLlamaGenerator,
+ ) -> None:
+ slixmpp.ClientXMPP.__init__(self, jid, password)
+ self.plugin.enable("xep_0085") # type: ignore
+ self.plugin.enable("xep_0184") # type: ignore
+
+ self.name = "mynion"
+ self.user = "ben"
+ self.first_round = True
+ self.min_response_tokens = 4
+ self.max_response_tokens = 256
+ self.extra_prune = 256
+ # todo: handle summary rollup when max_seq_len is reached
+ self.max_seq_len = 8000
+
+ self.model = model
+ self.tokenizer = tokenizer
+ self.generator = generator
+
+ root = os.getenv("BIZ_ROOT", "")
+ # this should be parameterized somehow
+ with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f:
+ txt = f.read()
+ txt = txt.format(user=self.user, name=self.name)
+
+ # this is the "system prompt", ideally i would load this in/out of a
+ # database with all of the past history. if the history gets too long, i
+ # can roll it up by asking llama to summarize it
+ self.past = smoosh(txt)
+
+ ids = tokenizer.encode(self.past)
+ self.generator.gen_begin(ids)
+
+ self.add_event_handler("session_start", self.session_start)
+ self.add_event_handler("message", self.message)
+
+ def session_start(self) -> None:
+ self.send_presence()
+ try:
+ self.get_roster() # type: ignore
+ except slixmpp.exceptions.IqError as err:
+ logging.error("There was an error getting the roster")
+ logging.error(err.iq["error"]["condition"])
+ self.disconnect()
+ except slixmpp.exceptions.IqTimeout:
+ logging.error("Server is taking too long to respond")
+ self.disconnect()
+
+ def message(self, msg: slixmpp.Message) -> None:
+ if msg["type"] in ("chat", "normal"):
+ res_line = f"{self.name}: "
+ res_tokens = self.tokenizer.encode(res_line)
+ num_res_tokens = res_tokens.shape[-1]
+
+ if self.first_round:
+ in_tokens = res_tokens
+ else:
+ # read and format input
+ in_line = f"{self.user}: " + msg["body"].strip() + "\n"
+ in_tokens = self.tokenizer.encode(in_line)
+ in_tokens = torch.cat((in_tokens, res_tokens), dim=1)
+
+ # If we're approaching the context limit, prune some whole lines
+ # from the start of the context. Also prune a little extra so we
+ # don't end up rebuilding the cache on every line when up against
+ # the limit.
+ expect_tokens = in_tokens.shape[-1] + self.max_response_tokens
+ max_tokens = self.max_seq_len - expect_tokens
+ if self.generator.gen_num_tokens() >= max_tokens:
+ generator.gen_prune_to(
+ self.max_seq_len - expect_tokens - self.extra_prune,
+ self.tokenizer.newline_token_id,
+ )
+
+ # feed in the user input and "{self.name}:", tokenized
+ self.generator.gen_feed_tokens(in_tokens)
+
+ # start beam search?
+ self.generator.begin_beam_search()
+
+ # generate tokens, with streaming
+ # TODO: drop the streaming!
+ for i in range(self.max_response_tokens):
+ # disallowing the end condition tokens seems like a clean way to
+ # force longer replies
+ if i < self.min_response_tokens:
+ self.generator.disallow_tokens(
+ [
+ self.tokenizer.newline_token_id,
+ self.tokenizer.eos_token_id,
+ ]
+ )
+ else:
+ self.generator.disallow_tokens(None)
+
+ # get a token
+ gen_token = self.generator.beam_search()
+
+ # if token is EOS, replace it with a newline before continuing
+ if gen_token.item() == self.tokenizer.eos_token_id:
+ self.generator.replace_last_token(
+ self.tokenizer.newline_token_id
+ )
+
+ # decode the current line
+ num_res_tokens += 1
+ text = self.tokenizer.decode(
+ self.generator.sequence_actual[:, -num_res_tokens:][0]
+ )
+
+ # append to res_line
+ res_line += text[len(res_line) :]
+
+ # end conditions
+ breakers = [
+ self.tokenizer.eos_token_id,
+ # self.tokenizer.newline_token_id,
+ ]
+ if gen_token.item() in breakers:
+ break
+
+ # try to drop the "ben:" at the end
+ if res_line.endswith(f"{self.user}:"):
+ logging.info("rewinding!")
+ plen = self.tokenizer.encode(f"{self.user}:").shape[-1]
+ self.generator.gen_rewind(plen)
+ break
+
+ # end generation and send the reply
+ self.generator.end_beam_search()
+ res_line = res_line.removeprefix(f"{self.name}:")
+ res_line = res_line.removesuffix(f"{self.user}:")
+ self.first_round = False
+ msg.reply(res_line).send() # type: ignore
+
+
+MY_MODELS = [
+ "Llama-2-13B-GPTQ",
+ "Nous-Hermes-13B-GPTQ",
+ "Nous-Hermes-Llama2-13b-GPTQ",
+ "Wizard-Vicuna-13B-Uncensored-GPTQ",
+ "Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ",
+ "Wizard-Vicuna-30B-Uncensored-GPTQ",
+ "CodeLlama-13B-Python-GPTQ",
+ "CodeLlama-13B-Instruct-GPTQ",
+ "CodeLlama-34B-Instruct-GPTQ",
+]
+
+
+def load_model(model_name: str) -> typing.Any:
+ assert model_name in MY_MODELS
+ if not torch.cuda.is_available():
+ raise ValueError("no cuda")
+ sys.exit(1)
+
+ torch.set_grad_enabled(False)
+ torch.cuda._lazy_init() # type: ignore
+
+ ml_models = "/mnt/campbell/ben/ml-models"
+
+ model_dir = os.path.join(ml_models, model_name)
+
+ tokenizer_path = os.path.join(model_dir, "tokenizer.model")
+ config_path = os.path.join(model_dir, "config.json")
+ st_pattern = os.path.join(model_dir, "*.safetensors")
+ st = glob.glob(st_pattern)
+ if len(st) != 1:
+ print("found multiple safetensors!")
+ sys.exit()
+ model_path = st[0]
+
+ config = exllama.model.ExLlamaConfig(config_path)
+ config.model_path = model_path
+
+ # gpu split
+ config.set_auto_map("23")
+
+ model = exllama.model.ExLlama(config)
+ cache = exllama.model.ExLlamaCache(model)
+ tokenizer = exllama.tokenizer.ExLlamaTokenizer(tokenizer_path)
+
+ generator = exllama.generator.ExLlamaGenerator(model, tokenizer, cache)
+ generator.settings = exllama.generator.ExLlamaGenerator.Settings()
+
+ return (model, tokenizer, generator)
+
+
+def main(
+ model: exllama.model.ExLlama,
+ tokenizer: exllama.tokenizer.ExLlamaTokenizer,
+ generator: exllama.generator.ExLlamaGenerator,
+ user: str,
+ password: str,
+) -> None:
+ """
+ Start the chatbot.
+
+ This purposefully does not call 'load_model()' so that you can load the
+ model in the repl and then restart the chatbot without unloading it.
+ """
+ Biz.Log.setup()
+ xmpp = Mynion(user, password, model, tokenizer, generator)
+ xmpp.connect()
+ xmpp.process(forever=True) # type: ignore
+
+
+if __name__ == "__main__":
+ if "test" in sys.argv:
+ print("pass: test: Biz/Mynion.py")
+ sys.exit(0)
+ else:
+ cli = argparse.ArgumentParser()
+ cli.add_argument("-u", "--user")
+ cli.add_argument("-p", "--password")
+ cli.add_argument("-m", "--model", choices=MY_MODELS)
+ args = cli.parse_args()
+ model, tokenizer, generator = load_model(args.model)
+ main(model, tokenizer, generator, args.user, args.password)
diff --git a/Biz/Namespace.hs b/Biz/Namespace.hs
index 9621186..48f6277 100644
--- a/Biz/Namespace.hs
+++ b/Biz/Namespace.hs
@@ -14,6 +14,7 @@ module Biz.Namespace
fromHaskellModule,
toHaskellModule,
toSchemeModule,
+ fromPythonModule,
isCab,
groupByExt,
)
@@ -105,6 +106,9 @@ fromHaskellModule s = Namespace (List.splitOn "." s) Hs
toSchemeModule :: Namespace -> String
toSchemeModule = toModule
+fromPythonModule :: String -> Namespace
+fromPythonModule s = Namespace (List.splitOn "." s) Py
+
dot :: Regex.RE Char String
dot = Regex.some <| Regex.sym '.'
diff --git a/Biz/Que/Client.py b/Biz/Que/Client.py
index ef6d6d2..20349b6 100755
--- a/Biz/Que/Client.py
+++ b/Biz/Que/Client.py
@@ -191,7 +191,9 @@ def get_args() -> argparse.Namespace:
be used for serving a webpage or other file continuously"""
),
)
- cli.add_argument("target", help="namespace and path of the que, like 'ns/path'")
+ cli.add_argument(
+ "target", help="namespace and path of the que, like 'ns/path'"
+ )
cli.add_argument(
"infile",
nargs="?",
diff --git a/Biz/Repl.py b/Biz/Repl.py
new file mode 100644
index 0000000..0732fae
--- /dev/null
+++ b/Biz/Repl.py
@@ -0,0 +1,36 @@
+"""
+This module attempts to emulate the workflow of ghci or lisp repls. It uses
+importlib to load a namespace from the given path. It then binds 'r()' to a
+function that reloads the same namespace.
+"""
+
+import importlib
+import sys
+
+
+def use(ns: str, path: str) -> None:
+ """
+ Load or reload the module named 'ns' from 'path'. Like `use` in the Guile
+ Scheme repl.
+ """
+ spec = importlib.util.spec_from_file_location(ns, path)
+ module = importlib.util.module_from_spec(spec)
+ # delete module and its imported names if its already loaded
+ if ns in sys.modules:
+ del sys.modules[ns]
+ for name in module.__dict__:
+ if name in globals():
+ del globals()[name]
+ sys.modules[ns] = module
+ spec.loader.exec_module(module)
+ names = [x for x in module.__dict__]
+ globals().update({k: getattr(module, k) for k in names})
+
+
+if __name__ == "__main__":
+ NS = sys.argv[1]
+ PATH = sys.argv[2]
+ use(NS, PATH)
+
+ def r():
+ return use(NS, PATH)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..7f7d0d3
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,6 @@
+[tool.black]
+line-length = 80
+
+[tool.mypy]
+strict = true
+implicit_reexport = true