Skip to content

Commit

Permalink
Refactor OpenAI helpers (#46956)
Browse files Browse the repository at this point in the history
  • Loading branch information
christothes authored Nov 4, 2024
1 parent 1bb0be0 commit b6c3306
Show file tree
Hide file tree
Showing 21 changed files with 929 additions and 359 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,57 @@ public static partial class AzureOpenAIExtensions
public static OpenAI.Embeddings.EmbeddingClient GetOpenAIEmbeddingsClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
}
}
namespace Azure.CloudMachine.OpenAI.Chat
{
public partial class ChatTools
{
public ChatTools(params System.Type[] tools) { }
public System.Collections.Generic.IList<OpenAI.Chat.ChatTool> Definitions { get { throw null; } }
public void Add(System.Reflection.MethodInfo function) { }
public void Add(System.Type functions) { }
public string Call(OpenAI.Chat.ChatToolCall call) { throw null; }
public string Call(string name, object[] arguments) { throw null; }
public System.Collections.Generic.IEnumerable<OpenAI.Chat.ToolChatMessage> CallAll(System.Collections.Generic.IEnumerable<OpenAI.Chat.ChatToolCall> toolCalls) { throw null; }
protected string ClrToJsonTypeUtf16(System.Type clrType) { throw null; }
protected System.ReadOnlySpan<byte> ClrToJsonTypeUtf8(System.Type clrType) { throw null; }
protected virtual string GetMethodInfoToDescription(System.Reflection.MethodInfo function) { throw null; }
protected virtual string GetMethodInfoToName(System.Reflection.MethodInfo function) { throw null; }
protected virtual string GetParameterInfoToDescription(System.Reflection.ParameterInfo parameter) { throw null; }
}
}
namespace Azure.CloudMachine.OpenAI.Embeddings
{
public partial class EmbeddingsVectorbase
{
public EmbeddingsVectorbase(OpenAI.Embeddings.EmbeddingClient client, Azure.CloudMachine.OpenAI.Embeddings.VectorbaseStore store = null, int factChunkSize = 0) { }
public void Add(string text) { }
public System.Collections.Generic.IEnumerable<Azure.CloudMachine.OpenAI.Embeddings.VectorbaseEntry> Find(string text, Azure.CloudMachine.OpenAI.Embeddings.FindOptions options = null) { throw null; }
}
public partial class FindOptions
{
public FindOptions() { }
public int MaxEntries { get { throw null; } set { } }
public float Threshold { get { throw null; } set { } }
}
[System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)]
public readonly partial struct VectorbaseEntry
{
private readonly object _dummy;
private readonly int _dummyPrimitive;
public VectorbaseEntry(System.ReadOnlyMemory<float> vector, System.BinaryData data, int? id = default(int?)) { throw null; }
public System.BinaryData Data { get { throw null; } }
public int? Id { get { throw null; } }
public System.ReadOnlyMemory<float> Vector { get { throw null; } }
}
public abstract partial class VectorbaseStore
{
protected VectorbaseStore() { }
public abstract int Add(Azure.CloudMachine.OpenAI.Embeddings.VectorbaseEntry entry);
public abstract void Add(System.Collections.Generic.IReadOnlyList<Azure.CloudMachine.OpenAI.Embeddings.VectorbaseEntry> entry);
public static float CosineSimilarity(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y) { throw null; }
public abstract System.Collections.Generic.IEnumerable<Azure.CloudMachine.OpenAI.Embeddings.VectorbaseEntry> Find(System.ReadOnlyMemory<float> vector, Azure.CloudMachine.OpenAI.Embeddings.FindOptions options);
}
}
namespace Azure.Core
{
public partial class ClientCache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

<!-- Disable warning CS1591: Missing XML comment for publicly visible type or member -->
<NoWarn>CS1591</NoWarn>
<NoWarn>OPENAI001</NoWarn>
</PropertyGroup>

<ItemGroup>
Expand All @@ -19,6 +20,7 @@
<PackageReference Include="Azure.Storage.Blobs" />
<PackageReference Include="Azure.Security.KeyVault.Secrets" />
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" VersionOverride="8.0.0" />
<PackageReference Include="Microsoft.Bcl.Numerics" />
</ItemGroup>

</Project>
10 changes: 10 additions & 0 deletions sdk/cloudmachine/Azure.CloudMachine/src/ClientCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@
namespace Azure.Core;

// TODO: this is a very demo implementation. We need to do better
/// <summary>
/// The client cache.
/// </summary>
public class ClientCache
{
private readonly Dictionary<(Type, string), object> _clients = new Dictionary<(Type, string), object>();

/// <summary>
/// Gets a client from the cache.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="value"></param>
/// <param name="id"></param>
/// <returns></returns>
public T Get<T>(Func<T> value, string id = default) where T: class
{
var client = (typeof(T), id);
Expand Down
64 changes: 64 additions & 0 deletions sdk/cloudmachine/Azure.CloudMachine/src/ClientWorkspace.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,109 @@

namespace Azure.Core;

/// <summary>
/// Retrieves the connection options for a specified client type and instance ID.
/// Represents a workspace for client operations.
/// </summary>
public abstract class ClientWorkspace
{
/// <summary>
/// Retrieves the connection options for a specified client type and instance ID.
/// </summary>
/// <param name="clientType">The type of the client.</param>
/// <param name="instanceId">The instance ID of the client.</param>
/// <returns>The connection options for the specified client type and instance ID.</returns>
public abstract ClientConnectionOptions GetConnectionOptions(Type clientType, string instanceId = default);

/// <summary>
/// Gets the cache of subclients.
/// </summary>
[EditorBrowsable(EditorBrowsableState.Never)]
public ClientCache Subclients { get; } = new ClientCache();
}

/// <summary>
/// Represents the connection options for a client.
/// </summary>
public readonly struct ClientConnectionOptions
{
/// <summary>
/// Initializes a new instance of the <see cref="ClientConnectionOptions"/> struct with the specified endpoint and API key.
/// </summary>
/// <param name="endpoint">The endpoint URI.</param>
/// <param name="apiKey">The API key credential.</param>
public ClientConnectionOptions(Uri endpoint, string apiKey)
{
Endpoint = endpoint;
ApiKeyCredential = apiKey;
ConnectionKind = ClientConnectionKind.ApiKey;
}

/// <summary>
/// Initializes a new instance of the <see cref="ClientConnectionOptions"/> struct with the specified endpoint and token credential.
/// </summary>
/// <param name="endpoint">The endpoint URI.</param>
/// <param name="credential">The token credential.</param>
public ClientConnectionOptions(Uri endpoint, TokenCredential credential)
{
Endpoint = endpoint;
TokenCredential = credential;
ConnectionKind = ClientConnectionKind.EntraId;
}

/// <summary>
/// Initializes a new instance of the <see cref="ClientConnectionOptions"/> struct with the specified subclient ID.
/// </summary>
/// <param name="subclientId">The subclient ID.</param>
public ClientConnectionOptions(string subclientId)
{
Id = subclientId;
ConnectionKind = ClientConnectionKind.OutOfBand;
}

/// <summary>
/// Gets the kind of connection used by the client.
/// </summary>
public ClientConnectionKind ConnectionKind { get; }

/// <summary>
/// Gets the endpoint URI.
/// </summary>
public Uri Endpoint { get; }

/// <summary>
/// Gets the subclient ID.
/// </summary>
public string Id { get; }

/// <summary>
/// Gets the API key credential.
/// </summary>
public string ApiKeyCredential { get; }

/// <summary>
/// Gets the token credential.
/// </summary>
public TokenCredential TokenCredential { get; }
}

/// <summary>
/// Specifies the kind of connection used by the client.
/// </summary>
public enum ClientConnectionKind
{
/// <summary>
/// Represents a connection using Entra ID.
/// </summary>
EntraId,

/// <summary>
/// Represents a connection using an API key.
/// </summary>
ApiKey,

/// <summary>
/// Represents a connection using an out-of-band method.
/// </summary>
OutOfBand
}
18 changes: 18 additions & 0 deletions sdk/cloudmachine/Azure.CloudMachine/src/CloudMachineClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@

namespace Azure.CloudMachine;

/// <summary>
/// The cloud machine client.
/// </summary>
public partial class CloudMachineClient : CloudMachineWorkspace
{
/// <summary>
/// Initializes a new instance of the <see cref="CloudMachineClient"/> class for mocking purposes..
/// </summary>
protected CloudMachineClient()
{
Messaging = new MessagingServices(this);
Storage = new StorageServices(this);
}
#pragma warning disable AZC0007 // DO provide a minimal constructor that takes only the parameters required to connect to the service.
/// <summary>
/// Initializes a new instance of the <see cref="CloudMachineClient"/> class.
/// </summary>
/// <param name="credential">The token credential.</param>
/// <param name="configuration">The configuration settings.</param>
public CloudMachineClient(TokenCredential credential = default, IConfiguration configuration = default)
#pragma warning restore AZC0007 // DO provide a minimal constructor that takes only the parameters required to connect to the service.
: base(credential, configuration)
Expand All @@ -22,6 +33,13 @@ public CloudMachineClient(TokenCredential credential = default, IConfiguration c
Storage = new StorageServices(this);
}

/// <summary>
/// Gets the messaging services.
/// </summary>
public MessagingServices Messaging { get; }

/// <summary>
/// Gets the storage services.
/// </summary>
public StorageServices Storage { get; }
}
24 changes: 24 additions & 0 deletions sdk/cloudmachine/Azure.CloudMachine/src/CloudMachineWorkspace.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,28 @@

namespace Azure.CloudMachine;

/// <summary>
/// The cloud machine workspace.
/// </summary>
public class CloudMachineWorkspace : ClientWorkspace
{
private TokenCredential Credential { get; } = new ChainedTokenCredential(
new AzureCliCredential(),
new AzureDeveloperCliCredential()
);

/// <summary>
/// The cloud machine ID.
/// </summary>
[EditorBrowsable(EditorBrowsableState.Never)]
public string Id { get; }

/// <summary>
/// Initializes a new instance of the <see cref="CloudMachineWorkspace"/> class.
/// </summary>
/// <param name="credential"></param>
/// <param name="configuration"></param>
/// <exception cref="Exception"></exception>
[SuppressMessage("Usage", "AZC0007:DO provide a minimal constructor that takes only the parameters required to connect to the service.", Justification = "<Pending>")]
public CloudMachineWorkspace(TokenCredential credential = default, IConfiguration configuration = default)
{
Expand All @@ -46,6 +58,13 @@ public CloudMachineWorkspace(TokenCredential credential = default, IConfiguratio
Id = cmid!;
}

/// <summary>
/// Retrieves the connection options for a specified client type and instance ID.
/// </summary>
/// <param name="clientType"></param>
/// <param name="instanceId"></param>
/// <returns></returns>
/// <exception cref="Exception"></exception>
[EditorBrowsable(EditorBrowsableState.Never)]
public override ClientConnectionOptions GetConnectionOptions(Type clientType, string instanceId = default)
{
Expand Down Expand Up @@ -75,10 +94,15 @@ public override ClientConnectionOptions GetConnectionOptions(Type clientType, st
}
}

/// <inheritdoc/>
[EditorBrowsable(EditorBrowsableState.Never)]
public override bool Equals(object obj) => base.Equals(obj);

/// <inheritdoc/>
[EditorBrowsable(EditorBrowsableState.Never)]
public override int GetHashCode() => base.GetHashCode();

/// <inheritdoc/>
[EditorBrowsable(EditorBrowsableState.Never)]
public override string ToString() => Id;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@

namespace Azure.CloudMachine;

/// <summary>
/// The messaging services for the cloud machine.
/// </summary>
public readonly struct MessagingServices
{
private readonly CloudMachineClient _cm;
internal MessagingServices(CloudMachineClient cm) => _cm = cm;

/// <summary>
/// Sends a message to the service bus.
/// </summary>
/// <param name="serializable"></param>
public void SendMessage(object serializable)
{
ServiceBusSender sender = GetServiceBusSender();
Expand All @@ -24,6 +31,10 @@ public void SendMessage(object serializable)
#pragma warning restore AZC0102 // Do not use GetAwaiter().GetResult().
}

/// <summary>
/// Adds a function to be called when a message is received.
/// </summary>
/// <param name="received"></param>
public void WhenMessageReceived(Action<string> received)
{
var processor = _cm.Messaging.GetServiceBusProcessor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@

namespace Azure.CloudMachine;

/// <summary>
/// The storage file for the cloud machine.
/// </summary>
public class StorageFile
{
private readonly Response _response;

private StorageServices _storage;

/// <summary>
/// The path of the file in the storage account.
/// </summary>
public string Path { get; internal set; }

/// <summary>
Expand All @@ -26,14 +33,24 @@ public class StorageFile
/// <remarks>returns null if the file is not created as a return value of a service method call.</remarks>
public static implicit operator Response(StorageFile result) => result._response;

/// <summary>
/// The cancellation token for the storage operation.
/// </summary>
public CancellationToken CancellationToken { get; internal set; }

/// <summary>
/// Downloads the file from the storage account.
/// </summary>
/// <returns></returns>
public BinaryData Download()
=> _storage.DownloadBlob(Path);

// public async Task<BinaryData> DownloadAsync()
// => await _storage.DownloadBlobAsync(Path).ConfigureAwait(false);

/// <summary>
/// Deletes the file from the storage account.
/// </summary>
public void Delete()
=> _storage.DeleteBlob(Path);

Expand All @@ -54,12 +71,15 @@ internal StorageFile(StorageServices storage, string path, string requestId, Res
_response = response;
}

/// <inheritdoc />
[EditorBrowsable(EditorBrowsableState.Never)]
public override bool Equals(object obj) => base.Equals(obj);

/// <inheritdoc />
[EditorBrowsable(EditorBrowsableState.Never)]
public override int GetHashCode() => base.GetHashCode();

/// <inheritdoc />
[EditorBrowsable(EditorBrowsableState.Never)]
public override string ToString() => $"{Path}";
}
Loading

0 comments on commit b6c3306

Please sign in to comment.