-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathServer.py
More file actions
340 lines (291 loc) · 17 KB
/
Server.py
File metadata and controls
340 lines (291 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import asyncio
import os
import re
import sys
import time
from collections import deque
from configparser import ConfigParser
from typing import Union
import discord
from discord.errors import HTTPException
import requests
from PIL import Image
import LangModelAPI as LangModel
import StableDiffusionAPI as sd
from CaptionerAPI import describe_image
from makeConfig import makeConfig
from CacheUsers import UserCache
from MemeDbAPI import MemeDatabase, get_image_embedding
profile = sys.argv[1] if len(sys.argv) > 1 else 'default'
makeConfig(profile)
config = ConfigParser()
config.read('config.ini')
token = config[profile]['token']
help_text = """Commands:
n - refers to the "n" most recent messages used for context in the channel.
n2 - refers to the "n2" most recent messages to skip when using context in the channel.
<> - required argument
[] - optional argument
... - any number of arguments
user_name - @mention of the user
Text operations:
`help` - display this message
`summarize <n> [n2]` - summarize the last n messages, optionally skipping the last n2 messages
`query <n> [n2] <...>` - query the chatbot with the given text, optionally skipping the last n2 messages. Ex: `query 10 2 What conclusions can we draw from this?`
`response <...>` - respond to the chatbot with the given text
`prompt <...>` - prompt the bare chatbot with the given text
`raw [n] <...>` - use the raw language model without user/agent tokens. `n` is the number of tokens to generate (default 4000)
`roast <user_name> <n> [n2]` - roast the user with the given name using the context from the past n messages, optionally skipping the last n2 messages (Doesn't work well. Better prompt engineering needed)
`act_like <user_name> <n> [n2]` - act like the user with the given name and respond as them. n is the number of messages for context, optionally skipping the last n2 messages
Image operations:
`generate [-waifu] <...>` - generate an image with the given prompt (normal Stable Diffusion) or waifu (hakurei/waifu-diffusion) if `-waifu` flag is set
`find_meme [n] <...>` - find a meme by text. `n` is the max number of results to return (default 5)
"""
commands = [r'(?P<command>help)',
r'(?P<command>summarize)\s+(?P<n>\d+)(?:\s+(?P<n2>\d*))?',
r'(?P<command>query)\s+(?P<n>\d+)(?:\s+(?P<n2>\d+))?\s+(?P<text>(?:.|\s)+)',
r'(?P<command>response)\s+(?P<text>(?:.|\s)+)',
r'(?P<command>prompt)\s+(?P<text>(?:.|\s)+)',
r'(?P<command>raw)\s+(?P<n>\d+)?\s+(?P<text>(?:.|\s)+)',
r'(?P<command>roast)\s+(?P<user>[^\s]+)\s+(?P<n>\d+)(?:\s+(?P<n2>\d*))?',
r'(?P<command>act_like)\s+(?P<user>[^\s]+)\s+(?P<n>\d+)(?:\s+(?P<n2>\d*))?',
r'(?P<command>generate)\s*(?P<isWaifu>\-waifu)?\s+(?P<text>(?:.|\s)+)',
r'(?P<command>find_meme)\s+(?P<n>\d+)?\s*(?P<text>(?:.|\s)+)']
class myClient(discord.Client):
async def on_ready(self):
print(f'Logged on as {self.user} ({self.user.mention})') # type: ignore
self.model_info = LangModel.get_model_info()
self.terminal_size = os.get_terminal_size()[0]
self.message_time_queue: deque[float] = deque(maxlen=24)
self.conversation_history: dict[int, str] = {} # {channel_id: conversation}
self.keep_history = False
self.user_cache = UserCache()
try:
self.meme_client = MemeDatabase('UCASEmbeddings', config[profile]['MemeDB_ip'], config[profile]['MemeDB_port'])
except Exception as e:
print('Could not connect to MemeDB because of error:', e)
self.meme_client = None
def get_user_name(self, user: Union[discord.User, discord.Member]):
name = self.user_cache.get_user(str(user.id))
if not name:
name = user.nick if isinstance(user, discord.Member) and user.nick else user.global_name
name = name if name else user.name
self.user_cache.add_user(str(user.id), name) if not user.bot else self.user_cache.add_bot(name)
if user.bot:
name = self.user_cache.get_user('bots', name)
return name
async def edit_message(self, message: discord.Message, content: str, no_check=False):
if len(self.message_time_queue) and (time.time() - self.message_time_queue[0] > 60 or no_check):
self.message_time_queue.popleft()
if len(self.message_time_queue) < self.message_time_queue.maxlen and (not len(self.message_time_queue) or (time.time() - self.message_time_queue[-1]) > 1.5) or no_check: # type: ignore
self.message_time_queue.append(time.time())
try:
return await message.edit(content=content)
except HTTPException as e:
await message.channel.send('Message too long to continue editing', silent=True)
raise e
async def format_messages(self, content, message: discord.Message, n: str, _n2: str = '0') -> tuple[discord.Message, str]:
if n is None or not n.isdigit():
return await message.channel.send('Working on it...', silent=True), '[Empty response]'
n2: int = int(_n2) if isinstance(_n2, str) and _n2.isdigit() else 0
sent_message_content = ''
messages = message.channel.history(limit=int(n)+1)
messages = [message async for message in messages][:n2:-1]
sent_message = await message.channel.send('Working on it...', silent=True)
self.message_time_queue.append(time.time())
previous_author = None
num_messages = len(messages)
for message_num, message in enumerate(messages):
content = message.content
for user in message.mentions:
name = self.get_user_name(user)
content = re.sub(f'<@!?{user.id}>', name, content)
for embed in message.embeds:
url = embed.url
if not url:
continue
nurl = re.match(r'^(.+\.((png)|(jpg)|(jpeg))).*', url)
if not nurl:
continue
url, extension = nurl.group(0), nurl.group(2)
embed = embed.to_dict()
if extension in ('png', 'jpg', 'jpeg'):
await self.edit_message(sent_message, f'Describing {url}')
url_replacement = f"<{embed['type']}>{describe_image(url)}</{embed['type']}>" # type: ignore
else:
url_replacement = ""
embed['url'] = re.sub(r'([(^)|*$])', r'\\\1', embed['url']) # type: ignore
content = re.sub(embed['url'], url_replacement, content)
if message.attachments:
for attachment in message.attachments:
if attachment.content_type == 'image/png' or attachment.content_type == 'image/jpeg':
await sent_message.edit(content=f'Analyzing {attachment.url}\n{message_num+1}/{num_messages} ({(message_num+1)/num_messages:.2%}) messages')
content += f"<{attachment.content_type}>{describe_image(attachment.url)}</{attachment.content_type}>"
name = self.get_user_name(message.author)
if content and previous_author != name:
if previous_author:
sent_message_content += self.model_info["end_token"] + '\n'
sent_message_content += f'{self.model_info["start_token"]}{name}\n{content}'
previous_author = name
elif content:
sent_message_content += f'\n{content}'
sent_message_content += self.model_info["end_token"] + '\n'
sent_message_content = re.sub(r'<\/(.*?)>\s+<\1>', r'\n', sent_message_content)
return sent_message, sent_message_content
async def send_message(self, generator, sent_message: discord.Message):
self.terminal_size = os.get_terminal_size()[0]
if sent_message.channel.id not in self.conversation_history or not self.keep_history:
self.conversation_history[sent_message.channel.id] = next(generator)
else:
self.conversation_history[sent_message.channel.id] += next(generator).replace(self.conversation_history[sent_message.channel.id], '')
self.keep_history = False
response = ''
for response in generator:
try:
if re.match(r'^\s*$', response):
await self.edit_message(sent_message, '[Empty response]', no_check=True)
continue
await self.edit_message(sent_message, response)
print(response.split('\n')[-1][-self.terminal_size:], end='\r\r')
except:
break
print()
self.conversation_history[sent_message.channel.id] += f'{response}{self.model_info["end_token"]}\n'
return await self.edit_message(sent_message, response, no_check=True)
async def send_image(self, image: Image.Image, sent_message: discord.Message):
image.save('temp.jpg', quality=95, subsampling=0)
discord_image = discord.File('temp.jpg')
await sent_message.channel.send(file=discord_image, silent=True)
# Delete the sent message
await sent_message.delete()
async def reply_and_react(self, message: discord.Message, reaction: str, response: str):
try:
await message.add_reaction(reaction)
except discord.errors.Forbidden:
response = f'{reaction} {response}'
# await message.reply(response, mention_author=False)
async def on_meme(self, message: discord.Message):
if not self.meme_client:
return
imgs_to_add_to_db: list[Image.Image] = []
# from random import randint
# await asyncio.sleep(randint(60, 300))
await asyncio.sleep(2)
messages = [m async for m in message.channel.history(limit=50) if m.id == message.id][0]
for embedObj in messages.embeds:
url = embedObj.url
if not url:
continue
nurl = re.match(r'^(.+\.((png)|(jpg)|(jpeg))).*', url)
if not nurl:
continue
url, extension = nurl.group(0), nurl.group(2)
imgs_to_add_to_db.append(Image.open(requests.get(url, stream=True).raw))
for attachment in messages.attachments:
if attachment.content_type and attachment.content_type.startswith('image'):
imgs_to_add_to_db.append(Image.open(requests.get(attachment.url, stream=True).raw))
if not imgs_to_add_to_db:
return
img_vec = []
for image in imgs_to_add_to_db:
response = self.meme_client.query(image)[0]
confidence = 1 - response['@distance']
if confidence < .69:
print(f'Image not in database with distance {response["@distance"]} ({confidence:.2%} confidence)') # type: ignore
img_vec.append(self.meme_client.format_img(image))
continue
print(f'Image already in database with distance {response["@distance"]}') # type: ignore
discord_link = f'https://discord.com/channels/{message.guild.id}/{message.channel.id}/{response["MessageID"]}' # type: ignore
print(discord_link)
await self.reply_and_react(message, '♻️', f'Meme already posted {discord_link} with {confidence:.2%} confidence') # type: ignore
if not img_vec:
return
status, response = self.meme_client.insert([{'MessageID': str(messages.id), 'PixelVec': img} for img in img_vec])
print(response['message'])
async def on_message(self, message: discord.Message):
if message.channel.id == int(config[profile]['meme_channel_id']):
await self.on_meme(message)
if not message.content:
return
content = message.content.split(' ')
if content.pop(0) != self.user.mention: # type: ignore
return
if not content:
return
mentions = {user.id: self.get_user_name(user) for user in message.mentions}
content = ' '.join(content)
mat = self.get_matching_command(content)
if not mat:
return
command = mat.group('command')
if command == 'help':
await message.channel.send(help_text)
return
if command == 'summarize':
sent_message, sent_message_content = await self.format_messages(content, message, mat.group('n'), mat.group('n2'))
return await self.send_message(LangModel.summarize(sent_message_content), sent_message)
if command == 'query':
sent_message, sent_message_content = await self.format_messages(content, message, mat.group('n'), mat.group('n2'))
return await self.send_message(LangModel.query(sent_message_content, mat.group('text')), sent_message)
if command == 'response':
sent_message = await message.channel.send('Working on it...')
self.keep_history = True
history = self.conversation_history[sent_message.channel.id] if sent_message.channel.id in self.conversation_history else None
return await self.send_message(LangModel.response(history, mat.group('text')), sent_message)
if command == 'prompt':
sent_message = await message.channel.send('Working on it...')
return await self.send_message(LangModel.prompt(mat.group('text')), sent_message)
if command == 'raw':
sent_message = await message.channel.send('Working on it...')
if not mat.group('n') and not mat.group('n').isdigit():
return await self.edit_message(sent_message, 'n must be a number')
n = int(mat.group('n')) if mat.group('n') else 4000
return await self.send_message(LangModel.raw(mat.group('text'), max_new_tokens=n), sent_message)
if command == 'roast':
sent_message, sent_message_content = await self.format_messages(content, message, mat.group('n'), mat.group('n2'))
username = re.sub(r'<@!?(\d+)>', r'\1', mat.group('user'))
username = mentions[int(username)] if username.isdigit() else username
return await self.send_message(LangModel.roast(sent_message_content, username), sent_message)
if command == 'act_like':
sent_message, sent_message_content = await self.format_messages(content, message, mat.group('n'), mat.group('n2'))
username = re.sub(r'<@!?(\d+)>', r'\1', mat.group('user'))
username = mentions[int(username)] if username.isdigit() else username
return await self.send_message(LangModel.act_like(sent_message_content, username), sent_message)
if command == 'generate':
sent_message = await message.channel.send('Generating...')
img = sd.generate(
mat.group('text'),
img_type='waifu' if mat.group('isWaifu') else 'normal',
width=512 if mat.group('isWaifu') else 512,
height=512 if mat.group('isWaifu') else 512,
num_inference_steps=120 if mat.group('isWaifu') else 40,
neg_prompt='lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry' if mat.group('isWaifu') else '',
),
print(img)
if isinstance(img, str):
return await self.edit_message(sent_message, img)
else:
return await self.send_image(
img[0],
sent_message
)
if command == 'find_meme':
if not self.meme_client:
return await message.channel.send('MemeDB not connected')
sent_message = await message.channel.send('Searching...')
results = self.meme_client.query(mat.group('text') + ' meme', limit=int(mat.group('n')) if mat.group('n') else 5)
if not results:
return await self.edit_message(sent_message, 'No results found')
results = set(int(response['MessageID']) for response in results)
links = [f'{i+1}. https://discord.com/channels/{message.guild.id}/{config[profile]["meme_channel_id"]}/{m_id}' for i, m_id in enumerate(results)] # type: ignore
links = '\n'.join(links)
links = 'Here are the results:\n' + links
return await self.edit_message(sent_message, links, True)
def get_matching_command(self, content: str):
for command in commands:
if mat:=re.match(command, content):
return mat
return None
if __name__ == '__main__':
client = myClient()
client.run(token)