diff --git a/crab/router.py b/crab/router.py index 663076f..7341fb8 100644 --- a/crab/router.py +++ b/crab/router.py @@ -1,9 +1,20 @@ -from flask import Flask, Response, request, abort -from urllib.parse import urlparse -from werkzeug.routing import Rule import os + +import httpx import psutil -import requests +import uvicorn +from starlette.applications import Starlette +from starlette.middleware.cors import ALL_METHODS +from starlette.responses import StreamingResponse, Response +from starlette.routing import Route +from uvicorn.config import LOGGING_CONFIG as UVICORN_LOGGING_CONFIG +from uvicorn.logging import AccessFormatter + + +HEADERS_TO_STRIP = ["server", "date"] + + +client = httpx.AsyncClient(timeout=None) def get_routes(): @@ -18,42 +29,54 @@ def get_routes(): return routes -app = Flask(__name__, static_folder=None) -app.url_map.add(Rule("/", endpoint="proxy", defaults={"path": ""})) -app.url_map.add(Rule("/", endpoint="proxy")) +class CustomAccessFormatter(AccessFormatter): + def get_client_addr(self, scope): + """ + _Pretend_ the client address is actually the hostname. + Makes the log messages much nicer! + """ + if "headers" not in scope: + return super().get_client_addr(scope) + return httpx.Headers(scope["headers"])["Host"] -@app.endpoint("proxy") -def proxy(path): +async def proxy(request): routes = get_routes() - hostname = urlparse(request.base_url).hostname + hostname = request.url.hostname if hostname not in routes: - app.logger.warn(f"No backend for {hostname}") - abort(502) - - path = request.full_path if request.args else request.path - target_url = f"http://localhost:{routes[hostname]}{path}" - app.logger.info(f"Routing request to backend - {request.method} {hostname}{path}") - - downstream_response = requests.request( + return Response(status_code=502, content=f"No backend found for {hostname}.") + target_url = f"http://localhost:{routes[hostname]}{request.url.path}" + if request.query_params: + target_url += f"?{request.query_params}" + body = await request.body() + upstream_response = await client.request( method=request.method, url=target_url, - headers=request.headers, - data=request.get_data(), + data=body, + headers=request.headers.raw, allow_redirects=False, stream=True, ) - return Response( - response=downstream_response.raw.data, - status=downstream_response.status_code, - headers=downstream_response.raw.headers.items(), + + # Strip some headers which uvicorn forcefully adds + upstream_headers = upstream_response.headers + for header_name in HEADERS_TO_STRIP: + if header_name in upstream_headers: + del upstream_headers[header_name] + + return StreamingResponse( + content=upstream_response.raw(), + status_code=upstream_response.status_code, + headers=upstream_headers, ) +app = Starlette(routes=[Route("/(.*)", endpoint=proxy, methods=ALL_METHODS)]) + + def start_on_port(port): - app.run( - port=port, debug=True, use_debugger=False, use_reloader=False, load_dotenv=False - ) + UVICORN_LOGGING_CONFIG["formatters"]["access"]["()"] = CustomAccessFormatter + uvicorn.run(app, port=port, host="0.0.0.0", headers=[("server", "crab")]) def run(): diff --git a/requirements.txt b/requirements.txt index b888a07..d247598 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -Flask==1.0.3 psutil==5.6.3 -requests==2.22.0 +uvicorn==0.10.8 +starlette==0.13.0 +httpx==0.7.8