Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: change tracking schema detection issues #397

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace Firebend.AutoCrud.ChangeTracking.EntityFramework.Implementations;
public class ChangeTrackingDbContextProvider<TEntityKey, TEntity>(
IDbContextFactory<ChangeTrackingDbContext<TEntityKey, TEntity>> contextFactory,
ILogger<ChangeTrackingDbContextProvider<TEntityKey, TEntity>> logger,
IChangeTrackingTableNameProvider<TEntityKey, TEntity> tableNameProvider,
IDbContextConnectionStringProvider<TEntityKey, TEntity> connectionStringProvider = null)
:
DbContextProvider<Guid, ChangeTrackingEntity<TEntityKey, TEntity>,
Expand All @@ -26,14 +27,15 @@ public class ChangeTrackingDbContextProvider<TEntityKey, TEntity>(
where TEntity : class, IEntity<TEntityKey>
where TEntityKey : struct
{
private record ScaffoldCacheContext(DbContext DbContext, ILogger Logger);
private record ScaffoldCacheContext(DbContext DbContext, TableNameResult TableName, ILogger Logger);

protected override bool WaitForMigrations => false;

protected override void InitDb(DbContext dbContext)
{
base.InitDb(dbContext);
ScaffoldDbContext(dbContext, logger);
var tableName = tableNameProvider.GetTableName();
ScaffoldDbContext(dbContext, tableName, logger);
}

protected override async Task<string> ProvideConnectionString(CancellationToken cancellationToken)
Expand All @@ -46,29 +48,22 @@ protected override async Task<string> ProvideConnectionString(CancellationToken
return await connectionStringProvider.GetConnectionStringAsync(cancellationToken);
}

private static void ScaffoldDbContext(DbContext context, ILogger logger)
private static void ScaffoldDbContext(DbContext context, TableNameResult tableName, ILogger logger)
{
var dbConn = context.Database.GetDbConnection();
var cacheKey = $"{dbConn.DataSource}_{dbConn.Database}_{typeof(TEntity).FullName}";

ChangeTrackingDbContextProviderCache.ScaffoldCache.GetOrAdd(cacheKey,
ScaffoldCacheFactory,
new ScaffoldCacheContext(context, logger));
new ScaffoldCacheContext(context, tableName, logger));
}

private static bool ScaffoldCacheFactory(string typeName, ScaffoldCacheContext scaffoldCacheContext)
{
var changeTrackingType = typeof(ChangeTrackingEntity<TEntityKey, TEntity>);
var type = scaffoldCacheContext.DbContext.Model.FindEntityType(changeTrackingType);

if (type is null)
{
scaffoldCacheContext.Logger.LogWarning("Could not find entity type for {TypeName}", changeTrackingType.FullName);
return false;
}

var schema = type.GetSchema() ?? "dbo";
var table = type.GetTableName();
var schema = scaffoldCacheContext.TableName.Schema;
var table = scaffoldCacheContext.TableName.Table;

if (string.IsNullOrEmpty(table))
{
Expand Down Expand Up @@ -115,9 +110,7 @@ SELECT 1 FROM sys.tables AS T
INNER JOIN sys.schemas AS S ON T.schema_id = S.schema_id
WHERE S.Name = '{schemaName}' AND T.Name = '{tableName}'
""";

var exists = command.ExecuteScalar() != null;

return exists;
var result = command.ExecuteScalar();
return result != null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public class ChangeTrackingTableNameProvider<TEntityKey, TEntity, TEntityContext
where TEntityKey : struct
where TEntityContext : DbContext, IDbContext
{
private static readonly Type EntityType = typeof(TEntity);
private TableNameResult _result;

public TableNameResult GetTableName()
Expand All @@ -22,12 +23,18 @@ public TableNameResult GetTableName()
}

using var context = contextFactory.CreateDbContext();
var entityType = context.Model.FindEntityType(typeof(TEntity)) ?? throw new Exception($"Entity type {typeof(TEntity).FullName} not found in the model.");

var tableName = entityType.GetTableName() + "_Changes";
var schema = entityType.GetSchema();
_result = new TableNameResult(tableName, schema);
_result = GetTableName(context);

return _result;
}

private TableNameResult GetTableName(DbContext context)
{
var entityType = context.Model.FindEntityType(EntityType) ??
throw new Exception($"Entity type {EntityType.FullName} not found in the model.");

var tableName = entityType.GetTableName() + "_Changes";
var schema = entityType.GetSchema() ?? context.Model.GetDefaultSchema() ?? "dbo";
return new TableNameResult(tableName, schema);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ public interface IChangeTrackingTableNameProvider<TEntityKey, TEntity>
where TEntity : class, IEntity<TEntityKey>
where TEntityKey : struct
{
TableNameResult GetTableName();
public TableNameResult GetTableName();
}
Loading