diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 9711f2005..90dd0a950 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -17,8 +17,9 @@ Response, status, ) +from fastapi.concurrency import iterate_in_threadpool from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import RedirectResponse +from fastapi.responses import RedirectResponse, StreamingResponse from fastapi.security import OAuth2AuthorizationCodeBearer from observability_utils.tracing import ( add_span_attributes, @@ -596,14 +597,32 @@ async def add_api_version_header( async def log_request_details( - request: Request, call_next: Callable[[Request], Awaitable[Response]] + request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]] ) -> Response: - msg = f"method: {request.method} url: {request.url} body: {await request.body()}" + """Middleware to log all request's host, method, path, status and request and + response bodies""" + request_body = await request.body() + + response = await call_next(request) + + # https://github.com/Kludex/starlette/issues/874#issuecomment-1027743996 + response_body = [section async for section in response.body_iterator] + response.body_iterator = iterate_in_threadpool(iter(response_body)) + + log_message = ( + f"{getattr(request.client, 'host', 'NO_ADDRESS')} {request.method}" + f" {request.url.path} {response.status_code}" + ) + + extra = { + "request_body": request_body, + "response_body": response_body, + } if request.url.path == "/healthz": - LOGGER.debug(msg) + LOGGER.debug(log_message, extra=extra) else: - LOGGER.info(msg) - response = await call_next(request) + LOGGER.info(log_message, extra=extra) + return response diff --git a/tests/unit_tests/service/test_main.py b/tests/unit_tests/service/test_main.py index 92d7cd414..2a7e8cb65 100644 --- a/tests/unit_tests/service/test_main.py +++ b/tests/unit_tests/service/test_main.py @@ -13,16 +13,20 @@ async def test_log_request_details(): app = FastAPI() app.middleware("http")(log_request_details) - @app.get("/") + @app.post("/") async def root(): return {"message": "Hello World"} client = TestClient(app) - response = client.get("/") + response = client.post("/", content="foo") assert response.status_code == 200 logger.info.assert_called_once_with( - "method: GET url: http://testserver/ body: b''" + "testclient POST / 200", + extra={ + "request_body": b"foo", + "response_body": [b'{"message":"Hello World"}'], + }, )