Update src/bot.py

This commit is contained in:
Colby 2024-02-21 17:30:22 +01:00
parent 2a8da606c2
commit 04816f0664

View file

@ -1,13 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio
import importlib.util import importlib.util
import inspect import inspect
import json import json
import os import os
import socket
import ssl import ssl
import sys import sys
import time
import yaml import yaml
from src.channel_manager import ChannelManager from src.channel_manager import ChannelManager
@ -15,7 +14,6 @@ from src.logger import Logger
from src.plugin_base import PluginBase from src.plugin_base import PluginBase
from src.sasl import handle_sasl, handle_authenticate, handle_903 from src.sasl import handle_sasl, handle_authenticate, handle_903
class Bot: class Bot:
def __init__(self, config_file): def __init__(self, config_file):
self.config = self.load_config(config_file) self.config = self.load_config(config_file)
@ -24,7 +22,8 @@ class Bot:
self.channel_manager = ChannelManager() self.channel_manager = ChannelManager()
self.logger = Logger('logs/elitebot.log') self.logger = Logger('logs/elitebot.log')
self.connected = False self.connected = False
self.ircsock = None self.reader = None
self.writer = None
self.running = True self.running = True
self.plugins = [] self.plugins = []
self.load_plugins() self.load_plugins()
@ -32,7 +31,7 @@ class Bot:
def validate_config(self, config): def validate_config(self, config):
required_fields = [ required_fields = [
['Connection', 'Port'], ['Connection', 'Port'],
['Connection' 'Hostname'], ['Connection', 'Hostname'],
['Connection', 'Nick'], ['Connection', 'Nick'],
['Connection', 'Ident'], ['Connection', 'Ident'],
['Connection', 'Name'], ['Connection', 'Name'],
@ -92,11 +91,12 @@ class Bot:
self.logger.error('Could not decode byte string with any known encoding') self.logger.error('Could not decode byte string with any known encoding')
return bytes.decode('utf-8', 'ignore') return bytes.decode('utf-8', 'ignore')
def ircsend(self, msg): async def ircsend(self, msg):
try: try:
if msg != '': if msg != '':
self.logger.info(f'Sending command: {msg}') self.logger.info(f'Sending command: {msg}')
self.ircsock.send(bytes(f'{msg}\r\n', 'UTF-8')) self.writer.write(bytes(f'{msg}\r\n', 'UTF-8'))
await self.writer.drain()
except Exception as e: except Exception as e:
self.logger.error(f'Error sending IRC message: {e}') self.logger.error(f'Error sending IRC message: {e}')
raise raise
@ -120,55 +120,61 @@ class Bot:
args.append(' '.join(parts[trailing_arg_start:])[1:]) args.append(' '.join(parts[trailing_arg_start:])[1:])
return source, command, args return source, command, args
def connect(self): async def connect(self):
try: try:
self.ircsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) ssl_context = None
if str(self.config['Connection'].get('Port'))[:1] == '+': if str(self.config['Connection'].get('Port'))[:1] == '+':
context = ssl.create_default_context() ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) # Corrected here
self.ircsock = context.wrap_socket(self.ircsock,
server_hostname=self.config['Connection'].get('Hostname'))
port = int(self.config['Connection'].get('Port')[1:])
else:
port = int(self.config['Connection'].get('Port'))
if 'BindHost' in self.config: self.reader, self.writer = await asyncio.open_connection(
self.ircsock.bind((self.config['Connection'].get('BindHost'), 0)) self.config['Connection'].get('Hostname'),
int(self.config['Connection'].get('Port')[1:]) if ssl_context else int(self.config['Connection'].get('Port')),
ssl=ssl_context
)
self.ircsock.connect_ex((self.config['Connection'].get('Hostname'), port)) await self.ircsend(f'NICK {self.config["Connection"].get("Nick")}')
self.ircsend(f'NICK {self.config["Connection"].get("Nick")}') await self.ircsend(f'USER {self.config["Connection"].get("Ident")} * * :{self.config["Connection"].get("Name")}')
self.ircsend(f'USER {self.config["Connection"].get("Ident")} * * :{self.config["Connection"].get("Name")}')
if self.config['SASL'].get('UseSASL'): if self.config['SASL'].get('UseSASL'):
self.ircsend('CAP REQ :sasl') await self.ircsend('CAP REQ :sasl')
except Exception as e: except Exception as e:
self.logger.error(f'Error establishing connection: {e}') self.logger.error(f'Error establishing connection: {e}')
self.connected = False self.connected = False
return return
def start(self): async def start(self):
while True: while True:
if not self.connected: if not self.connected:
try: try:
self.connect() await self.connect()
self.connected = True self.connected = True
except Exception as e: except Exception as e:
self.logger.error(f'Connection error: {e}') self.logger.error(f'Connection error: {e}')
time.sleep(60) await asyncio.sleep(60)
continue continue
try: try:
recvText = self.ircsock.recv(2048) recvText = await self.reader.read(2048)
if not recvText: if not recvText:
self.connected = False self.connected = False
continue continue
ircmsg = self.decode(recvText) ircmsg = self.decode(recvText)
source, command, args = self.parse_message(ircmsg)
self.logger.debug(f'Received: source: {source} | command: {command} | args: {args}') if '\r\n' in ircmsg:
messages = ircmsg.split('\r\n')
elif '\n' in ircmsg:
messages = ircmsg.split('\n')
else:
messages = [ircmsg] # If no newline characters, treat the whole message as a single message
for message in messages:
if message: # Check if message is not empty
source, command, args = self.parse_message(message)
self.logger.debug(f'Received: source: {source} | command: {command} | args: {args}')
if command == 'PING': if command == 'PING':
nospoof = args[0][1:] if args[0].startswith(':') else args[0] nospoof = args[0][1:] if args[0].startswith(':') else args[0]
self.ircsend(f'PONG :{nospoof}') await self.ircsend(f'PONG :{nospoof}')
continue continue
if command == 'PRIVMSG': if command == 'PRIVMSG':
@ -191,7 +197,7 @@ class Bot:
if command == 'PRIVMSG' and args[1].startswith('\x01VERSION\x01'): if command == 'PRIVMSG' and args[1].startswith('\x01VERSION\x01'):
source_nick = source.split('!')[0] source_nick = source.split('!')[0]
self.ircsend(f'NOTICE {source_nick} :\x01VERSION EliteBot 0.1\x01') await self.ircsend(f'NOTICE {source_nick} :\x01VERSION EliteBot 0.1\x01')
if command == '001': if command == '001':
for channel in self.channel_manager.get_channels(): for channel in self.channel_manager.get_channels():
@ -199,7 +205,7 @@ class Bot:
if command == 'INVITE': if command == 'INVITE':
channel = args[1] channel = args[1]
self.ircsend(f'JOIN {channel}') await self.ircsend(f'JOIN {channel}')
self.channel_manager.save_channel(channel) self.channel_manager.save_channel(channel)
if command == 'VERSION': if command == 'VERSION':
@ -209,11 +215,10 @@ class Bot:
self.logger.error(f'General error: {e}') self.logger.error(f'General error: {e}')
self.connected = False self.connected = False
if __name__ == '__main__': if __name__ == '__main__':
try: try:
bot = Bot(sys.argv[1]) bot = Bot(sys.argv[1])
bot.start() asyncio.run(bot.start())
except KeyboardInterrupt: except KeyboardInterrupt:
print('\nEliteBot has been stopped.') print('\nEliteBot has been stopped.')
except Exception as e: except Exception as e: