summaryrefslogtreecommitdiff
path: root/Biz/Mynion.py
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2024-04-10 19:56:46 -0400
committerBen Sima <ben@bsima.me>2024-04-10 19:56:46 -0400
commit2c09c7f73e2fc770f42b5dd2588aa9634b4e7c6e (patch)
treea6c49dddb7b1735a0f8de2a9a5f4efb605f16f36 /Biz/Mynion.py
parent051973aba8953ebde51eb1436fb3994e7ae699dc (diff)
Switch from black to ruff format
Ruff is faster and if it supports everything that black supports than why not? I did have to pull in a more recent version from unstable, but that's easy to do now. And I decided to just go ahead and configure ruff by turning on almost all checks, which meant I had to fix a whole bunch of things, but I did that and everything is okay now.
Diffstat (limited to 'Biz/Mynion.py')
-rw-r--r--Biz/Mynion.py108
1 files changed, 63 insertions, 45 deletions
diff --git a/Biz/Mynion.py b/Biz/Mynion.py
index d3dabcf..3b80f5f 100644
--- a/Biz/Mynion.py
+++ b/Biz/Mynion.py
@@ -1,38 +1,51 @@
-"""
-Mynion is a helper.
-"""
+"""Mynion is a helper."""
+
# : out mynion
# : dep exllama
# : dep slixmpp
import argparse
-import exllama # type: ignore
-import Biz.Log
-import glob
+import dataclasses
import logging
import os
+import pathlib
+import sys
+import typing
+
+import exllama # type: ignore[import]
import slixmpp
import slixmpp.exceptions
-import sys
import torch
-import typing
+
+import Biz.Log
def smoosh(s: str) -> str:
+ """Replace newlines with spaces."""
return s.replace("\n", " ")
+@dataclasses.dataclass
+class Auth:
+ """Container for XMPP authentication."""
+
+ jid: str
+ password: str
+
+
class Mynion(slixmpp.ClientXMPP):
+ """A helper via xmpp."""
+
def __init__(
- self,
- jid: str,
- password: str,
+ self: "Mynion",
+ auth: Auth,
model: exllama.model.ExLlama,
tokenizer: exllama.tokenizer.ExLlamaTokenizer,
generator: exllama.generator.ExLlamaGenerator,
) -> None:
- slixmpp.ClientXMPP.__init__(self, jid, password)
- self.plugin.enable("xep_0085") # type: ignore
- self.plugin.enable("xep_0184") # type: ignore
+ """Initialize Mynion chat bot service."""
+ slixmpp.ClientXMPP.__init__(self, auth.jid, auth.password)
+ self.plugin.enable("xep_0085") # type: ignore[attr-defined]
+ self.plugin.enable("xep_0184") # type: ignore[attr-defined]
self.name = "mynion"
self.user = "ben"
@@ -40,7 +53,6 @@ class Mynion(slixmpp.ClientXMPP):
self.min_response_tokens = 4
self.max_response_tokens = 256
self.extra_prune = 256
- # todo: handle summary rollup when max_seq_len is reached
self.max_seq_len = 8000
self.model = model
@@ -49,9 +61,8 @@ class Mynion(slixmpp.ClientXMPP):
root = os.getenv("CODEROOT", "")
# this should be parameterized somehow
- with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f:
- txt = f.read()
- txt = txt.format(user=self.user, name=self.name)
+ promptfile = pathlib.Path(root) / "Biz" / "Mynion" / "Prompt.md"
+ txt = promptfile.read_text().format(user=self.user, name=self.name)
# this is the "system prompt", ideally i would load this in/out of a
# database with all of the past history. if the history gets too long, i
@@ -64,20 +75,22 @@ class Mynion(slixmpp.ClientXMPP):
self.add_event_handler("session_start", self.session_start)
self.add_event_handler("message", self.message)
- def session_start(self) -> None:
+ def session_start(self: "Mynion") -> None:
+ """Start online session with xmpp server."""
self.send_presence()
try:
- self.get_roster() # type: ignore
+ self.get_roster() # type: ignore[no-untyped-call]
except slixmpp.exceptions.IqError as err:
- logging.error("There was an error getting the roster")
- logging.error(err.iq["error"]["condition"])
+ logging.exception("There was an error getting the roster")
+ logging.exception(err.iq["error"]["condition"])
self.disconnect()
except slixmpp.exceptions.IqTimeout:
- logging.error("Server is taking too long to respond")
+ logging.exception("Server is taking too long to respond")
self.disconnect()
- def message(self, msg: slixmpp.Message) -> None:
- if msg["type"] in ("chat", "normal"):
+ def message(self: "Mynion", msg: slixmpp.Message) -> None:
+ """Send a message."""
+ if msg["type"] in {"chat", "normal"}:
res_line = f"{self.name}: "
res_tokens = self.tokenizer.encode(res_line)
num_res_tokens = res_tokens.shape[-1]
@@ -109,7 +122,6 @@ class Mynion(slixmpp.ClientXMPP):
self.generator.begin_beam_search()
# generate tokens, with streaming
- # TODO: drop the streaming!
for i in range(self.max_response_tokens):
# disallowing the end condition tokens seems like a clean way to
# force longer replies
@@ -118,7 +130,7 @@ class Mynion(slixmpp.ClientXMPP):
[
self.tokenizer.newline_token_id,
self.tokenizer.eos_token_id,
- ]
+ ],
)
else:
self.generator.disallow_tokens(None)
@@ -129,13 +141,13 @@ class Mynion(slixmpp.ClientXMPP):
# if token is EOS, replace it with a newline before continuing
if gen_token.item() == self.tokenizer.eos_token_id:
self.generator.replace_last_token(
- self.tokenizer.newline_token_id
+ self.tokenizer.newline_token_id,
)
# decode the current line
num_res_tokens += 1
text = self.tokenizer.decode(
- self.generator.sequence_actual[:, -num_res_tokens:][0]
+ self.generator.sequence_actual[:, -num_res_tokens:][0],
)
# append to res_line
@@ -161,7 +173,7 @@ class Mynion(slixmpp.ClientXMPP):
res_line = res_line.removeprefix(f"{self.name}:")
res_line = res_line.removesuffix(f"{self.user}:")
self.first_round = False
- msg.reply(res_line).send() # type: ignore
+ msg.reply(res_line).send() # type: ignore[no-untyped-call]
MY_MODELS = [
@@ -178,26 +190,31 @@ MY_MODELS = [
def load_model(model_name: str) -> typing.Any:
- assert model_name in MY_MODELS
+ """Load an ML model from disk."""
+ if model_name not in MY_MODELS:
+ msg = f"{model_name} not available"
+ raise ValueError(msg)
if not torch.cuda.is_available():
- raise ValueError("no cuda")
+ msg = "no cuda"
+ raise ValueError(msg)
sys.exit(1)
- torch.set_grad_enabled(False)
- torch.cuda._lazy_init() # type: ignore
+ torch.set_grad_enabled(mode=False)
+ torch.cuda.init() # type: ignore[no-untyped-call]
ml_models = "/mnt/campbell/ben/ml-models"
- model_dir = os.path.join(ml_models, model_name)
+ model_dir = pathlib.Path(ml_models) / model_name
- tokenizer_path = os.path.join(model_dir, "tokenizer.model")
- config_path = os.path.join(model_dir, "config.json")
- st_pattern = os.path.join(model_dir, "*.safetensors")
- st = glob.glob(st_pattern)
+ tokenizer_path = pathlib.Path(model_dir) / "tokenizer.model"
+ config_path = pathlib.Path(model_dir) / "config.json"
+ st = list(pathlib.Path(model_dir).glob("*.safetensors"))
if len(st) > 1:
- raise ValueError("found multiple safetensors!")
- elif len(st) < 1:
- raise ValueError("could not find model")
+ msg = "found multiple safetensors!"
+ raise ValueError(msg)
+ if len(st) < 1:
+ msg = "could not find model"
+ raise ValueError(msg)
model_path = st[0]
config = exllama.model.ExLlamaConfig(config_path)
@@ -230,14 +247,15 @@ def main(
model in the repl and then restart the chatbot without unloading it.
"""
Biz.Log.setup()
- xmpp = Mynion(user, password, model, tokenizer, generator)
+ auth = Auth(user, password)
+ xmpp = Mynion(auth, model, tokenizer, generator)
xmpp.connect()
- xmpp.process(forever=True) # type: ignore
+ xmpp.process(forever=True) # type: ignore[no-untyped-call]
if __name__ == "__main__":
if "test" in sys.argv:
- print("pass: test: Biz/Mynion.py")
+ sys.stdout.write("pass: test: Biz/Mynion.py\n")
sys.exit(0)
else:
cli = argparse.ArgumentParser(description=__doc__)