Updated and fixed eltiebot.py, bot.py and sasl.py to use asyncio.

This commit is contained in:
Yuuki Chan 2024-02-22 18:41:50 +09:00
parent cb9a52d90a
commit a50a5f019e
3 changed files with 74 additions and 65 deletions

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio
import sys import sys
from src.bot import Bot from src.bot import Bot
@ -22,7 +23,7 @@ def main():
try: try:
print('EliteBot started successfully!') print('EliteBot started successfully!')
bot.start() asyncio.run(bot.start())
except Exception as e: except Exception as e:
print(f'Error starting EliteBot: {e}') print(f'Error starting EliteBot: {e}')
sys.exit(1) sys.exit(1)

View file

@ -1,13 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio
import importlib.util import importlib.util
import inspect import inspect
import json import json
import os import os
import socket
import ssl import ssl
import sys import sys
import time
import yaml import yaml
from src.channel_manager import ChannelManager from src.channel_manager import ChannelManager
@ -24,7 +24,8 @@ class Bot:
self.channel_manager = ChannelManager() self.channel_manager = ChannelManager()
self.logger = Logger('logs/elitebot.log') self.logger = Logger('logs/elitebot.log')
self.connected = False self.connected = False
self.ircsock = None self.reader = None
self.writer = None
self.running = True self.running = True
self.plugins = [] self.plugins = []
self.load_plugins() self.load_plugins()
@ -32,7 +33,7 @@ class Bot:
def validate_config(self, config): def validate_config(self, config):
required_fields = [ required_fields = [
['Connection', 'Port'], ['Connection', 'Port'],
['Connection' 'Hostname'], ['Connection', 'Hostname'],
['Connection', 'Nick'], ['Connection', 'Nick'],
['Connection', 'Ident'], ['Connection', 'Ident'],
['Connection', 'Name'], ['Connection', 'Name'],
@ -92,11 +93,12 @@ class Bot:
self.logger.error('Could not decode byte string with any known encoding') self.logger.error('Could not decode byte string with any known encoding')
return bytes.decode('utf-8', 'ignore') return bytes.decode('utf-8', 'ignore')
def ircsend(self, msg): async def ircsend(self, msg):
try: try:
if msg != '': if msg != '':
self.logger.info(f'Sending command: {msg}') self.logger.info(f'Sending command: {msg}')
self.ircsock.send(bytes(f'{msg}\r\n', 'UTF-8')) self.writer.write(bytes(f'{msg}\r\n', 'UTF-8'))
await self.writer.drain()
except Exception as e: except Exception as e:
self.logger.error(f'Error sending IRC message: {e}') self.logger.error(f'Error sending IRC message: {e}')
raise raise
@ -120,91 +122,97 @@ class Bot:
args.append(' '.join(parts[trailing_arg_start:])[1:]) args.append(' '.join(parts[trailing_arg_start:])[1:])
return source, command, args return source, command, args
def connect(self): async def connect(self):
try: try:
self.ircsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) ssl_context = None
if str(self.config['Connection'].get('Port'))[:1] == '+': if str(self.config['Connection'].get('Port'))[:1] == '+':
context = ssl.create_default_context() ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) # Corrected here
self.ircsock = context.wrap_socket(self.ircsock,
server_hostname=self.config['Connection'].get('Hostname'))
port = int(self.config['Connection'].get('Port')[1:])
else:
port = int(self.config['Connection'].get('Port'))
if 'BindHost' in self.config: self.reader, self.writer = await asyncio.open_connection(
self.ircsock.bind((self.config['Connection'].get('BindHost'), 0)) self.config['Connection'].get('Hostname'),
int(self.config['Connection'].get('Port')[1:]) if ssl_context else int(
self.config['Connection'].get('Port')),
ssl=ssl_context
)
self.ircsock.connect_ex((self.config['Connection'].get('Hostname'), port)) await self.ircsend(f'NICK {self.config["Connection"].get("Nick")}')
self.ircsend(f'NICK {self.config["Connection"].get("Nick")}') await self.ircsend(
self.ircsend(f'USER {self.config["Connection"].get("Ident")} * * :{self.config["Connection"].get("Name")}') f'USER {self.config["Connection"].get("Ident")} * * :{self.config["Connection"].get("Name")}')
if self.config['SASL'].get('UseSASL'): if self.config['SASL'].get('UseSASL'):
self.ircsend('CAP REQ :sasl') await self.ircsend('CAP REQ :sasl')
except Exception as e: except Exception as e:
self.logger.error(f'Error establishing connection: {e}') self.logger.error(f'Error establishing connection: {e}')
self.connected = False self.connected = False
return return
def start(self): async def start(self):
while True: while True:
if not self.connected: if not self.connected:
try: try:
self.connect() await self.connect()
self.connected = True self.connected = True
except Exception as e: except Exception as e:
self.logger.error(f'Connection error: {e}') self.logger.error(f'Connection error: {e}')
time.sleep(60) await asyncio.sleep(60)
continue continue
try: try:
recvText = self.ircsock.recv(2048) recvText = await self.reader.read(2048)
if not recvText: if not recvText:
self.connected = False self.connected = False
continue continue
ircmsg = self.decode(recvText) ircmsg = self.decode(recvText)
source, command, args = self.parse_message(ircmsg)
if '\r\n' in ircmsg:
messages = ircmsg.split('\r\n')
elif '\n' in ircmsg:
messages = ircmsg.split('\n')
else:
messages = [ircmsg] # If no newline characters, treat the whole message as a single message
for message in messages:
if message: # Check if message is not empty
source, command, args = self.parse_message(message)
self.logger.debug(f'Received: source: {source} | command: {command} | args: {args}') self.logger.debug(f'Received: source: {source} | command: {command} | args: {args}')
if command == 'PING': match command:
case 'PING':
nospoof = args[0][1:] if args[0].startswith(':') else args[0] nospoof = args[0][1:] if args[0].startswith(':') else args[0]
self.ircsend(f'PONG :{nospoof}') await self.ircsend(f'PONG :{nospoof}')
continue continue
case 'PRIVMSG':
if command == 'PRIVMSG':
channel, message = args[0], args[1] channel, message = args[0], args[1]
source_nick = source.split('!')[0] source_nick = source.split('!')[0]
if message.startswith('&'): if message.startswith('&'):
cmd, *cmd_args = message[1:].split() cmd, *cmd_args = message[1:].split()
self.handle_command(source_nick, channel, cmd, cmd_args) self.handle_command(source_nick, channel, cmd, cmd_args)
for plugin in self.plugins: elif args[1].startswith('\x01VERSION\x01'):
plugin.handle_message(source_nick, channel, message)
elif command == 'CAP' and args[1] == 'ACK' and 'sasl' in args[2]:
handle_sasl(self.config, self.ircsend)
elif command == 'AUTHENTICATE':
handle_authenticate(args, self.config, self.ircsend)
elif command == '903':
handle_903(self.ircsend)
if command == 'PRIVMSG' and args[1].startswith('\x01VERSION\x01'):
source_nick = source.split('!')[0] source_nick = source.split('!')[0]
self.ircsend(f'NOTICE {source_nick} :\x01VERSION EliteBot 0.1\x01') await self.ircsend(f'NOTICE {source_nick} :\x01VERSION EliteBot 0.1\x01')
if command == '001': for plugin in self.plugins:
for channel in self.channel_manager.get_channels(): await plugin.handle_message(source_nick, channel, message)
self.ircsend(f'JOIN {channel}') case 'CAP':
if args[1] == 'ACK' and 'sasl' in args[2]:
if command == 'INVITE': handle_sasl(self.config, self.ircsend)
case 'AUTHENTICATE':
handle_authenticate(args, self.config, self.ircsend)
case 'INVITE':
channel = args[1] channel = args[1]
self.ircsend(f'JOIN {channel}') await self.ircsend(f'JOIN {channel}')
self.channel_manager.save_channel(channel) self.channel_manager.save_channel(channel)
case 'VERSION':
if command == 'VERSION': await self.ircsend(f'NOTICE {source_nick} :I am a bot version 1.0.0')
self.ircsend('NOTICE', f'{source_nick} :I am a bot version 1.0.0') case '001':
await self.ircsend('JOIN #YuukiTest')
# for channel in self.channel_manager.get_channels():
# await self.ircsend(f'JOIN {channel}')
case '903':
await handle_903(self.ircsend)
case _:
continue
except Exception as e: except Exception as e:
self.logger.error(f'General error: {e}') self.logger.error(f'General error: {e}')
self.connected = False self.connected = False
@ -213,7 +221,7 @@ class Bot:
if __name__ == '__main__': if __name__ == '__main__':
try: try:
bot = Bot(sys.argv[1]) bot = Bot(sys.argv[1])
bot.start() asyncio.run(bot.start())
except KeyboardInterrupt: except KeyboardInterrupt:
print('\nEliteBot has been stopped.') print('\nEliteBot has been stopped.')
except Exception as e: except Exception as e:

View file

@ -36,7 +36,7 @@ def handle_authenticate(args, config, ircsend):
raise KeyError('SASLNICK and/or SASLPASS not found in config') raise KeyError('SASLNICK and/or SASLPASS not found in config')
def handle_903(ircsend): async def handle_903(ircsend):
""" """
Handles the 903 command by sending a CAP END command. Handles the 903 command by sending a CAP END command.