import logging from datetime import datetime from typing import List, Dict, Any, Optional from pymongo import MongoClient from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError from config import MONGO_CONFIG logger = logging.getLogger(__name__) class DatabaseManager: def __init__(self): self.client: Optional[MongoClient] = None self.db = None self.collections = {} self.energy_collections_cache = {} # Cache for dynamically created energy data collections self.connection_string = MONGO_CONFIG["connection_string"] self.database_name = MONGO_CONFIG["database_name"] logger.info(f"Database manager initialized for: {self.database_name}") async def connect(self): try: logger.info(f"Connecting to MongoDB at: {self.connection_string}") self.client = MongoClient(self.connection_string, serverSelectionTimeoutMS=5000) await self.ping() self.db = self.client[self.database_name] self.collections = { 'files': self.db.sa4cps_files, 'metadata': self.db.sa4cps_metadata, 'scanned_directories': self.db.sa4cps_scanned_directories } self._create_base_indexes() logger.info(f"Connected to MongoDB database: {self.database_name}") except (ConnectionFailure, ServerSelectionTimeoutError) as e: logger.error(f"Failed to connect to MongoDB: {e}") raise async def close(self): """Close MongoDB connection""" if self.client: self.client.close() logger.debug("MongoDB connection closed") async def ping(self): """Test database connection""" if not self.client: raise ConnectionFailure("No database connection") try: # Use async approach with timeout import asyncio import concurrent.futures # Run the ping command in a thread pool to avoid blocking loop = asyncio.get_event_loop() with concurrent.futures.ThreadPoolExecutor() as pool: await asyncio.wait_for( loop.run_in_executor(pool, self.client.admin.command, 'ping'), timeout=3.0 # 3 second timeout for ping ) logger.debug("MongoDB ping successful") except asyncio.TimeoutError: logger.error("MongoDB ping timeout after 3 seconds") raise ConnectionFailure("MongoDB ping timeout") except ConnectionFailure as e: logger.error(f"MongoDB ping failed - Server not available: {e}") raise except Exception as e: logger.error(f"MongoDB ping failed with error: {e}") raise ConnectionFailure(f"Ping failed: {e}") def _create_base_indexes(self): """Create indexes for base collections (not energy data collections)""" try: self.collections['files'].create_index("filename", unique=True) self.collections['files'].create_index("processed_at") self.collections['files'].create_index("directory_path") self.collections['scanned_directories'].create_index("directory_path", unique=True) self.collections['scanned_directories'].create_index("last_scanned") self.collections['scanned_directories'].create_index("scan_status") logger.info("Database indexes created successfully") except Exception as e: logger.warning(f"Failed to create indexes: {e}") def _extract_level3_path(self, directory_path: str) -> Optional[str]: """Extract level 3 directory path (SLGs/Community/Building) from full path""" # Expected structure: /SLGs/Community/Building/... parts = directory_path.strip('/').split('/') if len(parts) >= 3 and parts[0] == 'SLGs': # Return SLGs/Community/Building return '/'.join(parts[:3]) return None def _sanitize_collection_name(self, level3_path: str) -> str: """Convert level 3 directory path to valid MongoDB collection name Example: SLGs/CommunityA/Building1 -> energy_data__CommunityA_Building1 """ parts = level3_path.strip('/').split('/') if len(parts) >= 3 and parts[0] == 'SLGs': # Use Community_Building as the collection suffix collection_suffix = f"{parts[1]}_{parts[2]}" collection_name = f"energy_data__{collection_suffix}" return collection_name # Fallback: sanitize the entire path sanitized = level3_path.replace('/', '_').replace('.', '_').replace(' ', '_') sanitized = sanitized.strip('_') return f"energy_data__{sanitized}" def _get_energy_collection(self, directory_path: str): """Get or create energy data collection for a specific level 3 directory path""" level3_path = self._extract_level3_path(directory_path) if not level3_path: logger.warning(f"Could not extract level 3 path from: {directory_path}, using default collection") # Fallback to a default collection for non-standard paths collection_name = "energy_data__other" else: collection_name = self._sanitize_collection_name(level3_path) # Check cache first if collection_name in self.energy_collections_cache: return self.energy_collections_cache[collection_name] # Create/get collection collection = self.db[collection_name] # Create indexes for this energy collection try: collection.create_index([("filename", 1), ("timestamp", 1)]) collection.create_index("timestamp") collection.create_index("meter_id") logger.debug(f"Created indexes for collection: {collection_name}") except Exception as e: logger.warning(f"Failed to create indexes for {collection_name}: {e}") # Cache the collection self.energy_collections_cache[collection_name] = collection logger.info(f"Initialized energy data collection: {collection_name} for path: {directory_path}") return collection def _list_energy_collections(self) -> List[str]: """List all energy data collections in the database""" try: all_collections = self.db.list_collection_names() # Filter collections that start with 'energy_data__' energy_collections = [c for c in all_collections if c.startswith('energy_data__')] return energy_collections except Exception as e: logger.error(f"Error listing energy collections: {e}") return [] async def store_file_data(self, filename: str, records: List[Dict[str, Any]], directory_path: str = None) -> bool: try: current_time = datetime.now() # Determine which collection to use based on directory path if directory_path: energy_collection = self._get_energy_collection(directory_path) level3_path = self._extract_level3_path(directory_path) else: logger.warning(f"No directory path provided for {filename}, using default collection") energy_collection = self._get_energy_collection("/SLGs/unknown/unknown") level3_path = None # Store file metadata file_metadata = { "filename": filename, "directory_path": directory_path, "level3_path": level3_path, "record_count": len(records), "processed_at": current_time, "file_size": sum(len(str(record)) for record in records), "status": "processed" } # Insert or update file record self.collections['files'].replace_one( {"filename": filename}, file_metadata, upsert=True ) # Add filename and processed timestamp to each record for record in records: record["filename"] = filename record["processed_at"] = current_time record["directory_path"] = directory_path # Insert energy data records into the appropriate collection if records: result = energy_collection.insert_many(records) inserted_count = len(result.inserted_ids) logger.debug(f"Stored {inserted_count} records from {filename} to {energy_collection.name}") return True return False except Exception as e: logger.error(f"Error storing data for {filename}: {e}") # Store error metadata error_metadata = { "filename": filename, "directory_path": directory_path, "processed_at": current_time, "status": "error", "error_message": str(e) } self.collections['files'].replace_one( {"filename": filename}, error_metadata, upsert=True ) return False async def get_processed_files(self) -> List[str]: """Get list of successfully processed files""" try: cursor = self.collections['files'].find( {"status": "processed"}, {"filename": 1, "_id": 0} ) files = [] for doc in cursor: files.append(doc["filename"]) return files except Exception as e: logger.error(f"Error getting processed files: {e}") return [] async def is_file_processed(self, filename: str) -> bool: """Mock check if file is processed""" return filename in await self.get_processed_files() async def get_file_info(self, filename: str) -> Optional[Dict[str, Any]]: """Get information about a specific file""" try: return self.collections['files'].find_one({"filename": filename}) except Exception as e: logger.error(f"Error getting file info for {filename}: {e}") return None # Directory scanning tracking methods # Note: Only level 4+ directories (/SLGs/Community/Building/SubDir) are tracked # to avoid unnecessary caching of high-level organizational directories async def is_directory_scanned(self, directory_path: str, since_timestamp: datetime = None) -> bool: """Check if directory has been scanned recently Note: Only level 4+ directories are tracked in the database """ try: query = {"directory_path": directory_path, "scan_status": "complete"} if since_timestamp: query["last_scanned"] = {"$gte": since_timestamp} result = self.collections['scanned_directories'].find_one(query) return result is not None except Exception as e: logger.error(f"Error checking directory scan status for {directory_path}: {e}") return False async def mark_directory_scanned(self, directory_path: str, file_count: int, ftp_last_modified: datetime = None) -> bool: """Mark directory as scanned with current timestamp""" try: scan_record = { "directory_path": directory_path, "last_scanned": datetime.now(), "file_count": file_count, "scan_status": "complete" } if ftp_last_modified: scan_record["ftp_last_modified"] = ftp_last_modified # Use upsert to update existing or create new record self.collections['scanned_directories'].replace_one( {"directory_path": directory_path}, scan_record, upsert=True ) logger.debug(f"Marked directory as scanned: {directory_path} ({file_count} files)") return True except Exception as e: logger.error(f"Error marking directory as scanned {directory_path}: {e}") return False async def get_scanned_directories(self) -> List[Dict[str, Any]]: """Get all scanned directory records""" try: cursor = self.collections['scanned_directories'].find() return list(cursor) except Exception as e: logger.error(f"Error getting scanned directories: {e}") return [] async def should_skip_directory(self, directory_path: str, ftp_last_modified: datetime = None) -> bool: """Determine if directory should be skipped based on scan history and modification time""" try: scan_record = self.collections['scanned_directories'].find_one( {"directory_path": directory_path, "scan_status": "complete"} ) if not scan_record: return False # Never scanned, should scan # If we have FTP modification time and it's newer than our last scan, don't skip if ftp_last_modified and scan_record.get("last_scanned"): return ftp_last_modified <= scan_record["last_scanned"] # If directory was scanned successfully, skip it (assuming it's historical data) return True except Exception as e: logger.error(f"Error determining if directory should be skipped {directory_path}: {e}") return False async def get_stats(self) -> Dict[str, Any]: """Get database statistics including all energy collections""" try: stats = { "database": self.database_name, "timestamp": datetime.now().isoformat() } # Count documents in base collections for name, collection in self.collections.items(): try: count = collection.count_documents({}) stats[f"{name}_count"] = count except Exception as e: stats[f"{name}_count"] = f"error: {e}" # Get all energy collections and their counts try: energy_collections = self._list_energy_collections() energy_stats = [] total_energy_records = 0 for collection_name in energy_collections: collection = self.db[collection_name] count = collection.count_documents({}) total_energy_records += count energy_stats.append({ "collection": collection_name, "record_count": count }) stats["energy_collections"] = energy_stats stats["total_energy_collections"] = len(energy_collections) stats["total_energy_records"] = total_energy_records except Exception as e: stats["energy_collections"] = f"error: {e}" # Get recent files try: recent_files = [] cursor = self.collections['files'].find( {}, {"filename": 1, "processed_at": 1, "record_count": 1, "status": 1, "directory_path": 1, "level3_path": 1, "_id": 0} ).sort("processed_at", -1).limit(5) for doc in cursor: if doc.get("processed_at"): doc["processed_at"] = doc["processed_at"].isoformat() recent_files.append(doc) stats["recent_files"] = recent_files except Exception as e: stats["recent_files"] = f"error: {e}" return stats except Exception as e: logger.error(f"Error getting database stats: {e}") return {"error": str(e), "timestamp": datetime.now().isoformat()} async def get_energy_data(self, filename: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, directory_path: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]: """Retrieve energy data with optional filtering Args: filename: Filter by specific filename start_time: Filter by start timestamp end_time: Filter by end timestamp directory_path: Filter by specific directory path (level 3). If None, queries all collections limit: Maximum number of records to return """ try: query = {} if filename: query["filename"] = filename if start_time or end_time: time_query = {} if start_time: time_query["$gte"] = start_time if end_time: time_query["$lte"] = end_time query["timestamp"] = time_query data = [] # If directory_path is specified, query only that collection if directory_path: collection = self._get_energy_collection(directory_path) cursor = collection.find(query).sort("timestamp", -1).limit(limit) for doc in cursor: data.append(self._format_energy_document(doc)) else: # Query across all energy collections energy_collection_names = self._list_energy_collections() # Collect data from all collections, then sort and limit all_data = [] per_collection_limit = max(limit, 1000) # Get more from each to ensure we have enough after sorting for collection_name in energy_collection_names: collection = self.db[collection_name] cursor = collection.find(query).sort("timestamp", -1).limit(per_collection_limit) for doc in cursor: all_data.append(self._format_energy_document(doc)) # Sort all data by timestamp and apply final limit all_data.sort(key=lambda x: x.get("timestamp", ""), reverse=True) data = all_data[:limit] return data except Exception as e: logger.error(f"Error retrieving energy data: {e}") return [] def _format_energy_document(self, doc: Dict[str, Any]) -> Dict[str, Any]: """Format energy document for API response""" # Convert ObjectId to string and datetime to ISO string if "_id" in doc: doc["_id"] = str(doc["_id"]) if "timestamp" in doc and hasattr(doc["timestamp"], "isoformat"): doc["timestamp"] = doc["timestamp"].isoformat() if "processed_at" in doc and hasattr(doc["processed_at"], "isoformat"): doc["processed_at"] = doc["processed_at"].isoformat() return doc