-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
386 lines (335 loc) · 15 KB
/
main.py
File metadata and controls
386 lines (335 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""ASGI entrypoint for the OpenAPI MCP server."""
import os
import re
import sys
import json
import logging
import asyncio
from copy import copy
from fastapi import FastAPI, Request, HTTPException, Response
from starlette.middleware.cors import CORSMiddleware
from starlette.types import ASGIApp, Scope, Receive, Send, Message
from .mcp_audit import McpAuditMiddleware
from .storage_backend import read_file
from .memory_store import get_callback_result, set_callback_result
from .mcp_core import mcp
from .apis import (
async_tool,
automotive,
cap,
company,
docuengine,
exchange,
geocoding,
info,
pec,
risk,
sms,
trust,
visurecamerali,
)
# ---------------------------------------------------------------------------
# Bootstrap package logger early — before uvicorn configures its own logging.
# This ensures [MCP] and [api] lines appear whether the server is started via
# `openapi-mcp-sdk server` (run()) or directly via `uvicorn ... main:app`.
# ---------------------------------------------------------------------------
_pkg_log = logging.getLogger("openapi_mcp_sdk")
_pkg_log.setLevel(logging.getLevelName(os.environ.get("LOG_LEVEL", "INFO").upper()))
if not _pkg_log.handlers:
_h = logging.StreamHandler(sys.stderr)
try:
from uvicorn.logging import DefaultFormatter
_h.setFormatter(DefaultFormatter("%(levelprefix)s %(message)s", use_colors=True))
except ImportError:
_h.setFormatter(logging.Formatter("%(levelname)-8s %(message)s"))
_pkg_log.addHandler(_h)
_pkg_log.propagate = False
_logger = logging.getLogger(__name__)
class _SanitizedAccessFormatter:
"""Wraps uvicorn's AccessFormatter to produce a uniform log layout and mask tokens.
Output format matches all other log lines in this project:
[HTTP] X.X.X.X "POST / HTTP/1.1" 200 OK
"""
_TOKEN_RE = re.compile(r'([?&])token=[^&\s"]+')
# Parses the inner formatter output to extract the parts we need to rebuild
_FORMAT_RE = re.compile(r'^(.*?)\[HTTP\] (\S+) - "(.+)" (\S.*)$')
def __init__(self, *args, **kwargs):
try:
from uvicorn.logging import AccessFormatter
self._inner = AccessFormatter(*args, **kwargs)
except ImportError:
self._inner = logging.Formatter(*args, **kwargs)
def format(self, record: logging.LogRecord) -> str:
result = self._inner.format(record)
m = self._FORMAT_RE.match(result)
if m:
levelprefix, addr, req, status = m.groups()
ip = addr.rsplit(":", 1)[0] if ":" in addr else addr
req = self._TOKEN_RE.sub(r'\1token=***', req)
return f'{levelprefix}[HTTP] {ip} "{req}" {status}'
# Fallback: at least mask tokens
return self._TOKEN_RE.sub(r'\1token=***', result)
def __getattr__(self, name):
return getattr(self._inner, name)
# Create the MCP ASGI app mounted at root
mcp_app = mcp.http_app(path='/')
# Create the FastAPI app
app = FastAPI(lifespan=mcp_app.lifespan)
# Attempt to initialize dynamic tools if a token is present in the environment.
# Mark as registered on app.state so JIT registration is skipped for this session.
if docuengine.init_dynamic_tools():
app.state.dynamic_tools_registered = True
# Lock to prevent concurrent duplicate registrations
registration_lock = asyncio.Lock()
# ---------------------------------------------------------------------------
# Pure ASGI middlewares — avoid BaseHTTPMiddleware so that SSE / streaming
# responses from FastMCP pass through untouched. BaseHTTPMiddleware buffers
# the response in an anyio channel and cannot properly forward streaming
# bodies, causing "RuntimeError: No response returned." on POST / requests.
# ---------------------------------------------------------------------------
class Enrich404Middleware:
"""Replace plain 404 responses with a structured JSON body."""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
is_404 = False
async def send_wrapper(message: Message) -> None:
nonlocal is_404
if message["type"] == "http.response.start":
if message.get("status") == 404:
is_404 = True
body = json.dumps({
"error": "not_found",
"method": scope["method"],
"path": scope["path"],
}).encode()
await send({
"type": "http.response.start",
"status": 404,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(body)).encode()),
],
})
await send({"type": "http.response.body", "body": body, "more_body": False})
else:
await send(message)
elif not is_404:
# Pass through body messages only for non-404 responses
await send(message)
await self.app(scope, receive, send_wrapper)
class TokenQuerystringMiddleware:
"""Lift ?token=<value> from the query string into an Authorization: Bearer header.
Also triggers JIT registration of dynamic tools on first authenticated request.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
from urllib.parse import parse_qs, urlencode
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string, keep_blank_values=True)
token: str | None = (params.pop("token", None) or [None])[0]
if token:
# Strip any existing Authorization header and inject the new one.
# Also remove ?token= from the query string so it never appears in
# access logs — the token travels as a header from this point on.
clean_qs = urlencode(params, doseq=True).encode()
headers = [(k, v) for k, v in scope["headers"] if k.lower() != b"authorization"]
headers.append((b"authorization", f"Bearer {token}".encode()))
scope = {**scope, "headers": headers, "query_string": clean_qs}
else:
# Not in query string — try reading from the Authorization header for JIT registration
for k, v in scope["headers"]:
if k.lower() == b"authorization":
val = v.decode()
if val.lower().startswith("bearer "):
token = val[7:]
break
# JIT Registration: if we have a token and tools are not yet registered, proceed.
# Run the sync HTTP call in a thread so we never block the event loop.
if token and not getattr(app.state, "dynamic_tools_registered", False):
client_ip = (scope.get("client") or ("?", 0))[0]
async with registration_lock:
if not getattr(app.state, "dynamic_tools_registered", False):
_logger.info(
'%s %s "Initializing dynamic tools (token: %s...)"',
"[JIT]",
client_ip,
token[:8],
)
success = await asyncio.to_thread(docuengine.init_dynamic_tools, token)
if success:
app.state.dynamic_tools_registered = True
else:
_logger.warning(
'%s %s "Registration failed — will retry on next request"',
"[JIT]",
client_ip,
)
await self.app(scope, receive, send)
_OAUTH_NOT_SUPPORTED_BODY = json.dumps({
"error": "oauth_not_supported",
"message": (
"This server does not support OAuth. Use a pre-configured Bearer "
"token in the Authorization header."
),
}).encode()
_OAUTH_DISCOVERY_PATHS = {
"/.well-known/oauth-authorization-server",
"/.well-known/oauth-protected-resource",
"/.well-known/openid-configuration",
}
class RejectOAuthDiscoveryMiddleware:
"""Return a JSON 404 for OAuth discovery endpoints."""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
scope["type"] == "http"
and scope["path"].rstrip("/") in {p.rstrip("/") for p in _OAUTH_DISCOVERY_PATHS}
):
await send({
"type": "http.response.start",
"status": 404,
"headers": [(b"content-type", b"application/json")],
})
await send({
"type": "http.response.body",
"body": _OAUTH_NOT_SUPPORTED_BODY,
"more_body": False,
})
return
await self.app(scope, receive, send)
# Register middlewares.
# add_middleware inserts at position 0 each time, so the LAST call here becomes
# the outermost wrapper (executed first on every request).
# Desired execution order (outer → inner):
# CORS → RejectOAuth → TokenQuerystring → Enrich404 → McpAudit → FastAPI router
app.add_middleware(McpAuditMiddleware) # innermost — logs MCP JSON-RPC actions
app.add_middleware(Enrich404Middleware) # wraps 404s in JSON
app.add_middleware(TokenQuerystringMiddleware) # injects auth header + JIT registration
app.add_middleware(RejectOAuthDiscoveryMiddleware) # short-circuits OAuth discovery paths
# CORS must be outermost so it runs before anything else on every request,
# including pre-flight OPTIONS. expose_headers exposes Mcp-Session-Id to browsers.
# allow_credentials must NOT be True when allow_origins=["*"].
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["mcp-session-id"],
)
# ---------------------------------------------------------------------------
# Plain HTTP REST endpoints (outside MCP/JSON-RPC)
# ---------------------------------------------------------------------------
@app.post("/callbacks")
async def callbacks_endpoint(request: Request):
client_ip = request.client.host if request.client else "?"
# Read raw body and deserialize to an object
raw_body = await request.body()
try:
callback = json.loads(raw_body)
except Exception:
_logger.warning('%s %s "Invalid JSON body"', "[CB]", client_ip)
return {"status": "error", "message": "Body not a valid JSON"}
cb_obj = callback.get("callback")
custom = callback.get("custom") or (
cb_obj.get("data") if isinstance(cb_obj, dict) else None
)
if not custom:
_logger.warning('%s %s "Missing callback.custom field"', "[CB]", client_ip)
return {"status": "error", "message": "'callback.custom' missing from received data"}
request_id = custom.get("request_id")
if not request_id:
_logger.warning('%s %s "Missing request_id in custom field"', "[CB]", client_ip)
return {"status": "error", "message": "'request_id' missing from custom field"}
data = callback.get("data",{}) or callback
if not data:
_logger.warning('%s %s "Missing callback.data field"', "[CB]", client_ip)
return {"status": "error", "message": "'callback.data' missing from received data"}
# Store the result keyed by request_id (overwrites on subsequent callbacks)
set_callback_result(request_id, data, custom)
_logger.info('%s %s "Received" request_id=%s', "[CB]", client_ip, request_id)
return {"status": "ok"}
@app.get("/status/{request_id}")
async def get_status(request_id: str):
try:
return get_callback_result(request_id)
except KeyError:
raise HTTPException(status_code=404, detail="Not Found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
@app.get("/status/{request_id}/files/{file_name}")
async def get_file(request_id: str,file_name: str):
try:
file_content, content_type = read_file(f"{request_id}/{file_name}")
return Response(content=file_content, media_type=content_type)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error retrieving file: {str(e)}")
# Mount MCP at root; /callbacks and /status/* are handled by FastAPI above
app.mount("/", mcp_app)
# ---------------------------------------------------------------------------
# Logging configuration
# ---------------------------------------------------------------------------
# Extends uvicorn's default log config so that openapi_mcp_sdk loggers
# (audit, mcp_core, …) emit at INFO using the same format as uvicorn itself.
# Without this, custom loggers have no handler and stay silent.
_LOG_CONFIG: dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(message)s",
"use_colors": None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s',
},
"sanitized_access": {
"()": "openapi_mcp_sdk.main._SanitizedAccessFormatter",
"fmt": '%(levelprefix)s [HTTP] %(client_addr)s - "%(request_line)s" %(status_code)s',
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"access": {
"formatter": "sanitized_access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
},
},
"loggers": {
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
"uvicorn.error": {"handlers": ["default"], "level": "INFO", "propagate": False},
"uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False},
# Our package loggers — same handler/format as uvicorn, INFO by default.
# Set LOG_LEVEL=debug in the environment to promote to DEBUG.
"openapi_mcp_sdk": {
"handlers": ["default"],
"level": os.environ.get("LOG_LEVEL", "INFO").upper(),
"propagate": False,
},
},
}
def run():
import uvicorn
port = int(os.environ.get("MCP_PORT", 8080))
uvicorn.run(app, host="0.0.0.0", port=port, log_config=_LOG_CONFIG)
if __name__ == "__main__":
run()