"""Mynion is a helper.""" # : out mynion # : dep exllama # : dep slixmpp import argparse import dataclasses import logging import os import pathlib import sys import typing import exllama # type: ignore[import] import slixmpp import slixmpp.exceptions import torch import Biz.Log def smoosh(s: str) -> str: """Replace newlines with spaces.""" return s.replace("\n", " ") @dataclasses.dataclass class Auth: """Container for XMPP authentication.""" jid: str password: str class Mynion(slixmpp.ClientXMPP): """A helper via xmpp.""" def __init__( self: "Mynion", auth: Auth, model: exllama.model.ExLlama, tokenizer: exllama.tokenizer.ExLlamaTokenizer, generator: exllama.generator.ExLlamaGenerator, ) -> None: """Initialize Mynion chat bot service.""" slixmpp.ClientXMPP.__init__(self, auth.jid, auth.password) self.plugin.enable("xep_0085") # type: ignore[attr-defined] self.plugin.enable("xep_0184") # type: ignore[attr-defined] self.name = "mynion" self.user = "ben" self.first_round = True self.min_response_tokens = 4 self.max_response_tokens = 256 self.extra_prune = 256 self.max_seq_len = 8000 self.model = model self.tokenizer = tokenizer self.generator = generator root = os.getenv("CODEROOT", "") # this should be parameterized somehow promptfile = pathlib.Path(root) / "Biz" / "Mynion" / "Prompt.md" txt = promptfile.read_text().format(user=self.user, name=self.name) # this is the "system prompt", ideally i would load this in/out of a # database with all of the past history. if the history gets too long, i # can roll it up by asking llama to summarize it self.past = smoosh(txt) ids = tokenizer.encode(self.past) self.generator.gen_begin(ids) self.add_event_handler("session_start", self.session_start) self.add_event_handler("message", self.message) def session_start(self: "Mynion") -> None: """Start online session with xmpp server.""" self.send_presence() try: self.get_roster() # type: ignore[no-untyped-call] except slixmpp.exceptions.IqError as err: logging.exception("There was an error getting the roster") logging.exception(err.iq["error"]["condition"]) self.disconnect() except slixmpp.exceptions.IqTimeout: logging.exception("Server is taking too long to respond") self.disconnect() def message(self: "Mynion", msg: slixmpp.Message) -> None: """Send a message.""" if msg["type"] in {"chat", "normal"}: res_line = f"{self.name}: " res_tokens = self.tokenizer.encode(res_line) num_res_tokens = res_tokens.shape[-1] if self.first_round: in_tokens = res_tokens else: # read and format input in_line = f"{self.user}: " + msg["body"].strip() + "\n" in_tokens = self.tokenizer.encode(in_line) in_tokens = torch.cat((in_tokens, res_tokens), dim=1) # If we're approaching the context limit, prune some whole lines # from the start of the context. Also prune a little extra so we # don't end up rebuilding the cache on every line when up against # the limit. expect_tokens = in_tokens.shape[-1] + self.max_response_tokens max_tokens = self.max_seq_len - expect_tokens if self.generator.gen_num_tokens() >= max_tokens: generator.gen_prune_to( self.max_seq_len - expect_tokens - self.extra_prune, self.tokenizer.newline_token_id, ) # feed in the user input and "{self.name}:", tokenized self.generator.gen_feed_tokens(in_tokens) # start beam search? self.generator.begin_beam_search() # generate tokens, with streaming for i in range(self.max_response_tokens): # disallowing the end condition tokens seems like a clean way to # force longer replies if i < self.min_response_tokens: self.generator.disallow_tokens( [ self.tokenizer.newline_token_id, self.tokenizer.eos_token_id, ], ) else: self.generator.disallow_tokens(None) # get a token gen_token = self.generator.beam_search() # if token is EOS, replace it with a newline before continuing if gen_token.item() == self.tokenizer.eos_token_id: self.generator.replace_last_token( self.tokenizer.newline_token_id, ) # decode the current line num_res_tokens += 1 text = self.tokenizer.decode( self.generator.sequence_actual[:, -num_res_tokens:][0], ) # append to res_line res_line += text[len(res_line) :] # end conditions breakers = [ self.tokenizer.eos_token_id, # self.tokenizer.newline_token_id, ] if gen_token.item() in breakers: break # try to drop the "ben:" at the end if res_line.endswith(f"{self.user}:"): logging.info("rewinding!") plen = self.tokenizer.encode(f"{self.user}:").shape[-1] self.generator.gen_rewind(plen) break # end generation and send the reply self.generator.end_beam_search() res_line = res_line.removeprefix(f"{self.name}:") res_line = res_line.removesuffix(f"{self.user}:") self.first_round = False msg.reply(res_line).send() # type: ignore[no-untyped-call] MY_MODELS = [ "Llama-2-13B-GPTQ", "Nous-Hermes-13B-GPTQ", "Nous-Hermes-Llama2-13b-GPTQ", "Wizard-Vicuna-13B-Uncensored-GPTQ", "Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ", "Wizard-Vicuna-30B-Uncensored-GPTQ", "CodeLlama-13B-Python-GPTQ", "CodeLlama-13B-Instruct-GPTQ", "CodeLlama-34B-Instruct-GPTQ", ] def load_model(model_name: str) -> typing.Any: """Load an ML model from disk.""" if model_name not in MY_MODELS: msg = f"{model_name} not available" raise ValueError(msg) if not torch.cuda.is_available(): msg = "no cuda" raise ValueError(msg) sys.exit(1) torch.set_grad_enabled(mode=False) torch.cuda.init() # type: ignore[no-untyped-call] ml_models = "/mnt/campbell/ben/ml-models" model_dir = pathlib.Path(ml_models) / model_name tokenizer_path = pathlib.Path(model_dir) / "tokenizer.model" config_path = pathlib.Path(model_dir) / "config.json" st = list(pathlib.Path(model_dir).glob("*.safetensors")) if len(st) > 1: msg = "found multiple safetensors!" raise ValueError(msg) if len(st) < 1: msg = "could not find model" raise ValueError(msg) model_path = st[0] config = exllama.model.ExLlamaConfig(config_path) config.model_path = model_path # gpu split config.set_auto_map("23") model = exllama.model.ExLlama(config) cache = exllama.model.ExLlamaCache(model) tokenizer = exllama.tokenizer.ExLlamaTokenizer(tokenizer_path) generator = exllama.generator.ExLlamaGenerator(model, tokenizer, cache) generator.settings = exllama.generator.ExLlamaGenerator.Settings() return (model, tokenizer, generator) def main( model: exllama.model.ExLlama, tokenizer: exllama.tokenizer.ExLlamaTokenizer, generator: exllama.generator.ExLlamaGenerator, user: str, password: str, ) -> None: """ Start the chatbot. This purposefully does not call 'load_model()' so that you can load the model in the repl and then restart the chatbot without unloading it. """ Biz.Log.setup() auth = Auth(user, password) xmpp = Mynion(auth, model, tokenizer, generator) xmpp.connect() xmpp.process(forever=True) # type: ignore[no-untyped-call] if __name__ == "__main__": if "test" in sys.argv: sys.stdout.write("pass: test: Biz/Mynion.py\n") sys.exit(0) else: cli = argparse.ArgumentParser(description=__doc__) cli.add_argument("-u", "--user") cli.add_argument("-p", "--password") cli.add_argument("-m", "--model", choices=MY_MODELS) args = cli.parse_args() model, tokenizer, generator = load_model(args.model) main(model, tokenizer, generator, args.user, args.password)