Skip to content

Commit

Permalink
Enable tab-completion for nuget package versions (#42349)
Browse files Browse the repository at this point in the history
  • Loading branch information
baronfel authored Aug 15, 2024
1 parent 70064ae commit 4a46908
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 63 deletions.
86 changes: 85 additions & 1 deletion src/Cli/dotnet/NugetPackageDownloader/NuGetPackageDownloader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ private IEnumerable<PackageSource> LoadNuGetSources(PackageId packageId, Package

packageSourceMapping ??= PackageSourceMapping.GetPackageSourceMapping(settings);

// filter package patterns if enabled
// filter package patterns if enabled
if (_shouldUsePackageSourceMapping && packageSourceMapping?.IsEnabled == true)
{
IReadOnlyList<string> sources = packageSourceMapping.GetConfiguredPackageSources(packageId.ToString());
Expand Down Expand Up @@ -722,6 +722,90 @@ public async Task<NuGetVersion> GetLatestPackageVersion(PackageId packageId,
return packageMetadata.Identity.Version;
}

public async Task<IEnumerable<string>> GetPackageIdsAsync(string idStem, bool allowPrerelease, PackageSourceLocation packageSourceLocation = null, CancellationToken cancellationToken = default)
{
// grab allowed sources for the package in question
PackageId packageId = new(idStem);
IEnumerable<PackageSource> packagesSources = LoadNuGetSources(packageId, packageSourceLocation);
var autoCompletes = await Task.WhenAll(packagesSources.Select(async (source) => await GetAutocompleteAsync(source, cancellationToken).ConfigureAwait(false))).ConfigureAwait(false);
// filter down to autocomplete endpoints (not all sources support this)
var validAutoCompletes = autoCompletes.SelectMany(x => x);
// get versions valid for this source
var packageIdTasks = validAutoCompletes.Select(autocomplete => GetPackageIdsForSource(autocomplete, packageId, allowPrerelease, cancellationToken)).ToArray();
var packageIdLists = await Task.WhenAll(packageIdTasks).ConfigureAwait(false);
// sources may have the same versions, so we have to dedupe.
return packageIdLists.SelectMany(v => v).Distinct().OrderDescending();
}

public async Task<IEnumerable<NuGetVersion>> GetPackageVersionsAsync(PackageId packageId, string versionPrefix = null, bool allowPrerelease = false, PackageSourceLocation packageSourceLocation = null, CancellationToken cancellationToken = default)
{
// grab allowed sources for the package in question
IEnumerable<PackageSource> packagesSources = LoadNuGetSources(packageId, packageSourceLocation);
var autoCompletes = await Task.WhenAll(packagesSources.Select(async (source) => await GetAutocompleteAsync(source, cancellationToken).ConfigureAwait(false))).ConfigureAwait(false);
// filter down to autocomplete endpoints (not all sources support this)
var validAutoCompletes = autoCompletes.SelectMany(x => x);
// get versions valid for this source
var versionTasks = validAutoCompletes.Select(autocomplete => GetPackageVersionsForSource(autocomplete, packageId, versionPrefix, allowPrerelease, cancellationToken)).ToArray();
var versions = await Task.WhenAll(versionTasks).ConfigureAwait(false);
// sources may have the same versions, so we have to dedupe.
return versions.SelectMany(v => v).Distinct().OrderDescending();
}

private async Task<IEnumerable<AutoCompleteResource>> GetAutocompleteAsync(PackageSource source, CancellationToken cancellationToken)
{
SourceRepository repository = GetSourceRepository(source);
if (await repository.GetResourceAsync<AutoCompleteResource>(cancellationToken).ConfigureAwait(false) is var resource)
{
return [resource];
}
else return Enumerable.Empty<AutoCompleteResource>();
}

// only exposed for testing
internal static TimeSpan CliCompletionsTimeout
{
get => _cliCompletionsTimeout;
set => _cliCompletionsTimeout = value;
}
private static TimeSpan _cliCompletionsTimeout = TimeSpan.FromMilliseconds(500);
private async Task<IEnumerable<NuGetVersion>> GetPackageVersionsForSource(AutoCompleteResource autocomplete, PackageId packageId, string versionPrefix, bool allowPrerelease, CancellationToken cancellationToken)
{
try
{
var timeoutCts = new CancellationTokenSource(_cliCompletionsTimeout);
var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token);
// we use the NullLogger because we don't want to log to stdout for completions - they interfere with the completions mechanism of the shell program.
return await autocomplete.VersionStartsWith(packageId.ToString(), versionPrefix: versionPrefix ?? "", includePrerelease: allowPrerelease, sourceCacheContext: _cacheSettings, log: NullLogger.Instance, token: linkedCts.Token);
}
catch (FatalProtocolException) // this most often means that the source didn't actually have a SearchAutocompleteService
{
return Enumerable.Empty<NuGetVersion>();
}
catch (Exception) // any errors (i.e. auth) should just be ignored for completions
{
return Enumerable.Empty<NuGetVersion>();
}
}

private async Task<IEnumerable<string>> GetPackageIdsForSource(AutoCompleteResource autocomplete, PackageId packageId, bool allowPrerelease, CancellationToken cancellationToken)
{
try
{
var timeoutCts = new CancellationTokenSource(_cliCompletionsTimeout);
var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token);
// we use the NullLogger because we don't want to log to stdout for completions - they interfere with the completions mechanism of the shell program.
return await autocomplete.IdStartsWith(packageId.ToString(), includePrerelease: allowPrerelease, log: NullLogger.Instance, token: linkedCts.Token);
}
catch (FatalProtocolException) // this most often means that the source didn't actually have a SearchAutocompleteService
{
return Enumerable.Empty<string>();
}
catch (Exception) // any errors (i.e. auth) should just be ignored for completions
{
return Enumerable.Empty<string>();
}
}

private SourceRepository GetSourceRepository(PackageSource source)
{
if (!_sourceRepositories.ContainsKey(source))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Text.Json;
using Microsoft.DotNet.Tools;
using Microsoft.DotNet.Tools.Add.PackageReference;
using Microsoft.Extensions.EnvironmentAbstractions;
using NuGet.Versioning;
using LocalizableStrings = Microsoft.DotNet.Tools.Add.PackageReference.LocalizableStrings;

namespace Microsoft.DotNet.Cli
Expand All @@ -15,13 +17,34 @@ internal static class AddPackageParser
public static readonly CliArgument<string> CmdPackageArgument = new CliArgument<string>(LocalizableStrings.CmdPackage)
{
Description = LocalizableStrings.CmdPackageDescription
}.AddCompletions((context) => QueryNuGet(context.WordToComplete).Select(match => new CompletionItem(match)));
}.AddCompletions((context) =>
{
// we should take --prerelease flags into account for version completion
var allowPrerelease = context.ParseResult.GetValue(PrereleaseOption);
return QueryNuGet(context.WordToComplete, allowPrerelease, CancellationToken.None).Result.Select(packageId => new CompletionItem(packageId));
});

public static readonly CliOption<string> VersionOption = new ForwardedOption<string>("--version", "-v")
{
Description = LocalizableStrings.CmdVersionDescription,
HelpName = LocalizableStrings.CmdVersion
}.ForwardAsSingle(o => $"--version {o}");
}.ForwardAsSingle(o => $"--version {o}")
.AddCompletions((context) =>
{
// we can only do version completion if we have a package id
if (context.ParseResult.GetValue(CmdPackageArgument) is string packageId)
{
// we should take --prerelease flags into account for version completion
var allowPrerelease = context.ParseResult.GetValue(PrereleaseOption);
return QueryVersionsForPackage(packageId, context.WordToComplete, allowPrerelease, CancellationToken.None)
.Result
.Select(version => new CompletionItem(version.ToNormalizedString()));
}
else
{
return Enumerable.Empty<CompletionItem>();
}
});

public static readonly CliOption<string> FrameworkOption = new ForwardedOption<string>("--framework", "-f")
{
Expand Down Expand Up @@ -81,44 +104,31 @@ private static CliCommand ConstructCommand()
return command;
}

public static IEnumerable<string> QueryNuGet(string match)
public static async Task<IEnumerable<string>> QueryNuGet(string packageStem, bool allowPrerelease, CancellationToken cancellationToken)
{
var httpClient = new HttpClient();

Stream result;

try
{
using var cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(10));
var response = httpClient.GetAsync($"https://api-v2v3search-0.nuget.org/autocomplete?q={match}&skip=0&take=100", cancellation.Token)
.Result;

result = response.Content.ReadAsStreamAsync().Result;
var downloader = new NuGetPackageDownloader.NuGetPackageDownloader(packageInstallDir: new DirectoryPath());
var versions = await downloader.GetPackageIdsAsync(packageStem, allowPrerelease, cancellationToken: cancellationToken);
return versions;
}
catch (Exception)
{
yield break;
}

foreach (var packageId in EnumerablePackageIdFromQueryResponse(result))
{
yield return packageId;
return Enumerable.Empty<string>();
}
}

internal static IEnumerable<string> EnumerablePackageIdFromQueryResponse(Stream result)
internal static async Task<IEnumerable<NuGetVersion>> QueryVersionsForPackage(string packageId, string versionFragment, bool allowPrerelease, CancellationToken cancellationToken)
{
using (JsonDocument doc = JsonDocument.Parse(result))
try
{
JsonElement root = doc.RootElement;

if (root.TryGetProperty("data", out var data))
{
foreach (JsonElement packageIdElement in data.EnumerateArray())
{
yield return packageIdElement.GetString();
}
}
var downloader = new NuGetPackageDownloader.NuGetPackageDownloader(packageInstallDir: new DirectoryPath());
var versions = await downloader.GetPackageVersionsAsync(new(packageId), versionFragment, allowPrerelease, cancellationToken: cancellationToken);
return versions;
}
catch (Exception)
{
return Enumerable.Empty<NuGetVersion>();
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions test/TestAssets/TestProjects/NugetCompletion/nuget.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<packageSources>
<!--To inherit the global NuGet package sources remove the <clear/> line below -->
<clear />
<add key="nuget" value="https://api.nuget.org/v3/index.json" />
</packageSources>
</configuration>
93 changes: 93 additions & 0 deletions test/dotnet.Tests/CommandTests/CompleteCommandTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.DotNet.Cli;
using Microsoft.DotNet.Cli.NuGetPackageDownloader;

namespace Microsoft.DotNet.Tests.Commands
{
Expand Down Expand Up @@ -330,6 +331,98 @@ public void GivenDotnetToolInWithPosition()
reporter.Lines.OrderBy(c => c).Should().Equal(expected.OrderBy(c => c));
}

[Fact]
public void CompletesNugetPackageIds()
{
NuGetPackageDownloader.CliCompletionsTimeout = TimeSpan.FromDays(1);
var testAsset = _testAssetsManager.CopyTestAsset("NugetCompletion").WithSource();

string[] expected = ["Newtonsoft.Json"];
var reporter = new BufferedReporter();
var currentDirectory = Directory.GetCurrentDirectory();
try
{
Directory.SetCurrentDirectory(testAsset.Path);
CompleteCommand.RunWithReporter(GetArguments("dotnet add package Newt$"), reporter).Should().Be(0);
reporter.Lines.Should().Contain(expected);
}
finally
{
Directory.SetCurrentDirectory(currentDirectory);
}
}

[Fact]
public void CompletesNugetPackageVersions()
{
NuGetPackageDownloader.CliCompletionsTimeout = TimeSpan.FromDays(1);
var testAsset = _testAssetsManager.CopyTestAsset("NugetCompletion").WithSource();

string knownPackage = "Newtonsoft.Json";
string knownVersion = "13.0.1"; // not exhaustive
var reporter = new BufferedReporter();
var currentDirectory = Directory.GetCurrentDirectory();
try
{
Directory.SetCurrentDirectory(testAsset.Path);
CompleteCommand.RunWithReporter(GetArguments($"dotnet add package {knownPackage} --version $"), reporter).Should().Be(0);
reporter.Lines.Should().Contain(knownVersion);
}
finally
{
Directory.SetCurrentDirectory(currentDirectory);
}
}

[Fact]
public void CompletesNugetPackageVersionsWithStem()
{
NuGetPackageDownloader.CliCompletionsTimeout = TimeSpan.FromDays(1);
var testAsset = _testAssetsManager.CopyTestAsset("NugetCompletion").WithSource();

string knownPackage = "Newtonsoft.Json";
string knownVersion = "13.0"; // not exhaustive
string[] expectedVersions = ["13.0.1", "13.0.2", "13.0.3"]; // not exhaustive
var reporter = new BufferedReporter();
var currentDirectory = Directory.GetCurrentDirectory();
try
{
Directory.SetCurrentDirectory(testAsset.Path);
CompleteCommand.RunWithReporter(GetArguments($"dotnet add package {knownPackage} --version {knownVersion}$"), reporter).Should().Be(0);
reporter.Lines.Should().Contain(expectedVersions);
// by default only stable versions should be shown
reporter.Lines.Should().AllSatisfy(v => v.Should().NotContain("-"));

}
finally
{
Directory.SetCurrentDirectory(currentDirectory);
}
}

[Fact]
public void CompletesNugetPackageVersionsWithPrereleaseVersionsWhenSpecified()
{
NuGetPackageDownloader.CliCompletionsTimeout = TimeSpan.FromDays(1);
var testAsset = _testAssetsManager.CopyTestAsset("NugetCompletion").WithSource();

string knownPackage = "Spectre.Console";
string knownVersion = "0.49.1";
string[] expectedVersions = ["0.49.1", "0.49.1-preview.0.2", "0.49.1-preview.0.5"]; // exhaustive for this specific version
var reporter = new BufferedReporter();
var currentDirectory = Directory.GetCurrentDirectory();
try
{
Directory.SetCurrentDirectory(testAsset.Path);
CompleteCommand.RunWithReporter(GetArguments($"dotnet add package {knownPackage} --prerelease --version {knownVersion}$"), reporter).Should().Be(0);
reporter.Lines.Should().Equal(expectedVersions);
}
finally
{
Directory.SetCurrentDirectory(currentDirectory);
}
}

/// <summary>
/// Converts command annotated with dollar sign($) into string array with "--position" option pointing at dollar sign location.
/// </summary>
Expand Down
33 changes: 0 additions & 33 deletions test/dotnet.Tests/ParserTests/AddReferenceParserTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,38 +57,5 @@ public void AddReferenceWithoutArgumentResultsInAnError()
.Should()
.BeEquivalentTo(string.Format(LocalizableStrings.RequiredArgumentMissingForCommand, "'reference'."));
}

[Fact]
public void EnumerablePackageIdFromQueryResponseResultsPackageIds()
{
using (var stream = new MemoryStream())
using (var writer = new StreamWriter(stream))
{
writer.Write(_nugetResponseSample);
writer.Flush();
stream.Position = 0;

AddPackageParser.EnumerablePackageIdFromQueryResponse(stream)
.Should()
.Contain(
new List<string>
{ "System.Text.Json",
"System.Text.Json.Mobile" });
}
}

private string _nugetResponseSample =
@"{
""@context"": {
""@vocab"": ""http://schema.nuget.org/schema#""
},
""totalHits"": 2,
""lastReopen"": ""2019-03-17T22:25:28.9238936Z"",
""index"": ""v3-lucene2-v2v3-20171018"",
""data"": [
""System.Text.Json"",
""System.Text.Json.Mobile""
]
}";
}
}

0 comments on commit 4a46908

Please sign in to comment.