diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 51f96b88d..2a92af544 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,14 +22,14 @@ jobs: fetch-depth: 0 # Fetch the full history - name: Start Redis Services (docker-compose) working-directory: ./tests/RedisConfigs - run: docker compose -f docker-compose.yml up -d --wait + run: docker compose -f docker-compose.yml up -d --wait - name: Install .NET SDK uses: actions/setup-dotnet@v3 with: - dotnet-version: | + dotnet-version: | 6.0.x 8.0.x - 9.0.x + 10.0.x - name: .NET Build run: dotnet build Build.csproj -c Release /p:CI=true - name: StackExchange.Redis.Tests diff --git a/Directory.Build.props b/Directory.Build.props index 9f10eddcd..06542aa32 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -15,7 +15,7 @@ https://stackexchange.github.io/StackExchange.Redis/ MIT - 13 + 14 git https://github.com/StackExchange/StackExchange.Redis/ diff --git a/Directory.Packages.props b/Directory.Packages.props index 2088a054f..3fa9e0e3d 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -8,7 +8,8 @@ - + + diff --git a/StackExchange.Redis.sln.DotSettings b/StackExchange.Redis.sln.DotSettings index 216edbcca..8dd9095d9 100644 --- a/StackExchange.Redis.sln.DotSettings +++ b/StackExchange.Redis.sln.DotSettings @@ -12,9 +12,12 @@ True True True + True True True + True True + True True True True diff --git a/docs/KeyspaceNotifications.md b/docs/KeyspaceNotifications.md new file mode 100644 index 000000000..eeb156632 --- /dev/null +++ b/docs/KeyspaceNotifications.md @@ -0,0 +1,213 @@ +# Redis Keyspace Notifications + +Redis keyspace notifications let you monitor operations happening on your Redis keys in real-time. StackExchange.Redis provides a strongly-typed API for subscribing to and consuming these events. +This could be used for example to implement a cache invalidation strategy. + +## Prerequisites + +### Redis Configuration + +You must [enable keyspace notifications](https://redis.io/docs/latest/develop/pubsub/keyspace-notifications/#configuration) in your Redis server config, +for example: + +``` conf +notify-keyspace-events AKE +``` + +- **A** - All event types +- **K** - Keyspace notifications (`__keyspace@__:`) +- **E** - Keyevent notifications (`__keyevent@__:`) + +The two types of event (keyspace and keyevent) encode the same information, but in different formats. +To simplify consumption, StackExchange.Redis provides a unified API for both types of event, via the `KeyNotification` type. + +### Event Broadcasting in Redis Cluster + +Importantly, in Redis Cluster, keyspace notifications are **not** broadcast to all nodes - they are only received by clients connecting to the +individual node where the keyspace notification originated, i.e. where the key was modified. +This is different to how regular pub/sub events are handled, where a subscription to a channel on one node will receive events published on any node. +Clients must explicitly subscribe to the same channel on each node they wish to receive events from, which typically means: every primary node in the cluster. +To make this easier, StackExchange.Redis provides dedicated APIs for subscribing to keyspace and keyevent notifications that handle this for you. + +## Quick Start + +As an example, we'll subscribe to all keys with a specific prefix, and print out the key and event type for each notification. First, +we need to create a `RedisChannel`: + +```csharp +// this will subscribe to __keyspace@0__:user:*, including supporting Redis Cluster +var channel = RedisChannel.KeySpacePrefix(prefix: "user:"u8, database: 0); +``` + +Note that there are a range of other `KeySpace...` and `KeyEvent...` methods for different scenarios, including: + +- `KeySpaceSingleKey` - subscribe to notifications for a single key in a specific database +- `KeySpacePattern` - subscribe to notifications for a key pattern, optionally in a specific database +- `KeySpacePrefix` - subscribe to notifications for all keys with a specific prefix, optionally in a specific database +- `KeyEvent` - subscribe to notifications for a specific event type, optionally in a specific database + +The `KeySpace*` methods are similar, and are presented separately to make the intent clear. For example, `KeySpacePattern("foo*")` is equivalent to `KeySpacePrefix("foo")`, and will subscribe to all keys beginning with `"foo"`. + +Next, we subscribe to the channel and process the notifications using the normal pub/sub subscription API; there are two +main approaches: queue-based and callback-based. + +Queue-based: + +```csharp +var queue = await sub.SubscribeAsync(channel); +_ = Task.Run(async () => +{ + await foreach (var msg in queue) + { + if (msg.TryParseKeyNotification(out var notification)) + { + Console.WriteLine($"Key: {notification.GetKey()}"); + Console.WriteLine($"Type: {notification.Type}"); + Console.WriteLine($"Database: {notification.Database}"); + } + } +}); +``` + +Callback-based: + +```csharp +sub.Subscribe(channel, (recvChannel, recvValue) => +{ + if (KeyNotification.TryParse(recvChannel, recvValue, out var notification)) + { + Console.WriteLine($"Key: {notification.GetKey()}"); + Console.WriteLine($"Type: {notification.Type}"); + Console.WriteLine($"Database: {notification.Database}"); + } +}); +``` + +Note that the channels created by the `KeySpace...` and `KeyEvent...` methods cannot be used to manually *publish* events, +only to subscribe to them. The events are published automatically by the Redis server when keys are modified. If you +want to simulate keyspace notifications by publishing events manually, you should use regular pub/sub channels that avoid +the `__keyspace@` and `__keyevent@` prefixes. + +## Performance considerations for KeyNotification + +The `KeyNotification` struct provides parsed notification data, including (as already shown) the key, event type, +database, etc. Note that using `GetKey()` will allocate a copy of the key bytes; to avoid allocations, +you can use `TryCopyKey()` to copy the key bytes into a provided buffer (potentially with `GetKeyByteCount()`, +`GetKeyMaxCharCount()`, etc in order to size the buffer appropriately). Similarly, `KeyStartsWith()` can be used to +efficiently check the key prefix without allocating a string. This approach is designed to be efficient for high-volume +notification processing, and in particular: for use with the alt-lookup (span) APIs that are slowly being introduced +in various .NET APIs. + +For example, with a `ConcurrentDictionary` (for some `T`), you can use `GetAlternateLookup>()` +to get an alternate lookup API that takes a `ReadOnlySpan` instead of a `string`, and then use `TryCopyKey()` to copy +the key bytes into a buffer, and then use the alt-lookup API to find the value. This means that we avoid allocating a string +for the key entirely, and instead just copy the bytes into a buffer. If we consider that commonly a local cache will *not* +contain the key for the majority of notifications (since they are for cache invalidation), this can be a significant +performance win. + +## Considerations when database isolation + +Database isolation is controlled either via the `ConfigurationOptions.DefaultDatabase` option when connecting to Redis, +or by using the `GetDatabase(int? db = null)` method to get a specific database instance. Note that the +`KeySpace...` and `KeyEvent...` APIs may optionally take a database. When a database is specified, subscription will only +respond to notifications for keys in that database. If a database is not specified, the subscription will respond to +notifications for keys in all databases. Often, you will want to pass `db.Database` from the `IDatabase` instance you are +using for your application logic, to ensure that you are monitoring the correct database. When using Redis Cluster, +this usually means database `0`, since Redis Cluster does not usually support multiple databases. + +For example: + +- `RedisChannel.KeySpaceSingleKey("foo", 0)` maps to `SUBSCRIBE __keyspace@0__:foo` +- `RedisChannel.KeySpacePrefix("foo", 0)` maps to `PSUBSCRIBE __keyspace@0__:foo*` +- `RedisChannel.KeySpacePrefix("foo")` maps to `PSUBSCRIBE __keyspace@*__:foo*` +- `RedisChannel.KeyEvent(KeyNotificationType.Set, 0)` maps to `SUBSCRIBE __keyevent@0__:set` +- `RedisChannel.KeyEvent(KeyNotificationType.Set)` maps to `PSUBSCRIBE __keyevent@*__:set` + +Additionally, note that while most of these examples require multi-node subscriptions on Redis Cluster, `KeySpaceSingleKey` +is an exception, and will only subscribe to the single node that owns the key `foo`. + +When subscribing without specifying a database (i.e. listening to changes in all database), the database relating +to the notification can be fetched via `KeyNotification.Database`: + +``` c# +var channel = RedisChannel.KeySpacePrefix("foo"); +sub.SubscribeAsync(channel, (recvChannel, recvValue) => +{ + if (KeyNotification.TryParse(recvChannel, recvValue, out var notification)) + { + var key = notification.GetKey(); + var db = notification.Database; + // ... + } +} +``` + +## Considerations when using keyspace or channel isolation + +StackExchange.Redis supports the concept of keyspace and channel (pub/sub) isolation. + +Channel isolation is controlled using the `ConfigurationOptions.ChannelPrefix` option when connecting to Redis. +Intentionally, this feature *is ignored* by the `KeySpace...` and `KeyEvent...` APIs, because they are designed to +subscribe to specific (server-defined) channels that are outside the control of the client. + +Keyspace isolation is controlled using the `WithKeyPrefix` extension method on `IDatabase`. This is *not* used +by the `KeySpace...` and `KeyEvent...` APIs. Since the database and pub/sub APIs are independent, keyspace isolation +*is not applied* (and cannot be; consuming code could have zero, one, or multiple databases with different prefixes). +The caller is responsible for ensuring that the prefix is applied appropriately when constructing the `RedisChannel`. + +By default, key-related featured of `KeyNotification` will return the full key reported by the server, +including any prefix. However, the `TryParseKeyNotification` and `TryParse` methods can optionally be passed a +key prefix, which will be used both to filter unwanted notifications and strip the prefix from the key when reading. +It is *possible* to handle keyspace isolation manually by checking the key with `KeyNotification.KeyStartsWith` and +manually trimming the prefix, but it is *recommended* to do this via `TryParseKeyNotification` and `TryParse`. + +As an example, with a multi-tenant scenario using keyspace isolation, we might have in the database code: + +``` c# +// multi-tenant scenario using keyspace isolation +byte[] keyPrefix = Encoding.UTF8.GetBytes("client1234:"); +var db = conn.GetDatabase().WithKeyPrefix(keyPrefix); + +// we will later commit order data for example: +await db.StringSetAsync("order/123", "ISBN 9789123684434"); +``` + +To observe this, we could use: + +``` c# +var sub = conn.GetSubscriber(); + +// subscribe to the specific tenant as a prefix: +var channel = RedisChannel.KeySpacePrefix("client1234:order/", db.Database); + +sub.SubscribeAsync(channel, (recvChannel, recvValue) => +{ + // by including prefix in the TryParse, we filter out notifications that are not for this client + // *and* the key is sliced internally to remove this prefix when reading + if (KeyNotification.TryParse(keyPrefix, recvChannel, recvValue, out var notification)) + { + // if we get here, the key prefix was a match + var key = notification.GetKey(); // "order/123" - note no prefix + // ... + } + + /* + // for contrast only: this is *not* usually the recommended approach when using keyspace isolation + if (KeyNotification.TryParse(recvChannel, recvValue, out var notification) + && notification.KeyStartsWith(keyPrefix)) + { + var key = notification.GetKey(); // "client1234:order/123" - note prefix is included + // ... + } + */ +}); + +``` + +Alternatively, if we wanted a single handler that observed *all* tenants, we could use: + +``` c# +var channel = RedisChannel.KeySpacePattern("client*:order/*", db.Database); +``` + +with similar code, parsing the client from the key manually, using the full key length. \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 0b4d9bb2e..b1498d878 100644 --- a/docs/index.md +++ b/docs/index.md @@ -39,6 +39,7 @@ Documentation - [Transactions](Transactions) - how atomic transactions work in redis - [Events](Events) - the events available for logging / information purposes - [Pub/Sub Message Order](PubSubOrder) - advice on sequential and concurrent processing +- [Pub/Sub Key Notifications](KeyspaceNotifications) - how to use keyspace and keyevent notifications - [Using RESP3](Resp3) - information on using RESP3 - [ServerMaintenanceEvent](ServerMaintenanceEvent) - how to listen and prepare for hosted server maintenance (e.g. Azure Cache for Redis) - [Streams](Streams) - how to use the Stream data type diff --git a/src/StackExchange.Redis/ChannelMessage.cs b/src/StackExchange.Redis/ChannelMessage.cs new file mode 100644 index 000000000..a29454f0c --- /dev/null +++ b/src/StackExchange.Redis/ChannelMessage.cs @@ -0,0 +1,73 @@ +using System; + +namespace StackExchange.Redis; + +/// +/// Represents a message that is broadcast via publish/subscribe. +/// +public readonly struct ChannelMessage +{ + // this is *smaller* than storing a RedisChannel for the subscribed channel + private readonly ChannelMessageQueue _queue; + + /// + /// The Channel:Message string representation. + /// + public override string ToString() => ((string?)Channel) + ":" + ((string?)Message); + + /// + public override int GetHashCode() => Channel.GetHashCode() ^ Message.GetHashCode(); + + /// + public override bool Equals(object? obj) => obj is ChannelMessage cm + && cm.Channel == Channel && cm.Message == Message; + + internal ChannelMessage(ChannelMessageQueue queue, in RedisChannel channel, in RedisValue value) + { + _queue = queue; + _channel = channel; + _message = value; + } + + /// + /// The channel that the subscription was created from. + /// + public RedisChannel SubscriptionChannel => _queue.Channel; + + private readonly RedisChannel _channel; + + /// + /// The channel that the message was broadcast to. + /// + public RedisChannel Channel => _channel; + + private readonly RedisValue _message; + + /// + /// The value that was broadcast. + /// + public RedisValue Message => _message; + + /// + /// Checks if 2 messages are .Equal(). + /// + public static bool operator ==(ChannelMessage left, ChannelMessage right) => left.Equals(right); + + /// + /// Checks if 2 messages are not .Equal(). + /// + public static bool operator !=(ChannelMessage left, ChannelMessage right) => !left.Equals(right); + + /// + /// If the channel is either a keyspace or keyevent notification, resolve the key and event type. + /// + public bool TryParseKeyNotification(out KeyNotification notification) + => KeyNotification.TryParse(in _channel, in _message, out notification); + + /// + /// If the channel is either a keyspace or keyevent notification *with the requested prefix*, resolve the key and event type, + /// and remove the prefix when reading the key. + /// + public bool TryParseKeyNotification(ReadOnlySpan keyPrefix, out KeyNotification notification) + => KeyNotification.TryParse(keyPrefix, in _channel, in _message, out notification); +} diff --git a/src/StackExchange.Redis/ChannelMessageQueue.cs b/src/StackExchange.Redis/ChannelMessageQueue.cs index e58fb393b..9f962e52a 100644 --- a/src/StackExchange.Redis/ChannelMessageQueue.cs +++ b/src/StackExchange.Redis/ChannelMessageQueue.cs @@ -1,385 +1,353 @@ using System; +using System.Buffers.Text; using System.Collections.Generic; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; #if NETCOREAPP3_1 +using System.Diagnostics; using System.Reflection; #endif -namespace StackExchange.Redis +namespace StackExchange.Redis; + +/// +/// Represents a message queue of ordered pub/sub notifications. +/// +/// +/// To create a ChannelMessageQueue, use +/// or . +/// +public sealed class ChannelMessageQueue : IAsyncEnumerable { + private readonly Channel _queue; + /// - /// Represents a message that is broadcast via publish/subscribe. + /// The Channel that was subscribed for this queue. /// - public readonly struct ChannelMessage - { - // this is *smaller* than storing a RedisChannel for the subscribed channel - private readonly ChannelMessageQueue _queue; + public RedisChannel Channel { get; } - /// - /// The Channel:Message string representation. - /// - public override string ToString() => ((string?)Channel) + ":" + ((string?)Message); + private RedisSubscriber? _parent; - /// - public override int GetHashCode() => Channel.GetHashCode() ^ Message.GetHashCode(); - - /// - public override bool Equals(object? obj) => obj is ChannelMessage cm - && cm.Channel == Channel && cm.Message == Message; + /// + /// The string representation of this channel. + /// + public override string? ToString() => (string?)Channel; - internal ChannelMessage(ChannelMessageQueue queue, in RedisChannel channel, in RedisValue value) - { - _queue = queue; - Channel = channel; - Message = value; - } + /// + /// An awaitable task the indicates completion of the queue (including drain of data). + /// + public Task Completion => _queue.Reader.Completion; - /// - /// The channel that the subscription was created from. - /// - public RedisChannel SubscriptionChannel => _queue.Channel; - - /// - /// The channel that the message was broadcast to. - /// - public RedisChannel Channel { get; } - - /// - /// The value that was broadcast. - /// - public RedisValue Message { get; } - - /// - /// Checks if 2 messages are .Equal(). - /// - public static bool operator ==(ChannelMessage left, ChannelMessage right) => left.Equals(right); - - /// - /// Checks if 2 messages are not .Equal(). - /// - public static bool operator !=(ChannelMessage left, ChannelMessage right) => !left.Equals(right); + internal ChannelMessageQueue(in RedisChannel redisChannel, RedisSubscriber parent) + { + Channel = redisChannel; + _parent = parent; + _queue = System.Threading.Channels.Channel.CreateUnbounded(s_ChannelOptions); } - /// - /// Represents a message queue of ordered pub/sub notifications. - /// - /// - /// To create a ChannelMessageQueue, use - /// or . - /// - public sealed class ChannelMessageQueue : IAsyncEnumerable + private static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions { - private readonly Channel _queue; + SingleWriter = true, SingleReader = false, AllowSynchronousContinuations = false, + }; - /// - /// The Channel that was subscribed for this queue. - /// - public RedisChannel Channel { get; } - private RedisSubscriber? _parent; + private void Write(in RedisChannel channel, in RedisValue value) + { + var writer = _queue.Writer; + writer.TryWrite(new ChannelMessage(this, channel, value)); + } - /// - /// The string representation of this channel. - /// - public override string? ToString() => (string?)Channel; + /// + /// Consume a message from the channel. + /// + /// The to use. + public ValueTask ReadAsync(CancellationToken cancellationToken = default) + => _queue.Reader.ReadAsync(cancellationToken); - /// - /// An awaitable task the indicates completion of the queue (including drain of data). - /// - public Task Completion => _queue.Reader.Completion; + /// + /// Attempt to synchronously consume a message from the channel. + /// + /// The read from the Channel. + public bool TryRead(out ChannelMessage item) => _queue.Reader.TryRead(out item); - internal ChannelMessageQueue(in RedisChannel redisChannel, RedisSubscriber parent) + /// + /// Attempt to query the backlog length of the queue. + /// + /// The (approximate) count of items in the Channel. + public bool TryGetCount(out int count) + { + // This is specific to netcoreapp3.1, because full framework was out of band and the new prop is present +#if NETCOREAPP3_1 + // get this using the reflection + try { - Channel = redisChannel; - _parent = parent; - _queue = System.Threading.Channels.Channel.CreateUnbounded(s_ChannelOptions); + var prop = + _queue.GetType().GetProperty("ItemsCountForDebugger", BindingFlags.Instance | BindingFlags.NonPublic); + if (prop is not null) + { + count = (int)prop.GetValue(_queue)!; + return true; + } } - - private static readonly UnboundedChannelOptions s_ChannelOptions = new UnboundedChannelOptions + catch (Exception ex) { - SingleWriter = true, - SingleReader = false, - AllowSynchronousContinuations = false, - }; - - private void Write(in RedisChannel channel, in RedisValue value) + Debug.WriteLine(ex.Message); // but ignore + } +#else + var reader = _queue.Reader; + if (reader.CanCount) { - var writer = _queue.Writer; - writer.TryWrite(new ChannelMessage(this, channel, value)); + count = reader.Count; + return true; } +#endif - /// - /// Consume a message from the channel. - /// - /// The to use. - public ValueTask ReadAsync(CancellationToken cancellationToken = default) - => _queue.Reader.ReadAsync(cancellationToken); - - /// - /// Attempt to synchronously consume a message from the channel. - /// - /// The read from the Channel. - public bool TryRead(out ChannelMessage item) => _queue.Reader.TryRead(out item); - - /// - /// Attempt to query the backlog length of the queue. - /// - /// The (approximate) count of items in the Channel. - public bool TryGetCount(out int count) + count = 0; + return false; + } + + private Delegate? _onMessageHandler; + + private void AssertOnMessage(Delegate handler) + { + if (handler == null) throw new ArgumentNullException(nameof(handler)); + if (Interlocked.CompareExchange(ref _onMessageHandler, handler, null) != null) + throw new InvalidOperationException("Only a single " + nameof(OnMessage) + " is allowed"); + } + + /// + /// Create a message loop that processes messages sequentially. + /// + /// The handler to run when receiving a message. + public void OnMessage(Action handler) + { + AssertOnMessage(handler); + + ThreadPool.QueueUserWorkItem( + state => ((ChannelMessageQueue)state!).OnMessageSyncImpl().RedisFireAndForget(), this); + } + + private async Task OnMessageSyncImpl() + { + var handler = (Action?)_onMessageHandler; + while (!Completion.IsCompleted) { - // This is specific to netcoreapp3.1, because full framework was out of band and the new prop is present -#if NETCOREAPP3_1 - // get this using the reflection + ChannelMessage next; try { - var prop = _queue.GetType().GetProperty("ItemsCountForDebugger", BindingFlags.Instance | BindingFlags.NonPublic); - if (prop is not null) - { - count = (int)prop.GetValue(_queue)!; - return true; - } + if (!TryRead(out next)) next = await ReadAsync().ForAwait(); } - catch { } -#else - var reader = _queue.Reader; - if (reader.CanCount) + catch (ChannelClosedException) { break; } // expected + catch (Exception ex) { - count = reader.Count; - return true; + _parent?.multiplexer?.OnInternalError(ex); + break; } -#endif - count = default; - return false; + try { handler?.Invoke(next); } + catch { } // matches MessageCompletable } + } - private Delegate? _onMessageHandler; - private void AssertOnMessage(Delegate handler) + internal static void Combine(ref ChannelMessageQueue? head, ChannelMessageQueue queue) + { + if (queue != null) { - if (handler == null) throw new ArgumentNullException(nameof(handler)); - if (Interlocked.CompareExchange(ref _onMessageHandler, handler, null) != null) - throw new InvalidOperationException("Only a single " + nameof(OnMessage) + " is allowed"); + // insert at the start of the linked-list + ChannelMessageQueue? old; + do + { + old = Volatile.Read(ref head); + queue._next = old; + } + // format and validator disagree on newline... + while (Interlocked.CompareExchange(ref head, queue, old) != old); } + } - /// - /// Create a message loop that processes messages sequentially. - /// - /// The handler to run when receiving a message. - public void OnMessage(Action handler) - { - AssertOnMessage(handler); + /// + /// Create a message loop that processes messages sequentially. + /// + /// The handler to execute when receiving a message. + public void OnMessage(Func handler) + { + AssertOnMessage(handler); - ThreadPool.QueueUserWorkItem( - state => ((ChannelMessageQueue)state!).OnMessageSyncImpl().RedisFireAndForget(), this); - } + ThreadPool.QueueUserWorkItem( + state => ((ChannelMessageQueue)state!).OnMessageAsyncImpl().RedisFireAndForget(), this); + } - private async Task OnMessageSyncImpl() + internal static void Remove(ref ChannelMessageQueue? head, ChannelMessageQueue queue) + { + if (queue is null) { - var handler = (Action?)_onMessageHandler; - while (!Completion.IsCompleted) - { - ChannelMessage next; - try { if (!TryRead(out next)) next = await ReadAsync().ForAwait(); } - catch (ChannelClosedException) { break; } // expected - catch (Exception ex) - { - _parent?.multiplexer?.OnInternalError(ex); - break; - } - - try { handler?.Invoke(next); } - catch { } // matches MessageCompletable - } + return; } - internal static void Combine(ref ChannelMessageQueue? head, ChannelMessageQueue queue) + bool found; + // if we fail due to a conflict, re-do from start + do { - if (queue != null) + var current = Volatile.Read(ref head); + if (current == null) return; // no queue? nothing to do + if (current == queue) { - // insert at the start of the linked-list - ChannelMessageQueue? old; - do + found = true; + // found at the head - then we need to change the head + if (Interlocked.CompareExchange(ref head, Volatile.Read(ref current._next), current) == current) { - old = Volatile.Read(ref head); - queue._next = old; + return; // success } - while (Interlocked.CompareExchange(ref head, queue, old) != old); } - } - - /// - /// Create a message loop that processes messages sequentially. - /// - /// The handler to execute when receiving a message. - public void OnMessage(Func handler) - { - AssertOnMessage(handler); - - ThreadPool.QueueUserWorkItem( - state => ((ChannelMessageQueue)state!).OnMessageAsyncImpl().RedisFireAndForget(), this); - } - - internal static void Remove(ref ChannelMessageQueue? head, ChannelMessageQueue queue) - { - if (queue is null) + else { - return; - } - - bool found; - // if we fail due to a conflict, re-do from start - do - { - var current = Volatile.Read(ref head); - if (current == null) return; // no queue? nothing to do - if (current == queue) - { - found = true; - // found at the head - then we need to change the head - if (Interlocked.CompareExchange(ref head, Volatile.Read(ref current._next), current) == current) - { - return; // success - } - } - else + ChannelMessageQueue? previous = current; + current = Volatile.Read(ref previous._next); + found = false; + do { - ChannelMessageQueue? previous = current; - current = Volatile.Read(ref previous._next); - found = false; - do + if (current == queue) { - if (current == queue) + found = true; + // found it, not at the head; remove the node + if (Interlocked.CompareExchange( + ref previous._next, + Volatile.Read(ref current._next), + current) == current) { - found = true; - // found it, not at the head; remove the node - if (Interlocked.CompareExchange(ref previous._next, Volatile.Read(ref current._next), current) == current) - { - return; // success - } - else - { - break; // exit the inner loop, and repeat the outer loop - } + return; // success + } + else + { + break; // exit the inner loop, and repeat the outer loop } - previous = current; - current = Volatile.Read(ref previous!._next); } - while (current != null); + + previous = current; + current = Volatile.Read(ref previous!._next); } + // format and validator disagree on newline... + while (current != null); } - while (found); } + // format and validator disagree on newline... + while (found); + } - internal static int Count(ref ChannelMessageQueue? head) + internal static int Count(ref ChannelMessageQueue? head) + { + var current = Volatile.Read(ref head); + int count = 0; + while (current != null) { - var current = Volatile.Read(ref head); - int count = 0; - while (current != null) - { - count++; - current = Volatile.Read(ref current._next); - } - return count; + count++; + current = Volatile.Read(ref current._next); } - internal static void WriteAll(ref ChannelMessageQueue head, in RedisChannel channel, in RedisValue message) + return count; + } + + internal static void WriteAll(ref ChannelMessageQueue head, in RedisChannel channel, in RedisValue message) + { + var current = Volatile.Read(ref head); + while (current != null) { - var current = Volatile.Read(ref head); - while (current != null) - { - current.Write(channel, message); - current = Volatile.Read(ref current._next); - } + current.Write(channel, message); + current = Volatile.Read(ref current._next); } + } - private ChannelMessageQueue? _next; + private ChannelMessageQueue? _next; - private async Task OnMessageAsyncImpl() + private async Task OnMessageAsyncImpl() + { + var handler = (Func?)_onMessageHandler; + while (!Completion.IsCompleted) { - var handler = (Func?)_onMessageHandler; - while (!Completion.IsCompleted) + ChannelMessage next; + try { - ChannelMessage next; - try { if (!TryRead(out next)) next = await ReadAsync().ForAwait(); } - catch (ChannelClosedException) { break; } // expected - catch (Exception ex) - { - _parent?.multiplexer?.OnInternalError(ex); - break; - } - - try - { - var task = handler?.Invoke(next); - if (task != null && task.Status != TaskStatus.RanToCompletion) await task.ForAwait(); - } - catch { } // matches MessageCompletable + if (!TryRead(out next)) next = await ReadAsync().ForAwait(); + } + catch (ChannelClosedException) { break; } // expected + catch (Exception ex) + { + _parent?.multiplexer?.OnInternalError(ex); + break; } - } - internal static void MarkAllCompleted(ref ChannelMessageQueue? head) - { - var current = Interlocked.Exchange(ref head, null); - while (current != null) + try { - current.MarkCompleted(); - current = Volatile.Read(ref current._next); + var task = handler?.Invoke(next); + if (task != null && task.Status != TaskStatus.RanToCompletion) await task.ForAwait(); } + catch { } // matches MessageCompletable } + } - private void MarkCompleted(Exception? error = null) + internal static void MarkAllCompleted(ref ChannelMessageQueue? head) + { + var current = Interlocked.Exchange(ref head, null); + while (current != null) { - _parent = null; - _queue.Writer.TryComplete(error); + current.MarkCompleted(); + current = Volatile.Read(ref current._next); } + } - internal void UnsubscribeImpl(Exception? error = null, CommandFlags flags = CommandFlags.None) - { - var parent = _parent; - _parent = null; - parent?.UnsubscribeAsync(Channel, null, this, flags); - _queue.Writer.TryComplete(error); - } + private void MarkCompleted(Exception? error = null) + { + _parent = null; + _queue.Writer.TryComplete(error); + } - internal async Task UnsubscribeAsyncImpl(Exception? error = null, CommandFlags flags = CommandFlags.None) + internal void UnsubscribeImpl(Exception? error = null, CommandFlags flags = CommandFlags.None) + { + var parent = _parent; + _parent = null; + parent?.UnsubscribeAsync(Channel, null, this, flags); + _queue.Writer.TryComplete(error); + } + + internal async Task UnsubscribeAsyncImpl(Exception? error = null, CommandFlags flags = CommandFlags.None) + { + var parent = _parent; + _parent = null; + if (parent != null) { - var parent = _parent; - _parent = null; - if (parent != null) - { - await parent.UnsubscribeAsync(Channel, null, this, flags).ForAwait(); - } - _queue.Writer.TryComplete(error); + await parent.UnsubscribeAsync(Channel, null, this, flags).ForAwait(); } - /// - /// Stop receiving messages on this channel. - /// - /// The flags to use when unsubscribing. - public void Unsubscribe(CommandFlags flags = CommandFlags.None) => UnsubscribeImpl(null, flags); + _queue.Writer.TryComplete(error); + } - /// - /// Stop receiving messages on this channel. - /// - /// The flags to use when unsubscribing. - public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => UnsubscribeAsyncImpl(null, flags); + /// + /// Stop receiving messages on this channel. + /// + /// The flags to use when unsubscribing. + public void Unsubscribe(CommandFlags flags = CommandFlags.None) => UnsubscribeImpl(null, flags); - /// + /// + /// Stop receiving messages on this channel. + /// + /// The flags to use when unsubscribing. + public Task UnsubscribeAsync(CommandFlags flags = CommandFlags.None) => UnsubscribeAsyncImpl(null, flags); + + /// #if NETCOREAPP3_0_OR_GREATER - public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - => _queue.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken); + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + // ReSharper disable once MethodSupportsCancellation - provided in GetAsyncEnumerator + => _queue.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken); #else - public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + while (await _queue.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) { - while (await _queue.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + while (_queue.Reader.TryRead(out var item)) { - while (_queue.Reader.TryRead(out var item)) - { - yield return item; - } + yield return item; } } -#endif } +#endif } diff --git a/src/StackExchange.Redis/ConnectionMultiplexer.cs b/src/StackExchange.Redis/ConnectionMultiplexer.cs index cc766338a..0c6148923 100644 --- a/src/StackExchange.Redis/ConnectionMultiplexer.cs +++ b/src/StackExchange.Redis/ConnectionMultiplexer.cs @@ -730,6 +730,7 @@ private static ConnectionMultiplexer ConnectImpl(ConfigurationOptions configurat ReadOnlySpan IInternalConnectionMultiplexer.GetServerSnapshot() => _serverSnapshot.AsSpan(); internal ReadOnlySpan GetServerSnapshot() => _serverSnapshot.AsSpan(); + internal ReadOnlyMemory GetServerSnaphotMemory() => _serverSnapshot.AsMemory(); internal sealed class ServerSnapshot : IEnumerable { public static ServerSnapshot Empty { get; } = new ServerSnapshot(Array.Empty(), 0); @@ -1281,6 +1282,10 @@ public long OperationCount } } + // note that the RedisChannel->byte[] converter is always direct, so this is not an alloc + // (we deal with channels far less frequently, so pay the encoding cost up-front) + internal byte[] ChannelPrefix => ((byte[]?)RawConfig.ChannelPrefix) ?? []; + /// /// Reconfigure the current connections based on the existing configuration. /// diff --git a/src/StackExchange.Redis/Format.cs b/src/StackExchange.Redis/Format.cs index 86aa9910d..a76b77afc 100644 --- a/src/StackExchange.Redis/Format.cs +++ b/src/StackExchange.Redis/Format.cs @@ -468,6 +468,31 @@ internal static int FormatDouble(double value, Span destination) #endif } + internal static int FormatDouble(double value, Span destination) + { + string s; + if (double.IsInfinity(value)) + { + s = double.IsPositiveInfinity(value) ? "+inf" : "-inf"; + if (!s.AsSpan().TryCopyTo(destination)) ThrowFormatFailed(); + return 4; + } + +#if NET + if (!value.TryFormat(destination, out int len, "G17", NumberFormatInfo.InvariantInfo)) + { + ThrowFormatFailed(); + } + + return len; +#else + s = value.ToString("G17", NumberFormatInfo.InvariantInfo); // this looks inefficient, but is how Utf8Formatter works too, just: more direct + if (s.Length > destination.Length) ThrowFormatFailed(); + s.AsSpan().CopyTo(destination); + return s.Length; +#endif + } + internal static int MeasureInt64(long value) { Span valueSpan = stackalloc byte[MaxInt64TextLen]; @@ -481,12 +506,38 @@ internal static int FormatInt64(long value, Span destination) return len; } + internal static int FormatInt64(long value, Span destination) + { +#if NET + if (!value.TryFormat(destination, out var len)) + ThrowFormatFailed(); + return len; +#else + Span buffer = stackalloc byte[MaxInt64TextLen]; + var bytes = FormatInt64(value, buffer); + return Encoding.UTF8.GetChars(buffer.Slice(0, bytes), destination); +#endif + } + internal static int MeasureUInt64(ulong value) { Span valueSpan = stackalloc byte[MaxInt64TextLen]; return FormatUInt64(value, valueSpan); } + internal static int FormatUInt64(ulong value, Span destination) + { +#if NET + if (!value.TryFormat(destination, out var len)) + ThrowFormatFailed(); + return len; +#else + Span buffer = stackalloc byte[MaxInt64TextLen]; + var bytes = FormatUInt64(value, buffer); + return Encoding.UTF8.GetChars(buffer.Slice(0, bytes), destination); +#endif + } + internal static int FormatUInt64(ulong value, Span destination) { if (!Utf8Formatter.TryFormat(value, destination, out var len)) @@ -501,6 +552,19 @@ internal static int FormatInt32(int value, Span destination) return len; } + internal static int FormatInt32(int value, Span destination) + { +#if NET + if (!value.TryFormat(destination, out var len)) + ThrowFormatFailed(); + return len; +#else + Span buffer = stackalloc byte[MaxInt32TextLen]; + var bytes = FormatInt32(value, buffer); + return Encoding.UTF8.GetChars(buffer.Slice(0, bytes), destination); +#endif + } + internal static bool TryParseVersion(ReadOnlySpan input, [NotNullWhen(true)] out Version? version) { #if NETCOREAPP3_1_OR_GREATER diff --git a/src/StackExchange.Redis/FrameworkShims.cs b/src/StackExchange.Redis/FrameworkShims.cs index 9472df9ae..c0fe4cb1d 100644 --- a/src/StackExchange.Redis/FrameworkShims.cs +++ b/src/StackExchange.Redis/FrameworkShims.cs @@ -15,6 +15,18 @@ internal static class IsExternalInit { } } #endif +#if !NET9_0_OR_GREATER +namespace System.Runtime.CompilerServices +{ + // see https://learn.microsoft.com/dotnet/api/system.runtime.compilerservices.overloadresolutionpriorityattribute + [AttributeUsage(AttributeTargets.Constructor | AttributeTargets.Method | AttributeTargets.Property, Inherited = false)] + internal sealed class OverloadResolutionPriorityAttribute(int priority) : Attribute + { + public int Priority => priority; + } +} +#endif + #if !(NETCOREAPP || NETSTANDARD2_1_OR_GREATER) namespace System.Text @@ -31,6 +43,33 @@ public static unsafe int GetBytes(this Encoding encoding, ReadOnlySpan sou } } } + + public static unsafe int GetChars(this Encoding encoding, ReadOnlySpan source, Span destination) + { + fixed (byte* bPtr = source) + { + fixed (char* cPtr = destination) + { + return encoding.GetChars(bPtr, source.Length, cPtr, destination.Length); + } + } + } + + public static unsafe int GetCharCount(this Encoding encoding, ReadOnlySpan source) + { + fixed (byte* bPtr = source) + { + return encoding.GetCharCount(bPtr, source.Length); + } + } + + public static unsafe string GetString(this Encoding encoding, ReadOnlySpan source) + { + fixed (byte* bPtr = source) + { + return encoding.GetString(bPtr, source.Length); + } + } } } #endif diff --git a/src/StackExchange.Redis/KeyNotification.cs b/src/StackExchange.Redis/KeyNotification.cs new file mode 100644 index 000000000..3427c4dce --- /dev/null +++ b/src/StackExchange.Redis/KeyNotification.cs @@ -0,0 +1,497 @@ +using System; +using System.Buffers; +using System.Buffers.Text; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Text; +using static StackExchange.Redis.KeyNotificationChannels; +namespace StackExchange.Redis; + +/// +/// Represents keyspace and keyevent notifications, with utility methods for accessing the component data. Additionally, +/// since notifications can be high volume, a range of utility APIs is provided for avoiding allocations, in particular +/// to assist in filtering and inspecting the key without performing string allocations and substring operations. +/// In particular, note that this allows use with the alt-lookup (span-based) APIs on dictionaries. +/// +public readonly ref struct KeyNotification +{ + // effectively we just wrap a channel, but: we've pre-validated that things make sense + private readonly RedisChannel _channel; + private readonly RedisValue _value; + private readonly int _keyOffset; // used to efficiently strip key prefixes + + // this type has been designed with the intent of being able to move the entire thing alloc-free in some future + // high-throughput callback, potentially with a ReadOnlySpan field for the key fragment; this is + // not implemented currently, but is why this is a ref struct + + /// + /// If the channel is either a keyspace or keyevent notification, resolve the key and event type. + /// + public static bool TryParse(scoped in RedisChannel channel, scoped in RedisValue value, out KeyNotification notification) + { + // validate that it looks reasonable + var span = channel.Span; + + // KeySpaceStart and KeyEventStart are the same size, see KeyEventPrefix_KeySpacePrefix_Length_Matches + if (span.Length >= KeySpacePrefix.Length + MinSuffixBytes) + { + // check that the prefix is valid, i.e. "__keyspace@" or "__keyevent@" + var prefix = span.Slice(0, KeySpacePrefix.Length); + var hash = prefix.Hash64(); + switch (hash) + { + case KeySpacePrefix.Hash when KeySpacePrefix.Is(hash, prefix): + case KeyEventPrefix.Hash when KeyEventPrefix.Is(hash, prefix): + // check that there is *something* non-empty after the prefix, with __: as the suffix (we don't verify *what*) + if (span.Slice(KeySpacePrefix.Length).IndexOf("__:"u8) > 0) + { + notification = new KeyNotification(in channel, in value); + return true; + } + + break; + } + } + + notification = default; + return false; + } + + /// + /// If the channel is either a keyspace or keyevent notification *with the requested prefix*, resolve the key and event type, + /// and remove the prefix when reading the key. + /// + public static bool TryParse(scoped in ReadOnlySpan keyPrefix, scoped in RedisChannel channel, scoped in RedisValue value, out KeyNotification notification) + { + if (TryParse(in channel, in value, out notification) && notification.KeyStartsWith(keyPrefix)) + { + notification = notification.WithKeySlice(keyPrefix.Length); + return true; + } + + notification = default; + return false; + } + + internal KeyNotification WithKeySlice(int keyPrefixLength) + { + KeyNotification result = this; + Unsafe.AsRef(in result._keyOffset) = keyPrefixLength; + return result; + } + + private const int MinSuffixBytes = 5; // need "0__:x" or similar after prefix + + /// + /// The channel associated with this notification. + /// + public RedisChannel GetChannel() => _channel; + + /// + /// The payload associated with this notification. + /// + public RedisValue GetValue() => _value; + + internal KeyNotification(scoped in RedisChannel channel, scoped in RedisValue value) + { + _channel = channel; + _value = value; + _keyOffset = 0; + } + + internal int KeyOffset => _keyOffset; + + /// + /// The database the key is in. If the database cannot be parsed, -1 is returned. + /// + public int Database + { + get + { + // prevalidated format, so we can just skip past the prefix (except for the default value) + if (_channel.IsNull) return -1; + var span = _channel.Span.Slice(KeySpacePrefix.Length); // also works for KeyEventPrefix + var end = span.IndexOf((byte)'_'); // expecting "__:foo" - we'll just stop at the underscore + if (end <= 0) return -1; + + span = span.Slice(0, end); + return Utf8Parser.TryParse(span, out int database, out var bytes) + && bytes == end ? database : -1; + } + } + + /// + /// The key associated with this event. + /// + /// Note that this will allocate a copy of the key bytes; to avoid allocations, + /// the , , and APIs can be used. + public RedisKey GetKey() + { + if (IsKeySpace) + { + // then the channel contains the key, and the payload contains the event-type + return ChannelSuffix.Slice(_keyOffset).ToArray(); // create an isolated copy + } + + if (IsKeyEvent) + { + // then the channel contains the event-type, and the payload contains the key + byte[]? blob = _value; + if (_keyOffset != 0 & blob is not null) + { + return blob.AsSpan(_keyOffset).ToArray(); + } + return blob; + } + + return RedisKey.Null; + } + + /// + /// Get the number of bytes in the key. + /// + /// If a scratch-buffer is required, it may be preferable to use , which is less expensive. + public int GetKeyByteCount() + { + if (IsKeySpace) + { + return ChannelSuffix.Length - _keyOffset; + } + + if (IsKeyEvent) + { + return _value.GetByteCount() - _keyOffset; + } + + return 0; + } + + /// + /// Get the maximum number of bytes in the key. + /// + public int GetKeyMaxByteCount() + { + if (IsKeySpace) + { + return ChannelSuffix.Length - _keyOffset; + } + + if (IsKeyEvent) + { + return _value.GetMaxByteCount() - _keyOffset; + } + + return 0; + } + + /// + /// Get the maximum number of characters in the key, interpreting as UTF8. + /// + public int GetKeyMaxCharCount() + { + if (IsKeySpace) + { + return Encoding.UTF8.GetMaxCharCount(ChannelSuffix.Length - _keyOffset); + } + + if (IsKeyEvent) + { + return _value.GetMaxCharCount() - _keyOffset; + } + + return 0; + } + + /// + /// Get the number of characters in the key, interpreting as UTF8. + /// + /// If a scratch-buffer is required, it may be preferable to use , which is less expensive. + public int GetKeyCharCount() + { + if (IsKeySpace) + { + return Encoding.UTF8.GetCharCount(ChannelSuffix.Slice(_keyOffset)); + } + + if (IsKeyEvent) + { + return _keyOffset == 0 ? _value.GetCharCount() : SlowMeasure(in this); + } + + return 0; + + static int SlowMeasure(in KeyNotification value) + { + var span = value.GetKeySpan(out var lease, stackalloc byte[128]); + var result = Encoding.UTF8.GetCharCount(span); + Return(lease); + return result; + } + } + + private ReadOnlySpan GetKeySpan(out byte[]? lease, Span buffer) // buffer typically stackalloc + { + lease = null; + if (_value.TryGetSpan(out var direct)) + { + return direct.Slice(_keyOffset); + } + var count = _value.GetMaxByteCount(); + if (count > buffer.Length) + { + buffer = lease = ArrayPool.Shared.Rent(count); + } + count = _value.CopyTo(buffer); + return buffer.Slice(_keyOffset, count - _keyOffset); + } + + private static void Return(byte[]? lease) + { + if (lease is not null) ArrayPool.Shared.Return(lease); + } + + /// + /// Attempt to copy the bytes from the key to a buffer, returning the number of bytes written. + /// + public bool TryCopyKey(Span destination, out int bytesWritten) + { + if (IsKeySpace) + { + var suffix = ChannelSuffix.Slice(_keyOffset); + bytesWritten = suffix.Length; // assume success + if (bytesWritten <= destination.Length) + { + suffix.CopyTo(destination); + return true; + } + } + + if (IsKeyEvent) + { + if (_value.TryGetSpan(out var direct)) + { + bytesWritten = direct.Length - _keyOffset; // assume success + if (bytesWritten <= destination.Length) + { + direct.Slice(_keyOffset).CopyTo(destination); + return true; + } + bytesWritten = 0; + return false; + } + + if (_keyOffset == 0) + { + // get the value to do the hard work + bytesWritten = _value.GetByteCount(); + if (bytesWritten <= destination.Length) + { + _value.CopyTo(destination); + return true; + } + bytesWritten = 0; + return false; + } + + return SlowCopy(in this, destination, out bytesWritten); + + static bool SlowCopy(in KeyNotification value, Span destination, out int bytesWritten) + { + var span = value.GetKeySpan(out var lease, stackalloc byte[128]); + bool result = span.TryCopyTo(destination); + bytesWritten = result ? span.Length : 0; + Return(lease); + return result; + } + } + + bytesWritten = 0; + return false; + } + + /// + /// Attempt to copy the bytes from the key to a buffer, returning the number of bytes written. + /// + public bool TryCopyKey(Span destination, out int charsWritten) + { + if (IsKeySpace) + { + var suffix = ChannelSuffix.Slice(_keyOffset); + if (Encoding.UTF8.GetMaxCharCount(suffix.Length) <= destination.Length || + Encoding.UTF8.GetCharCount(suffix) <= destination.Length) + { + charsWritten = Encoding.UTF8.GetChars(suffix, destination); + return true; + } + } + + if (IsKeyEvent) + { + if (_keyOffset == 0) // can use short-cut + { + if (_value.GetMaxCharCount() <= destination.Length || _value.GetCharCount() <= destination.Length) + { + charsWritten = _value.CopyTo(destination); + return true; + } + } + var span = GetKeySpan(out var lease, stackalloc byte[128]); + charsWritten = 0; + bool result = false; + if (Encoding.UTF8.GetMaxCharCount(span.Length) <= destination.Length || + Encoding.UTF8.GetCharCount(span) <= destination.Length) + { + charsWritten = Encoding.UTF8.GetChars(span, destination); + result = true; + } + Return(lease); + return result; + } + + charsWritten = 0; + return false; + } + + /// + /// Get the portion of the channel after the "__{keyspace|keyevent}@{db}__:". + /// + private ReadOnlySpan ChannelSuffix + { + get + { + var span = _channel.Span; + var index = span.IndexOf("__:"u8); + return index > 0 ? span.Slice(index + 3) : default; + } + } + + /// + /// Indicates whether this notification is of the given type, specified as raw bytes. + /// + /// This is especially useful for working with unknown event types, but repeated calls to this method will be more expensive than + /// a single successful call to . + public bool IsType(ReadOnlySpan type) + { + if (IsKeySpace) + { + if (_value.TryGetSpan(out var direct)) + { + return direct.SequenceEqual(type); + } + + const int MAX_STACK = 64; + byte[]? lease = null; + var maxCount = _value.GetMaxByteCount(); + Span localCopy = maxCount <= MAX_STACK + ? stackalloc byte[MAX_STACK] + : (lease = ArrayPool.Shared.Rent(maxCount)); + var count = _value.CopyTo(localCopy); + bool result = localCopy.Slice(0, count).SequenceEqual(type); + if (lease is not null) ArrayPool.Shared.Return(lease); + return result; + } + + if (IsKeyEvent) + { + return ChannelSuffix.SequenceEqual(type); + } + + return false; + } + + /// + /// The type of notification associated with this event, if it is well-known - otherwise . + /// + /// Unexpected values can be processed manually from the and . + public KeyNotificationType Type + { + get + { + if (IsKeySpace) + { + // then the channel contains the key, and the payload contains the event-type + var count = _value.GetByteCount(); + if (count >= KeyNotificationTypeFastHash.MinBytes & count <= KeyNotificationTypeFastHash.MaxBytes) + { + if (_value.TryGetSpan(out var direct)) + { + return KeyNotificationTypeFastHash.Parse(direct); + } + else + { + Span localCopy = stackalloc byte[KeyNotificationTypeFastHash.MaxBytes]; + return KeyNotificationTypeFastHash.Parse(localCopy.Slice(0, _value.CopyTo(localCopy))); + } + } + } + + if (IsKeyEvent) + { + // then the channel contains the event-type, and the payload contains the key + return KeyNotificationTypeFastHash.Parse(ChannelSuffix); + } + return KeyNotificationType.Unknown; + } + } + + /// + /// Indicates whether this notification originated from a keyspace notification, for example __keyspace@4__:mykey with payload set. + /// + public bool IsKeySpace + { + get + { + var span = _channel.Span; + return span.Length >= KeySpacePrefix.Length + MinSuffixBytes && KeySpacePrefix.Is(span.Hash64(), span.Slice(0, KeySpacePrefix.Length)); + } + } + + /// + /// Indicates whether this notification originated from a keyevent notification, for example __keyevent@4__:set with payload mykey. + /// + public bool IsKeyEvent + { + get + { + var span = _channel.Span; + return span.Length >= KeyEventPrefix.Length + MinSuffixBytes && KeyEventPrefix.Is(span.Hash64(), span.Slice(0, KeyEventPrefix.Length)); + } + } + + /// + /// Indicates whether the key associated with this notification starts with the specified prefix. + /// + /// This API is intended as a high-throughput filter API. + public bool KeyStartsWith(ReadOnlySpan prefix) // intentionally leading people to the BLOB API + { + if (IsKeySpace) + { + return ChannelSuffix.Slice(_keyOffset).StartsWith(prefix); + } + + if (IsKeyEvent) + { + if (_keyOffset == 0) return _value.StartsWith(prefix); + + var span = GetKeySpan(out var lease, stackalloc byte[128]); + bool result = span.StartsWith(prefix); + Return(lease); + return result; + } + + return false; + } +} + +internal static partial class KeyNotificationChannels +{ + [FastHash("__keyspace@")] + internal static partial class KeySpacePrefix + { + } + + [FastHash("__keyevent@")] + internal static partial class KeyEventPrefix + { + } +} diff --git a/src/StackExchange.Redis/KeyNotificationType.cs b/src/StackExchange.Redis/KeyNotificationType.cs new file mode 100644 index 000000000..cc4c74ef1 --- /dev/null +++ b/src/StackExchange.Redis/KeyNotificationType.cs @@ -0,0 +1,69 @@ +namespace StackExchange.Redis; + +/// +/// The type of keyspace or keyevent notification. +/// +public enum KeyNotificationType +{ + // note: initially presented alphabetically, but: new values *must* be appended, not inserted + // (to preserve values of existing elements) +#pragma warning disable CS1591 // docs, redundant + Unknown = 0, + Append = 1, + Copy = 2, + Del = 3, + Expire = 4, + HDel = 5, + HExpired = 6, + HIncrByFloat = 7, + HIncrBy = 8, + HPersist = 9, + HSet = 10, + IncrByFloat = 11, + IncrBy = 12, + LInsert = 13, + LPop = 14, + LPush = 15, + LRem = 16, + LSet = 17, + LTrim = 18, + MoveFrom = 19, + MoveTo = 20, + Persist = 21, + RenameFrom = 22, + RenameTo = 23, + Restore = 24, + RPop = 25, + RPush = 26, + SAdd = 27, + Set = 28, + SetRange = 29, + SortStore = 30, + SRem = 31, + SPop = 32, + XAdd = 33, + XDel = 34, + XGroupCreateConsumer = 35, + XGroupCreate = 36, + XGroupDelConsumer = 37, + XGroupDestroy = 38, + XGroupSetId = 39, + XSetId = 40, + XTrim = 41, + ZAdd = 42, + ZDiffStore = 43, + ZInterStore = 44, + ZUnionStore = 45, + ZIncr = 46, + ZRemByRank = 47, + ZRemByScore = 48, + ZRem = 49, + + // side-effect notifications + Expired = 1000, + Evicted = 1001, + New = 1002, + Overwritten = 1003, + TypeChanged = 1004, // type_changed +#pragma warning restore CS1591 // docs, redundant +} diff --git a/src/StackExchange.Redis/KeyNotificationTypeFastHash.cs b/src/StackExchange.Redis/KeyNotificationTypeFastHash.cs new file mode 100644 index 000000000..bcf08bad2 --- /dev/null +++ b/src/StackExchange.Redis/KeyNotificationTypeFastHash.cs @@ -0,0 +1,413 @@ +using System; + +namespace StackExchange.Redis; + +/// +/// Internal helper type for fast parsing of key notification types, using [FastHash]. +/// +internal static partial class KeyNotificationTypeFastHash +{ + // these are checked by KeyNotificationTypeFastHash_MinMaxBytes_ReflectsActualLengths + public const int MinBytes = 3, MaxBytes = 21; + + public static KeyNotificationType Parse(ReadOnlySpan value) + { + var hash = value.Hash64(); + return hash switch + { + append.Hash when append.Is(hash, value) => KeyNotificationType.Append, + copy.Hash when copy.Is(hash, value) => KeyNotificationType.Copy, + del.Hash when del.Is(hash, value) => KeyNotificationType.Del, + expire.Hash when expire.Is(hash, value) => KeyNotificationType.Expire, + hdel.Hash when hdel.Is(hash, value) => KeyNotificationType.HDel, + hexpired.Hash when hexpired.Is(hash, value) => KeyNotificationType.HExpired, + hincrbyfloat.Hash when hincrbyfloat.Is(hash, value) => KeyNotificationType.HIncrByFloat, + hincrby.Hash when hincrby.Is(hash, value) => KeyNotificationType.HIncrBy, + hpersist.Hash when hpersist.Is(hash, value) => KeyNotificationType.HPersist, + hset.Hash when hset.Is(hash, value) => KeyNotificationType.HSet, + incrbyfloat.Hash when incrbyfloat.Is(hash, value) => KeyNotificationType.IncrByFloat, + incrby.Hash when incrby.Is(hash, value) => KeyNotificationType.IncrBy, + linsert.Hash when linsert.Is(hash, value) => KeyNotificationType.LInsert, + lpop.Hash when lpop.Is(hash, value) => KeyNotificationType.LPop, + lpush.Hash when lpush.Is(hash, value) => KeyNotificationType.LPush, + lrem.Hash when lrem.Is(hash, value) => KeyNotificationType.LRem, + lset.Hash when lset.Is(hash, value) => KeyNotificationType.LSet, + ltrim.Hash when ltrim.Is(hash, value) => KeyNotificationType.LTrim, + move_from.Hash when move_from.Is(hash, value) => KeyNotificationType.MoveFrom, + move_to.Hash when move_to.Is(hash, value) => KeyNotificationType.MoveTo, + persist.Hash when persist.Is(hash, value) => KeyNotificationType.Persist, + rename_from.Hash when rename_from.Is(hash, value) => KeyNotificationType.RenameFrom, + rename_to.Hash when rename_to.Is(hash, value) => KeyNotificationType.RenameTo, + restore.Hash when restore.Is(hash, value) => KeyNotificationType.Restore, + rpop.Hash when rpop.Is(hash, value) => KeyNotificationType.RPop, + rpush.Hash when rpush.Is(hash, value) => KeyNotificationType.RPush, + sadd.Hash when sadd.Is(hash, value) => KeyNotificationType.SAdd, + set.Hash when set.Is(hash, value) => KeyNotificationType.Set, + setrange.Hash when setrange.Is(hash, value) => KeyNotificationType.SetRange, + sortstore.Hash when sortstore.Is(hash, value) => KeyNotificationType.SortStore, + srem.Hash when srem.Is(hash, value) => KeyNotificationType.SRem, + spop.Hash when spop.Is(hash, value) => KeyNotificationType.SPop, + xadd.Hash when xadd.Is(hash, value) => KeyNotificationType.XAdd, + xdel.Hash when xdel.Is(hash, value) => KeyNotificationType.XDel, + xgroup_createconsumer.Hash when xgroup_createconsumer.Is(hash, value) => KeyNotificationType.XGroupCreateConsumer, + xgroup_create.Hash when xgroup_create.Is(hash, value) => KeyNotificationType.XGroupCreate, + xgroup_delconsumer.Hash when xgroup_delconsumer.Is(hash, value) => KeyNotificationType.XGroupDelConsumer, + xgroup_destroy.Hash when xgroup_destroy.Is(hash, value) => KeyNotificationType.XGroupDestroy, + xgroup_setid.Hash when xgroup_setid.Is(hash, value) => KeyNotificationType.XGroupSetId, + xsetid.Hash when xsetid.Is(hash, value) => KeyNotificationType.XSetId, + xtrim.Hash when xtrim.Is(hash, value) => KeyNotificationType.XTrim, + zadd.Hash when zadd.Is(hash, value) => KeyNotificationType.ZAdd, + zdiffstore.Hash when zdiffstore.Is(hash, value) => KeyNotificationType.ZDiffStore, + zinterstore.Hash when zinterstore.Is(hash, value) => KeyNotificationType.ZInterStore, + zunionstore.Hash when zunionstore.Is(hash, value) => KeyNotificationType.ZUnionStore, + zincr.Hash when zincr.Is(hash, value) => KeyNotificationType.ZIncr, + zrembyrank.Hash when zrembyrank.Is(hash, value) => KeyNotificationType.ZRemByRank, + zrembyscore.Hash when zrembyscore.Is(hash, value) => KeyNotificationType.ZRemByScore, + zrem.Hash when zrem.Is(hash, value) => KeyNotificationType.ZRem, + expired.Hash when expired.Is(hash, value) => KeyNotificationType.Expired, + evicted.Hash when evicted.Is(hash, value) => KeyNotificationType.Evicted, + _new.Hash when _new.Is(hash, value) => KeyNotificationType.New, + overwritten.Hash when overwritten.Is(hash, value) => KeyNotificationType.Overwritten, + type_changed.Hash when type_changed.Is(hash, value) => KeyNotificationType.TypeChanged, + _ => KeyNotificationType.Unknown, + }; + } + + internal static ReadOnlySpan GetRawBytes(KeyNotificationType type) + { + return type switch + { + KeyNotificationType.Append => append.U8, + KeyNotificationType.Copy => copy.U8, + KeyNotificationType.Del => del.U8, + KeyNotificationType.Expire => expire.U8, + KeyNotificationType.HDel => hdel.U8, + KeyNotificationType.HExpired => hexpired.U8, + KeyNotificationType.HIncrByFloat => hincrbyfloat.U8, + KeyNotificationType.HIncrBy => hincrby.U8, + KeyNotificationType.HPersist => hpersist.U8, + KeyNotificationType.HSet => hset.U8, + KeyNotificationType.IncrByFloat => incrbyfloat.U8, + KeyNotificationType.IncrBy => incrby.U8, + KeyNotificationType.LInsert => linsert.U8, + KeyNotificationType.LPop => lpop.U8, + KeyNotificationType.LPush => lpush.U8, + KeyNotificationType.LRem => lrem.U8, + KeyNotificationType.LSet => lset.U8, + KeyNotificationType.LTrim => ltrim.U8, + KeyNotificationType.MoveFrom => move_from.U8, + KeyNotificationType.MoveTo => move_to.U8, + KeyNotificationType.Persist => persist.U8, + KeyNotificationType.RenameFrom => rename_from.U8, + KeyNotificationType.RenameTo => rename_to.U8, + KeyNotificationType.Restore => restore.U8, + KeyNotificationType.RPop => rpop.U8, + KeyNotificationType.RPush => rpush.U8, + KeyNotificationType.SAdd => sadd.U8, + KeyNotificationType.Set => set.U8, + KeyNotificationType.SetRange => setrange.U8, + KeyNotificationType.SortStore => sortstore.U8, + KeyNotificationType.SRem => srem.U8, + KeyNotificationType.SPop => spop.U8, + KeyNotificationType.XAdd => xadd.U8, + KeyNotificationType.XDel => xdel.U8, + KeyNotificationType.XGroupCreateConsumer => xgroup_createconsumer.U8, + KeyNotificationType.XGroupCreate => xgroup_create.U8, + KeyNotificationType.XGroupDelConsumer => xgroup_delconsumer.U8, + KeyNotificationType.XGroupDestroy => xgroup_destroy.U8, + KeyNotificationType.XGroupSetId => xgroup_setid.U8, + KeyNotificationType.XSetId => xsetid.U8, + KeyNotificationType.XTrim => xtrim.U8, + KeyNotificationType.ZAdd => zadd.U8, + KeyNotificationType.ZDiffStore => zdiffstore.U8, + KeyNotificationType.ZInterStore => zinterstore.U8, + KeyNotificationType.ZUnionStore => zunionstore.U8, + KeyNotificationType.ZIncr => zincr.U8, + KeyNotificationType.ZRemByRank => zrembyrank.U8, + KeyNotificationType.ZRemByScore => zrembyscore.U8, + KeyNotificationType.ZRem => zrem.U8, + KeyNotificationType.Expired => expired.U8, + KeyNotificationType.Evicted => evicted.U8, + KeyNotificationType.New => _new.U8, + KeyNotificationType.Overwritten => overwritten.U8, + KeyNotificationType.TypeChanged => type_changed.U8, + _ => Throw(), + }; + static ReadOnlySpan Throw() => throw new ArgumentOutOfRangeException(nameof(type)); + } + +#pragma warning disable SA1300, CS8981 + // ReSharper disable InconsistentNaming + [FastHash] + internal static partial class append + { + } + + [FastHash] + internal static partial class copy + { + } + + [FastHash] + internal static partial class del + { + } + + [FastHash] + internal static partial class expire + { + } + + [FastHash] + internal static partial class hdel + { + } + + [FastHash] + internal static partial class hexpired + { + } + + [FastHash] + internal static partial class hincrbyfloat + { + } + + [FastHash] + internal static partial class hincrby + { + } + + [FastHash] + internal static partial class hpersist + { + } + + [FastHash] + internal static partial class hset + { + } + + [FastHash] + internal static partial class incrbyfloat + { + } + + [FastHash] + internal static partial class incrby + { + } + + [FastHash] + internal static partial class linsert + { + } + + [FastHash] + internal static partial class lpop + { + } + + [FastHash] + internal static partial class lpush + { + } + + [FastHash] + internal static partial class lrem + { + } + + [FastHash] + internal static partial class lset + { + } + + [FastHash] + internal static partial class ltrim + { + } + + [FastHash("move_from")] // by default, the generator interprets underscore as hyphen + internal static partial class move_from + { + } + + [FastHash("move_to")] // by default, the generator interprets underscore as hyphen + internal static partial class move_to + { + } + + [FastHash] + internal static partial class persist + { + } + + [FastHash("rename_from")] // by default, the generator interprets underscore as hyphen + internal static partial class rename_from + { + } + + [FastHash("rename_to")] // by default, the generator interprets underscore as hyphen + internal static partial class rename_to + { + } + + [FastHash] + internal static partial class restore + { + } + + [FastHash] + internal static partial class rpop + { + } + + [FastHash] + internal static partial class rpush + { + } + + [FastHash] + internal static partial class sadd + { + } + + [FastHash] + internal static partial class set + { + } + + [FastHash] + internal static partial class setrange + { + } + + [FastHash] + internal static partial class sortstore + { + } + + [FastHash] + internal static partial class srem + { + } + + [FastHash] + internal static partial class spop + { + } + + [FastHash] + internal static partial class xadd + { + } + + [FastHash] + internal static partial class xdel + { + } + + [FastHash] // note: becomes hyphenated + internal static partial class xgroup_createconsumer + { + } + + [FastHash] // note: becomes hyphenated + internal static partial class xgroup_create + { + } + + [FastHash] // note: becomes hyphenated + internal static partial class xgroup_delconsumer + { + } + + [FastHash] // note: becomes hyphenated + internal static partial class xgroup_destroy + { + } + + [FastHash] // note: becomes hyphenated + internal static partial class xgroup_setid + { + } + + [FastHash] + internal static partial class xsetid + { + } + + [FastHash] + internal static partial class xtrim + { + } + + [FastHash] + internal static partial class zadd + { + } + + [FastHash] + internal static partial class zdiffstore + { + } + + [FastHash] + internal static partial class zinterstore + { + } + + [FastHash] + internal static partial class zunionstore + { + } + + [FastHash] + internal static partial class zincr + { + } + + [FastHash] + internal static partial class zrembyrank + { + } + + [FastHash] + internal static partial class zrembyscore + { + } + + [FastHash] + internal static partial class zrem + { + } + + [FastHash] + internal static partial class expired + { + } + + [FastHash] + internal static partial class evicted + { + } + + [FastHash("new")] + internal static partial class _new // it isn't worth making the code-gen keyword aware + { + } + + [FastHash] + internal static partial class overwritten + { + } + + [FastHash("type_changed")] // by default, the generator interprets underscore as hyphen + internal static partial class type_changed + { + } + + // ReSharper restore InconsistentNaming +#pragma warning restore SA1300, CS8981 +} diff --git a/src/StackExchange.Redis/Message.cs b/src/StackExchange.Redis/Message.cs index 386d426d8..37472fd4c 100644 --- a/src/StackExchange.Redis/Message.cs +++ b/src/StackExchange.Redis/Message.cs @@ -890,7 +890,8 @@ protected CommandKeyBase(int db, CommandFlags flags, RedisCommand command, in Re private sealed class CommandChannelMessage : CommandChannelBase { - public CommandChannelMessage(int db, CommandFlags flags, RedisCommand command, in RedisChannel channel) : base(db, flags, command, channel) + public CommandChannelMessage(int db, CommandFlags flags, RedisCommand command, in RedisChannel channel) + : base(db, flags, command, channel) { } protected override void WriteImpl(PhysicalConnection physical) { @@ -903,7 +904,8 @@ protected override void WriteImpl(PhysicalConnection physical) private sealed class CommandChannelValueMessage : CommandChannelBase { private readonly RedisValue value; - public CommandChannelValueMessage(int db, CommandFlags flags, RedisCommand command, in RedisChannel channel, in RedisValue value) : base(db, flags, command, channel) + public CommandChannelValueMessage(int db, CommandFlags flags, RedisCommand command, in RedisChannel channel, in RedisValue value) + : base(db, flags, command, channel) { value.AssertNotNull(); this.value = value; @@ -1746,7 +1748,8 @@ protected override void WriteImpl(PhysicalConnection physical) private sealed class CommandValueChannelMessage : CommandChannelBase { private readonly RedisValue value; - public CommandValueChannelMessage(int db, CommandFlags flags, RedisCommand command, in RedisValue value, in RedisChannel channel) : base(db, flags, command, channel) + public CommandValueChannelMessage(int db, CommandFlags flags, RedisCommand command, in RedisValue value, in RedisChannel channel) + : base(db, flags, command, channel) { value.AssertNotNull(); this.value = value; diff --git a/src/StackExchange.Redis/PhysicalConnection.cs b/src/StackExchange.Redis/PhysicalConnection.cs index 57bcd608d..857902f48 100644 --- a/src/StackExchange.Redis/PhysicalConnection.cs +++ b/src/StackExchange.Redis/PhysicalConnection.cs @@ -90,7 +90,7 @@ public PhysicalConnection(PhysicalBridge bridge) lastBeatTickCount = 0; connectionType = bridge.ConnectionType; _bridge = new WeakReference(bridge); - ChannelPrefix = bridge.Multiplexer.RawConfig.ChannelPrefix; + ChannelPrefix = bridge.Multiplexer.ChannelPrefix; if (ChannelPrefix?.Length == 0) ChannelPrefix = null; // null tests are easier than null+empty var endpoint = bridge.ServerEndPoint.EndPoint; _physicalName = connectionType + "#" + Interlocked.Increment(ref totalCount) + "@" + Format.ToString(endpoint); @@ -820,7 +820,7 @@ internal void Write(in RedisKey key) } internal void Write(in RedisChannel channel) - => WriteUnifiedPrefixedBlob(_ioPipe?.Output, ChannelPrefix, channel.Value); + => WriteUnifiedPrefixedBlob(_ioPipe?.Output, channel.IgnoreChannelPrefix ? null : ChannelPrefix, channel.Value); [MethodImpl(MethodImplOptions.AggressiveInlining)] internal void WriteBulkString(in RedisValue value) @@ -1999,7 +1999,7 @@ static bool TryGetMultiPubSubPayload(in RawResult value, out Sequence } } - private bool PeekChannelMessage(RedisCommand command, RedisChannel channel) + private bool PeekChannelMessage(RedisCommand command, in RedisChannel channel) { Message? msg; bool haveMsg; diff --git a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt index 91b0e1a43..6e96ed550 100644 --- a/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/StackExchange.Redis/PublicAPI/PublicAPI.Unshipped.txt @@ -1 +1,85 @@ -#nullable enable \ No newline at end of file +#nullable enable +StackExchange.Redis.ChannelMessage.TryParseKeyNotification(System.ReadOnlySpan keyPrefix, out StackExchange.Redis.KeyNotification notification) -> bool +StackExchange.Redis.KeyNotification +StackExchange.Redis.KeyNotification.GetChannel() -> StackExchange.Redis.RedisChannel +StackExchange.Redis.KeyNotification.GetKeyByteCount() -> int +StackExchange.Redis.KeyNotification.GetKeyCharCount() -> int +StackExchange.Redis.KeyNotification.GetKeyMaxByteCount() -> int +StackExchange.Redis.KeyNotification.GetKeyMaxCharCount() -> int +StackExchange.Redis.KeyNotification.GetValue() -> StackExchange.Redis.RedisValue +StackExchange.Redis.KeyNotification.IsType(System.ReadOnlySpan type) -> bool +StackExchange.Redis.KeyNotification.KeyStartsWith(System.ReadOnlySpan prefix) -> bool +StackExchange.Redis.KeyNotification.TryCopyKey(System.Span destination, out int charsWritten) -> bool +StackExchange.Redis.KeyNotification.Database.get -> int +StackExchange.Redis.KeyNotification.GetKey() -> StackExchange.Redis.RedisKey +StackExchange.Redis.KeyNotification.IsKeyEvent.get -> bool +StackExchange.Redis.KeyNotification.IsKeySpace.get -> bool +StackExchange.Redis.KeyNotification.KeyNotification() -> void +StackExchange.Redis.KeyNotification.TryCopyKey(System.Span destination, out int bytesWritten) -> bool +StackExchange.Redis.KeyNotification.Type.get -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.RedisValue.StartsWith(System.ReadOnlySpan value) -> bool +static StackExchange.Redis.KeyNotification.TryParse(scoped in StackExchange.Redis.RedisChannel channel, scoped in StackExchange.Redis.RedisValue value, out StackExchange.Redis.KeyNotification notification) -> bool +StackExchange.Redis.ChannelMessage.TryParseKeyNotification(out StackExchange.Redis.KeyNotification notification) -> bool +static StackExchange.Redis.KeyNotification.TryParse(scoped in System.ReadOnlySpan keyPrefix, scoped in StackExchange.Redis.RedisChannel channel, scoped in StackExchange.Redis.RedisValue value, out StackExchange.Redis.KeyNotification notification) -> bool +static StackExchange.Redis.RedisChannel.KeyEvent(StackExchange.Redis.KeyNotificationType type, int? database = null) -> StackExchange.Redis.RedisChannel +static StackExchange.Redis.RedisChannel.KeyEvent(System.ReadOnlySpan type, int? database) -> StackExchange.Redis.RedisChannel +static StackExchange.Redis.RedisChannel.KeySpacePattern(in StackExchange.Redis.RedisKey pattern, int? database = null) -> StackExchange.Redis.RedisChannel +StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Append = 1 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Copy = 2 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Del = 3 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Evicted = 1001 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Expire = 4 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Expired = 1000 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.HDel = 5 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.HExpired = 6 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.HIncrBy = 8 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.HIncrByFloat = 7 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.HPersist = 9 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.HSet = 10 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.IncrBy = 12 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.IncrByFloat = 11 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.LInsert = 13 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.LPop = 14 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.LPush = 15 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.LRem = 16 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.LSet = 17 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.LTrim = 18 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.MoveFrom = 19 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.MoveTo = 20 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.New = 1002 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Overwritten = 1003 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Persist = 21 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.RenameFrom = 22 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.RenameTo = 23 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Restore = 24 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.RPop = 25 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.RPush = 26 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.SAdd = 27 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Set = 28 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.SetRange = 29 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.SortStore = 30 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.SPop = 32 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.SRem = 31 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.TypeChanged = 1004 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.Unknown = 0 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XAdd = 33 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XDel = 34 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XGroupCreate = 36 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XGroupCreateConsumer = 35 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XGroupDelConsumer = 37 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XGroupDestroy = 38 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XGroupSetId = 39 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XSetId = 40 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.XTrim = 41 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.ZAdd = 42 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.ZDiffStore = 43 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.ZIncr = 46 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.ZInterStore = 44 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.ZRem = 49 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.ZRemByRank = 47 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.ZRemByScore = 48 -> StackExchange.Redis.KeyNotificationType +StackExchange.Redis.KeyNotificationType.ZUnionStore = 45 -> StackExchange.Redis.KeyNotificationType +static StackExchange.Redis.RedisChannel.KeySpacePrefix(in StackExchange.Redis.RedisKey prefix, int? database = null) -> StackExchange.Redis.RedisChannel +static StackExchange.Redis.RedisChannel.KeySpacePrefix(System.ReadOnlySpan prefix, int? database = null) -> StackExchange.Redis.RedisChannel +static StackExchange.Redis.RedisChannel.KeySpaceSingleKey(in StackExchange.Redis.RedisKey key, int database) -> StackExchange.Redis.RedisChannel diff --git a/src/StackExchange.Redis/RawResult.cs b/src/StackExchange.Redis/RawResult.cs index 1ac9f081a..e1c91b74e 100644 --- a/src/StackExchange.Redis/RawResult.cs +++ b/src/StackExchange.Redis/RawResult.cs @@ -161,22 +161,34 @@ public bool MoveNext() } public ReadOnlySequence Current { get; private set; } } + internal RedisChannel AsRedisChannel(byte[]? channelPrefix, RedisChannel.RedisChannelOptions options) { switch (Resp2TypeBulkString) { case ResultType.SimpleString: case ResultType.BulkString: - if (channelPrefix == null) + if (channelPrefix is null) { + // no channel-prefix enabled, just use as-is return new RedisChannel(GetBlob(), options); } if (StartsWith(channelPrefix)) { + // we have a channel-prefix, and it matches; strip it byte[] copy = Payload.Slice(channelPrefix.Length).ToArray(); return new RedisChannel(copy, options); } + + // we shouldn't get unexpected events, so to get here: we've received a notification + // on a channel that doesn't match our prefix; this *should* be limited to + // key notifications (see: IgnoreChannelPrefix), but: we need to be sure + if (StartsWith("__keyspace@"u8) || StartsWith("__keyevent@"u8)) + { + // use as-is + return new RedisChannel(GetBlob(), options); + } return default; default: throw new InvalidCastException("Cannot convert to RedisChannel: " + Resp3Type); @@ -270,9 +282,8 @@ internal bool StartsWith(in CommandBytes expected) var rangeToCheck = Payload.Slice(0, len); return new CommandBytes(rangeToCheck).Equals(expected); } - internal bool StartsWith(byte[] expected) + internal bool StartsWith(ReadOnlySpan expected) { - if (expected == null) throw new ArgumentNullException(nameof(expected)); if (expected.Length > Payload.Length) return false; var rangeToCheck = Payload.Slice(0, expected.Length); @@ -282,7 +293,7 @@ internal bool StartsWith(byte[] expected) foreach (var segment in rangeToCheck) { var from = segment.Span; - var to = new Span(expected, offset, from.Length); + var to = expected.Slice(offset, from.Length); if (!from.SequenceEqual(to)) return false; offset += from.Length; diff --git a/src/StackExchange.Redis/RedisChannel.cs b/src/StackExchange.Redis/RedisChannel.cs index d4289f3c6..889525bd2 100644 --- a/src/StackExchange.Redis/RedisChannel.cs +++ b/src/StackExchange.Redis/RedisChannel.cs @@ -1,4 +1,7 @@ using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Text; namespace StackExchange.Redis @@ -10,6 +13,36 @@ namespace StackExchange.Redis { internal readonly byte[]? Value; + internal ReadOnlySpan Span => Value is null ? default : Value.AsSpan(); + + internal ReadOnlySpan RoutingSpan + { + get + { + var span = Span; + if ((Options & (RedisChannelOptions.KeyRouted | RedisChannelOptions.IgnoreChannelPrefix | + RedisChannelOptions.Sharded | RedisChannelOptions.MultiNode | RedisChannelOptions.Pattern)) + == (RedisChannelOptions.KeyRouted | RedisChannelOptions.IgnoreChannelPrefix)) + { + // this *could* be a single-key __keyspace@{db}__:{key} subscription, in which case we want to use the key + // part for routing, but to avoid overhead we'll only even look if the channel starts with an underscore + if (span.Length >= 16 && span[0] == (byte)'_') span = StripKeySpacePrefix(span); + } + return span; + } + } + + internal static ReadOnlySpan StripKeySpacePrefix(ReadOnlySpan span) + { + if (span.Length >= 16 && span.StartsWith("__keyspace@"u8)) + { + var subspan = span.Slice(12); + int end = subspan.IndexOf("__:"u8); + if (end >= 0) return subspan.Slice(end + 3); + } + return span; + } + internal readonly RedisChannelOptions Options; [Flags] @@ -19,19 +52,42 @@ internal enum RedisChannelOptions Pattern = 1 << 0, Sharded = 1 << 1, KeyRouted = 1 << 2, + MultiNode = 1 << 3, + IgnoreChannelPrefix = 1 << 4, } // we don't consider Routed for equality - it's an implementation detail, not a fundamental feature - private const RedisChannelOptions EqualityMask = ~RedisChannelOptions.KeyRouted; + private const RedisChannelOptions EqualityMask = + ~(RedisChannelOptions.KeyRouted | RedisChannelOptions.MultiNode | RedisChannelOptions.IgnoreChannelPrefix); - internal RedisCommand PublishCommand => IsSharded ? RedisCommand.SPUBLISH : RedisCommand.PUBLISH; + internal RedisCommand GetPublishCommand() + { + return (Options & (RedisChannelOptions.Sharded | RedisChannelOptions.MultiNode)) switch + { + RedisChannelOptions.None => RedisCommand.PUBLISH, + RedisChannelOptions.Sharded => RedisCommand.SPUBLISH, + _ => ThrowKeyRouted(), + }; + + static RedisCommand ThrowKeyRouted() => throw new InvalidOperationException("Publishing is not supported for multi-node channels"); + } /// - /// Should we use cluster routing for this channel? This applies *either* to sharded (SPUBLISH) scenarios, + /// Should we use cluster routing for this channel? This applies *either* to sharded (SPUBLISH) scenarios, /// or to scenarios using . /// internal bool IsKeyRouted => (Options & RedisChannelOptions.KeyRouted) != 0; + /// + /// Should this channel be subscribed to on all nodes? This is only relevant for cluster scenarios and keyspace notifications. + /// + internal bool IsMultiNode => (Options & RedisChannelOptions.MultiNode) != 0; + + /// + /// Should the channel prefix be ignored when writing this channel. + /// + internal bool IgnoreChannelPrefix => (Options & RedisChannelOptions.IgnoreChannelPrefix) != 0; + /// /// Indicates whether the channel-name is either null or a zero-length value. /// @@ -58,6 +114,7 @@ public static bool UseImplicitAutoPattern get => s_DefaultPatternMode == PatternMode.Auto; set => s_DefaultPatternMode = value ? PatternMode.Auto : PatternMode.Literal; } + private static PatternMode s_DefaultPatternMode = PatternMode.Auto; /// @@ -82,7 +139,13 @@ public static bool UseImplicitAutoPattern /// a consideration. /// /// Note that channels from Sharded are always routed. - public RedisChannel WithKeyRouting() => new(Value, Options | RedisChannelOptions.KeyRouted); + public RedisChannel WithKeyRouting() + { + if (IsMultiNode) Throw(); + return new(Value, Options | RedisChannelOptions.KeyRouted); + + static void Throw() => throw new InvalidOperationException("Key routing is not supported for multi-node channels"); + } /// /// Creates a new that acts as a wildcard subscription. In cluster @@ -105,7 +168,8 @@ public static bool UseImplicitAutoPattern /// /// The name of the channel to create. /// The mode for name matching. - public RedisChannel(byte[]? value, PatternMode mode) : this(value, DeterminePatternBased(value, mode) ? RedisChannelOptions.Pattern : RedisChannelOptions.None) + public RedisChannel(byte[]? value, PatternMode mode) : this( + value, DeterminePatternBased(value, mode) ? RedisChannelOptions.Pattern : RedisChannelOptions.None) { } @@ -115,7 +179,9 @@ public RedisChannel(byte[]? value, PatternMode mode) : this(value, DeterminePatt /// The string name of the channel to create. /// The mode for name matching. // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract - public RedisChannel(string value, PatternMode mode) : this(value is null ? null : Encoding.UTF8.GetBytes(value), mode) + public RedisChannel(string value, PatternMode mode) : this( + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract + value is null ? null : Encoding.UTF8.GetBytes(value), mode) { } @@ -128,7 +194,8 @@ public RedisChannel(string value, PatternMode mode) : this(value is null ? null /// The name of the channel to create. /// Note that sharded subscriptions are completely separate to regular subscriptions; subscriptions /// using sharded channels must also be published with sharded channels (and vice versa). - public static RedisChannel Sharded(byte[]? value) => new(value, RedisChannelOptions.Sharded | RedisChannelOptions.KeyRouted); + public static RedisChannel Sharded(byte[]? value) => + new(value, RedisChannelOptions.Sharded | RedisChannelOptions.KeyRouted); /// /// Create a new redis channel from a string, representing a sharded channel. In cluster @@ -139,7 +206,134 @@ public RedisChannel(string value, PatternMode mode) : this(value is null ? null /// The string name of the channel to create. /// Note that sharded subscriptions are completely separate to regular subscriptions; subscriptions /// using sharded channels must also be published with sharded channels (and vice versa). - public static RedisChannel Sharded(string value) => new(value, RedisChannelOptions.Sharded | RedisChannelOptions.KeyRouted); + public static RedisChannel Sharded(string value) => + new(value, RedisChannelOptions.Sharded | RedisChannelOptions.KeyRouted); + + /// + /// Create a key-notification channel for a single key in a single database. + /// + public static RedisChannel KeySpaceSingleKey(in RedisKey key, int database) + // note we can allow patterns, because we aren't using PSUBSCRIBE + => BuildKeySpaceChannel(key, database, RedisChannelOptions.KeyRouted, default, false, true); + + /// + /// Create a key-notification channel for a pattern, optionally in a specified database. + /// + public static RedisChannel KeySpacePattern(in RedisKey pattern, int? database = null) + => BuildKeySpaceChannel(pattern, database, RedisChannelOptions.Pattern | RedisChannelOptions.MultiNode, default, appendStar: pattern.IsNull, allowKeyPatterns: true); + +#pragma warning disable RS0026 // competing overloads - disambiguated via OverloadResolutionPriority + /// + /// Create a key-notification channel using a raw prefix, optionally in a specified database. + /// + public static RedisChannel KeySpacePrefix(in RedisKey prefix, int? database = null) + { + if (prefix.IsEmpty) Throw(); + return BuildKeySpaceChannel(prefix, database, RedisChannelOptions.Pattern | RedisChannelOptions.MultiNode, default, true, false); + static void Throw() => throw new ArgumentNullException(nameof(prefix)); + } + + /// + /// Create a key-notification channel using a raw prefix, optionally in a specified database. + /// + [OverloadResolutionPriority(1)] + public static RedisChannel KeySpacePrefix(ReadOnlySpan prefix, int? database = null) + { + if (prefix.IsEmpty) Throw(); + return BuildKeySpaceChannel(RedisKey.Null, database, RedisChannelOptions.Pattern | RedisChannelOptions.MultiNode, prefix, true, false); + static void Throw() => throw new ArgumentNullException(nameof(prefix)); + } +#pragma warning restore RS0026 // competing overloads - disambiguated via OverloadResolutionPriority + + private const int DatabaseScratchBufferSize = 16; // largest non-negative int32 is 10 digits + + private static ReadOnlySpan AppendDatabase(Span target, int? database, RedisChannelOptions options) + { + if (database is null) + { + if ((options & RedisChannelOptions.Pattern) == 0) throw new ArgumentNullException(nameof(database)); + return "*"u8; // don't worry about the inbound scratch buffer, this is fine + } + else + { + var db32 = database.GetValueOrDefault(); + if (db32 == 0) return "0"u8; // so common, we might as well special case + if (db32 < 0) throw new ArgumentOutOfRangeException(nameof(database)); + return target.Slice(0, Format.FormatInt32(db32, target)); + } + } + + /// + /// Create an event-notification channel for a given event type, optionally in a specified database. + /// +#pragma warning disable RS0027 + public static RedisChannel KeyEvent(KeyNotificationType type, int? database = null) +#pragma warning restore RS0027 + => KeyEvent(KeyNotificationTypeFastHash.GetRawBytes(type), database); + + /// + /// Create an event-notification channel for a given event type, optionally in a specified database. + /// + /// This API is intended for use with custom/unknown event types; for well-known types, use . + public static RedisChannel KeyEvent(ReadOnlySpan type, int? database) + { + if (type.IsEmpty) throw new ArgumentNullException(nameof(type)); + + RedisChannelOptions options = RedisChannelOptions.MultiNode; + if (database is null) options |= RedisChannelOptions.Pattern; + var db = AppendDatabase(stackalloc byte[DatabaseScratchBufferSize], database, options); + + // __keyevent@{db}__:{type} + var arr = new byte[14 + db.Length + type.Length]; + + var target = AppendAndAdvance(arr.AsSpan(), "__keyevent@"u8); + target = AppendAndAdvance(target, db); + target = AppendAndAdvance(target, "__:"u8); + target = AppendAndAdvance(target, type); + Debug.Assert(target.IsEmpty); // should have calculated length correctly + + return new RedisChannel(arr, options | RedisChannelOptions.IgnoreChannelPrefix); + } + + private static Span AppendAndAdvance(Span target, scoped ReadOnlySpan value) + { + value.CopyTo(target); + return target.Slice(value.Length); + } + + private static RedisChannel BuildKeySpaceChannel(in RedisKey key, int? database, RedisChannelOptions options, ReadOnlySpan suffix, bool appendStar, bool allowKeyPatterns) + { + int fullKeyLength = key.TotalLength() + suffix.Length + (appendStar ? 1 : 0); + if (appendStar & (options & RedisChannelOptions.Pattern) == 0) throw new ArgumentNullException(nameof(key)); + if (fullKeyLength == 0) throw new ArgumentOutOfRangeException(nameof(key)); + + var db = AppendDatabase(stackalloc byte[DatabaseScratchBufferSize], database, options); + + // __keyspace@{db}__:{key}[*] + var arr = new byte[14 + db.Length + fullKeyLength]; + + var target = AppendAndAdvance(arr.AsSpan(), "__keyspace@"u8); + target = AppendAndAdvance(target, db); + target = AppendAndAdvance(target, "__:"u8); + var keySpan = target; // remember this for if we need to check for patterns + var keyLen = key.CopyTo(target); + target = target.Slice(keyLen); + target = AppendAndAdvance(target, suffix); + if (!allowKeyPatterns) + { + keySpan = keySpan.Slice(0, keyLen + suffix.Length); + if (keySpan.IndexOfAny((byte)'*', (byte)'?', (byte)'[') >= 0) ThrowPattern(); + } + if (appendStar) + { + target[0] = (byte)'*'; + target = target.Slice(1); + } + Debug.Assert(target.IsEmpty, "length calculated incorrectly"); + return new RedisChannel(arr, options | RedisChannelOptions.IgnoreChannelPrefix); + + static void ThrowPattern() => throw new ArgumentException("The supplied key contains pattern characters, but patterns are not supported in this context."); + } internal RedisChannel(byte[]? value, RedisChannelOptions options) { @@ -351,7 +545,7 @@ public static implicit operator RedisChannel(byte[]? key) { return Encoding.UTF8.GetString(arr); } - catch (Exception e) when // Only catch exception throwed by Encoding.UTF8.GetString + catch (Exception e) when // Only catch exception thrown by Encoding.UTF8.GetString (e is DecoderFallbackException or ArgumentException or ArgumentNullException) { return BitConverter.ToString(arr); diff --git a/src/StackExchange.Redis/RedisDatabase.cs b/src/StackExchange.Redis/RedisDatabase.cs index c1c3c5728..056a5380a 100644 --- a/src/StackExchange.Redis/RedisDatabase.cs +++ b/src/StackExchange.Redis/RedisDatabase.cs @@ -1900,7 +1900,7 @@ public Task StringLongestCommonSubsequenceWithMatchesAsync(Redis public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) { if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel)); - var msg = Message.Create(-1, flags, channel.PublishCommand, channel, message); + var msg = Message.Create(-1, flags, channel.GetPublishCommand(), channel, message); // if we're actively subscribed: send via that connection (otherwise, follow normal rules) return ExecuteSync(msg, ResultProcessor.Int64, server: multiplexer.GetSubscribedServer(channel)); } @@ -1908,7 +1908,7 @@ public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags public Task PublishAsync(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) { if (channel.IsNullOrEmpty) throw new ArgumentNullException(nameof(channel)); - var msg = Message.Create(-1, flags, channel.PublishCommand, channel, message); + var msg = Message.Create(-1, flags, channel.GetPublishCommand(), channel, message); // if we're actively subscribed: send via that connection (otherwise, follow normal rules) return ExecuteAsync(msg, ResultProcessor.Int64, server: multiplexer.GetSubscribedServer(channel)); } diff --git a/src/StackExchange.Redis/RedisKey.cs b/src/StackExchange.Redis/RedisKey.cs index 0ee83d560..e18e0fb7c 100644 --- a/src/StackExchange.Redis/RedisKey.cs +++ b/src/StackExchange.Redis/RedisKey.cs @@ -395,6 +395,14 @@ internal int TotalLength() => _ => ((byte[])KeyValue).Length, }; + internal int MaxByteCount() => + (KeyPrefix is null ? 0 : KeyPrefix.Length) + KeyValue switch + { + null => 0, + string s => Encoding.UTF8.GetMaxByteCount(s.Length), + _ => ((byte[])KeyValue).Length, + }; + internal int CopyTo(Span destination) { int written = 0; diff --git a/src/StackExchange.Redis/RedisSubscriber.cs b/src/StackExchange.Redis/RedisSubscriber.cs index 9ade78c2d..ca66e6113 100644 --- a/src/StackExchange.Redis/RedisSubscriber.cs +++ b/src/StackExchange.Redis/RedisSubscriber.cs @@ -1,11 +1,8 @@ using System; using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; -using System.Diagnostics.SymbolStore; using System.Net; -using System.Threading; using System.Threading.Tasks; -using Pipelines.Sockets.Unofficial; using Pipelines.Sockets.Unofficial.Arenas; using static StackExchange.Redis.ConnectionMultiplexer; @@ -30,7 +27,7 @@ internal Subscription GetOrAddSubscription(in RedisChannel channel, CommandFlags { if (!subscriptions.TryGetValue(channel, out var sub)) { - sub = new Subscription(flags); + sub = channel.IsMultiNode ? new MultiNodeSubscription(flags) : new SingleNodeSubscription(flags); subscriptions.TryAdd(channel, sub); } return sub; @@ -71,7 +68,7 @@ internal bool GetSubscriberCounts(in RedisChannel channel, out int handlers, out { if (!channel.IsNullOrEmpty && subscriptions.TryGetValue(channel, out Subscription? sub)) { - return sub.GetCurrentServer(); + return sub.GetAnyCurrentServer(); } return null; } @@ -123,7 +120,7 @@ internal void UpdateSubscriptions() { foreach (var pair in subscriptions) { - pair.Value.UpdateServer(); + pair.Value.RemoveDisconnectedEndpoints(); } } @@ -135,13 +132,10 @@ internal long EnsureSubscriptions(CommandFlags flags = CommandFlags.None) { // TODO: Subscribe with variadic commands to reduce round trips long count = 0; + var subscriber = DefaultSubscriber; foreach (var pair in subscriptions) { - if (!pair.Value.IsConnected) - { - count++; - DefaultSubscriber.EnsureSubscribedToServer(pair.Value, pair.Key, flags, true); - } + count += pair.Value.EnsureSubscribedToServer(subscriber, pair.Key, flags, true); } return count; } @@ -151,161 +145,6 @@ internal enum SubscriptionAction Subscribe, Unsubscribe, } - - /// - /// This is the record of a single subscription to a redis server. - /// It's the singular channel (which may or may not be a pattern), to one or more handlers. - /// We subscriber to a redis server once (for all messages) and execute 1-many handlers when a message arrives. - /// - internal sealed class Subscription - { - private Action? _handlers; - private readonly object _handlersLock = new object(); - private ChannelMessageQueue? _queues; - private ServerEndPoint? CurrentServer; - public CommandFlags Flags { get; } - public ResultProcessor.TrackSubscriptionsProcessor Processor { get; } - - /// - /// Whether the we have is connected. - /// Since we clear on a disconnect, this should stay correct. - /// - internal bool IsConnected => CurrentServer?.IsSubscriberConnected == true; - - public Subscription(CommandFlags flags) - { - Flags = flags; - Processor = new ResultProcessor.TrackSubscriptionsProcessor(this); - } - - /// - /// Gets the configured (P)SUBSCRIBE or (P)UNSUBSCRIBE for an action. - /// - internal Message GetMessage(RedisChannel channel, SubscriptionAction action, CommandFlags flags, bool internalCall) - { - var command = action switch // note that the Routed flag doesn't impact the message here - just the routing - { - SubscriptionAction.Subscribe => (channel.Options & ~RedisChannel.RedisChannelOptions.KeyRouted) switch - { - RedisChannel.RedisChannelOptions.None => RedisCommand.SUBSCRIBE, - RedisChannel.RedisChannelOptions.Pattern => RedisCommand.PSUBSCRIBE, - RedisChannel.RedisChannelOptions.Sharded => RedisCommand.SSUBSCRIBE, - _ => Unknown(action, channel.Options), - }, - SubscriptionAction.Unsubscribe => (channel.Options & ~RedisChannel.RedisChannelOptions.KeyRouted) switch - { - RedisChannel.RedisChannelOptions.None => RedisCommand.UNSUBSCRIBE, - RedisChannel.RedisChannelOptions.Pattern => RedisCommand.PUNSUBSCRIBE, - RedisChannel.RedisChannelOptions.Sharded => RedisCommand.SUNSUBSCRIBE, - _ => Unknown(action, channel.Options), - }, - _ => Unknown(action, channel.Options), - }; - - // TODO: Consider flags here - we need to pass Fire and Forget, but don't want to intermingle Primary/Replica - var msg = Message.Create(-1, Flags | flags, command, channel); - msg.SetForSubscriptionBridge(); - if (internalCall) - { - msg.SetInternalCall(); - } - return msg; - } - - private RedisCommand Unknown(SubscriptionAction action, RedisChannel.RedisChannelOptions options) - => throw new ArgumentException($"Unable to determine pub/sub operation for '{action}' against '{options}'"); - - public void Add(Action? handler, ChannelMessageQueue? queue) - { - if (handler != null) - { - lock (_handlersLock) - { - _handlers += handler; - } - } - if (queue != null) - { - ChannelMessageQueue.Combine(ref _queues, queue); - } - } - - public bool Remove(Action? handler, ChannelMessageQueue? queue) - { - if (handler != null) - { - lock (_handlersLock) - { - _handlers -= handler; - } - } - if (queue != null) - { - ChannelMessageQueue.Remove(ref _queues, queue); - } - return _handlers == null & _queues == null; - } - - public ICompletable? ForInvoke(in RedisChannel channel, in RedisValue message, out ChannelMessageQueue? queues) - { - var handlers = _handlers; - queues = Volatile.Read(ref _queues); - return handlers == null ? null : new MessageCompletable(channel, message, handlers); - } - - internal void MarkCompleted() - { - lock (_handlersLock) - { - _handlers = null; - } - ChannelMessageQueue.MarkAllCompleted(ref _queues); - } - - internal void GetSubscriberCounts(out int handlers, out int queues) - { - queues = ChannelMessageQueue.Count(ref _queues); - var tmp = _handlers; - if (tmp == null) - { - handlers = 0; - } - else if (tmp.IsSingle()) - { - handlers = 1; - } - else - { - handlers = 0; - foreach (var sub in tmp.AsEnumerable()) { handlers++; } - } - } - - internal ServerEndPoint? GetCurrentServer() => Volatile.Read(ref CurrentServer); - internal void SetCurrentServer(ServerEndPoint? server) => CurrentServer = server; - // conditional clear - internal bool ClearCurrentServer(ServerEndPoint expected) - { - if (CurrentServer == expected) - { - CurrentServer = null; - return true; - } - - return false; - } - - /// - /// Evaluates state and if we're not currently connected, clears the server reference. - /// - internal void UpdateServer() - { - if (!IsConnected) - { - CurrentServer = null; - } - } - } } /// @@ -393,7 +232,7 @@ private static void ThrowIfNull(in RedisChannel channel) public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) { ThrowIfNull(channel); - var msg = Message.Create(-1, flags, channel.PublishCommand, channel, message); + var msg = Message.Create(-1, flags, channel.GetPublishCommand(), channel, message); // if we're actively subscribed: send via that connection (otherwise, follow normal rules) return ExecuteSync(msg, ResultProcessor.Int64, server: multiplexer.GetSubscribedServer(channel)); } @@ -401,7 +240,7 @@ public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags public Task PublishAsync(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) { ThrowIfNull(channel); - var msg = Message.Create(-1, flags, channel.PublishCommand, channel, message); + var msg = Message.Create(-1, flags, channel.GetPublishCommand(), channel, message); // if we're actively subscribed: send via that connection (otherwise, follow normal rules) return ExecuteAsync(msg, ResultProcessor.Int64, server: multiplexer.GetSubscribedServer(channel)); } @@ -416,37 +255,26 @@ public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = return queue; } - private bool Subscribe(RedisChannel channel, Action? handler, ChannelMessageQueue? queue, CommandFlags flags) + private int Subscribe(RedisChannel channel, Action? handler, ChannelMessageQueue? queue, CommandFlags flags) { ThrowIfNull(channel); - if (handler == null && queue == null) { return true; } + if (handler == null && queue == null) { return 0; } var sub = multiplexer.GetOrAddSubscription(channel, flags); sub.Add(handler, queue); - return EnsureSubscribedToServer(sub, channel, flags, false); + return sub.EnsureSubscribedToServer(this, channel, flags, false); } - internal bool EnsureSubscribedToServer(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall) - { - if (sub.IsConnected) { return true; } - - // TODO: Cleanup old hangers here? - sub.SetCurrentServer(null); // we're not appropriately connected, so blank it out for eligible reconnection - var message = sub.GetMessage(channel, SubscriptionAction.Subscribe, flags, internalCall); - var selected = multiplexer.SelectServer(message); - return ExecuteSync(message, sub.Processor, selected); - } - - internal void ResubscribeToServer(Subscription sub, RedisChannel channel, ServerEndPoint serverEndPoint, string cause) + internal void ResubscribeToServer(Subscription sub, in RedisChannel channel, ServerEndPoint serverEndPoint, string cause) { // conditional: only if that's the server we were connected to, or "none"; we don't want to end up duplicated - if (sub.ClearCurrentServer(serverEndPoint) || !sub.IsConnected) + if (sub.TryRemoveEndpoint(serverEndPoint) || !sub.IsConnectedAny()) { if (serverEndPoint.IsSubscriberConnected) { // we'll *try* for a simple resubscribe, following any -MOVED etc, but if that fails: fall back // to full reconfigure; importantly, note that we've already recorded the disconnect - var message = sub.GetMessage(channel, SubscriptionAction.Subscribe, CommandFlags.None, false); + var message = sub.GetSubscriptionMessage(channel, SubscriptionAction.Subscribe, CommandFlags.None, false); _ = ExecuteAsync(message, sub.Processor, serverEndPoint).ContinueWith( t => multiplexer.ReconfigureIfNeeded(serverEndPoint.EndPoint, false, cause: cause), TaskContinuationOptions.OnlyOnFaulted); @@ -470,25 +298,14 @@ public async Task SubscribeAsync(RedisChannel channel, Comm return queue; } - private Task SubscribeAsync(RedisChannel channel, Action? handler, ChannelMessageQueue? queue, CommandFlags flags, ServerEndPoint? server = null) + private Task SubscribeAsync(RedisChannel channel, Action? handler, ChannelMessageQueue? queue, CommandFlags flags, ServerEndPoint? server = null) { ThrowIfNull(channel); - if (handler == null && queue == null) { return CompletedTask.Default(null); } + if (handler == null && queue == null) { return CompletedTask.Default(null); } var sub = multiplexer.GetOrAddSubscription(channel, flags); sub.Add(handler, queue); - return EnsureSubscribedToServerAsync(sub, channel, flags, false, server); - } - - public Task EnsureSubscribedToServerAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall, ServerEndPoint? server = null) - { - if (sub.IsConnected) { return CompletedTask.Default(null); } - - // TODO: Cleanup old hangers here? - sub.SetCurrentServer(null); // we're not appropriately connected, so blank it out for eligible reconnection - var message = sub.GetMessage(channel, SubscriptionAction.Subscribe, flags, internalCall); - server ??= multiplexer.SelectServer(message); - return ExecuteAsync(message, sub.Processor, server); + return sub.EnsureSubscribedToServerAsync(this, channel, flags, false, server); } public EndPoint? SubscribedEndpoint(RedisChannel channel) => multiplexer.GetSubscribedServer(channel)?.EndPoint; @@ -500,21 +317,12 @@ public bool Unsubscribe(in RedisChannel channel, Action? handler, CommandFlags flags) => UnsubscribeAsync(channel, handler, null, flags); @@ -523,20 +331,10 @@ public Task UnsubscribeAsync(in RedisChannel channel, Action.Default(asyncState); } - private Task UnsubscribeFromServerAsync(Subscription sub, RedisChannel channel, CommandFlags flags, object? asyncState, bool internalCall) - { - if (sub.GetCurrentServer() is ServerEndPoint oldOwner) - { - var message = sub.GetMessage(channel, SubscriptionAction.Unsubscribe, flags, internalCall); - return multiplexer.ExecuteAsyncImpl(message, sub.Processor, asyncState, oldOwner); - } - return CompletedTask.FromResult(true, asyncState); - } - /// /// Unregisters a handler or queue and returns if we should remove it from the server. /// @@ -573,7 +371,7 @@ public void UnsubscribeAll(CommandFlags flags = CommandFlags.None) if (subs.TryRemove(pair.Key, out var sub)) { sub.MarkCompleted(); - UnsubscribeFromServer(sub, pair.Key, flags, false); + sub.UnsubscribeFromServer(this, pair.Key, flags, false); } } } @@ -588,7 +386,7 @@ public Task UnsubscribeAllAsync(CommandFlags flags = CommandFlags.None) if (subs.TryRemove(pair.Key, out var sub)) { sub.MarkCompleted(); - last = UnsubscribeFromServerAsync(sub, pair.Key, flags, asyncState, false); + last = sub.UnsubscribeFromServerAsync(this, pair.Key, flags, asyncState, false); } } return last ?? CompletedTask.Default(asyncState); diff --git a/src/StackExchange.Redis/RedisValue.cs b/src/StackExchange.Redis/RedisValue.cs index d306ca0d0..46228a912 100644 --- a/src/StackExchange.Redis/RedisValue.cs +++ b/src/StackExchange.Redis/RedisValue.cs @@ -869,19 +869,58 @@ private static string ToHex(ReadOnlySpan src) /// /// Gets the length of the value in bytes. /// - public int GetByteCount() + public int GetByteCount() => Type switch { - switch (Type) - { - case StorageType.Null: return 0; - case StorageType.Raw: return _memory.Length; - case StorageType.String: return Encoding.UTF8.GetByteCount((string)_objectOrSentinel!); - case StorageType.Int64: return Format.MeasureInt64(OverlappedValueInt64); - case StorageType.UInt64: return Format.MeasureUInt64(OverlappedValueUInt64); - case StorageType.Double: return Format.MeasureDouble(OverlappedValueDouble); - default: return ThrowUnableToMeasure(); - } - } + StorageType.Null => 0, + StorageType.Raw => _memory.Length, + StorageType.String => Encoding.UTF8.GetByteCount((string)_objectOrSentinel!), + StorageType.Int64 => Format.MeasureInt64(OverlappedValueInt64), + StorageType.UInt64 => Format.MeasureUInt64(OverlappedValueUInt64), + StorageType.Double => Format.MeasureDouble(OverlappedValueDouble), + _ => ThrowUnableToMeasure(), + }; + + /// + /// Gets the maximum length of the value in bytes. + /// + internal int GetMaxByteCount() => Type switch + { + StorageType.Null => 0, + StorageType.Raw => _memory.Length, + StorageType.String => Encoding.UTF8.GetMaxByteCount(((string)_objectOrSentinel!).Length), + StorageType.Int64 => Format.MaxInt64TextLen, + StorageType.UInt64 => Format.MaxInt64TextLen, + StorageType.Double => Format.MaxDoubleTextLen, + _ => ThrowUnableToMeasure(), + }; + + /// + /// Gets the length of the value in characters, assuming UTF8 interpretation of BLOB payloads. + /// + internal int GetCharCount() => Type switch + { + StorageType.Null => 0, + StorageType.Raw => Encoding.UTF8.GetCharCount(_memory.Span), + StorageType.String => ((string)_objectOrSentinel!).Length, + StorageType.Int64 => Format.MeasureInt64(OverlappedValueInt64), + StorageType.UInt64 => Format.MeasureUInt64(OverlappedValueUInt64), + StorageType.Double => Format.MeasureDouble(OverlappedValueDouble), + _ => ThrowUnableToMeasure(), + }; + + /// + /// Gets the length of the value in characters, assuming UTF8 interpretation of BLOB payloads. + /// + internal int GetMaxCharCount() => Type switch + { + StorageType.Null => 0, + StorageType.Raw => Encoding.UTF8.GetMaxCharCount(_memory.Length), + StorageType.String => ((string)_objectOrSentinel!).Length, + StorageType.Int64 => Format.MaxInt64TextLen, + StorageType.UInt64 => Format.MaxInt64TextLen, + StorageType.Double => Format.MaxDoubleTextLen, + _ => ThrowUnableToMeasure(), + }; private int ThrowUnableToMeasure() => throw new InvalidOperationException("Unable to compute length of type: " + Type); @@ -918,6 +957,33 @@ public int CopyTo(Span destination) } } + /// + /// Copy the value as character data to the provided . + /// + internal int CopyTo(Span destination) + { + switch (Type) + { + case StorageType.Null: + return 0; + case StorageType.Raw: + var srcBytes = _memory.Span; + return Encoding.UTF8.GetChars(srcBytes, destination); + case StorageType.String: + var span = ((string)_objectOrSentinel!).AsSpan(); + span.CopyTo(destination); + return span.Length; + case StorageType.Int64: + return Format.FormatInt64(OverlappedValueInt64, destination); + case StorageType.UInt64: + return Format.FormatUInt64(OverlappedValueUInt64, destination); + case StorageType.Double: + return Format.FormatDouble(OverlappedValueDouble, destination); + default: + return ThrowUnableToMeasure(); + } + } + /// /// Converts a to a . /// @@ -1245,5 +1311,61 @@ internal ValueCondition Digest() return digest; } } + + internal bool TryGetSpan(out ReadOnlySpan span) + { + if (_objectOrSentinel == Sentinel_Raw) + { + span = _memory.Span; + return true; + } + span = default; + return false; + } + + /// + /// Indicates whether the current value has the supplied value as a prefix. + /// + /// The to check. + [OverloadResolutionPriority(1)] // prefer this when it is an option (vs casting a byte[] to RedisValue) + public bool StartsWith(ReadOnlySpan value) + { + if (IsNull) return false; + if (value.IsEmpty) return true; + if (IsNullOrEmpty) return false; + + int len; + switch (Type) + { + case StorageType.Raw: + return _memory.Span.StartsWith(value); + case StorageType.Int64: + Span buffer = stackalloc byte[Format.MaxInt64TextLen]; + len = Format.FormatInt64(OverlappedValueInt64, buffer); + return buffer.Slice(0, len).StartsWith(value); + case StorageType.UInt64: + buffer = stackalloc byte[Format.MaxInt64TextLen]; + len = Format.FormatUInt64(OverlappedValueUInt64, buffer); + return buffer.Slice(0, len).StartsWith(value); + case StorageType.Double: + buffer = stackalloc byte[Format.MaxDoubleTextLen]; + len = Format.FormatDouble(OverlappedValueDouble, buffer); + return buffer.Slice(0, len).StartsWith(value); + case StorageType.String: + var s = ((string)_objectOrSentinel!).AsSpan(); + if (s.Length < value.Length) return false; // not enough characters to match + if (s.Length > value.Length) s = s.Slice(0, value.Length); // only need to match the prefix + var maxBytes = Encoding.UTF8.GetMaxByteCount(s.Length); + byte[]? lease = null; + const int MAX_STACK = 128; + buffer = maxBytes <= MAX_STACK ? stackalloc byte[MAX_STACK] : (lease = ArrayPool.Shared.Rent(maxBytes)); + var bytes = Encoding.UTF8.GetBytes(s, buffer); + bool isMatch = buffer.Slice(0, bytes).StartsWith(value); + if (lease is not null) ArrayPool.Shared.Return(lease); + return isMatch; + default: + return false; + } + } } } diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index 196cabde5..f2c6deb8b 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -469,12 +469,21 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes connection.SubscriptionCount = count; SetResult(message, true); - var newServer = message.Command switch + var ep = connection.BridgeCouldBeNull?.ServerEndPoint; + if (ep is not null) { - RedisCommand.SUBSCRIBE or RedisCommand.SSUBSCRIBE or RedisCommand.PSUBSCRIBE => connection.BridgeCouldBeNull?.ServerEndPoint, - _ => null, - }; - Subscription?.SetCurrentServer(newServer); + switch (message.Command) + { + case RedisCommand.SUBSCRIBE: + case RedisCommand.SSUBSCRIBE: + case RedisCommand.PSUBSCRIBE: + Subscription?.AddEndpoint(ep); + break; + default: + Subscription?.TryRemoveEndpoint(ep); + break; + } + } return true; } } diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index f856a5b21..abe8d8afb 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -695,14 +695,15 @@ internal void OnFullyEstablished(PhysicalConnection connection, string source) // Clear the unselectable flag ASAP since we are open for business ClearUnselectable(UnselectableFlags.DidNotRespond); - if (bridge == subscription) + bool isResp3 = KnowOrAssumeResp3(); + if (bridge == subscription || isResp3) { // Note: this MUST be fire and forget, because we might be in the middle of a Sync processing // TracerProcessor which is executing this line inside a SetResultCore(). // Since we're issuing commands inside a SetResult path in a message, we'd create a deadlock by waiting. Multiplexer.EnsureSubscriptions(CommandFlags.FireAndForget); } - if (IsConnected && (IsSubscriberConnected || !SupportsSubscriptions || KnowOrAssumeResp3())) + if (IsConnected && (IsSubscriberConnected || !SupportsSubscriptions || isResp3)) { // Only connect on the second leg - we can accomplish this by checking both // Or the first leg, if we're only making 1 connection because subscriptions aren't supported diff --git a/src/StackExchange.Redis/ServerSelectionStrategy.cs b/src/StackExchange.Redis/ServerSelectionStrategy.cs index ca247c38b..db729ba26 100644 --- a/src/StackExchange.Redis/ServerSelectionStrategy.cs +++ b/src/StackExchange.Redis/ServerSelectionStrategy.cs @@ -101,9 +101,31 @@ public int HashSlot(in RedisKey key) /// /// The to determine a slot ID for. public int HashSlot(in RedisChannel channel) - // note that the RedisChannel->byte[] converter is always direct, so this is not an alloc - // (we deal with channels far less frequently, so pay the encoding cost up-front) - => ServerType == ServerType.Standalone || channel.IsNull ? NoSlot : GetClusterSlot((byte[])channel!); + { + if (ServerType == ServerType.Standalone || channel.IsNull) return NoSlot; + + ReadOnlySpan routingSpan = channel.RoutingSpan; + byte[] prefix; + return channel.IgnoreChannelPrefix || (prefix = multiplexer.ChannelPrefix).Length == 0 + ? GetClusterSlot(routingSpan) : GetClusterSlotWithPrefix(prefix, routingSpan); + + static int GetClusterSlotWithPrefix(byte[] prefixRaw, ReadOnlySpan routingSpan) + { + ReadOnlySpan prefixSpan = prefixRaw; + const int MAX_STACK = 128; + byte[]? lease = null; + var totalLength = prefixSpan.Length + routingSpan.Length; + var span = totalLength <= MAX_STACK + ? stackalloc byte[MAX_STACK] + : (lease = ArrayPool.Shared.Rent(totalLength)); + + prefixSpan.CopyTo(span); + routingSpan.CopyTo(span.Slice(prefixSpan.Length)); + var result = GetClusterSlot(span.Slice(0, totalLength)); + if (lease is not null) ArrayPool.Shared.Return(lease); + return result; + } + } /// /// Gets the hashslot for a given byte sequence. @@ -360,5 +382,25 @@ private ServerEndPoint[] MapForMutation() } return Any(command, flags, allowDisconnected); } + + internal bool CanServeSlot(ServerEndPoint server, in RedisChannel channel) + => CanServeSlot(server, HashSlot(in channel)); + + internal bool CanServeSlot(ServerEndPoint server, int slot) + { + if (slot == NoSlot) return true; + var arr = map; + if (arr is null) return true; // means "any" + + var primary = arr[slot]; + if (server == primary) return true; + + var replicas = primary.Replicas; + for (int i = 0; i < replicas.Length; i++) + { + if (server == replicas[i]) return true; + } + return false; + } } } diff --git a/src/StackExchange.Redis/StackExchange.Redis.csproj b/src/StackExchange.Redis/StackExchange.Redis.csproj index 983624bc0..84e495f1a 100644 --- a/src/StackExchange.Redis/StackExchange.Redis.csproj +++ b/src/StackExchange.Redis/StackExchange.Redis.csproj @@ -19,7 +19,10 @@ - + + + + diff --git a/src/StackExchange.Redis/Subscription.cs b/src/StackExchange.Redis/Subscription.cs new file mode 100644 index 000000000..99f3d00cb --- /dev/null +++ b/src/StackExchange.Redis/Subscription.cs @@ -0,0 +1,520 @@ +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Pipelines.Sockets.Unofficial; + +namespace StackExchange.Redis; + +public partial class ConnectionMultiplexer +{ + /// + /// This is the record of a single subscription to a redis server. + /// It's the singular channel (which may or may not be a pattern), to one or more handlers. + /// We subscriber to a redis server once (for all messages) and execute 1-many handlers when a message arrives. + /// + internal abstract class Subscription + { + private Action? _handlers; + private readonly object _handlersLock = new(); + private ChannelMessageQueue? _queues; + public CommandFlags Flags { get; } + public ResultProcessor.TrackSubscriptionsProcessor Processor { get; } + + internal abstract bool IsConnectedAny(); + internal abstract bool IsConnectedTo(EndPoint endpoint); + + internal abstract void AddEndpoint(ServerEndPoint server); + + // conditional clear + internal abstract bool TryRemoveEndpoint(ServerEndPoint expected); + + internal abstract void RemoveDisconnectedEndpoints(); + + // returns the number of changes required + internal abstract int EnsureSubscribedToServer( + RedisSubscriber subscriber, + in RedisChannel channel, + CommandFlags flags, + bool internalCall); + + // returns the number of changes required + internal abstract Task EnsureSubscribedToServerAsync( + RedisSubscriber subscriber, + RedisChannel channel, + CommandFlags flags, + bool internalCall, + ServerEndPoint? server = null); + + internal abstract bool UnsubscribeFromServer( + RedisSubscriber subscriber, + in RedisChannel channel, + CommandFlags flags, + bool internalCall); + + internal abstract Task UnsubscribeFromServerAsync( + RedisSubscriber subscriber, + RedisChannel channel, + CommandFlags flags, + object? asyncState, + bool internalCall); + + internal abstract int GetConnectionCount(); + + internal abstract ServerEndPoint? GetAnyCurrentServer(); + + public Subscription(CommandFlags flags) + { + Flags = flags; + Processor = new ResultProcessor.TrackSubscriptionsProcessor(this); + } + + /// + /// Gets the configured (P)SUBSCRIBE or (P)UNSUBSCRIBE for an action. + /// + internal Message GetSubscriptionMessage( + in RedisChannel channel, + SubscriptionAction action, + CommandFlags flags, + bool internalCall) + { + const RedisChannel.RedisChannelOptions OPTIONS_MASK = ~( + RedisChannel.RedisChannelOptions.KeyRouted | RedisChannel.RedisChannelOptions.IgnoreChannelPrefix); + var command = + action switch // note that the Routed flag doesn't impact the message here - just the routing + { + SubscriptionAction.Subscribe => (channel.Options & OPTIONS_MASK) switch + { + RedisChannel.RedisChannelOptions.None => RedisCommand.SUBSCRIBE, + RedisChannel.RedisChannelOptions.MultiNode => RedisCommand.SUBSCRIBE, + RedisChannel.RedisChannelOptions.Pattern => RedisCommand.PSUBSCRIBE, + RedisChannel.RedisChannelOptions.Pattern | RedisChannel.RedisChannelOptions.MultiNode => + RedisCommand.PSUBSCRIBE, + RedisChannel.RedisChannelOptions.Sharded => RedisCommand.SSUBSCRIBE, + _ => Unknown(action, channel.Options), + }, + SubscriptionAction.Unsubscribe => (channel.Options & OPTIONS_MASK) switch + { + RedisChannel.RedisChannelOptions.None => RedisCommand.UNSUBSCRIBE, + RedisChannel.RedisChannelOptions.MultiNode => RedisCommand.UNSUBSCRIBE, + RedisChannel.RedisChannelOptions.Pattern => RedisCommand.PUNSUBSCRIBE, + RedisChannel.RedisChannelOptions.Pattern | RedisChannel.RedisChannelOptions.MultiNode => + RedisCommand.PUNSUBSCRIBE, + RedisChannel.RedisChannelOptions.Sharded => RedisCommand.SUNSUBSCRIBE, + _ => Unknown(action, channel.Options), + }, + _ => Unknown(action, channel.Options), + }; + + // TODO: Consider flags here - we need to pass Fire and Forget, but don't want to intermingle Primary/Replica + var msg = Message.Create(-1, Flags | flags, command, channel); + msg.SetForSubscriptionBridge(); + if (internalCall) + { + msg.SetInternalCall(); + } + + return msg; + } + + private RedisCommand Unknown(SubscriptionAction action, RedisChannel.RedisChannelOptions options) + => throw new ArgumentException( + $"Unable to determine pub/sub operation for '{action}' against '{options}'"); + + public void Add(Action? handler, ChannelMessageQueue? queue) + { + if (handler != null) + { + lock (_handlersLock) + { + _handlers += handler; + } + } + + if (queue != null) + { + ChannelMessageQueue.Combine(ref _queues, queue); + } + } + + public bool Remove(Action? handler, ChannelMessageQueue? queue) + { + if (handler != null) + { + lock (_handlersLock) + { + _handlers -= handler; + } + } + + if (queue != null) + { + ChannelMessageQueue.Remove(ref _queues, queue); + } + + return _handlers == null & _queues == null; + } + + public ICompletable? ForInvoke(in RedisChannel channel, in RedisValue message, out ChannelMessageQueue? queues) + { + var handlers = _handlers; + queues = Volatile.Read(ref _queues); + return handlers == null ? null : new MessageCompletable(channel, message, handlers); + } + + internal void MarkCompleted() + { + lock (_handlersLock) + { + _handlers = null; + } + + ChannelMessageQueue.MarkAllCompleted(ref _queues); + } + + internal void GetSubscriberCounts(out int handlers, out int queues) + { + queues = ChannelMessageQueue.Count(ref _queues); + var tmp = _handlers; + if (tmp == null) + { + handlers = 0; + } + else if (tmp.IsSingle()) + { + handlers = 1; + } + else + { + handlers = 0; + foreach (var sub in tmp.AsEnumerable()) { handlers++; } + } + } + } + + // used for most subscriptions; routed to a single node + internal sealed class SingleNodeSubscription(CommandFlags flags) : Subscription(flags) + { + internal override bool IsConnectedAny() => _currentServer is { IsSubscriberConnected: true }; + + internal override int GetConnectionCount() => IsConnectedAny() ? 1 : 0; + + internal override bool IsConnectedTo(EndPoint endpoint) + { + var server = _currentServer; + return server is { IsSubscriberConnected: true } && server.EndPoint == endpoint; + } + + internal override void AddEndpoint(ServerEndPoint server) => _currentServer = server; + + internal override bool TryRemoveEndpoint(ServerEndPoint expected) + { + if (_currentServer == expected) + { + _currentServer = null; + return true; + } + + return false; + } + + internal override bool UnsubscribeFromServer( + RedisSubscriber subscriber, + in RedisChannel channel, + CommandFlags flags, + bool internalCall) + { + var server = _currentServer; + if (server is not null) + { + var message = GetSubscriptionMessage(channel, SubscriptionAction.Unsubscribe, flags, internalCall); + return subscriber.multiplexer.ExecuteSyncImpl(message, Processor, server); + } + + return true; + } + + internal override Task UnsubscribeFromServerAsync( + RedisSubscriber subscriber, + RedisChannel channel, + CommandFlags flags, + object? asyncState, + bool internalCall) + { + var server = _currentServer; + if (server is not null) + { + var message = GetSubscriptionMessage(channel, SubscriptionAction.Unsubscribe, flags, internalCall); + return subscriber.multiplexer.ExecuteAsyncImpl(message, Processor, asyncState, server); + } + + return CompletedTask.FromResult(true, asyncState); + } + + private ServerEndPoint? _currentServer; + internal ServerEndPoint? GetCurrentServer() => Volatile.Read(ref _currentServer); + + internal override ServerEndPoint? GetAnyCurrentServer() => Volatile.Read(ref _currentServer); + + /// + /// Evaluates state and if we're not currently connected, clears the server reference. + /// + internal override void RemoveDisconnectedEndpoints() + { + var server = _currentServer; + if (server is { IsSubscriberConnected: false }) + { + _currentServer = null; + } + } + + internal override int EnsureSubscribedToServer( + RedisSubscriber subscriber, + in RedisChannel channel, + CommandFlags flags, + bool internalCall) + { + RemoveIncorrectRouting(subscriber, in channel, flags, internalCall); + if (IsConnectedAny()) return 0; + + // we're not appropriately connected, so blank it out for eligible reconnection + _currentServer = null; + var message = GetSubscriptionMessage(channel, SubscriptionAction.Subscribe, flags, internalCall); + var selected = subscriber.multiplexer.SelectServer(message); + _ = subscriber.ExecuteSync(message, Processor, selected); + return 1; + } + + private void RemoveIncorrectRouting(RedisSubscriber subscriber, in RedisChannel channel, CommandFlags flags, bool internalCall) + { + // only applies to cluster, when using key-routed channels (sharded, explicit key-routed, or + // a single-key keyspace notification); is the subscribed server still handling that channel? + if (channel.IsKeyRouted && _currentServer is { ServerType: ServerType.Cluster } current) + { + // if we consider replicas, there can be multiple valid target servers; we can't ask + // "is this the correct server?", but we can ask "is it suitable?", based on the slot + if (!subscriber.multiplexer.ServerSelectionStrategy.CanServeSlot(_currentServer, channel)) + { + var message = GetSubscriptionMessage(channel, SubscriptionAction.Unsubscribe, flags | CommandFlags.FireAndForget, internalCall); + subscriber.multiplexer.ExecuteSyncImpl(message, Processor, current); + _currentServer = null; // pre-emptively disconnect - F+F + } + } + } + + internal override async Task EnsureSubscribedToServerAsync( + RedisSubscriber subscriber, + RedisChannel channel, + CommandFlags flags, + bool internalCall, + ServerEndPoint? server = null) + { + RemoveIncorrectRouting(subscriber, in channel, flags, internalCall); + if (IsConnectedAny()) return 0; + + // we're not appropriately connected, so blank it out for eligible reconnection + _currentServer = null; + var message = GetSubscriptionMessage(channel, SubscriptionAction.Subscribe, flags, internalCall); + server ??= subscriber.multiplexer.SelectServer(message); + await subscriber.ExecuteAsync(message, Processor, server).ForAwait(); + return 1; + } + } + + // used for keyspace subscriptions, which are routed to multiple nodes + internal sealed class MultiNodeSubscription(CommandFlags flags) : Subscription(flags) + { + private readonly ConcurrentDictionary _servers = new(); + + internal override bool IsConnectedAny() + { + foreach (var server in _servers) + { + if (server.Value is { IsSubscriberConnected: true }) return true; + } + + return false; + } + + internal override int GetConnectionCount() + { + int count = 0; + foreach (var server in _servers) + { + if (server.Value is { IsSubscriberConnected: true }) count++; + } + + return count; + } + + internal override bool IsConnectedTo(EndPoint endpoint) + => _servers.TryGetValue(endpoint, out var server) + && server.IsSubscriberConnected; + + internal override void AddEndpoint(ServerEndPoint server) + { + var ep = server.EndPoint; + if (!_servers.TryAdd(ep, server)) + { + _servers[ep] = server; + } + } + + internal override bool TryRemoveEndpoint(ServerEndPoint expected) + { + return _servers.TryRemove(expected.EndPoint, out _); + } + + internal override ServerEndPoint? GetAnyCurrentServer() + { + ServerEndPoint? last = null; + // prefer actively connected servers, but settle for anything + foreach (var server in _servers) + { + last = server.Value; + if (last is { IsSubscriberConnected: true }) + { + break; + } + } + + return last; + } + + internal override void RemoveDisconnectedEndpoints() + { + // This looks more complicated than it is, because of avoiding mutating the collection + // while iterating; instead, buffer any removals in a scratch buffer, and remove them in a second pass. + EndPoint[] scratch = []; + int count = 0; + foreach (var server in _servers) + { + if (server.Value.IsSubscriberConnected) + { + // flag for removal + if (scratch.Length == count) // need to resize the scratch buffer, using the pool + { + // let the array pool worry about min-sizing etc + var newLease = ArrayPool.Shared.Rent(count + 1); + scratch.CopyTo(newLease, 0); + ArrayPool.Shared.Return(scratch); + scratch = newLease; + } + + scratch[count++] = server.Key; + } + } + + // did we find anything to remove? + if (count != 0) + { + foreach (var ep in scratch.AsSpan(0, count)) + { + _servers.TryRemove(ep, out _); + } + } + + ArrayPool.Shared.Return(scratch); + } + + internal override int EnsureSubscribedToServer( + RedisSubscriber subscriber, + in RedisChannel channel, + CommandFlags flags, + bool internalCall) + { + int delta = 0; + var muxer = subscriber.multiplexer; + foreach (var server in muxer.GetServerSnapshot()) + { + var change = GetSubscriptionChange(server, flags); + if (change is not null) + { + // make it so + var message = GetSubscriptionMessage(channel, change.GetValueOrDefault(), flags, internalCall); + subscriber.ExecuteSync(message, Processor, server); + delta++; + } + } + + return delta; + } + + private SubscriptionAction? GetSubscriptionChange(ServerEndPoint server, CommandFlags flags) + { + // exclude sentinel, and only use replicas if we're explicitly asking for them + bool useReplica = (Flags & CommandFlags.DemandReplica) != 0; + bool shouldBeConnected = server.ServerType != ServerType.Sentinel & server.IsReplica == useReplica; + if (shouldBeConnected == IsConnectedTo(server.EndPoint)) + { + return null; + } + return shouldBeConnected ? SubscriptionAction.Subscribe : SubscriptionAction.Unsubscribe; + } + + internal override async Task EnsureSubscribedToServerAsync( + RedisSubscriber subscriber, + RedisChannel channel, + CommandFlags flags, + bool internalCall, + ServerEndPoint? server = null) + { + int delta = 0; + var muxer = subscriber.multiplexer; + var snapshot = muxer.GetServerSnaphotMemory(); + var len = snapshot.Length; + for (int i = 0; i < len; i++) + { + var loopServer = snapshot.Span[i]; // spans and async do not mix well + if (server is null || server == loopServer) // either "all" or "just the one we passed in" + { + var change = GetSubscriptionChange(loopServer, flags); + if (change is not null) + { + // make it so + var message = GetSubscriptionMessage(channel, change.GetValueOrDefault(), flags, internalCall); + await subscriber.ExecuteAsync(message, Processor, loopServer).ForAwait(); + delta++; + } + } + } + + return delta; + } + + internal override bool UnsubscribeFromServer( + RedisSubscriber subscriber, + in RedisChannel channel, + CommandFlags flags, + bool internalCall) + { + bool any = false; + foreach (var server in _servers) + { + var message = GetSubscriptionMessage(channel, SubscriptionAction.Unsubscribe, flags, internalCall); + any |= subscriber.ExecuteSync(message, Processor, server.Value); + } + + return any; + } + + internal override async Task UnsubscribeFromServerAsync( + RedisSubscriber subscriber, + RedisChannel channel, + CommandFlags flags, + object? asyncState, + bool internalCall) + { + bool any = false; + foreach (var server in _servers) + { + var message = GetSubscriptionMessage(channel, SubscriptionAction.Unsubscribe, flags, internalCall); + any |= await subscriber.ExecuteAsync(message, Processor, server.Value).ForAwait(); + } + + return any; + } + } +} diff --git a/tests/RedisConfigs/3.0.503/redis.windows-service.conf b/tests/RedisConfigs/3.0.503/redis.windows-service.conf index ed44371a3..b374dad58 100644 --- a/tests/RedisConfigs/3.0.503/redis.windows-service.conf +++ b/tests/RedisConfigs/3.0.503/redis.windows-service.conf @@ -829,7 +829,7 @@ latency-monitor-threshold 0 # By default all notifications are disabled because most users don't need # this feature and the feature has some overhead. Note that if you don't # specify at least one of K or E, no events will be delivered. -notify-keyspace-events "" +notify-keyspace-events "AKE" ############################### ADVANCED CONFIG ############################### diff --git a/tests/RedisConfigs/3.0.503/redis.windows.conf b/tests/RedisConfigs/3.0.503/redis.windows.conf index c07a7e9ab..4a99b8fdb 100644 --- a/tests/RedisConfigs/3.0.503/redis.windows.conf +++ b/tests/RedisConfigs/3.0.503/redis.windows.conf @@ -829,7 +829,7 @@ latency-monitor-threshold 0 # By default all notifications are disabled because most users don't need # this feature and the feature has some overhead. Note that if you don't # specify at least one of K or E, no events will be delivered. -notify-keyspace-events "" +notify-keyspace-events "AKE" ############################### ADVANCED CONFIG ############################### diff --git a/tests/RedisConfigs/Basic/primary-6379-3.0.conf b/tests/RedisConfigs/Basic/primary-6379-3.0.conf index 1f4d96da5..889756fec 100644 --- a/tests/RedisConfigs/Basic/primary-6379-3.0.conf +++ b/tests/RedisConfigs/Basic/primary-6379-3.0.conf @@ -6,4 +6,5 @@ maxmemory 6gb dir "../Temp" appendonly no dbfilename "primary-6379.rdb" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Basic/primary-6379.conf b/tests/RedisConfigs/Basic/primary-6379.conf index dee83828c..2da592601 100644 --- a/tests/RedisConfigs/Basic/primary-6379.conf +++ b/tests/RedisConfigs/Basic/primary-6379.conf @@ -7,4 +7,5 @@ dir "../Temp" appendonly no dbfilename "primary-6379.rdb" save "" -enable-debug-command yes \ No newline at end of file +enable-debug-command yes +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Basic/replica-6380.conf b/tests/RedisConfigs/Basic/replica-6380.conf index 8d87e54c2..0c1650513 100644 --- a/tests/RedisConfigs/Basic/replica-6380.conf +++ b/tests/RedisConfigs/Basic/replica-6380.conf @@ -7,4 +7,5 @@ maxmemory 2gb appendonly no dir "../Temp" dbfilename "replica-6380.rdb" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Basic/secure-6381.conf b/tests/RedisConfigs/Basic/secure-6381.conf index bd9359244..ad2e380ad 100644 --- a/tests/RedisConfigs/Basic/secure-6381.conf +++ b/tests/RedisConfigs/Basic/secure-6381.conf @@ -4,4 +4,5 @@ databases 2000 maxmemory 512mb dir "../Temp" dbfilename "secure-6381.rdb" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Basic/tls-ciphers-6384.conf b/tests/RedisConfigs/Basic/tls-ciphers-6384.conf index 52fc7d7b1..857d5c741 100644 --- a/tests/RedisConfigs/Basic/tls-ciphers-6384.conf +++ b/tests/RedisConfigs/Basic/tls-ciphers-6384.conf @@ -9,3 +9,4 @@ tls-protocols "TLSv1.2 TLSv1.3" tls-cert-file /Certs/redis.crt tls-key-file /Certs/redis.key tls-ca-cert-file /Certs/ca.crt +notify-keyspace-events AKE diff --git a/tests/RedisConfigs/Cluster/cluster-7000.conf b/tests/RedisConfigs/Cluster/cluster-7000.conf index f250a3db3..ad11a23fd 100644 --- a/tests/RedisConfigs/Cluster/cluster-7000.conf +++ b/tests/RedisConfigs/Cluster/cluster-7000.conf @@ -6,4 +6,5 @@ cluster-node-timeout 5000 appendonly yes dbfilename "dump-7000.rdb" appendfilename "appendonly-7000.aof" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Cluster/cluster-7001.conf b/tests/RedisConfigs/Cluster/cluster-7001.conf index 1ae0c6f83..589f9ea23 100644 --- a/tests/RedisConfigs/Cluster/cluster-7001.conf +++ b/tests/RedisConfigs/Cluster/cluster-7001.conf @@ -6,4 +6,5 @@ cluster-node-timeout 5000 appendonly yes dbfilename "dump-7001.rdb" appendfilename "appendonly-7001.aof" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Cluster/cluster-7002.conf b/tests/RedisConfigs/Cluster/cluster-7002.conf index 897301f59..66a376865 100644 --- a/tests/RedisConfigs/Cluster/cluster-7002.conf +++ b/tests/RedisConfigs/Cluster/cluster-7002.conf @@ -6,4 +6,5 @@ cluster-node-timeout 5000 appendonly yes dbfilename "dump-7002.rdb" appendfilename "appendonly-7002.aof" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Cluster/cluster-7003.conf b/tests/RedisConfigs/Cluster/cluster-7003.conf index 0b51677fd..1f4883023 100644 --- a/tests/RedisConfigs/Cluster/cluster-7003.conf +++ b/tests/RedisConfigs/Cluster/cluster-7003.conf @@ -6,4 +6,5 @@ cluster-node-timeout 5000 appendonly yes dbfilename "dump-7003.rdb" appendfilename "appendonly-7003.aof" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Cluster/cluster-7004.conf b/tests/RedisConfigs/Cluster/cluster-7004.conf index 9a49d21f5..93d75f38a 100644 --- a/tests/RedisConfigs/Cluster/cluster-7004.conf +++ b/tests/RedisConfigs/Cluster/cluster-7004.conf @@ -6,4 +6,5 @@ cluster-node-timeout 5000 appendonly yes dbfilename "dump-7004.rdb" appendfilename "appendonly-7004.aof" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Cluster/cluster-7005.conf b/tests/RedisConfigs/Cluster/cluster-7005.conf index b333a4b44..c9b5d55e2 100644 --- a/tests/RedisConfigs/Cluster/cluster-7005.conf +++ b/tests/RedisConfigs/Cluster/cluster-7005.conf @@ -6,4 +6,5 @@ cluster-node-timeout 5000 appendonly yes dbfilename "dump-7005.rdb" appendfilename "appendonly-7005.aof" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Failover/primary-6382.conf b/tests/RedisConfigs/Failover/primary-6382.conf index c19e8c701..6055c0347 100644 --- a/tests/RedisConfigs/Failover/primary-6382.conf +++ b/tests/RedisConfigs/Failover/primary-6382.conf @@ -6,4 +6,5 @@ maxmemory 2gb dir "../Temp" appendonly no dbfilename "primary-6382.rdb" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Failover/replica-6383.conf b/tests/RedisConfigs/Failover/replica-6383.conf index 6f1a0fc7d..e07f5a69d 100644 --- a/tests/RedisConfigs/Failover/replica-6383.conf +++ b/tests/RedisConfigs/Failover/replica-6383.conf @@ -7,4 +7,5 @@ maxmemory 2gb appendonly no dir "../Temp" dbfilename "replica-6383.rdb" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Sentinel/redis-7010.conf b/tests/RedisConfigs/Sentinel/redis-7010.conf index 0e27680b2..878160632 100644 --- a/tests/RedisConfigs/Sentinel/redis-7010.conf +++ b/tests/RedisConfigs/Sentinel/redis-7010.conf @@ -5,4 +5,5 @@ maxmemory 100mb appendonly no dir "../Temp" dbfilename "sentinel-target-7010.rdb" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/RedisConfigs/Sentinel/redis-7011.conf b/tests/RedisConfigs/Sentinel/redis-7011.conf index 6d02eb150..08b8dad1a 100644 --- a/tests/RedisConfigs/Sentinel/redis-7011.conf +++ b/tests/RedisConfigs/Sentinel/redis-7011.conf @@ -6,4 +6,5 @@ maxmemory 100mb appendonly no dir "../Temp" dbfilename "sentinel-target-7011.rdb" -save "" \ No newline at end of file +save "" +notify-keyspace-events AKE \ No newline at end of file diff --git a/tests/StackExchange.Redis.Tests/Certificates/CertValidationTests.cs b/tests/StackExchange.Redis.Tests/Certificates/CertValidationTests.cs index a0d9b5c88..fa80114f8 100644 --- a/tests/StackExchange.Redis.Tests/Certificates/CertValidationTests.cs +++ b/tests/StackExchange.Redis.Tests/Certificates/CertValidationTests.cs @@ -51,7 +51,9 @@ public void CheckIssuerValidity() Assert.False(callback(this, endpointCert, null, SslPolicyErrors.RemoteCertificateChainErrors | SslPolicyErrors.RemoteCertificateNotAvailable), "subtest 3f"); } +#pragma warning disable SYSLIB0057 private static X509Certificate2 LoadCert(string certificatePath) => new X509Certificate2(File.ReadAllBytes(certificatePath)); +#pragma warning restore SYSLIB0057 [Fact] public void CheckIssuerArgs() diff --git a/tests/StackExchange.Redis.Tests/ClusterShardedTests.cs b/tests/StackExchange.Redis.Tests/ClusterShardedTests.cs index 8af0a1c7b..c9101fb08 100644 --- a/tests/StackExchange.Redis.Tests/ClusterShardedTests.cs +++ b/tests/StackExchange.Redis.Tests/ClusterShardedTests.cs @@ -72,7 +72,8 @@ public async Task TestShardedPubsubSubscriberAgainstReconnects() public async Task TestShardedPubsubSubscriberAgainsHashSlotMigration() { Skip.UnlessLongRunning(); - var channel = RedisChannel.Sharded(Me()); + var channel = RedisChannel.Sharded(Me()); // invent a channel that will use SSUBSCRIBE + var key = (RedisKey)(byte[])channel!; // use the same value as a key, to test keyspace notifications via a single-key API await using var conn = Create(allowAdmin: true, keepAlive: 1, connectTimeout: 3000, shared: false, require: RedisFeatures.v7_0_0_rc1); Assert.True(conn.IsConnected); var db = conn.GetDatabase(); @@ -80,46 +81,75 @@ public async Task TestShardedPubsubSubscriberAgainsHashSlotMigration() await Task.Delay(50); // let the sub settle (this isn't needed on RESP3, note) var pubsub = conn.GetSubscriber(); - List<(RedisChannel, RedisValue)> received = []; - var queue = await pubsub.SubscribeAsync(channel); - _ = Task.Run(async () => + var keynotify = RedisChannel.KeySpaceSingleKey(key, db.Database); + Assert.False(keynotify.IsSharded); // keyspace notifications do not use SSUBSCRIBE; this matters, because it means we don't get nuked when the slot migrates + Assert.False(keynotify.IsMultiNode); // we specificially want this *not* to be multi-node; we want to test that it follows the key correctly + + int keynotificationCount = 0; + await pubsub.SubscribeAsync(keynotify, (_, _) => Interlocked.Increment(ref keynotificationCount)); + try { - // use queue API to have control over order - await foreach (var item in queue) + List<(RedisChannel, RedisValue)> received = []; + var queue = await pubsub.SubscribeAsync(channel); + _ = Task.Run(async () => { - lock (received) + // use queue API to have control over order + await foreach (var item in queue) { - if (item.Channel.IsSharded && item.Channel == channel) received.Add((item.Channel, item.Message)); + lock (received) + { + if (item.Channel.IsSharded && item.Channel == channel) + received.Add((item.Channel, item.Message)); + } } + }); + Assert.Equal(2, conn.GetSubscriptionsCount()); + + await Task.Delay(50); // let the sub settle (this isn't needed on RESP3, note) + await db.PingAsync(); + + for (int i = 0; i < 5; i++) + { + // check we get a hit + Assert.Equal(1, await db.PublishAsync(channel, i.ToString())); + await db.StringIncrementAsync(key); } - }); - Assert.Equal(1, conn.GetSubscriptionsCount()); - await Task.Delay(50); // let the sub settle (this isn't needed on RESP3, note) - await db.PingAsync(); + await Task.Delay(50); // let the sub settle (this isn't needed on RESP3, note) - for (int i = 0; i < 5; i++) - { - // check we get a hit - Assert.Equal(1, await db.PublishAsync(channel, i.ToString())); - } - await Task.Delay(50); // let the sub settle (this isn't needed on RESP3, note) + // lets migrate the slot for "testShardChannel" to another node + await DoHashSlotMigrationAsync(); + + await Task.Delay(4000); + for (int i = 0; i < 5; i++) + { + // check we get a hit + Assert.Equal(1, await db.PublishAsync(channel, i.ToString())); + await db.StringIncrementAsync(key); + } - // lets migrate the slot for "testShardChannel" to another node - await DoHashSlotMigrationAsync(); + await Task.Delay(50); // let the sub settle (this isn't needed on RESP3, note) - await Task.Delay(4000); - for (int i = 0; i < 5; i++) + Assert.Equal(2, conn.GetSubscriptionsCount()); + Assert.Equal(10, received.Count); + Assert.Equal(10, Volatile.Read(ref keynotificationCount)); + await RollbackHashSlotMigrationAsync(); + ClearAmbientFailures(); + } + finally { - // check we get a hit - Assert.Equal(1, await db.PublishAsync(channel, i.ToString())); + try + { + // ReSharper disable once MethodHasAsyncOverload - F+F + await pubsub.UnsubscribeAsync(keynotify, flags: CommandFlags.FireAndForget); + await pubsub.UnsubscribeAsync(channel, flags: CommandFlags.FireAndForget); + Log("Channels unsubscribed."); + } + catch (Exception ex) + { + Log($"Error while unsubscribing: {ex.Message}"); + } } - await Task.Delay(50); // let the sub settle (this isn't needed on RESP3, note) - - Assert.Equal(1, conn.GetSubscriptionsCount()); - Assert.Equal(10, received.Count); - await RollbackHashSlotMigrationAsync(); - ClearAmbientFailures(); } private Task DoHashSlotMigrationAsync() => MigrateSlotForTestShardChannelAsync(false); diff --git a/tests/StackExchange.Redis.Tests/ClusterTests.cs b/tests/StackExchange.Redis.Tests/ClusterTests.cs index 8146dc9be..781b65fef 100644 --- a/tests/StackExchange.Redis.Tests/ClusterTests.cs +++ b/tests/StackExchange.Redis.Tests/ClusterTests.cs @@ -743,11 +743,15 @@ public async Task ConnectIncludesSubscriber() } [Theory] - [InlineData(true, false)] - [InlineData(true, true)] - [InlineData(false, false)] - [InlineData(false, true)] - public async Task ClusterPubSub(bool sharded, bool withKeyRouting) + [InlineData(true, false, false)] + [InlineData(true, true, false)] + [InlineData(false, false, false)] + [InlineData(false, true, false)] + [InlineData(true, false, true)] + [InlineData(true, true, true)] + [InlineData(false, false, true)] + [InlineData(false, true, true)] + public async Task ClusterPubSub(bool sharded, bool withKeyRouting, bool withKeyPrefix) { var guid = Guid.NewGuid().ToString(); var channel = sharded ? RedisChannel.Sharded(guid) : RedisChannel.Literal(guid); @@ -755,7 +759,12 @@ public async Task ClusterPubSub(bool sharded, bool withKeyRouting) { channel = channel.WithKeyRouting(); } - await using var conn = Create(keepAlive: 1, connectTimeout: 3000, shared: false, require: sharded ? RedisFeatures.v7_0_0_rc1 : RedisFeatures.v2_0_0); + await using var conn = Create( + keepAlive: 1, + connectTimeout: 3000, + shared: false, + require: sharded ? RedisFeatures.v7_0_0_rc1 : RedisFeatures.v2_0_0, + channelPrefix: withKeyPrefix ? "c_prefix:" : null); Assert.True(conn.IsConnected); var pubsub = conn.GetSubscriber(); @@ -778,7 +787,7 @@ public async Task ClusterPubSub(bool sharded, bool withKeyRouting) } List<(RedisChannel, RedisValue)> received = []; - var queue = await pubsub.SubscribeAsync(channel); + var queue = await pubsub.SubscribeAsync(channel, CommandFlags.NoRedirect); _ = Task.Run(async () => { // use queue API to have control over order diff --git a/tests/StackExchange.Redis.Tests/FailoverTests.cs b/tests/StackExchange.Redis.Tests/FailoverTests.cs index 1f33275b5..825c8efce 100644 --- a/tests/StackExchange.Redis.Tests/FailoverTests.cs +++ b/tests/StackExchange.Redis.Tests/FailoverTests.cs @@ -236,10 +236,12 @@ public async Task SubscriptionsSurviveConnectionFailureAsync() server.SimulateConnectionFailure(SimulatedFailureType.All); // Trigger failure (RedisTimeoutException or RedisConnectionException because // of backlog behavior) - var ex = Assert.ThrowsAny(() => sub.Ping()); - Assert.True(ex is RedisTimeoutException or RedisConnectionException); Assert.False(sub.IsConnected(channel)); + var ex = Assert.ThrowsAny(() => Log($"Ping: {sub.Ping(CommandFlags.DemandMaster)}ms")); + Assert.True(ex is RedisTimeoutException or RedisConnectionException); + Log($"Failed as expected: {ex.Message}"); + // Now reconnect... conn.AllowConnect = true; Log("Waiting on reconnect"); @@ -263,7 +265,7 @@ public async Task SubscriptionsSurviveConnectionFailureAsync() foreach (var pair in muxerSubs) { var muxerSub = pair.Value; - Log($" Muxer Sub: {pair.Key}: (EndPoint: {muxerSub.GetCurrentServer()}, Connected: {muxerSub.IsConnected})"); + Log($" Muxer Sub: {pair.Key}: (EndPoint: {muxerSub.GetAnyCurrentServer()}, Connected: {muxerSub.IsConnectedAny()})"); } Log("Publishing"); diff --git a/tests/StackExchange.Redis.Tests/FastHashTests.cs b/tests/StackExchange.Redis.Tests/FastHashTests.cs index 418198cfd..a032cfc80 100644 --- a/tests/StackExchange.Redis.Tests/FastHashTests.cs +++ b/tests/StackExchange.Redis.Tests/FastHashTests.cs @@ -2,13 +2,14 @@ using System.Runtime.InteropServices; using System.Text; using Xunit; +using Xunit.Sdk; #pragma warning disable CS8981, SA1134, SA1300, SA1303, SA1502 // names are weird in this test! // ReSharper disable InconsistentNaming - to better represent expected literals // ReSharper disable IdentifierTypo namespace StackExchange.Redis.Tests; -public partial class FastHashTests +public partial class FastHashTests(ITestOutputHelper log) { // note: if the hashing algorithm changes, we can update the last parameter freely; it doesn't matter // what it *is* - what matters is that we can see that it has entropy between different values @@ -83,6 +84,46 @@ public void FastHashIs_Long() Assert.False(abcdefghijklmnopqrst.Is(hash, value)); } + [Fact] + public void KeyNotificationTypeFastHash_MinMaxBytes_ReflectsActualLengths() + { + // Use reflection to find all nested types in KeyNotificationTypeFastHash + var fastHashType = typeof(KeyNotificationTypeFastHash); + var nestedTypes = fastHashType.GetNestedTypes(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + + int? minLength = null; + int? maxLength = null; + + foreach (var nestedType in nestedTypes) + { + // Look for the Length field (generated by FastHash source generator) + var lengthField = nestedType.GetField("Length", System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.Static); + if (lengthField != null && lengthField.FieldType == typeof(int)) + { + var length = (int)lengthField.GetValue(null)!; + + if (minLength == null || length < minLength) + { + minLength = length; + } + + if (maxLength == null || length > maxLength) + { + maxLength = length; + } + } + } + + // Assert that we found at least some nested types with Length fields + Assert.NotNull(minLength); + Assert.NotNull(maxLength); + + // Assert that MinBytes and MaxBytes match the actual min/max lengths + log.WriteLine($"MinBytes: {KeyNotificationTypeFastHash.MinBytes}, MaxBytes: {KeyNotificationTypeFastHash.MaxBytes}"); + Assert.Equal(KeyNotificationTypeFastHash.MinBytes, minLength.Value); + Assert.Equal(KeyNotificationTypeFastHash.MaxBytes, maxLength.Value); + } + [FastHash] private static partial class a { } [FastHash] private static partial class ab { } [FastHash] private static partial class abc { } diff --git a/tests/StackExchange.Redis.Tests/HashTests.cs b/tests/StackExchange.Redis.Tests/HashTests.cs index af2fa11c8..9523ca102 100644 --- a/tests/StackExchange.Redis.Tests/HashTests.cs +++ b/tests/StackExchange.Redis.Tests/HashTests.cs @@ -265,7 +265,7 @@ public async Task TestGetAll() } var inRedis = (await db.HashGetAllAsync(key).ForAwait()).ToDictionary( - x => Guid.Parse(x.Name!), x => int.Parse(x.Value!)); + x => Guid.Parse((string)x.Name!), x => int.Parse(x.Value!)); Assert.Equal(shouldMatch.Count, inRedis.Count); diff --git a/tests/StackExchange.Redis.Tests/Helpers/TextWriterOutputHelper.cs b/tests/StackExchange.Redis.Tests/Helpers/TextWriterOutputHelper.cs index e41a46670..2a23f3246 100644 --- a/tests/StackExchange.Redis.Tests/Helpers/TextWriterOutputHelper.cs +++ b/tests/StackExchange.Redis.Tests/Helpers/TextWriterOutputHelper.cs @@ -7,7 +7,7 @@ namespace StackExchange.Redis.Tests.Helpers; public class TextWriterOutputHelper(ITestOutputHelper outputHelper) : TextWriter { - private StringBuilder Buffer { get; } = new StringBuilder(2048); + private readonly StringBuilder _buffer = new(2048); private StringBuilder? Echo { get; set; } public override Encoding Encoding => Encoding.UTF8; private readonly ITestOutputHelper Output = outputHelper; @@ -37,7 +37,10 @@ public override void WriteLine(string? value) try { - base.WriteLine(value); + lock (_buffer) // keep everything together + { + base.WriteLine(value); + } } catch (Exception ex) { @@ -49,32 +52,44 @@ public override void WriteLine(string? value) public override void Write(char value) { - if (value == '\n' || value == '\r') + lock (_buffer) { - // Ignore empty lines - if (Buffer.Length > 0) + if (value == '\n' || value == '\r') { - FlushBuffer(); + // Ignore empty lines + if (_buffer.Length > 0) + { + FlushBuffer(); + } + } + else + { + _buffer.Append(value); } - } - else - { - Buffer.Append(value); } } protected override void Dispose(bool disposing) { - if (Buffer.Length > 0) + lock (_buffer) { - FlushBuffer(); + if (_buffer.Length > 0) + { + FlushBuffer(); + } } + base.Dispose(disposing); } private void FlushBuffer() { - var text = Buffer.ToString(); + string text; + lock (_buffer) + { + text = _buffer.ToString(); + _buffer.Clear(); + } try { Output.WriteLine(text); @@ -84,6 +99,5 @@ private void FlushBuffer() // Thrown when writing from a handler after a test has ended - just bail in this case } Echo?.AppendLine(text); - Buffer.Clear(); } } diff --git a/tests/StackExchange.Redis.Tests/Issues/Issue2507.cs b/tests/StackExchange.Redis.Tests/Issues/Issue2507.cs index b548d7031..f77e43e29 100644 --- a/tests/StackExchange.Redis.Tests/Issues/Issue2507.cs +++ b/tests/StackExchange.Redis.Tests/Issues/Issue2507.cs @@ -7,7 +7,7 @@ namespace StackExchange.Redis.Tests.Issues; [Collection(NonParallelCollection.Name)] public class Issue2507(ITestOutputHelper output, SharedConnectionFixture? fixture = null) : TestBase(output, fixture) { - [Fact(Explicit = true)] + [Fact(Explicit = true)] // note this may show as Inconclusive, depending on the runner public async Task Execute() { await using var conn = Create(shared: false); diff --git a/tests/StackExchange.Redis.Tests/KeyNotificationTests.cs b/tests/StackExchange.Redis.Tests/KeyNotificationTests.cs new file mode 100644 index 000000000..60469eb49 --- /dev/null +++ b/tests/StackExchange.Redis.Tests/KeyNotificationTests.cs @@ -0,0 +1,698 @@ +using System; +using System.Buffers; +using System.Text; +using Xunit; +using Xunit.Sdk; + +namespace StackExchange.Redis.Tests; + +public class KeyNotificationTests(ITestOutputHelper log) +{ + [Theory] + [InlineData("foo", "foo")] + [InlineData("__foo__", "__foo__")] + [InlineData("__keyspace@4__:", "__keyspace@4__:")] // not long enough + [InlineData("__keyspace@4__:f", "f")] + [InlineData("__keyspace@4__:fo", "fo")] + [InlineData("__keyspace@4__:foo", "foo")] + [InlineData("__keyspace@42__:foo", "foo")] // check multi-char db + [InlineData("__keyevent@4__:foo", "__keyevent@4__:foo")] // key-event + [InlineData("__keyevent@42__:foo", "__keyevent@42__:foo")] // key-event + public void RoutingSpan_StripKeySpacePrefix(string raw, string routed) + { + ReadOnlySpan srcBytes = Encoding.UTF8.GetBytes(raw); + var strippedBytes = RedisChannel.StripKeySpacePrefix(srcBytes); + var result = Encoding.UTF8.GetString(strippedBytes); + Assert.Equal(routed, result); + } + + [Fact] + public void Keyspace_Del_ParsesCorrectly() + { + // __keyspace@1__:mykey with payload "del" + var channel = RedisChannel.Literal("__keyspace@1__:mykey"); + Assert.False(channel.IgnoreChannelPrefix); // because constructed manually + RedisValue value = "del"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeySpace); + Assert.False(notification.IsKeyEvent); + Assert.Equal(1, notification.Database); + Assert.Equal(KeyNotificationType.Del, notification.Type); + Assert.True(notification.IsType("del"u8)); + Assert.Equal("mykey", (string?)notification.GetKey()); + Assert.Equal(5, notification.GetKeyByteCount()); + Assert.Equal(5, notification.GetKeyMaxByteCount()); + Assert.Equal(5, notification.GetKeyCharCount()); + Assert.Equal(6, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void Keyevent_Del_ParsesCorrectly() + { + // __keyevent@42__:del with value "mykey" + var channel = RedisChannel.Literal("__keyevent@42__:del"); + RedisValue value = "mykey"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.False(notification.IsKeySpace); + Assert.True(notification.IsKeyEvent); + Assert.Equal(42, notification.Database); + Assert.Equal(KeyNotificationType.Del, notification.Type); + Assert.True(notification.IsType("del"u8)); + Assert.Equal("mykey", (string?)notification.GetKey()); + Assert.Equal(5, notification.GetKeyByteCount()); + Assert.Equal(18, notification.GetKeyMaxByteCount()); + Assert.Equal(5, notification.GetKeyCharCount()); + Assert.Equal(5, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void Keyspace_Set_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyspace@0__:testkey"); + RedisValue value = "set"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeySpace); + Assert.Equal(0, notification.Database); + Assert.Equal(KeyNotificationType.Set, notification.Type); + Assert.True(notification.IsType("set"u8)); + Assert.Equal("testkey", (string?)notification.GetKey()); + Assert.Equal(7, notification.GetKeyByteCount()); + Assert.Equal(7, notification.GetKeyMaxByteCount()); + Assert.Equal(7, notification.GetKeyCharCount()); + Assert.Equal(8, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void Keyevent_Expire_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyevent@5__:expire"); + RedisValue value = "session:12345"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(5, notification.Database); + Assert.Equal(KeyNotificationType.Expire, notification.Type); + Assert.True(notification.IsType("expire"u8)); + Assert.Equal("session:12345", (string?)notification.GetKey()); + Assert.Equal(13, notification.GetKeyByteCount()); + Assert.Equal(42, notification.GetKeyMaxByteCount()); + Assert.Equal(13, notification.GetKeyCharCount()); + Assert.Equal(13, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void Keyspace_Expired_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyspace@3__:cache:item"); + RedisValue value = "expired"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeySpace); + Assert.Equal(3, notification.Database); + Assert.Equal(KeyNotificationType.Expired, notification.Type); + Assert.True(notification.IsType("expired"u8)); + Assert.Equal("cache:item", (string?)notification.GetKey()); + Assert.Equal(10, notification.GetKeyByteCount()); + Assert.Equal(10, notification.GetKeyMaxByteCount()); + Assert.Equal(10, notification.GetKeyCharCount()); + Assert.Equal(11, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void Keyevent_LPush_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyevent@0__:lpush"); + RedisValue value = "queue:tasks"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(0, notification.Database); + Assert.Equal(KeyNotificationType.LPush, notification.Type); + Assert.True(notification.IsType("lpush"u8)); + Assert.Equal("queue:tasks", (string?)notification.GetKey()); + Assert.Equal(11, notification.GetKeyByteCount()); + Assert.Equal(36, notification.GetKeyMaxByteCount()); + Assert.Equal(11, notification.GetKeyCharCount()); + Assert.Equal(11, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void Keyspace_HSet_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyspace@2__:user:1000"); + RedisValue value = "hset"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeySpace); + Assert.Equal(2, notification.Database); + Assert.Equal(KeyNotificationType.HSet, notification.Type); + Assert.True(notification.IsType("hset"u8)); + Assert.Equal("user:1000", (string?)notification.GetKey()); + Assert.Equal(9, notification.GetKeyByteCount()); + Assert.Equal(9, notification.GetKeyMaxByteCount()); + Assert.Equal(9, notification.GetKeyCharCount()); + Assert.Equal(10, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void Keyevent_ZAdd_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyevent@7__:zadd"); + RedisValue value = "leaderboard"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(7, notification.Database); + Assert.Equal(KeyNotificationType.ZAdd, notification.Type); + Assert.True(notification.IsType("zadd"u8)); + Assert.Equal("leaderboard", (string?)notification.GetKey()); + Assert.Equal(11, notification.GetKeyByteCount()); + Assert.Equal(36, notification.GetKeyMaxByteCount()); + Assert.Equal(11, notification.GetKeyCharCount()); + Assert.Equal(11, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void CustomEventWithUnusualValue_Works() + { + var channel = RedisChannel.Literal("__keyevent@7__:flooble"); + RedisValue value = 17.5; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(7, notification.Database); + Assert.Equal(KeyNotificationType.Unknown, notification.Type); + Assert.False(notification.IsType("zadd"u8)); + Assert.True(notification.IsType("flooble"u8)); + Assert.Equal("17.5", (string?)notification.GetKey()); + Assert.Equal(4, notification.GetKeyByteCount()); + Assert.Equal(40, notification.GetKeyMaxByteCount()); + Assert.Equal(4, notification.GetKeyCharCount()); + Assert.Equal(40, notification.GetKeyMaxCharCount()); + } + + [Fact] + public void TryCopyKey_WorksCorrectly() + { + var channel = RedisChannel.Literal("__keyspace@0__:testkey"); + RedisValue value = "set"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + var lease = ArrayPool.Shared.Rent(20); + Span buffer = lease.AsSpan(0, 20); + Assert.True(notification.TryCopyKey(buffer, out var bytesWritten)); + Assert.Equal(7, bytesWritten); + Assert.Equal("testkey", Encoding.UTF8.GetString(lease, 0, bytesWritten)); + ArrayPool.Shared.Return(lease); + } + + [Fact] + public void TryCopyKey_FailsWithSmallBuffer() + { + var channel = RedisChannel.Literal("__keyspace@0__:testkey"); + RedisValue value = "set"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Span buffer = stackalloc byte[3]; // too small + Assert.False(notification.TryCopyKey(buffer, out var bytesWritten)); + Assert.Equal(0, bytesWritten); + } + + [Fact] + public void InvalidChannel_ReturnsFalse() + { + var channel = RedisChannel.Literal("regular:channel"); + RedisValue value = "data"; + + Assert.False(KeyNotification.TryParse(in channel, in value, out var notification)); + } + + [Fact] + public void InvalidKeyspaceChannel_MissingDelimiter_ReturnsFalse() + { + var channel = RedisChannel.Literal("__keyspace@0__"); // missing the key part + RedisValue value = "set"; + + Assert.False(KeyNotification.TryParse(in channel, in value, out var notification)); + } + + [Fact] + public void Keyspace_UnknownEventType_ReturnsUnknown() + { + var channel = RedisChannel.Literal("__keyspace@0__:mykey"); + RedisValue value = "unknownevent"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeySpace); + Assert.Equal(0, notification.Database); + Assert.Equal(KeyNotificationType.Unknown, notification.Type); + Assert.False(notification.IsType("del"u8)); + Assert.Equal("mykey", (string?)notification.GetKey()); + } + + [Fact] + public void Keyevent_UnknownEventType_ReturnsUnknown() + { + var channel = RedisChannel.Literal("__keyevent@0__:unknownevent"); + RedisValue value = "mykey"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(0, notification.Database); + Assert.Equal(KeyNotificationType.Unknown, notification.Type); + Assert.False(notification.IsType("del"u8)); + Assert.Equal("mykey", (string?)notification.GetKey()); + } + + [Fact] + public void Keyspace_WithColonInKey_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyspace@0__:user:session:12345"); + RedisValue value = "del"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeySpace); + Assert.Equal(0, notification.Database); + Assert.Equal(KeyNotificationType.Del, notification.Type); + Assert.True(notification.IsType("del"u8)); + Assert.Equal("user:session:12345", (string?)notification.GetKey()); + } + + [Fact] + public void Keyevent_Evicted_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyevent@1__:evicted"); + RedisValue value = "cache:old"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(1, notification.Database); + Assert.Equal(KeyNotificationType.Evicted, notification.Type); + Assert.True(notification.IsType("evicted"u8)); + Assert.Equal("cache:old", (string?)notification.GetKey()); + } + + [Fact] + public void Keyspace_New_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyspace@0__:newkey"); + RedisValue value = "new"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeySpace); + Assert.Equal(0, notification.Database); + Assert.Equal(KeyNotificationType.New, notification.Type); + Assert.True(notification.IsType("new"u8)); + Assert.Equal("newkey", (string?)notification.GetKey()); + } + + [Fact] + public void Keyevent_XGroupCreate_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyevent@0__:xgroup-create"); + RedisValue value = "mystream"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(0, notification.Database); + Assert.Equal(KeyNotificationType.XGroupCreate, notification.Type); + Assert.True(notification.IsType("xgroup-create"u8)); + Assert.Equal("mystream", (string?)notification.GetKey()); + } + + [Fact] + public void Keyspace_TypeChanged_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyspace@0__:mykey"); + RedisValue value = "type_changed"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeySpace); + Assert.Equal(0, notification.Database); + Assert.Equal(KeyNotificationType.TypeChanged, notification.Type); + Assert.True(notification.IsType("type_changed"u8)); + Assert.Equal("mykey", (string?)notification.GetKey()); + } + + [Fact] + public void Keyevent_HighDatabaseNumber_ParsesCorrectly() + { + var channel = RedisChannel.Literal("__keyevent@999__:set"); + RedisValue value = "testkey"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(999, notification.Database); + Assert.Equal(KeyNotificationType.Set, notification.Type); + Assert.True(notification.IsType("set"u8)); + Assert.Equal("testkey", (string?)notification.GetKey()); + } + + [Fact] + public void Keyevent_NonIntegerDatabase_ParsesWellEnough() + { + var channel = RedisChannel.Literal("__keyevent@abc__:set"); + RedisValue value = "testkey"; + + Assert.True(KeyNotification.TryParse(in channel, in value, out var notification)); + + Assert.True(notification.IsKeyEvent); + Assert.Equal(-1, notification.Database); + Assert.Equal(KeyNotificationType.Set, notification.Type); + Assert.True(notification.IsType("set"u8)); + Assert.Equal("testkey", (string?)notification.GetKey()); + } + + [Fact] + public void DefaultKeyNotification_HasExpectedProperties() + { + var notification = default(KeyNotification); + + Assert.False(notification.IsKeySpace); + Assert.False(notification.IsKeyEvent); + Assert.Equal(-1, notification.Database); + Assert.Equal(KeyNotificationType.Unknown, notification.Type); + Assert.False(notification.IsType("del"u8)); + Assert.True(notification.GetKey().IsNull); + Assert.Equal(0, notification.GetKeyByteCount()); + Assert.Equal(0, notification.GetKeyMaxByteCount()); + Assert.Equal(0, notification.GetKeyCharCount()); + Assert.Equal(0, notification.GetKeyMaxCharCount()); + Assert.True(notification.GetChannel().IsNull); + Assert.True(notification.GetValue().IsNull); + + // TryCopyKey should return false and write 0 bytes + Span buffer = stackalloc byte[10]; + Assert.False(notification.TryCopyKey(buffer, out var bytesWritten)); + Assert.Equal(0, bytesWritten); + } + + [Theory] + [InlineData(KeyNotificationTypeFastHash.append.Text, KeyNotificationType.Append)] + [InlineData(KeyNotificationTypeFastHash.copy.Text, KeyNotificationType.Copy)] + [InlineData(KeyNotificationTypeFastHash.del.Text, KeyNotificationType.Del)] + [InlineData(KeyNotificationTypeFastHash.expire.Text, KeyNotificationType.Expire)] + [InlineData(KeyNotificationTypeFastHash.hdel.Text, KeyNotificationType.HDel)] + [InlineData(KeyNotificationTypeFastHash.hexpired.Text, KeyNotificationType.HExpired)] + [InlineData(KeyNotificationTypeFastHash.hincrbyfloat.Text, KeyNotificationType.HIncrByFloat)] + [InlineData(KeyNotificationTypeFastHash.hincrby.Text, KeyNotificationType.HIncrBy)] + [InlineData(KeyNotificationTypeFastHash.hpersist.Text, KeyNotificationType.HPersist)] + [InlineData(KeyNotificationTypeFastHash.hset.Text, KeyNotificationType.HSet)] + [InlineData(KeyNotificationTypeFastHash.incrbyfloat.Text, KeyNotificationType.IncrByFloat)] + [InlineData(KeyNotificationTypeFastHash.incrby.Text, KeyNotificationType.IncrBy)] + [InlineData(KeyNotificationTypeFastHash.linsert.Text, KeyNotificationType.LInsert)] + [InlineData(KeyNotificationTypeFastHash.lpop.Text, KeyNotificationType.LPop)] + [InlineData(KeyNotificationTypeFastHash.lpush.Text, KeyNotificationType.LPush)] + [InlineData(KeyNotificationTypeFastHash.lrem.Text, KeyNotificationType.LRem)] + [InlineData(KeyNotificationTypeFastHash.lset.Text, KeyNotificationType.LSet)] + [InlineData(KeyNotificationTypeFastHash.ltrim.Text, KeyNotificationType.LTrim)] + [InlineData(KeyNotificationTypeFastHash.move_from.Text, KeyNotificationType.MoveFrom)] + [InlineData(KeyNotificationTypeFastHash.move_to.Text, KeyNotificationType.MoveTo)] + [InlineData(KeyNotificationTypeFastHash.persist.Text, KeyNotificationType.Persist)] + [InlineData(KeyNotificationTypeFastHash.rename_from.Text, KeyNotificationType.RenameFrom)] + [InlineData(KeyNotificationTypeFastHash.rename_to.Text, KeyNotificationType.RenameTo)] + [InlineData(KeyNotificationTypeFastHash.restore.Text, KeyNotificationType.Restore)] + [InlineData(KeyNotificationTypeFastHash.rpop.Text, KeyNotificationType.RPop)] + [InlineData(KeyNotificationTypeFastHash.rpush.Text, KeyNotificationType.RPush)] + [InlineData(KeyNotificationTypeFastHash.sadd.Text, KeyNotificationType.SAdd)] + [InlineData(KeyNotificationTypeFastHash.set.Text, KeyNotificationType.Set)] + [InlineData(KeyNotificationTypeFastHash.setrange.Text, KeyNotificationType.SetRange)] + [InlineData(KeyNotificationTypeFastHash.sortstore.Text, KeyNotificationType.SortStore)] + [InlineData(KeyNotificationTypeFastHash.srem.Text, KeyNotificationType.SRem)] + [InlineData(KeyNotificationTypeFastHash.spop.Text, KeyNotificationType.SPop)] + [InlineData(KeyNotificationTypeFastHash.xadd.Text, KeyNotificationType.XAdd)] + [InlineData(KeyNotificationTypeFastHash.xdel.Text, KeyNotificationType.XDel)] + [InlineData(KeyNotificationTypeFastHash.xgroup_createconsumer.Text, KeyNotificationType.XGroupCreateConsumer)] + [InlineData(KeyNotificationTypeFastHash.xgroup_create.Text, KeyNotificationType.XGroupCreate)] + [InlineData(KeyNotificationTypeFastHash.xgroup_delconsumer.Text, KeyNotificationType.XGroupDelConsumer)] + [InlineData(KeyNotificationTypeFastHash.xgroup_destroy.Text, KeyNotificationType.XGroupDestroy)] + [InlineData(KeyNotificationTypeFastHash.xgroup_setid.Text, KeyNotificationType.XGroupSetId)] + [InlineData(KeyNotificationTypeFastHash.xsetid.Text, KeyNotificationType.XSetId)] + [InlineData(KeyNotificationTypeFastHash.xtrim.Text, KeyNotificationType.XTrim)] + [InlineData(KeyNotificationTypeFastHash.zadd.Text, KeyNotificationType.ZAdd)] + [InlineData(KeyNotificationTypeFastHash.zdiffstore.Text, KeyNotificationType.ZDiffStore)] + [InlineData(KeyNotificationTypeFastHash.zinterstore.Text, KeyNotificationType.ZInterStore)] + [InlineData(KeyNotificationTypeFastHash.zunionstore.Text, KeyNotificationType.ZUnionStore)] + [InlineData(KeyNotificationTypeFastHash.zincr.Text, KeyNotificationType.ZIncr)] + [InlineData(KeyNotificationTypeFastHash.zrembyrank.Text, KeyNotificationType.ZRemByRank)] + [InlineData(KeyNotificationTypeFastHash.zrembyscore.Text, KeyNotificationType.ZRemByScore)] + [InlineData(KeyNotificationTypeFastHash.zrem.Text, KeyNotificationType.ZRem)] + [InlineData(KeyNotificationTypeFastHash.expired.Text, KeyNotificationType.Expired)] + [InlineData(KeyNotificationTypeFastHash.evicted.Text, KeyNotificationType.Evicted)] + [InlineData(KeyNotificationTypeFastHash._new.Text, KeyNotificationType.New)] + [InlineData(KeyNotificationTypeFastHash.overwritten.Text, KeyNotificationType.Overwritten)] + [InlineData(KeyNotificationTypeFastHash.type_changed.Text, KeyNotificationType.TypeChanged)] + public unsafe void FastHashParse_AllKnownValues_ParseCorrectly(string raw, KeyNotificationType parsed) + { + var arr = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(raw.Length)); + int bytes; + fixed (byte* bPtr = arr) // encode into the buffer + { + fixed (char* cPtr = raw) + { + bytes = Encoding.UTF8.GetBytes(cPtr, raw.Length, bPtr, arr.Length); + } + } + + var result = KeyNotificationTypeFastHash.Parse(arr.AsSpan(0, bytes)); + log.WriteLine($"Parsed '{raw}' as {result}"); + Assert.Equal(parsed, result); + + // and the other direction: + var fetchedBytes = KeyNotificationTypeFastHash.GetRawBytes(parsed); + string fetched; + fixed (byte* bPtr = fetchedBytes) + { + fetched = Encoding.UTF8.GetString(bPtr, fetchedBytes.Length); + } + + log.WriteLine($"Fetched '{raw}'"); + Assert.Equal(raw, fetched); + + ArrayPool.Shared.Return(arr); + } + + [Fact] + public void CreateKeySpaceNotification_Valid() + { + var channel = RedisChannel.KeySpaceSingleKey("abc", 42); + Assert.Equal("__keyspace@42__:abc", channel.ToString()); + Assert.False(channel.IsMultiNode); + Assert.True(channel.IsKeyRouted); + Assert.False(channel.IsSharded); + Assert.False(channel.IsPattern); + Assert.True(channel.IgnoreChannelPrefix); + } + + [Theory] + [InlineData(null, null, "__keyspace@*__:*")] + [InlineData("abc*", null, "__keyspace@*__:abc*")] + [InlineData(null, 42, "__keyspace@42__:*")] + [InlineData("abc*", 42, "__keyspace@42__:abc*")] + public void CreateKeySpaceNotificationPattern(string? pattern, int? database, string expected) + { + var channel = RedisChannel.KeySpacePattern(pattern, database); + Assert.Equal(expected, channel.ToString()); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsKeyRouted); + Assert.False(channel.IsSharded); + Assert.True(channel.IsPattern); + Assert.True(channel.IgnoreChannelPrefix); + } + + [Theory] + [InlineData("abc", null, "__keyspace@*__:abc*")] + [InlineData("abc", 42, "__keyspace@42__:abc*")] + public void CreateKeySpaceNotificationPrefix_Key(string prefix, int? database, string expected) + { + var channel = RedisChannel.KeySpacePrefix((RedisKey)prefix, database); + Assert.Equal(expected, channel.ToString()); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsKeyRouted); + Assert.False(channel.IsSharded); + Assert.True(channel.IsPattern); + Assert.True(channel.IgnoreChannelPrefix); + } + + [Theory] + [InlineData("abc", null, "__keyspace@*__:abc*")] + [InlineData("abc", 42, "__keyspace@42__:abc*")] + public void CreateKeySpaceNotificationPrefix_Span(string prefix, int? database, string expected) + { + var channel = RedisChannel.KeySpacePrefix((ReadOnlySpan)Encoding.UTF8.GetBytes(prefix), database); + Assert.Equal(expected, channel.ToString()); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsKeyRouted); + Assert.False(channel.IsSharded); + Assert.True(channel.IsPattern); + Assert.True(channel.IgnoreChannelPrefix); + } + + [Theory] + [InlineData("a?bc", null)] + [InlineData("a?bc", 42)] + [InlineData("a*bc", null)] + [InlineData("a*bc", 42)] + [InlineData("a[bc", null)] + [InlineData("a[bc", 42)] + public void CreateKeySpaceNotificationPrefix_DisallowGlob(string prefix, int? database) + { + var bytes = Encoding.UTF8.GetBytes(prefix); + var ex = Assert.Throws(() => + RedisChannel.KeySpacePrefix((RedisKey)bytes, database)); + Assert.StartsWith("The supplied key contains pattern characters, but patterns are not supported in this context.", ex.Message); + + ex = Assert.Throws(() => + RedisChannel.KeySpacePrefix((ReadOnlySpan)bytes, database)); + Assert.StartsWith("The supplied key contains pattern characters, but patterns are not supported in this context.", ex.Message); + } + + [Theory] + [InlineData(KeyNotificationType.Set, null, "__keyevent@*__:set", true)] + [InlineData(KeyNotificationType.XGroupCreate, null, "__keyevent@*__:xgroup-create", true)] + [InlineData(KeyNotificationType.Set, 42, "__keyevent@42__:set", false)] + [InlineData(KeyNotificationType.XGroupCreate, 42, "__keyevent@42__:xgroup-create", false)] + public void CreateKeyEventNotification(KeyNotificationType type, int? database, string expected, bool isPattern) + { + var channel = RedisChannel.KeyEvent(type, database); + Assert.Equal(expected, channel.ToString()); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsKeyRouted); + Assert.False(channel.IsSharded); + Assert.True(channel.IgnoreChannelPrefix); + if (isPattern) + { + Assert.True(channel.IsPattern); + } + else + { + Assert.False(channel.IsPattern); + } + } + + [Theory] + [InlineData("abc", "__keyspace@42__:abc")] + [InlineData("a*bc", "__keyspace@42__:a*bc")] // pattern-like is allowed, since not using PSUBSCRIBE + public void Cannot_KeyRoute_KeySpace_SingleKeyIsKeyRouted(string key, string pattern) + { + var channel = RedisChannel.KeySpaceSingleKey(key, 42); + Assert.Equal(pattern, channel.ToString()); + Assert.False(channel.IsMultiNode); + Assert.False(channel.IsPattern); + Assert.False(channel.IsSharded); + Assert.True(channel.IgnoreChannelPrefix); + Assert.True(channel.IsKeyRouted); + Assert.True(channel.WithKeyRouting().IsKeyRouted); // no change, still key-routed + Assert.Equal(RedisCommand.PUBLISH, channel.GetPublishCommand()); + } + + [Fact] + public void Cannot_KeyRoute_KeySpacePattern() + { + var channel = RedisChannel.KeySpacePattern("abc", 42); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsKeyRouted); + Assert.True(channel.IgnoreChannelPrefix); + Assert.StartsWith("Key routing is not supported for multi-node channels", Assert.Throws(() => channel.WithKeyRouting()).Message); + Assert.StartsWith("Publishing is not supported for multi-node channels", Assert.Throws(() => channel.GetPublishCommand()).Message); + } + + [Fact] + public void Cannot_KeyRoute_KeyEvent() + { + var channel = RedisChannel.KeyEvent(KeyNotificationType.Set, 42); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsKeyRouted); + Assert.True(channel.IgnoreChannelPrefix); + Assert.StartsWith("Key routing is not supported for multi-node channels", Assert.Throws(() => channel.WithKeyRouting()).Message); + Assert.StartsWith("Publishing is not supported for multi-node channels", Assert.Throws(() => channel.GetPublishCommand()).Message); + } + + [Fact] + public void Cannot_KeyRoute_KeyEvent_Custom() + { + var channel = RedisChannel.KeyEvent("foo"u8, 42); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsKeyRouted); + Assert.True(channel.IgnoreChannelPrefix); + Assert.StartsWith("Key routing is not supported for multi-node channels", Assert.Throws(() => channel.WithKeyRouting()).Message); + Assert.StartsWith("Publishing is not supported for multi-node channels", Assert.Throws(() => channel.GetPublishCommand()).Message); + } + + [Fact] + public void KeyEventPrefix_KeySpacePrefix_Length_Matches() + { + // this is a sanity check for the parsing step in KeyNotification.TryParse + Assert.Equal(KeyNotificationChannels.KeySpacePrefix.Length, KeyNotificationChannels.KeyEventPrefix.Length); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void KeyNotificationKeyStripping(bool asString) + { + Span blob = stackalloc byte[32]; + Span clob = stackalloc char[32]; + + RedisChannel channel = RedisChannel.Literal("__keyevent@0__:sadd"); + RedisValue value = asString ? "mykey:abc" : "mykey:abc"u8.ToArray(); + KeyNotification.TryParse(in channel, in value, out var notification); + Assert.Equal("mykey:abc", (string?)notification.GetKey()); + Assert.True(notification.KeyStartsWith("mykey:"u8)); + Assert.Equal(0, notification.KeyOffset); + + Assert.Equal(9, notification.GetKeyByteCount()); + Assert.Equal(asString ? 30 : 9, notification.GetKeyMaxByteCount()); + Assert.Equal(9, notification.GetKeyCharCount()); + Assert.Equal(asString ? 9 : 10, notification.GetKeyMaxCharCount()); + + Assert.True(notification.TryCopyKey(blob, out var bytesWritten)); + Assert.Equal(9, bytesWritten); + Assert.Equal("mykey:abc", Encoding.UTF8.GetString(blob.Slice(0, bytesWritten))); + + Assert.True(notification.TryCopyKey(clob, out var charsWritten)); + Assert.Equal(9, charsWritten); + Assert.Equal("mykey:abc", clob.Slice(0, charsWritten).ToString()); + + // now with a prefix + notification = notification.WithKeySlice("mykey:"u8.Length); + Assert.Equal("abc", (string?)notification.GetKey()); + Assert.False(notification.KeyStartsWith("mykey:"u8)); + Assert.Equal(6, notification.KeyOffset); + + Assert.Equal(3, notification.GetKeyByteCount()); + Assert.Equal(asString ? 24 : 3, notification.GetKeyMaxByteCount()); + Assert.Equal(3, notification.GetKeyCharCount()); + Assert.Equal(asString ? 3 : 4, notification.GetKeyMaxCharCount()); + + Assert.True(notification.TryCopyKey(blob, out bytesWritten)); + Assert.Equal(3, bytesWritten); + Assert.Equal("abc", Encoding.UTF8.GetString(blob.Slice(0, bytesWritten))); + + Assert.True(notification.TryCopyKey(clob, out charsWritten)); + Assert.Equal(3, charsWritten); + Assert.Equal("abc", clob.Slice(0, charsWritten).ToString()); + } +} diff --git a/tests/StackExchange.Redis.Tests/PubSubKeyNotificationTests.cs b/tests/StackExchange.Redis.Tests/PubSubKeyNotificationTests.cs new file mode 100644 index 000000000..723921d45 --- /dev/null +++ b/tests/StackExchange.Redis.Tests/PubSubKeyNotificationTests.cs @@ -0,0 +1,419 @@ +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using StackExchange.Redis.KeyspaceIsolation; +using Xunit; + +namespace StackExchange.Redis.Tests; + +// ReSharper disable once UnusedMember.Global - used via test framework +public sealed class PubSubKeyNotificationTestsCluster(ITestOutputHelper output, ITestContextAccessor context, SharedConnectionFixture fixture) + : PubSubKeyNotificationTests(output, context, fixture) +{ + protected override string GetConfiguration() => TestConfig.Current.ClusterServersAndPorts; +} + +// ReSharper disable once UnusedMember.Global - used via test framework +public sealed class PubSubKeyNotificationTestsStandalone(ITestOutputHelper output, ITestContextAccessor context, SharedConnectionFixture fixture) + : PubSubKeyNotificationTests(output, context, fixture) +{ +} + +public abstract class PubSubKeyNotificationTests(ITestOutputHelper output, ITestContextAccessor context, SharedConnectionFixture? fixture = null) + : TestBase(output, fixture) +{ + private const int DefaultKeyCount = 10; + private const int DefaultEventCount = 512; + private CancellationToken CancellationToken => context.Current.CancellationToken; + + private RedisKey[] InventKeys(out byte[] prefix, int count = DefaultKeyCount) + { + RedisKey[] keys = new RedisKey[count]; + var prefixString = $"{Guid.NewGuid()}/"; + prefix = Encoding.UTF8.GetBytes(prefixString); + for (int i = 0; i < count; i++) + { + keys[i] = $"{prefixString}{Guid.NewGuid()}"; + } + return keys; + } + + [Obsolete("Use Create(withChannelPrefix: false) instead", error: true)] + private IInternalConnectionMultiplexer Create() => Create(withChannelPrefix: false); + private IInternalConnectionMultiplexer Create(bool withChannelPrefix) => + Create(channelPrefix: withChannelPrefix ? "prefix:" : null); + + private RedisKey SelectKey(RedisKey[] keys) => keys[SharedRandom.Next(0, keys.Length)]; + +#if NET6_0_OR_GREATER + private static Random SharedRandom => Random.Shared; +#else + private static Random SharedRandom { get; } = new(); +#endif + + [Fact] + public async Task KeySpace_Events_Enabled() + { + // see https://redis.io/docs/latest/develop/pubsub/keyspace-notifications/#configuration + await using var conn = Create(allowAdmin: true); + int failures = 0; + foreach (var ep in conn.GetEndPoints()) + { + var server = conn.GetServer(ep); + var config = (await server.ConfigGetAsync("notify-keyspace-events")).Single(); + Log($"[{Format.ToString(ep)}] notify-keyspace-events: '{config.Value}'"); + + // this is a very broad config, but it's what we use in CI (and probably a common basic config) + if (config.Value != "AKE") + { + failures++; + } + } + // for details, check the log output + Assert.Equal(0, failures); + } + + [Fact] + public async Task KeySpace_CanSubscribe_ManualPublish() + { + await using var conn = Create(withChannelPrefix: false); + var db = conn.GetDatabase(); + + var channel = RedisChannel.KeyEvent("nonesuch"u8, database: null); + Log($"Monitoring channel: {channel}"); + var sub = conn.GetSubscriber(); + await sub.UnsubscribeAsync(channel); + + int count = 0; + await sub.SubscribeAsync(channel, (_, _) => Interlocked.Increment(ref count)); + + // to publish, we need to remove the marker that this is a multi-node channel + var asLiteral = RedisChannel.Literal(channel.ToString()); + await sub.PublishAsync(asLiteral, Guid.NewGuid().ToString()); + + int expected = GetConnectedCount(conn, channel); + await Task.Delay(100).ForAwait(); + Assert.Equal(expected, count); + } + + // this looks past the horizon to see how many connections we actually have for a given channel, + // which could be more than 1 in a cluster scenario + private static int GetConnectedCount(IConnectionMultiplexer muxer, in RedisChannel channel) + => muxer is ConnectionMultiplexer typed && typed.TryGetSubscription(channel, out var sub) + ? sub.GetConnectionCount() : 1; + + private sealed class Counter + { + private int _count; + public int Count => Volatile.Read(ref _count); + public int Increment() => Interlocked.Increment(ref _count); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task KeyEvent_CanObserveSimple_ViaCallbackHandler(bool withChannelPrefix) + { + await using var conn = Create(withChannelPrefix); + var db = conn.GetDatabase(); + + var keys = InventKeys(out var prefix); + var channel = RedisChannel.KeyEvent(KeyNotificationType.SAdd, db.Database); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsPattern); + Log($"Monitoring channel: {channel}"); + var sub = conn.GetSubscriber(); + await sub.UnsubscribeAsync(channel); + Counter callbackCount = new(), matchingEventCount = new(); + TaskCompletionSource allDone = new(); + + ConcurrentDictionary observedCounts = new(); + foreach (var key in keys) + { + observedCounts[key.ToString()] = new(); + } + + await sub.SubscribeAsync(channel, (recvChannel, recvValue) => + { + callbackCount.Increment(); + if (KeyNotification.TryParse(in recvChannel, in recvValue, out var notification) + && notification is { IsKeyEvent: true, Type: KeyNotificationType.SAdd }) + { + OnNotification(notification, prefix, matchingEventCount, observedCounts, allDone); + } + }); + + await SendAndObserveAsync(keys, db, allDone, callbackCount, observedCounts); + await sub.UnsubscribeAsync(channel); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task KeyEvent_CanObserveSimple_ViaQueue(bool withChannelPrefix) + { + await using var conn = Create(withChannelPrefix); + var db = conn.GetDatabase(); + + var keys = InventKeys(out var prefix); + var channel = RedisChannel.KeyEvent(KeyNotificationType.SAdd, db.Database); + Assert.True(channel.IsMultiNode); + Assert.False(channel.IsPattern); + Log($"Monitoring channel: {channel}"); + var sub = conn.GetSubscriber(); + await sub.UnsubscribeAsync(channel); + Counter callbackCount = new(), matchingEventCount = new(); + TaskCompletionSource allDone = new(); + + ConcurrentDictionary observedCounts = new(); + foreach (var key in keys) + { + observedCounts[key.ToString()] = new(); + } + + var queue = await sub.SubscribeAsync(channel); + _ = Task.Run(async () => + { + await foreach (var msg in queue.WithCancellation(CancellationToken)) + { + callbackCount.Increment(); + if (msg.TryParseKeyNotification(out var notification) + && notification is { IsKeyEvent: true, Type: KeyNotificationType.SAdd }) + { + OnNotification(notification, prefix, matchingEventCount, observedCounts, allDone); + } + } + }); + + await SendAndObserveAsync(keys, db, allDone, callbackCount, observedCounts); + await queue.UnsubscribeAsync(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task KeyNotification_CanObserveSimple_ViaCallbackHandler(bool withChannelPrefix) + { + await using var conn = Create(withChannelPrefix); + var db = conn.GetDatabase(); + + var keys = InventKeys(out var prefix); + var channel = RedisChannel.KeySpacePrefix(prefix, db.Database); + Assert.True(channel.IsMultiNode); + Assert.True(channel.IsPattern); + Log($"Monitoring channel: {channel}"); + var sub = conn.GetSubscriber(); + await sub.UnsubscribeAsync(channel); + Counter callbackCount = new(), matchingEventCount = new(); + TaskCompletionSource allDone = new(); + + ConcurrentDictionary observedCounts = new(); + foreach (var key in keys) + { + observedCounts[key.ToString()] = new(); + } + + var queue = await sub.SubscribeAsync(channel); + _ = Task.Run(async () => + { + await foreach (var msg in queue.WithCancellation(CancellationToken)) + { + callbackCount.Increment(); + if (msg.TryParseKeyNotification(out var notification) + && notification is { IsKeySpace: true, Type: KeyNotificationType.SAdd }) + { + OnNotification(notification, prefix, matchingEventCount, observedCounts, allDone); + } + } + }); + + await SendAndObserveAsync(keys, db, allDone, callbackCount, observedCounts); + await sub.UnsubscribeAsync(channel); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task KeyNotification_CanObserveSimple_ViaQueue(bool withChannelPrefix) + { + await using var conn = Create(withChannelPrefix); + var db = conn.GetDatabase(); + + var keys = InventKeys(out var prefix); + var channel = RedisChannel.KeySpacePrefix(prefix, db.Database); + Assert.True(channel.IsMultiNode); + Assert.True(channel.IsPattern); + Log($"Monitoring channel: {channel}"); + var sub = conn.GetSubscriber(); + await sub.UnsubscribeAsync(channel); + Counter callbackCount = new(), matchingEventCount = new(); + TaskCompletionSource allDone = new(); + + ConcurrentDictionary observedCounts = new(); + foreach (var key in keys) + { + observedCounts[key.ToString()] = new(); + } + + await sub.SubscribeAsync(channel, (recvChannel, recvValue) => + { + callbackCount.Increment(); + if (KeyNotification.TryParse(in recvChannel, in recvValue, out var notification) + && notification is { IsKeySpace: true, Type: KeyNotificationType.SAdd }) + { + OnNotification(notification, prefix, matchingEventCount, observedCounts, allDone); + } + }); + + await SendAndObserveAsync(keys, db, allDone, callbackCount, observedCounts); + await sub.UnsubscribeAsync(channel); + } + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public async Task KeyNotification_CanObserveSingleKey_ViaQueue(bool withChannelPrefix, bool withKeyPrefix) + { + await using var conn = Create(withChannelPrefix); + string keyPrefix = withKeyPrefix ? "isolated:" : ""; + byte[] keyPrefixBytes = Encoding.UTF8.GetBytes(keyPrefix); + var db = conn.GetDatabase().WithKeyPrefix(keyPrefix); + + var keys = InventKeys(out var prefix, count: 1); + Log($"Using {Encoding.UTF8.GetString(prefix)} as filter prefix, sample key: {SelectKey(keys)}"); + var channel = RedisChannel.KeySpaceSingleKey(RedisKey.WithPrefix(keyPrefixBytes, keys.Single()), db.Database); + + Assert.False(channel.IsMultiNode); + Assert.False(channel.IsPattern); + Log($"Monitoring channel: {channel}, routing via {Encoding.UTF8.GetString(channel.RoutingSpan)}"); + + var sub = conn.GetSubscriber(); + await sub.UnsubscribeAsync(channel); + Counter callbackCount = new(), matchingEventCount = new(); + TaskCompletionSource allDone = new(); + + ConcurrentDictionary observedCounts = new(); + foreach (var key in keys) + { + observedCounts[key.ToString()] = new(); + } + + var queue = await sub.SubscribeAsync(channel); + _ = Task.Run(async () => + { + await foreach (var msg in queue.WithCancellation(CancellationToken)) + { + callbackCount.Increment(); + if (msg.TryParseKeyNotification(keyPrefixBytes, out var notification) + && notification is { IsKeySpace: true, Type: KeyNotificationType.SAdd }) + { + OnNotification(notification, prefix, matchingEventCount, observedCounts, allDone); + } + } + }); + + await SendAndObserveAsync(keys, db, allDone, callbackCount, observedCounts); + await sub.UnsubscribeAsync(channel); + } + + private void OnNotification( + in KeyNotification notification, + ReadOnlySpan prefix, + Counter matchingEventCount, + ConcurrentDictionary observedCounts, + TaskCompletionSource allDone) + { + if (notification.KeyStartsWith(prefix)) // avoid problems with parallel SADD tests + { + int currentCount = matchingEventCount.Increment(); + + // get the key and check that we expected it + var recvKey = notification.GetKey(); + Assert.True(observedCounts.TryGetValue(recvKey.ToString(), out var counter)); + +#if NET9_0_OR_GREATER + // it would be more efficient to stash the alt-lookup, but that would make our API here non-viable, + // since we need to support multiple frameworks + var viaAlt = FindViaAltLookup(notification, observedCounts.GetAlternateLookup>()); + Assert.Same(counter, viaAlt); +#endif + + // accounting... + if (counter.Increment() == 1) + { + Log($"Observed key: '{recvKey}' after {currentCount} events"); + } + + if (currentCount == DefaultEventCount) + { + allDone.TrySetResult(true); + } + } + } + + private async Task SendAndObserveAsync( + RedisKey[] keys, + IDatabase db, + TaskCompletionSource allDone, + Counter callbackCount, + ConcurrentDictionary observedCounts) + { + await Task.Delay(300).ForAwait(); // give it a moment to settle + + Dictionary sentCounts = new(keys.Length); + foreach (var key in keys) + { + sentCounts[key] = new(); + } + + for (int i = 0; i < DefaultEventCount; i++) + { + var key = SelectKey(keys); + sentCounts[key].Increment(); + await db.SetAddAsync(key, i); + } + + // Wait for all events to be observed + try + { + Assert.True(await allDone.Task.WithTimeout(5000)); + } + catch (TimeoutException) when (callbackCount.Count == 0) + { + Assert.Fail($"Timeout with zero events; are keyspace events enabled?"); + } + + foreach (var key in keys) + { + Assert.Equal(sentCounts[key].Count, observedCounts[key.ToString()].Count); + } + } + +#if NET9_0_OR_GREATER + // demonstrate that we can use the alt-lookup APIs to avoid string allocations + private static Counter? FindViaAltLookup( + in KeyNotification notification, + ConcurrentDictionary.AlternateLookup> lookup) + { + // Demonstrate typical alt-lookup usage; this is an advanced topic, so it + // isn't trivial to grok, but: this is typical of perf-focused APIs. + char[]? lease = null; + const int MAX_STACK = 128; + var maxLength = notification.GetKeyMaxCharCount(); + Span scratch = maxLength <= MAX_STACK + ? stackalloc char[MAX_STACK] + : (lease = ArrayPool.Shared.Rent(maxLength)); + Assert.True(notification.TryCopyKey(scratch, out var length)); + if (!lookup.TryGetValue(scratch.Slice(0, length), out var counter)) counter = null; + if (lease is not null) ArrayPool.Shared.Return(lease); + return counter; + } +#endif +} diff --git a/tests/StackExchange.Redis.Tests/PubSubMultiserverTests.cs b/tests/StackExchange.Redis.Tests/PubSubMultiserverTests.cs index 43bb4b2b8..691232218 100644 --- a/tests/StackExchange.Redis.Tests/PubSubMultiserverTests.cs +++ b/tests/StackExchange.Redis.Tests/PubSubMultiserverTests.cs @@ -63,7 +63,7 @@ await sub.SubscribeAsync(channel, (_, val) => Assert.True(subscribedServerEndpoint.IsSubscriberConnected, "subscribedServerEndpoint.IsSubscriberConnected"); Assert.True(conn.GetSubscriptions().TryGetValue(channel, out var subscription)); - var initialServer = subscription.GetCurrentServer(); + var initialServer = subscription.GetAnyCurrentServer(); Assert.NotNull(initialServer); Assert.True(initialServer.IsConnected); Log("Connected to: " + initialServer); @@ -83,10 +83,10 @@ await sub.SubscribeAsync(channel, (_, val) => Assert.True(subscribedServerEndpoint.IsConnected, "subscribedServerEndpoint.IsConnected"); Assert.False(subscribedServerEndpoint.IsSubscriberConnected, "subscribedServerEndpoint.IsSubscriberConnected"); } - await UntilConditionAsync(TimeSpan.FromSeconds(5), () => subscription.IsConnected); - Assert.True(subscription.IsConnected); + await UntilConditionAsync(TimeSpan.FromSeconds(5), () => subscription.IsConnectedAny()); + Assert.True(subscription.IsConnectedAny()); - var newServer = subscription.GetCurrentServer(); + var newServer = subscription.GetAnyCurrentServer(); Assert.NotNull(newServer); Assert.NotEqual(newServer, initialServer); Log("Now connected to: " + newServer); @@ -148,7 +148,7 @@ await sub.SubscribeAsync( Assert.True(subscribedServerEndpoint.IsSubscriberConnected, "subscribedServerEndpoint.IsSubscriberConnected"); Assert.True(conn.GetSubscriptions().TryGetValue(channel, out var subscription)); - var initialServer = subscription.GetCurrentServer(); + var initialServer = subscription.GetAnyCurrentServer(); Assert.NotNull(initialServer); Assert.True(initialServer.IsConnected); Log("Connected to: " + initialServer); @@ -169,10 +169,10 @@ await sub.SubscribeAsync( if (expectSuccess) { - await UntilConditionAsync(TimeSpan.FromSeconds(5), () => subscription.IsConnected); - Assert.True(subscription.IsConnected); + await UntilConditionAsync(TimeSpan.FromSeconds(5), () => subscription.IsConnectedAny()); + Assert.True(subscription.IsConnectedAny()); - var newServer = subscription.GetCurrentServer(); + var newServer = subscription.GetAnyCurrentServer(); Assert.NotNull(newServer); Assert.NotEqual(newServer, initialServer); Log("Now connected to: " + newServer); @@ -180,16 +180,16 @@ await sub.SubscribeAsync( else { // This subscription shouldn't be able to reconnect by flags (demanding an unavailable server) - await UntilConditionAsync(TimeSpan.FromSeconds(5), () => subscription.IsConnected); - Assert.False(subscription.IsConnected); + await UntilConditionAsync(TimeSpan.FromSeconds(5), () => subscription.IsConnectedAny()); + Assert.False(subscription.IsConnectedAny()); Log("Unable to reconnect (as expected)"); // Allow connecting back to the original conn.AllowConnect = true; - await UntilConditionAsync(TimeSpan.FromSeconds(5), () => subscription.IsConnected); - Assert.True(subscription.IsConnected); + await UntilConditionAsync(TimeSpan.FromSeconds(5), () => subscription.IsConnectedAny()); + Assert.True(subscription.IsConnectedAny()); - var newServer = subscription.GetCurrentServer(); + var newServer = subscription.GetAnyCurrentServer(); Assert.NotNull(newServer); Assert.Equal(newServer, initialServer); Log("Now connected to: " + newServer); diff --git a/tests/StackExchange.Redis.Tests/RedisValueEquivalencyTests.cs b/tests/StackExchange.Redis.Tests/RedisValueEquivalencyTests.cs index 7f6ad1561..391a0237a 100644 --- a/tests/StackExchange.Redis.Tests/RedisValueEquivalencyTests.cs +++ b/tests/StackExchange.Redis.Tests/RedisValueEquivalencyTests.cs @@ -297,11 +297,17 @@ public void RedisValueStartsWith() Assert.False(x.StartsWith(123), LineNumber()); Assert.False(x.StartsWith(false), LineNumber()); - Assert.True(x.StartsWith(Encoding.ASCII.GetBytes("a")), LineNumber()); - Assert.True(x.StartsWith(Encoding.ASCII.GetBytes("ab")), LineNumber()); - Assert.True(x.StartsWith(Encoding.ASCII.GetBytes("abc")), LineNumber()); - Assert.False(x.StartsWith(Encoding.ASCII.GetBytes("abd")), LineNumber()); - Assert.False(x.StartsWith(Encoding.ASCII.GetBytes("abcd")), LineNumber()); + Assert.True(x.StartsWith((RedisValue)Encoding.ASCII.GetBytes("a")), LineNumber()); + Assert.True(x.StartsWith((RedisValue)Encoding.ASCII.GetBytes("ab")), LineNumber()); + Assert.True(x.StartsWith((RedisValue)Encoding.ASCII.GetBytes("abc")), LineNumber()); + Assert.False(x.StartsWith((RedisValue)Encoding.ASCII.GetBytes("abd")), LineNumber()); + Assert.False(x.StartsWith((RedisValue)Encoding.ASCII.GetBytes("abcd")), LineNumber()); + + Assert.True(x.StartsWith("a"u8), LineNumber()); + Assert.True(x.StartsWith("ab"u8), LineNumber()); + Assert.True(x.StartsWith("abc"u8), LineNumber()); + Assert.False(x.StartsWith("abd"u8), LineNumber()); + Assert.False(x.StartsWith("abcd"u8), LineNumber()); x = 10; // integers are effectively strings in this context Assert.True(x.StartsWith(1), LineNumber()); diff --git a/tests/StackExchange.Redis.Tests/SSLTests.cs b/tests/StackExchange.Redis.Tests/SSLTests.cs index 0dafe3f9b..c9c5cc2bb 100644 --- a/tests/StackExchange.Redis.Tests/SSLTests.cs +++ b/tests/StackExchange.Redis.Tests/SSLTests.cs @@ -240,7 +240,9 @@ public async Task RedisLabsSSL() Skip.IfNoConfig(nameof(TestConfig.Config.RedisLabsSslServer), TestConfig.Current.RedisLabsSslServer); Skip.IfNoConfig(nameof(TestConfig.Config.RedisLabsPfxPath), TestConfig.Current.RedisLabsPfxPath); +#pragma warning disable SYSLIB0057 var cert = new X509Certificate2(TestConfig.Current.RedisLabsPfxPath, ""); +#pragma warning restore SYSLIB0057 Assert.NotNull(cert); Log("Thumbprint: " + cert.Thumbprint); diff --git a/tests/StackExchange.Redis.Tests/StackExchange.Redis.Tests.csproj b/tests/StackExchange.Redis.Tests/StackExchange.Redis.Tests.csproj index f6e38236b..e02a6ac36 100644 --- a/tests/StackExchange.Redis.Tests/StackExchange.Redis.Tests.csproj +++ b/tests/StackExchange.Redis.Tests/StackExchange.Redis.Tests.csproj @@ -1,6 +1,7 @@  - net481;net8.0 + + net481;net10.0 Exe StackExchange.Redis.Tests true diff --git a/tests/StackExchange.Redis.Tests/SyncContextTests.cs b/tests/StackExchange.Redis.Tests/SyncContextTests.cs index b98caefeb..5feb37e3d 100644 --- a/tests/StackExchange.Redis.Tests/SyncContextTests.cs +++ b/tests/StackExchange.Redis.Tests/SyncContextTests.cs @@ -122,7 +122,7 @@ public MySyncContext(TextWriter log) private int _opCount; private void Incr() => Interlocked.Increment(ref _opCount); - public void Reset() => Thread.VolatileWrite(ref _opCount, 0); + public void Reset() => Volatile.Write(ref _opCount, 0); public override string ToString() => $"Sync context ({(IsCurrent ? "active" : "inactive")}): {OpCount}";