From d7b8577bbe88e9ee5284691c6e5255d39d9cf3e4 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 17:17:44 -0700 Subject: [PATCH 1/3] Update C# sdk to use v3 websocket api --- sdks/csharp/src/CompressionHelpers.cs | 20 +- sdks/csharp/src/Plugins/WebSocket.jslib | 14 +- .../src/SpacetimeDB/ClientApi/ClientFrame.cs | 10 + .../src/SpacetimeDB/ClientApi/ServerFrame.cs | 10 + sdks/csharp/src/SpacetimeDBClient.cs | 46 +++-- sdks/csharp/src/WebSocket.cs | 191 +++++++++++++++--- sdks/csharp/src/WebSocketProtocols.cs | 26 +++ sdks/csharp/src/WebSocketV3Frames.cs | 101 +++++++++ sdks/csharp/tests~/SnapshotTests.cs | 66 ++++++ sdks/csharp/tests~/Tests.cs | 44 +++- 10 files changed, 478 insertions(+), 50 deletions(-) create mode 100644 sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs create mode 100644 sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs create mode 100644 sdks/csharp/src/WebSocketProtocols.cs create mode 100644 sdks/csharp/src/WebSocketV3Frames.cs diff --git a/sdks/csharp/src/CompressionHelpers.cs b/sdks/csharp/src/CompressionHelpers.cs index 832208938ed..cb37c5e25c6 100644 --- a/sdks/csharp/src/CompressionHelpers.cs +++ b/sdks/csharp/src/CompressionHelpers.cs @@ -49,13 +49,12 @@ internal static GZipStream GzipReader(Stream stream) /// /// The compressed and encoded server message as a byte array. /// The deserialized object. - internal static ServerMessage DecompressDecodeMessage(byte[] bytes) + internal static byte[] DecompressMessagePayload(byte[] bytes) { using var stream = new MemoryStream(bytes); // The stream will never be empty. It will at least contain the compression algo. var compression = (CompressionAlgos)stream.ReadByte(); - // Conditionally decompress and decode. Stream decompressedStream = compression switch { CompressionAlgos.None => stream, @@ -67,10 +66,21 @@ internal static ServerMessage DecompressDecodeMessage(byte[] bytes) // TODO: consider pooling these. // DO NOT TRY TO TAKE THIS OUT. The BrotliStream ReadByte() implementation allocates an array // PER BYTE READ. You have to do it all at once to avoid that problem. - MemoryStream memoryStream = new MemoryStream(); + using var memoryStream = new MemoryStream(); decompressedStream.CopyTo(memoryStream); - memoryStream.Seek(0, SeekOrigin.Begin); - return new ServerMessage.BSATN().Read(new BinaryReader(memoryStream)); + return memoryStream.ToArray(); + } + + internal static ServerMessage DecodeServerMessage(byte[] bytes) + { + using var stream = new MemoryStream(bytes); + using var reader = new BinaryReader(stream); + return new ServerMessage.BSATN().Read(reader); + } + + internal static ServerMessage DecompressDecodeMessage(byte[] bytes) + { + return DecodeServerMessage(DecompressMessagePayload(bytes)); } /// diff --git a/sdks/csharp/src/Plugins/WebSocket.jslib b/sdks/csharp/src/Plugins/WebSocket.jslib index d2427954bb8..34d62dd6613 100644 --- a/sdks/csharp/src/Plugins/WebSocket.jslib +++ b/sdks/csharp/src/Plugins/WebSocket.jslib @@ -24,6 +24,9 @@ mergeInto(LibraryManager.library, { var host = UTF8ToString(baseUriPtr); var uri = UTF8ToString(uriPtr); var protocol = UTF8ToString(protocolPtr); + // The C# WebGL bridge can only pass one string argument here, so + // multiple offered subprotocols are marshalled as a comma-separated string. + var offeredProtocols = protocol.indexOf(',') === -1 ? protocol : protocol.split(','); var authToken = UTF8ToString(authTokenPtr); if (authToken) { @@ -46,7 +49,7 @@ mergeInto(LibraryManager.library, { } } - var socket = new window.WebSocket(uri, protocol); + var socket = new window.WebSocket(uri, offeredProtocols); socket.binaryType = "arraybuffer"; var socketId = manager.nextId++; @@ -54,7 +57,12 @@ mergeInto(LibraryManager.library, { socket.onopen = function() { if (manager.callbacks.open) { - dynCall('vi', manager.callbacks.open, [socketId]); + var protocolStr = socket.protocol || ""; + var protocolArray = intArrayFromString(protocolStr); + var protocolPtr = _malloc(protocolArray.length); + HEAP8.set(protocolArray, protocolPtr); + dynCall('vii', manager.callbacks.open, [socketId, protocolPtr]); + _free(protocolPtr); } }; @@ -115,4 +123,4 @@ mergeInto(LibraryManager.library, { socket.close(code, reason); delete manager.instances[socketId]; } -}); \ No newline at end of file +}); diff --git a/sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs b/sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs new file mode 100644 index 00000000000..7cb7980a3f2 --- /dev/null +++ b/sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs @@ -0,0 +1,10 @@ +#nullable enable + +namespace SpacetimeDB.ClientApi +{ + [SpacetimeDB.Type] + internal partial record ClientFrame : SpacetimeDB.TaggedEnum<( + byte[] Single, + byte[][] Batch + )>; +} diff --git a/sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs b/sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs new file mode 100644 index 00000000000..9660ac589ef --- /dev/null +++ b/sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs @@ -0,0 +1,10 @@ +#nullable enable + +namespace SpacetimeDB.ClientApi +{ + [SpacetimeDB.Type] + internal partial record ServerFrame : SpacetimeDB.TaggedEnum<( + byte[] Single, + byte[][] Batch + )>; +} diff --git a/sdks/csharp/src/SpacetimeDBClient.cs b/sdks/csharp/src/SpacetimeDBClient.cs index ef202ed168b..639ad9f1d54 100644 --- a/sdks/csharp/src/SpacetimeDBClient.cs +++ b/sdks/csharp/src/SpacetimeDBClient.cs @@ -168,6 +168,8 @@ public abstract class DbConnectionBase : IDbConne protected abstract IErrorContext ToErrorContext(Exception errorContext); protected abstract IProcedureEventContext ToProcedureEventContext(ProcedureEvent procedureEvent); + private Func decodeTransportMessages = DecodeV2TransportMessages; + private readonly ConcurrentDictionary> waitingOneOffQueries = new(); private readonly ConcurrentDictionary pendingReducerCalls = new(); @@ -219,10 +221,16 @@ protected DbConnectionBase() { var options = new WebSocket.ConnectOptions { - Protocol = "v2.bsatn.spacetimedb" + Protocols = WebSocketProtocols.Preferred }; webSocket = new WebSocket(options); webSocket.OnMessage += OnMessageReceived; + webSocket.OnProtocolNegotiated += protocolVersion => + { + decodeTransportMessages = protocolVersion == WebSocketProtocolVersion.V3 + ? WebSocketV3Frames.DecodeServerMessages + : DecodeV2TransportMessages; + }; webSocket.OnSendError += a => onSendError?.Invoke(a); #if UNITY_5_3_OR_NEWER webSocket.OnClose += (e) => @@ -289,6 +297,8 @@ internal struct ParsedMessage private static readonly Status Committed = new Status.Committed(default); + private static byte[][] DecodeV2TransportMessages(byte[] payload) => new[] { payload }; + /// /// Get a description of a message suitable for storing in the tracker metadata. /// @@ -427,9 +437,18 @@ void ParseOneOffQuery(OneOffQueryResult resp) #endif try { - var message = _parseQueue.Take(_parseCancellationToken); - var parsedMessage = ParseMessage(message); - _applyQueue.Add(parsedMessage, _parseCancellationToken); + var unparsed = _parseQueue.Take(_parseCancellationToken); + var payload = CompressionHelpers.DecompressMessagePayload(unparsed.bytes); + var decodedMessages = decodeTransportMessages(payload); + stats.ParseMessageQueueTracker.FinishTrackingRequest( + unparsed.parseQueueTrackerId, + $"type=ws_frame,count={decodedMessages.Length}" + ); + foreach (var messageBytes in decodedMessages) + { + var parsedMessage = ParseMessage(messageBytes, unparsed.timestamp); + _applyQueue.Add(parsedMessage, _parseCancellationToken); + } } catch (OperationCanceledException) { @@ -452,13 +471,11 @@ void ParseOneOffQuery(OneOffQueryResult resp) } } - ParsedMessage ParseMessage(UnparsedMessage unparsed) + ParsedMessage ParseMessage(byte[] messageBytes, DateTime timestamp) { var dbOps = ParsedDatabaseUpdate.New(); - var message = CompressionHelpers.DecompressDecodeMessage(unparsed.bytes); + var message = CompressionHelpers.DecodeServerMessage(messageBytes); var trackerMetadata = TrackerMetadataForMessage(message); - - stats.ParseMessageQueueTracker.FinishTrackingRequest(unparsed.parseQueueTrackerId, trackerMetadata); var parseStart = DateTime.UtcNow; ReducerEvent? reducerEvent = default; @@ -469,11 +486,11 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) case ServerMessage.InitialConnection: break; case ServerMessage.SubscribeApplied(var subscribeApplied): - stats.SubscriptionRequestTracker.FinishTrackingRequest(subscribeApplied.RequestId, unparsed.timestamp); + stats.SubscriptionRequestTracker.FinishTrackingRequest(subscribeApplied.RequestId, timestamp); dbOps = ParseSubscribeRows(subscribeApplied.Rows); break; case ServerMessage.UnsubscribeApplied(var unsubscribeApplied): - stats.SubscriptionRequestTracker.FinishTrackingRequest(unsubscribeApplied.RequestId, unparsed.timestamp); + stats.SubscriptionRequestTracker.FinishTrackingRequest(unsubscribeApplied.RequestId, timestamp); if (unsubscribeApplied.Rows != null) { dbOps = ParseUnsubscribeRows(unsubscribeApplied.Rows); @@ -482,7 +499,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) case ServerMessage.SubscriptionError(var subscriptionError): if (subscriptionError.RequestId.HasValue) { - stats.SubscriptionRequestTracker.FinishTrackingRequest(subscriptionError.RequestId.Value, unparsed.timestamp); + stats.SubscriptionRequestTracker.FinishTrackingRequest(subscriptionError.RequestId.Value, timestamp); } break; case ServerMessage.TransactionUpdate(var transactionUpdate): @@ -492,7 +509,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) ParseOneOffQuery(resp); break; case ServerMessage.ReducerResult(var reducerResult): - if (!stats.ReducerRequestTracker.FinishTrackingRequest(reducerResult.RequestId, unparsed.timestamp)) + if (!stats.ReducerRequestTracker.FinishTrackingRequest(reducerResult.RequestId, timestamp)) { Log.Warn($"Failed to finish tracking reducer request: {reducerResult.RequestId}"); } @@ -545,7 +562,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) procedureResult.RequestId ); - if (!stats.ProcedureRequestTracker.FinishTrackingRequest(procedureResult.RequestId, unparsed.timestamp)) + if (!stats.ProcedureRequestTracker.FinishTrackingRequest(procedureResult.RequestId, timestamp)) { Log.Warn($"Failed to finish tracking procedure request: {procedureResult.RequestId}"); } @@ -558,7 +575,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) stats.ParseMessageTracker.InsertRequest(parseStart, trackerMetadata); var applyTracker = stats.ApplyMessageQueueTracker.StartTrackingRequest(trackerMetadata); - return new ParsedMessage { message = message, dbOps = dbOps, receiveTimestamp = unparsed.timestamp, applyQueueTrackerId = applyTracker, reducerEvent = reducerEvent, procedureEvent = procedureEvent }; + return new ParsedMessage { message = message, dbOps = dbOps, receiveTimestamp = timestamp, applyQueueTrackerId = applyTracker, reducerEvent = reducerEvent, procedureEvent = procedureEvent }; } } @@ -609,6 +626,7 @@ void IDbConnection.Connect(string? token, string uri, string addressOrName, Comp { isClosing = false; connectionClosed = false; + decodeTransportMessages = DecodeV2TransportMessages; Identity = null; initialConnectionId = null; onConnectInvoked = false; diff --git a/sdks/csharp/src/WebSocket.cs b/sdks/csharp/src/WebSocket.cs index 26ce87127ba..95749f08694 100644 --- a/sdks/csharp/src/WebSocket.cs +++ b/sdks/csharp/src/WebSocket.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Concurrent; -using System.Linq; +using System.Collections.Generic; using System.Net.Sockets; using System.Net.WebSockets; using System.Runtime.InteropServices; @@ -15,6 +15,8 @@ namespace SpacetimeDB { internal class WebSocket { + private delegate (byte[] EncodedMessage, bool ShouldYield) DequeueSendWork(); + public delegate void OpenEventHandler(); public delegate void MessageEventHandler(byte[] message, DateTime timestamp); @@ -26,7 +28,7 @@ internal class WebSocket public struct ConnectOptions { - public string Protocol; + public string[] Protocols; } // WebSocket buffer for incoming messages @@ -36,13 +38,16 @@ public struct ConnectOptions private readonly ConnectOptions _options; private readonly byte[] _receiveBuffer = new byte[MAXMessageSize]; private readonly ConcurrentQueue dispatchQueue = new(); + private static readonly ClientMessage.BSATN clientMessageBsatn = new(); protected ClientWebSocket Ws = new(); private CancellationTokenSource? _connectCts; + private DequeueSendWork dequeueSendWork; public WebSocket(ConnectOptions options) { _options = options; + dequeueSendWork = DequeueV2SendWork; #if UNITY_WEBGL && !UNITY_EDITOR InitializeWebGL(); #endif @@ -57,6 +62,14 @@ public WebSocket(ConnectOptions options) /// public event MessageEventHandler? OnMessage; public event CloseEventHandler? OnClose; + public event Action? OnProtocolNegotiated; + + private WebSocketProtocolVersion protocolVersion = WebSocketProtocolVersion.V2; + public WebSocketProtocolVersion ProtocolVersion + { + get => protocolVersion; + internal set => SetProtocolVersion(value); + } #if UNITY_WEBGL && !UNITY_EDITOR private bool _isConnected = false; @@ -88,10 +101,11 @@ IntPtr errorCallback [DllImport("__Internal")] private static extern void WebSocket_Close(int socketId, int code, string reason); - [AOT.MonoPInvokeCallback(typeof(Action))] - private static void WebGLOnOpen(int socketId) + [AOT.MonoPInvokeCallback(typeof(Action))] + private static void WebGLOnOpen(int socketId, IntPtr protocolPtr) { - Instance?.HandleWebGLOpen(socketId); + var protocol = Marshal.PtrToStringUTF8(protocolPtr) ?? string.Empty; + Instance?.HandleWebGLOpen(socketId, protocol); } [AOT.MonoPInvokeCallback(typeof(Action))] @@ -137,7 +151,7 @@ private void InitializeWebGL() { Instance = this; // Convert callbacks to function pointers - var openPtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnOpen); + var openPtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnOpen); var messagePtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnMessage); var closePtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnClose); var errorPtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnError); @@ -148,6 +162,7 @@ private void InitializeWebGL() public async Task Connect(string? auth, string host, string nameOrAddress, ConnectionId connectionId, Compression compression, bool light, bool? confirmedReads) { + ResetProtocolVersion(); #if UNITY_WEBGL && !UNITY_EDITOR if (_isConnecting || _isConnected) return; @@ -166,7 +181,7 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne _socketId = new TaskCompletionSource(); var callbackPtr = Marshal.GetFunctionPointerForDelegate((Action)OnSocketIdReceived); - WebSocket_Connect(host, uri, _options.Protocol, auth, callbackPtr); + WebSocket_Connect(host, uri, WebSocketProtocols.SerializeOfferedProtocols(_options.Protocols), auth, callbackPtr); _webglSocketId = await _socketId.Task; if (_webglSocketId == -1) { @@ -189,6 +204,7 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne } // Events will be handled via UnitySendMessage callbacks #else + Ws = new ClientWebSocket(); var uri = $"{host}/v1/database/{nameOrAddress}/subscribe?connection_id={connectionId}&compression={compression}"; if (light) { @@ -201,7 +217,10 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne uri += $"&confirmed={enabled}"; } var url = new Uri(uri); - Ws.Options.AddSubProtocol(_options.Protocol); + foreach (var protocol in _options.Protocols) + { + Ws.Options.AddSubProtocol(protocol); + } _connectCts = new CancellationTokenSource(10000); if (!string.IsNullOrEmpty(auth)) @@ -218,6 +237,7 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne await Ws.ConnectAsync(url, _connectCts.Token); if (Ws.State == WebSocketState.Open) { + SetProtocolVersion(WebSocketProtocols.Normalize(Ws.SubProtocol)); if (OnConnect != null) { dispatchQueue.Enqueue(() => OnConnect()); @@ -373,7 +393,8 @@ await Ws.CloseAsync(WebSocketCloseStatus.MessageTooBig, closeMessage, if (OnMessage != null) { - var message = _receiveBuffer.Take(count).ToArray(); + var message = new byte[count]; + Buffer.BlockCopy(_receiveBuffer, 0, message, 0, count); // directly invoke message handling OnMessage(message, startReceive); } @@ -454,8 +475,8 @@ public void Abort() #endif } - private Task? senderTask; - private readonly ConcurrentQueue messageSendQueue = new(); + private bool senderActive; + private readonly Queue messageSendQueue = new(); /// /// This sender guarantees that that messages are sent out in the order they are received. Our websocket @@ -465,25 +486,66 @@ public void Abort() /// The message to send public void Send(ClientMessage message) { -#if UNITY_WEBGL && !UNITY_EDITOR try { - var messageBSATN = new ClientMessage.BSATN(); - var encodedMessage = IStructuralReadWrite.ToBytes(messageBSATN, message); - WebSocket_Send(_webglSocketId, encodedMessage, encodedMessage.Length); + var encodedMessage = IStructuralReadWrite.ToBytes(clientMessageBsatn, message); + var startProcessor = false; + lock (messageSendQueue) + { + messageSendQueue.Enqueue(encodedMessage); + if (!senderActive) + { + senderActive = true; + startProcessor = true; + } + } + + if (startProcessor) + { + _ = StartProcessSendQueue(); + } } catch (Exception e) { - UnityEngine.Debug.LogError($"WebSocket send error: {e}"); dispatchQueue.Enqueue(() => OnSendError?.Invoke(e)); } + } + + private Task StartProcessSendQueue() + { +#if UNITY_WEBGL && !UNITY_EDITOR + return ProcessSendQueue(); #else + return Task.Run(ProcessSendQueue); +#endif + } + + private void ScheduleSendQueueContinuation() + { +#if UNITY_WEBGL && !UNITY_EDITOR + dispatchQueue.Enqueue(TryStartSendQueueProcessor); +#else + _ = Task.Run(() => + { + TryStartSendQueueProcessor(); + return Task.CompletedTask; + }); +#endif + } + + private void TryStartSendQueueProcessor() + { lock (messageSendQueue) { - messageSendQueue.Enqueue(message); - senderTask ??= Task.Run(ProcessSendQueue); + if (senderActive || messageSendQueue.Count == 0) + { + return; + } + + senderActive = true; } -#endif + + _ = StartProcessSendQueue(); } private async Task ProcessSendQueue() @@ -492,37 +554,111 @@ private async Task ProcessSendQueue() { while (true) { - ClientMessage message; + byte[] encodedMessage; + bool shouldYield; lock (messageSendQueue) { - if (!messageSendQueue.TryDequeue(out message)) + if (messageSendQueue.Count == 0) { // We are out of messages to send - senderTask = null; + senderActive = false; return; } + + (encodedMessage, shouldYield) = dequeueSendWork(); } - var messageBSATN = new ClientMessage.BSATN(); - var encodedMessage = IStructuralReadWrite.ToBytes(messageBSATN, message); - await Ws!.SendAsync(encodedMessage, WebSocketMessageType.Binary, true, CancellationToken.None); + await SendEncodedMessage(encodedMessage); + + if (shouldYield) + { + // After sending one capped v3 frame, stop this queue pump and + // schedule a follow-up pass using the same runtime primitives + // this SDK already uses for send processing on each platform. + lock (messageSendQueue) + { + senderActive = false; + } + ScheduleSendQueueContinuation(); + return; + } } } catch (Exception e) { - senderTask = null; + lock (messageSendQueue) + { + senderActive = false; + } if (OnSendError != null) dispatchQueue.Enqueue(() => OnSendError(e)); } } + private byte[][] DequeueMessagesForV3Frame() + { + var messageCount = WebSocketV3Frames.CountClientMessagesThatFitInFrame(messageSendQueue); + if (messageCount <= 0) + { + throw new InvalidOperationException("Expected at least one queued v2 message when building a v3 frame."); + } + + var messages = new byte[messageCount][]; + for (var i = 0; i < messageCount; i++) + { + messages[i] = messageSendQueue.Dequeue(); + } + return messages; + } + + private (byte[] EncodedMessage, bool ShouldYield) DequeueV2SendWork() => + (messageSendQueue.Dequeue(), false); + + private (byte[] EncodedMessage, bool ShouldYield) DequeueV3SendWork() + { + var queuedMessages = DequeueMessagesForV3Frame(); + return (WebSocketV3Frames.EncodeClientMessages(queuedMessages), messageSendQueue.Count > 0); + } + + private void ResetProtocolVersion() + { + protocolVersion = WebSocketProtocolVersion.V2; + dequeueSendWork = DequeueV2SendWork; + } + + private void SetProtocolVersion(WebSocketProtocolVersion protocolVersion) + { + // Protocol selection is a transport concern: changing it swaps the + // active send strategy and notifies higher layers to swap their + // receive decoder as well. + this.protocolVersion = protocolVersion; + dequeueSendWork = protocolVersion == WebSocketProtocolVersion.V3 + ? DequeueV3SendWork + : DequeueV2SendWork; + OnProtocolNegotiated?.Invoke(protocolVersion); + } + + private Task SendEncodedMessage(byte[] encodedMessage) + { +#if UNITY_WEBGL && !UNITY_EDITOR + var result = WebSocket_Send(_webglSocketId, encodedMessage, encodedMessage.Length); + if (result != 0) + { + throw new InvalidOperationException("WebSocket send failed."); + } + return Task.CompletedTask; +#else + return Ws!.SendAsync(new ArraySegment(encodedMessage), WebSocketMessageType.Binary, true, CancellationToken.None); +#endif + } + public WebSocketState GetState() { return Ws!.State; } #if UNITY_WEBGL && !UNITY_EDITOR - public void HandleWebGLOpen(int socketId) + public void HandleWebGLOpen(int socketId, string protocol) { if (socketId == _webglSocketId) { @@ -535,6 +671,7 @@ public void HandleWebGLOpen(int socketId) _cancelConnectRequested = false; return; } + SetProtocolVersion(WebSocketProtocols.Normalize(protocol)); _isConnected = true; if (OnConnect != null) dispatchQueue.Enqueue(() => OnConnect()); diff --git a/sdks/csharp/src/WebSocketProtocols.cs b/sdks/csharp/src/WebSocketProtocols.cs new file mode 100644 index 00000000000..0e98ec0c48c --- /dev/null +++ b/sdks/csharp/src/WebSocketProtocols.cs @@ -0,0 +1,26 @@ +namespace SpacetimeDB +{ + internal enum WebSocketProtocolVersion + { + V2, + V3, + } + + internal static class WebSocketProtocols + { + internal const string V2 = "v2.bsatn.spacetimedb"; + internal const string V3 = "v3.bsatn.spacetimedb"; + + internal static readonly string[] Preferred = new[] { V3, V2 }; + + internal static WebSocketProtocolVersion Normalize(string? protocol) + { + // Treat an empty negotiated subprotocol as legacy v2 defensively. + return protocol == V3 ? WebSocketProtocolVersion.V3 : WebSocketProtocolVersion.V2; + } + +#if UNITY_WEBGL && !UNITY_EDITOR + internal static string SerializeOfferedProtocols(string[] protocols) => string.Join(",", protocols); +#endif + } +} diff --git a/sdks/csharp/src/WebSocketV3Frames.cs b/sdks/csharp/src/WebSocketV3Frames.cs new file mode 100644 index 00000000000..7f715df7b09 --- /dev/null +++ b/sdks/csharp/src/WebSocketV3Frames.cs @@ -0,0 +1,101 @@ +using SpacetimeDB.BSATN; +using SpacetimeDB.ClientApi; + +using System; +using System.Collections.Generic; +using System.IO; + +namespace SpacetimeDB +{ + internal static class WebSocketV3Frames + { + internal const int MaxFrameBytes = 256 * 1024; + + private const int EnumTagBytes = 1; + private const int CollectionLengthBytes = 4; + private const int ByteArrayLengthBytes = 4; + + private static readonly ClientFrame.BSATN clientFrameBsatn = new(); + private static readonly ServerFrame.BSATN serverFrameBsatn = new(); + + // v3 is only a transport envelope around already-encoded v2 messages, + // so batching works in terms of raw byte payloads rather than logical messages. + internal static byte[] EncodeClientMessages(IReadOnlyList messages) + { + if (messages.Count == 0) + { + throw new InvalidOperationException("Cannot encode an empty v3 client frame."); + } + + ClientFrame frame = messages.Count == 1 + ? new ClientFrame.Single(messages[0]) + : new ClientFrame.Batch(ToArray(messages)); + + return IStructuralReadWrite.ToBytes(clientFrameBsatn, frame); + } + + internal static byte[][] DecodeServerMessages(byte[] encodedFrame) + { + using var stream = new MemoryStream(encodedFrame); + using var reader = new BinaryReader(stream); + var frame = serverFrameBsatn.Read(reader); + return frame switch + { + ServerFrame.Single(var message) => new[] { message }, + ServerFrame.Batch(var messages) => messages, + _ => throw new InvalidOperationException("Unknown v3 server frame variant."), + }; + } + + // Count the maximal prefix of already-encoded client messages that fits in + // one v3 frame using BSATN framing sizes directly instead of trial serialization. + internal static int CountClientMessagesThatFitInFrame( + IEnumerable messages, + int maxFrameBytes = MaxFrameBytes + ) + { + var messageCount = 0; + var payloadBytes = 0; + + foreach (var message in messages) + { + if (messageCount == 0) + { + if (EncodedSingleFrameSize(message.Length) > maxFrameBytes) + { + return 1; + } + } + else + { + var batchSize = EncodedBatchFrameSize(messageCount + 1, payloadBytes + message.Length); + if (batchSize > maxFrameBytes) + { + break; + } + } + + messageCount++; + payloadBytes += message.Length; + } + + return messageCount; + } + + private static int EncodedSingleFrameSize(int messageBytes) => + EnumTagBytes + ByteArrayLengthBytes + messageBytes; + + private static int EncodedBatchFrameSize(int messageCount, int payloadBytes) => + EnumTagBytes + CollectionLengthBytes + (messageCount * ByteArrayLengthBytes) + payloadBytes; + + private static byte[][] ToArray(IReadOnlyList messages) + { + var array = new byte[messages.Count][]; + for (var i = 0; i < messages.Count; i++) + { + array[i] = messages[i]; + } + return array; + } + } +} diff --git a/sdks/csharp/tests~/SnapshotTests.cs b/sdks/csharp/tests~/SnapshotTests.cs index e083928111e..fcaed905f5d 100644 --- a/sdks/csharp/tests~/SnapshotTests.cs +++ b/sdks/csharp/tests~/SnapshotTests.cs @@ -381,6 +381,72 @@ public static IEnumerable SampleDump() } + [Fact] + public void V3BatchedServerFrameIsProcessedInOrder() + { + DbConnection.IsTesting = true; + + var client = + DbConnection.Builder() + .WithUri("wss://spacetimedb.com") + .WithDatabaseName("example") + .Build(); + + client.webSocket.ProtocolVersion = WebSocketProtocolVersion.V3; + + ServerMessage initialConnection = SampleId( + "j5DMlKmWjfbSl7qmZQOok7HDSwsAJopRSJjdlUsNogs=", + "token", + "Vd4dFzcEzhLHJ6uNL8VXFg==" + ); + ServerMessage transactionUpdate = SampleTransactionUpdate( + 1, + [SampleUserInsert("l0qzG1GPRtC1mwr+54q98tv0325gozLc6cNzq4vrzqY=", "A", true)] + ); + + ServerFrame frame = new ServerFrame.Batch(new[] + { + IStructuralReadWrite.ToBytes(new ServerMessage.BSATN(), initialConnection), + IStructuralReadWrite.ToBytes(new ServerMessage.BSATN(), transactionUpdate), + }); + var payload = IStructuralReadWrite.ToBytes(new ServerFrame.BSATN(), frame); + + var transportMessage = new byte[payload.Length + 1]; + transportMessage[0] = 0; + Buffer.BlockCopy(payload, 0, transportMessage, 1, payload.Length); + + client.OnMessageReceived(transportMessage, DateTime.UtcNow); + + var deadline = DateTime.UtcNow.AddSeconds(2); + List? users = null; + while (true) + { + client.FrameTick(); + users = client.Db.User.Iter().ToList(); + if (users.Count == 1) + { + break; + } + + if (DateTime.UtcNow >= deadline) + { + throw new TimeoutException("Timed out waiting for a v3 batched frame to be applied."); + } + Thread.Sleep(1); + } + + Assert.Equal( + Identity.From(Convert.FromBase64String("j5DMlKmWjfbSl7qmZQOok7HDSwsAJopRSJjdlUsNogs=")), + client.Identity + ); + + Assert.Single(users); + Assert.Equal("A", users[0].Name); + Assert.True(users[0].Online); + + client.Disconnect(); + } + [Theory] [MemberData(nameof(SampleDump))] public async Task VerifySampleDump(string dumpName, ServerMessage[] sampleDumpParsed) diff --git a/sdks/csharp/tests~/Tests.cs b/sdks/csharp/tests~/Tests.cs index 3adb4970cba..1559317a382 100644 --- a/sdks/csharp/tests~/Tests.cs +++ b/sdks/csharp/tests~/Tests.cs @@ -2,6 +2,7 @@ using CsCheck; using SpacetimeDB; using SpacetimeDB.BSATN; +using SpacetimeDB.ClientApi; using SpacetimeDB.Types; public class Tests @@ -128,4 +129,45 @@ public static void ListstreamWorks() } }); } -} \ No newline at end of file + + [Fact] + public static void V3BatchSizingCapsAt256KiB() + { + var messages = new[] + { + new byte[100_000], + new byte[100_000], + new byte[100_000], + }; + + Assert.Equal(2, WebSocketV3Frames.CountClientMessagesThatFitInFrame(messages)); + Assert.Equal(1, WebSocketV3Frames.CountClientMessagesThatFitInFrame(new[] { new byte[300_000] })); + Assert.Equal(0, WebSocketV3Frames.CountClientMessagesThatFitInFrame(Array.Empty())); + } + + [Fact] + public static void V3ServerFrameDecodeHandlesSingleAndBatch() + { + static byte[] EncodeFrame(ServerFrame frame) => + IStructuralReadWrite.ToBytes(new ServerFrame.BSATN(), frame); + + var singlePayload = new byte[] { 1, 2, 3 }; + var single = WebSocketV3Frames.DecodeServerMessages( + EncodeFrame(new ServerFrame.Single(singlePayload)) + ); + Assert.Single(single); + Assert.Equal(singlePayload, single[0]); + + var batchPayloads = new[] + { + new byte[] { 4, 5 }, + new byte[] { 6, 7, 8 }, + }; + var batch = WebSocketV3Frames.DecodeServerMessages( + EncodeFrame(new ServerFrame.Batch(batchPayloads)) + ); + Assert.Equal(2, batch.Length); + Assert.Equal(batchPayloads[0], batch[0]); + Assert.Equal(batchPayloads[1], batch[1]); + } +} From 20f5a4f034817da81e31cfa991cf7f059dbb5f0e Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 17:36:16 -0700 Subject: [PATCH 2/3] Clarify WebGL protocol marshalling --- sdks/csharp/src/Plugins/WebSocket.jslib | 21 ++++++++++++++++----- sdks/csharp/src/WebSocket.cs | 2 ++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/sdks/csharp/src/Plugins/WebSocket.jslib b/sdks/csharp/src/Plugins/WebSocket.jslib index 34d62dd6613..4b4d39a8e49 100644 --- a/sdks/csharp/src/Plugins/WebSocket.jslib +++ b/sdks/csharp/src/Plugins/WebSocket.jslib @@ -58,11 +58,22 @@ mergeInto(LibraryManager.library, { socket.onopen = function() { if (manager.callbacks.open) { var protocolStr = socket.protocol || ""; - var protocolArray = intArrayFromString(protocolStr); - var protocolPtr = _malloc(protocolArray.length); - HEAP8.set(protocolArray, protocolPtr); - dynCall('vii', manager.callbacks.open, [socketId, protocolPtr]); - _free(protocolPtr); + // Marshal the negotiated subprotocol to C# just for the duration of + // this callback. We use stack allocation because the pointer only + // needs to remain valid while dynCall is executing synchronously. + var protocolLength = lengthBytesUTF8(protocolStr) + 1; + var stack = stackSave(); + try { + var protocolPtr = stackAlloc(protocolLength); + // Write a temporary null-terminated UTF-8 string into the + // Emscripten stack frame so the C# callback can copy it. + stringToUTF8(protocolStr, protocolPtr, protocolLength); + dynCall('vii', manager.callbacks.open, [socketId, protocolPtr]); + } finally { + // Release the temporary stack allocation immediately after + // the callback returns; C# must not retain the pointer. + stackRestore(stack); + } } }; diff --git a/sdks/csharp/src/WebSocket.cs b/sdks/csharp/src/WebSocket.cs index 95749f08694..96e61763f91 100644 --- a/sdks/csharp/src/WebSocket.cs +++ b/sdks/csharp/src/WebSocket.cs @@ -104,6 +104,8 @@ IntPtr errorCallback [AOT.MonoPInvokeCallback(typeof(Action))] private static void WebGLOnOpen(int socketId, IntPtr protocolPtr) { + // The JS bridge passes a temporary UTF-8 pointer that is only valid for + // this callback, so copy it into a managed string immediately. var protocol = Marshal.PtrToStringUTF8(protocolPtr) ?? string.Empty; Instance?.HandleWebGLOpen(socketId, protocol); } From ee0c8fb6ab5d2fbf82a4bdd46e39ce6c6e8bfcb5 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 17:49:06 -0700 Subject: [PATCH 3/3] Add C# v2 fallback websocket test --- sdks/csharp/tests~/Tests.cs | 84 +++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/sdks/csharp/tests~/Tests.cs b/sdks/csharp/tests~/Tests.cs index 1559317a382..cd92f84bb1d 100644 --- a/sdks/csharp/tests~/Tests.cs +++ b/sdks/csharp/tests~/Tests.cs @@ -1,4 +1,7 @@ using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using System.Net.WebSockets; using CsCheck; using SpacetimeDB; using SpacetimeDB.BSATN; @@ -170,4 +173,85 @@ static byte[] EncodeFrame(ServerFrame frame) => Assert.Equal(batchPayloads[0], batch[0]); Assert.Equal(batchPayloads[1], batch[1]); } + + [Fact] + public static async Task WebSocketFallsBackToV2WhenServerOnlyNegotiatesV2() + { + static int GetFreePort() + { + using var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + return ((IPEndPoint)listener.LocalEndpoint).Port; + } + + static async Task WaitForAsync(Task task, SpacetimeDB.WebSocket ws, string error) + { + var deadline = DateTime.UtcNow.AddSeconds(5); + while (!task.IsCompleted) + { + ws.Update(); + if (DateTime.UtcNow >= deadline) + { + throw new TimeoutException(error); + } + await Task.Delay(10); + } + + await task; + } + + var port = GetFreePort(); + using var listener = new HttpListener(); + listener.Prefixes.Add($"http://127.0.0.1:{port}/"); + listener.Start(); + + var serverObservedProtocols = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var serverTask = Task.Run(async () => + { + var context = await listener.GetContextAsync(); + serverObservedProtocols.TrySetResult(context.Request.Headers["Sec-WebSocket-Protocol"] ?? string.Empty); + + var webSocketContext = await context.AcceptWebSocketAsync(WebSocketProtocols.V2); + await Task.Delay(100); + await webSocketContext.WebSocket.CloseAsync( + WebSocketCloseStatus.NormalClosure, + "done", + CancellationToken.None + ); + }); + + var ws = new SpacetimeDB.WebSocket(new SpacetimeDB.WebSocket.ConnectOptions + { + Protocols = WebSocketProtocols.Preferred, + }); + + var connected = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + ws.OnConnect += () => connected.TrySetResult(); + ws.OnClose += _ => closed.TrySetResult(); + + var clientTask = Task.Run(() => ws.Connect( + "test-token", + $"ws://127.0.0.1:{port}", + "example", + ConnectionId.Random(), + Compression.None, + false, + null + )); + + await WaitForAsync(connected.Task, ws, "Timed out waiting for websocket connection."); + + Assert.Equal(WebSocketProtocolVersion.V2, ws.ProtocolVersion); + + var offeredProtocols = await serverObservedProtocols.Task.WaitAsync(TimeSpan.FromSeconds(5)); + Assert.Contains(WebSocketProtocols.V3, offeredProtocols); + Assert.Contains(WebSocketProtocols.V2, offeredProtocols); + + await WaitForAsync(closed.Task, ws, "Timed out waiting for websocket close."); + await serverTask.WaitAsync(TimeSpan.FromSeconds(5)); + await clientTask.WaitAsync(TimeSpan.FromSeconds(5)); + } }