""" Mynion is a helper. """ # : out mynion # : dep exllama # : dep slixmpp import argparse import exllama # type: ignore import Biz.Log import glob import logging import os import slixmpp import slixmpp.exceptions import sys import torch import typing def smoosh(s: str) -> str: return s.replace("\n", " ") class Mynion(slixmpp.ClientXMPP): def __init__( self, jid: str, password: str, model: exllama.model.ExLlama, tokenizer: exllama.tokenizer.ExLlamaTokenizer, generator: exllama.generator.ExLlamaGenerator, ) -> None: slixmpp.ClientXMPP.__init__(self, jid, password) self.plugin.enable("xep_0085") # type: ignore self.plugin.enable("xep_0184") # type: ignore self.name = "mynion" self.user = "ben" self.first_round = True self.min_response_tokens = 4 self.max_response_tokens = 256 self.extra_prune = 256 # todo: handle summary rollup when max_seq_len is reached self.max_seq_len = 8000 self.model = model self.tokenizer = tokenizer self.generator = generator root = os.getenv("CODEROOT", "") # this should be parameterized somehow with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f: txt = f.read() txt = txt.format(user=self.user, name=self.name) # this is the "system prompt", ideally i would load this in/out of a # database with all of the past history. if the history gets too long, i # can roll it up by asking llama to summarize it self.past = smoosh(txt) ids = tokenizer.encode(self.past) self.generator.gen_begin(ids) self.add_event_handler("session_start", self.session_start) self.add_event_handler("message", self.message) def session_start(self) -> None: self.send_presence() try: self.get_roster() # type: ignore except slixmpp.exceptions.IqError as err: logging.error("There was an error getting the roster") logging.error(err.iq["error"]["condition"]) self.disconnect() except slixmpp.exceptions.IqTimeout: logging.error("Server is taking too long to respond") self.disconnect() def message(self, msg: slixmpp.Message) -> None: if msg["type"] in ("chat", "normal"): res_line = f"{self.name}: " res_tokens = self.tokenizer.encode(res_line) num_res_tokens = res_tokens.shape[-1] if self.first_round: in_tokens = res_tokens else: # read and format input in_line = f"{self.user}: " + msg["body"].strip() + "\n" in_tokens = self.tokenizer.encode(in_line) in_tokens = torch.cat((in_tokens, res_tokens), dim=1) # If we're approaching the context limit, prune some whole lines # from the start of the context. Also prune a little extra so we # don't end up rebuilding the cache on every line when up against # the limit. expect_tokens = in_tokens.shape[-1] + self.max_response_tokens max_tokens = self.max_seq_len - expect_tokens if self.generator.gen_num_tokens() >= max_tokens: generator.gen_prune_to( self.max_seq_len - expect_tokens - self.extra_prune, self.tokenizer.newline_token_id, ) # feed in the user input and "{self.name}:", tokenized self.generator.gen_feed_tokens(in_tokens) # start beam search? self.generator.begin_beam_search() # generate tokens, with streaming # TODO: drop the streaming! for i in range(self.max_response_tokens): # disallowing the end condition tokens seems like a clean way to # force longer replies if i < self.min_response_tokens: self.generator.disallow_tokens( [ self.tokenizer.newline_token_id, self.tokenizer.eos_token_id, ] ) else: self.generator.disallow_tokens(None) # get a token gen_token = self.generator.beam_search() # if token is EOS, replace it with a newline before continuing if gen_token.item() == self.tokenizer.eos_token_id: self.generator.replace_last_token( self.tokenizer.newline_token_id ) # decode the current line num_res_tokens += 1 text = self.tokenizer.decode( self.generator.sequence_actual[:, -num_res_tokens:][0] ) # append to res_line res_line += text[len(res_line) :] # end conditions breakers = [ self.tokenizer.eos_token_id, # self.tokenizer.newline_token_id, ] if gen_token.item() in breakers: break # try to drop the "ben:" at the end if res_line.endswith(f"{self.user}:"): logging.info("rewinding!") plen = self.tokenizer.encode(f"{self.user}:").shape[-1] self.generator.gen_rewind(plen) break # end generation and send the reply self.generator.end_beam_search() res_line = res_line.removeprefix(f"{self.name}:") res_line = res_line.removesuffix(f"{self.user}:") self.first_round = False msg.reply(res_line).send() # type: ignore MY_MODELS = [ "Llama-2-13B-GPTQ", "Nous-Hermes-13B-GPTQ", "Nous-Hermes-Llama2-13b-GPTQ", "Wizard-Vicuna-13B-Uncensored-GPTQ", "Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ", "Wizard-Vicuna-30B-Uncensored-GPTQ", "CodeLlama-13B-Python-GPTQ", "CodeLlama-13B-Instruct-GPTQ", "CodeLlama-34B-Instruct-GPTQ", ] def load_model(model_name: str) -> typing.Any: assert model_name in MY_MODELS if not torch.cuda.is_available(): raise ValueError("no cuda") sys.exit(1) torch.set_grad_enabled(False) torch.cuda._lazy_init() # type: ignore ml_models = "/mnt/campbell/ben/ml-models" model_dir = os.path.join(ml_models, model_name) tokenizer_path = os.path.join(model_dir, "tokenizer.model") config_path = os.path.join(model_dir, "config.json") st_pattern = os.path.join(model_dir, "*.safetensors") st = glob.glob(st_pattern) if len(st) > 1: raise ValueError("found multiple safetensors!") elif len(st) < 1: raise ValueError("could not find model") model_path = st[0] config = exllama.model.ExLlamaConfig(config_path) config.model_path = model_path # gpu split config.set_auto_map("23") model = exllama.model.ExLlama(config) cache = exllama.model.ExLlamaCache(model) tokenizer = exllama.tokenizer.ExLlamaTokenizer(tokenizer_path) generator = exllama.generator.ExLlamaGenerator(model, tokenizer, cache) generator.settings = exllama.generator.ExLlamaGenerator.Settings() return (model, tokenizer, generator) def main( model: exllama.model.ExLlama, tokenizer: exllama.tokenizer.ExLlamaTokenizer, generator: exllama.generator.ExLlamaGenerator, user: str, password: str, ) -> None: """ Start the chatbot. This purposefully does not call 'load_model()' so that you can load the model in the repl and then restart the chatbot without unloading it. """ Biz.Log.setup() xmpp = Mynion(user, password, model, tokenizer, generator) xmpp.connect() xmpp.process(forever=True) # type: ignore if __name__ == "__main__": if "test" in sys.argv: print("pass: test: Biz/Mynion.py") sys.exit(0) else: cli = argparse.ArgumentParser(description=__doc__) cli.add_argument("-u", "--user") cli.add_argument("-p", "--password") cli.add_argument("-m", "--model", choices=MY_MODELS) args = cli.parse_args() model, tokenizer, generator = load_model(args.model) main(model, tokenizer, generator, args.user, args.password)