diff --git a/src/DfE.CoreLibs.Testing/AutoFixture/Customizations/DbContextCustomization.cs b/src/DfE.CoreLibs.Testing/AutoFixture/Customizations/DbContextCustomization.cs deleted file mode 100644 index 1d7b566..0000000 --- a/src/DfE.CoreLibs.Testing/AutoFixture/Customizations/DbContextCustomization.cs +++ /dev/null @@ -1,31 +0,0 @@ -using AutoFixture; -using DfE.CoreLibs.Testing.Helpers; -using Microsoft.EntityFrameworkCore; -using Microsoft.Extensions.DependencyInjection; -using System.Diagnostics.CodeAnalysis; - -namespace DfE.CoreLibs.Testing.AutoFixture.Customizations -{ - [ExcludeFromCodeCoverage] - public class DbContextCustomization : ICustomization where TContext : DbContext - { - public void Customize(IFixture fixture) - { - fixture.Register>(() => null!); - - fixture.Customize(composer => composer.FromFactory(() => - { - var services = new ServiceCollection(); - - DbContextHelper.CreateDbContext(services); - - var serviceProvider = services.BuildServiceProvider(); - var dbContext = serviceProvider.GetRequiredService(); - - fixture.Inject(dbContext); - - return dbContext; - }).OmitAutoProperties()); - } - } -} diff --git a/src/DfE.CoreLibs.Testing/Helpers/DbContextHelper.cs b/src/DfE.CoreLibs.Testing/Helpers/DbContextHelper.cs index 4cdd4d0..22f432b 100644 --- a/src/DfE.CoreLibs.Testing/Helpers/DbContextHelper.cs +++ b/src/DfE.CoreLibs.Testing/Helpers/DbContextHelper.cs @@ -1,60 +1,53 @@ -using System.Data.Common; -using System.Diagnostics.CodeAnalysis; -using Microsoft.Data.Sqlite; using Microsoft.EntityFrameworkCore; -using Microsoft.Extensions.Configuration; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.DependencyInjection; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; namespace DfE.CoreLibs.Testing.Helpers { [ExcludeFromCodeCoverage] public static class DbContextHelper { - public static void CreateDbContext(IServiceCollection services, Action? seedTestData = null) where TContext : DbContext + public static void CreateDbContext( + IServiceCollection services, + DbConnection connection, + Action? seedTestData = null) where TContext : DbContext { - var connectionString = GetConnectionStringFromConfig(); + ConfigureDbContext(services, connection); + InitializeDbContext(services, seedTestData); + } - if (string.IsNullOrEmpty(connectionString) || connectionString.Contains("DataSource=:memory:")) + public static void ConfigureDbContext( + IServiceCollection services, + DbConnection connection) where TContext : DbContext + { + services.AddDbContext((sp, options) => { - // Sqlite doesn't seem to allow multiple dbContexts added to the same connection - // We are creating a separate in-memory database for each dbContext - // Please feel free to update if you have a better/ more efficient solution - var connection = new SqliteConnection("DataSource=:memory:"); - - connection.Open(); + options.UseSqlite(connection); + }); + } - services.AddSingleton(connection); + private static void InitializeDbContext( + IServiceCollection services, + Action? seedTestData) where TContext : DbContext + { + var serviceProvider = services.BuildServiceProvider(); + using var scope = serviceProvider.CreateScope(); + var dbContext = scope.ServiceProvider.GetRequiredService(); - services.AddDbContext((sp, options) => - { - options.UseSqlite(connection); - }); + var relationalDatabaseCreator = dbContext.Database.GetService(); + if (!dbContext.Database.CanConnect()) + { + relationalDatabaseCreator.Create(); } else { - services.AddDbContext(options => - { - options.UseSqlServer(connectionString); - }); + relationalDatabaseCreator.CreateTables(); } - var serviceProvider = services.BuildServiceProvider(); - - using var scope = serviceProvider.CreateScope(); - var dbContext = scope.ServiceProvider.GetRequiredService(); - dbContext.Database.EnsureCreated(); - seedTestData?.Invoke(dbContext); } - - private static string? GetConnectionStringFromConfig() - { - var configuration = new ConfigurationBuilder() - .SetBasePath(Directory.GetCurrentDirectory()) - .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) - .Build(); - - return configuration.GetConnectionString("DefaultConnection"); - } } -} +} \ No newline at end of file diff --git a/src/DfE.CoreLibs.Testing/Mocks/WebApplicationFactory/CustomWebApplicationDbContextFactory.cs b/src/DfE.CoreLibs.Testing/Mocks/WebApplicationFactory/CustomWebApplicationDbContextFactory.cs index 3ebf23a..0fddc4e 100644 --- a/src/DfE.CoreLibs.Testing/Mocks/WebApplicationFactory/CustomWebApplicationDbContextFactory.cs +++ b/src/DfE.CoreLibs.Testing/Mocks/WebApplicationFactory/CustomWebApplicationDbContextFactory.cs @@ -1,6 +1,7 @@ using DfE.CoreLibs.Testing.Helpers; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Mvc.Testing; +using Microsoft.Data.Sqlite; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using System.Data.Common; @@ -13,7 +14,7 @@ namespace DfE.CoreLibs.Testing.Mocks.WebApplicationFactory public class CustomWebApplicationDbContextFactory : WebApplicationFactory where TProgram : class { - public List? TestClaims { get; set; } = []; + public List? TestClaims { get; set; } = new(); public Dictionary>? SeedData { get; set; } public Action? ExternalServicesConfiguration { get; set; } public Action? ExternalHttpClientConfiguration { get; set; } @@ -22,35 +23,23 @@ protected override void ConfigureWebHost(IWebHostBuilder builder) { builder.ConfigureServices(services => { + RemoveDbContextAndConnectionServices(services); - var dbContextDescriptors = services - .Where(d => d.ServiceType.IsGenericType && d.ServiceType.GetGenericTypeDefinition() == typeof(DbContextOptions<>)) - .ToList(); - foreach (var dbContextDescriptor in dbContextDescriptors) - { - services.Remove(dbContextDescriptor); - } + var connection = new SqliteConnection("DataSource=:memory:"); + connection.Open(); + services.AddSingleton(connection); - var dbConnectionDescriptor = services.SingleOrDefault(d => d.ServiceType == typeof(DbConnection)); - if (dbConnectionDescriptor != null) - { - services.Remove(dbConnectionDescriptor); - } - - foreach (var entry in SeedData ?? []) + foreach (var entry in SeedData ?? new Dictionary>()) { var dbContextType = entry.Key; var seedAction = entry.Value; - var createDbContextMethod = typeof(DbContextHelper).GetMethod(nameof(DbContextHelper.CreateDbContext)) ?.MakeGenericMethod(dbContextType); - - createDbContextMethod?.Invoke(null, new object[] { services, seedAction }); + createDbContextMethod?.Invoke(null, new object[] { services, connection, seedAction }); } ExternalServicesConfiguration?.Invoke(services); - - services.AddSingleton>(sp => TestClaims ?? []); + services.AddSingleton>(sp => TestClaims ?? new()); }); builder.UseEnvironment("Development"); @@ -59,7 +48,6 @@ protected override void ConfigureWebHost(IWebHostBuilder builder) protected override void ConfigureClient(HttpClient client) { ExternalHttpClientConfiguration?.Invoke(client); - base.ConfigureClient(client); } @@ -70,5 +58,21 @@ public TDbContext GetDbContext() where TDbContext : DbContext return scope.ServiceProvider.GetRequiredService(); } + private static void RemoveDbContextAndConnectionServices(IServiceCollection services) + { + var dbContextDescriptors = services + .Where(d => d.ServiceType.IsGenericType && d.ServiceType.GetGenericTypeDefinition() == typeof(DbContextOptions<>)) + .ToList(); + foreach (var dbContextDescriptor in dbContextDescriptors) + { + services.Remove(dbContextDescriptor); + } + + var dbConnectionDescriptor = services.SingleOrDefault(d => d.ServiceType == typeof(DbConnection)); + if (dbConnectionDescriptor != null) + { + services.Remove(dbConnectionDescriptor); + } + } } }