From 2c09c7f73e2fc770f42b5dd2588aa9634b4e7c6e Mon Sep 17 00:00:00 2001 From: Ben Sima Date: Wed, 10 Apr 2024 19:56:46 -0400 Subject: Switch from black to ruff format Ruff is faster and if it supports everything that black supports than why not? I did have to pull in a more recent version from unstable, but that's easy to do now. And I decided to just go ahead and configure ruff by turning on almost all checks, which meant I had to fix a whole bunch of things, but I did that and everything is okay now. --- Biz/Bild.nix | 4 +- Biz/Bild/Builder.nix | 22 ++++----- Biz/Bild/Example.py | 22 ++++++--- Biz/Dragons/main.py | 132 +++++++++++++++++++++++++++------------------------ Biz/Ide/repl.sh | 8 ++-- Biz/Lint.hs | 33 +++++++------ Biz/Llamacpp.py | 4 +- Biz/Log.py | 17 ++++--- Biz/Mynion.py | 108 +++++++++++++++++++++++------------------ Biz/Que/Client.py | 70 ++++++++++++++------------- Biz/Repl.py | 27 ++++++----- pyproject.toml | 33 +++++++++++-- 12 files changed, 275 insertions(+), 205 deletions(-) mode change 100755 => 100644 Biz/Dragons/main.py mode change 100755 => 100644 Biz/Que/Client.py diff --git a/Biz/Bild.nix b/Biz/Bild.nix index 24ce4bf..7a07f36 100644 --- a/Biz/Bild.nix +++ b/Biz/Bild.nix @@ -64,7 +64,8 @@ in nixpkgs // { # expose some packages for inclusion in os/image builds pkgs = with nixpkgs.pkgs; { - inherit black deadnix git hlint indent ormolu ruff shellcheck nixfmt; + inherit deadnix git hlint indent ormolu shellcheck nixfmt mypy pkg-config; + ruff = unstable.ruff; }; # a standard nix build for bild, for bootstrapping. this should be the only @@ -146,7 +147,6 @@ in nixpkgs // { bat bc bild - black ctags fd figlet diff --git a/Biz/Bild/Builder.nix b/Biz/Bild/Builder.nix index d2e6875..4bef830 100644 --- a/Biz/Bild/Builder.nix +++ b/Biz/Bild/Builder.nix @@ -110,14 +110,14 @@ let inherit name src CODEROOT; propagatedBuildInputs = langdeps_ ++ sysdeps_; buildInputs = sysdeps_; - nativeCheckInputs = [ black mypy ruff ]; + nativeCheckInputs = lib.attrsets.attrVals [ "mypy" "ruff" ] bild.pkgs; checkPhase = '' check() { $@ || { echo "fail: $name: $3"; exit 1; } } cp ${../../pyproject.toml} ./pyproject.toml - check python -m black --quiet --exclude 'setup\.py$' --check . - check ${ruff}/bin/ruff check . + check ruff format --exclude 'setup.py' --check . + check ruff check --exclude 'setup.py' --exclude '__init__.py' . touch ./py.typed check python -m mypy \ --explicit-package-bases \ @@ -133,15 +133,15 @@ let find . -type d -exec touch {}/__init__.py \; # generate a minimal setup.py cat > setup.py << EOF - from setuptools import setup, find_packages + from setuptools import find_packages, setup setup( - name='${name}', - entry_points={'console_scripts':['${name} = ${mainModule}:main']}, - version='0.0.0', - url='git://simatime.com/biz.git', - author='dev', - author_email='dev@simatime.com', - description='nil', + name="${name}", + entry_points={"console_scripts":["${name} = ${mainModule}:main"]}, + version="0.0.0", + url="git://simatime.com/biz.git", + author="dev", + author_email="dev@simatime.com", + description="nil", packages=find_packages(), install_requires=[], ) diff --git a/Biz/Bild/Example.py b/Biz/Bild/Example.py index 5d165d8..1bd30ae 100644 --- a/Biz/Bild/Example.py +++ b/Biz/Bild/Example.py @@ -1,32 +1,42 @@ """ +Test that bild can build Python stuff. + Example Python file that also serves as a test case for bild. """ + # : out example # : dep cryptography import sys + import cryptography.fernet def cryptic_hello(name: str) -> str: - "Example taken from `cryptography` docs." + """ + Encrypt and decrypt `name`. + + Example taken from `cryptography` docs. + """ key = cryptography.fernet.Fernet.generate_key() f = cryptography.fernet.Fernet(key) token = f.encrypt(hello(name).encode("utf-8")) ret = f.decrypt(token).decode("utf-8") - assert ret == hello(name) + if ret != hello(name): + msg = "en/decryption failed!" + raise ValueError(msg) return ret def hello(name: str) -> str: - "Say hello" + """Say hello.""" return f"Hello {name}" def main() -> None: - "Entrypoint" + """Entrypoint.""" if "test" in sys.argv: - print("testing success") - print(cryptic_hello("world")) + sys.stdout.write("testing success") + sys.stdout.write(cryptic_hello("world")) if __name__ == "__main__": diff --git a/Biz/Dragons/main.py b/Biz/Dragons/main.py old mode 100755 new mode 100644 index 7a94f99..6e2c995 --- a/Biz/Dragons/main.py +++ b/Biz/Dragons/main.py @@ -1,13 +1,11 @@ -#!/usr/bin/env python # : out dragons.py -""" -Analyze developer allocation across a codebase. -""" +"""Analyze developer allocation across a codebase.""" import argparse import datetime import logging import os +import pathlib import re import subprocess import sys @@ -15,18 +13,26 @@ import typing def find_user(line: str) -> typing.Any: - """Given 'Ben Sima ', finds `Ben Sima'. Returns the first - matching string.""" + """ + Find a person's name in a .mailmap file. + + Given 'Ben Sima ', finds `Ben Sima'. Returns the first + matching string. + """ return re.findall(r"^[^<]*", line)[0].strip() def authors_for( - path: str, active_users: typing.List[str] -) -> typing.Dict[str, str]: - """Return a dictionary of {author: commits} for given path. Usernames not in - the 'active_users' list will be filtered out.""" + path: str, + active_users: list[str], +) -> dict[str, str]: + """ + Return a dictionary of {author: commits} for given path. + + Usernames not in the 'active_users' list will be filtered out. + """ raw = subprocess.check_output( - ["git", "shortlog", "--numbered", "--summary", "--email", "--", path] + ["git", "shortlog", "--numbered", "--summary", "--email", "--", path], ).decode("utf-8") lines = [s for s in raw.split("\n") if s] data = {} @@ -39,21 +45,18 @@ def authors_for( return data -def mailmap_users() -> typing.List[str]: - """Returns users from the .mailmap file.""" - users = [] - with open(".mailmap", encoding="utf-8") as file: +def mailmap_users() -> list[str]: + """Return users from the .mailmap file.""" + with pathlib.Path(".mailmap").open() as file: lines = file.readlines() - for line in lines: - users.append(find_user(line)) - return users + return [find_user(line) for line in lines] MAX_SCORE = 10 def score(blackhole: float, liability: float, good: int, total: int) -> float: - "Calculate the score." + """Calculate the score.""" weights = { "blackhole": 0.5, "liability": 0.7, @@ -70,17 +73,20 @@ def score(blackhole: float, liability: float, good: int, total: int) -> float: def get_args() -> typing.Any: - "Parse CLI arguments." + """Parse CLI arguments.""" cli = argparse.ArgumentParser(description=__doc__) cli.add_argument("test", action="store_true", help="run the test suite") cli.add_argument( - "repo", default=".", help="the git repo to run on", metavar="REPO" + "repo", + default=".", + help="the git repo to run on", + metavar="REPO", ) cli.add_argument( "-b", "--blackholes", action="store_true", - help="print the blackholes (files with one or zero active contributors)", + help="print the blackholes (files with 1 or 0 active contributors)", ) cli.add_argument( "-l", @@ -105,12 +111,7 @@ def get_args() -> typing.Any: "--active-users", nargs="+", default=[], - help=" ".join( - [ - "list of active user emails." - "if not provided, this is loaded from .mailmap" - ] - ), + help="list of active user emails. default: loaded from .mailmap", ) cli.add_argument( "-v", @@ -123,7 +124,7 @@ def get_args() -> typing.Any: def staleness(path: str, now: datetime.datetime) -> int: - "How long has it been since this file was touched?" + """How long has it been since this file was touched?.""" timestamp = datetime.datetime.strptime( subprocess.check_output(["git", "log", "-n1", "--pretty=%aI", path]) .decode("utf-8") @@ -135,15 +136,18 @@ def staleness(path: str, now: datetime.datetime) -> int: class Repo: - "Represents a repo and stats for the repo." + """Represents a repo and stats for the repo.""" def __init__( - self, ignored_paths: typing.List[str], active_users: typing.List[str] + self: "Repo", + ignored_paths: list[str], + active_users: list[str], ) -> None: + """Create analysis of a git repo.""" self.paths = [ p for p in subprocess.check_output( - ["git", "ls-files", "--no-deleted"] + ["git", "ls-files", "--no-deleted"], ) .decode("utf-8") .split() @@ -156,60 +160,65 @@ class Repo: self.blackholes = [ path for path, authors in self.stats.items() if not authors ] + max_authors = 3 self.liabilities = { path: list(authors) for path, authors in self.stats.items() - if 1 <= len(authors) < 3 + if 1 <= len(authors) < max_authors } now = datetime.datetime.utcnow().astimezone() self.stale = {} - for path, _ in self.stats.items(): + max_staleness = 180 + for path in self.stats: _staleness = staleness(path, now) - if _staleness > 180: + if _staleness > max_staleness: self.stale[path] = _staleness - def print_blackholes(self, full: bool) -> None: - "Print number of blackholes, or list of all blackholes." + def print_blackholes(self: "Repo", *, full: bool) -> None: + """Print number of blackholes, or list of all blackholes.""" # note: file renames may result in false positives n_blackhole = len(self.blackholes) - print(f"Blackholes: {n_blackhole}") + sys.stdout.write(f"Blackholes: {n_blackhole}") if full: for path in self.blackholes: - print(f" {path}") + sys.stdout.write(f" {path}") + sys.stdout.flush() - def print_liabilities(self, full: bool) -> None: - "Print number of liabilities, or list of all liabilities." + def print_liabilities(self: "Repo", *, full: bool) -> None: + """Print number of liabilities, or list of all liabilities.""" n_liabilities = len(self.liabilities) - print(f"Liabilities: {n_liabilities}") + sys.stdout.write(f"Liabilities: {n_liabilities}") if full: for path, authors in self.liabilities.items(): - print(f" {path} ({', '.join(authors)})") + sys.stdout.write(f" {path} ({', '.join(authors)})") + sys.stdout.flush() - def print_score(self) -> None: - "Print the overall score." + def print_score(self: "Repo") -> None: + """Print the overall score.""" n_total = len(self.stats.keys()) n_blackhole = len(self.blackholes) n_liabilities = len(self.liabilities) n_good = n_total - n_blackhole - n_liabilities - print("Total:", n_total) + sys.stdout.write(f"Total: {n_total}") this_score = score(n_blackhole, n_liabilities, n_good, n_total) - print(f"Score: {this_score:.2f}/{MAX_SCORE}".format()) + sys.stdout.write(f"Score: {this_score:.2f}/{MAX_SCORE}".format()) + sys.stdout.flush() - def print_stale(self, full: bool) -> None: - "Print stale files" + def print_stale(self: "Repo", *, full: bool) -> None: + """Print stale files.""" n_stale = len(self.stale) - print(f"Stale files: {n_stale}") + sys.stdout.write(f"Stale files: {n_stale}") if full: for path, days in self.stale.items(): - print(f" {path} ({days} days)") + sys.stdout.write(f" {path} ({days} days)") + sys.stdout.flush() def guard_git(repo: Repo) -> None: - "Guard against non-git repos." + """Guard against non-git repos.""" is_git = subprocess.run( ["git", "rev-parse"], - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, + capture_output=True, check=False, ).returncode if is_git != 0: @@ -219,25 +228,24 @@ def guard_git(repo: Repo) -> None: if __name__ == "__main__": ARGS = get_args() if ARGS.test: - print("ok") + sys.stdout.write("ok") sys.exit() logging.basicConfig(stream=sys.stderr, level=ARGS.verbosity.upper()) logging.debug("starting") - os.chdir(os.path.abspath(ARGS.repo)) + os.chdir(pathlib.Path(ARGS.repo).resolve()) guard_git(ARGS.repo) # if no active users provided, load from .mailmap - if ARGS.active_users == []: - if os.path.exists(".mailmap"): - ARGS.active_users = mailmap_users() + if ARGS.active_users == [] and pathlib.Path(".mailmap").exists(): + ARGS.active_users = mailmap_users() # collect data REPO = Repo(ARGS.ignored, ARGS.active_users) # print data REPO.print_score() - REPO.print_blackholes(ARGS.blackholes) - REPO.print_liabilities(ARGS.liabilities) - REPO.print_stale(ARGS.stale) + REPO.print_blackholes(full=ARGS.blackholes) + REPO.print_liabilities(full=ARGS.liabilities) + REPO.print_stale(full=ARGS.stale) diff --git a/Biz/Ide/repl.sh b/Biz/Ide/repl.sh index 78fe1eb..8b28dcd 100755 --- a/Biz/Ide/repl.sh +++ b/Biz/Ide/repl.sh @@ -33,16 +33,16 @@ fi packageSet=$(jq --raw-output '.[].packageSet' <<< "$json") module=$(jq --raw-output '.[].mainModule' <<< "$json") BILD="(import ${CODEROOT:?}/Biz/Bild.nix {})" - declare -a flags=(--packages "$BILD.pkgs.pkg-config") + declare -a flags=(--packages "$BILD.bild.pkgs.pkg-config") for lib in "${sysdeps[@]}"; do - flags+=(--packages "$BILD.pkgs.${lib}") + flags+=(--packages "$BILD.bild.pkgs.${lib}") done for lib in "${rundeps[@]}"; do - flags+=(--packages "$BILD.pkgs.${lib}") + flags+=(--packages "$BILD.bild.pkgs.${lib}") done case $exts in C) - flags+=(--packages "$BILD.pkgs.gcc") + flags+=(--packages "$BILD.bild.pkgs.gcc") command="bash" ;; Hs) diff --git a/Biz/Lint.hs b/Biz/Lint.hs index d27ca1d..d387db0 100644 --- a/Biz/Lint.hs +++ b/Biz/Lint.hs @@ -10,7 +10,6 @@ -- : out lint -- : run ormolu -- : run hlint --- : run black -- : run ruff -- : run deadnix -- : run shellcheck @@ -138,7 +137,7 @@ data Linter = Linter fixArgs :: Maybe [Text], -- | An optional function to format the output of the linter as you want -- it, perhaps decoding json or something - formatter :: Maybe (String -> String) + decoder :: Maybe (String -> String) } ormolu :: Linter @@ -147,7 +146,7 @@ ormolu = { exe = "ormolu", checkArgs = ["--mode", "check", "--no-cabal"], fixArgs = Just ["--mode", "inplace", "--no-cabal"], - formatter = Nothing + decoder = Nothing } hlint :: Linter @@ -158,16 +157,16 @@ hlint = -- needs apply-refact >0.9.1.0, which needs ghc >9 -- fixArgs = Just ["--refactor", "--refactor-options=-i"] fixArgs = Nothing, - formatter = Nothing + decoder = Nothing } -black :: Linter -black = +ruffFormat :: Linter +ruffFormat = Linter - { exe = "black", - checkArgs = ["--check"], - fixArgs = Just [], - formatter = Nothing + { exe = "ruff", + checkArgs = ["format", "--check", "--silent"], + fixArgs = Just ["format", "--silent"], + decoder = Nothing } ruff :: Linter @@ -176,7 +175,7 @@ ruff = { exe = "ruff", checkArgs = ["check"], fixArgs = Just ["check", "--fix"], - formatter = Nothing + decoder = Nothing } data DeadnixOutput = DeadnixOutput @@ -199,7 +198,7 @@ deadnix = { exe = "deadnix", checkArgs = "--fail" : commonArgs, fixArgs = Just <| "--edit" : commonArgs, - formatter = Just decodeDeadnixOutput + decoder = Just decodeDeadnixOutput } where commonArgs = @@ -227,7 +226,7 @@ nixfmt = { exe = "nixfmt", checkArgs = ["--check"], fixArgs = Nothing, - formatter = Nothing + decoder = Nothing } shellcheck :: Linter @@ -236,7 +235,7 @@ shellcheck = { exe = "shellcheck", checkArgs = [], fixArgs = Nothing, - formatter = Nothing + decoder = Nothing } indent :: Linter @@ -245,7 +244,7 @@ indent = { exe = "indent", checkArgs = [], fixArgs = Nothing, - formatter = Nothing + decoder = Nothing } data Status = Good | Bad String @@ -272,7 +271,7 @@ runOne mode (ext, ns's) = results +> traverse printResult lint mode hlint ns's ] Namespace.Py -> - [ lint mode black ns's, + [ lint mode ruffFormat ns's, lint mode ruff ns's ] Namespace.Sh -> [lint mode shellcheck ns's] @@ -294,7 +293,7 @@ lint mode linter@Linter {..} ns's = >> Process.readProcessWithExitCode (str exe) args "" /> \case (Exit.ExitSuccess, _, _) -> Done linter Good - (Exit.ExitFailure _, msg, _) -> case formatter of + (Exit.ExitFailure _, msg, _) -> case decoder of Nothing -> Done linter <| Bad msg Just fmt -> Done linter <| Bad <| fmt msg where diff --git a/Biz/Llamacpp.py b/Biz/Llamacpp.py index e75de5b..cd47e1e 100644 --- a/Biz/Llamacpp.py +++ b/Biz/Llamacpp.py @@ -1,6 +1,4 @@ -""" -Llamacpp -""" +"""Llamacpp.""" # : run nixos-23_11.llama-cpp # : run nixos-23_11.openblas diff --git a/Biz/Log.py b/Biz/Log.py index af28c41..68b1e90 100644 --- a/Biz/Log.py +++ b/Biz/Log.py @@ -1,21 +1,24 @@ -""" -Setup logging like Biz/Log.hs. -""" +"""Setup logging like Biz/Log.hs.""" +# ruff: noqa: A003 import logging import typing class LowerFormatter(logging.Formatter): - def format(self, record: typing.Any) -> typing.Any: + """A logging formatter that formats logs how I like.""" + + def format(self: "LowerFormatter", record: typing.Any) -> typing.Any: + """Use the format I like for logging.""" record.levelname = record.levelname.lower() - return super(logging.Formatter, self).format(record) # type: ignore + return super(logging.Formatter, self).format(record) # type: ignore[misc] def setup() -> None: - "Run this in your __main__ function" + """Run this in your __main__ function.""" logging.basicConfig( - level=logging.DEBUG, format="%(levelname)s: %(name)s: %(message)s" + level=logging.DEBUG, + format="%(levelname)s: %(name)s: %(message)s", ) logging.addLevelName(logging.DEBUG, "dbug") logging.addLevelName(logging.ERROR, "fail") diff --git a/Biz/Mynion.py b/Biz/Mynion.py index d3dabcf..3b80f5f 100644 --- a/Biz/Mynion.py +++ b/Biz/Mynion.py @@ -1,38 +1,51 @@ -""" -Mynion is a helper. -""" +"""Mynion is a helper.""" + # : out mynion # : dep exllama # : dep slixmpp import argparse -import exllama # type: ignore -import Biz.Log -import glob +import dataclasses import logging import os +import pathlib +import sys +import typing + +import exllama # type: ignore[import] import slixmpp import slixmpp.exceptions -import sys import torch -import typing + +import Biz.Log def smoosh(s: str) -> str: + """Replace newlines with spaces.""" return s.replace("\n", " ") +@dataclasses.dataclass +class Auth: + """Container for XMPP authentication.""" + + jid: str + password: str + + class Mynion(slixmpp.ClientXMPP): + """A helper via xmpp.""" + def __init__( - self, - jid: str, - password: str, + self: "Mynion", + auth: Auth, model: exllama.model.ExLlama, tokenizer: exllama.tokenizer.ExLlamaTokenizer, generator: exllama.generator.ExLlamaGenerator, ) -> None: - slixmpp.ClientXMPP.__init__(self, jid, password) - self.plugin.enable("xep_0085") # type: ignore - self.plugin.enable("xep_0184") # type: ignore + """Initialize Mynion chat bot service.""" + slixmpp.ClientXMPP.__init__(self, auth.jid, auth.password) + self.plugin.enable("xep_0085") # type: ignore[attr-defined] + self.plugin.enable("xep_0184") # type: ignore[attr-defined] self.name = "mynion" self.user = "ben" @@ -40,7 +53,6 @@ class Mynion(slixmpp.ClientXMPP): self.min_response_tokens = 4 self.max_response_tokens = 256 self.extra_prune = 256 - # todo: handle summary rollup when max_seq_len is reached self.max_seq_len = 8000 self.model = model @@ -49,9 +61,8 @@ class Mynion(slixmpp.ClientXMPP): root = os.getenv("CODEROOT", "") # this should be parameterized somehow - with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f: - txt = f.read() - txt = txt.format(user=self.user, name=self.name) + promptfile = pathlib.Path(root) / "Biz" / "Mynion" / "Prompt.md" + txt = promptfile.read_text().format(user=self.user, name=self.name) # this is the "system prompt", ideally i would load this in/out of a # database with all of the past history. if the history gets too long, i @@ -64,20 +75,22 @@ class Mynion(slixmpp.ClientXMPP): self.add_event_handler("session_start", self.session_start) self.add_event_handler("message", self.message) - def session_start(self) -> None: + def session_start(self: "Mynion") -> None: + """Start online session with xmpp server.""" self.send_presence() try: - self.get_roster() # type: ignore + self.get_roster() # type: ignore[no-untyped-call] except slixmpp.exceptions.IqError as err: - logging.error("There was an error getting the roster") - logging.error(err.iq["error"]["condition"]) + logging.exception("There was an error getting the roster") + logging.exception(err.iq["error"]["condition"]) self.disconnect() except slixmpp.exceptions.IqTimeout: - logging.error("Server is taking too long to respond") + logging.exception("Server is taking too long to respond") self.disconnect() - def message(self, msg: slixmpp.Message) -> None: - if msg["type"] in ("chat", "normal"): + def message(self: "Mynion", msg: slixmpp.Message) -> None: + """Send a message.""" + if msg["type"] in {"chat", "normal"}: res_line = f"{self.name}: " res_tokens = self.tokenizer.encode(res_line) num_res_tokens = res_tokens.shape[-1] @@ -109,7 +122,6 @@ class Mynion(slixmpp.ClientXMPP): self.generator.begin_beam_search() # generate tokens, with streaming - # TODO: drop the streaming! for i in range(self.max_response_tokens): # disallowing the end condition tokens seems like a clean way to # force longer replies @@ -118,7 +130,7 @@ class Mynion(slixmpp.ClientXMPP): [ self.tokenizer.newline_token_id, self.tokenizer.eos_token_id, - ] + ], ) else: self.generator.disallow_tokens(None) @@ -129,13 +141,13 @@ class Mynion(slixmpp.ClientXMPP): # if token is EOS, replace it with a newline before continuing if gen_token.item() == self.tokenizer.eos_token_id: self.generator.replace_last_token( - self.tokenizer.newline_token_id + self.tokenizer.newline_token_id, ) # decode the current line num_res_tokens += 1 text = self.tokenizer.decode( - self.generator.sequence_actual[:, -num_res_tokens:][0] + self.generator.sequence_actual[:, -num_res_tokens:][0], ) # append to res_line @@ -161,7 +173,7 @@ class Mynion(slixmpp.ClientXMPP): res_line = res_line.removeprefix(f"{self.name}:") res_line = res_line.removesuffix(f"{self.user}:") self.first_round = False - msg.reply(res_line).send() # type: ignore + msg.reply(res_line).send() # type: ignore[no-untyped-call] MY_MODELS = [ @@ -178,26 +190,31 @@ MY_MODELS = [ def load_model(model_name: str) -> typing.Any: - assert model_name in MY_MODELS + """Load an ML model from disk.""" + if model_name not in MY_MODELS: + msg = f"{model_name} not available" + raise ValueError(msg) if not torch.cuda.is_available(): - raise ValueError("no cuda") + msg = "no cuda" + raise ValueError(msg) sys.exit(1) - torch.set_grad_enabled(False) - torch.cuda._lazy_init() # type: ignore + torch.set_grad_enabled(mode=False) + torch.cuda.init() # type: ignore[no-untyped-call] ml_models = "/mnt/campbell/ben/ml-models" - model_dir = os.path.join(ml_models, model_name) + model_dir = pathlib.Path(ml_models) / model_name - tokenizer_path = os.path.join(model_dir, "tokenizer.model") - config_path = os.path.join(model_dir, "config.json") - st_pattern = os.path.join(model_dir, "*.safetensors") - st = glob.glob(st_pattern) + tokenizer_path = pathlib.Path(model_dir) / "tokenizer.model" + config_path = pathlib.Path(model_dir) / "config.json" + st = list(pathlib.Path(model_dir).glob("*.safetensors")) if len(st) > 1: - raise ValueError("found multiple safetensors!") - elif len(st) < 1: - raise ValueError("could not find model") + msg = "found multiple safetensors!" + raise ValueError(msg) + if len(st) < 1: + msg = "could not find model" + raise ValueError(msg) model_path = st[0] config = exllama.model.ExLlamaConfig(config_path) @@ -230,14 +247,15 @@ def main( model in the repl and then restart the chatbot without unloading it. """ Biz.Log.setup() - xmpp = Mynion(user, password, model, tokenizer, generator) + auth = Auth(user, password) + xmpp = Mynion(auth, model, tokenizer, generator) xmpp.connect() - xmpp.process(forever=True) # type: ignore + xmpp.process(forever=True) # type: ignore[no-untyped-call] if __name__ == "__main__": if "test" in sys.argv: - print("pass: test: Biz/Mynion.py") + sys.stdout.write("pass: test: Biz/Mynion.py\n") sys.exit(0) else: cli = argparse.ArgumentParser(description=__doc__) diff --git a/Biz/Que/Client.py b/Biz/Que/Client.py old mode 100755 new mode 100644 index 20349b6..53f14e4 --- a/Biz/Que/Client.py +++ b/Biz/Que/Client.py @@ -1,22 +1,20 @@ -#!/usr/bin/env python3 # : out que -""" -simple client for que.run -""" +# ruff: noqa: PERF203 +"""simple client for que.run.""" import argparse import configparser import functools import http.client import logging -import os +import pathlib import subprocess import sys import textwrap import time -import urllib.parse -import urllib.request as request import typing +import urllib.parse +from urllib import request MAX_TIMEOUT = 9999999 RETRIES = 10 @@ -24,14 +22,14 @@ DELAY = 3 BACKOFF = 1 -def auth(args: argparse.Namespace) -> typing.Union[str, None]: - "Returns the auth key for the given ns from ~/.config/que.conf" +def auth(args: argparse.Namespace) -> str | None: + """Return the auth key for the given ns from ~/.config/que.conf.""" logging.debug("auth") namespace = args.target.split("/")[0] if namespace == "pub": return None - conf_file = os.path.expanduser("~/.config/que.conf") - if not os.path.exists(conf_file): + conf_file = pathlib.Path("~/.config/que.conf").expanduser() + if not conf_file.exists(): sys.exit("you need a ~/.config/que.conf") cfg = configparser.ConfigParser() cfg.read(conf_file) @@ -39,8 +37,11 @@ def auth(args: argparse.Namespace) -> typing.Union[str, None]: def autodecode(bytestring: bytes) -> typing.Any: - """Attempt to decode bytes into common codecs, preferably utf-8. If no - decoding is available, just return the raw bytes. + """ + Automatically decode bytes into common codecs. + + Or at least make an attempt. Output is preferably utf-8. If no decoding is + available, just return the raw bytes. For all available codecs, see: @@ -63,7 +64,7 @@ def retry( delay: typing.Any = DELAY, backoff: typing.Any = BACKOFF, ) -> typing.Any: - "Decorator for retrying an action." + """Retry an action.""" def decorator(func: typing.Any) -> typing.Any: @functools.wraps(func) @@ -90,7 +91,7 @@ def retry( @retry(http.client.IncompleteRead) @retry(http.client.RemoteDisconnected) def send(args: argparse.Namespace) -> None: - "Send a message to the que." + """Send a message to the que.""" logging.debug("send") key = auth(args) data = args.infile @@ -108,13 +109,13 @@ def send(args: argparse.Namespace) -> None: def then(args: argparse.Namespace, msg: str) -> None: - "Perform an action when passed `--then`." + """Perform an action when passed `--then`.""" if args.then: logging.debug("then") subprocess.run( args.then.format(msg=msg, que=args.target), check=False, - shell=True, + shell=True, # noqa: S602 ) @@ -123,7 +124,7 @@ def then(args: argparse.Namespace, msg: str) -> None: @retry(http.client.IncompleteRead) @retry(http.client.RemoteDisconnected) def recv(args: argparse.Namespace) -> None: - "Receive a message from the que." + """Receive a message from the que.""" logging.debug("recv on: %s", args.target) if args.poll: req = request.Request(f"{args.host}/{args.target}/stream") @@ -133,6 +134,9 @@ def recv(args: argparse.Namespace) -> None: key = auth(args) if key: req.add_header("Authorization", key) + if not req.startswith(("http:", "https:")): + msg = "URL must start with 'http:' or 'https:'" + raise ValueError(msg) with request.urlopen(req) as _req: if args.poll: logging.debug("polling") @@ -141,29 +145,31 @@ def recv(args: argparse.Namespace) -> None: if reply: msg = autodecode(reply) logging.debug("read") - print(msg, end="") + sys.stdout.write(msg) then(args, msg) else: continue else: msg = autodecode(_req.readline()) - print(msg) + sys.stdout.write(msg) then(args, msg) def get_args() -> argparse.Namespace: - "Command line parser" + """Command line parser.""" cli = argparse.ArgumentParser( description=__doc__, epilog=textwrap.dedent( f"""Requests will retry up to {RETRIES} times, with {DELAY} seconds - between attempts.""" + between attempts.""", ), ) cli.add_argument("test", action="store_true", help="run tests") cli.add_argument("--debug", action="store_true", help="log to stderr") cli.add_argument( - "--host", default="http://que.run", help="where que-server is running" + "--host", + default="http://que.run", + help="where que-server is running", ) cli.add_argument( "--poll", @@ -171,7 +177,7 @@ def get_args() -> argparse.Namespace: action="store_true", help=textwrap.dedent( """keep the connection open to stream data from the que. without - this flag, the program will exit after receiving a message""" + this flag, the program will exit after receiving a message""", ), ) cli.add_argument( @@ -179,7 +185,7 @@ def get_args() -> argparse.Namespace: help=textwrap.dedent( """when polling, run this shell command after each response, presumably for side effects, replacing '{que}' with the target and - '{msg}' with the body of the response""" + '{msg}' with the body of the response""", ), ) cli.add_argument( @@ -188,22 +194,18 @@ def get_args() -> argparse.Namespace: action="store_true", help=textwrap.dedent( """when posting to the que, do so continuously in a loop. this can - be used for serving a webpage or other file continuously""" + be used for serving a webpage or other file continuously""", ), ) cli.add_argument( - "target", help="namespace and path of the que, like 'ns/path'" + "target", + help="namespace and path of the que, like 'ns/path'", ) cli.add_argument( "infile", nargs="?", type=argparse.FileType("rb"), - help=" ".join( - [ - "data to put on the que.", - "use '-' for stdin, otherwise should be a readable file", - ] - ), + help="file of data to put on the que. use '-' for stdin", ) return cli.parse_args() @@ -211,7 +213,7 @@ def get_args() -> argparse.Namespace: if __name__ == "__main__": ARGV = get_args() if ARGV.test: - print("ok") + sys.stdout.write("ok\n") sys.exit() if ARGV.debug: logging.basicConfig( diff --git a/Biz/Repl.py b/Biz/Repl.py index 9844abf..9cc0c35 100644 --- a/Biz/Repl.py +++ b/Biz/Repl.py @@ -1,19 +1,25 @@ """ +Improve the standard Python REPL. + This module attempts to emulate the workflow of ghci or lisp repls. It uses importlib to load a namespace from the given path. It then binds 'r()' to a function that reloads the same namespace. """ import importlib +import logging import sys +from Biz import Log + def use(ns: str, path: str) -> None: """ - Load or reload the module named 'ns' from 'path'. Like `use` in the Guile - Scheme repl. + Load or reload the module named 'ns' from 'path'. + + Like `use` in the Guile Scheme repl. """ - info(f"loading {ns} from {path}") + logging.info("loading %s from %s", ns, path) spec = importlib.util.spec_from_file_location(ns, path) module = importlib.util.module_from_spec(spec) # delete module and its imported names if its already loaded @@ -24,25 +30,24 @@ def use(ns: str, path: str) -> None: del globals()[name] sys.modules[ns] = module spec.loader.exec_module(module) - names = [x for x in module.__dict__] + names = list(module.__dict__) globals().update({k: getattr(module, k) for k in names}) -def info(s): - print(f"info: repl: {s}") - - if __name__ == "__main__": + Log.setup() NS = sys.argv[1] PATH = sys.argv[2] use(NS, PATH) - info("use reload() or _r() after making changes") + logging.info("use reload() or _r() after making changes") sys.ps1 = f"{NS}> " sys.ps2 = f"{NS}| " - def reload(): + def reload() -> None: + """Reload the namespace.""" return use(NS, PATH) - def _r(): + def _r() -> None: + """Shorthand: Reload the namespace.""" return use(NS, PATH) diff --git a/pyproject.toml b/pyproject.toml index 7f7d0d3..dbb1de1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,33 @@ -[tool.black] -line-length = 80 - [tool.mypy] strict = true implicit_reexport = true + +[tool.ruff] +exclude = ["_", ".git"] +line-length = 80 +indent-width = 4 +target-version = "py310" + +[tool.ruff.format] +preview = true +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + +[tool.ruff.lint] +preview = true +select = ["ALL"] +fixable = ["ALL"] +ignore = [ + "ANN401", # any-type, we allow typing.Any, although we shouldn't + "CPY001", # missing-copyright-notice + "D203", # no-blank-line-before-class + "D212", # multi-line-summary-first-line + "E203", # whitespace-before-punctuation, doesn't work with ruff format + "INP001", # implicit-namespace-package + "N999", # invalid-module-name + "S310", # suspicious-url-open-usage, doesn't work in 0.1.5 + "S603", # subprocess-without-shell-equals-true, false positives + "S607", # start-process-with-partial-path +] -- cgit v1.2.3