Skip to content

Commit

Permalink
feat: Add TrackError to mirror TrackSuccess (#64)
Browse files Browse the repository at this point in the history
Additionally, emit new `$ld:ai:generation:(success|error)` events on
success or failure.
  • Loading branch information
keelerm84 authored Dec 17, 2024
1 parent ac29d46 commit 7acc574
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 16 deletions.
5 changes: 5 additions & 0 deletions pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public interface ILdAiConfigTracker
/// </summary>
public void TrackSuccess();

/// <summary>
/// Tracks an unsuccessful generation event related to this config.
/// </summary>
public void TrackError();

/// <summary>
/// Tracks a request to a provider. The request is a task that returns a <see cref="Response"/>, which
/// contains information about the request such as token usage and metrics.
Expand Down
52 changes: 36 additions & 16 deletions pkgs/sdk/server-ai/src/LdAiConfigTracker.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using LaunchDarkly.Sdk.Server.Ai.Config;
using LaunchDarkly.Sdk.Server.Ai.Interfaces;
Expand All @@ -21,6 +22,8 @@ public class LdAiConfigTracker : ILdAiConfigTracker
private const string FeedbackPositive = "$ld:ai:feedback:user:positive";
private const string FeedbackNegative = "$ld:ai:feedback:user:negative";
private const string Generation = "$ld:ai:generation";
private const string GenerationSuccess = "$ld:ai:generation:success";
private const string GenerationError = "$ld:ai:generation:error";
private const string TokenTotal = "$ld:ai:tokens:total";
private const string TokenInput = "$ld:ai:tokens:input";
private const string TokenOutput = "$ld:ai:tokens:output";
Expand Down Expand Up @@ -57,18 +60,14 @@ public void TrackDuration(float durationMs) =>

/// <inheritdoc/>
public async Task<T> TrackDurationOfTask<T>(Task<T> task)
{
var result = await MeasureDurationOfTaskMs(task);
TrackDuration(result.Item2);
return result.Item1;
}

private static async Task<Tuple<T, long>> MeasureDurationOfTaskMs<T>(Task<T> task)
{
var sw = Stopwatch.StartNew();
var result = await task;
sw.Stop();
return Tuple.Create(result, sw.ElapsedMilliseconds);
try {
return await task;
} finally {
sw.Stop();
TrackDuration(sw.ElapsedMilliseconds);
}
}

/// <inheritdoc/>
Expand All @@ -90,23 +89,44 @@ public void TrackFeedback(Feedback feedback)
/// <inheritdoc/>
public void TrackSuccess()
{
_client.Track(GenerationSuccess, _context, _trackData, 1);
_client.Track(Generation, _context, _trackData, 1);
}

/// <inheritdoc/>
public void TrackError()
{
_client.Track(GenerationError, _context, _trackData, 1);
_client.Track(Generation, _context, _trackData, 1);
}

/// <inheritdoc/>
public async Task<Response> TrackRequest(Task<Response> request)
{
var (result, durationMs) = await MeasureDurationOfTaskMs(request);
TrackSuccess();
var sw = Stopwatch.StartNew();
try
{
var result = await request;
TrackSuccess();

sw.Stop();
TrackDuration(result.Metrics?.LatencyMs ?? sw.ElapsedMilliseconds);

TrackDuration(result.Metrics?.LatencyMs ?? durationMs);
if (result.Usage != null)
{
TrackTokens(result.Usage.Value);
}

if (result.Usage != null)
return result;
}
catch (Exception)
{
TrackTokens(result.Usage.Value);
sw.Stop();
TrackDuration(sw.ElapsedMilliseconds);
TrackError();
throw;
}

return result;
}

/// <inheritdoc/>
Expand Down
47 changes: 47 additions & 0 deletions pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ public void CanTrackSuccess()
var tracker = new LdAiConfigTracker(mockClient.Object, flagKey, config, context);
tracker.TrackSuccess();
mockClient.Verify(x => x.Track("$ld:ai:generation", context, data, 1.0f), Times.Once);
mockClient.Verify(x => x.Track("$ld:ai:generation:success", context, data, 1.0f), Times.Once);
}


[Fact]
public void CanTrackError()
{
var mockClient = new Mock<ILaunchDarklyClient>();
var context = Context.New("key");
const string flagKey = "key";
var config = LdAiConfig.Disabled;
var data = LdValue.ObjectFrom(new Dictionary<string, LdValue>
{
{ "variationKey", LdValue.Of(config.VariationKey) },
{ "configKey", LdValue.Of(flagKey) }
});

var tracker = new LdAiConfigTracker(mockClient.Object, flagKey, config, context);
tracker.TrackError();
mockClient.Verify(x => x.Track("$ld:ai:generation", context, data, 1.0f), Times.Once);
mockClient.Verify(x => x.Track("$ld:ai:generation:error", context, data, 1.0f), Times.Once);
}


Expand Down Expand Up @@ -189,6 +210,8 @@ public void CanTrackResponseWithSpecificLatency()

var result = tracker.TrackRequest(Task.Run(() => givenResponse));
Assert.Equal(givenResponse, result.Result);
mockClient.Verify(x => x.Track("$ld:ai:generation:success", context, data, 1.0f), Times.Once);
mockClient.Verify(x => x.Track("$ld:ai:generation", context, data, 1.0f), Times.Once);
mockClient.Verify(x => x.Track("$ld:ai:tokens:total", context, data, 1.0f), Times.Once);
mockClient.Verify(x => x.Track("$ld:ai:tokens:input", context, data, 2.0f), Times.Once);
mockClient.Verify(x => x.Track("$ld:ai:tokens:output", context, data, 3.0f), Times.Once);
Expand Down Expand Up @@ -228,5 +251,29 @@ public void CanTrackResponseWithPartialData()
// if latency isn't provided via Statistics, then it is automatically measured.
mockClient.Verify(x => x.Track("$ld:ai:duration:total", context, data, It.IsAny<double>()), Times.Once);
}

[Fact]
public async Task CanTrackExceptionFromResponse()
{
var mockClient = new Mock<ILaunchDarklyClient>();
var context = Context.New("key");
const string flagKey = "key";
var config = LdAiConfig.Disabled;
var data = LdValue.ObjectFrom(new Dictionary<string, LdValue>
{
{ "variationKey", LdValue.Of(config.VariationKey) },
{ "configKey", LdValue.Of(flagKey) }
});

var tracker = new LdAiConfigTracker(mockClient.Object, flagKey, config, context);

await Assert.ThrowsAsync<System.Exception>(() => tracker.TrackRequest(Task.FromException<Response>(new System.Exception("I am an exception"))));

mockClient.Verify(x => x.Track("$ld:ai:generation", context, data, 1.0f), Times.Once);
mockClient.Verify(x => x.Track("$ld:ai:generation:error", context, data, 1.0f), Times.Once);

// if latency isn't provided via Statistics, then it is automatically measured.
mockClient.Verify(x => x.Track("$ld:ai:duration:total", context, data, It.IsAny<double>()), Times.Once);
}
}
}

0 comments on commit 7acc574

Please sign in to comment.