Skip to content

Commit

Permalink
Add IP Extractor Middleware
Browse files Browse the repository at this point in the history
Closes #92
  • Loading branch information
AlexMacocian committed Aug 13, 2024
1 parent 09104af commit f94df5d
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 5 deletions.
3 changes: 1 addition & 2 deletions GuildWarsPartySearch/Endpoints/LiveFeed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ await this.SendMessage(new PartySearchList

public override async Task SocketAccepted(CancellationToken cancellationToken)
{
var ipAddress = this.Context?.Connection.RemoteIpAddress?.ToString();
var ipAddress = this.Context?.GetClientIP();
var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), ipAddress ?? string.Empty);
if (!await this.liveFeedService.AddClient(this.WebSocket!, ipAddress, this.Context?.GetPermissionLevel() ?? Models.PermissionLevel.None, cancellationToken))
{
Expand All @@ -49,7 +49,6 @@ public override async Task SocketAccepted(CancellationToken cancellationToken)
}

scopedLogger.LogDebug("Client accepted to livefeed");

scopedLogger.LogDebug("Sending all party searches");
var updates = await this.partySearchService.GetAllPartySearches(cancellationToken);
await this.SendMessage(new PartySearchList { Searches = updates }, cancellationToken);
Expand Down
19 changes: 19 additions & 0 deletions GuildWarsPartySearch/Extensions/HttpContextExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace GuildWarsPartySearch.Server.Extensions;
public static class HttpContextExtensions
{
private const string PermissionLevelKey = "PermissionLevel";
private const string ClientIPKey = "ClientIP";

public static void SetPermissionLevel(this HttpContext context, PermissionLevel permissionLevel)
{
Expand All @@ -24,4 +25,22 @@ public static PermissionLevel GetPermissionLevel(this HttpContext context)

return permissionLevel;
}

public static void SetClientIP(this HttpContext context, string ip)
{
context.ThrowIfNull()
.Items.Add(ClientIPKey, ip);
}

public static string GetClientIP(this HttpContext context)
{
context.ThrowIfNull();
if (!context.Items.TryGetValue(ClientIPKey, out var ip) ||
ip is not string ipStr)
{
throw new InvalidOperationException("Unable to extract IP from context");
}

return ipStr;
}
}
3 changes: 2 additions & 1 deletion GuildWarsPartySearch/Launch/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ private static async Task Main()
}

var app = builder.Build();
app.UseMiddleware<PermissioningMiddleware>()
app.UseMiddleware<IPExtractingMiddleware>()
.UseMiddleware<PermissioningMiddleware>()
.UseSwagger()
.UseWebSockets()
.UseRouting()
Expand Down
3 changes: 2 additions & 1 deletion GuildWarsPartySearch/Launch/ServerConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static ILoggingBuilder SetupLogging(this ILoggingBuilder builder)
{
builder.ThrowIfNull()
.ClearProviders()
.AddConsole();
.AddConsole(o => o.TimestampFormat = "[yyyy-MM-dd HH:mm:ss] ");

return builder;
}
Expand Down Expand Up @@ -84,6 +84,7 @@ public static IServiceCollection SetupServices(this IServiceCollection services)
services.AddSingleton<IPartySearchDatabase, PartySearchSqliteDatabase>();
services.AddSingleton<IBotHistoryDatabase, BotHistorySqliteDatabase>();
services.AddSingleton<IApiKeyDatabase, ApiKeySqliteDatabase>();
services.AddScoped<IPExtractingMiddleware>();
services.AddScoped<PermissioningMiddleware>();
services.AddScoped<UserAgentRequired>();
services.AddScoped<AdminPermissionRequired>();
Expand Down
52 changes: 52 additions & 0 deletions GuildWarsPartySearch/Middleware/IPExtractingMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using GuildWarsPartySearch.Server.Extensions;
using System.Core.Extensions;
using System.Extensions;

namespace GuildWarsPartySearch.Server.Middleware;

public sealed class IPExtractingMiddleware : IMiddleware
{
private const string XForwardedForHeaderKey = "X-Forwarded-For";
private const string CFConnectingIPHeaderKey = "CF-Connecting-IP";

private readonly ILogger<IPExtractingMiddleware> logger;

public IPExtractingMiddleware(
ILogger<IPExtractingMiddleware> logger)
{
this.logger = logger.ThrowIfNull();
}

public async Task InvokeAsync(HttpContext context, RequestDelegate next)
{
var address = context.Connection.RemoteIpAddress?.ToString();
var scopedLogger = this.logger.CreateScopedLogger(nameof(this.InvokeAsync), address ?? string.Empty);
scopedLogger.LogDebug($"Received request");
if (context.Request.Headers.TryGetValue(XForwardedForHeaderKey, out var xForwardedForValues))
{
scopedLogger.LogDebug($"X-Forwarded-For {string.Join(',', xForwardedForValues.Select(s => s))}");
}

if (xForwardedForValues.FirstOrDefault() is string xForwardedIpAddress)
{
context.SetClientIP(xForwardedIpAddress);
await next(context);
return;
}

if (context.Request.Headers.TryGetValue(CFConnectingIPHeaderKey, out var cfConnectingIpValues))
{
scopedLogger.LogDebug($"CF-Connecting-IP {string.Join(',', cfConnectingIpValues.Select(s => s))}");
}

if (cfConnectingIpValues.FirstOrDefault() is string xCfIpAddress)
{
context.SetClientIP(xCfIpAddress);
await next(context);
return;
}

context.SetClientIP(address ?? throw new InvalidOperationException("Unable to extract client IP address"));
await next(context);
}
}
1 change: 0 additions & 1 deletion GuildWarsPartySearch/Services/Feed/LiveFeedService.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using GuildWarsPartySearch.Server.Models;
using GuildWarsPartySearch.Server.Models.Endpoints;
using GuildWarsPartySearch.Server.Services.Database;
using System.Core.Extensions;
using System.Extensions;
using System.Net.WebSockets;
Expand Down

0 comments on commit f94df5d

Please sign in to comment.