diff --git a/elitebot.py b/elitebot.py index 18fa6b8..7524b3d 100644 --- a/elitebot.py +++ b/elitebot.py @@ -1,12 +1,15 @@ #!/usr/bin/env python3 import asyncio +import os import sys from src.bot import Bot def main(): + os.makedirs('data', exist_ok=True) + if len(sys.argv) < 2: print('Usage: python elitebot.py ') sys.exit(1) diff --git a/src/bot.py b/src/bot.py index 2dd69f6..8810bb8 100644 --- a/src/bot.py +++ b/src/bot.py @@ -208,9 +208,8 @@ class Bot: case 'VERSION': await self.ircsend(f'NOTICE {source_nick} :I am a bot version 1.0.0') case '001': - await self.ircsend(f'JOIN #YuukiTest') - # for channel in self.channel_manager.get_channels(): - # await self.ircsend(f'JOIN {channel}') + for channel in self.channel_manager.get_channels(): + await self.ircsend(f'JOIN {channel[1]}') case '903': await handle_903(self.ircsend) case _: diff --git a/src/channel_manager.py b/src/channel_manager.py index b65a6c1..50938e7 100644 --- a/src/channel_manager.py +++ b/src/channel_manager.py @@ -1,49 +1,30 @@ #!/usr/bin/env python3 -import json -import os -from os import path +from src.db import Database +from sqlalchemy import Table, Column, Integer, String, Boolean, MetaData + +meta = MetaData() +channel_table = Table( + 'Channels', + meta, + Column('id', Integer, primary_key=True, autoincrement=True), + Column('channel', String, unique=True, nullable=False), + Column('autojoin', Boolean, default=True), +) +db = Database(channel_table, meta) class ChannelManager: def __init__(self): - self.channels = self._load_channels() + db.create_table(channel_table.name) - def _load_channels(self): - os.makedirs('data', exist_ok=True) - if not path.exists('data/channels.json'): - with open('data/channels.json', 'w') as f: - json.dump([], f) - return [] - try: - with open('data/channels.json', 'r') as f: - return json.load(f) - except json.JSONDecodeError as e: - print(f'Error decoding JSON: {e}') - return [] - except Exception as e: - print(f'Error loading channels: {e}') - return [] + self.channels = db._load_channels() def save_channel(self, channel): - channel = channel.lstrip(':') - if channel not in self.channels: - self.channels.append(channel) - self._write_channels() + db._save_channel(channel) def remove_channel(self, channel): - channel = channel.lstrip(':') - if channel in self.channels: - self.channels.remove(channel) - self._write_channels() - - def _write_channels(self): - os.makedirs('data', exist_ok=True) - try: - with open('data/channels.json', 'w') as f: - json.dump(self.channels, f) - except Exception as e: - print(f'Error saving channels: {e}') + db._remove_channel(channel) def get_channels(self): return self.channels diff --git a/src/db.py b/src/db.py index 5ad154e..96a4cc5 100644 --- a/src/db.py +++ b/src/db.py @@ -1,4 +1,4 @@ -from sqlalchemy import create_engine, Table, MetaData, inspect, update, select +from sqlalchemy import create_engine, Table, MetaData, inspect, update, select, insert, delete from sqlalchemy_utils import database_exists, create_database @@ -37,3 +37,31 @@ class Database: return conn.execute(select(self.table).where(self.table.c.name == user)).fetchone()[index] else: return -1 + + def _load_channels(self): + with self.engine.connect() as conn: + return conn.execute(select(self.table)).fetchall() + + def _save_channel(self, channel: str): + with self.engine.connect() as conn: + stmt = select(self.table).where(self.table.c.channel == channel) + cnt = len(conn.execute(stmt).fetchall()) + + if cnt == 0: + conn.execute(( + insert(self.table). + values({'channel': channel}) + )) + conn.commit() + + def _remove_channel(self, channel: str): + with self.engine.connect() as conn: + stmt = select(self.table).where(self.table.c.channel == channel) + cnt = len(conn.execute(stmt).fetchall()) + + if cnt == 1: + conn.execute(( + delete(self.table). + where(self.table.c.channel == channel) + )) + conn.commit()