AccountRegistrationService now uses Redis to track registration throttling instead of spawning a bunch of async tasks

This commit is contained in:
2025-09-14 23:40:58 +01:00
parent 3aa4e62f28
commit a741f25bde

View File

@@ -8,21 +8,11 @@ using MareSynchronosShared.Utils.Configuration;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using MareSynchronosShared.Models; using MareSynchronosShared.Models;
using StackExchange.Redis;
using StackExchange.Redis.Extensions.Core.Abstractions;
namespace MareSynchronosAuthService.Services; namespace MareSynchronosAuthService.Services;
internal record IpRegistrationCount
{
private int count = 1;
public int Count => count;
public Task ResetTask { get; set; }
public CancellationTokenSource ResetTaskCts { get; set; }
public void IncreaseCount()
{
Interlocked.Increment(ref count);
}
}
public class AccountRegistrationService public class AccountRegistrationService
{ {
private readonly MareMetrics _metrics; private readonly MareMetrics _metrics;
@@ -30,52 +20,35 @@ public class AccountRegistrationService
private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IConfigurationService<AuthServiceConfiguration> _configurationService; private readonly IConfigurationService<AuthServiceConfiguration> _configurationService;
private readonly ILogger<AccountRegistrationService> _logger; private readonly ILogger<AccountRegistrationService> _logger;
private readonly ConcurrentDictionary<string, IpRegistrationCount> _registrationsPerIp = new(StringComparer.Ordinal); private readonly IRedisDatabase _redis;
public AccountRegistrationService(MareMetrics metrics, MareDbContext mareDbContext, public AccountRegistrationService(MareMetrics metrics, MareDbContext mareDbContext,
IServiceScopeFactory serviceScopeFactory, IConfigurationService<AuthServiceConfiguration> configuration, IServiceScopeFactory serviceScopeFactory, IConfigurationService<AuthServiceConfiguration> configuration,
ILogger<AccountRegistrationService> logger) ILogger<AccountRegistrationService> logger, IRedisDatabase redisDb)
{ {
_mareDbContext = mareDbContext; _mareDbContext = mareDbContext;
_logger = logger; _logger = logger;
_configurationService = configuration; _configurationService = configuration;
_metrics = metrics; _metrics = metrics;
_serviceScopeFactory = serviceScopeFactory; _serviceScopeFactory = serviceScopeFactory;
_redis = redisDb;
} }
public async Task<RegisterReplyV2Dto> RegisterAccountAsync(string ua, string ip, string hashedSecretKey) public async Task<RegisterReplyV2Dto> RegisterAccountAsync(string ua, string ip, string hashedSecretKey)
{ {
var reply = new RegisterReplyV2Dto(); var reply = new RegisterReplyV2Dto();
if (!ua.StartsWith("MareSynchronos/", StringComparison.Ordinal)) if (string.IsNullOrEmpty(ua) || !ua.StartsWith("MareSynchronos/", StringComparison.Ordinal))
{ {
reply.ErrorMessage = "User-Agent not allowed"; reply.ErrorMessage = "User-Agent not allowed";
return reply; return reply;
} }
if (_registrationsPerIp.TryGetValue(ip, out var registrationCount) var registrationsByIp = await _redis.GetAsync<int>("IPREG:" + ip).ConfigureAwait(false);
&& registrationCount.Count >= _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpLimit), 3)) if (registrationsByIp >= _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpLimit), 3))
{ {
_logger.LogWarning("Rejecting {ip} for registration spam", ip);
if (registrationCount.ResetTask == null)
{
registrationCount.ResetTaskCts = new CancellationTokenSource();
if (registrationCount.ResetTaskCts != null)
registrationCount.ResetTaskCts.Cancel();
registrationCount.ResetTask = Task.Run(async () =>
{
await Task.Delay(TimeSpan.FromMinutes(_configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpDurationInMinutes), 10))).ConfigureAwait(false);
}).ContinueWith((t) =>
{
_registrationsPerIp.Remove(ip, out _);
}, registrationCount.ResetTaskCts.Token);
}
reply.ErrorMessage = "Too many registrations from this IP. Please try again later."; reply.ErrorMessage = "Too many registrations from this IP. Please try again later.";
return reply; return reply;
} }
@@ -91,12 +64,6 @@ public class AccountRegistrationService
hasValidUid = true; hasValidUid = true;
} }
// make the first registered user on the service to admin
if (!await _mareDbContext.Users.AnyAsync().ConfigureAwait(false))
{
user.IsAdmin = true;
}
user.LastLoggedIn = DateTime.UtcNow; user.LastLoggedIn = DateTime.UtcNow;
var auth = new Auth() var auth = new Auth()
@@ -115,38 +82,14 @@ public class AccountRegistrationService
reply.Success = true; reply.Success = true;
reply.UID = user.UID; reply.UID = user.UID;
RecordIpRegistration(ip);
await _redis.Database.StringIncrementAsync($"IPREG:{ip}").ConfigureAwait(false);
// Naive implementation, but should be good enough. A true sliding window *probably* isn't necessary.
await _redis.Database.KeyExpireAsync($"IPREG:{ip}", TimeSpan.
FromMinutes(_configurationService.GetValueOrDefault(nameof(
AuthServiceConfiguration.RegisterIpDurationInMinutes), 60))).
ConfigureAwait(false);
return reply; return reply;
} }
private void RecordIpRegistration(string ip)
{
var whitelisted = _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.WhitelistedIps), new List<string>());
if (!whitelisted.Any(w => ip.Contains(w, StringComparison.OrdinalIgnoreCase)))
{
if (_registrationsPerIp.TryGetValue(ip, out var count))
{
count.IncreaseCount();
}
else
{
count = _registrationsPerIp[ip] = new IpRegistrationCount();
if (count.ResetTaskCts != null)
count.ResetTaskCts.Cancel();
count.ResetTaskCts = new CancellationTokenSource();
count.ResetTask = Task.Run(async () =>
{
await Task.Delay(TimeSpan.FromMinutes(_configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpDurationInMinutes), 10))).ConfigureAwait(false);
}).ContinueWith((t) =>
{
_registrationsPerIp.Remove(ip, out _);
}, count.ResetTaskCts.Token);
}
}
}
} }