Update src/bot.py
This commit is contained in:
parent
2a8da606c2
commit
04816f0664
1 changed files with 42 additions and 37 deletions
69
src/bot.py
69
src/bot.py
|
@ -1,13 +1,12 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import socket
|
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from src.channel_manager import ChannelManager
|
from src.channel_manager import ChannelManager
|
||||||
|
@ -15,7 +14,6 @@ from src.logger import Logger
|
||||||
from src.plugin_base import PluginBase
|
from src.plugin_base import PluginBase
|
||||||
from src.sasl import handle_sasl, handle_authenticate, handle_903
|
from src.sasl import handle_sasl, handle_authenticate, handle_903
|
||||||
|
|
||||||
|
|
||||||
class Bot:
|
class Bot:
|
||||||
def __init__(self, config_file):
|
def __init__(self, config_file):
|
||||||
self.config = self.load_config(config_file)
|
self.config = self.load_config(config_file)
|
||||||
|
@ -24,7 +22,8 @@ class Bot:
|
||||||
self.channel_manager = ChannelManager()
|
self.channel_manager = ChannelManager()
|
||||||
self.logger = Logger('logs/elitebot.log')
|
self.logger = Logger('logs/elitebot.log')
|
||||||
self.connected = False
|
self.connected = False
|
||||||
self.ircsock = None
|
self.reader = None
|
||||||
|
self.writer = None
|
||||||
self.running = True
|
self.running = True
|
||||||
self.plugins = []
|
self.plugins = []
|
||||||
self.load_plugins()
|
self.load_plugins()
|
||||||
|
@ -32,7 +31,7 @@ class Bot:
|
||||||
def validate_config(self, config):
|
def validate_config(self, config):
|
||||||
required_fields = [
|
required_fields = [
|
||||||
['Connection', 'Port'],
|
['Connection', 'Port'],
|
||||||
['Connection' 'Hostname'],
|
['Connection', 'Hostname'],
|
||||||
['Connection', 'Nick'],
|
['Connection', 'Nick'],
|
||||||
['Connection', 'Ident'],
|
['Connection', 'Ident'],
|
||||||
['Connection', 'Name'],
|
['Connection', 'Name'],
|
||||||
|
@ -92,11 +91,12 @@ class Bot:
|
||||||
self.logger.error('Could not decode byte string with any known encoding')
|
self.logger.error('Could not decode byte string with any known encoding')
|
||||||
return bytes.decode('utf-8', 'ignore')
|
return bytes.decode('utf-8', 'ignore')
|
||||||
|
|
||||||
def ircsend(self, msg):
|
async def ircsend(self, msg):
|
||||||
try:
|
try:
|
||||||
if msg != '':
|
if msg != '':
|
||||||
self.logger.info(f'Sending command: {msg}')
|
self.logger.info(f'Sending command: {msg}')
|
||||||
self.ircsock.send(bytes(f'{msg}\r\n', 'UTF-8'))
|
self.writer.write(bytes(f'{msg}\r\n', 'UTF-8'))
|
||||||
|
await self.writer.drain()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f'Error sending IRC message: {e}')
|
self.logger.error(f'Error sending IRC message: {e}')
|
||||||
raise
|
raise
|
||||||
|
@ -120,55 +120,61 @@ class Bot:
|
||||||
args.append(' '.join(parts[trailing_arg_start:])[1:])
|
args.append(' '.join(parts[trailing_arg_start:])[1:])
|
||||||
return source, command, args
|
return source, command, args
|
||||||
|
|
||||||
def connect(self):
|
async def connect(self):
|
||||||
try:
|
try:
|
||||||
self.ircsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
ssl_context = None
|
||||||
|
|
||||||
if str(self.config['Connection'].get('Port'))[:1] == '+':
|
if str(self.config['Connection'].get('Port'))[:1] == '+':
|
||||||
context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) # Corrected here
|
||||||
self.ircsock = context.wrap_socket(self.ircsock,
|
|
||||||
server_hostname=self.config['Connection'].get('Hostname'))
|
|
||||||
port = int(self.config['Connection'].get('Port')[1:])
|
|
||||||
else:
|
|
||||||
port = int(self.config['Connection'].get('Port'))
|
|
||||||
|
|
||||||
if 'BindHost' in self.config:
|
self.reader, self.writer = await asyncio.open_connection(
|
||||||
self.ircsock.bind((self.config['Connection'].get('BindHost'), 0))
|
self.config['Connection'].get('Hostname'),
|
||||||
|
int(self.config['Connection'].get('Port')[1:]) if ssl_context else int(self.config['Connection'].get('Port')),
|
||||||
|
ssl=ssl_context
|
||||||
|
)
|
||||||
|
|
||||||
self.ircsock.connect_ex((self.config['Connection'].get('Hostname'), port))
|
await self.ircsend(f'NICK {self.config["Connection"].get("Nick")}')
|
||||||
self.ircsend(f'NICK {self.config["Connection"].get("Nick")}')
|
await self.ircsend(f'USER {self.config["Connection"].get("Ident")} * * :{self.config["Connection"].get("Name")}')
|
||||||
self.ircsend(f'USER {self.config["Connection"].get("Ident")} * * :{self.config["Connection"].get("Name")}')
|
|
||||||
if self.config['SASL'].get('UseSASL'):
|
if self.config['SASL'].get('UseSASL'):
|
||||||
self.ircsend('CAP REQ :sasl')
|
await self.ircsend('CAP REQ :sasl')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f'Error establishing connection: {e}')
|
self.logger.error(f'Error establishing connection: {e}')
|
||||||
self.connected = False
|
self.connected = False
|
||||||
return
|
return
|
||||||
|
|
||||||
def start(self):
|
async def start(self):
|
||||||
while True:
|
while True:
|
||||||
if not self.connected:
|
if not self.connected:
|
||||||
try:
|
try:
|
||||||
self.connect()
|
await self.connect()
|
||||||
self.connected = True
|
self.connected = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f'Connection error: {e}')
|
self.logger.error(f'Connection error: {e}')
|
||||||
time.sleep(60)
|
await asyncio.sleep(60)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
recvText = self.ircsock.recv(2048)
|
recvText = await self.reader.read(2048)
|
||||||
if not recvText:
|
if not recvText:
|
||||||
self.connected = False
|
self.connected = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ircmsg = self.decode(recvText)
|
ircmsg = self.decode(recvText)
|
||||||
source, command, args = self.parse_message(ircmsg)
|
|
||||||
|
if '\r\n' in ircmsg:
|
||||||
|
messages = ircmsg.split('\r\n')
|
||||||
|
elif '\n' in ircmsg:
|
||||||
|
messages = ircmsg.split('\n')
|
||||||
|
else:
|
||||||
|
messages = [ircmsg] # If no newline characters, treat the whole message as a single message
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message: # Check if message is not empty
|
||||||
|
source, command, args = self.parse_message(message)
|
||||||
self.logger.debug(f'Received: source: {source} | command: {command} | args: {args}')
|
self.logger.debug(f'Received: source: {source} | command: {command} | args: {args}')
|
||||||
|
|
||||||
if command == 'PING':
|
if command == 'PING':
|
||||||
nospoof = args[0][1:] if args[0].startswith(':') else args[0]
|
nospoof = args[0][1:] if args[0].startswith(':') else args[0]
|
||||||
self.ircsend(f'PONG :{nospoof}')
|
await self.ircsend(f'PONG :{nospoof}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if command == 'PRIVMSG':
|
if command == 'PRIVMSG':
|
||||||
|
@ -191,7 +197,7 @@ class Bot:
|
||||||
|
|
||||||
if command == 'PRIVMSG' and args[1].startswith('\x01VERSION\x01'):
|
if command == 'PRIVMSG' and args[1].startswith('\x01VERSION\x01'):
|
||||||
source_nick = source.split('!')[0]
|
source_nick = source.split('!')[0]
|
||||||
self.ircsend(f'NOTICE {source_nick} :\x01VERSION EliteBot 0.1\x01')
|
await self.ircsend(f'NOTICE {source_nick} :\x01VERSION EliteBot 0.1\x01')
|
||||||
|
|
||||||
if command == '001':
|
if command == '001':
|
||||||
for channel in self.channel_manager.get_channels():
|
for channel in self.channel_manager.get_channels():
|
||||||
|
@ -199,7 +205,7 @@ class Bot:
|
||||||
|
|
||||||
if command == 'INVITE':
|
if command == 'INVITE':
|
||||||
channel = args[1]
|
channel = args[1]
|
||||||
self.ircsend(f'JOIN {channel}')
|
await self.ircsend(f'JOIN {channel}')
|
||||||
self.channel_manager.save_channel(channel)
|
self.channel_manager.save_channel(channel)
|
||||||
|
|
||||||
if command == 'VERSION':
|
if command == 'VERSION':
|
||||||
|
@ -209,11 +215,10 @@ class Bot:
|
||||||
self.logger.error(f'General error: {e}')
|
self.logger.error(f'General error: {e}')
|
||||||
self.connected = False
|
self.connected = False
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
try:
|
try:
|
||||||
bot = Bot(sys.argv[1])
|
bot = Bot(sys.argv[1])
|
||||||
bot.start()
|
asyncio.run(bot.start())
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('\nEliteBot has been stopped.')
|
print('\nEliteBot has been stopped.')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Reference in a new issue