diff options
author | Ben Sima <ben@bsima.me> | 2022-12-28 19:53:55 -0500 |
---|---|---|
committer | Ben Sima <ben@bsima.me> | 2022-12-28 19:53:55 -0500 |
commit | c3b955145998d39df39370671585a271ca6f80f0 (patch) | |
tree | 33614e03c966d205e2eadaf4dd183f52618afebc /ava.py | |
parent | 11e480c4b13808f12bc3f5db2765cebebf1aaf46 (diff) |
Get ava GPT chatbot prototype working
Mostly thid required packaging up some deps, but also had to recompile
stuff with cuda support.
Diffstat (limited to 'ava.py')
-rwxr-xr-x | ava.py | 53 |
1 files changed, 38 insertions, 15 deletions
@@ -1,25 +1,48 @@ #!/usr/bin/env python -import transformers import AutoModelWithLMHead, AutoTokenizer, TextGenerator +import transformers +import torch +import sys # import sleekxmpp -model_name = "gpt-neox-20b" -model = AutoModelWithLMHead.from_pretrained(model_name) -tokenizer = AutoTokenizer.from_pretrained(model_name) -generator = TextGenerator(model=model, tokenizer=tokenizer) +#model_name = "EleutherAI/gpt-neox-20b" +model_name = "EleutherAI/gpt-j-6B" -def generate_response(input_text): - response = model.generate( - input_ids=input_text, - max_length=1024, - temperature=0.7, - ) - return response +if torch.cuda.is_available(): + device = "cuda:0" +else: + raise ValueError("no cuda") + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) +model = transformers.AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + load_in_8bit=True, + pad_token_id=tokenizer.eos_token_id, + revision="float16", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, +) + +# set attention_mask and pad_token_id + +def gen(txt): + input_ids = tokenizer(txt, return_tensors="pt").input_ids.to('cuda') + outputs = model.generate( + input_ids=input_ids, + max_length=1024, + temperature=0.7, + ) + result = tokenizer.batch_decode(outputs, skip_special_tokens=True) + result = "".join(result) + return result # Get user input and generate a response -user_input = input("User: ") -response = generate_response(user_input) -print("Bot: ", response) +while True: + user_input = input("ben: ") + response = gen(user_input) + print("bot: ", response) + """ # Set up the XMPP client |