From 6e4a65579c3ade76feea0890072099f0d0caf416 Mon Sep 17 00:00:00 2001 From: Ben Sima Date: Mon, 28 Aug 2023 21:05:25 -0400 Subject: Prototype Mynion This implements a prototype Mynion, my chatbot which will eventually help me write code here. In fact he's already helping me, and works pretty well over xmpp. The prompt is currently not checked in because I'm experimenting with it a lot, and it should probably be a runtime parameter anyways. In the course of writing this I added some helper libraries to get me going, configured black (didn't even know that was possible), and added 'outlines' and its dependencies even though I didn't end up using it. I'll keep outlines around for now, but I'm not sure how useful it really is because afaict its just pre-defining some stop conditions. But it took a while to get it working so I'll just keep it in for now. --- Biz/Mynion.py | 246 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 Biz/Mynion.py (limited to 'Biz/Mynion.py') diff --git a/Biz/Mynion.py b/Biz/Mynion.py new file mode 100644 index 0000000..6bb55e1 --- /dev/null +++ b/Biz/Mynion.py @@ -0,0 +1,246 @@ +# : out mynion +# +# : dep exllama +# : dep slixmpp +import argparse +import exllama # type: ignore +import Biz.Log +import glob +import logging +import os +import slixmpp +import slixmpp.exceptions +import sys +import torch +import typing + + +def smoosh(s: str) -> str: + return s.replace("\n", " ") + + +class Mynion(slixmpp.ClientXMPP): + def __init__( + self, + jid: str, + password: str, + model: exllama.model.ExLlama, + tokenizer: exllama.tokenizer.ExLlamaTokenizer, + generator: exllama.generator.ExLlamaGenerator, + ) -> None: + slixmpp.ClientXMPP.__init__(self, jid, password) + self.plugin.enable("xep_0085") # type: ignore + self.plugin.enable("xep_0184") # type: ignore + + self.name = "mynion" + self.user = "ben" + self.first_round = True + self.min_response_tokens = 4 + self.max_response_tokens = 256 + self.extra_prune = 256 + # todo: handle summary rollup when max_seq_len is reached + self.max_seq_len = 8000 + + self.model = model + self.tokenizer = tokenizer + self.generator = generator + + root = os.getenv("BIZ_ROOT", "") + # this should be parameterized somehow + with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f: + txt = f.read() + txt = txt.format(user=self.user, name=self.name) + + # this is the "system prompt", ideally i would load this in/out of a + # database with all of the past history. if the history gets too long, i + # can roll it up by asking llama to summarize it + self.past = smoosh(txt) + + ids = tokenizer.encode(self.past) + self.generator.gen_begin(ids) + + self.add_event_handler("session_start", self.session_start) + self.add_event_handler("message", self.message) + + def session_start(self) -> None: + self.send_presence() + try: + self.get_roster() # type: ignore + except slixmpp.exceptions.IqError as err: + logging.error("There was an error getting the roster") + logging.error(err.iq["error"]["condition"]) + self.disconnect() + except slixmpp.exceptions.IqTimeout: + logging.error("Server is taking too long to respond") + self.disconnect() + + def message(self, msg: slixmpp.Message) -> None: + if msg["type"] in ("chat", "normal"): + res_line = f"{self.name}: " + res_tokens = self.tokenizer.encode(res_line) + num_res_tokens = res_tokens.shape[-1] + + if self.first_round: + in_tokens = res_tokens + else: + # read and format input + in_line = f"{self.user}: " + msg["body"].strip() + "\n" + in_tokens = self.tokenizer.encode(in_line) + in_tokens = torch.cat((in_tokens, res_tokens), dim=1) + + # If we're approaching the context limit, prune some whole lines + # from the start of the context. Also prune a little extra so we + # don't end up rebuilding the cache on every line when up against + # the limit. + expect_tokens = in_tokens.shape[-1] + self.max_response_tokens + max_tokens = self.max_seq_len - expect_tokens + if self.generator.gen_num_tokens() >= max_tokens: + generator.gen_prune_to( + self.max_seq_len - expect_tokens - self.extra_prune, + self.tokenizer.newline_token_id, + ) + + # feed in the user input and "{self.name}:", tokenized + self.generator.gen_feed_tokens(in_tokens) + + # start beam search? + self.generator.begin_beam_search() + + # generate tokens, with streaming + # TODO: drop the streaming! + for i in range(self.max_response_tokens): + # disallowing the end condition tokens seems like a clean way to + # force longer replies + if i < self.min_response_tokens: + self.generator.disallow_tokens( + [ + self.tokenizer.newline_token_id, + self.tokenizer.eos_token_id, + ] + ) + else: + self.generator.disallow_tokens(None) + + # get a token + gen_token = self.generator.beam_search() + + # if token is EOS, replace it with a newline before continuing + if gen_token.item() == self.tokenizer.eos_token_id: + self.generator.replace_last_token( + self.tokenizer.newline_token_id + ) + + # decode the current line + num_res_tokens += 1 + text = self.tokenizer.decode( + self.generator.sequence_actual[:, -num_res_tokens:][0] + ) + + # append to res_line + res_line += text[len(res_line) :] + + # end conditions + breakers = [ + self.tokenizer.eos_token_id, + # self.tokenizer.newline_token_id, + ] + if gen_token.item() in breakers: + break + + # try to drop the "ben:" at the end + if res_line.endswith(f"{self.user}:"): + logging.info("rewinding!") + plen = self.tokenizer.encode(f"{self.user}:").shape[-1] + self.generator.gen_rewind(plen) + break + + # end generation and send the reply + self.generator.end_beam_search() + res_line = res_line.removeprefix(f"{self.name}:") + res_line = res_line.removesuffix(f"{self.user}:") + self.first_round = False + msg.reply(res_line).send() # type: ignore + + +MY_MODELS = [ + "Llama-2-13B-GPTQ", + "Nous-Hermes-13B-GPTQ", + "Nous-Hermes-Llama2-13b-GPTQ", + "Wizard-Vicuna-13B-Uncensored-GPTQ", + "Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ", + "Wizard-Vicuna-30B-Uncensored-GPTQ", + "CodeLlama-13B-Python-GPTQ", + "CodeLlama-13B-Instruct-GPTQ", + "CodeLlama-34B-Instruct-GPTQ", +] + + +def load_model(model_name: str) -> typing.Any: + assert model_name in MY_MODELS + if not torch.cuda.is_available(): + raise ValueError("no cuda") + sys.exit(1) + + torch.set_grad_enabled(False) + torch.cuda._lazy_init() # type: ignore + + ml_models = "/mnt/campbell/ben/ml-models" + + model_dir = os.path.join(ml_models, model_name) + + tokenizer_path = os.path.join(model_dir, "tokenizer.model") + config_path = os.path.join(model_dir, "config.json") + st_pattern = os.path.join(model_dir, "*.safetensors") + st = glob.glob(st_pattern) + if len(st) != 1: + print("found multiple safetensors!") + sys.exit() + model_path = st[0] + + config = exllama.model.ExLlamaConfig(config_path) + config.model_path = model_path + + # gpu split + config.set_auto_map("23") + + model = exllama.model.ExLlama(config) + cache = exllama.model.ExLlamaCache(model) + tokenizer = exllama.tokenizer.ExLlamaTokenizer(tokenizer_path) + + generator = exllama.generator.ExLlamaGenerator(model, tokenizer, cache) + generator.settings = exllama.generator.ExLlamaGenerator.Settings() + + return (model, tokenizer, generator) + + +def main( + model: exllama.model.ExLlama, + tokenizer: exllama.tokenizer.ExLlamaTokenizer, + generator: exllama.generator.ExLlamaGenerator, + user: str, + password: str, +) -> None: + """ + Start the chatbot. + + This purposefully does not call 'load_model()' so that you can load the + model in the repl and then restart the chatbot without unloading it. + """ + Biz.Log.setup() + xmpp = Mynion(user, password, model, tokenizer, generator) + xmpp.connect() + xmpp.process(forever=True) # type: ignore + + +if __name__ == "__main__": + if "test" in sys.argv: + print("pass: test: Biz/Mynion.py") + sys.exit(0) + else: + cli = argparse.ArgumentParser() + cli.add_argument("-u", "--user") + cli.add_argument("-p", "--password") + cli.add_argument("-m", "--model", choices=MY_MODELS) + args = cli.parse_args() + model, tokenizer, generator = load_model(args.model) + main(model, tokenizer, generator, args.user, args.password) -- cgit v1.2.3