Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions sdks/csharp/src/CompressionHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,12 @@ internal static GZipStream GzipReader(Stream stream)
/// </summary>
/// <param name="bytes">The compressed and encoded server message as a byte array.</param>
/// <returns>The deserialized <see cref="ServerMessage"/> object.</returns>
internal static ServerMessage DecompressDecodeMessage(byte[] bytes)
internal static byte[] DecompressMessagePayload(byte[] bytes)
{
using var stream = new MemoryStream(bytes);

// The stream will never be empty. It will at least contain the compression algo.
var compression = (CompressionAlgos)stream.ReadByte();
// Conditionally decompress and decode.
Stream decompressedStream = compression switch
{
CompressionAlgos.None => stream,
Expand All @@ -67,10 +66,21 @@ internal static ServerMessage DecompressDecodeMessage(byte[] bytes)
// TODO: consider pooling these.
// DO NOT TRY TO TAKE THIS OUT. The BrotliStream ReadByte() implementation allocates an array
// PER BYTE READ. You have to do it all at once to avoid that problem.
MemoryStream memoryStream = new MemoryStream();
using var memoryStream = new MemoryStream();
decompressedStream.CopyTo(memoryStream);
memoryStream.Seek(0, SeekOrigin.Begin);
return new ServerMessage.BSATN().Read(new BinaryReader(memoryStream));
return memoryStream.ToArray();
}

internal static ServerMessage DecodeServerMessage(byte[] bytes)
{
using var stream = new MemoryStream(bytes);
using var reader = new BinaryReader(stream);
return new ServerMessage.BSATN().Read(reader);
}

internal static ServerMessage DecompressDecodeMessage(byte[] bytes)
{
return DecodeServerMessage(DecompressMessagePayload(bytes));
}

/// <summary>
Expand Down
25 changes: 22 additions & 3 deletions sdks/csharp/src/Plugins/WebSocket.jslib
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ mergeInto(LibraryManager.library, {
var host = UTF8ToString(baseUriPtr);
var uri = UTF8ToString(uriPtr);
var protocol = UTF8ToString(protocolPtr);
// The C# WebGL bridge can only pass one string argument here, so
// multiple offered subprotocols are marshalled as a comma-separated string.
var offeredProtocols = protocol.indexOf(',') === -1 ? protocol : protocol.split(',');
var authToken = UTF8ToString(authTokenPtr);
if (authToken)
{
Expand All @@ -46,15 +49,31 @@ mergeInto(LibraryManager.library, {
}
}

var socket = new window.WebSocket(uri, protocol);
var socket = new window.WebSocket(uri, offeredProtocols);
socket.binaryType = "arraybuffer";

var socketId = manager.nextId++;
manager.instances[socketId] = socket;

socket.onopen = function() {
if (manager.callbacks.open) {
dynCall('vi', manager.callbacks.open, [socketId]);
var protocolStr = socket.protocol || "";
// Marshal the negotiated subprotocol to C# just for the duration of
// this callback. We use stack allocation because the pointer only
// needs to remain valid while dynCall is executing synchronously.
var protocolLength = lengthBytesUTF8(protocolStr) + 1;
var stack = stackSave();
try {
var protocolPtr = stackAlloc(protocolLength);
// Write a temporary null-terminated UTF-8 string into the
// Emscripten stack frame so the C# callback can copy it.
stringToUTF8(protocolStr, protocolPtr, protocolLength);
dynCall('vii', manager.callbacks.open, [socketId, protocolPtr]);
} finally {
// Release the temporary stack allocation immediately after
// the callback returns; C# must not retain the pointer.
stackRestore(stack);
}
}
};

Expand Down Expand Up @@ -115,4 +134,4 @@ mergeInto(LibraryManager.library, {
socket.close(code, reason);
delete manager.instances[socketId];
}
});
});
10 changes: 10 additions & 0 deletions sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#nullable enable

namespace SpacetimeDB.ClientApi
{
[SpacetimeDB.Type]
internal partial record ClientFrame : SpacetimeDB.TaggedEnum<(
byte[] Single,
byte[][] Batch
)>;
}
10 changes: 10 additions & 0 deletions sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#nullable enable

namespace SpacetimeDB.ClientApi
{
[SpacetimeDB.Type]
internal partial record ServerFrame : SpacetimeDB.TaggedEnum<(
byte[] Single,
byte[][] Batch
)>;
}
46 changes: 32 additions & 14 deletions sdks/csharp/src/SpacetimeDBClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ public abstract class DbConnectionBase<DbConnection, Tables, Reducer> : IDbConne
protected abstract IErrorContext ToErrorContext(Exception errorContext);
protected abstract IProcedureEventContext ToProcedureEventContext(ProcedureEvent procedureEvent);

private Func<byte[], byte[][]> decodeTransportMessages = DecodeV2TransportMessages;

private readonly ConcurrentDictionary<uint, TaskCompletionSource<OneOffQueryResult>> waitingOneOffQueries = new();

private readonly ConcurrentDictionary<uint, PendingReducerCall> pendingReducerCalls = new();
Expand Down Expand Up @@ -219,10 +221,16 @@ protected DbConnectionBase()
{
var options = new WebSocket.ConnectOptions
{
Protocol = "v2.bsatn.spacetimedb"
Protocols = WebSocketProtocols.Preferred
};
webSocket = new WebSocket(options);
webSocket.OnMessage += OnMessageReceived;
webSocket.OnProtocolNegotiated += protocolVersion =>
{
decodeTransportMessages = protocolVersion == WebSocketProtocolVersion.V3
? WebSocketV3Frames.DecodeServerMessages
: DecodeV2TransportMessages;
};
webSocket.OnSendError += a => onSendError?.Invoke(a);
#if UNITY_5_3_OR_NEWER
webSocket.OnClose += (e) =>
Expand Down Expand Up @@ -289,6 +297,8 @@ internal struct ParsedMessage

private static readonly Status Committed = new Status.Committed(default);

private static byte[][] DecodeV2TransportMessages(byte[] payload) => new[] { payload };

/// <summary>
/// Get a description of a message suitable for storing in the tracker metadata.
/// </summary>
Expand Down Expand Up @@ -427,9 +437,18 @@ void ParseOneOffQuery(OneOffQueryResult resp)
#endif
try
{
var message = _parseQueue.Take(_parseCancellationToken);
var parsedMessage = ParseMessage(message);
_applyQueue.Add(parsedMessage, _parseCancellationToken);
var unparsed = _parseQueue.Take(_parseCancellationToken);
var payload = CompressionHelpers.DecompressMessagePayload(unparsed.bytes);
var decodedMessages = decodeTransportMessages(payload);
stats.ParseMessageQueueTracker.FinishTrackingRequest(
unparsed.parseQueueTrackerId,
$"type=ws_frame,count={decodedMessages.Length}"
);
foreach (var messageBytes in decodedMessages)
{
var parsedMessage = ParseMessage(messageBytes, unparsed.timestamp);
_applyQueue.Add(parsedMessage, _parseCancellationToken);
}
}
catch (OperationCanceledException)
{
Expand All @@ -452,13 +471,11 @@ void ParseOneOffQuery(OneOffQueryResult resp)
}
}

ParsedMessage ParseMessage(UnparsedMessage unparsed)
ParsedMessage ParseMessage(byte[] messageBytes, DateTime timestamp)
{
var dbOps = ParsedDatabaseUpdate.New();
var message = CompressionHelpers.DecompressDecodeMessage(unparsed.bytes);
var message = CompressionHelpers.DecodeServerMessage(messageBytes);
var trackerMetadata = TrackerMetadataForMessage(message);

stats.ParseMessageQueueTracker.FinishTrackingRequest(unparsed.parseQueueTrackerId, trackerMetadata);
var parseStart = DateTime.UtcNow;

ReducerEvent<Reducer>? reducerEvent = default;
Expand All @@ -469,11 +486,11 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed)
case ServerMessage.InitialConnection:
break;
case ServerMessage.SubscribeApplied(var subscribeApplied):
stats.SubscriptionRequestTracker.FinishTrackingRequest(subscribeApplied.RequestId, unparsed.timestamp);
stats.SubscriptionRequestTracker.FinishTrackingRequest(subscribeApplied.RequestId, timestamp);
dbOps = ParseSubscribeRows(subscribeApplied.Rows);
break;
case ServerMessage.UnsubscribeApplied(var unsubscribeApplied):
stats.SubscriptionRequestTracker.FinishTrackingRequest(unsubscribeApplied.RequestId, unparsed.timestamp);
stats.SubscriptionRequestTracker.FinishTrackingRequest(unsubscribeApplied.RequestId, timestamp);
if (unsubscribeApplied.Rows != null)
{
dbOps = ParseUnsubscribeRows(unsubscribeApplied.Rows);
Expand All @@ -482,7 +499,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed)
case ServerMessage.SubscriptionError(var subscriptionError):
if (subscriptionError.RequestId.HasValue)
{
stats.SubscriptionRequestTracker.FinishTrackingRequest(subscriptionError.RequestId.Value, unparsed.timestamp);
stats.SubscriptionRequestTracker.FinishTrackingRequest(subscriptionError.RequestId.Value, timestamp);
}
break;
case ServerMessage.TransactionUpdate(var transactionUpdate):
Expand All @@ -492,7 +509,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed)
ParseOneOffQuery(resp);
break;
case ServerMessage.ReducerResult(var reducerResult):
if (!stats.ReducerRequestTracker.FinishTrackingRequest(reducerResult.RequestId, unparsed.timestamp))
if (!stats.ReducerRequestTracker.FinishTrackingRequest(reducerResult.RequestId, timestamp))
{
Log.Warn($"Failed to finish tracking reducer request: {reducerResult.RequestId}");
}
Expand Down Expand Up @@ -545,7 +562,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed)
procedureResult.RequestId
);

if (!stats.ProcedureRequestTracker.FinishTrackingRequest(procedureResult.RequestId, unparsed.timestamp))
if (!stats.ProcedureRequestTracker.FinishTrackingRequest(procedureResult.RequestId, timestamp))
{
Log.Warn($"Failed to finish tracking procedure request: {procedureResult.RequestId}");
}
Expand All @@ -558,7 +575,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed)
stats.ParseMessageTracker.InsertRequest(parseStart, trackerMetadata);
var applyTracker = stats.ApplyMessageQueueTracker.StartTrackingRequest(trackerMetadata);

return new ParsedMessage { message = message, dbOps = dbOps, receiveTimestamp = unparsed.timestamp, applyQueueTrackerId = applyTracker, reducerEvent = reducerEvent, procedureEvent = procedureEvent };
return new ParsedMessage { message = message, dbOps = dbOps, receiveTimestamp = timestamp, applyQueueTrackerId = applyTracker, reducerEvent = reducerEvent, procedureEvent = procedureEvent };
}
}

Expand Down Expand Up @@ -609,6 +626,7 @@ void IDbConnection.Connect(string? token, string uri, string addressOrName, Comp
{
isClosing = false;
connectionClosed = false;
decodeTransportMessages = DecodeV2TransportMessages;
Identity = null;
initialConnectionId = null;
onConnectInvoked = false;
Expand Down
Loading
Loading