-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunibot.py
179 lines (138 loc) · 5.26 KB
/
unibot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import pyautogui
import pynput.mouse
import pyperclip
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pynput.mouse import Listener
import argparse
import platform
def get_parser_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang",
default="ru",
help="Bot language.")
parser.add_argument(
"--memory",
default=5,
type=int,
help="How many messages should bot remember.")
parser.add_argument(
"--max-mess-len",
default=1024,
type=int,
help="Max message length in symbols.")
parser.add_argument(
"--temperature",
default=0.8,
type=float,
help="The more is temperature the more nonsense bot will generate.")
parser.add_argument(
"--response",
default=10.0,
type=float,
help="How often bot would check chat and respond in seconds")
args = parser.parse_args()
return args
pos = 0
pos2 = 0
right_pressed = False
def on_click_pos(x, y, button, pressed): # preparing input/output sources
global right_pressed
global pos
global pos2
if button == pynput.mouse.Button.right:
pos = (x, y)
print(f"Input point was set to {pos}")
print("Please left click where bot should paste output data")
right_pressed = True
elif button == pynput.mouse.Button.left and right_pressed:
pos2 = (x, y)
print(f"Output point was set to {pos2}")
print("Starting the bot...")
return False
def paste(text: str): # func for text pasting
pyperclip.copy(text)
if platform.system() == "Darwin":
pyautogui.hotkey("command", "v")
else:
pyautogui.hotkey("ctrl", "v")
def get_length_param(text: str) -> str: # input str size in tokens
tokens_count = len(tokenizer.encode(text))
if tokens_count <= 15:
len_param = '1'
elif tokens_count <= 50:
len_param = '2'
elif tokens_count <= 256:
len_param = '3'
else:
len_param = '-'
return len_param
model_dict = {'ru': 'Grossmend/rudialogpt3_medium_based_on_gpt2',
'en': 'microsoft/DialoGPT-medium'}
if __name__ == '__main__':
args = get_parser_args()
if args.lang not in model_dict.keys():
raise ValueError('Model language must be "en" or "ru"!')
print("Loading model weights...")
tokenizer = AutoTokenizer.from_pretrained(model_dict[args.lang])
model = AutoModelForCausalLM.from_pretrained(model_dict[args.lang])
print("Please right click from where bot should get input data")
with Listener(on_click=on_click_pos) as listener:
listener.join()
phrase = ''
step = 0
while True: # main loop
time.sleep(args.response)
def copy_clipboard():
pyautogui.hotkey('ctrl', 'c')
time.sleep(.01) # ctrl-c is usually very fast but your program may execute faster
return pyperclip.paste()
# double clicks on a position of the cursor to select all message
pyautogui.doubleClick(pos)
pyautogui.doubleClick(pos)
pyautogui.doubleClick(pos)
var = str(copy_clipboard()).strip('\n').strip()
if var == phrase: # if current input is the same as latest generated text ignore current iteration
continue
# encode the new user input, add parameters and return a tensor in Pytorch
if args.lang == 'ru':
new_user_input_ids = tokenizer.encode(
f"|0|{get_length_param(var)}|" + var + tokenizer.eos_token + "|1|1|", return_tensors="pt")
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids],
dim=-1) if step > 0 else new_user_input_ids
if args.lang == 'en':
new_user_input_ids = tokenizer.encode(var + tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids],
dim=-1) if step > 0 else new_user_input_ids
# generated a response
chat_history_ids = model.generate(
bot_input_ids,
num_return_sequences=1,
max_length=args.max_mess_len,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.9,
temperature=args.temperature,
mask_token_id=tokenizer.mask_token_id,
eos_token_id=tokenizer.eos_token_id,
unk_token_id=tokenizer.unk_token_id,
pad_token_id=tokenizer.pad_token_id,
device='cpu',
)
step += 1
if step >= args.memory: # reset memory
step = 0
# double clicks on a position of the cursor to paste text
pyautogui.doubleClick(pos2)
pyautogui.doubleClick(pos2)
gen_text = f'{tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)}' # generate a response
print(gen_text)
phrase = gen_text
paste(gen_text)
pyautogui.press('Enter') # send a response
pyautogui.hotkey('ctrl', 'Enter') # send a response