Skip to content

Commit

Permalink
Updated the Testing package to support multi dbcontext with SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
Farshad DASHTI authored and Farshad DASHTI committed Oct 15, 2024
1 parent 2d72080 commit 076778d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 91 deletions.

This file was deleted.

71 changes: 32 additions & 39 deletions src/DfE.CoreLibs.Testing/Helpers/DbContextHelper.cs
Original file line number Diff line number Diff line change
@@ -1,60 +1,53 @@
using System.Data.Common;
using System.Diagnostics.CodeAnalysis;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.Extensions.DependencyInjection;
using System.Data.Common;
using System.Diagnostics.CodeAnalysis;

namespace DfE.CoreLibs.Testing.Helpers
{
[ExcludeFromCodeCoverage]
public static class DbContextHelper
{
public static void CreateDbContext<TContext>(IServiceCollection services, Action<TContext>? seedTestData = null) where TContext : DbContext
public static void CreateDbContext<TContext>(
IServiceCollection services,
DbConnection connection,
Action<TContext>? seedTestData = null) where TContext : DbContext
{
var connectionString = GetConnectionStringFromConfig();
ConfigureDbContext<TContext>(services, connection);
InitializeDbContext(services, seedTestData);
}

if (string.IsNullOrEmpty(connectionString) || connectionString.Contains("DataSource=:memory:"))
public static void ConfigureDbContext<TContext>(
IServiceCollection services,
DbConnection connection) where TContext : DbContext
{
services.AddDbContext<TContext>((sp, options) =>
{
// Sqlite doesn't seem to allow multiple dbContexts added to the same connection
// We are creating a separate in-memory database for each dbContext
// Please feel free to update if you have a better/ more efficient solution
var connection = new SqliteConnection("DataSource=:memory:");

connection.Open();
options.UseSqlite(connection);
});
}

services.AddSingleton(connection);
private static void InitializeDbContext<TContext>(
IServiceCollection services,
Action<TContext>? seedTestData) where TContext : DbContext
{
var serviceProvider = services.BuildServiceProvider();
using var scope = serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<TContext>();

services.AddDbContext<TContext>((sp, options) =>
{
options.UseSqlite(connection);
});
var relationalDatabaseCreator = dbContext.Database.GetService<IRelationalDatabaseCreator>();
if (!dbContext.Database.CanConnect())
{
relationalDatabaseCreator.Create();
}
else
{
services.AddDbContext<TContext>(options =>
{
options.UseSqlServer(connectionString);
});
relationalDatabaseCreator.CreateTables();
}

var serviceProvider = services.BuildServiceProvider();

using var scope = serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<TContext>();
dbContext.Database.EnsureCreated();

seedTestData?.Invoke(dbContext);
}

private static string? GetConnectionStringFromConfig()
{
var configuration = new ConfigurationBuilder()
.SetBasePath(Directory.GetCurrentDirectory())
.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true)
.Build();

return configuration.GetConnectionString("DefaultConnection");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DfE.CoreLibs.Testing.Helpers;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Mvc.Testing;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using System.Data.Common;
Expand All @@ -13,7 +14,7 @@ namespace DfE.CoreLibs.Testing.Mocks.WebApplicationFactory
public class CustomWebApplicationDbContextFactory<TProgram> : WebApplicationFactory<TProgram>
where TProgram : class
{
public List<Claim>? TestClaims { get; set; } = [];
public List<Claim>? TestClaims { get; set; } = new();
public Dictionary<Type, Action<DbContext>>? SeedData { get; set; }
public Action<IServiceCollection>? ExternalServicesConfiguration { get; set; }
public Action<HttpClient>? ExternalHttpClientConfiguration { get; set; }
Expand All @@ -22,35 +23,23 @@ protected override void ConfigureWebHost(IWebHostBuilder builder)
{
builder.ConfigureServices(services =>
{
RemoveDbContextAndConnectionServices(services);

var dbContextDescriptors = services
.Where(d => d.ServiceType.IsGenericType && d.ServiceType.GetGenericTypeDefinition() == typeof(DbContextOptions<>))
.ToList();
foreach (var dbContextDescriptor in dbContextDescriptors)
{
services.Remove(dbContextDescriptor);
}
var connection = new SqliteConnection("DataSource=:memory:");
connection.Open();
services.AddSingleton(connection);

var dbConnectionDescriptor = services.SingleOrDefault(d => d.ServiceType == typeof(DbConnection));
if (dbConnectionDescriptor != null)
{
services.Remove(dbConnectionDescriptor);
}

foreach (var entry in SeedData ?? [])
foreach (var entry in SeedData ?? new Dictionary<Type, Action<DbContext>>())
{
var dbContextType = entry.Key;
var seedAction = entry.Value;

var createDbContextMethod = typeof(DbContextHelper).GetMethod(nameof(DbContextHelper.CreateDbContext))
?.MakeGenericMethod(dbContextType);

createDbContextMethod?.Invoke(null, new object[] { services, seedAction });
createDbContextMethod?.Invoke(null, new object[] { services, connection, seedAction });
}

ExternalServicesConfiguration?.Invoke(services);

services.AddSingleton<IEnumerable<Claim>>(sp => TestClaims ?? []);
services.AddSingleton<IEnumerable<Claim>>(sp => TestClaims ?? new());
});

builder.UseEnvironment("Development");
Expand All @@ -59,7 +48,6 @@ protected override void ConfigureWebHost(IWebHostBuilder builder)
protected override void ConfigureClient(HttpClient client)
{
ExternalHttpClientConfiguration?.Invoke(client);

base.ConfigureClient(client);
}

Expand All @@ -70,5 +58,21 @@ public TDbContext GetDbContext<TDbContext>() where TDbContext : DbContext
return scope.ServiceProvider.GetRequiredService<TDbContext>();
}

private static void RemoveDbContextAndConnectionServices(IServiceCollection services)
{
var dbContextDescriptors = services
.Where(d => d.ServiceType.IsGenericType && d.ServiceType.GetGenericTypeDefinition() == typeof(DbContextOptions<>))
.ToList();
foreach (var dbContextDescriptor in dbContextDescriptors)
{
services.Remove(dbContextDescriptor);
}

var dbConnectionDescriptor = services.SingleOrDefault(d => d.ServiceType == typeof(DbConnection));
if (dbConnectionDescriptor != null)
{
services.Remove(dbConnectionDescriptor);
}
}
}
}

0 comments on commit 076778d

Please sign in to comment.