summaryrefslogtreecommitdiff
path: root/ava.py
blob: c364f6de02032f84e13a295ea8aca3953f968963 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
# : out ava
# : dep transformers
# : dep torch
import transformers
import torch
import sys

# import sleekxmpp


# model_name = "EleutherAI/gpt-neox-20b"
model_name = "EleutherAI/gpt-j-6B"

if torch.cuda.is_available():
    device = "cuda:0"
else:
    raise ValueError("no cuda")
    sys.exit(1)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    load_in_8bit=True,
    pad_token_id=tokenizer.eos_token_id,
    revision="float16",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

# set attention_mask and pad_token_id


def gen(txt):
    input_ids = tokenizer(txt, return_tensors="pt").input_ids.to("cuda")
    outputs = model.generate(
        input_ids=input_ids,
        max_length=1024,
        temperature=0.7,
    )
    result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    result = "".join(result)
    return result


# Get user input and generate a response
while True:
    user_input = input("ben: ")
    response = gen(user_input)
    print("bot: ", response)


"""
# Set up the XMPP client
client = sleekxmpp.ClientXMPP(
  "ava@simatime.com",
  "test"
)
client.connect()
client.process(block=True)

# Define a function that takes in a user's input and returns a response
def generate_response(input_text):
  # You would use your language model to generate a response here
  response = "This is a response to the user's input: " + input_text
  return response

# Handle incoming messages
@client.add_event_handler("message")
def handle_message(message):
  # Get the user's input
  user_input = message["body"]

  # Generate a response
  response = generate_response(user_input)

  # Send the response to the user
  message.reply(response).send()
"""