diff --git a/src/NuGet.Services.Configuration/SecretDictionary.cs b/src/NuGet.Services.Configuration/SecretDictionary.cs index 44262f1526..7b75588d0d 100644 --- a/src/NuGet.Services.Configuration/SecretDictionary.cs +++ b/src/NuGet.Services.Configuration/SecretDictionary.cs @@ -1,11 +1,10 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections; using System.Collections.Generic; using System.Linq; -using System.Threading.Tasks; using NuGet.Services.KeyVault; namespace NuGet.Services.Configuration @@ -122,14 +121,9 @@ private string InjectOrSkip(string key, string value) { if (!_notInjectedKeys.Contains(key)) { - return Inject(value).Result; + return _secretInjector.Inject(value); } return value; } - - private Task Inject(string value) - { - return _secretInjector.InjectAsync(value); - } } } diff --git a/src/NuGet.Services.KeyVault/CachingSecretReader.cs b/src/NuGet.Services.KeyVault/CachingSecretReader.cs index c0a09b5fac..4bbf0a8e4a 100644 --- a/src/NuGet.Services.KeyVault/CachingSecretReader.cs +++ b/src/NuGet.Services.KeyVault/CachingSecretReader.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -29,6 +29,16 @@ public CachingSecretReader(ISecretReader secretReader, _refreshIntervalBeforeExpiry = TimeSpan.FromSeconds(refreshIntervalBeforeExpirySec); } + public string GetSecret(string secretName) + { + return GetSecret(secretName, logger: null); + } + + public string GetSecret(string secretName, ILogger logger) + { + return GetSecretObject(secretName, logger).Value; + } + public async Task GetSecretAsync(string secretName) { return await GetSecretAsync(secretName, logger: null); @@ -39,32 +49,51 @@ public async Task GetSecretAsync(string secretName, ILogger logger) return (await GetSecretObjectAsync(secretName, logger)).Value; } - public async Task GetSecretObjectAsync(string secretName) + public ISecret GetSecretObject(string secretName) { - return await GetSecretObjectAsync(secretName, logger: null); + return GetSecretObject(secretName, logger: null); } - public async Task GetSecretObjectAsync(string secretName, ILogger logger) + public ISecret GetSecretObject(string secretName, ILogger logger) { - if (string.IsNullOrEmpty(secretName)) + if (TryGetCachedSecretObject(secretName, logger, out var cachedSecret)) { - throw new ArgumentException("Null or empty secret name", nameof(secretName)); + return cachedSecret; } - // If the cache contains the secret and it is not expired, return the cached value. + var start = DateTimeOffset.UtcNow; + + var updatedValue = new CachedSecret(_internalReader.GetSecretObject(secretName)); + + return UpdateCachedSecret(secretName, logger, start, updatedValue); + } + + public async Task GetSecretObjectAsync(string secretName) + { + return await GetSecretObjectAsync(secretName, logger: null); + } + + public async Task GetSecretObjectAsync(string secretName, ILogger logger) + { if (TryGetCachedSecretObject(secretName, logger, out var cachedSecret)) { return cachedSecret; } var start = DateTimeOffset.UtcNow; - // The cache does not contain a fresh copy of the secret. Fetch and cache the secret. + var updatedValue = new CachedSecret(await _internalReader.GetSecretObjectAsync(secretName)); + + return UpdateCachedSecret(secretName, logger, start, updatedValue); + } + + private ISecret UpdateCachedSecret(string secretName, ILogger logger, DateTimeOffset start, CachedSecret updatedValue) + { var updatedSecret = _cache.AddOrUpdate(secretName, updatedValue, (key, old) => updatedValue).Secret; logger?.LogInformation("Refreshed secret {SecretName}, Expiring at: {ExpirationTime}. Took {ElapsedMilliseconds}ms.", updatedSecret.Name, - updatedSecret.Expiration == null ? "null" : ((DateTimeOffset) updatedSecret.Expiration).UtcDateTime.ToString(), + updatedSecret.Expiration == null ? "null" : ((DateTimeOffset)updatedSecret.Expiration).UtcDateTime.ToString(), (DateTimeOffset.UtcNow - start).TotalMilliseconds.ToString("F2")); return updatedSecret; @@ -87,6 +116,11 @@ public bool TryGetCachedSecret(string secretName, ILogger logger, out string sec public bool TryGetCachedSecretObject(string secretName, ILogger logger, out ISecret secretObject) { + if (string.IsNullOrEmpty(secretName)) + { + throw new ArgumentException("Null or empty secret name", nameof(secretName)); + } + secretObject = null; if (_cache.TryGetValue(secretName, out CachedSecret result) && !IsSecretOutdated(result)) @@ -122,4 +156,4 @@ public CachedSecret(ISecret secret) } } -} \ No newline at end of file +} diff --git a/src/NuGet.Services.KeyVault/EmptySecretReader.cs b/src/NuGet.Services.KeyVault/EmptySecretReader.cs index 558661923f..9ff9af7047 100644 --- a/src/NuGet.Services.KeyVault/EmptySecretReader.cs +++ b/src/NuGet.Services.KeyVault/EmptySecretReader.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading.Tasks; @@ -8,6 +8,16 @@ namespace NuGet.Services.KeyVault { public class EmptySecretReader : ICachingSecretReader { + public string GetSecret(string secretName) + { + return GetSecret(secretName, logger: null); + } + + public string GetSecret(string secretName, ILogger logger) + { + return secretName; + } + public Task GetSecretAsync(string secretName) => GetSecretAsync(secretName, logger: null); public Task GetSecretAsync(string secretName, ILogger logger) @@ -15,11 +25,21 @@ public Task GetSecretAsync(string secretName, ILogger logger) return Task.FromResult(secretName); } + public ISecret GetSecretObject(string secretName) + { + return GetSecretObject(secretName, logger: null); + } + + public ISecret GetSecretObject(string secretName, ILogger logger) + { + return new KeyVaultSecret(secretName, secretName, null); + } + public Task GetSecretObjectAsync(string secretName) => GetSecretObjectAsync(secretName, logger: null); public Task GetSecretObjectAsync(string secretName, ILogger logger) { - return Task.FromResult((ISecret)new KeyVaultSecret(secretName, secretName, null)); + return Task.FromResult(GetSecretObject(secretName, logger)); } public bool TryGetCachedSecret(string secretName, out string secretValue) => TryGetCachedSecret(secretName, logger: null, out secretValue); @@ -38,4 +58,4 @@ public bool TryGetCachedSecretObject(string secretName, ILogger logger, out ISec return true; } } -} \ No newline at end of file +} diff --git a/src/NuGet.Services.KeyVault/IRefreshableSecretReaderFactory.cs b/src/NuGet.Services.KeyVault/IRefreshableSecretReaderFactory.cs index e938060e0a..7272044518 100644 --- a/src/NuGet.Services.KeyVault/IRefreshableSecretReaderFactory.cs +++ b/src/NuGet.Services.KeyVault/IRefreshableSecretReaderFactory.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading; @@ -11,6 +11,13 @@ namespace NuGet.Services.KeyVault /// public interface IRefreshableSecretReaderFactory : ISecretReaderFactory { + /// + /// Refresh the values of the secrets that have already been read and cached. Since the cache is shared between + /// all instances creates, this refresh applies to all secret readers created by + /// this factory. + /// + void Refresh(); + /// /// Refresh the values of the secrets that have already been read and cached. Since the cache is shared between /// all instances creates, this refresh applies to all secret readers created by diff --git a/src/NuGet.Services.KeyVault/ISecretInjector.cs b/src/NuGet.Services.KeyVault/ISecretInjector.cs index fb77987cd6..b157b2f370 100644 --- a/src/NuGet.Services.KeyVault/ISecretInjector.cs +++ b/src/NuGet.Services.KeyVault/ISecretInjector.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading.Tasks; @@ -8,7 +8,9 @@ namespace NuGet.Services.KeyVault { public interface ISecretInjector { + string Inject(string input); + string Inject(string input, ILogger logger); Task InjectAsync(string input); Task InjectAsync(string input, ILogger logger); } -} \ No newline at end of file +} diff --git a/src/NuGet.Services.KeyVault/ISecretReader.cs b/src/NuGet.Services.KeyVault/ISecretReader.cs index f4bcbaa203..fb2bed1b1e 100644 --- a/src/NuGet.Services.KeyVault/ISecretReader.cs +++ b/src/NuGet.Services.KeyVault/ISecretReader.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading.Tasks; @@ -8,9 +8,13 @@ namespace NuGet.Services.KeyVault { public interface ISecretReader { + string GetSecret(string secretName); + string GetSecret(string secretName, ILogger logger); Task GetSecretAsync(string secretName); Task GetSecretAsync(string secretName, ILogger logger); + ISecret GetSecretObject(string secretName); + ISecret GetSecretObject(string secretName, ILogger logger); Task GetSecretObjectAsync(string secretName); Task GetSecretObjectAsync(string secretName, ILogger logger); } -} \ No newline at end of file +} diff --git a/src/NuGet.Services.KeyVault/KeyVaultReader.cs b/src/NuGet.Services.KeyVault/KeyVaultReader.cs index bc618528be..5214f8c6f8 100644 --- a/src/NuGet.Services.KeyVault/KeyVaultReader.cs +++ b/src/NuGet.Services.KeyVault/KeyVaultReader.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -33,6 +33,17 @@ public KeyVaultReader(KeyVaultConfiguration configuration) _keyVaultClient = new Lazy(InitializeClient); } + public string GetSecret(string secretName) + { + return GetSecret(secretName, logger: null); + } + + public string GetSecret(string secretName, ILogger logger) + { + AzureSecurityKeyVaultSecret secret = _keyVaultClient.Value.GetSecret(secretName); + return secret.Value; + } + public async Task GetSecretAsync(string secretName) { return await GetSecretAsync(secretName, logger: null); @@ -44,6 +55,17 @@ public async Task GetSecretAsync(string secretName, ILogger logger) return secret.Value; } + public ISecret GetSecretObject(string secretName) + { + return GetSecretObject(secretName, logger: null); + } + + public ISecret GetSecretObject(string secretName, ILogger logger) + { + AzureSecurityKeyVaultSecret secret = _keyVaultClient.Value.GetSecret(secretName); + return MapSecret(secretName, secret); + } + public async Task GetSecretObjectAsync(string secretName) { return await GetSecretObjectAsync(secretName, logger: null); @@ -52,6 +74,11 @@ public async Task GetSecretObjectAsync(string secretName) public async Task GetSecretObjectAsync(string secretName, ILogger logger) { AzureSecurityKeyVaultSecret secret = await _keyVaultClient.Value.GetSecretAsync(secretName); + return MapSecret(secretName, secret); + } + + private static ISecret MapSecret(string secretName, AzureSecurityKeyVaultSecret secret) + { return new KeyVaultSecret(secretName, secret.Value, secret.Properties.ExpiresOn); } diff --git a/src/NuGet.Services.KeyVault/RefreshableSecretReader.cs b/src/NuGet.Services.KeyVault/RefreshableSecretReader.cs index c614b747cb..26f2b7245c 100644 --- a/src/NuGet.Services.KeyVault/RefreshableSecretReader.cs +++ b/src/NuGet.Services.KeyVault/RefreshableSecretReader.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -32,6 +32,14 @@ public RefreshableSecretReader( _settings = settings ?? throw new ArgumentNullException(nameof(settings)); } + public void Refresh() + { + foreach (var secretName in _cache.Keys) + { + UncachedGetSecretObject(secretName); + } + } + public async Task RefreshAsync(CancellationToken token) { foreach (var secretName in _cache.Keys) @@ -45,19 +53,39 @@ public async Task RefreshAsync(CancellationToken token) } } + public string GetSecret(string secretName) + { + return GetSecret(secretName, logger: null); + } + + public string GetSecret(string secretName, ILogger logger) + { + return GetSecretObject(secretName, logger).Value; + } + public Task GetSecretAsync(string secretName) { return GetSecretAsync(secretName, logger: null); } - public Task GetSecretAsync(string secretName, ILogger logger) + public async Task GetSecretAsync(string secretName, ILogger logger) + { + return (await GetSecretObjectAsync(secretName, logger)).Value; + } + + public ISecret GetSecretObject(string secretName) + { + return GetSecretObject(secretName, logger: null); + } + + public ISecret GetSecretObject(string secretName, ILogger logger) { if (TryGetCachedSecretObject(secretName, out var secret)) { - return Task.FromResult(secret.Value); + return secret; } - return UncachedGetSecretAsync(secretName); + return UncachedGetSecretObject(secretName); } public Task GetSecretObjectAsync(string secretName) @@ -108,10 +136,11 @@ public bool TryGetCachedSecretObject(string secretName, ILogger logger, out ISec public bool TryGetCachedSecretObject(string secretName, out ISecret secretObject) => TryGetCachedSecretObject(secretName, logger: null, secretObject: out secretObject); - private async Task UncachedGetSecretAsync(string secretName) + private ISecret UncachedGetSecretObject(string secretName) { - var secretObject = await UncachedGetSecretObjectAsync(secretName); - return secretObject.Value; + var secretObject = _secretReader.GetSecretObject(secretName); + _cache.AddOrUpdate(secretName, secretObject, (_, __) => secretObject); + return secretObject; } private async Task UncachedGetSecretObjectAsync(string secretName) @@ -121,4 +150,4 @@ private async Task UncachedGetSecretObjectAsync(string secretName) return secretObject; } } -} \ No newline at end of file +} diff --git a/src/NuGet.Services.KeyVault/RefreshableSecretReaderFactory.cs b/src/NuGet.Services.KeyVault/RefreshableSecretReaderFactory.cs index 6abc0d27b8..990eb831f1 100644 --- a/src/NuGet.Services.KeyVault/RefreshableSecretReaderFactory.cs +++ b/src/NuGet.Services.KeyVault/RefreshableSecretReaderFactory.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -25,6 +25,11 @@ public RefreshableSecretReaderFactory(ISecretReaderFactory underlyingFactory, Re _settings = settings ?? throw new ArgumentNullException(nameof(settings)); } + public void Refresh() + { + GetRefreshableSecretReader().Refresh(); + } + public async Task RefreshAsync(CancellationToken token) { await GetRefreshableSecretReader().RefreshAsync(token); diff --git a/src/NuGet.Services.KeyVault/SecretInjector.cs b/src/NuGet.Services.KeyVault/SecretInjector.cs index 6d8cf1ffcc..9b55aa4a1f 100644 --- a/src/NuGet.Services.KeyVault/SecretInjector.cs +++ b/src/NuGet.Services.KeyVault/SecretInjector.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -37,6 +37,30 @@ public SecretInjector(ISecretReader secretReader, string frame) _cachingSecretReader = secretReader as ICachingSecretReader; } + public string Inject(string input) + { + return Inject(input, logger: null); + } + + public string Inject(string input, ILogger logger) + { + if (string.IsNullOrEmpty(input)) + { + return input; + } + + var output = new StringBuilder(input); + var secretNames = GetSecretNames(input); + + foreach (var secretName in secretNames) + { + var secretValue = _secretReader.GetSecret(secretName, logger); + output.Replace($"{_frame}{secretName}{_frame}", secretValue); + } + + return output.ToString(); + } + public Task InjectAsync(string input) { return InjectAsync(input, logger: null); @@ -130,4 +154,4 @@ private ICollection GetSecretNames(string input) return secretNames; } } -} \ No newline at end of file +} diff --git a/src/NuGetCDNRedirect/NuGetCDNRedirect.csproj b/src/NuGetCDNRedirect/NuGetCDNRedirect.csproj index 29168d7855..14a9ddaa33 100644 --- a/src/NuGetCDNRedirect/NuGetCDNRedirect.csproj +++ b/src/NuGetCDNRedirect/NuGetCDNRedirect.csproj @@ -1,4 +1,4 @@ - + diff --git a/tests/NuGet.Services.Configuration.Tests/SecretDictionaryFacts.cs b/tests/NuGet.Services.Configuration.Tests/SecretDictionaryFacts.cs index 6314d3fca6..b800544138 100644 --- a/tests/NuGet.Services.Configuration.Tests/SecretDictionaryFacts.cs +++ b/tests/NuGet.Services.Configuration.Tests/SecretDictionaryFacts.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Moq; @@ -18,7 +18,7 @@ public void RefreshesSecretWhenItChanges() { // Arrange var mockSecretInjector = new Mock(); - mockSecretInjector.Setup(x => x.InjectAsync(It.IsAny())).Returns(Task.FromResult(Secret1.InjectedValue)); + mockSecretInjector.Setup(x => x.Inject(It.IsAny())).Returns(Secret1.InjectedValue); var unprocessedDictionary = new Dictionary() { @@ -32,19 +32,19 @@ public void RefreshesSecretWhenItChanges() var value2 = secretDict[Secret1.Key]; // Assert - mockSecretInjector.Verify(x => x.InjectAsync(It.IsAny()), Times.Exactly(2)); + mockSecretInjector.Verify(x => x.Inject(It.IsAny()), Times.Exactly(2)); Assert.Equal(Secret1.InjectedValue, value1); Assert.Equal(value1, value2); // Arrange 2 - mockSecretInjector.Setup(x => x.InjectAsync(It.IsAny())).Returns(Task.FromResult(Secret2.InjectedValue)); + mockSecretInjector.Setup(x => x.Inject(It.IsAny())).Returns(Secret2.InjectedValue); // Act 2 var value3 = secretDict[Secret1.Key]; var value4 = secretDict[Secret1.Key]; // Assert 2 - mockSecretInjector.Verify(x => x.InjectAsync(It.IsAny()), Times.Exactly(4)); + mockSecretInjector.Verify(x => x.Inject(It.IsAny()), Times.Exactly(4)); Assert.Equal(Secret2.InjectedValue, value3); Assert.Equal(value3, value4); } @@ -349,7 +349,7 @@ public void NotInjectedKeys() var notInjectedKeys = new HashSet { key }; var mockSecretInjector = new Mock(); - mockSecretInjector.Setup(x => x.InjectAsync(It.IsAny())); + mockSecretInjector.Setup(x => x.Inject(It.IsAny())); var secretDict = CreatSecretDictionaryWithNotInjectedKeys(mockSecretInjector.Object, unprocessedDictionary, @@ -381,7 +381,7 @@ public void NotInjectedKeys() // Act and Assert 6 Assert.True(secretDict.Remove(key)); - mockSecretInjector.Verify(x => x.InjectAsync(It.IsAny()), Times.Never); + mockSecretInjector.Verify(x => x.Inject(It.IsAny()), Times.Never); } /// @@ -429,7 +429,7 @@ private static IDictionary CreatSecretDictionaryWithNotInjectedK private static Mock CreateMappedSecretInjectorMock(IDictionary keyToValue) { var mockSecretInjector = new Mock(); - mockSecretInjector.Setup(x => x.InjectAsync(It.IsAny())).Returns(key => Task.FromResult(keyToValue[key])); + mockSecretInjector.Setup(x => x.Inject(It.IsAny())).Returns(key => keyToValue[key]); return mockSecretInjector; } diff --git a/tests/NuGet.Services.KeyVault.Tests/CachingSecretReaderFacts.cs b/tests/NuGet.Services.KeyVault.Tests/CachingSecretReaderFacts.cs index 95df900a0c..34543400dc 100644 --- a/tests/NuGet.Services.KeyVault.Tests/CachingSecretReaderFacts.cs +++ b/tests/NuGet.Services.KeyVault.Tests/CachingSecretReaderFacts.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -13,7 +13,7 @@ namespace NuGet.Services.KeyVault.Tests public class CachingSecretReaderFacts { [Fact] - public async Task WhenGetSecretIsCalledCacheIsUsed() + public async Task WhenGetSecretAsyncIsCalledCacheIsUsed() { // Arrange const string secretName = "secretname"; @@ -41,6 +41,35 @@ public async Task WhenGetSecretIsCalledCacheIsUsed() It.IsAny>()), Times.Once); } + [Fact] + public void WhenGetSecretIsCalledCacheIsUsed() + { + // Arrange + const string secretName = "secretname"; + const string secretValue = "testValue"; + KeyVaultSecret secret = new KeyVaultSecret(secretName, secretValue, null); + + var mockSecretReader = new Mock(); + mockSecretReader + .Setup(x => x.GetSecretObject(It.IsAny())) + .Returns(secret); + var mockLogger = new Mock(); + + var cachingSecretReader = new CachingSecretReader(mockSecretReader.Object, int.MaxValue); + + // Act + var value1 = cachingSecretReader.GetSecret("secretname", mockLogger.Object); + var value2 = cachingSecretReader.GetSecret("secretname", mockLogger.Object); + + // Assert + mockSecretReader.Verify(x => x.GetSecretObject(It.IsAny()), Times.Once); + mockLogger.Verify(x => x.Log(It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>()), Times.Once); + } + [Fact] public async Task WhenSecretIsFreshTryGetCachedSecretReturnsIt() { @@ -268,4 +297,4 @@ public async Task WhenSecretIsStaleTryGetCachedSecretReturnsNullWithoutLogger() Assert.Null(value2); } } -} \ No newline at end of file +} diff --git a/tests/NuGet.Services.KeyVault.Tests/KeyVaultReaderFormatterFacts.cs b/tests/NuGet.Services.KeyVault.Tests/KeyVaultReaderFormatterFacts.cs index 13e2b7f68c..7cfdbeba71 100644 --- a/tests/NuGet.Services.KeyVault.Tests/KeyVaultReaderFormatterFacts.cs +++ b/tests/NuGet.Services.KeyVault.Tests/KeyVaultReaderFormatterFacts.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; @@ -58,13 +58,15 @@ public KeyVaultReaderFormatterFacts() var mockKeyVault = new Mock(); mockKeyVault.Setup(x => x.GetSecretAsync(It.IsAny(), It.IsAny())) .Returns((string s, ILogger logger) => Task.FromResult(s.ToUpper())); + mockKeyVault.Setup(x => x.GetSecret(It.IsAny(), It.IsAny())) + .Returns((string s, ILogger logger) => s.ToUpper()); _secretInjector = new SecretInjector(mockKeyVault.Object); } [Theory] [MemberData(nameof(_testFormatParameters))] - public async Task TestFormat(string input, string expectedOutput) + public async Task TestFormatAsync(string input, string expectedOutput) { // Act string formattedString = await _secretInjector.InjectAsync(input); @@ -72,5 +74,16 @@ public async Task TestFormat(string input, string expectedOutput) // Assert formattedString.Should().BeEquivalentTo(expectedOutput); } + + [Theory] + [MemberData(nameof(_testFormatParameters))] + public void TestFormat(string input, string expectedOutput) + { + // Act + string formattedString = _secretInjector.Inject(input); + + // Assert + formattedString.Should().BeEquivalentTo(expectedOutput); + } } } diff --git a/tests/NuGet.Services.KeyVault.Tests/RefreshableSecretReaderFactoryFacts.cs b/tests/NuGet.Services.KeyVault.Tests/RefreshableSecretReaderFactoryFacts.cs index 1e480c9864..77552745d9 100644 --- a/tests/NuGet.Services.KeyVault.Tests/RefreshableSecretReaderFactoryFacts.cs +++ b/tests/NuGet.Services.KeyVault.Tests/RefreshableSecretReaderFactoryFacts.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -14,7 +14,7 @@ public class RefreshableSecretReaderFactoryFacts public class CreateSecretReader : Facts { [Fact] - public async Task CreatesWrapper() + public async Task CreatesWrapperAsync() { var actual = Target.CreateSecretReader(); @@ -23,6 +23,17 @@ public async Task CreatesWrapper() Assert.Same(secret, Secret.Object); UnderlyingReader.Verify(x => x.GetSecretObjectAsync(SecretName), Times.Once); } + + [Fact] + public void CreatesWrapper() + { + var actual = Target.CreateSecretReader(); + + var secret = actual.GetSecretObject(SecretName); + Assert.IsType(actual); + Assert.Same(secret, Secret.Object); + UnderlyingReader.Verify(x => x.GetSecretObject(SecretName), Times.Once); + } } public class CreateSecretInjector : Facts @@ -41,7 +52,7 @@ public void CreatesWrapper() public class RefreshAsync : Facts { [Fact] - public async Task RefreshesSecrets() + public async Task RefreshesSecretsAsync() { var reader = Target.CreateSecretReader(); await reader.GetSecretAsync(SecretName); @@ -51,18 +62,39 @@ public async Task RefreshesSecrets() UnderlyingReader.Verify(x => x.GetSecretObjectAsync(SecretName), Times.Once); } + + [Fact] + public void RefreshesSecrets() + { + var reader = Target.CreateSecretReader(); + reader.GetSecret(SecretName); + UnderlyingReader.Invocations.Clear(); + + Target.Refresh(); + + UnderlyingReader.Verify(x => x.GetSecretObject(SecretName), Times.Once); + } } public class Settings : Facts { [Fact] - public async Task AffectCreatedReaders() + public async Task AffectCreatedReadersAsync() { var actual = Target.CreateSecretReader(); Settings.BlockUncachedReads = true; await Assert.ThrowsAsync(() => actual.GetSecretAsync(SecretName)); } + + [Fact] + public void AffectCreatedReaders() + { + var actual = Target.CreateSecretReader(); + Settings.BlockUncachedReads = true; + + Assert.Throws(() => actual.GetSecret(SecretName)); + } } public abstract class Facts @@ -83,6 +115,9 @@ public Facts() UnderlyingFactory .Setup(x => x.CreateSecretInjector(It.IsAny())) .Returns(() => SecretInjector.Object); + UnderlyingReader + .Setup(x => x.GetSecretObject(It.IsAny())) + .Returns(() => Secret.Object); UnderlyingReader .Setup(x => x.GetSecretObjectAsync(It.IsAny())) .ReturnsAsync(() => Secret.Object); diff --git a/tests/NuGet.Services.KeyVault.Tests/RefreshableSecretReaderFacts.cs b/tests/NuGet.Services.KeyVault.Tests/RefreshableSecretReaderFacts.cs index 3344397a50..b17b4f20c5 100644 --- a/tests/NuGet.Services.KeyVault.Tests/RefreshableSecretReaderFacts.cs +++ b/tests/NuGet.Services.KeyVault.Tests/RefreshableSecretReaderFacts.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -65,6 +65,44 @@ public async Task RespectsTheToken() } } + public class Refresh : Facts + { + [Fact] + public void DoesNothingWithEmptyCache() + { + Target.Refresh(); + + SecretReader.Verify(x => x.GetSecret(It.IsAny()), Times.Never); + SecretReader.Verify(x => x.GetSecretObject(It.IsAny()), Times.Never); + } + + [Fact] + public void RefreshesAllNames() + { + Target.GetSecret(SecretNameA); + Target.GetSecret(SecretNameB); + SecretReader.Invocations.Clear(); + + Target.Refresh(); + + SecretReader.Verify(x => x.GetSecret(It.IsAny()), Times.Never); + SecretReader.Verify(x => x.GetSecretObject(SecretNameA), Times.Once); + SecretReader.Verify(x => x.GetSecretObject(SecretNameB), Times.Once); + } + + [Fact] + public void CachesLatestValue() + { + Target.GetSecret(SecretNameA); + SecretReader.Setup(x => x.GetSecretObject(SecretNameA)).Returns(() => SecretB.Object); + + Target.Refresh(); + + var secretObject = Target.GetSecretObject(SecretNameA); + Assert.Same(SecretB.Object, secretObject); + } + } + public class GetSecretObjectAsync : Facts { [Fact] @@ -111,6 +149,27 @@ public async Task ThrowsIfReadsAreBlocked() } } + public class GetSecretObject : Facts + { + [Fact] + public void FetchesAnUncachedSecret() + { + var actual = Target.GetSecretObject(SecretNameA); + + Assert.Same(SecretA.Object, actual); + SecretReader.Verify(x => x.GetSecretObject(SecretNameA), Times.Once); + } + + [Fact] + public void ThrowsIfReadsAreBlocked() + { + Settings.BlockUncachedReads = true; + + var ex = Assert.Throws(() => Target.GetSecret(SecretNameA)); + Assert.Equal($"The secret '{SecretNameA}' is not cached.", ex.Message); + } + } + public class GetSecretAsync : Facts { [Fact] @@ -148,6 +207,19 @@ public async Task DoesNotSwitchThreadIfAlreadyCached() } } + public class GetSecret : Facts + { + [Fact] + public void FetchesAnUncachedSecret() + { + var actual = Target.GetSecret(SecretNameA); + + Assert.Same(SecretA.Object.Value, actual); + SecretReader.Verify(x => x.GetSecret(It.IsAny()), Times.Never); + SecretReader.Verify(x => x.GetSecretObject(SecretNameA), Times.Once); + } + } + public abstract class Facts { public Facts() @@ -165,6 +237,19 @@ public Facts() SecretA.Setup(x => x.Value).Returns("A-value"); SecretB.Setup(x => x.Value).Returns("B-value"); + SecretReader + .Setup(x => x.GetSecretObject(SecretNameA)) + .Returns(() => + { + return SecretA.Object; + }); + SecretReader + .Setup(x => x.GetSecretObject(SecretNameB)) + .Returns(() => + { + return SecretB.Object; + }); + SecretReader .Setup(x => x.GetSecretObjectAsync(SecretNameA)) .Returns(async () => diff --git a/tests/NuGet.Services.KeyVault.Tests/SecretReaderFacts.cs b/tests/NuGet.Services.KeyVault.Tests/SecretReaderFacts.cs index a757321511..0950647f81 100644 --- a/tests/NuGet.Services.KeyVault.Tests/SecretReaderFacts.cs +++ b/tests/NuGet.Services.KeyVault.Tests/SecretReaderFacts.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -10,6 +10,31 @@ namespace NuGet.Services.KeyVault.Tests { public class SecretReaderFacts { + [Fact] + public void GetSecretObjectReturnsSecretExpiry() + { + // Arrange + const string secretName = "secretname"; + const string secretValue = "testValue"; + DateTime secretExpiration = DateTime.UtcNow.AddSeconds(3); + KeyVaultSecret secret = new KeyVaultSecret(secretName, secretValue, secretExpiration); + + var mockSecretReader = new Mock(); + mockSecretReader + .SetupSequence(x => x.GetSecretObject(It.IsAny())) + .Returns(secret); + + var cachingSecretReader = new CachingSecretReader(mockSecretReader.Object); + + // Act + var secretObject = cachingSecretReader.GetSecretObject(secretName); + + // Assert + mockSecretReader.Verify(x => x.GetSecretObject(It.IsAny()), Times.Once); + Assert.Equal(secretValue, secretObject.Value); + Assert.Equal(secretObject.Expiration, secretExpiration); + } + [Fact] public async Task GetSecretObjectAsyncReturnsSecretExpiry() {