summaryrefslogtreecommitdiff
path: root/ava.py
blob: 1f082414916d21ef93619c566c3ae0240ef5c462 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python
import transformers
import torch
import sys
# import sleekxmpp


#model_name = "EleutherAI/gpt-neox-20b"
model_name = "EleutherAI/gpt-j-6B"

if torch.cuda.is_available():
    device = "cuda:0"
else:
    raise ValueError("no cuda")

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
  model_name,
  device_map="auto",
  load_in_8bit=True,
  pad_token_id=tokenizer.eos_token_id,
  revision="float16",
  torch_dtype=torch.float16,
  low_cpu_mem_usage=True,
)

# set attention_mask and pad_token_id

def gen(txt):
    input_ids = tokenizer(txt, return_tensors="pt").input_ids.to('cuda')
    outputs = model.generate(
      input_ids=input_ids,
      max_length=1024,
      temperature=0.7,
    )
    result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    result = "".join(result)
    return result

# Get user input and generate a response
while True:
    user_input = input("ben: ")
    response = gen(user_input)
    print("bot: ", response)


"""
# Set up the XMPP client
client = sleekxmpp.ClientXMPP(
  "ava@simatime.com",
  "test"
)
client.connect()
client.process(block=True)

# Define a function that takes in a user's input and returns a response
def generate_response(input_text):
  # You would use your language model to generate a response here
  response = "This is a response to the user's input: " + input_text
  return response

# Handle incoming messages
@client.add_event_handler("message")
def handle_message(message):
  # Get the user's input
  user_input = message["body"]

  # Generate a response
  response = generate_response(user_input)

  # Send the response to the user
  message.reply(response).send()
"""