diff --git a/.changeset/add-get-request-handler.md b/.changeset/add-get-request-handler.md new file mode 100644 index 000000000..3387122bf --- /dev/null +++ b/.changeset/add-get-request-handler.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/core': minor +--- + +Add `getRequestHandler()` method to `Protocol`, enabling retrieval and wrapping of existing request handlers. This allows composable handler middleware without re-implementing SDK internals — for example, transforming `tools/list` responses by wrapping the default handler. diff --git a/packages/core/src/shared/protocol.examples.ts b/packages/core/src/shared/protocol.examples.ts new file mode 100644 index 000000000..8e203ce68 --- /dev/null +++ b/packages/core/src/shared/protocol.examples.ts @@ -0,0 +1,26 @@ +/** + * Type-checked examples for `protocol.ts`. + * + * These examples are synced into JSDoc comments via the sync-snippets script. + * Each function's region markers define the code snippet that appears in the docs. + * + * @module + */ + +import type { BaseContext, Protocol } from './protocol.js'; + +/** + * Example: Wrapping an existing request handler with getRequestHandler. + */ +function getRequestHandler_wrapping(protocol: Protocol) { + //#region getRequestHandler_wrapping + const original = protocol.getRequestHandler('tools/list'); + if (original) { + protocol.setRequestHandler('tools/list', async (request, ctx) => { + const result = await original(request, ctx); + // Transform the result before returning + return result; + }); + } + //#endregion getRequestHandler_wrapping +} diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index b82731582..61694f24e 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1508,6 +1508,36 @@ export abstract class Protocol { }); } + /** + * Returns the current request handler for the given method, or undefined if none is registered. + * + * The returned function is a snapshot — it captures the handler registered at call time. + * If the handler is later replaced or removed, the previously returned function still + * delegates to the original handler. + * + * Note: the returned handler includes the SDK's internal schema validation layer, so + * requests passed to it will be re-validated. This is harmless but redundant when + * wrapping an existing handler in the standard pattern below. + * + * ```ts source="./protocol.examples.ts#getRequestHandler_wrapping" + * const original = protocol.getRequestHandler('tools/list'); + * if (original) { + * protocol.setRequestHandler('tools/list', async (request, ctx) => { + * const result = await original(request, ctx); + * // Transform the result before returning + * return result; + * }); + * } + * ``` + */ + getRequestHandler( + method: M + ): ((request: RequestTypeMap[M], ctx: ContextT) => Promise) | undefined { + const raw = this._requestHandlers.get(method); + if (!raw) return undefined; + return (request, ctx) => raw(request as unknown as JSONRPCRequest, ctx) as Promise; + } + /** * Removes the request handler for the given method. */ diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 8675c1e03..7320c7065 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -5723,3 +5723,169 @@ describe('Error handling for missing resolvers', () => { }); }); }); + +describe('getRequestHandler', () => { + let protocol: Protocol; + let transport: MockTransport; + let sendSpy: MockInstance; + + beforeEach(() => { + transport = new MockTransport(); + sendSpy = vi.spyOn(transport, 'send'); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected buildContext(ctx: BaseContext): BaseContext { + return ctx; + } + protected assertTaskHandlerCapability(): void {} + })(); + }); + + it('should return undefined for unregistered methods', () => { + const handler = protocol.getRequestHandler('tools/list'); + expect(handler).toBeUndefined(); + }); + + it('should return a callable handler after setRequestHandler', async () => { + await protocol.connect(transport); + + protocol.setRequestHandler('ping', async () => { + return {}; + }); + + const handler = protocol.getRequestHandler('ping'); + expect(handler).toBeDefined(); + expect(typeof handler).toBe('function'); + }); + + it('should return undefined after removeRequestHandler', () => { + protocol.setRequestHandler('ping', async () => { + return {}; + }); + + protocol.removeRequestHandler('ping'); + expect(protocol.getRequestHandler('ping')).toBeUndefined(); + }); + + it('should reflect the latest handler after replacement', () => { + protocol.setRequestHandler('ping', async () => { + return {}; + }); + + const handlerA = protocol.getRequestHandler('ping'); + + protocol.setRequestHandler('ping', async () => { + return {}; + }); + + const handlerB = protocol.getRequestHandler('ping'); + + expect(handlerA).toBeDefined(); + expect(handlerB).toBeDefined(); + expect(handlerA).not.toBe(handlerB); + }); + + it('should return a snapshot that still works after the handler is replaced', async () => { + await protocol.connect(transport); + + protocol.setRequestHandler('ping', async () => { + return {}; + }); + + const snapshot = protocol.getRequestHandler('ping')!; + + // Replace with a different handler + protocol.setRequestHandler('ping', async () => { + return {}; + }); + + // Simulate incoming request — the snapshot is used inside the new wrapper + const calls: string[] = []; + protocol.setRequestHandler('ping', async (request, ctx) => { + calls.push('new'); + await snapshot(request, ctx); + calls.push('snapshot-called'); + return {}; + }); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'ping', + params: {} + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + expect(calls).toEqual(['new', 'snapshot-called']); + }); + + it('should enable wrapping an existing handler and transforming results', async () => { + await protocol.connect(transport); + + const calls: string[] = []; + + protocol.setRequestHandler('ping', async () => { + calls.push('original'); + return {}; + }); + + const original = protocol.getRequestHandler('ping')!; + + protocol.setRequestHandler('ping', async (request, ctx) => { + calls.push('wrapper-before'); + const result = await original(request, ctx); + calls.push('wrapper-after'); + return result; + }); + + // Simulate incoming ping request + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'ping', + params: {} + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(calls).toEqual(['wrapper-before', 'original', 'wrapper-after']); + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ id: 1, jsonrpc: '2.0', result: {} })); + }); + + it('should propagate errors from the original handler', async () => { + await protocol.connect(transport); + + protocol.setRequestHandler('ping', async () => { + throw new ProtocolError(ProtocolErrorCode.InternalError, 'original failed'); + }); + + const original = protocol.getRequestHandler('ping')!; + + protocol.setRequestHandler('ping', async (request, ctx) => { + return original(request, ctx); + }); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'ping', + params: {} + }); + + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + id: 1, + jsonrpc: '2.0', + error: expect.objectContaining({ + code: ProtocolErrorCode.InternalError, + message: 'original failed' + }) + }) + ); + }); +});