From d6d2717771f9c10f4e719e8b801d48e26905b315 Mon Sep 17 00:00:00 2001 From: RaidMax Date: Sat, 27 May 2023 12:15:22 -0500 Subject: [PATCH] possible fix for remotely loaded plugins --- Application/Misc/RemoteAssemblyHandler.cs | 28 ++--- Application/Plugin/PluginImporter.cs | 141 ++++++++++++---------- 2 files changed, 93 insertions(+), 76 deletions(-) diff --git a/Application/Misc/RemoteAssemblyHandler.cs b/Application/Misc/RemoteAssemblyHandler.cs index f2214d99c..6b7671b1e 100644 --- a/Application/Misc/RemoteAssemblyHandler.cs +++ b/Application/Misc/RemoteAssemblyHandler.cs @@ -13,10 +13,10 @@ namespace IW4MAdmin.Application.Misc { public class RemoteAssemblyHandler : IRemoteAssemblyHandler { - private const int keyLength = 32; - private const int tagLength = 16; - private const int nonceLength = 12; - private const int iterationCount = 10000; + private const int KeyLength = 32; + private const int TagLength = 16; + private const int NonceLength = 12; + private const int IterationCount = 10000; private readonly ApplicationConfiguration _appconfig; private readonly ILogger _logger; @@ -30,7 +30,7 @@ namespace IW4MAdmin.Application.Misc public IEnumerable DecryptAssemblies(string[] encryptedAssemblies) { return DecryptContent(encryptedAssemblies) - .Select(decryptedAssembly => Assembly.Load(decryptedAssembly)); + .Select(Assembly.Load); } public IEnumerable DecryptScripts(string[] encryptedScripts) @@ -38,24 +38,24 @@ namespace IW4MAdmin.Application.Misc return DecryptContent(encryptedScripts).Select(decryptedScript => Encoding.UTF8.GetString(decryptedScript)); } - private byte[][] DecryptContent(string[] content) + private IEnumerable DecryptContent(string[] content) { if (string.IsNullOrEmpty(_appconfig.Id) || string.IsNullOrWhiteSpace(_appconfig.SubscriptionId)) { _logger.LogWarning($"{nameof(_appconfig.Id)} and {nameof(_appconfig.SubscriptionId)} must be provided to attempt loading remote assemblies/scripts"); - return new byte[0][]; + return Array.Empty(); } var assemblies = content.Select(piece => { - byte[] byteContent = Convert.FromBase64String(piece); - byte[] encryptedContent = byteContent.Take(byteContent.Length - (tagLength + nonceLength)).ToArray(); - byte[] tag = byteContent.Skip(byteContent.Length - (tagLength + nonceLength)).Take(tagLength).ToArray(); - byte[] nonce = byteContent.Skip(byteContent.Length - nonceLength).Take(nonceLength).ToArray(); - byte[] decryptedContent = new byte[encryptedContent.Length]; + var byteContent = Convert.FromBase64String(piece); + var encryptedContent = byteContent.Take(byteContent.Length - (TagLength + NonceLength)).ToArray(); + var tag = byteContent.Skip(byteContent.Length - (TagLength + NonceLength)).Take(TagLength).ToArray(); + var nonce = byteContent.Skip(byteContent.Length - NonceLength).Take(NonceLength).ToArray(); + var decryptedContent = new byte[encryptedContent.Length]; - var keyGen = new Rfc2898DeriveBytes(Encoding.UTF8.GetBytes(_appconfig.SubscriptionId), Encoding.UTF8.GetBytes(_appconfig.Id.ToString()), iterationCount, HashAlgorithmName.SHA512); - var encryption = new AesGcm(keyGen.GetBytes(keyLength)); + var keyGen = new Rfc2898DeriveBytes(Encoding.UTF8.GetBytes(_appconfig.SubscriptionId), Encoding.UTF8.GetBytes(_appconfig.Id), IterationCount, HashAlgorithmName.SHA512); + var encryption = new AesGcm(keyGen.GetBytes(KeyLength)); try { diff --git a/Application/Plugin/PluginImporter.cs b/Application/Plugin/PluginImporter.cs index cc5cb0804..9df88890a 100644 --- a/Application/Plugin/PluginImporter.cs +++ b/Application/Plugin/PluginImporter.cs @@ -20,13 +20,21 @@ namespace IW4MAdmin.Application.Plugin public class PluginImporter : IPluginImporter { private IEnumerable _pluginSubscription; - private static readonly string PluginDir = "Plugins"; + private const string PluginDir = "Plugins"; private const string PluginV2Match = "^ *((?:var|const|let) +init)|function init"; private readonly ILogger _logger; private readonly IRemoteAssemblyHandler _remoteAssemblyHandler; private readonly IMasterApi _masterApi; private readonly ApplicationConfiguration _appConfig; + private static readonly Type[] FilterTypes = + { + typeof(IPlugin), + typeof(IPluginV2), + typeof(Command), + typeof(IBaseConfiguration) + }; + public PluginImporter(ILogger logger, ApplicationConfiguration appConfig, IMasterApi masterApi, IRemoteAssemblyHandler remoteAssemblyHandler) { @@ -77,74 +85,80 @@ namespace IW4MAdmin.Application.Plugin public (IEnumerable, IEnumerable, IEnumerable) DiscoverAssemblyPluginImplementations() { var pluginDir = $"{Utilities.OperatingDirectory}{PluginDir}{Path.DirectorySeparatorChar}"; - var pluginTypes = Enumerable.Empty(); - var commandTypes = Enumerable.Empty(); - var configurationTypes = Enumerable.Empty(); + var pluginTypes = new List(); + var commandTypes = new List(); + var configurationTypes = new List(); - if (Directory.Exists(pluginDir)) + if (!Directory.Exists(pluginDir)) { - var dllFileNames = Directory.GetFiles(pluginDir, "*.dll"); - _logger.LogDebug("Discovered {Count} potential plugin assemblies", dllFileNames.Length); + return (pluginTypes, commandTypes, configurationTypes); + } - if (dllFileNames.Length > 0) + var dllFileNames = Directory.GetFiles(pluginDir, "*.dll"); + _logger.LogDebug("Discovered {Count} potential plugin assemblies", dllFileNames.Length); + + if (!dllFileNames.Any()) + { + return (pluginTypes, commandTypes, configurationTypes); + } + + // we only want to load the most recent assembly in case of duplicates + var assemblies = dllFileNames.Select(Assembly.LoadFrom) + .Union(GetRemoteAssemblies()) + .GroupBy(assembly => assembly.FullName).Select(assembly => + assembly.OrderByDescending(asm => asm.GetName().Version).First()); + + var eligibleAssemblyTypes = assemblies + .SelectMany(asm => { - // we only want to load the most recent assembly in case of duplicates - var assemblies = dllFileNames.Select(name => Assembly.LoadFrom(name)) - .Union(GetRemoteAssemblies()) - .GroupBy(assembly => assembly.FullName).Select(assembly => assembly.OrderByDescending(asm => asm.GetName().Version).First()); + try + { + return asm.GetTypes(); + } + catch + { + return Enumerable.Empty(); + } + }).Where(type => + FilterTypes.Any(filterType => type.GetInterface(filterType.Name, false) != null) || + (type.IsClass && FilterTypes.Contains(type.BaseType))); + + foreach (var assemblyType in eligibleAssemblyTypes) + { + var isPlugin = + (assemblyType.GetInterface(nameof(IPlugin), false) ?? + assemblyType.GetInterface(nameof(IPluginV2), false)) != null && + (!assemblyType.Namespace?.StartsWith(nameof(SharedLibraryCore)) ?? false); - pluginTypes = assemblies - .SelectMany(asm => - { - try - { - return asm.GetTypes(); - } - catch - { - return Enumerable.Empty(); - } - }) - .Where(assemblyType => (assemblyType.GetInterface(nameof(IPlugin), false) ?? assemblyType.GetInterface(nameof(IPluginV2), false)) != null) - .Where(assemblyType => !assemblyType.Namespace?.StartsWith(nameof(SharedLibraryCore)) ?? false); + if (isPlugin) + { + pluginTypes.Add(assemblyType); + continue; + } - _logger.LogDebug("Discovered {count} plugin implementations", pluginTypes.Count()); + var isCommand = assemblyType.IsClass && assemblyType.BaseType == typeof(Command) && + (!assemblyType.Namespace?.StartsWith(nameof(SharedLibraryCore)) ?? false); - commandTypes = assemblies - .SelectMany(asm =>{ - try - { - return asm.GetTypes(); - } - catch - { - return Enumerable.Empty(); - } - }) - .Where(assemblyType => assemblyType.IsClass && assemblyType.BaseType == typeof(Command)) - .Where(assemblyType => !assemblyType.Namespace?.StartsWith(nameof(SharedLibraryCore)) ?? false); + if (isCommand) + { + commandTypes.Add(assemblyType); + continue; + } - _logger.LogDebug("Discovered {Count} plugin commands", commandTypes.Count()); + var isConfiguration = assemblyType.IsClass && + assemblyType.GetInterface(nameof(IBaseConfiguration), false) != null && + (!assemblyType.Namespace?.StartsWith(nameof(SharedLibraryCore)) ?? false); - configurationTypes = assemblies - .SelectMany(asm => { - try - { - return asm.GetTypes(); - } - catch - { - return Enumerable.Empty(); - } - }) - .Where(asmType => - asmType.IsClass && asmType.GetInterface(nameof(IBaseConfiguration), false) != null) - .Where(assemblyType => !assemblyType.Namespace?.StartsWith(nameof(SharedLibraryCore)) ?? false); - - _logger.LogDebug("Discovered {Count} configuration implementations", configurationTypes.Count()); + if (isConfiguration) + { + configurationTypes.Add(assemblyType); } } + _logger.LogDebug("Discovered {Count} plugin implementations", pluginTypes.Count); + _logger.LogDebug("Discovered {Count} plugin commands", pluginTypes.Count); + _logger.LogDebug("Discovered {Count} configuration implementations", pluginTypes.Count); + return (pluginTypes, commandTypes, configurationTypes); } @@ -152,10 +166,11 @@ namespace IW4MAdmin.Application.Plugin { try { - if (_pluginSubscription == null) - _pluginSubscription = _masterApi.GetPluginSubscription(Guid.Parse(_appConfig.Id), _appConfig.SubscriptionId).Result; + _pluginSubscription ??= _masterApi + .GetPluginSubscription(Guid.Parse(_appConfig.Id), _appConfig.SubscriptionId).Result; - return _remoteAssemblyHandler.DecryptAssemblies(_pluginSubscription.Where(sub => sub.Type == PluginType.Binary).Select(sub => sub.Content).ToArray()); + return _remoteAssemblyHandler.DecryptAssemblies(_pluginSubscription + .Where(sub => sub.Type == PluginType.Binary).Select(sub => sub.Content).ToArray()); } catch (Exception ex) @@ -169,9 +184,11 @@ namespace IW4MAdmin.Application.Plugin { try { - _pluginSubscription ??= _masterApi.GetPluginSubscription(Guid.Parse(_appConfig.Id), _appConfig.SubscriptionId).Result; + _pluginSubscription ??= _masterApi + .GetPluginSubscription(Guid.Parse(_appConfig.Id), _appConfig.SubscriptionId).Result; - return _remoteAssemblyHandler.DecryptScripts(_pluginSubscription.Where(sub => sub.Type == PluginType.Script).Select(sub => sub.Content).ToArray()); + return _remoteAssemblyHandler.DecryptScripts(_pluginSubscription + .Where(sub => sub.Type == PluginType.Script).Select(sub => sub.Content).ToArray()); } catch (Exception ex)