import json import logging import time from typing import Callable from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response def get_logger(name: str) -> logging.Logger: logger = logging.getLogger(name) if not logger.handlers: handler = logging.StreamHandler() handler.setFormatter(_JsonFormatter()) logger.addHandler(handler) logger.propagate = False return logger class _JsonFormatter(logging.Formatter): def format(self, record: logging.LogRecord) -> str: payload = { "time": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"), "level": record.levelname, "logger": record.name, "message": record.getMessage(), } if record.exc_info: payload["exc_info"] = self.formatException(record.exc_info) # Merge any extra fields passed via `extra=` for key, value in record.__dict__.items(): if key not in ( "name", "msg", "args", "levelname", "levelno", "pathname", "filename", "module", "exc_info", "exc_text", "stack_info", "lineno", "funcName", "created", "msecs", "relativeCreated", "thread", "threadName", "processName", "process", "message", "taskName", ): payload[key] = value return json.dumps(payload, ensure_ascii=False) class RequestLoggingMiddleware(BaseHTTPMiddleware): def __init__(self, app, logger: logging.Logger | None = None) -> None: super().__init__(app) self._logger = logger or get_logger("request") async def dispatch(self, request: Request, call_next: Callable) -> Response: start = time.perf_counter() response = await call_next(request) duration_ms = round((time.perf_counter() - start) * 1000, 1) self._logger.info( "request", extra={ "method": request.method, "path": request.url.path, "status": response.status_code, "duration_ms": duration_ms, }, ) return response