From a741f25bde7cf5e2534e0918f6d99e4dca15ca9f Mon Sep 17 00:00:00 2001 From: Eauldane Date: Sun, 14 Sep 2025 23:40:58 +0100 Subject: [PATCH] AccountRegistrationService now uses Redis to track registration throttling instead of spawning a bunch of async tasks --- .../Services/AccountRegistrationService.cs | 87 ++++--------------- 1 file changed, 15 insertions(+), 72 deletions(-) diff --git a/MareSynchronosServer/MareSynchronosAuthService/Services/AccountRegistrationService.cs b/MareSynchronosServer/MareSynchronosAuthService/Services/AccountRegistrationService.cs index 21d8091..8eb811e 100644 --- a/MareSynchronosServer/MareSynchronosAuthService/Services/AccountRegistrationService.cs +++ b/MareSynchronosServer/MareSynchronosAuthService/Services/AccountRegistrationService.cs @@ -8,21 +8,11 @@ using MareSynchronosShared.Utils.Configuration; using Microsoft.EntityFrameworkCore; using System.Text.RegularExpressions; using MareSynchronosShared.Models; +using StackExchange.Redis; +using StackExchange.Redis.Extensions.Core.Abstractions; namespace MareSynchronosAuthService.Services; -internal record IpRegistrationCount -{ - private int count = 1; - public int Count => count; - public Task ResetTask { get; set; } - public CancellationTokenSource ResetTaskCts { get; set; } - public void IncreaseCount() - { - Interlocked.Increment(ref count); - } -} - public class AccountRegistrationService { private readonly MareMetrics _metrics; @@ -30,52 +20,35 @@ public class AccountRegistrationService private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IConfigurationService _configurationService; private readonly ILogger _logger; - private readonly ConcurrentDictionary _registrationsPerIp = new(StringComparer.Ordinal); + private readonly IRedisDatabase _redis; public AccountRegistrationService(MareMetrics metrics, MareDbContext mareDbContext, IServiceScopeFactory serviceScopeFactory, IConfigurationService configuration, - ILogger logger) + ILogger logger, IRedisDatabase redisDb) { _mareDbContext = mareDbContext; _logger = logger; _configurationService = configuration; _metrics = metrics; _serviceScopeFactory = serviceScopeFactory; + _redis = redisDb; } public async Task RegisterAccountAsync(string ua, string ip, string hashedSecretKey) { var reply = new RegisterReplyV2Dto(); - if (!ua.StartsWith("MareSynchronos/", StringComparison.Ordinal)) + if (string.IsNullOrEmpty(ua) || !ua.StartsWith("MareSynchronos/", StringComparison.Ordinal)) { reply.ErrorMessage = "User-Agent not allowed"; return reply; } - if (_registrationsPerIp.TryGetValue(ip, out var registrationCount) - && registrationCount.Count >= _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpLimit), 3)) + var registrationsByIp = await _redis.GetAsync("IPREG:" + ip).ConfigureAwait(false); + if (registrationsByIp >= _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpLimit), 3)) { - _logger.LogWarning("Rejecting {ip} for registration spam", ip); - - if (registrationCount.ResetTask == null) - { - registrationCount.ResetTaskCts = new CancellationTokenSource(); - - if (registrationCount.ResetTaskCts != null) - registrationCount.ResetTaskCts.Cancel(); - - registrationCount.ResetTask = Task.Run(async () => - { - await Task.Delay(TimeSpan.FromMinutes(_configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpDurationInMinutes), 10))).ConfigureAwait(false); - - }).ContinueWith((t) => - { - _registrationsPerIp.Remove(ip, out _); - }, registrationCount.ResetTaskCts.Token); - } reply.ErrorMessage = "Too many registrations from this IP. Please try again later."; return reply; } @@ -91,12 +64,6 @@ public class AccountRegistrationService hasValidUid = true; } - // make the first registered user on the service to admin - if (!await _mareDbContext.Users.AnyAsync().ConfigureAwait(false)) - { - user.IsAdmin = true; - } - user.LastLoggedIn = DateTime.UtcNow; var auth = new Auth() @@ -115,38 +82,14 @@ public class AccountRegistrationService reply.Success = true; reply.UID = user.UID; - RecordIpRegistration(ip); + + await _redis.Database.StringIncrementAsync($"IPREG:{ip}").ConfigureAwait(false); + // Naive implementation, but should be good enough. A true sliding window *probably* isn't necessary. + await _redis.Database.KeyExpireAsync($"IPREG:{ip}", TimeSpan. + FromMinutes(_configurationService.GetValueOrDefault(nameof( + AuthServiceConfiguration.RegisterIpDurationInMinutes), 60))). + ConfigureAwait(false); return reply; } - - private void RecordIpRegistration(string ip) - { - var whitelisted = _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.WhitelistedIps), new List()); - if (!whitelisted.Any(w => ip.Contains(w, StringComparison.OrdinalIgnoreCase))) - { - if (_registrationsPerIp.TryGetValue(ip, out var count)) - { - count.IncreaseCount(); - } - else - { - count = _registrationsPerIp[ip] = new IpRegistrationCount(); - - if (count.ResetTaskCts != null) - count.ResetTaskCts.Cancel(); - - count.ResetTaskCts = new CancellationTokenSource(); - - count.ResetTask = Task.Run(async () => - { - await Task.Delay(TimeSpan.FromMinutes(_configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpDurationInMinutes), 10))).ConfigureAwait(false); - - }).ContinueWith((t) => - { - _registrationsPerIp.Remove(ip, out _); - }, count.ResetTaskCts.Token); - } - } - } }