63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
|
|
import json
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
from typing import Callable
|
||
|
|
|
||
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
|
from starlette.requests import Request
|
||
|
|
from starlette.responses import Response
|
||
|
|
|
||
|
|
|
||
|
|
def get_logger(name: str) -> logging.Logger:
|
||
|
|
logger = logging.getLogger(name)
|
||
|
|
if not logger.handlers:
|
||
|
|
handler = logging.StreamHandler()
|
||
|
|
handler.setFormatter(_JsonFormatter())
|
||
|
|
logger.addHandler(handler)
|
||
|
|
logger.propagate = False
|
||
|
|
return logger
|
||
|
|
|
||
|
|
|
||
|
|
class _JsonFormatter(logging.Formatter):
|
||
|
|
def format(self, record: logging.LogRecord) -> str:
|
||
|
|
payload = {
|
||
|
|
"time": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"),
|
||
|
|
"level": record.levelname,
|
||
|
|
"logger": record.name,
|
||
|
|
"message": record.getMessage(),
|
||
|
|
}
|
||
|
|
if record.exc_info:
|
||
|
|
payload["exc_info"] = self.formatException(record.exc_info)
|
||
|
|
# Merge any extra fields passed via `extra=`
|
||
|
|
for key, value in record.__dict__.items():
|
||
|
|
if key not in (
|
||
|
|
"name", "msg", "args", "levelname", "levelno", "pathname",
|
||
|
|
"filename", "module", "exc_info", "exc_text", "stack_info",
|
||
|
|
"lineno", "funcName", "created", "msecs", "relativeCreated",
|
||
|
|
"thread", "threadName", "processName", "process", "message",
|
||
|
|
"taskName",
|
||
|
|
):
|
||
|
|
payload[key] = value
|
||
|
|
return json.dumps(payload, ensure_ascii=False)
|
||
|
|
|
||
|
|
|
||
|
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||
|
|
def __init__(self, app, logger: logging.Logger | None = None) -> None:
|
||
|
|
super().__init__(app)
|
||
|
|
self._logger = logger or get_logger("request")
|
||
|
|
|
||
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
|
|
start = time.perf_counter()
|
||
|
|
response = await call_next(request)
|
||
|
|
duration_ms = round((time.perf_counter() - start) * 1000, 1)
|
||
|
|
self._logger.info(
|
||
|
|
"request",
|
||
|
|
extra={
|
||
|
|
"method": request.method,
|
||
|
|
"path": request.url.path,
|
||
|
|
"status": response.status_code,
|
||
|
|
"duration_ms": duration_ms,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return response
|