summaryrefslogtreecommitdiff
path: root/Biz/Mynion.py
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/Mynion.py')
-rw-r--r--Biz/Mynion.py246
1 files changed, 246 insertions, 0 deletions
diff --git a/Biz/Mynion.py b/Biz/Mynion.py
new file mode 100644
index 0000000..6bb55e1
--- /dev/null
+++ b/Biz/Mynion.py
@@ -0,0 +1,246 @@
+# : out mynion
+#
+# : dep exllama
+# : dep slixmpp
+import argparse
+import exllama # type: ignore
+import Biz.Log
+import glob
+import logging
+import os
+import slixmpp
+import slixmpp.exceptions
+import sys
+import torch
+import typing
+
+
+def smoosh(s: str) -> str:
+ return s.replace("\n", " ")
+
+
+class Mynion(slixmpp.ClientXMPP):
+ def __init__(
+ self,
+ jid: str,
+ password: str,
+ model: exllama.model.ExLlama,
+ tokenizer: exllama.tokenizer.ExLlamaTokenizer,
+ generator: exllama.generator.ExLlamaGenerator,
+ ) -> None:
+ slixmpp.ClientXMPP.__init__(self, jid, password)
+ self.plugin.enable("xep_0085") # type: ignore
+ self.plugin.enable("xep_0184") # type: ignore
+
+ self.name = "mynion"
+ self.user = "ben"
+ self.first_round = True
+ self.min_response_tokens = 4
+ self.max_response_tokens = 256
+ self.extra_prune = 256
+ # todo: handle summary rollup when max_seq_len is reached
+ self.max_seq_len = 8000
+
+ self.model = model
+ self.tokenizer = tokenizer
+ self.generator = generator
+
+ root = os.getenv("BIZ_ROOT", "")
+ # this should be parameterized somehow
+ with open(os.path.join(root, "Biz", "Mynion", "Prompt.md")) as f:
+ txt = f.read()
+ txt = txt.format(user=self.user, name=self.name)
+
+ # this is the "system prompt", ideally i would load this in/out of a
+ # database with all of the past history. if the history gets too long, i
+ # can roll it up by asking llama to summarize it
+ self.past = smoosh(txt)
+
+ ids = tokenizer.encode(self.past)
+ self.generator.gen_begin(ids)
+
+ self.add_event_handler("session_start", self.session_start)
+ self.add_event_handler("message", self.message)
+
+ def session_start(self) -> None:
+ self.send_presence()
+ try:
+ self.get_roster() # type: ignore
+ except slixmpp.exceptions.IqError as err:
+ logging.error("There was an error getting the roster")
+ logging.error(err.iq["error"]["condition"])
+ self.disconnect()
+ except slixmpp.exceptions.IqTimeout:
+ logging.error("Server is taking too long to respond")
+ self.disconnect()
+
+ def message(self, msg: slixmpp.Message) -> None:
+ if msg["type"] in ("chat", "normal"):
+ res_line = f"{self.name}: "
+ res_tokens = self.tokenizer.encode(res_line)
+ num_res_tokens = res_tokens.shape[-1]
+
+ if self.first_round:
+ in_tokens = res_tokens
+ else:
+ # read and format input
+ in_line = f"{self.user}: " + msg["body"].strip() + "\n"
+ in_tokens = self.tokenizer.encode(in_line)
+ in_tokens = torch.cat((in_tokens, res_tokens), dim=1)
+
+ # If we're approaching the context limit, prune some whole lines
+ # from the start of the context. Also prune a little extra so we
+ # don't end up rebuilding the cache on every line when up against
+ # the limit.
+ expect_tokens = in_tokens.shape[-1] + self.max_response_tokens
+ max_tokens = self.max_seq_len - expect_tokens
+ if self.generator.gen_num_tokens() >= max_tokens:
+ generator.gen_prune_to(
+ self.max_seq_len - expect_tokens - self.extra_prune,
+ self.tokenizer.newline_token_id,
+ )
+
+ # feed in the user input and "{self.name}:", tokenized
+ self.generator.gen_feed_tokens(in_tokens)
+
+ # start beam search?
+ self.generator.begin_beam_search()
+
+ # generate tokens, with streaming
+ # TODO: drop the streaming!
+ for i in range(self.max_response_tokens):
+ # disallowing the end condition tokens seems like a clean way to
+ # force longer replies
+ if i < self.min_response_tokens:
+ self.generator.disallow_tokens(
+ [
+ self.tokenizer.newline_token_id,
+ self.tokenizer.eos_token_id,
+ ]
+ )
+ else:
+ self.generator.disallow_tokens(None)
+
+ # get a token
+ gen_token = self.generator.beam_search()
+
+ # if token is EOS, replace it with a newline before continuing
+ if gen_token.item() == self.tokenizer.eos_token_id:
+ self.generator.replace_last_token(
+ self.tokenizer.newline_token_id
+ )
+
+ # decode the current line
+ num_res_tokens += 1
+ text = self.tokenizer.decode(
+ self.generator.sequence_actual[:, -num_res_tokens:][0]
+ )
+
+ # append to res_line
+ res_line += text[len(res_line) :]
+
+ # end conditions
+ breakers = [
+ self.tokenizer.eos_token_id,
+ # self.tokenizer.newline_token_id,
+ ]
+ if gen_token.item() in breakers:
+ break
+
+ # try to drop the "ben:" at the end
+ if res_line.endswith(f"{self.user}:"):
+ logging.info("rewinding!")
+ plen = self.tokenizer.encode(f"{self.user}:").shape[-1]
+ self.generator.gen_rewind(plen)
+ break
+
+ # end generation and send the reply
+ self.generator.end_beam_search()
+ res_line = res_line.removeprefix(f"{self.name}:")
+ res_line = res_line.removesuffix(f"{self.user}:")
+ self.first_round = False
+ msg.reply(res_line).send() # type: ignore
+
+
+MY_MODELS = [
+ "Llama-2-13B-GPTQ",
+ "Nous-Hermes-13B-GPTQ",
+ "Nous-Hermes-Llama2-13b-GPTQ",
+ "Wizard-Vicuna-13B-Uncensored-GPTQ",
+ "Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ",
+ "Wizard-Vicuna-30B-Uncensored-GPTQ",
+ "CodeLlama-13B-Python-GPTQ",
+ "CodeLlama-13B-Instruct-GPTQ",
+ "CodeLlama-34B-Instruct-GPTQ",
+]
+
+
+def load_model(model_name: str) -> typing.Any:
+ assert model_name in MY_MODELS
+ if not torch.cuda.is_available():
+ raise ValueError("no cuda")
+ sys.exit(1)
+
+ torch.set_grad_enabled(False)
+ torch.cuda._lazy_init() # type: ignore
+
+ ml_models = "/mnt/campbell/ben/ml-models"
+
+ model_dir = os.path.join(ml_models, model_name)
+
+ tokenizer_path = os.path.join(model_dir, "tokenizer.model")
+ config_path = os.path.join(model_dir, "config.json")
+ st_pattern = os.path.join(model_dir, "*.safetensors")
+ st = glob.glob(st_pattern)
+ if len(st) != 1:
+ print("found multiple safetensors!")
+ sys.exit()
+ model_path = st[0]
+
+ config = exllama.model.ExLlamaConfig(config_path)
+ config.model_path = model_path
+
+ # gpu split
+ config.set_auto_map("23")
+
+ model = exllama.model.ExLlama(config)
+ cache = exllama.model.ExLlamaCache(model)
+ tokenizer = exllama.tokenizer.ExLlamaTokenizer(tokenizer_path)
+
+ generator = exllama.generator.ExLlamaGenerator(model, tokenizer, cache)
+ generator.settings = exllama.generator.ExLlamaGenerator.Settings()
+
+ return (model, tokenizer, generator)
+
+
+def main(
+ model: exllama.model.ExLlama,
+ tokenizer: exllama.tokenizer.ExLlamaTokenizer,
+ generator: exllama.generator.ExLlamaGenerator,
+ user: str,
+ password: str,
+) -> None:
+ """
+ Start the chatbot.
+
+ This purposefully does not call 'load_model()' so that you can load the
+ model in the repl and then restart the chatbot without unloading it.
+ """
+ Biz.Log.setup()
+ xmpp = Mynion(user, password, model, tokenizer, generator)
+ xmpp.connect()
+ xmpp.process(forever=True) # type: ignore
+
+
+if __name__ == "__main__":
+ if "test" in sys.argv:
+ print("pass: test: Biz/Mynion.py")
+ sys.exit(0)
+ else:
+ cli = argparse.ArgumentParser()
+ cli.add_argument("-u", "--user")
+ cli.add_argument("-p", "--password")
+ cli.add_argument("-m", "--model", choices=MY_MODELS)
+ args = cli.parse_args()
+ model, tokenizer, generator = load_model(args.model)
+ main(model, tokenizer, generator, args.user, args.password)