diff --git a/__tests__/unserializable.test.ts b/__tests__/unserializable.test.ts new file mode 100644 index 00000000..c7754aed --- /dev/null +++ b/__tests__/unserializable.test.ts @@ -0,0 +1,193 @@ +import { beforeEach, describe, expect, test } from 'vitest'; +import { Type } from '@sinclair/typebox'; +import { + Procedure, + createServiceSchema, + Ok, + createClient, + createServer, + UNEXPECTED_DISCONNECT_CODE, +} from '../router'; +import { testMatrix } from '../testUtil/fixtures/matrix'; +import { + advanceFakeTimersBySessionGrace, + cleanupTransports, + createPostTestCleanups, +} from '../testUtil/fixtures/cleanup'; +import { TestSetupHelpers } from '../testUtil/fixtures/transports'; +import { readNextResult } from '../testUtil'; + +const ServiceSchema = createServiceSchema(); + +const UnserializableServiceSchema = ServiceSchema.define({ + returnSymbol: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({ id: Type.String() }), + async handler() { + return Ok({ id: 'test', extra: Symbol('unserializable') }); + }, + }), + streamSymbol: Procedure.subscription({ + requestInit: Type.Object({}), + responseData: Type.Object({ id: Type.String() }), + async handler({ resWritable }) { + resWritable.write(Ok({ id: 'test', extra: Symbol('unserializable') })); + resWritable.close(); + }, + }), +}); + +describe('unserializable values in procedure handlers', () => { + // binary codec (msgpack) throws on Symbol, causing encode failure + // which kills the session -- only test with ws transport since mock + // transport's setImmediate chains conflict with fake timer flushing + describe.each(testMatrix(['ws', 'binary']))( + 'binary codec ($transport.name transport)', + ({ transport, codec }) => { + const opts = { codec: codec.codec }; + const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); + let getClientTransport: TestSetupHelpers['getClientTransport']; + let getServerTransport: TestSetupHelpers['getServerTransport']; + + beforeEach(async () => { + const setup = await transport.setup({ client: opts, server: opts }); + getClientTransport = setup.getClientTransport; + getServerTransport = setup.getServerTransport; + + return async () => { + await postTestCleanup(); + await setup.cleanup(); + }; + }); + + test('rpc handler returning symbol causes client disconnect', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { svc: UnserializableServiceSchema }; + createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + addPostTestCleanup(() => + cleanupTransports([clientTransport, serverTransport]), + ); + + const resultPromise = client.svc.returnSymbol.rpc({}); + await advanceFakeTimersBySessionGrace(); + + const result = await resultPromise; + expect(result).toMatchObject({ + ok: false, + payload: { + code: UNEXPECTED_DISCONNECT_CODE, + }, + }); + }); + + test('client-side encode failure cleans up listeners', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { svc: UnserializableServiceSchema }; + createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + addPostTestCleanup(() => + cleanupTransports([clientTransport, serverTransport]), + ); + + const messageListenersBefore = + clientTransport.eventDispatcher.numberOfListeners('message'); + const sessionStatusListenersBefore = + clientTransport.eventDispatcher.numberOfListeners('sessionStatus'); + + // sending a Symbol as init payload will fail encoding on the client side + expect(() => + // eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-explicit-any + client.svc.returnSymbol.rpc({ extra: Symbol('x') } as any), + ).toThrow(); + + // listeners should not leak after the failed send + expect( + clientTransport.eventDispatcher.numberOfListeners('message'), + ).toEqual(messageListenersBefore); + expect( + clientTransport.eventDispatcher.numberOfListeners('sessionStatus'), + ).toEqual(sessionStatusListenersBefore); + }); + + test('subscription handler writing symbol causes client disconnect', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { svc: UnserializableServiceSchema }; + createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + addPostTestCleanup(() => + cleanupTransports([clientTransport, serverTransport]), + ); + + const { resReadable } = client.svc.streamSymbol.subscribe({}); + await advanceFakeTimersBySessionGrace(); + + const result = await readNextResult(resReadable); + expect(result).toMatchObject({ + ok: false, + payload: { + code: UNEXPECTED_DISCONNECT_CODE, + }, + }); + }); + }, + ); + + // json codec silently drops Symbol values via JSON.stringify + describe.each(testMatrix(['all', 'naive']))( + 'json codec ($transport.name transport)', + ({ transport, codec }) => { + const opts = { codec: codec.codec }; + const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); + let getClientTransport: TestSetupHelpers['getClientTransport']; + let getServerTransport: TestSetupHelpers['getServerTransport']; + + beforeEach(async () => { + const setup = await transport.setup({ client: opts, server: opts }); + getClientTransport = setup.getClientTransport; + getServerTransport = setup.getServerTransport; + + return async () => { + await postTestCleanup(); + await setup.cleanup(); + }; + }); + + test('rpc handler returning symbol silently drops the value', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { svc: UnserializableServiceSchema }; + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + addPostTestCleanup(() => + cleanupTransports([clientTransport, serverTransport]), + ); + + const result = await client.svc.returnSymbol.rpc({}); + // JSON.stringify silently drops Symbol values, so the + // response arrives with the extra symbol field missing + expect(result).toStrictEqual({ + ok: true, + payload: { id: 'test' }, + }); + + await server.close(); + }); + }, + ); +}); diff --git a/package-lock.json b/package-lock.json index e590b98b..ace3e0bd 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.214.0", + "version": "0.215.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.214.0", + "version": "0.215.0", "license": "MIT", "dependencies": { "@msgpack/msgpack": "^3.1.2", diff --git a/package.json b/package.json index cdca6edd..e1d6b519 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.214.0", + "version": "0.215.0", "type": "module", "exports": { ".": { diff --git a/router/client.ts b/router/client.ts index 798974bc..818a02bc 100644 --- a/router/client.ts +++ b/router/client.ts @@ -508,16 +508,21 @@ function handleProc( transport.addEventListener('message', onMessage); transport.addEventListener('sessionStatus', onSessionStatus); - sessionScopedSend({ - streamId, - serviceName, - procedureName, - tracing: getPropagationContext(ctx), - payload: init, - controlFlags: procClosesWithInit - ? ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit - : ControlFlags.StreamOpenBit, - }); + try { + sessionScopedSend({ + streamId, + serviceName, + procedureName, + tracing: getPropagationContext(ctx), + payload: init, + controlFlags: procClosesWithInit + ? ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit + : ControlFlags.StreamOpenBit, + }); + } catch (e) { + cleanup(); + throw e; + } if (procClosesWithInit) { reqWritable.close(); diff --git a/testUtil/fixtures/cleanup.ts b/testUtil/fixtures/cleanup.ts index 2ac0ed43..43cf38eb 100644 --- a/testUtil/fixtures/cleanup.ts +++ b/testUtil/fixtures/cleanup.ts @@ -1,8 +1,7 @@ -import { expect, vi } from 'vitest'; +import { assert, expect, vi } from 'vitest'; import { ClientTransport, Connection, - OpaqueTransportMessage, ServerTransport, Transport, } from '../../transport'; @@ -68,14 +67,17 @@ export async function ensureTransportBuffersAreEventuallyEmpty( [...t.sessions] .map(([client, sess]) => { // get all messages that are not heartbeats - const buff = sess.sendBuffer.filter((msg) => { - return !Value.Check(ControlMessageAckSchema, msg.payload); + const buff = sess.sendBuffer.filter((encodedMsg) => { + const decoded = sess.codec.fromBuffer(encodedMsg.data); + assert(decoded.ok); + + return !Value.Check( + ControlMessageAckSchema, + decoded.value.payload, + ); }); - return [client, buff] as [ - string, - ReadonlyArray, - ]; + return [client, buff] as const; }) .filter((entry) => entry[1].length > 0), ), diff --git a/testUtil/index.ts b/testUtil/index.ts index dc3dee0b..ae9a4330 100644 --- a/testUtil/index.ts +++ b/testUtil/index.ts @@ -184,6 +184,9 @@ export function dummySession() { onSessionGracePeriodElapsed: () => { /* noop */ }, + onMessageSendFailure: () => { + /* noop */ + }, }, testingSessionOptions, currentProtocolVersion, diff --git a/transport/client.ts b/transport/client.ts index 5927e4f9..0d5d876c 100644 --- a/transport/client.ts +++ b/transport/client.ts @@ -107,6 +107,18 @@ export abstract class ClientTransport< onSessionGracePeriodElapsed: () => { this.onSessionGracePeriodElapsed(session); }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...session.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(session, { unhealthy: true }); + }, }, this.options, currentProtocolVersion, @@ -186,6 +198,18 @@ export abstract class ClientTransport< onSessionGracePeriodElapsed: () => { this.onSessionGracePeriodElapsed(handshakingSession); }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...handshakingSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(handshakingSession, { unhealthy: true }); + }, }, ); @@ -395,6 +419,18 @@ export abstract class ClientTransport< onSessionGracePeriodElapsed: () => { this.onSessionGracePeriodElapsed(backingOffSession); }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...backingOffSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(backingOffSession, { unhealthy: true }); + }, }, ); @@ -470,6 +506,18 @@ export abstract class ClientTransport< onSessionGracePeriodElapsed: () => { this.onSessionGracePeriodElapsed(connectingSession); }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...connectingSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(connectingSession, { unhealthy: true }); + }, }, ); diff --git a/transport/message.ts b/transport/message.ts index 1879341d..40861281 100644 --- a/transport/message.ts +++ b/transport/message.ts @@ -280,6 +280,18 @@ export function cancelMessage( export type OpaqueTransportMessage = TransportMessage; export type TransportClientId = string; +/** + * An encoded message that is ready to be sent over the transport. + * The seq number is kept to track which messages have been + * acked by the peer and can be dropped from the send buffer. + */ +export interface EncodedTransportMessage { + id: string; + seq: number; + msg: PartialTransportMessage; + data: Uint8Array; +} + /** * Checks if the given control flag (usually found in msg.controlFlag) is an ack message. * @param controlFlag - The control flag to check. diff --git a/transport/results.ts b/transport/results.ts index ff8c12da..df9879b9 100644 --- a/transport/results.ts +++ b/transport/results.ts @@ -1,4 +1,4 @@ -import { OpaqueTransportMessage } from './message'; +import { EncodedTransportMessage, OpaqueTransportMessage } from './message'; // internal use only, not to be used in public API type SessionApiResult = @@ -13,5 +13,6 @@ type SessionApiResult = export type SendResult = SessionApiResult; export type SendBufferResult = SessionApiResult; +export type EncodeResult = SessionApiResult; export type SerializeResult = SessionApiResult; export type DeserializeResult = SessionApiResult; diff --git a/transport/server.ts b/transport/server.ts index 8a7f8554..78cf8042 100644 --- a/transport/server.ts +++ b/transport/server.ts @@ -408,6 +408,18 @@ export abstract class ServerTransport< onSessionGracePeriodElapsed: () => { this.onSessionGracePeriodElapsed(noConnectionSession); }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...noConnectionSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(noConnectionSession, { unhealthy: true }); + }, }, ); diff --git a/transport/sessionStateMachine/SessionConnected.ts b/transport/sessionStateMachine/SessionConnected.ts index 747c02cb..1316ca5f 100644 --- a/transport/sessionStateMachine/SessionConnected.ts +++ b/transport/sessionStateMachine/SessionConnected.ts @@ -2,26 +2,25 @@ import { Static } from '@sinclair/typebox'; import { ControlFlags, ControlMessageAckSchema, + EncodedTransportMessage, OpaqueTransportMessage, PartialTransportMessage, - TransportMessage, isAck, } from '../message'; import { IdentifiedSession, + IdentifiedSessionListeners, IdentifiedSessionProps, - sendMessage, SessionState, } from './common'; import { Connection } from '../connection'; import { SpanStatusCode } from '@opentelemetry/api'; import { SendBufferResult, SendResult } from '../results'; -export interface SessionConnectedListeners { +export interface SessionConnectedListeners extends IdentifiedSessionListeners { onConnectionErrored: (err: unknown) => void; onConnectionClosed: () => void; onMessage: (msg: OpaqueTransportMessage) => void; - onMessageSendFailure: (msg: PartialTransportMessage, reason: string) => void; onInvalidMessage: (reason: string) => void; } @@ -57,12 +56,11 @@ export class SessionConnected< this.startMissingHeartbeatTimeout(); } - private assertSendOrdering(constructedMsg: TransportMessage) { - if (constructedMsg.seq > this.seqSent + 1) { - const msg = `invariant violation: would have sent out of order msg (seq: ${constructedMsg.seq}, expected: ${this.seqSent} + 1)`; + private assertSendOrdering(encodedMsg: EncodedTransportMessage) { + if (encodedMsg.seq > this.seqSent + 1) { + const msg = `invariant violation: would have sent out of order msg (seq: ${encodedMsg.seq}, expected: ${this.seqSent} + 1)`; this.log?.error(msg, { ...this.loggingMetadata, - transportMessage: constructedMsg, tags: ['invariant-violation'], }); @@ -71,19 +69,29 @@ export class SessionConnected< } send(msg: PartialTransportMessage): SendResult { - const constructedMsg = this.constructMsg(msg); - this.assertSendOrdering(constructedMsg); - this.sendBuffer.push(constructedMsg); - const res = sendMessage(this.conn, this.codec, constructedMsg); - if (!res.ok) { - this.listeners.onMessageSendFailure(constructedMsg, res.reason); - - return res; + const encodeResult = this.encodeMsg(msg); + if (!encodeResult.ok) { + return encodeResult; } - this.seqSent = constructedMsg.seq; + const encodedMsg = encodeResult.value; + this.assertSendOrdering(encodedMsg); + this.sendBuffer.push(encodedMsg); - return res; + const sent = this.conn.send(encodedMsg.data); + if (!sent) { + const reason = 'failed to send message'; + this.listeners.onMessageSendFailure( + { ...encodedMsg.msg, seq: encodedMsg.seq }, + reason, + ); + + return { ok: false, reason }; + } + + this.seqSent = encodedMsg.seq; + + return { ok: true, value: encodedMsg.id }; } constructor(props: SessionConnectedProps) { @@ -110,11 +118,16 @@ export class SessionConnected< for (const msg of this.sendBuffer) { this.assertSendOrdering(msg); - const res = sendMessage(this.conn, this.codec, msg); - if (!res.ok) { - this.listeners.onMessageSendFailure(msg, res.reason); - return res; + const sent = this.conn.send(msg.data); + if (!sent) { + const reason = 'failed to send buffered message'; + this.listeners.onMessageSendFailure( + { ...msg.msg, seq: msg.seq }, + reason, + ); + + return { ok: false, reason }; } this.seqSent = msg.seq; diff --git a/transport/sessionStateMachine/SessionHandshaking.ts b/transport/sessionStateMachine/SessionHandshaking.ts index b7c14a02..50359952 100644 --- a/transport/sessionStateMachine/SessionHandshaking.ts +++ b/transport/sessionStateMachine/SessionHandshaking.ts @@ -9,7 +9,6 @@ import { IdentifiedSessionWithGracePeriod, IdentifiedSessionWithGracePeriodListeners, IdentifiedSessionWithGracePeriodProps, - sendMessage, SessionState, } from './common'; import { SendResult } from '../results'; @@ -83,7 +82,23 @@ export class SessionHandshaking< }; sendHandshake(msg: TransportMessage): SendResult { - return sendMessage(this.conn, this.codec, msg); + const buff = this.codec.toBuffer(msg); + if (!buff.ok) { + return buff; + } + + const sent = this.conn.send(buff.value); + if (!sent) { + return { + ok: false, + reason: 'failed to send handshake', + }; + } + + return { + ok: true, + value: msg.id, + }; } _handleStateExit(): void { diff --git a/transport/sessionStateMachine/SessionWaitingForHandshake.ts b/transport/sessionStateMachine/SessionWaitingForHandshake.ts index d88685de..6a8107d9 100644 --- a/transport/sessionStateMachine/SessionWaitingForHandshake.ts +++ b/transport/sessionStateMachine/SessionWaitingForHandshake.ts @@ -5,12 +5,7 @@ import { OpaqueTransportMessage, TransportMessage, } from '../message'; -import { - CommonSession, - CommonSessionProps, - sendMessage, - SessionState, -} from './common'; +import { CommonSession, CommonSessionProps, SessionState } from './common'; import { SendResult } from '../results'; export interface SessionWaitingForHandshakeListeners { @@ -84,7 +79,23 @@ export class SessionWaitingForHandshake< }; sendHandshake(msg: TransportMessage): SendResult { - return sendMessage(this.conn, this.codec, msg); + const buff = this.codec.toBuffer(msg); + if (!buff.ok) { + return buff; + } + + const sent = this.conn.send(buff.value); + if (!sent) { + return { + ok: false, + reason: 'failed to send handshake', + }; + } + + return { + ok: true, + value: msg.id, + }; } _handleStateExit(): void { diff --git a/transport/sessionStateMachine/common.ts b/transport/sessionStateMachine/common.ts index 2db1499b..e4b46e23 100644 --- a/transport/sessionStateMachine/common.ts +++ b/transport/sessionStateMachine/common.ts @@ -1,17 +1,15 @@ import { Logger, MessageMetadata } from '../../logging'; import { TelemetryInfo } from '../../tracing'; import { - OpaqueTransportMessage, + EncodedTransportMessage, PartialTransportMessage, ProtocolVersion, TransportClientId, - TransportMessage, } from '../message'; import { Codec, CodecMessageAdapter } from '../../codec'; import { generateId } from '../id'; import { Tracer } from '@opentelemetry/api'; -import { SendResult } from '../results'; -import { Connection } from '../connection'; +import { EncodeResult, SendResult } from '../results'; export const enum SessionState { NoConnection = 'NoConnection', @@ -174,11 +172,26 @@ export abstract class CommonSession extends StateMachineState { export type InheritedProperties = Pick< IdentifiedSession, - 'id' | 'from' | 'to' | 'seq' | 'ack' | 'sendBuffer' | 'telemetry' | 'options' + | 'id' + | 'from' + | 'to' + | 'seq' + | 'ack' + | 'seqSent' + | 'sendBuffer' + | 'telemetry' + | 'options' >; export type SessionId = string; +export interface IdentifiedSessionListeners { + onMessageSendFailure: ( + msg: PartialTransportMessage & { seq: number }, + reason: string, + ) => void; +} + // all sessions where we know the other side's client id export interface IdentifiedSessionProps extends CommonSessionProps { id: SessionId; @@ -186,9 +199,10 @@ export interface IdentifiedSessionProps extends CommonSessionProps { seq: number; ack: number; seqSent: number; - sendBuffer: Array; + sendBuffer: Array; telemetry: TelemetryInfo; protocolVersion: ProtocolVersion; + listeners: IdentifiedSessionListeners; } export abstract class IdentifiedSession extends CommonSession { @@ -196,6 +210,7 @@ export abstract class IdentifiedSession extends CommonSession { readonly telemetry: TelemetryInfo; readonly to: TransportClientId; readonly protocolVersion: ProtocolVersion; + listeners: IdentifiedSessionListeners; /** * Index of the message we will send next (excluding handshake) @@ -211,7 +226,7 @@ export abstract class IdentifiedSession extends CommonSession { * Number of unique messages we've received this session (excluding handshake) */ ack: number; - sendBuffer: Array; + sendBuffer: Array; constructor(props: IdentifiedSessionProps) { const { @@ -224,6 +239,7 @@ export abstract class IdentifiedSession extends CommonSession { log, protocolVersion, seqSent: messagesSent, + listeners, } = props; super(props); this.id = id; @@ -235,6 +251,7 @@ export abstract class IdentifiedSession extends CommonSession { this.log = log; this.protocolVersion = protocolVersion; this.seqSent = messagesSent; + this.listeners = listeners; } get loggingMetadata(): MessageMetadata { @@ -255,9 +272,7 @@ export abstract class IdentifiedSession extends CommonSession { return metadata; } - constructMsg( - partialMsg: PartialTransportMessage, - ): TransportMessage { + encodeMsg(partialMsg: PartialTransportMessage): EncodeResult { const msg = { ...partialMsg, id: generateId(), @@ -267,9 +282,29 @@ export abstract class IdentifiedSession extends CommonSession { ack: this.ack, }; + const encoded = this.codec.toBuffer(msg); + if (!encoded.ok) { + // safety: onMessageSendFailure tears down the session via protocol error, + // which emits sessionStatus 'closing' and cleans up all procedure listeners. + this.listeners.onMessageSendFailure( + { ...partialMsg, seq: this.seq }, + encoded.reason, + ); + + return encoded; + } + this.seq++; - return msg; + return { + ok: true, + value: { + id: msg.id, + seq: msg.seq, + msg: partialMsg, + data: encoded.value, + }, + }; } nextSeq(): number { @@ -277,12 +312,16 @@ export abstract class IdentifiedSession extends CommonSession { } send(msg: PartialTransportMessage): SendResult { - const constructedMsg = this.constructMsg(msg); - this.sendBuffer.push(constructedMsg); + const encodeResult = this.encodeMsg(msg); + if (!encodeResult.ok) { + return encodeResult; + } + + this.sendBuffer.push(encodeResult.value); return { ok: true, - value: constructedMsg.id, + value: encodeResult.value.id, }; } @@ -297,7 +336,8 @@ export abstract class IdentifiedSession extends CommonSession { } } -export interface IdentifiedSessionWithGracePeriodListeners { +export interface IdentifiedSessionWithGracePeriodListeners + extends IdentifiedSessionListeners { onSessionGracePeriodElapsed: () => void; } @@ -336,27 +376,3 @@ export abstract class IdentifiedSessionWithGracePeriod extends IdentifiedSession super._handleClose(); } } - -export function sendMessage( - conn: Connection, - codec: CodecMessageAdapter, - msg: TransportMessage, -): SendResult { - const buff = codec.toBuffer(msg); - if (!buff.ok) { - return buff; - } - - const sent = conn.send(buff.value); - if (!sent) { - return { - ok: false, - reason: 'failed to send message', - }; - } - - return { - ok: true, - value: msg.id, - }; -} diff --git a/transport/sessionStateMachine/index.ts b/transport/sessionStateMachine/index.ts index 9d75122a..03830f81 100644 --- a/transport/sessionStateMachine/index.ts +++ b/transport/sessionStateMachine/index.ts @@ -1,7 +1,10 @@ export { SessionState } from './common'; export { type SessionWaitingForHandshake } from './SessionWaitingForHandshake'; export { type SessionConnecting } from './SessionConnecting'; -export { type SessionNoConnection } from './SessionNoConnection'; +export { + type SessionNoConnection, + type SessionNoConnectionListeners, +} from './SessionNoConnection'; export { type SessionHandshaking } from './SessionHandshaking'; export { type SessionConnected } from './SessionConnected'; export { diff --git a/transport/sessionStateMachine/stateMachine.test.ts b/transport/sessionStateMachine/stateMachine.test.ts index 725cd8e4..0d825daa 100644 --- a/transport/sessionStateMachine/stateMachine.test.ts +++ b/transport/sessionStateMachine/stateMachine.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, test, vi } from 'vitest'; +import { assert, describe, expect, test, vi } from 'vitest'; import { payloadToTransportMessage, testingSessionOptions, @@ -99,6 +99,7 @@ function getPendingMockConnection(): PendingMockConnectionHandle { function createSessionNoConnectionListeners(): SessionNoConnectionListeners { return { onSessionGracePeriodElapsed: vi.fn(), + onMessageSendFailure: vi.fn(), }; } @@ -106,6 +107,7 @@ function createSessionBackingOffListeners(): SessionBackingOffListeners { return { onBackoffFinished: vi.fn(), onSessionGracePeriodElapsed: vi.fn(), + onMessageSendFailure: vi.fn(), }; } @@ -115,6 +117,7 @@ function createSessionConnectingListeners(): SessionConnectingListeners { onConnectionFailed: vi.fn(), onConnectionTimeout: vi.fn(), onSessionGracePeriodElapsed: vi.fn(), + onMessageSendFailure: vi.fn(), }; } @@ -126,6 +129,7 @@ function createSessionHandshakingListeners(): SessionHandshakingListeners { onConnectionErrored: vi.fn(), onHandshakeTimeout: vi.fn(), onSessionGracePeriodElapsed: vi.fn(), + onMessageSendFailure: vi.fn(), }; } @@ -1807,6 +1811,103 @@ describe('session state machine', () => { }); }); + test('handshaking sendHandshake: codec failure does not corrupt seq and subsequent success works', async () => { + const sessionHandle = await createSessionHandshaking(); + const session = sessionHandle.session; + + // buffer some messages during handshake + session.send(payloadToTransportMessage('hello')); + session.send(payloadToTransportMessage('world')); + expect(session.seq).toBe(2); + expect(session.ack).toBe(0); + expect(session.sendBuffer.length).toBe(2); + + const msg = handshakeRequestMessage({ + from: 'from', + to: 'to', + sessionId: 'clientSessionId', + expectedSessionState: { + nextExpectedSeq: 0, + nextSentSeq: 0, + }, + }); + + // make codec.toBuffer fail + const spy = vi + .spyOn(session.codec, 'toBuffer') + .mockReturnValue({ ok: false, reason: 'encode error' }); + + const res = session.sendHandshake(msg); + expect(res.ok).toBe(false); + assert(!res.ok); + expect(res.reason).toBe('encode error'); + expect(session.conn.send).not.toHaveBeenCalled(); + + // seq/ack/sendBuffer should be unchanged + expect(session.seq).toBe(2); + expect(session.ack).toBe(0); + expect(session.sendBuffer.length).toBe(2); + + // restore codec and retry handshake + spy.mockRestore(); + const retryRes = session.sendHandshake(msg); + expect(retryRes.ok).toBe(true); + + // transition to connected and verify messages work + const connectedListeners = createSessionConnectedListeners(); + const connected = SessionStateGraph.transition.HandshakingToConnected( + session, + connectedListeners, + ); + + expect(connected.state).toBe(SessionState.Connected); + expect(connected.seq).toBe(2); + expect(connected.ack).toBe(0); + + // flush buffered messages first + const bufferRes = connected.sendBufferedMessages(); + expect(bufferRes.ok).toBe(true); + + // send a new message in connected state + const sendRes = connected.send(payloadToTransportMessage('after')); + expect(sendRes.ok).toBe(true); + expect(connected.seq).toBe(3); + // 1 handshake retry + 2 buffered + 1 new = 4 + expect(connected.conn.send).toHaveBeenCalledTimes(4); + }); + + test('pending identification sendHandshake: codec failure does not prevent subsequent success', () => { + const sessionHandle = createSessionWaitingForHandshake(); + const session = sessionHandle.session; + + const msg = handshakeRequestMessage({ + from: 'from', + to: 'to', + sessionId: 'clientSessionId', + expectedSessionState: { + nextExpectedSeq: 0, + nextSentSeq: 0, + }, + }); + + // make codec.toBuffer fail + const spy = vi + .spyOn(session.codec, 'toBuffer') + .mockReturnValue({ ok: false, reason: 'encode error' }); + + const res = session.sendHandshake(msg); + expect(res.ok).toBe(false); + assert(!res.ok); + expect(res.reason).toBe('encode error'); + expect(session.conn.send).not.toHaveBeenCalled(); + + // restore codec and retry handshake + spy.mockRestore(); + const retryRes = session.sendHandshake(msg); + expect(retryRes.ok).toBe(true); + expect(session.conn.send).toHaveBeenCalledTimes(1); + }); + test('connected event listeners: connectionErrored', async () => { const sessionHandle = await createSessionConnected(); const session = sessionHandle.session; @@ -1866,8 +1967,11 @@ describe('session state machine', () => { expect(onConnectionClosed).not.toHaveBeenCalled(); expect(onConnectionErrored).not.toHaveBeenCalled(); - const msg = session.constructMsg(payloadToTransportMessage('hello')); - session.conn.emitData(session.options.codec.toBuffer(msg)); + const encodeResult = session.encodeMsg( + payloadToTransportMessage('hello'), + ); + assert(encodeResult.ok); + session.conn.emitData(encodeResult.value.data); await waitFor(async () => { expect(onMessage).toHaveBeenCalledTimes(1); diff --git a/transport/sessionStateMachine/transitions.ts b/transport/sessionStateMachine/transitions.ts index b14a9c34..b85745ee 100644 --- a/transport/sessionStateMachine/transitions.ts +++ b/transport/sessionStateMachine/transitions.ts @@ -1,4 +1,4 @@ -import { OpaqueTransportMessage, TransportClientId } from '..'; +import { TransportClientId } from '..'; import { SessionConnecting, SessionConnectingListeners, @@ -38,13 +38,13 @@ import { SessionBackingOff, SessionBackingOffListeners, } from './SessionBackingOff'; -import { ProtocolVersion } from '../message'; +import { EncodedTransportMessage, ProtocolVersion } from '../message'; import { Tracer } from '@opentelemetry/api'; import { CodecMessageAdapter } from '../../codec'; function inheritSharedSession( session: IdentifiedSession, -): IdentifiedSessionProps { +): Omit { return { id: session.id, from: session.from, @@ -84,7 +84,7 @@ export const SessionStateGraph = { ) => { const id = `session-${generateId()}`; const telemetry = createSessionTelemetryInfo(tracer, id, to, from); - const sendBuffer: Array = []; + const sendBuffer: Array = []; const session = new SessionNoConnection({ listeners, @@ -255,7 +255,7 @@ export const SessionStateGraph = { ): SessionConnected => { const conn = pendingSession.conn; const { from, options } = pendingSession; - const carriedState: IdentifiedSessionProps = oldSession + const carriedState: Omit = oldSession ? // old session exists, inherit state inheritSharedSession(oldSession) : // old session does not exist, create new state @@ -279,7 +279,7 @@ export const SessionStateGraph = { log: pendingSession.log, protocolVersion, codec: new CodecMessageAdapter(options.codec), - } satisfies IdentifiedSessionProps); + } satisfies Omit); pendingSession._handleStateExit(); oldSession?._handleStateExit(); diff --git a/transport/transport.ts b/transport/transport.ts index 1dd944fe..85e80083 100644 --- a/transport/transport.ts +++ b/transport/transport.ts @@ -10,7 +10,13 @@ import { LoggingLevel, createLogProxy, } from '../logging/log'; -import { EventDispatcher, EventHandler, EventMap, EventTypes } from './events'; +import { + EventDispatcher, + EventHandler, + EventMap, + EventTypes, + ProtocolError, +} from './events'; import { ProvidedTransportOptions, TransportOptions, @@ -21,6 +27,7 @@ import { SessionConnecting, SessionHandshaking, SessionNoConnection, + SessionNoConnectionListeners, SessionState, } from './sessionStateMachine'; import { Connection } from './connection'; @@ -277,6 +284,18 @@ export abstract class Transport { onSessionGracePeriodElapsed: () => { this.onSessionGracePeriodElapsed(noConnectionSession); }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...noConnectionSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(noConnectionSession, { unhealthy: true }); + }, }); this.updateSession(noConnectionSession); @@ -289,20 +308,36 @@ export abstract class Transport { ): SessionNoConnection { // transition to no connection let noConnectionSession: SessionNoConnection; + const listeners: SessionNoConnectionListeners = { + onSessionGracePeriodElapsed: () => { + this.onSessionGracePeriodElapsed(noConnectionSession); + }, + onMessageSendFailure: (msg, reason) => { + this.log?.error(`failed to send message: ${reason}`, { + ...noConnectionSession.loggingMetadata, + transportMessage: msg, + }); + + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + this.deleteSession(noConnectionSession, { unhealthy: true }); + }, + }; + if (session.state === SessionState.Handshaking) { noConnectionSession = - SessionStateGraph.transition.HandshakingToNoConnection(session, { - onSessionGracePeriodElapsed: () => { - this.onSessionGracePeriodElapsed(noConnectionSession); - }, - }); + SessionStateGraph.transition.HandshakingToNoConnection( + session, + listeners, + ); } else { noConnectionSession = - SessionStateGraph.transition.ConnectedToNoConnection(session, { - onSessionGracePeriodElapsed: () => { - this.onSessionGracePeriodElapsed(noConnectionSession); - }, - }); + SessionStateGraph.transition.ConnectedToNoConnection( + session, + listeners, + ); } this.updateSession(noConnectionSession);