diff --git a/src/bot.py b/src/bot.py index 635d45d..dc1ec86 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 +import asyncio import importlib.util import inspect import json import os -import socket import ssl import sys -import time import yaml from src.channel_manager import ChannelManager @@ -15,7 +14,6 @@ from src.logger import Logger from src.plugin_base import PluginBase from src.sasl import handle_sasl, handle_authenticate, handle_903 - class Bot: def __init__(self, config_file): self.config = self.load_config(config_file) @@ -24,15 +22,16 @@ class Bot: self.channel_manager = ChannelManager() self.logger = Logger('logs/elitebot.log') self.connected = False - self.ircsock = None + self.reader = None + self.writer = None self.running = True self.plugins = [] self.load_plugins() - + def validate_config(self, config): required_fields = [ ['Connection', 'Port'], - ['Connection' 'Hostname'], + ['Connection', 'Hostname'], ['Connection', 'Nick'], ['Connection', 'Ident'], ['Connection', 'Name'], @@ -92,11 +91,12 @@ class Bot: self.logger.error('Could not decode byte string with any known encoding') return bytes.decode('utf-8', 'ignore') - def ircsend(self, msg): + async def ircsend(self, msg): try: if msg != '': self.logger.info(f'Sending command: {msg}') - self.ircsock.send(bytes(f'{msg}\r\n', 'UTF-8')) + self.writer.write(bytes(f'{msg}\r\n', 'UTF-8')) + await self.writer.drain() except Exception as e: self.logger.error(f'Error sending IRC message: {e}') raise @@ -120,55 +120,61 @@ class Bot: args.append(' '.join(parts[trailing_arg_start:])[1:]) return source, command, args - def connect(self): + async def connect(self): try: - self.ircsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - + ssl_context = None if str(self.config['Connection'].get('Port'))[:1] == '+': - context = ssl.create_default_context() - self.ircsock = context.wrap_socket(self.ircsock, - server_hostname=self.config['Connection'].get('Hostname')) - port = int(self.config['Connection'].get('Port')[1:]) - else: - port = int(self.config['Connection'].get('Port')) + ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) # Corrected here + + self.reader, self.writer = await asyncio.open_connection( + self.config['Connection'].get('Hostname'), + int(self.config['Connection'].get('Port')[1:]) if ssl_context else int(self.config['Connection'].get('Port')), + ssl=ssl_context + ) - if 'BindHost' in self.config: - self.ircsock.bind((self.config['Connection'].get('BindHost'), 0)) - - self.ircsock.connect_ex((self.config['Connection'].get('Hostname'), port)) - self.ircsend(f'NICK {self.config["Connection"].get("Nick")}') - self.ircsend(f'USER {self.config["Connection"].get("Ident")} * * :{self.config["Connection"].get("Name")}') + await self.ircsend(f'NICK {self.config["Connection"].get("Nick")}') + await self.ircsend(f'USER {self.config["Connection"].get("Ident")} * * :{self.config["Connection"].get("Name")}') if self.config['SASL'].get('UseSASL'): - self.ircsend('CAP REQ :sasl') + await self.ircsend('CAP REQ :sasl') except Exception as e: self.logger.error(f'Error establishing connection: {e}') self.connected = False return - - def start(self): + + async def start(self): while True: if not self.connected: try: - self.connect() + await self.connect() self.connected = True except Exception as e: self.logger.error(f'Connection error: {e}') - time.sleep(60) + await asyncio.sleep(60) continue try: - recvText = self.ircsock.recv(2048) + recvText = await self.reader.read(2048) if not recvText: self.connected = False continue ircmsg = self.decode(recvText) - source, command, args = self.parse_message(ircmsg) - self.logger.debug(f'Received: source: {source} | command: {command} | args: {args}') + + if '\r\n' in ircmsg: + messages = ircmsg.split('\r\n') + elif '\n' in ircmsg: + messages = ircmsg.split('\n') + else: + messages = [ircmsg] # If no newline characters, treat the whole message as a single message + + for message in messages: + if message: # Check if message is not empty + source, command, args = self.parse_message(message) + self.logger.debug(f'Received: source: {source} | command: {command} | args: {args}') if command == 'PING': nospoof = args[0][1:] if args[0].startswith(':') else args[0] - self.ircsend(f'PONG :{nospoof}') + await self.ircsend(f'PONG :{nospoof}') continue if command == 'PRIVMSG': @@ -191,7 +197,7 @@ class Bot: if command == 'PRIVMSG' and args[1].startswith('\x01VERSION\x01'): source_nick = source.split('!')[0] - self.ircsend(f'NOTICE {source_nick} :\x01VERSION EliteBot 0.1\x01') + await self.ircsend(f'NOTICE {source_nick} :\x01VERSION EliteBot 0.1\x01') if command == '001': for channel in self.channel_manager.get_channels(): @@ -199,7 +205,7 @@ class Bot: if command == 'INVITE': channel = args[1] - self.ircsend(f'JOIN {channel}') + await self.ircsend(f'JOIN {channel}') self.channel_manager.save_channel(channel) if command == 'VERSION': @@ -209,12 +215,11 @@ class Bot: self.logger.error(f'General error: {e}') self.connected = False - if __name__ == '__main__': try: bot = Bot(sys.argv[1]) - bot.start() + asyncio.run(bot.start()) except KeyboardInterrupt: print('\nEliteBot has been stopped.') except Exception as e: - print(f'An unexpected error occurred: {e}') + print(f'An unexpected error occurred: {e}') \ No newline at end of file