diff options
author | Ben Sima <ben@bsima.me> | 2023-08-28 21:05:25 -0400 |
---|---|---|
committer | Ben Sima <ben@bsima.me> | 2023-09-20 17:56:12 -0400 |
commit | 6e4a65579c3ade76feea0890072099f0d0caf416 (patch) | |
tree | 95671321c951134753323978854cece5f7d5435b /Biz/Mynion.py | |
parent | 13added53bbf996ec25a19b734326a6834918279 (diff) |
Prototype Mynion
This implements a prototype Mynion, my chatbot which will eventually
help me write code here. In fact he's already helping me, and works
pretty well over xmpp.
The prompt is currently not checked in because I'm experimenting with it
a lot, and it should probably be a runtime parameter anyways.
In the course of writing this I added some helper libraries to get me
going, configured black (didn't even know that was possible), and added
'outlines' and its dependencies even though I didn't end up using it.
I'll keep outlines around for now, but I'm not sure how useful it really
is because afaict its just pre-defining some stop conditions. But it
took a while to get it working so I'll just keep it in for now.
Diffstat (limited to 'Biz/Mynion.py')
-rw-r--r-- | Biz/Mynion.py | 246 |
1 files changed, 246 insertions, 0 deletions
diff --git a/Biz/Mynion.py b/Biz/Mynion.py new file mode 100644 index 0000000..6bb55e1 --- /dev/null +++ b/Biz/Mynion.py @@ -0,0 +1,246 @@ +# : out mynion +# +# : dep exllama +# : dep slixmpp +import argparse +import exllama # type: ignore +import Biz.Log +import glob +import logging +import os +import slixmpp +import slixmpp.exceptions +import sys +import torch +import typing + + +def smoosh(s: str) -> str: + return s.replace("\n", " ") + + +class Mynion(slixmpp.ClientXMPP): + def __init__( + self, + jid: str, + password: str, + model: exllama.model.ExLlama, + tokenizer: exllama.tokenizer.ExLlamaTokenizer, + generator: exllama.generator.ExLlamaGenerator, + ) -> None: + slixmpp.ClientXMPP.__init__(self, jid, password) + self.plugin.enable("xep_0085") # type: ignore + self.plugin.enable("xep_0184") # type: ignore + + self.name = "mynion" + self.user = "ben" + self.first_round = True + self.min_response_tokens = 4 + self.max_response_tokens = 256 + self.extra_prune = 256 + # todo: handle summary rollup when max_seq_len is reached + self.max_seq_len = 8000 + + self.model = model + self.tokenizer = tokenizer + self.generator = generator + + root = os.getenv("BIZ_ROOT", "") + # this should be parameterized somehow + with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f: + txt = f.read() + txt = txt.format(user=self.user, name=self.name) + + # this is the "system prompt", ideally i would load this in/out of a + # database with all of the past history. if the history gets too long, i + # can roll it up by asking llama to summarize it + self.past = smoosh(txt) + + ids = tokenizer.encode(self.past) + self.generator.gen_begin(ids) + + self.add_event_handler("session_start", self.session_start) + self.add_event_handler("message", self.message) + + def session_start(self) -> None: + self.send_presence() + try: + self.get_roster() # type: ignore + except slixmpp.exceptions.IqError as err: + logging.error("There was an error getting the roster") + logging.error(err.iq["error"]["condition"]) + self.disconnect() + except slixmpp.exceptions.IqTimeout: + logging.error("Server is taking too long to respond") + self.disconnect() + + def message(self, msg: slixmpp.Message) -> None: + if msg["type"] in ("chat", "normal"): + res_line = f"{self.name}: " + res_tokens = self.tokenizer.encode(res_line) + num_res_tokens = res_tokens.shape[-1] + + if self.first_round: + in_tokens = res_tokens + else: + # read and format input + in_line = f"{self.user}: " + msg["body"].strip() + "\n" + in_tokens = self.tokenizer.encode(in_line) + in_tokens = torch.cat((in_tokens, res_tokens), dim=1) + + # If we're approaching the context limit, prune some whole lines + # from the start of the context. Also prune a little extra so we + # don't end up rebuilding the cache on every line when up against + # the limit. + expect_tokens = in_tokens.shape[-1] + self.max_response_tokens + max_tokens = self.max_seq_len - expect_tokens + if self.generator.gen_num_tokens() >= max_tokens: + generator.gen_prune_to( + self.max_seq_len - expect_tokens - self.extra_prune, + self.tokenizer.newline_token_id, + ) + + # feed in the user input and "{self.name}:", tokenized + self.generator.gen_feed_tokens(in_tokens) + + # start beam search? + self.generator.begin_beam_search() + + # generate tokens, with streaming + # TODO: drop the streaming! + for i in range(self.max_response_tokens): + # disallowing the end condition tokens seems like a clean way to + # force longer replies + if i < self.min_response_tokens: + self.generator.disallow_tokens( + [ + self.tokenizer.newline_token_id, + self.tokenizer.eos_token_id, + ] + ) + else: + self.generator.disallow_tokens(None) + + # get a token + gen_token = self.generator.beam_search() + + # if token is EOS, replace it with a newline before continuing + if gen_token.item() == self.tokenizer.eos_token_id: + self.generator.replace_last_token( + self.tokenizer.newline_token_id + ) + + # decode the current line + num_res_tokens += 1 + text = self.tokenizer.decode( + self.generator.sequence_actual[:, -num_res_tokens:][0] + ) + + # append to res_line + res_line += text[len(res_line) :] + + # end conditions + breakers = [ + self.tokenizer.eos_token_id, + # self.tokenizer.newline_token_id, + ] + if gen_token.item() in breakers: + break + + # try to drop the "ben:" at the end + if res_line.endswith(f"{self.user}:"): + logging.info("rewinding!") + plen = self.tokenizer.encode(f"{self.user}:").shape[-1] + self.generator.gen_rewind(plen) + break + + # end generation and send the reply + self.generator.end_beam_search() + res_line = res_line.removeprefix(f"{self.name}:") + res_line = res_line.removesuffix(f"{self.user}:") + self.first_round = False + msg.reply(res_line).send() # type: ignore + + +MY_MODELS = [ + "Llama-2-13B-GPTQ", + "Nous-Hermes-13B-GPTQ", + "Nous-Hermes-Llama2-13b-GPTQ", + "Wizard-Vicuna-13B-Uncensored-GPTQ", + "Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ", + "Wizard-Vicuna-30B-Uncensored-GPTQ", + "CodeLlama-13B-Python-GPTQ", + "CodeLlama-13B-Instruct-GPTQ", + "CodeLlama-34B-Instruct-GPTQ", +] + + +def load_model(model_name: str) -> typing.Any: + assert model_name in MY_MODELS + if not torch.cuda.is_available(): + raise ValueError("no cuda") + sys.exit(1) + + torch.set_grad_enabled(False) + torch.cuda._lazy_init() # type: ignore + + ml_models = "/mnt/campbell/ben/ml-models" + + model_dir = os.path.join(ml_models, model_name) + + tokenizer_path = os.path.join(model_dir, "tokenizer.model") + config_path = os.path.join(model_dir, "config.json") + st_pattern = os.path.join(model_dir, "*.safetensors") + st = glob.glob(st_pattern) + if len(st) != 1: + print("found multiple safetensors!") + sys.exit() + model_path = st[0] + + config = exllama.model.ExLlamaConfig(config_path) + config.model_path = model_path + + # gpu split + config.set_auto_map("23") + + model = exllama.model.ExLlama(config) + cache = exllama.model.ExLlamaCache(model) + tokenizer = exllama.tokenizer.ExLlamaTokenizer(tokenizer_path) + + generator = exllama.generator.ExLlamaGenerator(model, tokenizer, cache) + generator.settings = exllama.generator.ExLlamaGenerator.Settings() + + return (model, tokenizer, generator) + + +def main( + model: exllama.model.ExLlama, + tokenizer: exllama.tokenizer.ExLlamaTokenizer, + generator: exllama.generator.ExLlamaGenerator, + user: str, + password: str, +) -> None: + """ + Start the chatbot. + + This purposefully does not call 'load_model()' so that you can load the + model in the repl and then restart the chatbot without unloading it. + """ + Biz.Log.setup() + xmpp = Mynion(user, password, model, tokenizer, generator) + xmpp.connect() + xmpp.process(forever=True) # type: ignore + + +if __name__ == "__main__": + if "test" in sys.argv: + print("pass: test: Biz/Mynion.py") + sys.exit(0) + else: + cli = argparse.ArgumentParser() + cli.add_argument("-u", "--user") + cli.add_argument("-p", "--password") + cli.add_argument("-m", "--model", choices=MY_MODELS) + args = cli.parse_args() + model, tokenizer, generator = load_model(args.model) + main(model, tokenizer, generator, args.user, args.password) |