diff --git a/WebApiThrottle.Tests/ThrottlingMiddlewareTests.cs b/WebApiThrottle.Tests/ThrottlingMiddlewareTests.cs new file mode 100644 index 0000000..a149686 --- /dev/null +++ b/WebApiThrottle.Tests/ThrottlingMiddlewareTests.cs @@ -0,0 +1,78 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.Owin; +using Moq; +using Xunit; + +namespace WebApiThrottle.Tests +{ + public class ThrottlingMiddlewareTests + { + private static IOwinContext CreateMockContext() + { + var context = Mock.Of(); + + Mock.Get(context).SetupGet(x => x.Request).Returns(Mock.Of()); + Mock.Get(context.Request).SetupAllProperties(); + Mock.Get(context.Request).SetupGet(x => x.Headers).Returns(Mock.Of()); + Mock.Get(context.Request.Headers).SetupGet(x => x.Keys).Returns(new List()); + context.Request.RemoteIpAddress = "127.0.0.1"; + Mock.Get(context.Request).SetupGet(x => x.Uri).Returns(new Uri($"http://{context.Request.RemoteIpAddress}")); + + Mock.Get(context).SetupGet(x => x.Response).Returns(Mock.Of()); + Mock.Get(context.Response).SetupAllProperties(); + Mock.Get(context.Response).SetupGet(x => x.Headers).Returns(Mock.Of()); + Mock.Get(context.Response.Headers).Setup(x => x.Add("Retry-After", It.IsAny())); + context.Response.StatusCode = 200; + + return context; + } + + private static ThrottlingMiddleware CreateThrottlingMiddleware() + { + return new ThrottlingMiddleware( + new DummyMiddleware(null), + new ThrottlePolicy(1) {IpThrottling = true}, + new PolicyMemoryCacheRepository(), + new MemoryCacheRepository(), + null, + null); + } + + + [Fact] + public void When_RateIsExceeded_Should_SetStatusCodeSoItsAvailableToMiddlewareFurtherDownTheStack() + { + var context = CreateMockContext(); + + var throttlingMiddleware = CreateThrottlingMiddleware(); + + throttlingMiddleware.Invoke(context).Wait(); + throttlingMiddleware.Invoke(context).Wait(); + + Assert.Equal(429, context.Response.StatusCode); + } + + [Fact] + public void When_RateIsNotExceeded_Should_NotSetStatusCode() + { + var context = CreateMockContext(); + + CreateThrottlingMiddleware().Invoke(context).Wait(); + + Assert.Equal(200, context.Response.StatusCode); + } + } + + internal class DummyMiddleware : OwinMiddleware + { + public DummyMiddleware(OwinMiddleware next) : base(next) + { + } + + public override async Task Invoke(IOwinContext context) + { + } + } +} \ No newline at end of file diff --git a/WebApiThrottle.Tests/WebApiThrottle.Tests.csproj b/WebApiThrottle.Tests/WebApiThrottle.Tests.csproj index c54f60d..535c5e0 100644 --- a/WebApiThrottle.Tests/WebApiThrottle.Tests.csproj +++ b/WebApiThrottle.Tests/WebApiThrottle.Tests.csproj @@ -30,6 +30,18 @@ 4 + + ..\packages\Castle.Core.4.0.0\lib\net45\Castle.Core.dll + + + ..\packages\Microsoft.Owin.3.0.1\lib\net45\Microsoft.Owin.dll + + + ..\packages\Moq.4.7.1\lib\net45\Moq.dll + + + ..\packages\Owin.1.0\lib\net40\Owin.dll + @@ -58,6 +70,7 @@ + diff --git a/WebApiThrottle.Tests/packages.config b/WebApiThrottle.Tests/packages.config index 44a585a..161e84a 100644 --- a/WebApiThrottle.Tests/packages.config +++ b/WebApiThrottle.Tests/packages.config @@ -1,5 +1,9 @@  + + + + diff --git a/WebApiThrottle/ThrottlingMiddleware.cs b/WebApiThrottle/ThrottlingMiddleware.cs index a9be672..c2eba2e 100644 --- a/WebApiThrottle/ThrottlingMiddleware.cs +++ b/WebApiThrottle/ThrottlingMiddleware.cs @@ -196,13 +196,10 @@ public override async Task Invoke(IOwinContext context) : "API calls quota exceeded! maximum admitted {0} per {1}."; // break execution - response.OnSendingHeaders(state => - { - var resp = (OwinResponse)state; - resp.Headers.Add("Retry-After", new string[] { core.RetryAfterFrom(throttleCounter.Timestamp, rateLimitPeriod) }); - resp.StatusCode = (int)QuotaExceededResponseCode; - resp.ReasonPhrase = string.Format(message, rateLimit, rateLimitPeriod); - }, response); + response.StatusCode = (int)QuotaExceededResponseCode; + response.ReasonPhrase = string.Format(message, rateLimit, rateLimitPeriod); + + response.Headers.Add("Retry-After", new[] { core.RetryAfterFrom(throttleCounter.Timestamp, rateLimitPeriod) }); return; }