diff --git a/src/LinkDotNet.Blog.Infrastructure/Persistence/Sql/Repository.cs b/src/LinkDotNet.Blog.Infrastructure/Persistence/Sql/Repository.cs index 7b19ad7a..08bafeda 100644 --- a/src/LinkDotNet.Blog.Infrastructure/Persistence/Sql/Repository.cs +++ b/src/LinkDotNet.Blog.Infrastructure/Persistence/Sql/Repository.cs @@ -11,15 +11,16 @@ namespace LinkDotNet.Blog.Infrastructure.Persistence.Sql; public class Repository : IRepository where TEntity : Entity { - private readonly BlogDbContext blogDbContext; + private readonly IDbContextFactory dbContextFactory; - public Repository(BlogDbContext blogDbContext) + public Repository(IDbContextFactory dbContextFactory) { - this.blogDbContext = blogDbContext; + this.dbContextFactory = dbContextFactory; } public async ValueTask GetByIdAsync(string id) { + await using var blogDbContext = await dbContextFactory.CreateDbContextAsync(); return await blogDbContext.Set().SingleOrDefaultAsync(b => b.Id == id); } @@ -30,6 +31,7 @@ public async ValueTask> GetAllAsync( int page = 1, int pageSize = int.MaxValue) { + await using var blogDbContext = await dbContextFactory.CreateDbContextAsync(); var entity = blogDbContext.Set().AsNoTracking().AsQueryable(); if (filter != null) @@ -49,6 +51,7 @@ public async ValueTask> GetAllAsync( public async ValueTask StoreAsync(TEntity entity) { + await using var blogDbContext = await dbContextFactory.CreateDbContextAsync(); if (string.IsNullOrEmpty(entity.Id)) { await blogDbContext.Set().AddAsync(entity); @@ -63,6 +66,7 @@ public async ValueTask StoreAsync(TEntity entity) public async ValueTask DeleteAsync(string id) { + await using var blogDbContext = await dbContextFactory.CreateDbContextAsync(); var entityToDelete = await GetByIdAsync(id); if (entityToDelete != null) { diff --git a/tests/LinkDotNet.Blog.IntegrationTests/SqlDatabaseTestBase.cs b/tests/LinkDotNet.Blog.IntegrationTests/SqlDatabaseTestBase.cs index bf7b29c0..0c98c6a5 100644 --- a/tests/LinkDotNet.Blog.IntegrationTests/SqlDatabaseTestBase.cs +++ b/tests/LinkDotNet.Blog.IntegrationTests/SqlDatabaseTestBase.cs @@ -12,19 +12,26 @@ namespace LinkDotNet.Blog.IntegrationTests; public abstract class SqlDatabaseTestBase : IAsyncLifetime, IAsyncDisposable where TEntity : Entity { + private readonly Mock> dbContextFactory; + protected SqlDatabaseTestBase() { var options = new DbContextOptionsBuilder() .UseSqlite(CreateInMemoryConnection()) .Options; DbContext = new BlogDbContext(options); - Repository = new Repository(new BlogDbContext(options)); + dbContextFactory = new Mock>(); + dbContextFactory.Setup(d => d.CreateDbContextAsync(default)) + .ReturnsAsync(() => new BlogDbContext(options)); + Repository = new Repository(dbContextFactory.Object); } protected IRepository Repository { get; } protected BlogDbContext DbContext { get; } + protected IDbContextFactory DbContextFactory => dbContextFactory.Object; + public Task InitializeAsync() { return Task.CompletedTask; diff --git a/tests/LinkDotNet.Blog.IntegrationTests/Web/Features/Admin/Dashboard/Components/VisitCountPerPageTests.cs b/tests/LinkDotNet.Blog.IntegrationTests/Web/Features/Admin/Dashboard/Components/VisitCountPerPageTests.cs index 04344a74..129adde0 100644 --- a/tests/LinkDotNet.Blog.IntegrationTests/Web/Features/Admin/Dashboard/Components/VisitCountPerPageTests.cs +++ b/tests/LinkDotNet.Blog.IntegrationTests/Web/Features/Admin/Dashboard/Components/VisitCountPerPageTests.cs @@ -10,6 +10,7 @@ using LinkDotNet.Blog.Web.Features.Admin.Dashboard.Components; using LinkDotNet.Blog.Web.Features.Admin.Dashboard.Services; using Microsoft.AspNetCore.Components; +using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; namespace LinkDotNet.Blog.IntegrationTests.Web.Features.Admin.Dashboard.Components; @@ -22,8 +23,7 @@ public async Task ShouldShowCounts() var blogPost = new BlogPostBuilder().WithTitle("I was clicked").WithLikes(2).Build(); await Repository.StoreAsync(blogPost); using var ctx = new TestContext(); - ctx.Services.AddScoped>(_ => new Repository(DbContext)); - ctx.Services.AddScoped>(_ => new Repository(DbContext)); + RegisterRepositories(ctx); await SaveBlogPostArticleClicked(blogPost.Id, 10); var cut = ctx.RenderComponent(); @@ -58,8 +58,7 @@ public async Task ShouldFilterByDate() await DbContext.SaveChangesAsync(); using var ctx = new TestContext(); ctx.ComponentFactories.Add(); - ctx.Services.AddScoped>(_ => new Repository(DbContext)); - ctx.Services.AddScoped>(_ => new Repository(DbContext)); + RegisterRepositories(ctx); var cut = ctx.RenderComponent(); var filter = new Filter { StartDate = new DateTime(2019, 1, 1), EndDate = new DateTime(2020, 12, 31) }; @@ -93,8 +92,7 @@ public async Task ShouldShowTotalClickCount() await DbContext.UserRecords.AddRangeAsync(new[] { clicked1, clicked2, clicked3, clicked4 }); await DbContext.SaveChangesAsync(); using var ctx = new TestContext(); - ctx.Services.AddScoped>(_ => new Repository(DbContext)); - ctx.Services.AddScoped>(_ => new Repository(DbContext)); + RegisterRepositories(ctx); var cut = ctx.RenderComponent(); @@ -102,6 +100,12 @@ public async Task ShouldShowTotalClickCount() cut.Find("#total-clicks").Unwrap().TextContent.Should().Be("4 clicks in total"); } + private void RegisterRepositories(TestContextBase ctx) + { + ctx.Services.AddScoped>(_ => new Repository(DbContextFactory)); + ctx.Services.AddScoped>(_ => new Repository(DbContextFactory)); + } + private async Task SaveBlogPostArticleClicked(string blogPostId, int count) { var urlClicked = $"blogPost/{blogPostId}";