205 lines
7.7 KiB
Python
205 lines
7.7 KiB
Python
|
|
"""
|
||
|
|
JSON database manager for Discord user data storage.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import asyncio
|
||
|
|
import shutil
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, List, Optional, Any
|
||
|
|
from dataclasses import dataclass, asdict
|
||
|
|
import logging
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class UserData:
|
||
|
|
"""Data structure for storing user information."""
|
||
|
|
user_id: int
|
||
|
|
username: str
|
||
|
|
discriminator: str
|
||
|
|
display_name: Optional[str] = None
|
||
|
|
avatar_url: Optional[str] = None
|
||
|
|
banner_url: Optional[str] = None
|
||
|
|
bio: Optional[str] = None
|
||
|
|
status: Optional[str] = None
|
||
|
|
activity: Optional[str] = None
|
||
|
|
servers: List[int] = None
|
||
|
|
created_at: str = None
|
||
|
|
updated_at: str = None
|
||
|
|
|
||
|
|
def __post_init__(self):
|
||
|
|
if self.servers is None:
|
||
|
|
self.servers = []
|
||
|
|
|
||
|
|
current_time = datetime.utcnow().isoformat()
|
||
|
|
if self.created_at is None:
|
||
|
|
self.created_at = current_time
|
||
|
|
self.updated_at = current_time
|
||
|
|
|
||
|
|
|
||
|
|
class JSONDatabase:
|
||
|
|
"""JSON-based database for storing Discord user data."""
|
||
|
|
|
||
|
|
def __init__(self, database_path: str):
|
||
|
|
"""Initialize the JSON database."""
|
||
|
|
self.database_path = Path(database_path)
|
||
|
|
self.backup_path = Path("data/backups")
|
||
|
|
self.logger = logging.getLogger(__name__)
|
||
|
|
self._lock = asyncio.Lock()
|
||
|
|
self._data: Dict[str, Dict] = {}
|
||
|
|
|
||
|
|
# Ensure database directory exists
|
||
|
|
self.database_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
self.backup_path.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Load existing data
|
||
|
|
self._load_data()
|
||
|
|
|
||
|
|
def _load_data(self):
|
||
|
|
"""Load data from JSON file."""
|
||
|
|
if self.database_path.exists():
|
||
|
|
try:
|
||
|
|
with open(self.database_path, 'r', encoding='utf-8') as f:
|
||
|
|
self._data = json.load(f)
|
||
|
|
self.logger.info(f"Loaded {len(self._data)} users from database")
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error loading database: {e}")
|
||
|
|
self._data = {}
|
||
|
|
else:
|
||
|
|
self._data = {}
|
||
|
|
self.logger.info("Created new database")
|
||
|
|
|
||
|
|
async def _save_data(self):
|
||
|
|
"""Save data to JSON file."""
|
||
|
|
async with self._lock:
|
||
|
|
try:
|
||
|
|
# Create backup before saving
|
||
|
|
if self.database_path.exists():
|
||
|
|
backup_filename = f"users_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||
|
|
backup_path = self.backup_path / backup_filename
|
||
|
|
shutil.copy2(self.database_path, backup_path)
|
||
|
|
|
||
|
|
# Save data
|
||
|
|
with open(self.database_path, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(self._data, f, indent=2, ensure_ascii=False)
|
||
|
|
|
||
|
|
self.logger.debug(f"Saved {len(self._data)} users to database")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error saving database: {e}")
|
||
|
|
|
||
|
|
async def get_user(self, user_id: int) -> Optional[UserData]:
|
||
|
|
"""Get user data by ID."""
|
||
|
|
user_key = str(user_id)
|
||
|
|
if user_key in self._data:
|
||
|
|
user_dict = self._data[user_key]
|
||
|
|
return UserData(**user_dict)
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def save_user(self, user_data: UserData):
|
||
|
|
"""Save or update user data."""
|
||
|
|
user_key = str(user_data.user_id)
|
||
|
|
|
||
|
|
# If user exists, preserve created_at timestamp
|
||
|
|
if user_key in self._data:
|
||
|
|
user_data.created_at = self._data[user_key]['created_at']
|
||
|
|
|
||
|
|
# Update timestamp
|
||
|
|
user_data.updated_at = datetime.utcnow().isoformat()
|
||
|
|
|
||
|
|
# Save to memory
|
||
|
|
self._data[user_key] = asdict(user_data)
|
||
|
|
|
||
|
|
# Save to disk
|
||
|
|
await self._save_data()
|
||
|
|
|
||
|
|
self.logger.debug(f"Saved user {user_data.username}#{user_data.discriminator} ({user_data.user_id})")
|
||
|
|
|
||
|
|
async def add_server_to_user(self, user_id: int, server_id: int):
|
||
|
|
"""Add a server to user's server list."""
|
||
|
|
user_key = str(user_id)
|
||
|
|
if user_key in self._data:
|
||
|
|
if server_id not in self._data[user_key]['servers']:
|
||
|
|
self._data[user_key]['servers'].append(server_id)
|
||
|
|
self._data[user_key]['updated_at'] = datetime.utcnow().isoformat()
|
||
|
|
await self._save_data()
|
||
|
|
|
||
|
|
async def get_all_users(self) -> List[UserData]:
|
||
|
|
"""Get all users from the database."""
|
||
|
|
return [UserData(**user_dict) for user_dict in self._data.values()]
|
||
|
|
|
||
|
|
async def get_users_by_server(self, server_id: int) -> List[UserData]:
|
||
|
|
"""Get all users that are members of a specific server."""
|
||
|
|
users = []
|
||
|
|
for user_dict in self._data.values():
|
||
|
|
if server_id in user_dict.get('servers', []):
|
||
|
|
users.append(UserData(**user_dict))
|
||
|
|
return users
|
||
|
|
|
||
|
|
async def get_user_count(self) -> int:
|
||
|
|
"""Get total number of users in database."""
|
||
|
|
return len(self._data)
|
||
|
|
|
||
|
|
async def get_server_count(self) -> int:
|
||
|
|
"""Get total number of unique servers."""
|
||
|
|
servers = set()
|
||
|
|
for user_dict in self._data.values():
|
||
|
|
servers.update(user_dict.get('servers', []))
|
||
|
|
return len(servers)
|
||
|
|
|
||
|
|
async def cleanup_old_backups(self, max_backups: int = 10):
|
||
|
|
"""Clean up old backup files, keeping only the most recent ones."""
|
||
|
|
backup_files = sorted(self.backup_path.glob("users_backup_*.json"))
|
||
|
|
|
||
|
|
if len(backup_files) > max_backups:
|
||
|
|
files_to_remove = backup_files[:-max_backups]
|
||
|
|
for file_path in files_to_remove:
|
||
|
|
try:
|
||
|
|
file_path.unlink()
|
||
|
|
self.logger.info(f"Removed old backup: {file_path.name}")
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error removing backup {file_path.name}: {e}")
|
||
|
|
|
||
|
|
async def export_to_csv(self, output_path: str):
|
||
|
|
"""Export user data to CSV format."""
|
||
|
|
import csv
|
||
|
|
|
||
|
|
output_path = Path(output_path)
|
||
|
|
|
||
|
|
try:
|
||
|
|
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||
|
|
fieldnames = ['user_id', 'username', 'discriminator', 'display_name',
|
||
|
|
'avatar_url', 'bio', 'status', 'servers', 'created_at', 'updated_at']
|
||
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||
|
|
|
||
|
|
writer.writeheader()
|
||
|
|
for user_dict in self._data.values():
|
||
|
|
# Convert servers list to string
|
||
|
|
user_dict_copy = user_dict.copy()
|
||
|
|
user_dict_copy['servers'] = ','.join(map(str, user_dict.get('servers', [])))
|
||
|
|
writer.writerow(user_dict_copy)
|
||
|
|
|
||
|
|
self.logger.info(f"Exported {len(self._data)} users to {output_path}")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error exporting to CSV: {e}")
|
||
|
|
|
||
|
|
async def get_statistics(self) -> Dict[str, Any]:
|
||
|
|
"""Get database statistics."""
|
||
|
|
stats = {
|
||
|
|
'total_users': await self.get_user_count(),
|
||
|
|
'total_servers': await self.get_server_count(),
|
||
|
|
'database_size': self.database_path.stat().st_size if self.database_path.exists() else 0
|
||
|
|
}
|
||
|
|
|
||
|
|
# Most active servers
|
||
|
|
server_counts = {}
|
||
|
|
for user_dict in self._data.values():
|
||
|
|
for server_id in user_dict.get('servers', []):
|
||
|
|
server_counts[server_id] = server_counts.get(server_id, 0) + 1
|
||
|
|
|
||
|
|
stats['most_active_servers'] = sorted(server_counts.items(),
|
||
|
|
key=lambda x: x[1], reverse=True)[:10]
|
||
|
|
|
||
|
|
return stats
|