Refactor service registry and load balancer integration

- Pass service registry to load balancer for dependency injection -
Remove dynamic imports of service registry in load balancer - Update
service registration and health check logic - Enable token-service in
docker-compose and service config - Add room names and rooms proxy
endpoints - Improve logging for proxy requests and health checks -
Update deploy script project name to sa4cps - Add test script for
coroutine fix - Minor code cleanup and formatting
This commit is contained in:
rafaeldpsilva
2025-09-22 15:13:06 +01:00
parent 41b8753a92
commit 2008ea0e70
7 changed files with 148 additions and 136 deletions

View File

@@ -11,10 +11,11 @@ logger = logging.getLogger(__name__)
class AuthMiddleware: class AuthMiddleware:
"""Authentication middleware for validating tokens""" """Authentication middleware for validating tokens"""
def __init__(self, token_service_url: str = "http://localhost:8001"): def __init__(self, token_service_url: str = "http://localhost:8001"):
self.token_service_url = token_service_url self.token_service_url = token_service_url
logger.info(f"Initialized AuthMiddleware with token service URL: {self.token_service_url}")
async def verify_token(self, request: Request) -> Optional[Dict[str, Any]]: async def verify_token(self, request: Request) -> Optional[Dict[str, Any]]:
""" """
Verify authentication token from request headers Verify authentication token from request headers
@@ -24,12 +25,12 @@ class AuthMiddleware:
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
if not auth_header: if not auth_header:
raise HTTPException(status_code=401, detail="Authorization header required") raise HTTPException(status_code=401, detail="Authorization header required")
if not auth_header.startswith("Bearer "): if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Bearer token required") raise HTTPException(status_code=401, detail="Bearer token required")
token = auth_header[7:] # Remove "Bearer " prefix token = auth_header[7:] # Remove "Bearer " prefix
try: try:
# Validate token with token service # Validate token with token service
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@@ -38,19 +39,19 @@ class AuthMiddleware:
json={"token": token}, json={"token": token},
timeout=aiohttp.ClientTimeout(total=5) timeout=aiohttp.ClientTimeout(total=5)
) as response: ) as response:
if response.status != 200: if response.status != 200:
raise HTTPException(status_code=401, detail="Token validation failed") raise HTTPException(status_code=401, detail="Token validation failed")
token_data = await response.json() token_data = await response.json()
if not token_data.get("valid"): if not token_data.get("valid"):
error_msg = token_data.get("error", "Invalid token") error_msg = token_data.get("error", "Invalid token")
raise HTTPException(status_code=401, detail=error_msg) raise HTTPException(status_code=401, detail=error_msg)
# Token is valid, return decoded payload # Token is valid, return decoded payload
return token_data.get("decoded") return token_data.get("decoded")
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error(f"Token service connection error: {e}") logger.error(f"Token service connection error: {e}")
raise HTTPException(status_code=503, detail="Authentication service unavailable") raise HTTPException(status_code=503, detail="Authentication service unavailable")
@@ -59,24 +60,24 @@ class AuthMiddleware:
except Exception as e: except Exception as e:
logger.error(f"Token verification error: {e}") logger.error(f"Token verification error: {e}")
raise HTTPException(status_code=500, detail="Authentication error") raise HTTPException(status_code=500, detail="Authentication error")
async def check_permissions(self, token_payload: Dict[str, Any], required_resources: list) -> bool: async def check_permissions(self, token_payload: Dict[str, Any], required_resources: list) -> bool:
""" """
Check if token has required permissions for specific resources Check if token has required permissions for specific resources
""" """
if not token_payload: if not token_payload:
return False return False
# Get list of resources the token has access to # Get list of resources the token has access to
token_resources = token_payload.get("list_of_resources", []) token_resources = token_payload.get("list_of_resources", [])
# Check if token has access to all required resources # Check if token has access to all required resources
for resource in required_resources: for resource in required_resources:
if resource not in token_resources: if resource not in token_resources:
return False return False
return True return True
def extract_user_info(self, token_payload: Dict[str, Any]) -> Dict[str, Any]: def extract_user_info(self, token_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Extract user information from token payload""" """Extract user information from token payload"""
return { return {
@@ -86,4 +87,4 @@ class AuthMiddleware:
"time_aggregation": token_payload.get("time_aggregation", False), "time_aggregation": token_payload.get("time_aggregation", False),
"embargo": token_payload.get("embargo", 0), "embargo": token_payload.get("embargo", 0),
"expires_at": token_payload.get("exp") "expires_at": token_payload.get("exp")
} }

View File

@@ -10,11 +10,12 @@ logger = logging.getLogger(__name__)
class LoadBalancer: class LoadBalancer:
"""Simple load balancer for microservice requests""" """Simple load balancer for microservice requests"""
def __init__(self): def __init__(self, service_registry=None):
# In a real implementation, this would track multiple instances per service # In a real implementation, this would track multiple instances per service
self.service_instances: Dict[str, List[str]] = {} self.service_instances: Dict[str, List[str]] = {}
self.current_index: Dict[str, int] = {} self.current_index: Dict[str, int] = {}
self.service_registry = service_registry
def register_service_instance(self, service_name: str, instance_url: str): def register_service_instance(self, service_name: str, instance_url: str):
"""Register a new service instance""" """Register a new service instance"""
@@ -54,9 +55,11 @@ class LoadBalancer:
if strategy == "single": if strategy == "single":
# Default behavior - get the service URL from service registry # Default behavior - get the service URL from service registry
from service_registry import ServiceRegistry if self.service_registry:
service_registry = ServiceRegistry() return await self.service_registry.get_service_url(service_name)
return await service_registry.get_service_url(service_name) else:
logger.error("No service registry available")
return None
elif strategy == "round_robin": elif strategy == "round_robin":
return await self._round_robin_select(service_name) return await self._round_robin_select(service_name)
@@ -73,9 +76,11 @@ class LoadBalancer:
instances = self.service_instances.get(service_name, []) instances = self.service_instances.get(service_name, [])
if not instances: if not instances:
# Fall back to service registry # Fall back to service registry
from service_registry import ServiceRegistry if self.service_registry:
service_registry = ServiceRegistry() return await self.service_registry.get_service_url(service_name)
return await service_registry.get_service_url(service_name) else:
logger.error("No service registry available for fallback")
return None
# Round-robin selection # Round-robin selection
current_idx = self.current_index[service_name] current_idx = self.current_index[service_name]
@@ -92,9 +97,11 @@ class LoadBalancer:
instances = self.service_instances.get(service_name, []) instances = self.service_instances.get(service_name, [])
if not instances: if not instances:
# Fall back to service registry # Fall back to service registry
from service_registry import ServiceRegistry if self.service_registry:
service_registry = ServiceRegistry() return await self.service_registry.get_service_url(service_name)
return await service_registry.get_service_url(service_name) else:
logger.error("No service registry available for fallback")
return None
selected_instance = random.choice(instances) selected_instance = random.choice(instances)
logger.debug(f"Random selected {selected_instance} for {service_name}") logger.debug(f"Random selected {selected_instance} for {service_name}")

View File

@@ -64,47 +64,17 @@ app.add_middleware(
# Service registry and load balancer # Service registry and load balancer
service_registry = ServiceRegistry() service_registry = ServiceRegistry()
load_balancer = LoadBalancer() load_balancer = LoadBalancer(service_registry)
auth_middleware = AuthMiddleware() auth_middleware = AuthMiddleware()
# Service configuration # Service configuration
SERVICES = { SERVICES = {
# "token-service": ServiceConfig( "token-service": ServiceConfig(
# name="token-service", name="token-service",
# base_url=os.getenv("TOKEN_SERVICE_URL", "http://token-service:8001"), base_url=os.getenv("TOKEN_SERVICE_URL", "http://token-service:8001"),
# health_endpoint="/health", health_endpoint="/health",
# auth_required=False auth_required=False
# ), ),
# "battery-service": ServiceConfig(
# name="battery-service",
# base_url=os.getenv("BATTERY_SERVICE_URL", "http://battery-service:8002"),
# health_endpoint="/health",
# auth_required=True
# ),
# "demand-response-service": ServiceConfig(
# name="demand-response-service",
# base_url=os.getenv("DEMAND_RESPONSE_SERVICE_URL", "http://demand-response-service:8003"),
# health_endpoint="/health",
# auth_required=True
# ),
# "p2p-trading-service": ServiceConfig(
# name="p2p-trading-service",
# base_url=os.getenv("P2P_TRADING_SERVICE_URL", "http://p2p-trading-service:8004"),
# health_endpoint="/health",
# auth_required=True
# ),
# "forecasting-service": ServiceConfig(
# name="forecasting-service",
# base_url=os.getenv("FORECASTING_SERVICE_URL", "http://forecasting-service:8005"),
# health_endpoint="/health",
# auth_required=True
# ),
# "iot-control-service": ServiceConfig(
# name="iot-control-service",
# base_url=os.getenv("IOT_CONTROL_SERVICE_URL", "http://iot-control-service:8006"),
# health_endpoint="/health",
# auth_required=True
# ),
"sensor-service": ServiceConfig( "sensor-service": ServiceConfig(
name="sensor-service", name="sensor-service",
base_url=os.getenv("SENSOR_SERVICE_URL", "http://sensor-service:8007"), base_url=os.getenv("SENSOR_SERVICE_URL", "http://sensor-service:8007"),
@@ -187,7 +157,7 @@ async def get_gateway_stats():
@app.api_route("/api/v1/tokens/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/api/v1/tokens/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def token_service_proxy(request: Request, path: str): async def token_service_proxy(request: Request, path: str):
"""Proxy requests to token service""" """Proxy requests to token service"""
return await proxy_request(request, "token-service", f"/{path}") return await proxy_request(request, "token-service", f"/tokens/{path}")
# Battery Service Routes # Battery Service Routes
@app.api_route("/api/v1/batteries/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/api/v1/batteries/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
@@ -241,6 +211,16 @@ async def data_sources_list_proxy(request: Request):
"""Proxy requests to data ingestion service for sources list""" """Proxy requests to data ingestion service for sources list"""
return await proxy_request(request, "data-ingestion-service", "/sources") return await proxy_request(request, "data-ingestion-service", "/sources")
@app.get("/api/v1/rooms/names")
async def room_names_proxy(request: Request):
"""Proxy requests to sensor service for room names list"""
return await proxy_request(request, "sensor-service", "/rooms/names")
@app.get("/api/v1/rooms")
async def rooms_list_proxy(request: Request):
"""Proxy requests to sensor service for rooms list"""
return await proxy_request(request, "sensor-service", "/rooms")
@app.api_route("/api/v1/rooms/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/api/v1/rooms/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def room_service_proxy(request: Request, path: str): async def room_service_proxy(request: Request, path: str):
"""Proxy requests to sensor service for room management""" """Proxy requests to sensor service for room management"""
@@ -302,6 +282,7 @@ async def websocket_proxy(websocket: WebSocket):
async def proxy_request(request: Request, service_name: str, path: str): async def proxy_request(request: Request, service_name: str, path: str):
"""Generic request proxy function""" """Generic request proxy function"""
try: try:
logger.info(f"Proxying request to {service_name} at {path}")
# Update request statistics # Update request statistics
request_stats["total_requests"] += 1 request_stats["total_requests"] += 1
request_stats["service_requests"][service_name] += 1 request_stats["service_requests"][service_name] += 1

View File

@@ -1,7 +1,3 @@
"""
Service registry for managing microservice discovery and health monitoring
"""
import aiohttp import aiohttp
import asyncio import asyncio
from datetime import datetime from datetime import datetime
@@ -13,30 +9,26 @@ from models import ServiceConfig, ServiceHealth
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ServiceRegistry: class ServiceRegistry:
"""Service registry for microservice management"""
def __init__(self): def __init__(self):
self.services: Dict[str, ServiceConfig] = {} self.services: Dict[str, ServiceConfig] = {}
self.service_health: Dict[str, ServiceHealth] = {} self.service_health: Dict[str, ServiceHealth] = {}
self.session: Optional[aiohttp.ClientSession] = None self.session: Optional[aiohttp.ClientSession] = None
async def initialize(self): async def initialize(self):
"""Initialize the service registry"""
self.session = aiohttp.ClientSession( self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=10) timeout=aiohttp.ClientTimeout(total=10)
) )
logger.info("Service registry initialized") logger.info("Service registry initialized")
async def close(self): async def close(self):
"""Close the service registry"""
if self.session: if self.session:
await self.session.close() await self.session.close()
logger.info("Service registry closed") logger.info("Service registry closed")
async def register_services(self, services: Dict[str, ServiceConfig]): async def register_services(self, services: Dict[str, ServiceConfig]):
"""Register multiple services"""
self.services.update(services) self.services.update(services)
# Initialize health status for all services # Initialize health status for all services
for service_name, config in services.items(): for service_name, config in services.items():
self.service_health[service_name] = ServiceHealth( self.service_health[service_name] = ServiceHealth(
@@ -44,34 +36,31 @@ class ServiceRegistry:
status="unknown", status="unknown",
last_check=datetime.utcnow() last_check=datetime.utcnow()
) )
logger.info(f"Registered {len(services)} services") logger.info(f"Registered {len(services)} services")
# Perform initial health check # Perform initial health check
await self.update_all_service_health() await self.update_all_service_health()
async def register_service(self, service_config: ServiceConfig): async def register_service(self, service_config: ServiceConfig):
"""Register a single service"""
self.services[service_config.name] = service_config self.services[service_config.name] = service_config
self.service_health[service_config.name] = ServiceHealth( self.service_health[service_config.name] = ServiceHealth(
service=service_config.name, service=service_config.name,
status="unknown", status="unknown",
last_check=datetime.utcnow() last_check=datetime.utcnow()
) )
logger.info(f"Registered service: {service_config.name}") logger.info(f"Registered service: {service_config.name}")
# Check health of the newly registered service # Check health of the newly registered service
await self.check_service_health(service_config.name) await self.check_service_health(service_config.name)
async def unregister_service(self, service_name: str): async def unregister_service(self, service_name: str):
"""Unregister a service"""
self.services.pop(service_name, None) self.services.pop(service_name, None)
self.service_health.pop(service_name, None) self.service_health.pop(service_name, None)
logger.info(f"Unregistered service: {service_name}") logger.info(f"Unregistered service: {service_name}")
async def check_service_health(self, service_name: str) -> ServiceHealth: async def check_service_health(self, service_name: str) -> ServiceHealth:
"""Check health of a specific service"""
service_config = self.services.get(service_name) service_config = self.services.get(service_name)
if not service_config: if not service_config:
logger.error(f"Service {service_name} not found in registry") logger.error(f"Service {service_name} not found in registry")
@@ -81,20 +70,20 @@ class ServiceRegistry:
last_check=datetime.utcnow(), last_check=datetime.utcnow(),
error_message="Service not registered" error_message="Service not registered"
) )
start_time = datetime.utcnow() start_time = datetime.utcnow()
try: try:
health_url = f"{service_config.base_url}{service_config.health_endpoint}" health_url = f"{service_config.base_url}{service_config.health_endpoint}"
async with self.session.get(health_url) as response: async with self.session.get(health_url) as response:
end_time = datetime.utcnow() end_time = datetime.utcnow()
response_time = (end_time - start_time).total_seconds() * 1000 response_time = (end_time - start_time).total_seconds() * 1000
if response.status == 200: if response.status == 200:
health_data = await response.json() health_data = await response.json()
status = "healthy" if health_data.get("status") in ["healthy", "ok"] else "unhealthy" status = "healthy" if health_data.get("status") in ["healthy", "ok"] else "unhealthy"
health = ServiceHealth( health = ServiceHealth(
service=service_name, service=service_name,
status=status, status=status,
@@ -109,7 +98,7 @@ class ServiceRegistry:
last_check=end_time, last_check=end_time,
error_message=f"HTTP {response.status}" error_message=f"HTTP {response.status}"
) )
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
health = ServiceHealth( health = ServiceHealth(
service=service_name, service=service_name,
@@ -124,37 +113,33 @@ class ServiceRegistry:
last_check=datetime.utcnow(), last_check=datetime.utcnow(),
error_message=f"Health check failed: {str(e)}" error_message=f"Health check failed: {str(e)}"
) )
# Update health status # Update health status
self.service_health[service_name] = health self.service_health[service_name] = health
# Log health status changes # Log health status changes
if health.status != "healthy": if health.status != "healthy":
logger.warning(f"Service {service_name} health check failed: {health.error_message}") logger.warning(f"Service {service_name} health check failed: {health.error_message}")
return health return health
async def update_all_service_health(self): async def update_all_service_health(self):
"""Update health status for all registered services"""
health_checks = [ health_checks = [
self.check_service_health(service_name) self.check_service_health(service_name)
for service_name in self.services.keys() for service_name in self.services.keys()
] ]
if health_checks: if health_checks:
await asyncio.gather(*health_checks, return_exceptions=True) await asyncio.gather(*health_checks, return_exceptions=True)
# Log summary
healthy_count = sum(1 for h in self.service_health.values() if h.status == "healthy") healthy_count = sum(1 for h in self.service_health.values() if h.status == "healthy")
total_count = len(self.services) total_count = len(self.services)
logger.info(f"Health check complete: {healthy_count}/{total_count} services healthy") logger.info(f"Health check complete: {healthy_count}/{total_count} services healthy {self.service_health.values()}")
async def get_service_health(self, service_name: str) -> Optional[ServiceHealth]: async def get_service_health(self, service_name: str) -> Optional[ServiceHealth]:
"""Get health status of a specific service"""
return self.service_health.get(service_name) return self.service_health.get(service_name)
async def get_all_service_health(self) -> Dict[str, Dict]: async def get_all_service_health(self) -> Dict[str, Dict]:
"""Get health status of all services"""
health_dict = {} health_dict = {}
for service_name, health in self.service_health.items(): for service_name, health in self.service_health.items():
health_dict[service_name] = { health_dict[service_name] = {
@@ -164,31 +149,26 @@ class ServiceRegistry:
"error_message": health.error_message "error_message": health.error_message
} }
return health_dict return health_dict
async def is_service_healthy(self, service_name: str) -> bool: async def is_service_healthy(self, service_name: str) -> bool:
"""Check if a service is healthy"""
health = self.service_health.get(service_name) health = self.service_health.get(service_name)
return health is not None and health.status == "healthy" return health is not None and health.status == "healthy"
async def get_healthy_services(self) -> List[str]: async def get_healthy_services(self) -> List[str]:
"""Get list of healthy service names"""
return [ return [
service_name service_name
for service_name, health in self.service_health.items() for service_name, health in self.service_health.items()
if health.status == "healthy" if health.status == "healthy"
] ]
def get_service_config(self, service_name: str) -> Optional[ServiceConfig]: def get_service_config(self, service_name: str) -> Optional[ServiceConfig]:
"""Get configuration for a specific service"""
return self.services.get(service_name) return self.services.get(service_name)
def get_all_services(self) -> Dict[str, ServiceConfig]: def get_all_services(self) -> Dict[str, ServiceConfig]:
"""Get all registered services"""
return self.services.copy() return self.services.copy()
async def get_service_url(self, service_name: str) -> Optional[str]: async def get_service_url(self, service_name: str) -> Optional[str]:
"""Get base URL for a healthy service"""
if await self.is_service_healthy(service_name): if await self.is_service_healthy(service_name):
service_config = self.services.get(service_name) service_config = self.services.get(service_name)
return service_config.base_url if service_config else None return service_config.base_url if service_config else None
return None return None

View File

@@ -0,0 +1,42 @@
#!/usr/bin/env python3
"""
Test script to validate that the coroutine fix works
"""
import asyncio
import sys
from unittest.mock import MagicMock, AsyncMock
# Mock the dependencies
sys.modules['aiohttp'] = MagicMock()
sys.modules['models'] = MagicMock()
sys.modules['service_registry'] = MagicMock()
sys.modules['load_balancer'] = MagicMock()
sys.modules['auth_middleware'] = MagicMock()
# Import the main module after mocking
import main
async def test_lifespan():
"""Test that the lifespan function works correctly"""
# Mock the service registry
main.service_registry.initialize = AsyncMock()
main.service_registry.register_services = AsyncMock()
main.service_registry.close = AsyncMock()
# Test the lifespan context manager
async with main.lifespan(None):
print("✅ Lifespan startup completed successfully")
# Verify that the methods were called
main.service_registry.initialize.assert_called_once()
main.service_registry.register_services.assert_called_once_with(main.SERVICES)
# Verify shutdown was called
main.service_registry.close.assert_called_once()
print("✅ Lifespan shutdown completed successfully")
print("✅ All coroutines are properly awaited - RuntimeWarning should be resolved")
if __name__ == "__main__":
asyncio.run(test_lifespan())

View File

@@ -14,7 +14,7 @@ NC='\033[0m' # No Color
# Configuration # Configuration
COMPOSE_FILE="docker-compose.yml" COMPOSE_FILE="docker-compose.yml"
PROJECT_NAME="energy-dashboard" PROJECT_NAME="sa4cps"
# Function to print colored output # Function to print colored output
print_status() { print_status() {

View File

@@ -51,7 +51,7 @@ services:
depends_on: depends_on:
- mongodb - mongodb
- redis - redis
# - token-service - token-service
- sensor-service - sensor-service
- data-ingestion-service - data-ingestion-service
# - battery-service # - battery-service
@@ -60,21 +60,21 @@ services:
- energy-network - energy-network
# Token Management Service # Token Management Service
# token-service: token-service:
# build: build:
# context: ./token-service context: ./token-service
# dockerfile: Dockerfile dockerfile: Dockerfile
# container_name: token-service container_name: token-service
# restart: unless-stopped restart: unless-stopped
# ports: ports:
# - "8001:8001" - "8001:8001"
# environment: environment:
# - MONGO_URL=mongodb://admin:password123@localhost:27017/energy_dashboard_tokens?authSource=admin - MONGO_URL=mongodb://admin:password123@mongodb:27017/energy_dashboard_tokens?authSource=admin
# - JWT_SECRET_KEY=your-super-secret-jwt-key-change-in-production - JWT_SECRET_KEY=your-super-secret-jwt-key-change-in-production
# depends_on: depends_on:
# - mongodb - mongodb
# networks: networks:
# - energy-network - energy-network
# Battery Management Service # Battery Management Service
# battery-service: # battery-service:
@@ -185,6 +185,7 @@ services:
- FTP_SA4CPS_USERNAME=curvascarga@sa4cps.pt - FTP_SA4CPS_USERNAME=curvascarga@sa4cps.pt
- FTP_SA4CPS_REMOTE_PATH=/SLGs/ - FTP_SA4CPS_REMOTE_PATH=/SLGs/
- FTP_CHECK_INTERVAL=21600 - FTP_CHECK_INTERVAL=21600
- FTP_SKIP_INITIAL_SCAN=true
depends_on: depends_on:
- mongodb - mongodb
networks: networks:
@@ -202,7 +203,7 @@ services:
environment: environment:
- MONGO_URL=mongodb://admin:password123@mongodb:27017/energy_dashboard_sensors?authSource=admin - MONGO_URL=mongodb://admin:password123@mongodb:27017/energy_dashboard_sensors?authSource=admin
- REDIS_URL=redis://redis:6379 - REDIS_URL=redis://redis:6379
- TOKEN_SERVICE_URL=http://token-service:8001 # - TOKEN_SERVICE_URL=http://token-service:8001
depends_on: depends_on:
- mongodb - mongodb
- redis - redis