""" Token management service for authentication and resource access control. Based on the tiocps JWT token implementation with resource-based permissions. """ import jwt import uuid from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from pydantic import BaseModel from motor.motor_asyncio import AsyncIOMotorDatabase class TokenPayload(BaseModel): """Token payload structure""" name: str list_of_resources: List[str] data_aggregation: bool = False time_aggregation: bool = False embargo: int = 0 # embargo period in seconds exp: int # expiration timestamp class TokenRecord(BaseModel): """Token database record""" token: str datetime: datetime active: bool = True created_at: datetime updated_at: datetime class TokenService: """Service for managing JWT tokens and authentication""" def __init__(self, db: AsyncIOMotorDatabase, secret_key: str = "dashboard-secret-key"): self.db = db self.secret_key = secret_key self.tokens_collection = db.tokens def generate_token(self, name: str, list_of_resources: List[str], data_aggregation: bool = False, time_aggregation: bool = False, embargo: int = 0, exp_hours: int = 24) -> str: """Generate a new JWT token with specified permissions""" # Calculate expiration time exp_timestamp = int((datetime.utcnow() + timedelta(hours=exp_hours)).timestamp()) # Create token payload payload = { "name": name, "list_of_resources": list_of_resources, "data_aggregation": data_aggregation, "time_aggregation": time_aggregation, "embargo": embargo, "exp": exp_timestamp, "iat": int(datetime.utcnow().timestamp()), "jti": str(uuid.uuid4()) # unique token ID } # Generate JWT token token = jwt.encode(payload, self.secret_key, algorithm="HS256") return token def decode_token(self, token: str) -> Optional[Dict[str, Any]]: """Decode and verify JWT token""" try: payload = jwt.decode(token, self.secret_key, algorithms=["HS256"]) return payload except jwt.ExpiredSignatureError: return {"error": "Token has expired"} except jwt.InvalidTokenError: return {"error": "Invalid token"} async def insert_token(self, token: str) -> Dict[str, Any]: """Save token to database""" now = datetime.utcnow() # Decode token to verify it's valid decoded = self.decode_token(token) if decoded and "error" not in decoded: token_record = { "token": token, "datetime": now, "active": True, "created_at": now, "updated_at": now, "name": decoded.get("name", ""), "resources": decoded.get("list_of_resources", []), "expires_at": datetime.fromtimestamp(decoded.get("exp", 0)) } await self.tokens_collection.insert_one(token_record) return { "token": token, "datetime": now.isoformat(), "active": True } else: raise ValueError("Invalid token cannot be saved") async def revoke_token(self, token: str) -> Dict[str, Any]: """Revoke a token by marking it as inactive""" now = datetime.utcnow() result = await self.tokens_collection.update_one( {"token": token}, { "$set": { "active": False, "updated_at": now, "revoked_at": now } } ) if result.matched_count > 0: return { "token": token, "datetime": now.isoformat(), "active": False } else: raise ValueError("Token not found") async def get_tokens(self) -> List[Dict[str, Any]]: """Get all tokens from database""" cursor = self.tokens_collection.find({}) tokens = [] async for token_record in cursor: # Convert ObjectId to string and datetime to ISO format token_record["_id"] = str(token_record["_id"]) for field in ["datetime", "created_at", "updated_at", "expires_at", "revoked_at"]: if field in token_record and token_record[field]: token_record[field] = token_record[field].isoformat() tokens.append(token_record) return tokens async def is_token_valid(self, token: str) -> bool: """Check if token is valid and active""" # Check if token exists and is active in database token_record = await self.tokens_collection.find_one({ "token": token, "active": True }) if not token_record: return False # Verify JWT signature and expiration decoded = self.decode_token(token) return decoded is not None and "error" not in decoded async def get_token_permissions(self, token: str) -> Optional[Dict[str, Any]]: """Get permissions for a valid token""" if await self.is_token_valid(token): return self.decode_token(token) return None async def cleanup_expired_tokens(self): """Remove expired tokens from database""" now = datetime.utcnow() # Find tokens that have expired expired_cursor = self.tokens_collection.find({ "expires_at": {"$lt": now} }) expired_count = 0 async for token_record in expired_cursor: await self.tokens_collection.delete_one({"_id": token_record["_id"]}) expired_count += 1 return expired_count