diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..aa3f625 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +.idea +.ipynb_checkpoints +.mypy_cache +.vscode +__pycache__ +.pytest_cache +htmlcov +dist +site +.coverage +coverage.xml +.netlify +test.db +log.txt +Pipfile.lock +env3.* +env +docs_build +site_build +venv +docs.zip +archive.zip + +*~ +.*.sw? +.cache + +.DS_Store diff --git a/README.md b/README.md index e77c04f..e52b3fa 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ pip install logmiddleware ### Setting up middleware ```python +import logging from fastapi import FastAPI from logmiddleware import RouterLoggingMiddleware, logging_config @@ -35,7 +36,7 @@ app = FastAPI() # Add the middleware to your FastAPI app app.add_middleware( RouterLoggingMiddleware, - logger=logging.getLogger(), # Pass your logger instance + logger=logging.getLogger(__name__), # Pass your logger instance api_debug=True, # Set to True to enable debugging of response bodies ) ``` @@ -50,4 +51,4 @@ If you find any issues or have suggestions for improvements, please feel free to ## License -This project is licensed under the MIT license. \ No newline at end of file +This project is licensed under the MIT license. diff --git a/logmiddleware/__init__.py b/logmiddleware/__init__.py new file mode 100644 index 0000000..8ed9af2 --- /dev/null +++ b/logmiddleware/__init__.py @@ -0,0 +1 @@ +from .middleware import RouterLoggingMiddleware, logging_config diff --git a/logmiddleware/middleware.py b/logmiddleware/middleware.py new file mode 100644 index 0000000..1e95287 --- /dev/null +++ b/logmiddleware/middleware.py @@ -0,0 +1,147 @@ +import json +import logging +import sys +import time +from typing import Callable +from uuid import uuid4 + +from fastapi import FastAPI, Request, Response +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import Message + +logging_config = { + "version": 1, + "formatters": { + "json": { + "class": "pythonjsonlogger.jsonlogger.JsonFormatter", + "format": "%(asctime)s %(process)s %(levelname)s", + } + }, + "handlers": { + "console": { + "level": "DEBUG", + "class": "logging.StreamHandler", + "formatter": "json", + "stream": sys.stderr, + } + }, + "root": {"level": "DEBUG", "handlers": ["console"], "propagate": True}, +} + + +class RouterLoggingMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI, *, logger: logging.Logger, api_debug: bool = False) -> None: + self._logger = logger + self.api_debug = api_debug + super().__init__(app) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + request_header = request.headers.get("x-api-request-id") + if request_header is not None: + request_id = request_header + else: + request_id: str = str(uuid4()) + + logging_dict = {"X-API-REQUEST-ID": request_id} + + if self.api_debug: + await self.set_body(request) + + response, response_dict = await self._log_response( + call_next, request, request_id + ) + request_dict = await self._log_request(request) + logging_dict["request"] = request_dict + logging_dict["response"] = response_dict + + self._logger.info(logging_dict) + + return response + + async def set_body(self, request: Request): + _receive = await request._receive() + + async def receive() -> Message: + return _receive + + request._receive = receive + + async def _log_request(self, request: Request) -> str: + path = request.url.path + if request.query_params: + path += f"?{request.query_params}" + + request_logging = { + "method": request.method, + "path": path, + "ip": request.client.host, + } + + if self.api_debug: + try: + body = await request.json() + request_logging["body"] = body + except ValueError: + body = None + + return request_logging + + async def _log_response( + self, call_next: Callable, request: Request, request_id: str + ) -> Response: + start_time = time.perf_counter() + response = await self._execute_request(call_next, request, request_id) + finish_time = time.perf_counter() + + overall_status = "successful" if response.status_code < 400 else "failed" + execution_time = finish_time - start_time + + response_logging = { + "status": overall_status, + "status_code": response.status_code, + "time_taken": f"{execution_time:0.4f}s", + } + + if self.api_debug: + resp_body = [ + section async for section in response.__dict__["body_iterator"] + ] + response.__setattr__("body_iterator", AsyncIteratorWrapper(resp_body)) + + try: + resp_body = json.loads(resp_body[0].decode()) + except ValueError: + resp_body = str(resp_body) + + response_logging["body"] = resp_body + + return response, response_logging + + async def _execute_request( + self, call_next: Callable, request: Request, request_id: str + ) -> Response: + try: + response: Response = await call_next(request) + + response.headers["X-API-Request-ID"] = request_id + return response + + except Exception as e: + self._logger.exception( + {"path": request.url.path, "method": request.method, "reason": e} + ) + + +class AsyncIteratorWrapper: + def __init__(self, obj): + self._it = iter(obj) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + value = next(self._it) + except StopIteration: + raise StopAsyncIteration + return value