summaryrefslogtreecommitdiff
path: root/ava.py
diff options
context:
space:
mode:
Diffstat (limited to 'ava.py')
-rwxr-xr-xava.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/ava.py b/ava.py
index 1f08241..c364f6d 100755
--- a/ava.py
+++ b/ava.py
@@ -1,42 +1,49 @@
#!/usr/bin/env python
+# : out ava
+# : dep transformers
+# : dep torch
import transformers
import torch
import sys
+
# import sleekxmpp
-#model_name = "EleutherAI/gpt-neox-20b"
+# model_name = "EleutherAI/gpt-neox-20b"
model_name = "EleutherAI/gpt-j-6B"
if torch.cuda.is_available():
device = "cuda:0"
else:
raise ValueError("no cuda")
+ sys.exit(1)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
- model_name,
- device_map="auto",
- load_in_8bit=True,
- pad_token_id=tokenizer.eos_token_id,
- revision="float16",
- torch_dtype=torch.float16,
- low_cpu_mem_usage=True,
+ model_name,
+ device_map="auto",
+ load_in_8bit=True,
+ pad_token_id=tokenizer.eos_token_id,
+ revision="float16",
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
)
# set attention_mask and pad_token_id
+
def gen(txt):
- input_ids = tokenizer(txt, return_tensors="pt").input_ids.to('cuda')
+ input_ids = tokenizer(txt, return_tensors="pt").input_ids.to("cuda")
outputs = model.generate(
- input_ids=input_ids,
- max_length=1024,
- temperature=0.7,
+ input_ids=input_ids,
+ max_length=1024,
+ temperature=0.7,
)
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
result = "".join(result)
return result
+
# Get user input and generate a response
while True:
user_input = input("ben: ")