summaryrefslogtreecommitdiff
path: root/Biz
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2024-04-10 19:56:46 -0400
committerBen Sima <ben@bsima.me>2024-04-10 19:56:46 -0400
commit2c09c7f73e2fc770f42b5dd2588aa9634b4e7c6e (patch)
treea6c49dddb7b1735a0f8de2a9a5f4efb605f16f36 /Biz
parent051973aba8953ebde51eb1436fb3994e7ae699dc (diff)
Switch from black to ruff format
Ruff is faster and if it supports everything that black supports than why not? I did have to pull in a more recent version from unstable, but that's easy to do now. And I decided to just go ahead and configure ruff by turning on almost all checks, which meant I had to fix a whole bunch of things, but I did that and everything is okay now.
Diffstat (limited to 'Biz')
-rw-r--r--Biz/Bild.nix4
-rw-r--r--Biz/Bild/Builder.nix22
-rw-r--r--Biz/Bild/Example.py22
-rw-r--r--[-rwxr-xr-x]Biz/Dragons/main.py132
-rwxr-xr-xBiz/Ide/repl.sh8
-rw-r--r--Biz/Lint.hs33
-rw-r--r--Biz/Llamacpp.py4
-rw-r--r--Biz/Log.py17
-rw-r--r--Biz/Mynion.py108
-rw-r--r--[-rwxr-xr-x]Biz/Que/Client.py70
-rw-r--r--Biz/Repl.py27
11 files changed, 245 insertions, 202 deletions
diff --git a/Biz/Bild.nix b/Biz/Bild.nix
index 24ce4bf..7a07f36 100644
--- a/Biz/Bild.nix
+++ b/Biz/Bild.nix
@@ -64,7 +64,8 @@ in nixpkgs // {
# expose some packages for inclusion in os/image builds
pkgs = with nixpkgs.pkgs; {
- inherit black deadnix git hlint indent ormolu ruff shellcheck nixfmt;
+ inherit deadnix git hlint indent ormolu shellcheck nixfmt mypy pkg-config;
+ ruff = unstable.ruff;
};
# a standard nix build for bild, for bootstrapping. this should be the only
@@ -146,7 +147,6 @@ in nixpkgs // {
bat
bc
bild
- black
ctags
fd
figlet
diff --git a/Biz/Bild/Builder.nix b/Biz/Bild/Builder.nix
index d2e6875..4bef830 100644
--- a/Biz/Bild/Builder.nix
+++ b/Biz/Bild/Builder.nix
@@ -110,14 +110,14 @@ let
inherit name src CODEROOT;
propagatedBuildInputs = langdeps_ ++ sysdeps_;
buildInputs = sysdeps_;
- nativeCheckInputs = [ black mypy ruff ];
+ nativeCheckInputs = lib.attrsets.attrVals [ "mypy" "ruff" ] bild.pkgs;
checkPhase = ''
check() {
$@ || { echo "fail: $name: $3"; exit 1; }
}
cp ${../../pyproject.toml} ./pyproject.toml
- check python -m black --quiet --exclude 'setup\.py$' --check .
- check ${ruff}/bin/ruff check .
+ check ruff format --exclude 'setup.py' --check .
+ check ruff check --exclude 'setup.py' --exclude '__init__.py' .
touch ./py.typed
check python -m mypy \
--explicit-package-bases \
@@ -133,15 +133,15 @@ let
find . -type d -exec touch {}/__init__.py \;
# generate a minimal setup.py
cat > setup.py << EOF
- from setuptools import setup, find_packages
+ from setuptools import find_packages, setup
setup(
- name='${name}',
- entry_points={'console_scripts':['${name} = ${mainModule}:main']},
- version='0.0.0',
- url='git://simatime.com/biz.git',
- author='dev',
- author_email='dev@simatime.com',
- description='nil',
+ name="${name}",
+ entry_points={"console_scripts":["${name} = ${mainModule}:main"]},
+ version="0.0.0",
+ url="git://simatime.com/biz.git",
+ author="dev",
+ author_email="dev@simatime.com",
+ description="nil",
packages=find_packages(),
install_requires=[],
)
diff --git a/Biz/Bild/Example.py b/Biz/Bild/Example.py
index 5d165d8..1bd30ae 100644
--- a/Biz/Bild/Example.py
+++ b/Biz/Bild/Example.py
@@ -1,32 +1,42 @@
"""
+Test that bild can build Python stuff.
+
Example Python file that also serves as a test case for bild.
"""
+
# : out example
# : dep cryptography
import sys
+
import cryptography.fernet
def cryptic_hello(name: str) -> str:
- "Example taken from `cryptography` docs."
+ """
+ Encrypt and decrypt `name`.
+
+ Example taken from `cryptography` docs.
+ """
key = cryptography.fernet.Fernet.generate_key()
f = cryptography.fernet.Fernet(key)
token = f.encrypt(hello(name).encode("utf-8"))
ret = f.decrypt(token).decode("utf-8")
- assert ret == hello(name)
+ if ret != hello(name):
+ msg = "en/decryption failed!"
+ raise ValueError(msg)
return ret
def hello(name: str) -> str:
- "Say hello"
+ """Say hello."""
return f"Hello {name}"
def main() -> None:
- "Entrypoint"
+ """Entrypoint."""
if "test" in sys.argv:
- print("testing success")
- print(cryptic_hello("world"))
+ sys.stdout.write("testing success")
+ sys.stdout.write(cryptic_hello("world"))
if __name__ == "__main__":
diff --git a/Biz/Dragons/main.py b/Biz/Dragons/main.py
index 7a94f99..6e2c995 100755..100644
--- a/Biz/Dragons/main.py
+++ b/Biz/Dragons/main.py
@@ -1,13 +1,11 @@
-#!/usr/bin/env python
# : out dragons.py
-"""
-Analyze developer allocation across a codebase.
-"""
+"""Analyze developer allocation across a codebase."""
import argparse
import datetime
import logging
import os
+import pathlib
import re
import subprocess
import sys
@@ -15,18 +13,26 @@ import typing
def find_user(line: str) -> typing.Any:
- """Given 'Ben Sima <ben@bsima.me>', finds `Ben Sima'. Returns the first
- matching string."""
+ """
+ Find a person's name in a .mailmap file.
+
+ Given 'Ben Sima <ben@bsima.me>', finds `Ben Sima'. Returns the first
+ matching string.
+ """
return re.findall(r"^[^<]*", line)[0].strip()
def authors_for(
- path: str, active_users: typing.List[str]
-) -> typing.Dict[str, str]:
- """Return a dictionary of {author: commits} for given path. Usernames not in
- the 'active_users' list will be filtered out."""
+ path: str,
+ active_users: list[str],
+) -> dict[str, str]:
+ """
+ Return a dictionary of {author: commits} for given path.
+
+ Usernames not in the 'active_users' list will be filtered out.
+ """
raw = subprocess.check_output(
- ["git", "shortlog", "--numbered", "--summary", "--email", "--", path]
+ ["git", "shortlog", "--numbered", "--summary", "--email", "--", path],
).decode("utf-8")
lines = [s for s in raw.split("\n") if s]
data = {}
@@ -39,21 +45,18 @@ def authors_for(
return data
-def mailmap_users() -> typing.List[str]:
- """Returns users from the .mailmap file."""
- users = []
- with open(".mailmap", encoding="utf-8") as file:
+def mailmap_users() -> list[str]:
+ """Return users from the .mailmap file."""
+ with pathlib.Path(".mailmap").open() as file:
lines = file.readlines()
- for line in lines:
- users.append(find_user(line))
- return users
+ return [find_user(line) for line in lines]
MAX_SCORE = 10
def score(blackhole: float, liability: float, good: int, total: int) -> float:
- "Calculate the score."
+ """Calculate the score."""
weights = {
"blackhole": 0.5,
"liability": 0.7,
@@ -70,17 +73,20 @@ def score(blackhole: float, liability: float, good: int, total: int) -> float:
def get_args() -> typing.Any:
- "Parse CLI arguments."
+ """Parse CLI arguments."""
cli = argparse.ArgumentParser(description=__doc__)
cli.add_argument("test", action="store_true", help="run the test suite")
cli.add_argument(
- "repo", default=".", help="the git repo to run on", metavar="REPO"
+ "repo",
+ default=".",
+ help="the git repo to run on",
+ metavar="REPO",
)
cli.add_argument(
"-b",
"--blackholes",
action="store_true",
- help="print the blackholes (files with one or zero active contributors)",
+ help="print the blackholes (files with 1 or 0 active contributors)",
)
cli.add_argument(
"-l",
@@ -105,12 +111,7 @@ def get_args() -> typing.Any:
"--active-users",
nargs="+",
default=[],
- help=" ".join(
- [
- "list of active user emails."
- "if not provided, this is loaded from .mailmap"
- ]
- ),
+ help="list of active user emails. default: loaded from .mailmap",
)
cli.add_argument(
"-v",
@@ -123,7 +124,7 @@ def get_args() -> typing.Any:
def staleness(path: str, now: datetime.datetime) -> int:
- "How long has it been since this file was touched?"
+ """How long has it been since this file was touched?."""
timestamp = datetime.datetime.strptime(
subprocess.check_output(["git", "log", "-n1", "--pretty=%aI", path])
.decode("utf-8")
@@ -135,15 +136,18 @@ def staleness(path: str, now: datetime.datetime) -> int:
class Repo:
- "Represents a repo and stats for the repo."
+ """Represents a repo and stats for the repo."""
def __init__(
- self, ignored_paths: typing.List[str], active_users: typing.List[str]
+ self: "Repo",
+ ignored_paths: list[str],
+ active_users: list[str],
) -> None:
+ """Create analysis of a git repo."""
self.paths = [
p
for p in subprocess.check_output(
- ["git", "ls-files", "--no-deleted"]
+ ["git", "ls-files", "--no-deleted"],
)
.decode("utf-8")
.split()
@@ -156,60 +160,65 @@ class Repo:
self.blackholes = [
path for path, authors in self.stats.items() if not authors
]
+ max_authors = 3
self.liabilities = {
path: list(authors)
for path, authors in self.stats.items()
- if 1 <= len(authors) < 3
+ if 1 <= len(authors) < max_authors
}
now = datetime.datetime.utcnow().astimezone()
self.stale = {}
- for path, _ in self.stats.items():
+ max_staleness = 180
+ for path in self.stats:
_staleness = staleness(path, now)
- if _staleness > 180:
+ if _staleness > max_staleness:
self.stale[path] = _staleness
- def print_blackholes(self, full: bool) -> None:
- "Print number of blackholes, or list of all blackholes."
+ def print_blackholes(self: "Repo", *, full: bool) -> None:
+ """Print number of blackholes, or list of all blackholes."""
# note: file renames may result in false positives
n_blackhole = len(self.blackholes)
- print(f"Blackholes: {n_blackhole}")
+ sys.stdout.write(f"Blackholes: {n_blackhole}")
if full:
for path in self.blackholes:
- print(f" {path}")
+ sys.stdout.write(f" {path}")
+ sys.stdout.flush()
- def print_liabilities(self, full: bool) -> None:
- "Print number of liabilities, or list of all liabilities."
+ def print_liabilities(self: "Repo", *, full: bool) -> None:
+ """Print number of liabilities, or list of all liabilities."""
n_liabilities = len(self.liabilities)
- print(f"Liabilities: {n_liabilities}")
+ sys.stdout.write(f"Liabilities: {n_liabilities}")
if full:
for path, authors in self.liabilities.items():
- print(f" {path} ({', '.join(authors)})")
+ sys.stdout.write(f" {path} ({', '.join(authors)})")
+ sys.stdout.flush()
- def print_score(self) -> None:
- "Print the overall score."
+ def print_score(self: "Repo") -> None:
+ """Print the overall score."""
n_total = len(self.stats.keys())
n_blackhole = len(self.blackholes)
n_liabilities = len(self.liabilities)
n_good = n_total - n_blackhole - n_liabilities
- print("Total:", n_total)
+ sys.stdout.write(f"Total: {n_total}")
this_score = score(n_blackhole, n_liabilities, n_good, n_total)
- print(f"Score: {this_score:.2f}/{MAX_SCORE}".format())
+ sys.stdout.write(f"Score: {this_score:.2f}/{MAX_SCORE}".format())
+ sys.stdout.flush()
- def print_stale(self, full: bool) -> None:
- "Print stale files"
+ def print_stale(self: "Repo", *, full: bool) -> None:
+ """Print stale files."""
n_stale = len(self.stale)
- print(f"Stale files: {n_stale}")
+ sys.stdout.write(f"Stale files: {n_stale}")
if full:
for path, days in self.stale.items():
- print(f" {path} ({days} days)")
+ sys.stdout.write(f" {path} ({days} days)")
+ sys.stdout.flush()
def guard_git(repo: Repo) -> None:
- "Guard against non-git repos."
+ """Guard against non-git repos."""
is_git = subprocess.run(
["git", "rev-parse"],
- stderr=subprocess.PIPE,
- stdout=subprocess.PIPE,
+ capture_output=True,
check=False,
).returncode
if is_git != 0:
@@ -219,25 +228,24 @@ def guard_git(repo: Repo) -> None:
if __name__ == "__main__":
ARGS = get_args()
if ARGS.test:
- print("ok")
+ sys.stdout.write("ok")
sys.exit()
logging.basicConfig(stream=sys.stderr, level=ARGS.verbosity.upper())
logging.debug("starting")
- os.chdir(os.path.abspath(ARGS.repo))
+ os.chdir(pathlib.Path(ARGS.repo).resolve())
guard_git(ARGS.repo)
# if no active users provided, load from .mailmap
- if ARGS.active_users == []:
- if os.path.exists(".mailmap"):
- ARGS.active_users = mailmap_users()
+ if ARGS.active_users == [] and pathlib.Path(".mailmap").exists():
+ ARGS.active_users = mailmap_users()
# collect data
REPO = Repo(ARGS.ignored, ARGS.active_users)
# print data
REPO.print_score()
- REPO.print_blackholes(ARGS.blackholes)
- REPO.print_liabilities(ARGS.liabilities)
- REPO.print_stale(ARGS.stale)
+ REPO.print_blackholes(full=ARGS.blackholes)
+ REPO.print_liabilities(full=ARGS.liabilities)
+ REPO.print_stale(full=ARGS.stale)
diff --git a/Biz/Ide/repl.sh b/Biz/Ide/repl.sh
index 78fe1eb..8b28dcd 100755
--- a/Biz/Ide/repl.sh
+++ b/Biz/Ide/repl.sh
@@ -33,16 +33,16 @@ fi
packageSet=$(jq --raw-output '.[].packageSet' <<< "$json")
module=$(jq --raw-output '.[].mainModule' <<< "$json")
BILD="(import ${CODEROOT:?}/Biz/Bild.nix {})"
- declare -a flags=(--packages "$BILD.pkgs.pkg-config")
+ declare -a flags=(--packages "$BILD.bild.pkgs.pkg-config")
for lib in "${sysdeps[@]}"; do
- flags+=(--packages "$BILD.pkgs.${lib}")
+ flags+=(--packages "$BILD.bild.pkgs.${lib}")
done
for lib in "${rundeps[@]}"; do
- flags+=(--packages "$BILD.pkgs.${lib}")
+ flags+=(--packages "$BILD.bild.pkgs.${lib}")
done
case $exts in
C)
- flags+=(--packages "$BILD.pkgs.gcc")
+ flags+=(--packages "$BILD.bild.pkgs.gcc")
command="bash"
;;
Hs)
diff --git a/Biz/Lint.hs b/Biz/Lint.hs
index d27ca1d..d387db0 100644
--- a/Biz/Lint.hs
+++ b/Biz/Lint.hs
@@ -10,7 +10,6 @@
-- : out lint
-- : run ormolu
-- : run hlint
--- : run black
-- : run ruff
-- : run deadnix
-- : run shellcheck
@@ -138,7 +137,7 @@ data Linter = Linter
fixArgs :: Maybe [Text],
-- | An optional function to format the output of the linter as you want
-- it, perhaps decoding json or something
- formatter :: Maybe (String -> String)
+ decoder :: Maybe (String -> String)
}
ormolu :: Linter
@@ -147,7 +146,7 @@ ormolu =
{ exe = "ormolu",
checkArgs = ["--mode", "check", "--no-cabal"],
fixArgs = Just ["--mode", "inplace", "--no-cabal"],
- formatter = Nothing
+ decoder = Nothing
}
hlint :: Linter
@@ -158,16 +157,16 @@ hlint =
-- needs apply-refact >0.9.1.0, which needs ghc >9
-- fixArgs = Just ["--refactor", "--refactor-options=-i"]
fixArgs = Nothing,
- formatter = Nothing
+ decoder = Nothing
}
-black :: Linter
-black =
+ruffFormat :: Linter
+ruffFormat =
Linter
- { exe = "black",
- checkArgs = ["--check"],
- fixArgs = Just [],
- formatter = Nothing
+ { exe = "ruff",
+ checkArgs = ["format", "--check", "--silent"],
+ fixArgs = Just ["format", "--silent"],
+ decoder = Nothing
}
ruff :: Linter
@@ -176,7 +175,7 @@ ruff =
{ exe = "ruff",
checkArgs = ["check"],
fixArgs = Just ["check", "--fix"],
- formatter = Nothing
+ decoder = Nothing
}
data DeadnixOutput = DeadnixOutput
@@ -199,7 +198,7 @@ deadnix =
{ exe = "deadnix",
checkArgs = "--fail" : commonArgs,
fixArgs = Just <| "--edit" : commonArgs,
- formatter = Just decodeDeadnixOutput
+ decoder = Just decodeDeadnixOutput
}
where
commonArgs =
@@ -227,7 +226,7 @@ nixfmt =
{ exe = "nixfmt",
checkArgs = ["--check"],
fixArgs = Nothing,
- formatter = Nothing
+ decoder = Nothing
}
shellcheck :: Linter
@@ -236,7 +235,7 @@ shellcheck =
{ exe = "shellcheck",
checkArgs = [],
fixArgs = Nothing,
- formatter = Nothing
+ decoder = Nothing
}
indent :: Linter
@@ -245,7 +244,7 @@ indent =
{ exe = "indent",
checkArgs = [],
fixArgs = Nothing,
- formatter = Nothing
+ decoder = Nothing
}
data Status = Good | Bad String
@@ -272,7 +271,7 @@ runOne mode (ext, ns's) = results +> traverse printResult
lint mode hlint ns's
]
Namespace.Py ->
- [ lint mode black ns's,
+ [ lint mode ruffFormat ns's,
lint mode ruff ns's
]
Namespace.Sh -> [lint mode shellcheck ns's]
@@ -294,7 +293,7 @@ lint mode linter@Linter {..} ns's =
>> Process.readProcessWithExitCode (str exe) args "" /> \case
(Exit.ExitSuccess, _, _) ->
Done linter Good
- (Exit.ExitFailure _, msg, _) -> case formatter of
+ (Exit.ExitFailure _, msg, _) -> case decoder of
Nothing -> Done linter <| Bad msg
Just fmt -> Done linter <| Bad <| fmt msg
where
diff --git a/Biz/Llamacpp.py b/Biz/Llamacpp.py
index e75de5b..cd47e1e 100644
--- a/Biz/Llamacpp.py
+++ b/Biz/Llamacpp.py
@@ -1,6 +1,4 @@
-"""
-Llamacpp
-"""
+"""Llamacpp."""
# : run nixos-23_11.llama-cpp
# : run nixos-23_11.openblas
diff --git a/Biz/Log.py b/Biz/Log.py
index af28c41..68b1e90 100644
--- a/Biz/Log.py
+++ b/Biz/Log.py
@@ -1,21 +1,24 @@
-"""
-Setup logging like Biz/Log.hs.
-"""
+"""Setup logging like Biz/Log.hs."""
+# ruff: noqa: A003
import logging
import typing
class LowerFormatter(logging.Formatter):
- def format(self, record: typing.Any) -> typing.Any:
+ """A logging formatter that formats logs how I like."""
+
+ def format(self: "LowerFormatter", record: typing.Any) -> typing.Any:
+ """Use the format I like for logging."""
record.levelname = record.levelname.lower()
- return super(logging.Formatter, self).format(record) # type: ignore
+ return super(logging.Formatter, self).format(record) # type: ignore[misc]
def setup() -> None:
- "Run this in your __main__ function"
+ """Run this in your __main__ function."""
logging.basicConfig(
- level=logging.DEBUG, format="%(levelname)s: %(name)s: %(message)s"
+ level=logging.DEBUG,
+ format="%(levelname)s: %(name)s: %(message)s",
)
logging.addLevelName(logging.DEBUG, "dbug")
logging.addLevelName(logging.ERROR, "fail")
diff --git a/Biz/Mynion.py b/Biz/Mynion.py
index d3dabcf..3b80f5f 100644
--- a/Biz/Mynion.py
+++ b/Biz/Mynion.py
@@ -1,38 +1,51 @@
-"""
-Mynion is a helper.
-"""
+"""Mynion is a helper."""
+
# : out mynion
# : dep exllama
# : dep slixmpp
import argparse
-import exllama # type: ignore
-import Biz.Log
-import glob
+import dataclasses
import logging
import os
+import pathlib
+import sys
+import typing
+
+import exllama # type: ignore[import]
import slixmpp
import slixmpp.exceptions
-import sys
import torch
-import typing
+
+import Biz.Log
def smoosh(s: str) -> str:
+ """Replace newlines with spaces."""
return s.replace("\n", " ")
+@dataclasses.dataclass
+class Auth:
+ """Container for XMPP authentication."""
+
+ jid: str
+ password: str
+
+
class Mynion(slixmpp.ClientXMPP):
+ """A helper via xmpp."""
+
def __init__(
- self,
- jid: str,
- password: str,
+ self: "Mynion",
+ auth: Auth,
model: exllama.model.ExLlama,
tokenizer: exllama.tokenizer.ExLlamaTokenizer,
generator: exllama.generator.ExLlamaGenerator,
) -> None:
- slixmpp.ClientXMPP.__init__(self, jid, password)
- self.plugin.enable("xep_0085") # type: ignore
- self.plugin.enable("xep_0184") # type: ignore
+ """Initialize Mynion chat bot service."""
+ slixmpp.ClientXMPP.__init__(self, auth.jid, auth.password)
+ self.plugin.enable("xep_0085") # type: ignore[attr-defined]
+ self.plugin.enable("xep_0184") # type: ignore[attr-defined]
self.name = "mynion"
self.user = "ben"
@@ -40,7 +53,6 @@ class Mynion(slixmpp.ClientXMPP):
self.min_response_tokens = 4
self.max_response_tokens = 256
self.extra_prune = 256
- # todo: handle summary rollup when max_seq_len is reached
self.max_seq_len = 8000
self.model = model
@@ -49,9 +61,8 @@ class Mynion(slixmpp.ClientXMPP):
root = os.getenv("CODEROOT", "")
# this should be parameterized somehow
- with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f:
- txt = f.read()
- txt = txt.format(user=self.user, name=self.name)
+ promptfile = pathlib.Path(root) / "Biz" / "Mynion" / "Prompt.md"
+ txt = promptfile.read_text().format(user=self.user, name=self.name)
# this is the "system prompt", ideally i would load this in/out of a
# database with all of the past history. if the history gets too long, i
@@ -64,20 +75,22 @@ class Mynion(slixmpp.ClientXMPP):
self.add_event_handler("session_start", self.session_start)
self.add_event_handler("message", self.message)
- def session_start(self) -> None:
+ def session_start(self: "Mynion") -> None:
+ """Start online session with xmpp server."""
self.send_presence()
try:
- self.get_roster() # type: ignore
+ self.get_roster() # type: ignore[no-untyped-call]
except slixmpp.exceptions.IqError as err:
- logging.error("There was an error getting the roster")
- logging.error(err.iq["error"]["condition"])
+ logging.exception("There was an error getting the roster")
+ logging.exception(err.iq["error"]["condition"])
self.disconnect()
except slixmpp.exceptions.IqTimeout:
- logging.error("Server is taking too long to respond")
+ logging.exception("Server is taking too long to respond")
self.disconnect()
- def message(self, msg: slixmpp.Message) -> None:
- if msg["type"] in ("chat", "normal"):
+ def message(self: "Mynion", msg: slixmpp.Message) -> None:
+ """Send a message."""
+ if msg["type"] in {"chat", "normal"}:
res_line = f"{self.name}: "
res_tokens = self.tokenizer.encode(res_line)
num_res_tokens = res_tokens.shape[-1]
@@ -109,7 +122,6 @@ class Mynion(slixmpp.ClientXMPP):
self.generator.begin_beam_search()
# generate tokens, with streaming
- # TODO: drop the streaming!
for i in range(self.max_response_tokens):
# disallowing the end condition tokens seems like a clean way to
# force longer replies
@@ -118,7 +130,7 @@ class Mynion(slixmpp.ClientXMPP):
[
self.tokenizer.newline_token_id,
self.tokenizer.eos_token_id,
- ]
+ ],
)
else:
self.generator.disallow_tokens(None)
@@ -129,13 +141,13 @@ class Mynion(slixmpp.ClientXMPP):
# if token is EOS, replace it with a newline before continuing
if gen_token.item() == self.tokenizer.eos_token_id:
self.generator.replace_last_token(
- self.tokenizer.newline_token_id
+ self.tokenizer.newline_token_id,
)
# decode the current line
num_res_tokens += 1
text = self.tokenizer.decode(
- self.generator.sequence_actual[:, -num_res_tokens:][0]
+ self.generator.sequence_actual[:, -num_res_tokens:][0],
)
# append to res_line
@@ -161,7 +173,7 @@ class Mynion(slixmpp.ClientXMPP):
res_line = res_line.removeprefix(f"{self.name}:")
res_line = res_line.removesuffix(f"{self.user}:")
self.first_round = False
- msg.reply(res_line).send() # type: ignore
+ msg.reply(res_line).send() # type: ignore[no-untyped-call]
MY_MODELS = [
@@ -178,26 +190,31 @@ MY_MODELS = [
def load_model(model_name: str) -> typing.Any:
- assert model_name in MY_MODELS
+ """Load an ML model from disk."""
+ if model_name not in MY_MODELS:
+ msg = f"{model_name} not available"
+ raise ValueError(msg)
if not torch.cuda.is_available():
- raise ValueError("no cuda")
+ msg = "no cuda"
+ raise ValueError(msg)
sys.exit(1)
- torch.set_grad_enabled(False)
- torch.cuda._lazy_init() # type: ignore
+ torch.set_grad_enabled(mode=False)
+ torch.cuda.init() # type: ignore[no-untyped-call]
ml_models = "/mnt/campbell/ben/ml-models"
- model_dir = os.path.join(ml_models, model_name)
+ model_dir = pathlib.Path(ml_models) / model_name
- tokenizer_path = os.path.join(model_dir, "tokenizer.model")
- config_path = os.path.join(model_dir, "config.json")
- st_pattern = os.path.join(model_dir, "*.safetensors")
- st = glob.glob(st_pattern)
+ tokenizer_path = pathlib.Path(model_dir) / "tokenizer.model"
+ config_path = pathlib.Path(model_dir) / "config.json"
+ st = list(pathlib.Path(model_dir).glob("*.safetensors"))
if len(st) > 1:
- raise ValueError("found multiple safetensors!")
- elif len(st) < 1:
- raise ValueError("could not find model")
+ msg = "found multiple safetensors!"
+ raise ValueError(msg)
+ if len(st) < 1:
+ msg = "could not find model"
+ raise ValueError(msg)
model_path = st[0]
config = exllama.model.ExLlamaConfig(config_path)
@@ -230,14 +247,15 @@ def main(
model in the repl and then restart the chatbot without unloading it.
"""
Biz.Log.setup()
- xmpp = Mynion(user, password, model, tokenizer, generator)
+ auth = Auth(user, password)
+ xmpp = Mynion(auth, model, tokenizer, generator)
xmpp.connect()
- xmpp.process(forever=True) # type: ignore
+ xmpp.process(forever=True) # type: ignore[no-untyped-call]
if __name__ == "__main__":
if "test" in sys.argv:
- print("pass: test: Biz/Mynion.py")
+ sys.stdout.write("pass: test: Biz/Mynion.py\n")
sys.exit(0)
else:
cli = argparse.ArgumentParser(description=__doc__)
diff --git a/Biz/Que/Client.py b/Biz/Que/Client.py
index 20349b6..53f14e4 100755..100644
--- a/Biz/Que/Client.py
+++ b/Biz/Que/Client.py
@@ -1,22 +1,20 @@
-#!/usr/bin/env python3
# : out que
-"""
-simple client for que.run
-"""
+# ruff: noqa: PERF203
+"""simple client for que.run."""
import argparse
import configparser
import functools
import http.client
import logging
-import os
+import pathlib
import subprocess
import sys
import textwrap
import time
-import urllib.parse
-import urllib.request as request
import typing
+import urllib.parse
+from urllib import request
MAX_TIMEOUT = 9999999
RETRIES = 10
@@ -24,14 +22,14 @@ DELAY = 3
BACKOFF = 1
-def auth(args: argparse.Namespace) -> typing.Union[str, None]:
- "Returns the auth key for the given ns from ~/.config/que.conf"
+def auth(args: argparse.Namespace) -> str | None:
+ """Return the auth key for the given ns from ~/.config/que.conf."""
logging.debug("auth")
namespace = args.target.split("/")[0]
if namespace == "pub":
return None
- conf_file = os.path.expanduser("~/.config/que.conf")
- if not os.path.exists(conf_file):
+ conf_file = pathlib.Path("~/.config/que.conf").expanduser()
+ if not conf_file.exists():
sys.exit("you need a ~/.config/que.conf")
cfg = configparser.ConfigParser()
cfg.read(conf_file)
@@ -39,8 +37,11 @@ def auth(args: argparse.Namespace) -> typing.Union[str, None]:
def autodecode(bytestring: bytes) -> typing.Any:
- """Attempt to decode bytes into common codecs, preferably utf-8. If no
- decoding is available, just return the raw bytes.
+ """
+ Automatically decode bytes into common codecs.
+
+ Or at least make an attempt. Output is preferably utf-8. If no decoding is
+ available, just return the raw bytes.
For all available codecs, see:
<https://docs.python.org/3/library/codecs.html#standard-encodings>
@@ -63,7 +64,7 @@ def retry(
delay: typing.Any = DELAY,
backoff: typing.Any = BACKOFF,
) -> typing.Any:
- "Decorator for retrying an action."
+ """Retry an action."""
def decorator(func: typing.Any) -> typing.Any:
@functools.wraps(func)
@@ -90,7 +91,7 @@ def retry(
@retry(http.client.IncompleteRead)
@retry(http.client.RemoteDisconnected)
def send(args: argparse.Namespace) -> None:
- "Send a message to the que."
+ """Send a message to the que."""
logging.debug("send")
key = auth(args)
data = args.infile
@@ -108,13 +109,13 @@ def send(args: argparse.Namespace) -> None:
def then(args: argparse.Namespace, msg: str) -> None:
- "Perform an action when passed `--then`."
+ """Perform an action when passed `--then`."""
if args.then:
logging.debug("then")
subprocess.run(
args.then.format(msg=msg, que=args.target),
check=False,
- shell=True,
+ shell=True, # noqa: S602
)
@@ -123,7 +124,7 @@ def then(args: argparse.Namespace, msg: str) -> None:
@retry(http.client.IncompleteRead)
@retry(http.client.RemoteDisconnected)
def recv(args: argparse.Namespace) -> None:
- "Receive a message from the que."
+ """Receive a message from the que."""
logging.debug("recv on: %s", args.target)
if args.poll:
req = request.Request(f"{args.host}/{args.target}/stream")
@@ -133,6 +134,9 @@ def recv(args: argparse.Namespace) -> None:
key = auth(args)
if key:
req.add_header("Authorization", key)
+ if not req.startswith(("http:", "https:")):
+ msg = "URL must start with 'http:' or 'https:'"
+ raise ValueError(msg)
with request.urlopen(req) as _req:
if args.poll:
logging.debug("polling")
@@ -141,29 +145,31 @@ def recv(args: argparse.Namespace) -> None:
if reply:
msg = autodecode(reply)
logging.debug("read")
- print(msg, end="")
+ sys.stdout.write(msg)
then(args, msg)
else:
continue
else:
msg = autodecode(_req.readline())
- print(msg)
+ sys.stdout.write(msg)
then(args, msg)
def get_args() -> argparse.Namespace:
- "Command line parser"
+ """Command line parser."""
cli = argparse.ArgumentParser(
description=__doc__,
epilog=textwrap.dedent(
f"""Requests will retry up to {RETRIES} times, with {DELAY} seconds
- between attempts."""
+ between attempts.""",
),
)
cli.add_argument("test", action="store_true", help="run tests")
cli.add_argument("--debug", action="store_true", help="log to stderr")
cli.add_argument(
- "--host", default="http://que.run", help="where que-server is running"
+ "--host",
+ default="http://que.run",
+ help="where que-server is running",
)
cli.add_argument(
"--poll",
@@ -171,7 +177,7 @@ def get_args() -> argparse.Namespace:
action="store_true",
help=textwrap.dedent(
"""keep the connection open to stream data from the que. without
- this flag, the program will exit after receiving a message"""
+ this flag, the program will exit after receiving a message""",
),
)
cli.add_argument(
@@ -179,7 +185,7 @@ def get_args() -> argparse.Namespace:
help=textwrap.dedent(
"""when polling, run this shell command after each response,
presumably for side effects, replacing '{que}' with the target and
- '{msg}' with the body of the response"""
+ '{msg}' with the body of the response""",
),
)
cli.add_argument(
@@ -188,22 +194,18 @@ def get_args() -> argparse.Namespace:
action="store_true",
help=textwrap.dedent(
"""when posting to the que, do so continuously in a loop. this can
- be used for serving a webpage or other file continuously"""
+ be used for serving a webpage or other file continuously""",
),
)
cli.add_argument(
- "target", help="namespace and path of the que, like 'ns/path'"
+ "target",
+ help="namespace and path of the que, like 'ns/path'",
)
cli.add_argument(
"infile",
nargs="?",
type=argparse.FileType("rb"),
- help=" ".join(
- [
- "data to put on the que.",
- "use '-' for stdin, otherwise should be a readable file",
- ]
- ),
+ help="file of data to put on the que. use '-' for stdin",
)
return cli.parse_args()
@@ -211,7 +213,7 @@ def get_args() -> argparse.Namespace:
if __name__ == "__main__":
ARGV = get_args()
if ARGV.test:
- print("ok")
+ sys.stdout.write("ok\n")
sys.exit()
if ARGV.debug:
logging.basicConfig(
diff --git a/Biz/Repl.py b/Biz/Repl.py
index 9844abf..9cc0c35 100644
--- a/Biz/Repl.py
+++ b/Biz/Repl.py
@@ -1,19 +1,25 @@
"""
+Improve the standard Python REPL.
+
This module attempts to emulate the workflow of ghci or lisp repls. It uses
importlib to load a namespace from the given path. It then binds 'r()' to a
function that reloads the same namespace.
"""
import importlib
+import logging
import sys
+from Biz import Log
+
def use(ns: str, path: str) -> None:
"""
- Load or reload the module named 'ns' from 'path'. Like `use` in the Guile
- Scheme repl.
+ Load or reload the module named 'ns' from 'path'.
+
+ Like `use` in the Guile Scheme repl.
"""
- info(f"loading {ns} from {path}")
+ logging.info("loading %s from %s", ns, path)
spec = importlib.util.spec_from_file_location(ns, path)
module = importlib.util.module_from_spec(spec)
# delete module and its imported names if its already loaded
@@ -24,25 +30,24 @@ def use(ns: str, path: str) -> None:
del globals()[name]
sys.modules[ns] = module
spec.loader.exec_module(module)
- names = [x for x in module.__dict__]
+ names = list(module.__dict__)
globals().update({k: getattr(module, k) for k in names})
-def info(s):
- print(f"info: repl: {s}")
-
-
if __name__ == "__main__":
+ Log.setup()
NS = sys.argv[1]
PATH = sys.argv[2]
use(NS, PATH)
- info("use reload() or _r() after making changes")
+ logging.info("use reload() or _r() after making changes")
sys.ps1 = f"{NS}> "
sys.ps2 = f"{NS}| "
- def reload():
+ def reload() -> None:
+ """Reload the namespace."""
return use(NS, PATH)
- def _r():
+ def _r() -> None:
+ """Shorthand: Reload the namespace."""
return use(NS, PATH)