Skip to content

Commit

Permalink
Refactor LLM to incude model type.
Browse files Browse the repository at this point in the history
  • Loading branch information
primaryobjects committed Jun 15, 2024
1 parent 330175d commit e828929
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
37 changes: 27 additions & 10 deletions Monster Collector/Managers/Concrete/BaseLlmManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,37 @@
using LlmTornado.Images;
using LlmTornado.Models;

public abstract class BaseLlmManager(string? apiKey, LLmProviders provider) : LLM
public abstract class BaseLlmManager : LLM
{
protected string? _apiKey = apiKey;
protected LLmProviders _provider = provider;
protected string? _apiKey;
protected LLmProviders _provider;
protected TornadoApi _api;

public bool IsValid() => _apiKey != null;
public BaseLlmManager(string? apiKey, LLmProviders provider)
{
_apiKey = apiKey;
_provider = provider;
_api = new([new ProviderAuthentication(provider, _apiKey ?? "")]);
}

public virtual async Task<string?> GetTextAsync(string prompt, string input)
protected ChatModel GetTextModel()
{
TornadoApi api = new([new ProviderAuthentication(provider, _apiKey ?? "")]);
ChatModel model = ChatModel.Cohere.CommandRPlus;
ChatModel model = _provider switch
{
LLmProviders.Cohere => ChatModel.Cohere.CommandRPlus,
LLmProviders.Anthropic => ChatModel.Anthropic.Claude3.Opus,
LLmProviders.OpenAi or LLmProviders.AzureOpenAi => ChatModel.OpenAi.Gpt35.Turbo,
_ => ChatModel.Cohere.CommandRPlus,
};

return model;
}

public bool IsValid() => _apiKey != null;

string? response = await api.Chat.CreateConversation(model)
public virtual async Task<string?> GetTextAsync(string prompt, string input, ChatModel? model)
{
string? response = await _api.Chat.CreateConversation(model ?? GetTextModel())
.AppendSystemMessage(prompt)
.AppendUserInput(input)
.GetResponse();
Expand All @@ -26,8 +44,7 @@ public abstract class BaseLlmManager(string? apiKey, LLmProviders provider) : LL

public virtual async Task<ImageResult?> GetImage(string prompt)
{
TornadoApi api = new([new ProviderAuthentication(provider, _apiKey ?? "")]);
ImageResult? response = await api.ImageGenerations.CreateImageAsync(
ImageResult? response = await _api.ImageGenerations.CreateImageAsync(
new ImageGenerationRequest(
prompt,
quality: ImageQuality.Standard,
Expand Down
6 changes: 6 additions & 0 deletions Monster Collector/Managers/Concrete/OpenAIManager.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using LlmTornado.Chat.Models;
using LlmTornado.Code;

public class OpenAIManager : BaseLlmManager
Expand All @@ -6,4 +7,9 @@ public OpenAIManager()
: base(Environment.GetEnvironmentVariable("OpenAIApiKey"), LLmProviders.OpenAi)
{
}

public override Task<string?> GetTextAsync(string prompt, string input, ChatModel? model)
{
return base.GetTextAsync(prompt, input, ChatModel.OpenAi.Gpt35.Turbo);
}
}
5 changes: 3 additions & 2 deletions Monster Collector/Managers/Interface/LLM.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using LlmTornado.Images;
using LlmTornado.Chat.Models;
using LlmTornado.Images;

public interface LLM
{
bool IsValid();
Task<string?> GetTextAsync(string prompt, string input);
Task<string?> GetTextAsync(string prompt, string input, ChatModel? model = null);
Task<ImageResult?> GetImage(string prompt);
}

0 comments on commit e828929

Please sign in to comment.