diff --git a/src/PowerShellEditorServices/Language/LanguageService.cs b/src/PowerShellEditorServices/Language/LanguageService.cs index ff07acddd..70979cd2c 100644 --- a/src/PowerShellEditorServices/Language/LanguageService.cs +++ b/src/PowerShellEditorServices/Language/LanguageService.cs @@ -26,19 +26,28 @@ public class LanguageService { #region Private Fields - private ILogger logger; - private bool areAliasesLoaded; - private PowerShellContext powerShellContext; - private CompletionResults mostRecentCompletions; - private int mostRecentRequestLine; - private int mostRecentRequestOffest; - private string mostRecentRequestFile; - private Dictionary> CmdletToAliasDictionary; - private Dictionary AliasToCmdletDictionary; - private IDocumentSymbolProvider[] documentSymbolProviders; - const int DefaultWaitTimeoutMilliseconds = 5000; + private readonly ILogger _logger; + + private readonly PowerShellContext _powerShellContext; + + private readonly Dictionary> _cmdletToAliasDictionary; + + private readonly Dictionary _aliasToCmdletDictionary; + + private readonly IDocumentSymbolProvider[] _documentSymbolProviders; + + private bool _areAliasesLoaded; + + private CompletionResults _mostRecentCompletions; + + private int _mostRecentRequestLine; + + private int _mostRecentRequestOffest; + + private string _mostRecentRequestFile; + #endregion #region Constructors @@ -55,14 +64,14 @@ public LanguageService( PowerShellContext powerShellContext, ILogger logger) { - Validate.IsNotNull("powerShellContext", powerShellContext); + Validate.IsNotNull(nameof(powerShellContext), powerShellContext); - this.powerShellContext = powerShellContext; - this.logger = logger; + _powerShellContext = powerShellContext; + _logger = logger; - this.CmdletToAliasDictionary = new Dictionary>(StringComparer.OrdinalIgnoreCase); - this.AliasToCmdletDictionary = new Dictionary(StringComparer.OrdinalIgnoreCase); - this.documentSymbolProviders = new IDocumentSymbolProvider[] + _cmdletToAliasDictionary = new Dictionary>(StringComparer.OrdinalIgnoreCase); + _aliasToCmdletDictionary = new Dictionary(StringComparer.OrdinalIgnoreCase); + _documentSymbolProviders = new IDocumentSymbolProvider[] { new ScriptDocumentSymbolProvider(powerShellContext.LocalPowerShellVersion.Version), new PsdDocumentSymbolProvider(), @@ -95,7 +104,7 @@ public async Task GetCompletionsInFile( int lineNumber, int columnNumber) { - Validate.IsNotNull("scriptFile", scriptFile); + Validate.IsNotNull(nameof(scriptFile), scriptFile); // Get the offset at the specified position. This method // will also validate the given position. @@ -109,39 +118,40 @@ await AstOperations.GetCompletions( scriptFile.ScriptAst, scriptFile.ScriptTokens, fileOffset, - this.powerShellContext, - this.logger, + _powerShellContext, + _logger, new CancellationTokenSource(DefaultWaitTimeoutMilliseconds).Token); - if (commandCompletion != null) + if (commandCompletion == null) { - try - { - CompletionResults completionResults = - CompletionResults.Create( - scriptFile, - commandCompletion); - - // save state of most recent completion - mostRecentCompletions = completionResults; - mostRecentRequestFile = scriptFile.Id; - mostRecentRequestLine = lineNumber; - mostRecentRequestOffest = columnNumber; - - return completionResults; - } - catch (ArgumentException e) - { - // Bad completion results could return an invalid - // replacement range, catch that here - this.logger.Write( - LogLevel.Error, - $"Caught exception while trying to create CompletionResults:\n\n{e.ToString()}"); - } + return new CompletionResults(); } - // If all else fails, return empty results - return new CompletionResults(); + try + { + CompletionResults completionResults = + CompletionResults.Create( + scriptFile, + commandCompletion); + + // save state of most recent completion + _mostRecentCompletions = completionResults; + _mostRecentRequestFile = scriptFile.Id; + _mostRecentRequestLine = lineNumber; + _mostRecentRequestOffest = columnNumber; + + return completionResults; + } + catch (ArgumentException e) + { + // Bad completion results could return an invalid + // replacement range, catch that here + _logger.Write( + LogLevel.Error, + $"Caught exception while trying to create CompletionResults:\n\n{e.ToString()}"); + + return new CompletionResults(); + } } /// @@ -154,19 +164,21 @@ public CompletionDetails GetCompletionDetailsInFile( ScriptFile file, string entryName) { - // Makes sure the most recent completions request was the same line and column as this request - if (file.Id.Equals(mostRecentRequestFile)) + if (!file.Id.Equals(_mostRecentRequestFile)) { - CompletionDetails completionResult = - mostRecentCompletions.Completions.FirstOrDefault( - result => result.CompletionText.Equals(entryName)); - - return completionResult; + return null; } - else + + foreach (CompletionDetails completion in _mostRecentCompletions.Completions) { - return null; + if (completion.CompletionText.Equals(entryName)) + { + return completion; + } } + + // If we found no completions, return null + return null; } /// @@ -245,20 +257,18 @@ public async Task FindSymbolDetailsAtLocation( lineNumber, columnNumber); - if (symbolReference != null) - { - symbolReference.FilePath = scriptFile.FilePath; - symbolDetails = - await SymbolDetails.Create( - symbolReference, - this.powerShellContext); - } - else + if (symbolReference == null) { // TODO #21: Return Result return null; } + symbolReference.FilePath = scriptFile.FilePath; + symbolDetails = + await SymbolDetails.Create( + symbolReference, + _powerShellContext); + return symbolDetails; } @@ -269,18 +279,22 @@ await SymbolDetails.Create( /// public FindOccurrencesResult FindSymbolsInFile(ScriptFile scriptFile) { - Validate.IsNotNull("scriptFile", scriptFile); + Validate.IsNotNull(nameof(scriptFile), scriptFile); + + var foundOccurrences = new List(); + foreach (IDocumentSymbolProvider symbolProvider in _documentSymbolProviders) + { + foreach (SymbolReference reference in symbolProvider.ProvideDocumentSymbols(scriptFile)) + { + reference.SourceLine = scriptFile.GetLine(reference.ScriptRegion.StartLineNumber); + reference.FilePath = scriptFile.FilePath; + foundOccurrences.Add(reference); + } + } + return new FindOccurrencesResult { - FoundOccurrences = documentSymbolProviders - .SelectMany(p => p.ProvideDocumentSymbols(scriptFile)) - .Select(reference => - { - reference.SourceLine = - scriptFile.GetLine(reference.ScriptRegion.StartLineNumber); - reference.FilePath = scriptFile.FilePath; - return reference; - }) + FoundOccurrences = foundOccurrences }; } @@ -346,30 +360,27 @@ public async Task FindReferencesOfSymbol( foreach (object fileName in fileMap.Keys) { var file = (ScriptFile)fileMap[fileName]; - IEnumerable symbolReferencesinFile = - AstOperations - .FindReferencesOfSymbol( - file.ScriptAst, - foundSymbol, - CmdletToAliasDictionary, - AliasToCmdletDictionary) - .Select(reference => - { - try - { - reference.SourceLine = file.GetLine(reference.ScriptRegion.StartLineNumber); - } - catch (ArgumentOutOfRangeException e) - { - reference.SourceLine = string.Empty; - this.logger.WriteException("Found reference is out of range in script file", e); - } - - reference.FilePath = file.FilePath; - return reference; - }); - - symbolReferences.AddRange(symbolReferencesinFile); + + IEnumerable references = AstOperations.FindReferencesOfSymbol( + file.ScriptAst, + foundSymbol, + _cmdletToAliasDictionary, + _aliasToCmdletDictionary); + + foreach (SymbolReference reference in references) + { + try + { + reference.SourceLine = file.GetLine(reference.ScriptRegion.StartLineNumber); + } + catch (ArgumentOutOfRangeException e) + { + reference.SourceLine = string.Empty; + _logger.WriteException("Found reference is out of range in script file", e); + } + reference.FilePath = file.FilePath; + symbolReferences.Add(reference); + } } return new FindReferencesResult @@ -393,9 +404,9 @@ public async Task GetDefinitionOfSymbol( SymbolReference foundSymbol, Workspace workspace) { - Validate.IsNotNull("sourceFile", sourceFile); - Validate.IsNotNull("foundSymbol", foundSymbol); - Validate.IsNotNull("workspace", workspace); + Validate.IsNotNull(nameof(sourceFile), sourceFile); + Validate.IsNotNull(nameof(foundSymbol), foundSymbol); + Validate.IsNotNull(nameof(workspace), workspace); ScriptFile[] referencedFiles = workspace.ExpandScriptReferences( @@ -426,8 +437,8 @@ public async Task GetDefinitionOfSymbol( if (foundDefinition == null) { // Get a list of all powershell files in the workspace path - var allFiles = workspace.EnumeratePSFiles(); - foreach (var file in allFiles) + IEnumerable allFiles = workspace.EnumeratePSFiles(); + foreach (string file in allFiles) { if (filesSearched.Contains(file)) { @@ -458,7 +469,7 @@ public async Task GetDefinitionOfSymbol( CommandInfo cmdInfo = await CommandHelpers.GetCommandInfo( foundSymbol.SymbolName, - this.powerShellContext); + _powerShellContext); foundDefinition = FindDeclarationForBuiltinCommand( @@ -490,26 +501,23 @@ public FindOccurrencesResult FindOccurrencesInFile( lineNumber, columnNumber); - if (foundSymbol != null) - { - // find all references, and indicate that looking for aliases is not needed - IEnumerable symbolOccurrences = - AstOperations - .FindReferencesOfSymbol( - file.ScriptAst, - foundSymbol, - false); - - return - new FindOccurrencesResult - { - FoundOccurrences = symbolOccurrences - }; - } - else + if (foundSymbol == null) { return null; } + + // find all references, and indicate that looking for aliases is not needed + IEnumerable symbolOccurrences = + AstOperations + .FindReferencesOfSymbol( + file.ScriptAst, + foundSymbol, + false); + + return new FindOccurrencesResult + { + FoundOccurrences = symbolOccurrences + }; } /// @@ -530,44 +538,40 @@ public async Task FindParameterSetsInFile( lineNumber, columnNumber); - if (foundSymbol != null) + if (foundSymbol == null) { - CommandInfo commandInfo = - await CommandHelpers.GetCommandInfo( - foundSymbol.SymbolName, - this.powerShellContext); + return null; + } - if (commandInfo != null) - { - try - { - IEnumerable commandParamSets = commandInfo.ParameterSets; - return new ParameterSetSignatures(commandParamSets, foundSymbol); - } - catch (RuntimeException e) - { - // A RuntimeException will be thrown when an invalid attribute is - // on a parameter binding block and then that command/script has - // its signatures resolved by typing it into a script. - this.logger.WriteException("RuntimeException encountered while accessing command parameter sets", e); + CommandInfo commandInfo = + await CommandHelpers.GetCommandInfo( + foundSymbol.SymbolName, + _powerShellContext); - return null; - } - catch (InvalidOperationException) - { - // For some commands there are no paramsets (like applications). Until - // the valid command types are better understood, catch this exception - // which gets raised when there are no ParameterSets for the command type. - return null; - } - } - else - { - return null; - } + if (commandInfo == null) + { + return null; } - else + + try { + IEnumerable commandParamSets = commandInfo.ParameterSets; + return new ParameterSetSignatures(commandParamSets, foundSymbol); + } + catch (RuntimeException e) + { + // A RuntimeException will be thrown when an invalid attribute is + // on a parameter binding block and then that command/script has + // its signatures resolved by typing it into a script. + _logger.WriteException("RuntimeException encountered while accessing command parameter sets", e); + + return null; + } + catch (InvalidOperationException) + { + // For some commands there are no paramsets (like applications). Until + // the valid command types are better understood, catch this exception + // which gets raised when there are no ParameterSets for the command type. return null; } } @@ -585,7 +589,7 @@ public ScriptRegion FindSmallestStatementAstRegion( int lineNumber, int columnNumber) { - var ast = FindSmallestStatementAst(scriptFile, lineNumber, columnNumber); + Ast ast = FindSmallestStatementAst(scriptFile, lineNumber, columnNumber); if (ast == null) { return null; @@ -604,7 +608,7 @@ public FunctionDefinitionAst GetFunctionDefinitionAtLine( ScriptFile scriptFile, int lineNumber) { - var functionDefinitionAst = scriptFile.ScriptAst.Find( + Ast functionDefinitionAst = scriptFile.ScriptAst.Find( ast => ast is FunctionDefinitionAst && ast.Extent.StartLineNumber == lineNumber, true); @@ -624,7 +628,7 @@ public FunctionDefinitionAst GetFunctionDefinitionForHelpComment( out string helpLocation) { // check if the next line contains a function definition - var funcDefnAst = GetFunctionDefinitionAtLine(scriptFile, lineNumber + 1); + FunctionDefinitionAst funcDefnAst = GetFunctionDefinitionAtLine(scriptFile, lineNumber + 1); if (funcDefnAst != null) { helpLocation = "before"; @@ -632,7 +636,7 @@ public FunctionDefinitionAst GetFunctionDefinitionForHelpComment( } // find all the script definitions that contain the line `lineNumber` - var foundAsts = scriptFile.ScriptAst.FindAll( + IEnumerable foundAsts = scriptFile.ScriptAst.FindAll( ast => { var fdAst = ast as FunctionDefinitionAst; @@ -646,34 +650,43 @@ public FunctionDefinitionAst GetFunctionDefinitionForHelpComment( }, true); - if (foundAsts != null && foundAsts.Any()) + if (foundAsts == null || !foundAsts.Any()) { - // of all the function definitions found, return the innermost function - // definition that contains `lineNumber` - funcDefnAst = foundAsts.Cast().Aggregate((x, y) => - { - if (x.Extent.StartOffset >= y.Extent.StartOffset && x.Extent.EndOffset <= x.Extent.EndOffset) - { - return x; - } - - return y; - }); + helpLocation = null; + return null; + } - // TODO use tokens to check for non empty character instead of just checking for line offset - if (funcDefnAst.Body.Extent.StartLineNumber == lineNumber - 1) + // of all the function definitions found, return the innermost function + // definition that contains `lineNumber` + foreach (FunctionDefinitionAst foundAst in foundAsts.Cast()) + { + if (funcDefnAst == null) { - helpLocation = "begin"; - return funcDefnAst; + funcDefnAst = foundAst; + continue; } - if (funcDefnAst.Body.Extent.EndLineNumber == lineNumber + 1) + if (funcDefnAst.Extent.StartOffset >= foundAst.Extent.StartOffset + && funcDefnAst.Extent.EndOffset <= foundAst.Extent.EndOffset) { - helpLocation = "end"; - return funcDefnAst; + funcDefnAst = foundAst; } } + // TODO use tokens to check for non empty character instead of just checking for line offset + if (funcDefnAst.Body.Extent.StartLineNumber == lineNumber - 1) + { + helpLocation = "begin"; + return funcDefnAst; + } + + if (funcDefnAst.Body.Extent.EndLineNumber == lineNumber + 1) + { + helpLocation = "end"; + return funcDefnAst; + } + + // If we didn't find a function definition, then return null helpLocation = null; return null; } @@ -687,48 +700,50 @@ public FunctionDefinitionAst GetFunctionDefinitionForHelpComment( /// private async Task GetAliases() { - if (!this.areAliasesLoaded) + if (_areAliasesLoaded) { - try - { - RunspaceHandle runspaceHandle = - await this.powerShellContext.GetRunspaceHandle( - new CancellationTokenSource(DefaultWaitTimeoutMilliseconds).Token); + return; + } + + try + { + RunspaceHandle runspaceHandle = + await _powerShellContext.GetRunspaceHandle( + new CancellationTokenSource(DefaultWaitTimeoutMilliseconds).Token); - CommandInvocationIntrinsics invokeCommand = runspaceHandle.Runspace.SessionStateProxy.InvokeCommand; - IEnumerable aliases = invokeCommand.GetCommands("*", CommandTypes.Alias, true); + CommandInvocationIntrinsics invokeCommand = runspaceHandle.Runspace.SessionStateProxy.InvokeCommand; + IEnumerable aliases = invokeCommand.GetCommands("*", CommandTypes.Alias, true); - runspaceHandle.Dispose(); + runspaceHandle.Dispose(); - foreach (AliasInfo aliasInfo in aliases) + foreach (AliasInfo aliasInfo in aliases) + { + if (!_cmdletToAliasDictionary.ContainsKey(aliasInfo.Definition)) + { + _cmdletToAliasDictionary.Add(aliasInfo.Definition, new List{ aliasInfo.Name }); + } + else { - if (!CmdletToAliasDictionary.ContainsKey(aliasInfo.Definition)) - { - CmdletToAliasDictionary.Add(aliasInfo.Definition, new List() { aliasInfo.Name }); - } - else - { - CmdletToAliasDictionary[aliasInfo.Definition].Add(aliasInfo.Name); - } - - AliasToCmdletDictionary.Add(aliasInfo.Name, aliasInfo.Definition); + _cmdletToAliasDictionary[aliasInfo.Definition].Add(aliasInfo.Name); } - this.areAliasesLoaded = true; + _aliasToCmdletDictionary.Add(aliasInfo.Name, aliasInfo.Definition); } - catch (PSNotSupportedException e) - { - this.logger.Write( - LogLevel.Warning, - $"Caught PSNotSupportedException while attempting to get aliases from remote session:\n\n{e.ToString()}"); - // Prevent the aliases from being fetched again - no point if the remote doesn't support InvokeCommand. - this.areAliasesLoaded = true; - } - catch (TaskCanceledException) - { - // The wait for a RunspaceHandle has timed out, skip aliases for now - } + _areAliasesLoaded = true; + } + catch (PSNotSupportedException e) + { + _logger.Write( + LogLevel.Warning, + $"Caught PSNotSupportedException while attempting to get aliases from remote session:\n\n{e.ToString()}"); + + // Prevent the aliases from being fetched again - no point if the remote doesn't support InvokeCommand. + _areAliasesLoaded = true; + } + catch (TaskCanceledException) + { + // The wait for a RunspaceHandle has timed out, skip aliases for now } } @@ -736,39 +751,38 @@ private ScriptFile[] GetBuiltinCommandScriptFiles( PSModuleInfo moduleInfo, Workspace workspace) { - // if there is module info for this command - if (moduleInfo != null) + if (moduleInfo == null) { - string modPath = moduleInfo.Path; - List scriptFiles = new List(); - ScriptFile newFile; + return new ScriptFile[0]; + } - // find any files where the moduleInfo's path ends with ps1 or psm1 - // and add it to allowed script files - if (modPath.EndsWith(@".ps1") || modPath.EndsWith(@".psm1")) - { - newFile = workspace.GetFile(modPath); - newFile.IsAnalysisEnabled = false; - scriptFiles.Add(newFile); - } - if (moduleInfo.NestedModules.Count > 0) + string modPath = moduleInfo.Path; + List scriptFiles = new List(); + ScriptFile newFile; + + // find any files where the moduleInfo's path ends with ps1 or psm1 + // and add it to allowed script files + if (modPath.EndsWith(@".ps1") || modPath.EndsWith(@".psm1")) + { + newFile = workspace.GetFile(modPath); + newFile.IsAnalysisEnabled = false; + scriptFiles.Add(newFile); + } + if (moduleInfo.NestedModules.Count > 0) + { + foreach (PSModuleInfo nestedInfo in moduleInfo.NestedModules) { - foreach (PSModuleInfo nestedInfo in moduleInfo.NestedModules) + string nestedModPath = nestedInfo.Path; + if (nestedModPath.EndsWith(@".ps1") || nestedModPath.EndsWith(@".psm1")) { - string nestedModPath = nestedInfo.Path; - if (nestedModPath.EndsWith(@".ps1") || nestedModPath.EndsWith(@".psm1")) - { - newFile = workspace.GetFile(nestedModPath); - newFile.IsAnalysisEnabled = false; - scriptFiles.Add(newFile); - } + newFile = workspace.GetFile(nestedModPath); + newFile.IsAnalysisEnabled = false; + scriptFiles.Add(newFile); } } - - return scriptFiles.ToArray(); } - return new List().ToArray(); + return scriptFiles.ToArray(); } private SymbolReference FindDeclarationForBuiltinCommand( @@ -776,30 +790,27 @@ private SymbolReference FindDeclarationForBuiltinCommand( SymbolReference foundSymbol, Workspace workspace) { - SymbolReference foundDefinition = null; - if (commandInfo != null) + if (commandInfo == null) { - int index = 0; - ScriptFile[] nestedModuleFiles; - - nestedModuleFiles = - GetBuiltinCommandScriptFiles( - commandInfo.Module, - workspace); + return null; + } - while (foundDefinition == null && index < nestedModuleFiles.Length) - { - foundDefinition = - AstOperations.FindDefinitionOfSymbol( - nestedModuleFiles[index].ScriptAst, - foundSymbol); + ScriptFile[] nestedModuleFiles = + GetBuiltinCommandScriptFiles( + commandInfo.Module, + workspace); - if (foundDefinition != null) - { - foundDefinition.FilePath = nestedModuleFiles[index].FilePath; - } + SymbolReference foundDefinition = null; + foreach (ScriptFile nestedModuleFile in nestedModuleFiles) + { + foundDefinition = AstOperations.FindDefinitionOfSymbol( + nestedModuleFile.ScriptAst, + foundSymbol); - index++; + if (foundDefinition != null) + { + foundDefinition.FilePath = nestedModuleFile.FilePath; + break; } } @@ -808,13 +819,22 @@ private SymbolReference FindDeclarationForBuiltinCommand( private Ast FindSmallestStatementAst(ScriptFile scriptFile, int lineNumber, int columnNumber) { - var asts = scriptFile.ScriptAst.FindAll(ast => + IEnumerable asts = scriptFile.ScriptAst.FindAll(ast => { return ast is StatementAst && ast.Extent.Contains(lineNumber, columnNumber); }, true); - // Find ast with the smallest extent - return asts.MinElement((astX, astY) => astX.Extent.ExtentWidthComparer(astY.Extent)); + // Find the Ast with the smallest extent + Ast minAst = scriptFile.ScriptAst; + foreach (Ast ast in asts) + { + if (ast.Extent.ExtentWidthComparer(minAst.Extent) == -1) + { + minAst = ast; + } + } + + return minAst; } #endregion