summaryrefslogtreecommitdiff
path: root/ava.py
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2022-12-28 19:53:55 -0500
committerBen Sima <ben@bsima.me>2022-12-28 19:53:55 -0500
commitc3b955145998d39df39370671585a271ca6f80f0 (patch)
tree33614e03c966d205e2eadaf4dd183f52618afebc /ava.py
parent11e480c4b13808f12bc3f5db2765cebebf1aaf46 (diff)
Get ava GPT chatbot prototype working
Mostly thid required packaging up some deps, but also had to recompile stuff with cuda support.
Diffstat (limited to 'ava.py')
-rwxr-xr-xava.py53
1 files changed, 38 insertions, 15 deletions
diff --git a/ava.py b/ava.py
index 6ca3a3e..1f08241 100755
--- a/ava.py
+++ b/ava.py
@@ -1,25 +1,48 @@
#!/usr/bin/env python
-import transformers import AutoModelWithLMHead, AutoTokenizer, TextGenerator
+import transformers
+import torch
+import sys
# import sleekxmpp
-model_name = "gpt-neox-20b"
-model = AutoModelWithLMHead.from_pretrained(model_name)
-tokenizer = AutoTokenizer.from_pretrained(model_name)
-generator = TextGenerator(model=model, tokenizer=tokenizer)
+#model_name = "EleutherAI/gpt-neox-20b"
+model_name = "EleutherAI/gpt-j-6B"
-def generate_response(input_text):
- response = model.generate(
- input_ids=input_text,
- max_length=1024,
- temperature=0.7,
- )
- return response
+if torch.cuda.is_available():
+ device = "cuda:0"
+else:
+ raise ValueError("no cuda")
+
+tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
+model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_name,
+ device_map="auto",
+ load_in_8bit=True,
+ pad_token_id=tokenizer.eos_token_id,
+ revision="float16",
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+)
+
+# set attention_mask and pad_token_id
+
+def gen(txt):
+ input_ids = tokenizer(txt, return_tensors="pt").input_ids.to('cuda')
+ outputs = model.generate(
+ input_ids=input_ids,
+ max_length=1024,
+ temperature=0.7,
+ )
+ result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ result = "".join(result)
+ return result
# Get user input and generate a response
-user_input = input("User: ")
-response = generate_response(user_input)
-print("Bot: ", response)
+while True:
+ user_input = input("ben: ")
+ response = gen(user_input)
+ print("bot: ", response)
+
"""
# Set up the XMPP client