Skip to content

Commit

Permalink
Merge pull request LykosAI#888 from ionite34/backport/main/pr-887
Browse files Browse the repository at this point in the history
[dev to main] backport: Get base model types from civit dynamically  (887)
  • Loading branch information
mohnjiles authored Nov 14, 2024
2 parents 61dfbed + 97ce099 commit 99b974c
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 14 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2.0.0.html).

## v2.12.4
### Changed
- Model browser base model types are now loaded dynamically from CivitAI, reducing the need for updates to add new types
### Fixed
- Fixed crash when clicking "Remind me Later" on the update dialog
- Fixed some cases of crashing when GitHub API rate limits are exceeded
### Supporters
#### Visionaries
- A huge thank you to our dedicated Visionary-tier Patreon supporter, **Waterclouds**! We’re thrilled to have your ongoing support!
#### Pioneers
- Shoutout to our great Pioneer-tier patrons: **tankfox**, **tanangular**, **Mr. Unknown**, **Szir777**, and our newest Pioneer, **Tigon**!. Your continued support is greatly appreciated!

## v2.12.3
### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq;
using System.Net.Http;
using System.Reactive.Linq;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using AsyncAwaitBestPractices;
using Avalonia.Controls;
Expand Down Expand Up @@ -42,6 +43,7 @@ public sealed partial class CivitAiBrowserViewModel : TabViewModelBase, IInfinit
private readonly ISettingsManager settingsManager;
private readonly ILiteDbContext liteDbContext;
private readonly INotificationService notificationService;
private bool dontSearch = false;

private readonly SourceCache<OrderedValue<CivitModel>, int> modelCache = new(static ov => ov.Value.Id);

Expand Down Expand Up @@ -80,7 +82,7 @@ public sealed partial class CivitAiBrowserViewModel : TabViewModelBase, IInfinit
private string noResultsText = string.Empty;

[ObservableProperty]
private string selectedBaseModelType = "All";
private string selectedBaseModelType;

[ObservableProperty]
private bool showSantaHats = true;
Expand All @@ -98,6 +100,11 @@ public sealed partial class CivitAiBrowserViewModel : TabViewModelBase, IInfinit
[NotifyPropertyChangedFor(nameof(StatsResizeFactor))]
private double resizeFactor;

private readonly SourceCache<string, string> baseModelCache = new(static s => s);

[ObservableProperty]
private IObservableCollection<string> allBaseModels = new ObservableCollectionExtended<string>();

public double StatsResizeFactor => Math.Clamp(ResizeFactor, 0.75d, 1.25d);

public IEnumerable<CivitPeriod> AllCivitPeriods =>
Expand All @@ -111,9 +118,6 @@ public sealed partial class CivitAiBrowserViewModel : TabViewModelBase, IInfinit
.Where(t => t == CivitModelType.All || t.ConvertTo<SharedFolderType>() > 0)
.OrderBy(t => t.ToString());

public IEnumerable<string> BaseModelOptions =>
Enum.GetValues<CivitBaseModelType>().Select(t => t.GetStringValue());

public CivitAiBrowserViewModel(
ICivitApi civitApi,
IDownloadService downloadService,
Expand Down Expand Up @@ -165,6 +169,10 @@ or nameof(HideEarlyAccessModels)
.SortAndBind(ModelCards, sortPredicate)
.Subscribe();

baseModelCache.Connect().DeferUntilLoaded().SortAndBind(AllBaseModels).Subscribe();

baseModelCache.AddOrUpdate(Enum.GetValues<CivitBaseModelType>().Select(t => t.GetStringValue()));

settingsManager.RelayPropertyFor(
this,
model => model.ShowNsfw,
Expand Down Expand Up @@ -234,6 +242,34 @@ public override void OnLoaded()
{
SearchModelsCommand.ExecuteAsync(false);
}

base.OnLoaded();
}

protected override async Task OnInitialLoadedAsync()
{
await base.OnInitialLoadedAsync();
var baseModels = await GetBaseModelList();
if (baseModels.Count == 0)
{
LoadSelectedBaseModelType();
return;
}

dontSearch = true;
baseModelCache.AddOrUpdate(baseModels);
dontSearch = false;

LoadSelectedBaseModelType();
}

private void LoadSelectedBaseModelType()
{
var searchOptions = settingsManager.Settings.ModelSearchOptions;
dontSearch = true;
SelectedBaseModelType = "All";
SelectedBaseModelType = searchOptions is null ? "All" : searchOptions.SelectedBaseModelType;
dontSearch = false;
}

/// <summary>
Expand Down Expand Up @@ -340,7 +376,22 @@ private async Task CivitModelQuery(CivitModelsRequest request, bool isInfiniteSc
}
);

UpdateModelCards(models, isInfiniteScroll);
if (cacheNew)
{
var doesBaseModelTypeMatch =
SelectedBaseModelType == "All"
? string.IsNullOrWhiteSpace(request.BaseModel)
: SelectedBaseModelType == request.BaseModel;
var doesModelTypeMatch =
SelectedModelType == CivitModelType.All
? request.Types == null || request.Types.Length == 0
: SelectedModelType == request.Types?.FirstOrDefault();

if (doesBaseModelTypeMatch && doesModelTypeMatch)
{
UpdateModelCards(models, isInfiniteScroll);
}
}

NextPageCursor = modelsResponse?.Metadata?.NextCursor;
}
Expand Down Expand Up @@ -615,6 +666,9 @@ partial void OnSelectedModelTypeChanged(CivitModelType value)

partial void OnSelectedBaseModelTypeChanged(string value)
{
if (dontSearch)
return;

TrySearchAgain().SafeFireAndForget();
settingsManager.Transaction(
s =>
Expand All @@ -632,6 +686,7 @@ private async Task TrySearchAgain(bool shouldUpdatePageNumber = true)
{
if (!HasSearched)
return;

modelCache.Clear();

if (shouldUpdatePageNumber)
Expand All @@ -649,5 +704,35 @@ private void UpdateResultsText()
NoResultsText = "No results found";
}

[Localizable(false)]
private async Task<List<string>> GetBaseModelList()
{
try
{
var baseModelsResponse = await civitApi.GetBaseModelList();
var jsonContent = await baseModelsResponse.Content.ReadAsStringAsync();
var baseModels = JsonNode.Parse(jsonContent);

var jArray =
baseModels?["error"]?["issues"]?[0]?["unionErrors"]?[0]?["issues"]?[0]?["options"]
as JsonArray;
var civitBaseModels = jArray?.GetValues<string>().ToList() ?? [];

civitBaseModels.Insert(0, CivitBaseModelType.All.ToString());

var filteredResults = civitBaseModels
.Where(s => s.Equals("odor", StringComparison.OrdinalIgnoreCase) == false)
.OrderBy(s => s)
.ToList();

return filteredResults;
}
catch (Exception e)
{
Logger.Error(e, "Failed to get base model list");
return [];
}
}

public override string Header => Resources.Label_CivitAi;
}
10 changes: 8 additions & 2 deletions StabilityMatrix.Avalonia/Views/CivitAiBrowserPage.axaml
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,14 @@
Grid.Column="3"
MinWidth="100"
Margin="4,0,4,8"
ItemsSource="{Binding BaseModelOptions}"
SelectedItem="{Binding SelectedBaseModelType}" />
ItemsSource="{Binding AllBaseModels}"
SelectedItem="{Binding SelectedBaseModelType}">
<ComboBox.ItemTemplate>
<DataTemplate x:DataType="system:String">
<TextBlock MinWidth="100" Text="{Binding .}" />
</DataTemplate>
</ComboBox.ItemTemplate>
</ComboBox>

<ui:CommandBar
Grid.Row="0"
Expand Down
6 changes: 5 additions & 1 deletion StabilityMatrix.Core/Api/ICivitApi.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Refit;
using System.Text.Json.Nodes;
using Refit;
using StabilityMatrix.Core.Models.Api;

namespace StabilityMatrix.Core.Api;
Expand All @@ -16,4 +17,7 @@ public interface ICivitApi

[Get("/api/v1/model-versions/{id}")]
Task<CivitModelVersion> GetModelVersionById(int id);

[Get("/api/v1/models?baseModels=gimmethelist")]
Task<HttpResponseMessage> GetBaseModelList();
}
12 changes: 12 additions & 0 deletions StabilityMatrix.Core/Models/Api/CivitBaseModelType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ public enum CivitBaseModelType
[StringValue("Lumina")]
Lumina,

[StringValue("Mochi")]
Mochi,

[StringValue("PixArt a")]
PixArtA,

Expand All @@ -53,6 +56,15 @@ public enum CivitBaseModelType
[StringValue("SD 3.5")]
Sd35,

[StringValue("SD 3.5 Large")]
Sd35Large,

[StringValue("SD 3.5 Large Turbo")]
Sd35LargeTurbo,

[StringValue("SD 3.5 Medium")]
Sd35Medium,

[StringValue("SDXL 0.9")]
Sdxl09,

Expand Down
1 change: 1 addition & 0 deletions StabilityMatrix.Core/Models/Api/CivitFileType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public enum CivitFileType
Model,
VAE,
Config,
Archive,

[EnumMember(Value = "Pruned Model")]
PrunedModel,
Expand Down
1 change: 1 addition & 0 deletions StabilityMatrix.Core/Models/Api/CivitModelFormat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ public enum CivitModelFormat
SafeTensor,
PickleTensor,
Diffusers,
GGUF,
Other
}
23 changes: 18 additions & 5 deletions StabilityMatrix.Core/Models/Packages/BaseGitPackage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ IPrerequisiteHelper prerequisiteHelper
PrerequisiteHelper = prerequisiteHelper;
}

public override async Task<DownloadPackageVersionOptions> GetLatestVersion(bool includePrerelease = false)
public override async Task<DownloadPackageVersionOptions?> GetLatestVersion(
bool includePrerelease = false
)
{
if (ShouldIgnoreReleases)
{
Expand All @@ -88,12 +90,23 @@ public override async Task<DownloadPackageVersionOptions> GetLatestVersion(bool
IsLatest = true,
IsPrerelease = false,
BranchName = MainBranch,
CommitHash = commits?.FirstOrDefault()?.Sha ?? "unknown"
CommitHash = commits?.FirstOrDefault()?.Sha
};
}

var releases = await GithubApi.GetAllReleases(RepositoryAuthor, RepositoryName).ConfigureAwait(false);
var latestRelease = includePrerelease ? releases.First() : releases.First(x => !x.Prerelease);
var releaseList = releases.ToList();
if (releaseList.Count == 0)
{
return new DownloadPackageVersionOptions
{
IsLatest = true,
IsPrerelease = false,
BranchName = MainBranch
};
}

var latestRelease = includePrerelease ? releaseList.First() : releaseList.First(x => !x.Prerelease);

return new DownloadPackageVersionOptions
{
Expand Down Expand Up @@ -319,7 +332,7 @@ public override async Task<bool> CheckForUpdates(InstalledPackage package)
await GetAllCommits(currentVersion.InstalledBranch!).ConfigureAwait(false)
)?.ToList();

if (allCommits == null || !allCommits.Any())
if (allCommits == null || allCommits.Count == 0)
{
Logger.Warn("No commits found for {Package}", package.PackageName);
return false;
Expand Down Expand Up @@ -363,7 +376,7 @@ await GetAllCommits(currentVersion.InstalledBranch!).ConfigureAwait(false)
await GetAllCommits(currentVersion.InstalledBranch!).ConfigureAwait(false)
)?.ToList();

if (allCommits == null || !allCommits.Any())
if (allCommits == null || allCommits.Count == 0)
{
Logger.Warn("No commits found for {Package}", installedPackage.PackageName);
return null;
Expand Down
2 changes: 1 addition & 1 deletion StabilityMatrix.Core/Models/Packages/BasePackage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public virtual TorchIndex GetRecommendedTorchVersion()
int page = 1,
int perPage = 10
);
public abstract Task<DownloadPackageVersionOptions> GetLatestVersion(bool includePrerelease = false);
public abstract Task<DownloadPackageVersionOptions?> GetLatestVersion(bool includePrerelease = false);
public abstract string MainBranch { get; }
public event EventHandler<int>? Exited;
public event EventHandler<string>? StartupComplete;
Expand Down

0 comments on commit 99b974c

Please sign in to comment.