diff options
Diffstat (limited to 'Biz/Mynion.py')
-rw-r--r-- | Biz/Mynion.py | 108 |
1 files changed, 63 insertions, 45 deletions
diff --git a/Biz/Mynion.py b/Biz/Mynion.py index d3dabcf..3b80f5f 100644 --- a/Biz/Mynion.py +++ b/Biz/Mynion.py @@ -1,38 +1,51 @@ -""" -Mynion is a helper. -""" +"""Mynion is a helper.""" + # : out mynion # : dep exllama # : dep slixmpp import argparse -import exllama # type: ignore -import Biz.Log -import glob +import dataclasses import logging import os +import pathlib +import sys +import typing + +import exllama # type: ignore[import] import slixmpp import slixmpp.exceptions -import sys import torch -import typing + +import Biz.Log def smoosh(s: str) -> str: + """Replace newlines with spaces.""" return s.replace("\n", " ") +@dataclasses.dataclass +class Auth: + """Container for XMPP authentication.""" + + jid: str + password: str + + class Mynion(slixmpp.ClientXMPP): + """A helper via xmpp.""" + def __init__( - self, - jid: str, - password: str, + self: "Mynion", + auth: Auth, model: exllama.model.ExLlama, tokenizer: exllama.tokenizer.ExLlamaTokenizer, generator: exllama.generator.ExLlamaGenerator, ) -> None: - slixmpp.ClientXMPP.__init__(self, jid, password) - self.plugin.enable("xep_0085") # type: ignore - self.plugin.enable("xep_0184") # type: ignore + """Initialize Mynion chat bot service.""" + slixmpp.ClientXMPP.__init__(self, auth.jid, auth.password) + self.plugin.enable("xep_0085") # type: ignore[attr-defined] + self.plugin.enable("xep_0184") # type: ignore[attr-defined] self.name = "mynion" self.user = "ben" @@ -40,7 +53,6 @@ class Mynion(slixmpp.ClientXMPP): self.min_response_tokens = 4 self.max_response_tokens = 256 self.extra_prune = 256 - # todo: handle summary rollup when max_seq_len is reached self.max_seq_len = 8000 self.model = model @@ -49,9 +61,8 @@ class Mynion(slixmpp.ClientXMPP): root = os.getenv("CODEROOT", "") # this should be parameterized somehow - with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f: - txt = f.read() - txt = txt.format(user=self.user, name=self.name) + promptfile = pathlib.Path(root) / "Biz" / "Mynion" / "Prompt.md" + txt = promptfile.read_text().format(user=self.user, name=self.name) # this is the "system prompt", ideally i would load this in/out of a # database with all of the past history. if the history gets too long, i @@ -64,20 +75,22 @@ class Mynion(slixmpp.ClientXMPP): self.add_event_handler("session_start", self.session_start) self.add_event_handler("message", self.message) - def session_start(self) -> None: + def session_start(self: "Mynion") -> None: + """Start online session with xmpp server.""" self.send_presence() try: - self.get_roster() # type: ignore + self.get_roster() # type: ignore[no-untyped-call] except slixmpp.exceptions.IqError as err: - logging.error("There was an error getting the roster") - logging.error(err.iq["error"]["condition"]) + logging.exception("There was an error getting the roster") + logging.exception(err.iq["error"]["condition"]) self.disconnect() except slixmpp.exceptions.IqTimeout: - logging.error("Server is taking too long to respond") + logging.exception("Server is taking too long to respond") self.disconnect() - def message(self, msg: slixmpp.Message) -> None: - if msg["type"] in ("chat", "normal"): + def message(self: "Mynion", msg: slixmpp.Message) -> None: + """Send a message.""" + if msg["type"] in {"chat", "normal"}: res_line = f"{self.name}: " res_tokens = self.tokenizer.encode(res_line) num_res_tokens = res_tokens.shape[-1] @@ -109,7 +122,6 @@ class Mynion(slixmpp.ClientXMPP): self.generator.begin_beam_search() # generate tokens, with streaming - # TODO: drop the streaming! for i in range(self.max_response_tokens): # disallowing the end condition tokens seems like a clean way to # force longer replies @@ -118,7 +130,7 @@ class Mynion(slixmpp.ClientXMPP): [ self.tokenizer.newline_token_id, self.tokenizer.eos_token_id, - ] + ], ) else: self.generator.disallow_tokens(None) @@ -129,13 +141,13 @@ class Mynion(slixmpp.ClientXMPP): # if token is EOS, replace it with a newline before continuing if gen_token.item() == self.tokenizer.eos_token_id: self.generator.replace_last_token( - self.tokenizer.newline_token_id + self.tokenizer.newline_token_id, ) # decode the current line num_res_tokens += 1 text = self.tokenizer.decode( - self.generator.sequence_actual[:, -num_res_tokens:][0] + self.generator.sequence_actual[:, -num_res_tokens:][0], ) # append to res_line @@ -161,7 +173,7 @@ class Mynion(slixmpp.ClientXMPP): res_line = res_line.removeprefix(f"{self.name}:") res_line = res_line.removesuffix(f"{self.user}:") self.first_round = False - msg.reply(res_line).send() # type: ignore + msg.reply(res_line).send() # type: ignore[no-untyped-call] MY_MODELS = [ @@ -178,26 +190,31 @@ MY_MODELS = [ def load_model(model_name: str) -> typing.Any: - assert model_name in MY_MODELS + """Load an ML model from disk.""" + if model_name not in MY_MODELS: + msg = f"{model_name} not available" + raise ValueError(msg) if not torch.cuda.is_available(): - raise ValueError("no cuda") + msg = "no cuda" + raise ValueError(msg) sys.exit(1) - torch.set_grad_enabled(False) - torch.cuda._lazy_init() # type: ignore + torch.set_grad_enabled(mode=False) + torch.cuda.init() # type: ignore[no-untyped-call] ml_models = "/mnt/campbell/ben/ml-models" - model_dir = os.path.join(ml_models, model_name) + model_dir = pathlib.Path(ml_models) / model_name - tokenizer_path = os.path.join(model_dir, "tokenizer.model") - config_path = os.path.join(model_dir, "config.json") - st_pattern = os.path.join(model_dir, "*.safetensors") - st = glob.glob(st_pattern) + tokenizer_path = pathlib.Path(model_dir) / "tokenizer.model" + config_path = pathlib.Path(model_dir) / "config.json" + st = list(pathlib.Path(model_dir).glob("*.safetensors")) if len(st) > 1: - raise ValueError("found multiple safetensors!") - elif len(st) < 1: - raise ValueError("could not find model") + msg = "found multiple safetensors!" + raise ValueError(msg) + if len(st) < 1: + msg = "could not find model" + raise ValueError(msg) model_path = st[0] config = exllama.model.ExLlamaConfig(config_path) @@ -230,14 +247,15 @@ def main( model in the repl and then restart the chatbot without unloading it. """ Biz.Log.setup() - xmpp = Mynion(user, password, model, tokenizer, generator) + auth = Auth(user, password) + xmpp = Mynion(auth, model, tokenizer, generator) xmpp.connect() - xmpp.process(forever=True) # type: ignore + xmpp.process(forever=True) # type: ignore[no-untyped-call] if __name__ == "__main__": if "test" in sys.argv: - print("pass: test: Biz/Mynion.py") + sys.stdout.write("pass: test: Biz/Mynion.py\n") sys.exit(0) else: cli = argparse.ArgumentParser(description=__doc__) |