290 lines
8.7 KiB
Python
290 lines
8.7 KiB
Python
import asyncio
|
|
from collections import defaultdict
|
|
from slixmpp import ClientXMPP
|
|
from slixmpp.stanza import Message
|
|
|
|
class Lifo(list):
|
|
"""Limited size LIFO array to store messages and urls."""
|
|
|
|
def __init__(self, size):
|
|
"""Initialize the LIFO array."""
|
|
super().__init__()
|
|
self.size = size
|
|
|
|
def add(self, item):
|
|
"""Add an item to the LIFO array."""
|
|
self.insert(0, item)
|
|
if len(self) > self.size:
|
|
self.pop()
|
|
|
|
def create_messages_dict():
|
|
return defaultdict(
|
|
lambda: {
|
|
"messages": Lifo(100),
|
|
"links": Lifo(10),
|
|
"previews": Lifo(10),
|
|
}
|
|
)
|
|
|
|
class RegexCmd:
|
|
"""Regex command decorator."""
|
|
|
|
def __init__(self, bot, pattern, block=False, matcher=None):
|
|
"""Initialize the decorator."""
|
|
self.pattern = pattern
|
|
self.bot = bot
|
|
self.block = block
|
|
self.matcher = matcher
|
|
|
|
def __call__(self, func):
|
|
"""Call the decorator."""
|
|
self.bot.regex_cmds.append(self)
|
|
self.func = func
|
|
return self
|
|
|
|
class AngelBot(ClientXMPP):
|
|
"""AngelBot class."""
|
|
|
|
def __init__(self, jid, password, nick="angel", autojoin=None,
|
|
youtube_links=None,
|
|
invidious_instances=None):
|
|
"""Initialize the bot."""
|
|
super().__init__(jid, password)
|
|
self.jid = jid
|
|
self.nick = nick
|
|
self.autojoin = autojoin or []
|
|
self.invidious_instances = invidious_instances or []
|
|
self.youtube_links = youtube_links or []
|
|
self.messages = create_messages_dict()
|
|
self.register_plugins()
|
|
self.add_handlers()
|
|
|
|
def reply(self, msg, body):
|
|
"""Reply to a message."""
|
|
self.save_message_history(msg)
|
|
self.raw_reply(msg, body)
|
|
|
|
def raw_reply(self, msg, body):
|
|
"""Reply to a message without saving history."""
|
|
self.send_message(
|
|
mto=msg["from"].bare,
|
|
mbody=body,
|
|
mtype=msg["type"],
|
|
)
|
|
|
|
def save_message_history(self, msg):
|
|
"""Save the history of messages."""
|
|
sender = msg["from"].bare
|
|
self.messages[sender]["messages"].add(msg["body"])
|
|
|
|
def get_message_history(self, msg):
|
|
"""Get the messages from the sender."""
|
|
sender = msg["from"].bare
|
|
return self.messages[sender]["messages"]
|
|
|
|
def save_link_history(self, msg, url):
|
|
"""Save the history of links."""
|
|
sender = msg["from"].bare
|
|
self.messages[sender]["links"].add(url)
|
|
|
|
def get_link_history(self, msg):
|
|
"""Get the links from the sender."""
|
|
sender = msg["from"].bare
|
|
return self.messages[sender]["links"]
|
|
|
|
def save_preview_history(self, msg, preview):
|
|
"""Save the history of previews."""
|
|
sender = msg["from"].bare
|
|
self.messages[sender]["previews"].add(preview)
|
|
|
|
def get_preview_history(self, msg):
|
|
"""Get the previews from the sender."""
|
|
sender = msg["from"].bare
|
|
return self.messages[sender]["previews"]
|
|
|
|
regex_cmds = []
|
|
|
|
async def embed_file(self, sender, mtype, ftype, fname, outfile):
|
|
"""Embed a file and send the result to the sender."""
|
|
furl = await self.plugin["xep_0363"].upload_file(
|
|
fname, content_type=ftype, input_file=outfile
|
|
)
|
|
self.messages[sender]["links"].add(furl)
|
|
message = self.make_message(sender)
|
|
message["body"] = furl
|
|
message["type"] = mtype
|
|
message["oob"]["url"] = furl
|
|
message.send()
|
|
|
|
def register_plugins(self):
|
|
self.register_plugin("xep_0030")
|
|
self.register_plugin("xep_0060")
|
|
self.register_plugin("xep_0054")
|
|
self.register_plugin("xep_0045")
|
|
self.register_plugin("xep_0066")
|
|
self.register_plugin("xep_0084")
|
|
self.register_plugin("xep_0153")
|
|
self.register_plugin("xep_0363")
|
|
|
|
def add_handlers(self):
|
|
self.add_event_handler("session_start", self.session_start)
|
|
self.add_event_handler("message", self.message)
|
|
self.add_event_handler("groupchat_message", self.muc_message)
|
|
# self.add_event_handler("vcard_avatar_update", self.debug_event)
|
|
# self.add_event_handler("stream_error", self.debug_event)
|
|
self.add_event_handler("disconnected", lambda _: self.connect())
|
|
|
|
async def session_start(self, event):
|
|
"""Start the bot."""
|
|
self.send_presence()
|
|
await self.get_roster()
|
|
await self.update_info()
|
|
for channel in self.autojoin:
|
|
try:
|
|
self.plugin["xep_0045"].join_muc(channel, self.nick)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
async def update_info(self):
|
|
"""Update the bot info."""
|
|
with open("angel.png", "rb") as avatar_file:
|
|
avatar = avatar_file.read()
|
|
|
|
avatar_type = "image/png"
|
|
avatar_id = self.plugin["xep_0084"].generate_id(avatar)
|
|
avatar_bytes = len(avatar)
|
|
|
|
asyncio.gather(self.plugin["xep_0084"].publish_avatar(avatar))
|
|
|
|
asyncio.gather(
|
|
self.plugin["xep_0153"].set_avatar(
|
|
avatar=avatar,
|
|
mtype=avatar_type,
|
|
)
|
|
)
|
|
|
|
info = {
|
|
"id": avatar_id,
|
|
"type": avatar_type,
|
|
"bytes": avatar_bytes,
|
|
}
|
|
|
|
asyncio.gather(self.plugin["xep_0084"].publish_avatar_metadata([info]))
|
|
|
|
vcard = self.plugin["xep_0054"].make_vcard()
|
|
|
|
vcard["URL"] = "https://wiki.kalli.st/Angel"
|
|
vcard["DESC"] = "Angel is a bot that can do link previews and embeds."
|
|
vcard["NICKNAME"] = "Angel"
|
|
vcard["FN"] = "Angel"
|
|
|
|
asyncio.gather(self.plugin["xep_0054"].publish_vcard(vcard))
|
|
|
|
async def message(self, msg):
|
|
"""Process a message."""
|
|
if msg["type"] in ("chat", "normal"):
|
|
edit = "urn:xmpp:message-correct:0" in str(msg)
|
|
if edit:
|
|
return
|
|
|
|
mtype = msg["type"]
|
|
sender = msg["from"].bare
|
|
|
|
self.process_commands(msg, sender, mtype)
|
|
|
|
async def muc_message(self, msg):
|
|
"""Process a groupchat message."""
|
|
if msg["type"] in ("groupchat", "normal"):
|
|
edit = "urn:xmpp:message-correct:0" in str(msg)
|
|
|
|
if edit:
|
|
return
|
|
|
|
if msg["mucnick"] == self.nick:
|
|
return
|
|
|
|
mtype = msg["type"]
|
|
sender = msg["from"].bare
|
|
|
|
self.process_commands(msg, sender, mtype)
|
|
|
|
|
|
def process_commands(self, msg, sender, mtype):
|
|
"""Process commands."""
|
|
for cmd in self.regex_cmds:
|
|
if cmd.pattern.match(msg["body"]):
|
|
ctx = CommandContext(self, msg)
|
|
if cmd.matcher and not cmd.matcher(ctx):
|
|
continue
|
|
cmd.func(ctx)
|
|
if(cmd.block):
|
|
return
|
|
self.messages[sender]["messages"].add(msg["body"])
|
|
|
|
class CommandContext:
|
|
"""Command context."""
|
|
|
|
def __init__(self, bot: AngelBot, msg: Message):
|
|
"""Initialize the command context."""
|
|
self.bot = bot
|
|
self.msg = msg
|
|
|
|
def reply(self, body):
|
|
"""Get the reply function."""
|
|
return self.bot.reply(self.msg, body)
|
|
|
|
@property
|
|
def sender(self):
|
|
"""Get the sender of the message."""
|
|
return self.msg["from"].bare
|
|
|
|
@property
|
|
def mtype(self):
|
|
"""Get the message type."""
|
|
return self.msg["type"]
|
|
|
|
@property
|
|
def body(self):
|
|
"""Get the message body."""
|
|
return self.msg["body"]
|
|
|
|
@property
|
|
def raw_reply(self, body):
|
|
"""Get the raw reply function."""
|
|
return self.bot.raw_reply(self.msg, body)
|
|
|
|
@property
|
|
def message_history(self):
|
|
"""Get the message history."""
|
|
return self.bot.get_message_history(self.msg)
|
|
|
|
@property
|
|
def link_history(self):
|
|
"""Get the link history."""
|
|
return self.bot.get_link_history(self.msg)
|
|
|
|
@property
|
|
def preview_history(self):
|
|
"""Get the preview history."""
|
|
return self.bot.get_preview_history(self.msg)
|
|
|
|
def save_link_history(self, url):
|
|
"""Save the link history."""
|
|
self.bot.save_link_history(self.msg, url)
|
|
|
|
def save_message_history(self):
|
|
"""Save the message history."""
|
|
self.bot.save_message_history(self.msg)
|
|
|
|
def save_preview_history(self, preview):
|
|
"""Save the preview history."""
|
|
self.bot.save_preview_history(self.msg, preview)
|
|
|
|
@property
|
|
def is_oob(self):
|
|
"""Check if the message is OOB."""
|
|
return bool(self.msg["oob"]["url"])
|
|
|
|
def embed_file(self, ftype, fname, outfile):
|
|
"""Embed a file and send the result to the sender."""
|
|
asyncio.gather(self.bot.embed_file(self.sender, self.mtype, ftype, fname, outfile))
|