AccountRegistrationService now uses Redis to track registration throttling instead of spawning a bunch of async tasks
This commit is contained in:
@@ -8,21 +8,11 @@ using MareSynchronosShared.Utils.Configuration;
|
|||||||
using Microsoft.EntityFrameworkCore;
|
using Microsoft.EntityFrameworkCore;
|
||||||
using System.Text.RegularExpressions;
|
using System.Text.RegularExpressions;
|
||||||
using MareSynchronosShared.Models;
|
using MareSynchronosShared.Models;
|
||||||
|
using StackExchange.Redis;
|
||||||
|
using StackExchange.Redis.Extensions.Core.Abstractions;
|
||||||
|
|
||||||
namespace MareSynchronosAuthService.Services;
|
namespace MareSynchronosAuthService.Services;
|
||||||
|
|
||||||
internal record IpRegistrationCount
|
|
||||||
{
|
|
||||||
private int count = 1;
|
|
||||||
public int Count => count;
|
|
||||||
public Task ResetTask { get; set; }
|
|
||||||
public CancellationTokenSource ResetTaskCts { get; set; }
|
|
||||||
public void IncreaseCount()
|
|
||||||
{
|
|
||||||
Interlocked.Increment(ref count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public class AccountRegistrationService
|
public class AccountRegistrationService
|
||||||
{
|
{
|
||||||
private readonly MareMetrics _metrics;
|
private readonly MareMetrics _metrics;
|
||||||
@@ -30,52 +20,35 @@ public class AccountRegistrationService
|
|||||||
private readonly IServiceScopeFactory _serviceScopeFactory;
|
private readonly IServiceScopeFactory _serviceScopeFactory;
|
||||||
private readonly IConfigurationService<AuthServiceConfiguration> _configurationService;
|
private readonly IConfigurationService<AuthServiceConfiguration> _configurationService;
|
||||||
private readonly ILogger<AccountRegistrationService> _logger;
|
private readonly ILogger<AccountRegistrationService> _logger;
|
||||||
private readonly ConcurrentDictionary<string, IpRegistrationCount> _registrationsPerIp = new(StringComparer.Ordinal);
|
private readonly IRedisDatabase _redis;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public AccountRegistrationService(MareMetrics metrics, MareDbContext mareDbContext,
|
public AccountRegistrationService(MareMetrics metrics, MareDbContext mareDbContext,
|
||||||
IServiceScopeFactory serviceScopeFactory, IConfigurationService<AuthServiceConfiguration> configuration,
|
IServiceScopeFactory serviceScopeFactory, IConfigurationService<AuthServiceConfiguration> configuration,
|
||||||
ILogger<AccountRegistrationService> logger)
|
ILogger<AccountRegistrationService> logger, IRedisDatabase redisDb)
|
||||||
{
|
{
|
||||||
_mareDbContext = mareDbContext;
|
_mareDbContext = mareDbContext;
|
||||||
_logger = logger;
|
_logger = logger;
|
||||||
_configurationService = configuration;
|
_configurationService = configuration;
|
||||||
_metrics = metrics;
|
_metrics = metrics;
|
||||||
_serviceScopeFactory = serviceScopeFactory;
|
_serviceScopeFactory = serviceScopeFactory;
|
||||||
|
_redis = redisDb;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<RegisterReplyV2Dto> RegisterAccountAsync(string ua, string ip, string hashedSecretKey)
|
public async Task<RegisterReplyV2Dto> RegisterAccountAsync(string ua, string ip, string hashedSecretKey)
|
||||||
{
|
{
|
||||||
var reply = new RegisterReplyV2Dto();
|
var reply = new RegisterReplyV2Dto();
|
||||||
|
|
||||||
if (!ua.StartsWith("MareSynchronos/", StringComparison.Ordinal))
|
if (string.IsNullOrEmpty(ua) || !ua.StartsWith("MareSynchronos/", StringComparison.Ordinal))
|
||||||
{
|
{
|
||||||
reply.ErrorMessage = "User-Agent not allowed";
|
reply.ErrorMessage = "User-Agent not allowed";
|
||||||
return reply;
|
return reply;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_registrationsPerIp.TryGetValue(ip, out var registrationCount)
|
var registrationsByIp = await _redis.GetAsync<int>("IPREG:" + ip).ConfigureAwait(false);
|
||||||
&& registrationCount.Count >= _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpLimit), 3))
|
if (registrationsByIp >= _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpLimit), 3))
|
||||||
{
|
{
|
||||||
_logger.LogWarning("Rejecting {ip} for registration spam", ip);
|
|
||||||
|
|
||||||
if (registrationCount.ResetTask == null)
|
|
||||||
{
|
|
||||||
registrationCount.ResetTaskCts = new CancellationTokenSource();
|
|
||||||
|
|
||||||
if (registrationCount.ResetTaskCts != null)
|
|
||||||
registrationCount.ResetTaskCts.Cancel();
|
|
||||||
|
|
||||||
registrationCount.ResetTask = Task.Run(async () =>
|
|
||||||
{
|
|
||||||
await Task.Delay(TimeSpan.FromMinutes(_configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpDurationInMinutes), 10))).ConfigureAwait(false);
|
|
||||||
|
|
||||||
}).ContinueWith((t) =>
|
|
||||||
{
|
|
||||||
_registrationsPerIp.Remove(ip, out _);
|
|
||||||
}, registrationCount.ResetTaskCts.Token);
|
|
||||||
}
|
|
||||||
reply.ErrorMessage = "Too many registrations from this IP. Please try again later.";
|
reply.ErrorMessage = "Too many registrations from this IP. Please try again later.";
|
||||||
return reply;
|
return reply;
|
||||||
}
|
}
|
||||||
@@ -91,12 +64,6 @@ public class AccountRegistrationService
|
|||||||
hasValidUid = true;
|
hasValidUid = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// make the first registered user on the service to admin
|
|
||||||
if (!await _mareDbContext.Users.AnyAsync().ConfigureAwait(false))
|
|
||||||
{
|
|
||||||
user.IsAdmin = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
user.LastLoggedIn = DateTime.UtcNow;
|
user.LastLoggedIn = DateTime.UtcNow;
|
||||||
|
|
||||||
var auth = new Auth()
|
var auth = new Auth()
|
||||||
@@ -115,38 +82,14 @@ public class AccountRegistrationService
|
|||||||
reply.Success = true;
|
reply.Success = true;
|
||||||
reply.UID = user.UID;
|
reply.UID = user.UID;
|
||||||
|
|
||||||
RecordIpRegistration(ip);
|
|
||||||
|
await _redis.Database.StringIncrementAsync($"IPREG:{ip}").ConfigureAwait(false);
|
||||||
|
// Naive implementation, but should be good enough. A true sliding window *probably* isn't necessary.
|
||||||
|
await _redis.Database.KeyExpireAsync($"IPREG:{ip}", TimeSpan.
|
||||||
|
FromMinutes(_configurationService.GetValueOrDefault(nameof(
|
||||||
|
AuthServiceConfiguration.RegisterIpDurationInMinutes), 60))).
|
||||||
|
ConfigureAwait(false);
|
||||||
|
|
||||||
return reply;
|
return reply;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void RecordIpRegistration(string ip)
|
|
||||||
{
|
|
||||||
var whitelisted = _configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.WhitelistedIps), new List<string>());
|
|
||||||
if (!whitelisted.Any(w => ip.Contains(w, StringComparison.OrdinalIgnoreCase)))
|
|
||||||
{
|
|
||||||
if (_registrationsPerIp.TryGetValue(ip, out var count))
|
|
||||||
{
|
|
||||||
count.IncreaseCount();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
count = _registrationsPerIp[ip] = new IpRegistrationCount();
|
|
||||||
|
|
||||||
if (count.ResetTaskCts != null)
|
|
||||||
count.ResetTaskCts.Cancel();
|
|
||||||
|
|
||||||
count.ResetTaskCts = new CancellationTokenSource();
|
|
||||||
|
|
||||||
count.ResetTask = Task.Run(async () =>
|
|
||||||
{
|
|
||||||
await Task.Delay(TimeSpan.FromMinutes(_configurationService.GetValueOrDefault(nameof(AuthServiceConfiguration.RegisterIpDurationInMinutes), 10))).ConfigureAwait(false);
|
|
||||||
|
|
||||||
}).ContinueWith((t) =>
|
|
||||||
{
|
|
||||||
_registrationsPerIp.Remove(ip, out _);
|
|
||||||
}, count.ResetTaskCts.Token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user