Skip to content

Commit

Permalink
Improved Copilot resiliency
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Aug 5, 2024
1 parent 29207c8 commit a5820c2
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 5 deletions.
85 changes: 81 additions & 4 deletions MarkMpn.Sql4Cds.XTB/CopilotScriptObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class CopilotScriptObject
private bool _runningQuery;
private Dictionary<string, string> _pendingQueries;
private List<ToolOutput> _toolOutputs;
private TimeSpan _toolDelay;

internal CopilotScriptObject(SqlQueryControl control, WebView2 copilotWebView)
{
Expand All @@ -53,6 +54,7 @@ internal CopilotScriptObject(SqlQueryControl control, WebView2 copilotWebView)
_control = control;
_copilotWebView = copilotWebView;
_pendingQueries = new Dictionary<string, string>();
_toolDelay = TimeSpan.FromSeconds(0.5);
}

public void Cancel()
Expand Down Expand Up @@ -89,7 +91,7 @@ public async Task<string[]> SendMessage(string request)
Instructions = definition.Instructions,
DefaultTools = definition.Tools,
});

Settings.Instance.AssistantVersion = Assembly.GetExecutingAssembly().GetName().Version.ToString();
SettingsManager.Instance.Save(typeof(PluginControl), Settings.Instance);
}
Expand Down Expand Up @@ -157,6 +159,13 @@ private async Task ExecuteInternal(string id, bool executeAllowed)
_pendingQueries.Remove(id);
}

public async Task<string[]> Retry()
{
var updates = _assistantClient.CreateRunStreamingAsync(_assistantThread, _assistant);
await DoRunAsync(updates);
return null;
}

private async Task DoRunAsync(AsyncResultCollection<StreamingUpdate> updates)
{
Func<Task> promptSuggestion = null;
Expand Down Expand Up @@ -185,7 +194,27 @@ private async Task DoRunAsync(AsyncResultCollection<StreamingUpdate> updates)
else if (update is RequiredActionUpdate func)
{
promptSuggestion = null;
var args = JsonConvert.DeserializeObject<Dictionary<string, string>>(func.FunctionArguments);
Dictionary<string, string> args;

try
{
args = JsonConvert.DeserializeObject<Dictionary<string, string>>(func.FunctionArguments);
}
catch (JsonException)
{
if (func.FunctionName == "execute_query")
{
// Query is sometimes presented directly rather than wrapped in JSON
args = new Dictionary<string, string>
{
["query"] = func.FunctionArguments
};
}
else
{
throw;
}
}

switch (func.FunctionName)
{
Expand All @@ -205,7 +234,7 @@ private async Task DoRunAsync(AsyncResultCollection<StreamingUpdate> updates)
}
else
{
var columns = metadata.Attributes.ToDictionary(a => a.LogicalName, a => new { displayName = a.DisplayName?.UserLocalizedLabel?.Label, description = a.Description?.UserLocalizedLabel?.Label, type = a.AttributeTypeName?.Value, options = (a as EnumAttributeMetadata)?.OptionSet?.Options?.ToDictionary(o => o.Value, o => o.Label?.UserLocalizedLabel?.Label), lookupTo = (a as LookupAttributeMetadata)?.Targets?.Select(target => target + "." + dataSource.Metadata[target].PrimaryIdAttribute)?.ToArray() });
var columns = metadata.Attributes.ToDictionary(a => a.LogicalName, a => new { displayName = a.DisplayName?.UserLocalizedLabel?.Label, description = ShowDescription(a) ? a.Description?.UserLocalizedLabel?.Label : null, type = a.AttributeTypeName?.Value, options = (a as EnumAttributeMetadata)?.OptionSet?.Options?.ToDictionary(o => o.Value, o => o.Label?.UserLocalizedLabel?.Label), lookupTo = (a as LookupAttributeMetadata)?.Targets?.Select(target => target + "." + dataSource.Metadata[target].PrimaryIdAttribute)?.ToArray() });
_toolOutputs.Add(new ToolOutput(func.ToolCallId, JsonConvert.SerializeObject(columns, new JsonSerializerSettings { NullValueHandling = NullValueHandling.Ignore })));
}
break;
Expand Down Expand Up @@ -444,9 +473,22 @@ private async Task DoRunAsync(AsyncResultCollection<StreamingUpdate> updates)
}

if (_toolOutputs.Any() && _pendingQueries.Count == 0 && !_canceled)
{
// Sleep to avoid overloading the API with sequential function calls
await Task.Delay(_toolDelay);

updates = _assistantClient.SubmitToolOutputsToRunStreamingAsync(_run, _toolOutputs);
}
}
} while (_run?.Status.IsTerminal == false);
} while (_run?.Status.IsTerminal == false && !_canceled);

if (_run?.LastError != null)
{
if (_run.LastError.Code == RunErrorCode.RateLimitExceeded)
_toolDelay = TimeSpan.FromSeconds(_toolDelay.TotalSeconds * 2);

await ShowRetryAsync(HttpUtility.HtmlEncode(_run.LastError.Message));
}
}
catch (Exception ex)
{
Expand All @@ -465,6 +507,36 @@ private async Task DoRunAsync(AsyncResultCollection<StreamingUpdate> updates)
}
}

private bool ShowDescription(AttributeMetadata a)
{
// Check if the description should be output to the API. Some common attributes are not useful
// or are already understood by the model and can be skipped to reduce the required tokens
switch (a.LogicalName)
{
case "createdby":
case "createdon":
case "createdonbehalfby":
case "modifiedby":
case "modifiedon":
case "modifiedonbehalfby":
case "owningbusinessunit":
case "owningteam":
case "owninguser":
case "ownerid":
case "transactioncurrencyid":
case "versionnumber":
case "importsequencenumber":
case "overriddencreatedon":
case "statecode":
case "statuscode":
case "timezoneruleversionnumber":
case "utcconversiontimezonecode":
return false;
}

return true;
}

private bool ContainsJoin(string sql)
{
var parsed = new TSql160Parser(Settings.Instance.QuotedIdentifiers).Parse(new StringReader(sql), out _);
Expand Down Expand Up @@ -499,6 +571,11 @@ private async Task ShowExecutePromptAsync(string html, string id)
await _copilotWebView.ExecuteScriptAsync("showExecutePrompt(" + JsonConvert.SerializeObject(html) + "," + JsonConvert.SerializeObject(id) + ")");
}

private async Task ShowRetryAsync(string html)
{
await _copilotWebView.ExecuteScriptAsync("showRetryPrompt(" + JsonConvert.SerializeObject(html) + ")");
}

private async Task RunStarted()
{
await _copilotWebView.ExecuteScriptAsync("runStarted()");
Expand Down
26 changes: 26 additions & 0 deletions MarkMpn.Sql4Cds.XTB/Resources/Copilot.html
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,32 @@
div.scrollIntoView(false);
}

showRetryPrompt = function (title) {
const div = document.createElement('div');
div.className = 'suggestion warning';
const block = document.createElement('div');
block.className = 'block';
block.innerHTML = title;
const icon = document.getElementById('warning').cloneNode(true);
div.appendChild(icon);
div.appendChild(block);

const button = document.createElement('button');
button.textContent = 'Retry';
button.addEventListener('click', async () => {
result.querySelectorAll("button").forEach(b => b.disabled = true);
send.className = "disabled";

showProcessingMessage();

await window.chrome.webview.hostObjects.sql4cds.Retry();
});
block.appendChild(button);

result.appendChild(div);
div.scrollIntoView(false);
}

sendMessage = async function (message) {
result.querySelectorAll("button").forEach(b => b.disabled = true);
send.className = "disabled";
Expand Down
8 changes: 7 additions & 1 deletion MarkMpn.Sql4Cds.XTB/SqlQueryControl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,13 @@ private void ShowResult(IRootExecutionPlanNode query, ExecuteParams args, DataTa
var dict = new Dictionary<string, object>();

for (var i = 0; i < results.Columns.Count; i++)
dict[results.Columns[i].ColumnName] = row[i];
{
// Do not put nulls in the dictionary as the serializer will fail when attempting to read the .Value property
if (row[i] is INullable nullable && nullable.IsNull)
dict[results.Columns[i].ColumnName] = null;
else
dict[results.Columns[i].ColumnName] = row[i];
}

rows.Add(dict);
}
Expand Down

0 comments on commit a5820c2

Please sign in to comment.