using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; using Data.Abstractions; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging; using Timer = System.Timers.Timer; namespace Data.Helpers { public class DataValueCache : IDataValueCache where TEntityType : class { private readonly ILogger _logger; private readonly IDatabaseContextFactory _contextFactory; private readonly ConcurrentDictionary>> _cacheStates = new(); private readonly string _defaultKey = null; private bool _autoRefresh; private const int DefaultExpireMinutes = 15; private Timer _timer; private class CacheState { public string Key { get; set; } public DateTime LastRetrieval { get; set; } public TimeSpan ExpirationTime { get; set; } public Func, IEnumerable, CancellationToken, Task> Getter { get; set; } public TCacheType Value { get; set; } public bool IsSet { get; set; } public bool IsExpired => ExpirationTime != TimeSpan.MaxValue && (DateTime.Now - LastRetrieval.Add(ExpirationTime)).TotalSeconds > 0; } public DataValueCache(ILogger> logger, IDatabaseContextFactory contextFactory) { _logger = logger; _contextFactory = contextFactory; } ~DataValueCache() { _timer?.Stop(); _timer?.Dispose(); } public void SetCacheItem(Func, CancellationToken, Task> getter, string key, TimeSpan? expirationTime = null, bool autoRefresh = false) { SetCacheItem((set, _, token) => getter(set, token), key, null, expirationTime, autoRefresh); } public void SetCacheItem(Func, IEnumerable, CancellationToken, Task> getter, string key, IEnumerable ids = null, TimeSpan? expirationTime = null, bool autoRefresh = false) { ids ??= new[] { _defaultKey }; if (!_cacheStates.ContainsKey(key)) { _cacheStates.TryAdd(key, new Dictionary>()); } var cacheInstance = _cacheStates[key]; var id = GenerateKeyFromIds(ids); lock (_cacheStates) { if (cacheInstance.ContainsKey(id)) { return; } } var state = new CacheState { Key = key, Getter = getter, ExpirationTime = expirationTime ?? TimeSpan.FromMinutes(DefaultExpireMinutes) }; lock (_cacheStates) { cacheInstance.Add(id, state); } _autoRefresh = autoRefresh; if (!_autoRefresh || expirationTime == TimeSpan.MaxValue) { return; } _timer = new Timer(state.ExpirationTime.TotalMilliseconds); _timer.Elapsed += async (sender, args) => await RunCacheUpdate(state, ids, CancellationToken.None); _timer.Start(); } public Task GetCacheItem(string keyName, CancellationToken cancellationToken = default) => GetCacheItem(keyName, null, cancellationToken); public async Task GetCacheItem(string keyName, IEnumerable ids = null, CancellationToken cancellationToken = default) { if (!_cacheStates.ContainsKey(keyName)) { throw new ArgumentException("No cache found for key {key}", keyName); } var cacheInstance = _cacheStates[keyName]; CacheState state; lock (_cacheStates) { state = ids is null ? cacheInstance.Values.First() : _cacheStates[keyName][GenerateKeyFromIds(ids)]; } // when auto refresh is off we want to check the expiration and value // when auto refresh is on, we want to only check the value, because it'll be refreshed automatically if ((state.IsExpired || !state.IsSet) && !_autoRefresh || _autoRefresh && !state.IsSet) { await RunCacheUpdate(state, ids, cancellationToken); } return state.Value; } private async Task RunCacheUpdate(CacheState state, IEnumerable ids, CancellationToken token) { try { _logger.LogDebug("Running update for {ClassName} {@State}", GetType().Name, state); await using var context = _contextFactory.CreateContext(false); var set = context.Set(); var value = await state.Getter(set, ids, token); state.Value = value; state.IsSet = true; state.LastRetrieval = DateTime.Now; } catch (Exception ex) { _logger.LogError(ex, "Could not get cached value for {Key}", state.Key); } } private static string GenerateKeyFromIds(IEnumerable ids) => string.Join("_", ids.Select(id => id?.ToString() ?? "null")); } }