diff --git a/src/DispatchR/Configuration/ServiceRegistrator.cs b/src/DispatchR/Configuration/ServiceRegistrator.cs index 99f726d..b716195 100644 --- a/src/DispatchR/Configuration/ServiceRegistrator.cs +++ b/src/DispatchR/Configuration/ServiceRegistrator.cs @@ -11,143 +11,154 @@ public static void RegisterHandlers(IServiceCollection services, List allT Type requestHandlerType, Type pipelineBehaviorType, Type streamRequestHandlerType, Type streamPipelineBehaviorType, bool withPipelines, List? pipelineOrder = null) { + var handlerTypes = new[] { requestHandlerType, streamRequestHandlerType }; + var pipelineTypes = new[] { pipelineBehaviorType, streamPipelineBehaviorType }; + var allHandlers = allTypes - .Where(p => + .Where(type => { - var @interface = p.GetInterfaces().First(i => i.IsGenericType); - return new[] { requestHandlerType, streamRequestHandlerType } - .Contains(@interface.GetGenericTypeDefinition()); - }).ToList(); + var genericInterfaces = type.GetInterfaces() + .Where(i => i.IsGenericType) + .Select(i => i.GetGenericTypeDefinition()) + .ToList(); + + return genericInterfaces.Intersect(handlerTypes).Any() && + !genericInterfaces.Intersect(pipelineTypes).Any(); + }) + .ToList(); var allPipelines = allTypes - .Where(p => - { - var @interface = p.GetInterfaces().First(i => i.IsGenericType); - return new[] { pipelineBehaviorType, streamPipelineBehaviorType } - .Contains(@interface.GetGenericTypeDefinition()); - }).ToList(); + .Where(type => type.GetInterfaces() + .Where(i => i.IsGenericType) + .Select(i => i.GetGenericTypeDefinition()) + .Intersect(pipelineTypes) + .Any()) + .ToList(); foreach (var handler in allHandlers) { - object key = handler.GUID; - var handlerType = requestHandlerType; - var behaviorType = pipelineBehaviorType; + var handlerInterfaces = handler.GetInterfaces() + .Where(p => p.IsGenericType && handlerTypes.Contains(p.GetGenericTypeDefinition())) + .ToList(); - var isStream = handler.GetInterfaces() - .Any(p => p.IsGenericType && p.GetGenericTypeDefinition() == streamRequestHandlerType); - if (isStream) + foreach (var handlerInterface in handlerInterfaces) { - handlerType = streamRequestHandlerType; - behaviorType = streamPipelineBehaviorType; - } - - services.AddKeyedScoped(typeof(IRequestHandler), key, handler); - - var handlerInterface = handler.GetInterfaces() - .First(p => p.IsGenericType && p.GetGenericTypeDefinition() == handlerType); + object key = Guid.NewGuid(); + var handlerType = requestHandlerType; + var behaviorType = pipelineBehaviorType; - // find pipelines - if (withPipelines) - { - var pipelines = allPipelines - .Where(p => - { - var interfaces = p.GetInterfaces(); - if (p.IsGenericType) - { - // handle generic pipelines - return interfaces - .FirstOrDefault(inter => - inter.IsGenericType && - inter.GetGenericTypeDefinition() == behaviorType) - ?.GetInterfaces().First().GetGenericTypeDefinition() == - handlerInterface.GetGenericTypeDefinition(); - } + var isStream = handlerInterface.GetGenericTypeDefinition() == streamRequestHandlerType; + if (isStream) + { + handlerType = streamRequestHandlerType; + behaviorType = streamPipelineBehaviorType; + } - return interfaces - .FirstOrDefault(inter => - inter.IsGenericType && - inter.GetGenericTypeDefinition() == behaviorType) - ?.GetInterfaces().First() == handlerInterface; - }).ToList(); + services.AddKeyedScoped(typeof(IRequestHandler), key, handler); - // Sort pipelines by the specified order passed via ConfigurationOptions - if (pipelineOrder is { Count: > 0 }) + // find pipelines + if (withPipelines) { - pipelines = pipelines - .OrderBy(p => + var pipelines = allPipelines + .Where(p => { - var idx = pipelineOrder.IndexOf(p); - return idx == -1 ? int.MaxValue : idx; - }) - .ToList(); - pipelines.Reverse(); - } + var interfaces = p.GetInterfaces(); + if (p.IsGenericType) + { + // handle generic pipelines + return interfaces + .FirstOrDefault(inter => + inter.IsGenericType && + inter.GetGenericTypeDefinition() == behaviorType) + ?.GetInterfaces().First().GetGenericTypeDefinition() == + handlerInterface.GetGenericTypeDefinition(); + } - foreach (var pipeline in pipelines) - { - if (pipeline.IsGenericType) + return interfaces + .FirstOrDefault(inter => + inter.IsGenericType && + inter.GetGenericTypeDefinition() == behaviorType) + ?.GetInterfaces().First() == handlerInterface; + }).ToList(); + + // Sort pipelines by the specified order passed via ConfigurationOptions + if (pipelineOrder is { Count: > 0 }) { - var genericHandlerResponseType = pipeline.GetInterfaces().First(inter => - inter.IsGenericType && - inter.GetGenericTypeDefinition() == behaviorType).GenericTypeArguments[1]; + pipelines = pipelines + .OrderBy(p => + { + var idx = pipelineOrder.IndexOf(p); + return idx == -1 ? int.MaxValue : idx; + }) + .ToList(); + pipelines.Reverse(); + } - var genericHandlerResponseIsAwaitable = IsAwaitable(genericHandlerResponseType); - var handlerResponseTypeIsAwaitable = IsAwaitable(handlerInterface.GenericTypeArguments[1]); - if (genericHandlerResponseIsAwaitable ^ handlerResponseTypeIsAwaitable) + foreach (var pipeline in pipelines) + { + if (pipeline.IsGenericType) { - continue; - } + var genericHandlerResponseType = pipeline.GetInterfaces().First(inter => + inter.IsGenericType && + inter.GetGenericTypeDefinition() == behaviorType).GenericTypeArguments[1]; - var responseTypeArg = handlerInterface.GenericTypeArguments[1]; - if (genericHandlerResponseIsAwaitable && handlerResponseTypeIsAwaitable) - { - var areGenericTypeArgumentsInHandlerInterfaceMismatched = - genericHandlerResponseType.IsGenericType && - handlerInterface.GenericTypeArguments[1].IsGenericType && - genericHandlerResponseType.GetGenericTypeDefinition() != - handlerInterface.GenericTypeArguments[1].GetGenericTypeDefinition(); - - if (areGenericTypeArgumentsInHandlerInterfaceMismatched || - genericHandlerResponseType.IsGenericType ^ - handlerInterface.GenericTypeArguments[1].IsGenericType) + var genericHandlerResponseIsAwaitable = IsAwaitable(genericHandlerResponseType); + var handlerResponseTypeIsAwaitable = IsAwaitable(handlerInterface.GenericTypeArguments[1]); + if (genericHandlerResponseIsAwaitable ^ handlerResponseTypeIsAwaitable) { continue; } - // register async generic pipelines - if (responseTypeArg.GenericTypeArguments.Any()) + var responseTypeArg = handlerInterface.GenericTypeArguments[1]; + if (genericHandlerResponseIsAwaitable && handlerResponseTypeIsAwaitable) { - responseTypeArg = responseTypeArg.GenericTypeArguments[0]; + var areGenericTypeArgumentsInHandlerInterfaceMismatched = + genericHandlerResponseType.IsGenericType && + handlerInterface.GenericTypeArguments[1].IsGenericType && + genericHandlerResponseType.GetGenericTypeDefinition() != + handlerInterface.GenericTypeArguments[1].GetGenericTypeDefinition(); + + if (areGenericTypeArgumentsInHandlerInterfaceMismatched || + genericHandlerResponseType.IsGenericType ^ + handlerInterface.GenericTypeArguments[1].IsGenericType) + { + continue; + } + + // register async generic pipelines + if (responseTypeArg.GenericTypeArguments.Any()) + { + responseTypeArg = responseTypeArg.GenericTypeArguments[0]; + } } - } - var closedGenericType = pipeline.MakeGenericType(handlerInterface.GenericTypeArguments[0], - responseTypeArg); - services.AddKeyedScoped(typeof(IRequestHandler), key, closedGenericType); - } - else - { - services.AddKeyedScoped(typeof(IRequestHandler), key, pipeline); + var closedGenericType = pipeline.MakeGenericType(handlerInterface.GenericTypeArguments[0], + responseTypeArg); + services.AddKeyedScoped(typeof(IRequestHandler), key, closedGenericType); + } + else + { + services.AddKeyedScoped(typeof(IRequestHandler), key, pipeline); + } } } - } - - services.AddScoped(handlerInterface, sp => - { - var pipelinesWithHandler = Unsafe - .As(sp.GetKeyedServices(key)); - IRequestHandler lastPipeline = pipelinesWithHandler[0]; - for (int i = 1; i < pipelinesWithHandler.Length; i++) + services.AddScoped(handlerInterface, sp => { - var pipeline = pipelinesWithHandler[i]; - pipeline.SetNext(lastPipeline); - lastPipeline = pipeline; - } + var pipelinesWithHandler = Unsafe + .As(sp.GetKeyedServices(key)); - return lastPipeline; - }); + IRequestHandler lastPipeline = pipelinesWithHandler[0]; + for (int i = 1; i < pipelinesWithHandler.Length; i++) + { + var pipeline = pipelinesWithHandler[i]; + pipeline.SetNext(lastPipeline); + lastPipeline = pipeline; + } + + return lastPipeline; + }); + } } } @@ -155,34 +166,14 @@ public static void RegisterNotification(IServiceCollection services, List Type syncNotificationHandlerType) { var allNotifications = allTypes - .Where(p => - { - return p.GetInterfaces() - .Where(i => i.IsGenericType) - .Select(i => i.GetGenericTypeDefinition()) - .Any(i => new[] - { - syncNotificationHandlerType - }.Contains(i)); - }) - .GroupBy(p => - { - var @interface = p.GetInterfaces() - .Where(i => i.IsGenericType) - .First(i => new[] - { - syncNotificationHandlerType - }.Contains(i.GetGenericTypeDefinition())); - return @interface.GenericTypeArguments.First(); - }) + .SelectMany(handlerType => handlerType.GetInterfaces() + .Where(i => i.IsGenericType && syncNotificationHandlerType == i.GetGenericTypeDefinition()) + .Select(i => new { HandlerType = handlerType, Interface = i })) .ToList(); foreach (var notification in allNotifications) { - foreach (var types in notification.ToList()) - { - services.AddScoped(typeof(INotificationHandler<>).MakeGenericType(notification.Key), types); - } + services.AddScoped(notification.Interface, notification.HandlerType); } } diff --git a/src/DispatchR/Extensions/ServiceCollectionExtensions.cs b/src/DispatchR/Extensions/ServiceCollectionExtensions.cs index f522205..f8c164d 100644 --- a/src/DispatchR/Extensions/ServiceCollectionExtensions.cs +++ b/src/DispatchR/Extensions/ServiceCollectionExtensions.cs @@ -50,30 +50,24 @@ public static IServiceCollection AddDispatchR(this IServiceCollection services, var streamPipelineBehaviorType = typeof(IStreamPipelineBehavior<,>); var syncNotificationHandlerType = typeof(INotificationHandler<>); + var otherHandlerTypes = new HashSet() + { + pipelineBehaviorType, + streamRequestHandlerType, + streamPipelineBehaviorType, + syncNotificationHandlerType + }; + var allTypes = configurationOptions.Assemblies.SelectMany(x => x.GetTypes()).Distinct() .Where(p => - { - var interfaces = p.GetInterfaces(); - return interfaces.Length >= 1 && + p.GetInterfaces() is { Length: >= 1 } interfaces && interfaces .Where(i => i.IsGenericType) .Select(i => i.GetGenericTypeDefinition()) - .Any(i => - { - if (i == requestHandlerType) - { - return configurationOptions.IsHandlerIncluded(p); - } - - return new[] - { - pipelineBehaviorType, - streamRequestHandlerType, - streamPipelineBehaviorType, - syncNotificationHandlerType - }.Contains(i); - }); - }).ToList(); + .Any(i => i == requestHandlerType + ? configurationOptions.IsHandlerIncluded(p) + : otherHandlerTypes.Contains(i))) + .ToList(); if (configurationOptions.RegisterNotifications) { diff --git a/tests/DispatchR.IntegrationTest/NotificationTests.cs b/tests/DispatchR.IntegrationTest/NotificationTests.cs index 43cc932..06dd78a 100644 --- a/tests/DispatchR.IntegrationTest/NotificationTests.cs +++ b/tests/DispatchR.IntegrationTest/NotificationTests.cs @@ -20,7 +20,7 @@ public async Task Publish_CallsAllHandlers_WhenMultipleHandlersAreRegistered() cfg.RegisterPipelines = false; cfg.RegisterNotifications = true; }); - + var spyPipelineOneMock = new Mock>(); var spyPipelineTwoMock = new Mock>(); var spyPipelineThreeMock = new Mock>(); @@ -28,23 +28,23 @@ public async Task Publish_CallsAllHandlers_WhenMultipleHandlersAreRegistered() spyPipelineOneMock.Setup(p => p.Handle(It.IsAny(), It.IsAny())); spyPipelineTwoMock.Setup(p => p.Handle(It.IsAny(), It.IsAny())); spyPipelineThreeMock.Setup(p => p.Handle(It.IsAny(), It.IsAny())); - + services.AddScoped>(sp => spyPipelineOneMock.Object); services.AddScoped>(sp => spyPipelineTwoMock.Object); services.AddScoped>(sp => spyPipelineThreeMock.Object); - + var serviceProvider = services.BuildServiceProvider(); var mediator = serviceProvider.GetRequiredService(); - + // Act await mediator.Publish(new MultiHandlersNotification(Guid.Empty), CancellationToken.None); - + // Assert spyPipelineOneMock.Verify(p => p.Handle(It.IsAny(), It.IsAny()), Times.Exactly(1)); spyPipelineTwoMock.Verify(p => p.Handle(It.IsAny(), It.IsAny()), Times.Exactly(1)); spyPipelineThreeMock.Verify(p => p.Handle(It.IsAny(), It.IsAny()), Times.Exactly(1)); } - + [Fact] public async Task PublishObject_CallsAllHandlers_WhenMultipleHandlersAreRegistered() { @@ -56,7 +56,7 @@ public async Task PublishObject_CallsAllHandlers_WhenMultipleHandlersAreRegister cfg.RegisterPipelines = false; cfg.RegisterNotifications = true; }); - + var spyPipelineOneMock = new Mock>(); var spyPipelineTwoMock = new Mock>(); var spyPipelineThreeMock = new Mock>(); @@ -64,21 +64,43 @@ public async Task PublishObject_CallsAllHandlers_WhenMultipleHandlersAreRegister spyPipelineOneMock.Setup(p => p.Handle(It.IsAny(), It.IsAny())); spyPipelineTwoMock.Setup(p => p.Handle(It.IsAny(), It.IsAny())); spyPipelineThreeMock.Setup(p => p.Handle(It.IsAny(), It.IsAny())); - + services.AddScoped>(sp => spyPipelineOneMock.Object); services.AddScoped>(sp => spyPipelineTwoMock.Object); services.AddScoped>(sp => spyPipelineThreeMock.Object); - + var serviceProvider = services.BuildServiceProvider(); var mediator = serviceProvider.GetRequiredService(); - + // Act object notificationObject = new MultiHandlersNotification(Guid.Empty); await mediator.Publish(notificationObject, CancellationToken.None); - + // Assert spyPipelineOneMock.Verify(p => p.Handle(It.IsAny(), It.IsAny()), Times.Exactly(1)); spyPipelineTwoMock.Verify(p => p.Handle(It.IsAny(), It.IsAny()), Times.Exactly(1)); spyPipelineThreeMock.Verify(p => p.Handle(It.IsAny(), It.IsAny()), Times.Exactly(1)); } -} \ No newline at end of file + + [Fact] + public void RegisterNotification_SingleClassWithMultipleNotificationInterfaces_ResolvesAllHandlers() + { + // Arrange + var services = new ServiceCollection(); + services.AddDispatchR(cfg => + { + cfg.Assemblies.Add(typeof(Fixture).Assembly); + cfg.RegisterNotifications = true; + }); + + var serviceProvider = services.BuildServiceProvider(); + + // Act + var handlers1 = serviceProvider.GetServices>(); + var handlers2 = serviceProvider.GetServices>(); + + // Assert + Assert.Contains(handlers1, h => h is MultiNotificationHandler); + Assert.Contains(handlers2, h => h is MultiNotificationHandler); + } +} diff --git a/tests/DispatchR.IntegrationTest/RequestHandlerTests.cs b/tests/DispatchR.IntegrationTest/RequestHandlerTests.cs index ad86a11..8be80f7 100644 --- a/tests/DispatchR.IntegrationTest/RequestHandlerTests.cs +++ b/tests/DispatchR.IntegrationTest/RequestHandlerTests.cs @@ -1,6 +1,8 @@ using DispatchR.Abstractions.Send; using DispatchR.Extensions; using DispatchR.TestCommon.Fixtures; +using DispatchR.TestCommon.Fixtures.SendRequest; +using DispatchR.TestCommon.Fixtures.SendRequest.Task; using DispatchR.TestCommon.Fixtures.SendRequest.ValueTask; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -54,4 +56,26 @@ public async Task Send_UsesPipelineBehaviors_RequestWithSinglePipelines() spyPipelineOneMock.Verify(p => p.Handle(It.IsAny(), It.IsAny()), Times.Exactly(1)); spyPipelineTwoMock.Verify(p => p.Handle(It.IsAny(), It.IsAny()), Times.Exactly(1)); } + + [Fact] + public void RegisterHandlers_SingleClassWithMultipleRequestInterfaces_ResolvesAllHandlers() + { + // Arrange + var services = new ServiceCollection(); + services.AddDispatchR(cfg => + { + cfg.Assemblies.Add(typeof(Fixture).Assembly); + cfg.RegisterPipelines = false; + }); + + var serviceProvider = services.BuildServiceProvider(); + + // Act + var handlers1 = serviceProvider.GetServices>>(); + var handlers2 = serviceProvider.GetServices>>(); + + // Assert + Assert.Contains(handlers1, h => h is MultiRequestHandler); + Assert.Contains(handlers2, h => h is MultiRequestHandler); + } } \ No newline at end of file diff --git a/tests/DispatchR.TestCommon/DispatchR.TestCommon.csproj b/tests/DispatchR.TestCommon/DispatchR.TestCommon.csproj index bb8e64a..8e1adc7 100644 --- a/tests/DispatchR.TestCommon/DispatchR.TestCommon.csproj +++ b/tests/DispatchR.TestCommon/DispatchR.TestCommon.csproj @@ -5,6 +5,10 @@ enable enable + + + + diff --git a/tests/DispatchR.TestCommon/Fixtures/Notification/MultiHandlersNotification2.cs b/tests/DispatchR.TestCommon/Fixtures/Notification/MultiHandlersNotification2.cs new file mode 100644 index 0000000..ecb11c5 --- /dev/null +++ b/tests/DispatchR.TestCommon/Fixtures/Notification/MultiHandlersNotification2.cs @@ -0,0 +1,5 @@ +using DispatchR.Abstractions.Notification; + +namespace DispatchR.TestCommon.Fixtures.Notification; + +public sealed record MultiHandlersNotification2(Guid Id) : INotification; \ No newline at end of file diff --git a/tests/DispatchR.TestCommon/Fixtures/Notification/MultiNotificationHandler.cs b/tests/DispatchR.TestCommon/Fixtures/Notification/MultiNotificationHandler.cs new file mode 100644 index 0000000..35dd829 --- /dev/null +++ b/tests/DispatchR.TestCommon/Fixtures/Notification/MultiNotificationHandler.cs @@ -0,0 +1,18 @@ +using DispatchR.Abstractions.Notification; + +namespace DispatchR.TestCommon.Fixtures.Notification; + +public sealed class MultiNotificationHandler : + INotificationHandler, + INotificationHandler +{ + public ValueTask Handle(MultiHandlersNotification request, CancellationToken cancellationToken) + { + return ValueTask.CompletedTask; + } + + public ValueTask Handle(MultiHandlersNotification2 request, CancellationToken cancellationToken) + { + return ValueTask.CompletedTask; + } +} diff --git a/tests/DispatchR.TestCommon/Fixtures/SendRequest/MultiRequestHandler.cs b/tests/DispatchR.TestCommon/Fixtures/SendRequest/MultiRequestHandler.cs new file mode 100644 index 0000000..443671f --- /dev/null +++ b/tests/DispatchR.TestCommon/Fixtures/SendRequest/MultiRequestHandler.cs @@ -0,0 +1,20 @@ +using DispatchR.Abstractions.Send; +using DispatchR.TestCommon.Fixtures.SendRequest.Task; +using DispatchR.TestCommon.Fixtures.SendRequest.ValueTask; + +namespace DispatchR.TestCommon.Fixtures.SendRequest; + +public sealed class MultiRequestHandler : + IRequestHandler>, + IRequestHandler> +{ + public Task Handle(PingTask request, CancellationToken cancellationToken) + { + return System.Threading.Tasks.Task.FromResult(1); + } + + public ValueTask Handle(PingValueTask request, CancellationToken cancellationToken) + { + return System.Threading.Tasks.ValueTask.FromResult(1); + } +}