diff options
Diffstat (limited to 'Biz/Mynion.py')
-rw-r--r-- | Biz/Mynion.py | 246 |
1 files changed, 246 insertions, 0 deletions
diff --git a/Biz/Mynion.py b/Biz/Mynion.py new file mode 100644 index 0000000..6bb55e1 --- /dev/null +++ b/Biz/Mynion.py @@ -0,0 +1,246 @@ +# : out mynion +# +# : dep exllama +# : dep slixmpp +import argparse +import exllama # type: ignore +import Biz.Log +import glob +import logging +import os +import slixmpp +import slixmpp.exceptions +import sys +import torch +import typing + + +def smoosh(s: str) -> str: + return s.replace("\n", " ") + + +class Mynion(slixmpp.ClientXMPP): + def __init__( + self, + jid: str, + password: str, + model: exllama.model.ExLlama, + tokenizer: exllama.tokenizer.ExLlamaTokenizer, + generator: exllama.generator.ExLlamaGenerator, + ) -> None: + slixmpp.ClientXMPP.__init__(self, jid, password) + self.plugin.enable("xep_0085") # type: ignore + self.plugin.enable("xep_0184") # type: ignore + + self.name = "mynion" + self.user = "ben" + self.first_round = True + self.min_response_tokens = 4 + self.max_response_tokens = 256 + self.extra_prune = 256 + # todo: handle summary rollup when max_seq_len is reached + self.max_seq_len = 8000 + + self.model = model + self.tokenizer = tokenizer + self.generator = generator + + root = os.getenv("BIZ_ROOT", "") + # this should be parameterized somehow + with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f: + txt = f.read() + txt = txt.format(user=self.user, name=self.name) + + # this is the "system prompt", ideally i would load this in/out of a + # database with all of the past history. if the history gets too long, i + # can roll it up by asking llama to summarize it + self.past = smoosh(txt) + + ids = tokenizer.encode(self.past) + self.generator.gen_begin(ids) + + self.add_event_handler("session_start", self.session_start) + self.add_event_handler("message", self.message) + + def session_start(self) -> None: + self.send_presence() + try: + self.get_roster() # type: ignore + except slixmpp.exceptions.IqError as err: + logging.error("There was an error getting the roster") + logging.error(err.iq["error"]["condition"]) + self.disconnect() + except slixmpp.exceptions.IqTimeout: + logging.error("Server is taking too long to respond") + self.disconnect() + + def message(self, msg: slixmpp.Message) -> None: + if msg["type"] in ("chat", "normal"): + res_line = f"{self.name}: " + res_tokens = self.tokenizer.encode(res_line) + num_res_tokens = res_tokens.shape[-1] + + if self.first_round: + in_tokens = res_tokens + else: + # read and format input + in_line = f"{self.user}: " + msg["body"].strip() + "\n" + in_tokens = self.tokenizer.encode(in_line) + in_tokens = torch.cat((in_tokens, res_tokens), dim=1) + + # If we're approaching the context limit, prune some whole lines + # from the start of the context. Also prune a little extra so we + # don't end up rebuilding the cache on every line when up against + # the limit. + expect_tokens = in_tokens.shape[-1] + self.max_response_tokens + max_tokens = self.max_seq_len - expect_tokens + if self.generator.gen_num_tokens() >= max_tokens: + generator.gen_prune_to( + self.max_seq_len - expect_tokens - self.extra_prune, + self.tokenizer.newline_token_id, + ) + + # feed in the user input and "{self.name}:", tokenized + self.generator.gen_feed_tokens(in_tokens) + + # start beam search? + self.generator.begin_beam_search() + + # generate tokens, with streaming + # TODO: drop the streaming! + for i in range(self.max_response_tokens): + # disallowing the end condition tokens seems like a clean way to + # force longer replies + if i < self.min_response_tokens: + self.generator.disallow_tokens( + [ + self.tokenizer.newline_token_id, + self.tokenizer.eos_token_id, + ] + ) + else: + self.generator.disallow_tokens(None) + + # get a token + gen_token = self.generator.beam_search() + + # if token is EOS, replace it with a newline before continuing + if gen_token.item() == self.tokenizer.eos_token_id: + self.generator.replace_last_token( + self.tokenizer.newline_token_id + ) + + # decode the current line + num_res_tokens += 1 + text = self.tokenizer.decode( + self.generator.sequence_actual[:, -num_res_tokens:][0] + ) + + # append to res_line + res_line += text[len(res_line) :] + + # end conditions + breakers = [ + self.tokenizer.eos_token_id, + # self.tokenizer.newline_token_id, + ] + if gen_token.item() in breakers: + break + + # try to drop the "ben:" at the end + if res_line.endswith(f"{self.user}:"): + logging.info("rewinding!") + plen = self.tokenizer.encode(f"{self.user}:").shape[-1] + self.generator.gen_rewind(plen) + break + + # end generation and send the reply + self.generator.end_beam_search() + res_line = res_line.removeprefix(f"{self.name}:") + res_line = res_line.removesuffix(f"{self.user}:") + self.first_round = False + msg.reply(res_line).send() # type: ignore + + +MY_MODELS = [ + "Llama-2-13B-GPTQ", + "Nous-Hermes-13B-GPTQ", + "Nous-Hermes-Llama2-13b-GPTQ", + "Wizard-Vicuna-13B-Uncensored-GPTQ", + "Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ", + "Wizard-Vicuna-30B-Uncensored-GPTQ", + "CodeLlama-13B-Python-GPTQ", + "CodeLlama-13B-Instruct-GPTQ", + "CodeLlama-34B-Instruct-GPTQ", +] + + +def load_model(model_name: str) -> typing.Any: + assert model_name in MY_MODELS + if not torch.cuda.is_available(): + raise ValueError("no cuda") + sys.exit(1) + + torch.set_grad_enabled(False) + torch.cuda._lazy_init() # type: ignore + + ml_models = "/mnt/campbell/ben/ml-models" + + model_dir = os.path.join(ml_models, model_name) + + tokenizer_path = os.path.join(model_dir, "tokenizer.model") + config_path = os.path.join(model_dir, "config.json") + st_pattern = os.path.join(model_dir, "*.safetensors") + st = glob.glob(st_pattern) + if len(st) != 1: + print("found multiple safetensors!") + sys.exit() + model_path = st[0] + + config = exllama.model.ExLlamaConfig(config_path) + config.model_path = model_path + + # gpu split + config.set_auto_map("23") + + model = exllama.model.ExLlama(config) + cache = exllama.model.ExLlamaCache(model) + tokenizer = exllama.tokenizer.ExLlamaTokenizer(tokenizer_path) + + generator = exllama.generator.ExLlamaGenerator(model, tokenizer, cache) + generator.settings = exllama.generator.ExLlamaGenerator.Settings() + + return (model, tokenizer, generator) + + +def main( + model: exllama.model.ExLlama, + tokenizer: exllama.tokenizer.ExLlamaTokenizer, + generator: exllama.generator.ExLlamaGenerator, + user: str, + password: str, +) -> None: + """ + Start the chatbot. + + This purposefully does not call 'load_model()' so that you can load the + model in the repl and then restart the chatbot without unloading it. + """ + Biz.Log.setup() + xmpp = Mynion(user, password, model, tokenizer, generator) + xmpp.connect() + xmpp.process(forever=True) # type: ignore + + +if __name__ == "__main__": + if "test" in sys.argv: + print("pass: test: Biz/Mynion.py") + sys.exit(0) + else: + cli = argparse.ArgumentParser() + cli.add_argument("-u", "--user") + cli.add_argument("-p", "--password") + cli.add_argument("-m", "--model", choices=MY_MODELS) + args = cli.parse_args() + model, tokenizer, generator = load_model(args.model) + main(model, tokenizer, generator, args.user, args.password) |