""" Token service implementation """ import jwt import uuid from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from motor.motor_asyncio import AsyncIOMotorDatabase import os class TokenService: """Service for managing JWT tokens and authentication""" def __init__(self, db: AsyncIOMotorDatabase, secret_key: str = None): self.db = db self.secret_key = secret_key or os.getenv("JWT_SECRET_KEY", "energy-dashboard-secret-key") self.tokens_collection = db.tokens def generate_token(self, name: str, list_of_resources: List[str], data_aggregation: bool = False, time_aggregation: bool = False, embargo: int = 0, exp_hours: int = 24) -> str: """Generate a new JWT token with specified permissions""" # Calculate expiration time exp_timestamp = int((datetime.utcnow() + timedelta(hours=exp_hours)).timestamp()) # Create token payload payload = { "name": name, "list_of_resources": list_of_resources, "data_aggregation": data_aggregation, "time_aggregation": time_aggregation, "embargo": embargo, "exp": exp_timestamp, "iat": int(datetime.utcnow().timestamp()), "jti": str(uuid.uuid4()) # unique token ID } # Generate JWT token token = jwt.encode(payload, self.secret_key, algorithm="HS256") return token def decode_token(self, token: str) -> Optional[Dict[str, Any]]: """Decode and verify JWT token""" try: payload = jwt.decode(token, self.secret_key, algorithms=["HS256"]) return payload except jwt.ExpiredSignatureError: return {"error": "Token has expired"} except jwt.InvalidTokenError: return {"error": "Invalid token"} async def insert_token(self, token: str) -> Dict[str, Any]: """Save token to database""" now = datetime.utcnow() # Decode token to verify it's valid decoded = self.decode_token(token) if decoded and "error" not in decoded: token_record = { "token": token, "datetime": now, "active": True, "created_at": now, "updated_at": now, "name": decoded.get("name", ""), "resources": decoded.get("list_of_resources", []), "expires_at": datetime.fromtimestamp(decoded.get("exp", 0)) } # Upsert token (update if exists, insert if not) await self.tokens_collection.replace_one( {"token": token}, token_record, upsert=True ) return { "token": token, "datetime": now.isoformat(), "active": True } else: raise ValueError("Invalid token cannot be saved") async def revoke_token(self, token: str) -> Dict[str, Any]: """Revoke a token by marking it as inactive""" now = datetime.utcnow() result = await self.tokens_collection.update_one( {"token": token}, { "$set": { "active": False, "updated_at": now, "revoked_at": now } } ) if result.matched_count > 0: return { "token": token, "datetime": now.isoformat(), "active": False } else: raise ValueError("Token not found") async def get_tokens(self) -> List[Dict[str, Any]]: """Get all tokens from database""" cursor = self.tokens_collection.find({}) tokens = [] async for token_record in cursor: # Convert ObjectId to string and datetime to ISO format token_record["_id"] = str(token_record["_id"]) for field in ["datetime", "created_at", "updated_at", "expires_at", "revoked_at"]: if field in token_record and token_record[field]: token_record[field] = token_record[field].isoformat() tokens.append(token_record) return tokens async def is_token_valid(self, token: str) -> bool: """Check if token is valid and active""" token_record = await self.tokens_collection.find_one({ "token": token, "active": True }) if not token_record: return False decoded = self.decode_token(token) return decoded is not None and "error" not in decoded async def get_token_permissions(self, token: str) -> Optional[Dict[str, Any]]: """Get permissions for a valid token""" if await self.is_token_valid(token): return self.decode_token(token) return None async def cleanup_expired_tokens(self) -> int: """Remove expired tokens from database""" now = datetime.now() # Delete tokens that have expired result = await self.tokens_collection.delete_many({ "expires_at": {"$lt": now} }) return result.deleted_count