Skip to content

Commit 1d2bc3f

Browse files
committed
Make request limit provider async and scoped
1 parent d038db4 commit 1d2bc3f

9 files changed

Lines changed: 85 additions & 43 deletions

File tree

samples/Sample.Bruno/Authenticated Endpoints/Get Addresses.bru

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@ meta {
66

77
get {
88
url: {{baseUrl}}/addresses
9+
body: none
10+
auth: none
911
}
1012

1113
tests {
1214
test("Status code is 200", function() {
1315
expect(res.getStatus()).to.equal(200);
1416
});
15-
17+
1618
test("Response is array", function() {
1719
const body = res.getBody();
1820
expect(Array.isArray(body)).to.be.true;
1921
});
20-
22+
2123
test("Address objects have required properties", function() {
2224
const body = res.getBody();
2325
if (body.length > 0) {

samples/Sample.MinimalApi/Program.cs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,22 @@ public static void Main(string[] args)
2323
builder.Services.AddOpenApi();
2424
builder.Services.AddEndpointsApiExplorer();
2525

26-
builder.Services.AddHmacRateLimiter(configure: options =>
27-
{
28-
options.RequestsPerPeriod = 10;
29-
options.BurstFactor = 1;
30-
});
26+
builder.Services
27+
.AddHmacRateLimiter(
28+
policyName: "UserPolicy",
29+
configure: options =>
30+
{
31+
options.RequestsPerPeriod = 2;
32+
options.BurstFactor = 1;
33+
}
34+
)
35+
.AddHmacRateLimiter(
36+
configure: options =>
37+
{
38+
options.RequestsPerPeriod = 10;
39+
options.BurstFactor = 1;
40+
}
41+
);
3142

3243
var application = builder.Build();
3344

@@ -48,19 +59,20 @@ public static void Main(string[] args)
4859

4960
application
5061
.MapPost("/weather", (Weather weather) => Results.Ok(weather))
51-
.WithName("PostWeather");
62+
.WithName("PostWeather")
63+
.RequireHmacRateLimiting();
5264

5365
application
5466
.MapGet("/users", () => UserFaker.Instance.Generate(10))
5567
.WithName("GetUsers")
5668
.RequireAuthorization()
57-
.RequireHmacRateLimiting();
69+
.RequireHmacRateLimiting("UserPolicy");
5870

5971
application
6072
.MapPost("/users", (User user) => Results.Ok(user))
6173
.WithName("PostUser")
6274
.RequireAuthorization()
63-
.RequireHmacRateLimiting();
75+
.RequireHmacRateLimiting("UserPolicy");
6476

6577
application
6678
.MapGet("/addresses", () => AddressFaker.Instance.Generate(10))

src/HashGate.AspNetCore/HmacAuthenticationHandler.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ public partial class HmacAuthenticationHandler : AuthenticationHandler<HmacAuthe
3535
/// <param name="options">The options monitor for <see cref="HmacAuthenticationSchemeOptions"/>.</param>
3636
/// <param name="logger">The logger factory.</param>
3737
/// <param name="encoder">The URL encoder.</param>
38-
/// <param name="keyProvider">The HMAC key provider used to retrieve client secrets.</param>
3938
public HmacAuthenticationHandler(
4039
IOptionsMonitor<HmacAuthenticationSchemeOptions> options,
4140
ILoggerFactory logger,

src/HashGate.AspNetCore/HmacAuthenticationSchemeOptions.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,13 @@ public class HmacAuthenticationSchemeOptions : AuthenticationSchemeOptions
3030
/// When set, the keyed service registered under this key is used instead of the default provider.
3131
/// </summary>
3232
public string? ProviderServiceKey { get; set; }
33+
34+
/// <summary>
35+
/// Gets or sets the duration for which claims are cached.
36+
/// </summary>
37+
/// <value>
38+
/// The time span for caching claims. If <c>null</c>, claims are not cached.
39+
/// </value>
40+
public TimeSpan? CacheTime { get; set; }
3341
}
3442

src/HashGate.AspNetCore/IRequestLimitProvider.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace HashGate.AspNetCore;
55
/// </summary>
66
/// <remarks>
77
/// When registered via <see cref="RequestLimitExtensions.AddHmacRateLimiter{TProvider}"/>,
8-
/// the middleware calls <see cref="Get"/> on every request to determine the effective limit
8+
/// the middleware calls <see cref="GetAsync"/> on every request to determine the effective limit
99
/// for the requesting client. Return <see langword="null"/> to fall back to the defaults
1010
/// configured on <see cref="RequestLimitOptions"/>.
1111
/// </remarks>
@@ -16,5 +16,6 @@ public interface IRequestLimitProvider
1616
/// client should use the defaults from <see cref="RequestLimitOptions"/>.
1717
/// </summary>
1818
/// <param name="client">The client identifier extracted from the HMAC Authorization header.</param>
19-
RequestLimit? Get(string client);
19+
/// <param name="cancellationToken">A token to cancel the asynchronous operation.</param>
20+
Task<RequestLimit?> GetAsync(string client, CancellationToken cancellationToken = default);
2021
}

src/HashGate.AspNetCore/RequestLimitExtensions.cs

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ public static class RequestLimitExtensions
1919
/// </summary>
2020
public const string Policy = "hmac-client";
2121

22+
// Key used to stash the resolved policy name in HttpContext.Items so OnRejectedAsync
23+
// can look up the correct named RequestLimitOptions without knowing the policy name.
24+
private static readonly object _policyItemKey = new();
25+
2226
/// <summary>Returns the resolved rate limiting policy name, lowercased.</summary>
2327
/// <param name="policyName">
2428
/// Optional suffix that scopes the policy name, enabling multiple independent limiters in
@@ -60,7 +64,7 @@ public static IServiceCollection AddHmacRateLimiter(
6064
/// </summary>
6165
/// <remarks>
6266
/// <typeparamref name="TProvider"/> is registered as a <see cref="IRequestLimitProvider"/> service.
63-
/// On each request the provider's <see cref="IRequestLimitProvider.Get"/> is called; returning
67+
/// On each request the provider's <see cref="IRequestLimitProvider.GetAsync"/> is called; returning
6468
/// <see langword="null"/> falls back to <see cref="RequestLimitOptions.RequestsPerPeriod"/> and
6569
/// <see cref="RequestLimitOptions.BurstFactor"/>. Use <see cref="RequestLimitProvider"/> for
6670
/// configuration-backed per-client limits.
@@ -81,7 +85,9 @@ public static IServiceCollection AddHmacRateLimiter<TProvider>(
8185
Action<RequestLimitOptions>? configure = null)
8286
where TProvider : class, IRequestLimitProvider
8387
{
84-
services.TryAdd(ServiceDescriptor.Describe(typeof(IRequestLimitProvider), typeof(TProvider), lifetime));
88+
var policy = PolicyName(policyName);
89+
services.TryAdd(new ServiceDescriptor(typeof(IRequestLimitProvider), policy, typeof(TProvider), lifetime));
90+
8591
return AddHmacRateLimiterCore(services, policyName, configure);
8692
}
8793

@@ -108,16 +114,16 @@ private static IServiceCollection AddHmacRateLimiterCore(
108114
{
109115
var policy = PolicyName(policyName);
110116

111-
services.AddOptions<RequestLimitOptions>();
117+
services.AddOptions<RequestLimitOptions>(policy);
112118
if (configure != null)
113-
services.Configure(configure);
119+
services.Configure(policy, configure);
114120

115121
services.AddRateLimiter(options =>
116122
{
117123
options.RejectionStatusCode = StatusCodes.Status429TooManyRequests;
118124
options.OnRejected = OnRejectedAsync;
119125

120-
options.AddPolicy(policy, Partition);
126+
options.AddPolicy(policy, httpContext => Partition(httpContext, policy));
121127
});
122128

123129
return services;
@@ -127,7 +133,9 @@ private static async ValueTask OnRejectedAsync(OnRejectedContext context, Cancel
127133
{
128134
var httpContext = context.HttpContext;
129135

130-
var opts = httpContext.RequestServices.GetRequiredService<IOptions<RequestLimitOptions>>().Value;
136+
// Look up the policy name and options to determine the appropriate Retry-After value for this request.
137+
var policyName = httpContext.Items[_policyItemKey] as string ?? Options.DefaultName;
138+
var opts = httpContext.RequestServices.GetRequiredService<IOptionsMonitor<RequestLimitOptions>>().Get(policyName);
131139

132140
// Prefer the lease's own retry hint; token buckets replenish continuously so the lease
133141
// knows exactly when tokens will be available. Fall back to a short window within the period.
@@ -142,10 +150,14 @@ private static async ValueTask OnRejectedAsync(OnRejectedContext context, Cancel
142150
await httpContext.Response.WriteAsync($"Rate limit exceeded. Retry after {retrySeconds}s.", token);
143151
}
144152

145-
private static RateLimitPartition<string> Partition(HttpContext httpContext)
153+
private static RateLimitPartition<string> Partition(HttpContext httpContext, string policy)
146154
{
147155
var opts = httpContext.RequestServices
148-
.GetRequiredService<IOptions<RequestLimitOptions>>().Value;
156+
.GetRequiredService<IOptionsMonitor<RequestLimitOptions>>()
157+
.Get(policy);
158+
159+
// Stash the policy name so OnRejectedAsync can look up the same named options.
160+
httpContext.Items[_policyItemKey] = policy;
149161

150162
var authorizationHeader = httpContext.Request.Headers.Authorization.ToString();
151163

@@ -161,9 +173,16 @@ private static RateLimitPartition<string> Partition(HttpContext httpContext)
161173
var endpoint = opts.EndpointSelector(httpContext).ToLowerInvariant();
162174

163175
// Per-client override: provider returns null → use options defaults.
164-
// GetService returns null when no IRequestLimitProvider is registered (non-generic overload).
165-
var provider = httpContext.RequestServices.GetService<IRequestLimitProvider>();
166-
var limit = provider?.Get(client) ?? new RequestLimit(opts.RequestsPerPeriod, opts.BurstFactor);
176+
// Keyed lookup targets the provider registered for this specific policy;
177+
// non-keyed fallback preserves backward compat if no keyed registration exists.
178+
var provider = httpContext.RequestServices.GetKeyedService<IRequestLimitProvider>(policy)
179+
?? httpContext.RequestServices.GetService<IRequestLimitProvider>();
180+
181+
// default to options if provider doesn't exist
182+
// ASP.NET Core's partition callback is synchronous; GetAwaiter().GetResult() is safe here
183+
// because ASP.NET Core has no SynchronizationContext.
184+
var limit = provider?.GetAsync(client, httpContext.RequestAborted).GetAwaiter().GetResult()
185+
?? new RequestLimit(opts.RequestsPerPeriod, opts.BurstFactor);
167186

168187
// Include a content-derived version so that limit changes in configuration
169188
// cause new partition keys — and thus fresh token buckets — rather than

src/HashGate.AspNetCore/RequestLimitProvider.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ public RequestLimitProvider(
5252
/// <see cref="RequestLimitOptions"/> rather than being treated as zero.
5353
/// </summary>
5454
/// <param name="client">The client identifier extracted from the HMAC Authorization header.</param>
55-
public RequestLimit? Get(string client)
55+
/// <param name="cancellationToken">A token to cancel the asynchronous operation.</param>
56+
public Task<RequestLimit?> GetAsync(string client, CancellationToken cancellationToken = default)
5657
{
5758
var opts = _options.CurrentValue;
5859

@@ -66,14 +67,14 @@ public RequestLimitProvider(
6667
"Client '{Client}' not found in rate limit configuration section '{SectionName}'.",
6768
client, opts.SectionName);
6869

69-
return null;
70+
return Task.FromResult<RequestLimit?>(null);
7071
}
7172

7273
// Read each field independently so partial configuration is valid.
7374
// A missing field (null) falls back to the option default rather than 0.
7475
var rpp = clientSection.GetValue<int?>("RequestsPerPeriod") ?? opts.RequestsPerPeriod;
7576
var bf = clientSection.GetValue<int?>("BurstFactor") ?? opts.BurstFactor;
7677

77-
return new RequestLimit(RequestsPerPeriod: rpp, BurstFactor: bf);
78+
return Task.FromResult<RequestLimit?>(new RequestLimit(RequestsPerPeriod: rpp, BurstFactor: bf));
7879
}
7980
}

test/HashGate.AspNetCore.Tests/RequestLimitProviderTests.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ private static RequestLimitProvider CreateProvider(
2929
// -----------------------------------------------------------------------
3030

3131
[Fact]
32-
public void Get_WhenClientNotInConfiguration_ReturnsNull()
32+
public async Task GetAsync_WhenClientNotInConfiguration_ReturnsNull()
3333
{
3434
var provider = CreateProvider();
3535

36-
var result = provider.Get("unknown-client");
36+
var result = await provider.GetAsync("unknown-client", TestContext.Current.CancellationToken);
3737

3838
Assert.Null(result);
3939
}
@@ -43,22 +43,22 @@ public void Get_WhenClientNotInConfiguration_ReturnsNull()
4343
// -----------------------------------------------------------------------
4444

4545
[Fact]
46-
public void Get_WhenClientFullyConfigured_ReturnsConfiguredRequestsPerPeriod()
46+
public async Task GetAsync_WhenClientFullyConfigured_ReturnsConfiguredRequestsPerPeriod()
4747
{
4848
var provider = CreateProvider();
4949

50-
var result = provider.Get("full-client");
50+
var result = await provider.GetAsync("full-client", TestContext.Current.CancellationToken);
5151

5252
Assert.NotNull(result);
5353
Assert.Equal(100, result.Value.RequestsPerPeriod);
5454
}
5555

5656
[Fact]
57-
public void Get_WhenClientFullyConfigured_ReturnsConfiguredBurstFactor()
57+
public async Task GetAsync_WhenClientFullyConfigured_ReturnsConfiguredBurstFactor()
5858
{
5959
var provider = CreateProvider();
6060

61-
var result = provider.Get("full-client");
61+
var result = await provider.GetAsync("full-client", TestContext.Current.CancellationToken);
6262

6363
Assert.NotNull(result);
6464
Assert.Equal(3, result.Value.BurstFactor);
@@ -69,27 +69,27 @@ public void Get_WhenClientFullyConfigured_ReturnsConfiguredBurstFactor()
6969
// -----------------------------------------------------------------------
7070

7171
[Fact]
72-
public void Get_WhenBurstFactorNotConfigured_FallsBackToOptionsDefault()
72+
public async Task GetAsync_WhenBurstFactorNotConfigured_FallsBackToOptionsDefault()
7373
{
7474
const int defaultBurstFactor = 7;
7575
var provider = CreateProvider(configure: o => o.BurstFactor = defaultBurstFactor);
7676

7777
// rpp-only-client has RequestsPerPeriod=50 in config but no BurstFactor
78-
var result = provider.Get("rpp-only-client");
78+
var result = await provider.GetAsync("rpp-only-client", TestContext.Current.CancellationToken);
7979

8080
Assert.NotNull(result);
8181
Assert.Equal(50, result.Value.RequestsPerPeriod);
8282
Assert.Equal(defaultBurstFactor, result.Value.BurstFactor); // fell back
8383
}
8484

8585
[Fact]
86-
public void Get_WhenRequestsPerPeriodNotConfigured_FallsBackToOptionsDefault()
86+
public async Task GetAsync_WhenRequestsPerPeriodNotConfigured_FallsBackToOptionsDefault()
8787
{
8888
const int defaultRpp = 99;
8989
var provider = CreateProvider(configure: o => o.RequestsPerPeriod = defaultRpp);
9090

9191
// bf-only-client has BurstFactor=5 in config but no RequestsPerPeriod
92-
var result = provider.Get("bf-only-client");
92+
var result = await provider.GetAsync("bf-only-client", TestContext.Current.CancellationToken);
9393

9494
Assert.NotNull(result);
9595
Assert.Equal(defaultRpp, result.Value.RequestsPerPeriod); // fell back
@@ -101,23 +101,23 @@ public void Get_WhenRequestsPerPeriodNotConfigured_FallsBackToOptionsDefault()
101101
// -----------------------------------------------------------------------
102102

103103
[Fact]
104-
public void Get_WithCustomSectionName_ReadsFromCorrectSection()
104+
public async Task GetAsync_WithCustomSectionName_ReadsFromCorrectSection()
105105
{
106106
var provider = CreateProvider(configure: o => o.SectionName = "CustomRateLimits");
107107

108-
var result = provider.Get("custom-client");
108+
var result = await provider.GetAsync("custom-client", TestContext.Current.CancellationToken);
109109

110110
Assert.NotNull(result);
111111
Assert.Equal(new RequestLimit(RequestsPerPeriod: 200, BurstFactor: 4), result.Value);
112112
}
113113

114114
[Fact]
115-
public void Get_WithCustomSectionName_WhenClientNotInSection_ReturnsNull()
115+
public async Task GetAsync_WithCustomSectionName_WhenClientNotInSection_ReturnsNull()
116116
{
117117
var provider = CreateProvider(configure: o => o.SectionName = "CustomRateLimits");
118118

119119
// full-client exists under HmacRateLimits but not CustomRateLimits
120-
var result = provider.Get("full-client");
120+
var result = await provider.GetAsync("full-client", TestContext.Current.CancellationToken);
121121

122122
Assert.Null(result);
123123
}

test/HashGate.Integration.Tests/Fixtures/TestApplicationFactory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ private static void ConfigureApp(IApplicationBuilder app)
7575
});
7676
}
7777

78-
private static async Task<IResult> EchoPostAsync(HttpContext ctx)
78+
private static async Task<IResult> EchoPostAsync(HttpRequest request)
7979
{
80-
using var reader = new StreamReader(ctx.Request.Body, Encoding.UTF8);
80+
using var reader = new StreamReader(request.Body, Encoding.UTF8);
8181
var body = await reader.ReadToEndAsync();
8282
return Results.Ok(new { body });
8383
}

0 commit comments

Comments
 (0)