diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index c52a6e33..c9867425 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -325,6 +325,12 @@ akka.persistence.r2dbc { # This timeout is handled by the database server. # This timeout should be less than `close-calls-exceeding`. statement-timeout = off + + # Possibility to programatically amend the ConnectionFactoryOptions. + # Enable by specifying the fully qualified class name of a + # `akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider`. + # The class can optionally have a constructor with an ActorSystem parameter. + options-provider = "" // #connection-settings-postgres // #connection-settings-yugabyte } @@ -413,6 +419,12 @@ akka.persistence.r2dbc { # Used to encode tags to and from db. Tags must not contain this separator. tag-separator = "," + # Possibility to programatically amend the ConnectionFactoryOptions. + # Enable by specifying the fully qualified class name of a + # `akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider`. + # The class can optionally have a constructor with an ActorSystem parameter. + options-provider = "" + // #connection-settings-sqlserver } diff --git a/core/src/main/scala/akka/persistence/r2dbc/ConnectionFactoryProvider.scala b/core/src/main/scala/akka/persistence/r2dbc/ConnectionFactoryProvider.scala index ec5c4e19..2465a5ac 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/ConnectionFactoryProvider.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/ConnectionFactoryProvider.scala @@ -20,6 +20,11 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.concurrent.Future import scala.concurrent.duration.Duration +import scala.util.Failure +import scala.util.Success + +import com.typesafe.config.Config +import io.r2dbc.spi.ConnectionFactoryOptions import akka.annotation.InternalApi import akka.annotation.InternalStableApi @@ -29,9 +34,24 @@ object ConnectionFactoryProvider extends ExtensionId[ConnectionFactoryProvider] // Java API def get(system: ActorSystem[_]): ConnectionFactoryProvider = apply(system) + + trait ConnectionFactoryOptionsProvider { + def buildOptions( + builder: ConnectionFactoryOptions.Builder, + connectionFactoryConfig: Config): ConnectionFactoryOptions + } + + private object DefaultConnectionFactoryOptionsProvider extends ConnectionFactoryOptionsProvider { + override def buildOptions( + builder: ConnectionFactoryOptions.Builder, + connectionFactoryConfig: Config): ConnectionFactoryOptions = + builder.build() + } } class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension { + import ConnectionFactoryProvider.ConnectionFactoryOptionsProvider + import ConnectionFactoryProvider.DefaultConnectionFactoryOptionsProvider import R2dbcExecutor.PublisherOps private val sessions = new ConcurrentHashMap[String, ConnectionPool] @@ -51,7 +71,8 @@ class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension { configLocation, configLocation => { val settings = connectionFactorySettingsFor(configLocation) - val connectionFactory = settings.dialect.createConnectionFactory(settings.config) + val optionsProvider = connectionFactoryOptionsProvider(settings) + val connectionFactory = settings.dialect.createConnectionFactory(settings.config, optionsProvider) createConnectionPoolFactory(settings.poolSettings, connectionFactory) }) .asInstanceOf[ConnectionFactory] @@ -72,6 +93,22 @@ class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension { } } + private def connectionFactoryOptionsProvider( + settings: ConnectionFactorySettings): ConnectionFactoryOptionsProvider = { + settings.optionsProvider match { + case "" => DefaultConnectionFactoryOptionsProvider + case fqcn => + system.dynamicAccess.createInstanceFor[ConnectionFactoryOptionsProvider](fqcn, Nil) match { + case Success(provider) => provider + case Failure(_) => + system.dynamicAccess + .createInstanceFor[ConnectionFactoryOptionsProvider](fqcn, List(classOf[ActorSystem[_]] -> system)) + .get + + } + } + } + /** * INTERNAL API */ diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/ConnectionFactorySettings.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/ConnectionFactorySettings.scala index aaa17363..cb2d11de 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/ConnectionFactorySettings.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/ConnectionFactorySettings.scala @@ -16,6 +16,8 @@ import com.typesafe.config.Config import org.slf4j.Logger import org.slf4j.LoggerFactory +import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider + /** * INTERNAL API */ @@ -44,7 +46,10 @@ private[r2dbc] object ConnectionFactorySettings { // for backwards compatibility/convenience val poolSettings = new ConnectionPoolSettings(config) - ConnectionFactorySettings(dialect, config, poolSettings) + // H2 dialect doesn't support options-provider + val optionsProvider = if (dialect == H2Dialect) "" else config.getString("options-provider") + + ConnectionFactorySettings(dialect, config, poolSettings, optionsProvider) } } @@ -56,4 +61,5 @@ private[r2dbc] object ConnectionFactorySettings { private[r2dbc] case class ConnectionFactorySettings( dialect: Dialect, config: Config, - poolSettings: ConnectionPoolSettings) + poolSettings: ConnectionPoolSettings, + optionsProvider: String) diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/Dialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/Dialect.scala index f66617cd..617aae1d 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/Dialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/Dialect.scala @@ -9,9 +9,10 @@ import akka.annotation.InternalStableApi import akka.persistence.r2dbc.R2dbcSettings import com.typesafe.config.Config import io.r2dbc.spi.ConnectionFactory - import scala.concurrent.ExecutionContext +import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider + /** * INTERNAL API */ @@ -31,7 +32,7 @@ private[r2dbc] trait Dialect { def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext - def createConnectionFactory(config: Config): ConnectionFactory + def createConnectionFactory(config: Config, optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory def createJournalDao(executorProvider: R2dbcExecutorProvider): JournalDao diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2Dialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2Dialect.scala index a95d10ed..c93dffa9 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2Dialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/h2/H2Dialect.scala @@ -24,6 +24,7 @@ import java.util.Locale import scala.concurrent.ExecutionContext +import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider import akka.persistence.r2dbc.internal.R2dbcExecutorProvider import akka.persistence.r2dbc.internal.codec.IdentityAdapter import akka.persistence.r2dbc.internal.codec.QueryAdapter @@ -46,7 +47,9 @@ private[r2dbc] object H2Dialect extends Dialect { res } - override def createConnectionFactory(config: Config): ConnectionFactory = { + override def createConnectionFactory( + config: Config, + optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = { // starting point for both url and regular configs, // to allow url to override anything but provide sane defaults val builder = H2ConnectionConfiguration.builder() diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDialect.scala index 01d2e4b8..7c1c7aab 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/PostgresDialect.scala @@ -26,6 +26,7 @@ import io.r2dbc.spi.ConnectionFactories import io.r2dbc.spi.ConnectionFactory import io.r2dbc.spi.ConnectionFactoryOptions +import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider import akka.persistence.r2dbc.internal.R2dbcExecutorProvider /** @@ -66,7 +67,9 @@ private[r2dbc] object PostgresDialect extends Dialect { } } - override def createConnectionFactory(config: Config): ConnectionFactory = { + override def createConnectionFactory( + config: Config, + optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = { val settings = new PostgresConnectionFactorySettings(config) val builder = settings.urlOption match { @@ -115,7 +118,8 @@ private[r2dbc] object PostgresDialect extends Dialect { builder.option(PostgresqlConnectionFactoryProvider.SSL_PASSWORD, settings.sslPassword) } - ConnectionFactories.get(builder.build()) + val options = optionsProvider.buildOptions(builder, config) + ConnectionFactories.get(options) } override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext = diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDialect.scala index 49a69c22..bd2862b8 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/postgres/YugabyteDialect.scala @@ -17,6 +17,7 @@ import akka.persistence.r2dbc.internal.SnapshotDao import com.typesafe.config.Config import io.r2dbc.spi.ConnectionFactory +import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider import akka.persistence.r2dbc.internal.R2dbcExecutorProvider /** @@ -27,8 +28,10 @@ private[r2dbc] object YugabyteDialect extends Dialect { override def name: String = "yugabyte" - override def createConnectionFactory(config: Config): ConnectionFactory = - PostgresDialect.createConnectionFactory(config) + override def createConnectionFactory( + config: Config, + optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = + PostgresDialect.createConnectionFactory(config, optionsProvider) override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext = system.executionContext diff --git a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala index ccb97fff..ee6d5125 100644 --- a/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala +++ b/core/src/main/scala/akka/persistence/r2dbc/internal/sqlserver/SqlServerDialect.scala @@ -24,6 +24,7 @@ import io.r2dbc.spi.ConnectionFactories import io.r2dbc.spi.ConnectionFactory import io.r2dbc.spi.ConnectionFactoryOptions +import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider import akka.persistence.r2dbc.internal.R2dbcExecutorProvider /** @@ -59,7 +60,9 @@ private[r2dbc] object SqlServerDialect extends Dialect { res } - override def createConnectionFactory(config: Config): ConnectionFactory = { + override def createConnectionFactory( + config: Config, + optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = { val settings = new SqlServerConnectionFactorySettings(config) val builder = @@ -79,11 +82,13 @@ private[r2dbc] object SqlServerDialect extends Dialect { .option(ConnectionFactoryOptions.DATABASE, settings.database) .option(ConnectionFactoryOptions.CONNECT_TIMEOUT, JDuration.ofMillis(settings.connectTimeout.toMillis)) } - ConnectionFactories.get( - builder - //the option below is necessary to avoid https://github.com/r2dbc/r2dbc-mssql/issues/276 - .option(MssqlConnectionFactoryProvider.PREFER_CURSORED_EXECUTION, false) - .build()) + + builder + //the option below is necessary to avoid https://github.com/r2dbc/r2dbc-mssql/issues/276 + .option(MssqlConnectionFactoryProvider.PREFER_CURSORED_EXECUTION, false) + + val options = optionsProvider.buildOptions(builder, config) + ConnectionFactories.get(options) } override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext = diff --git a/core/src/test/scala/akka/persistence/r2dbc/R2dbcSettingsSpec.scala b/core/src/test/scala/akka/persistence/r2dbc/R2dbcSettingsSpec.scala index f03374be..66a9cb57 100644 --- a/core/src/test/scala/akka/persistence/r2dbc/R2dbcSettingsSpec.scala +++ b/core/src/test/scala/akka/persistence/r2dbc/R2dbcSettingsSpec.scala @@ -64,6 +64,14 @@ class R2dbcSettingsSpec extends AnyWordSpec with TestSuite with Matchers { connectionFactorySettings.sslMode shouldBe "verify-full" SSLMode.fromValue(connectionFactorySettings.sslMode) shouldBe SSLMode.VERIFY_FULL } + + "support options-provider" in { + val config = ConfigFactory + .parseString("akka.persistence.r2dbc.connection-factory.options-provider=my.OptProvider") + .withFallback(ConfigFactory.load("application-postgres.conf")) + val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc")) + settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider" + } } "data-partition settings" should { @@ -287,5 +295,46 @@ class R2dbcSettingsSpec extends AnyWordSpec with TestSuite with Matchers { settings.connectionFactorSliceRanges(0) should be(0 until 1024) } + "support options-provider" in { + val config = ConfigFactory + .parseString(""" + akka.persistence.r2dbc.postgres.options-provider=my.OptProvider + akka.persistence.r2dbc.data-partition { + number-of-partitions = 2 + number-of-databases = 2 + } + akka.persistence.r2dbc.connection-factory-0-0 = ${akka.persistence.r2dbc.postgres} + akka.persistence.r2dbc.connection-factory-0-0.host = hostA + akka.persistence.r2dbc.connection-factory-1-1 = ${akka.persistence.r2dbc.postgres} + akka.persistence.r2dbc.connection-factory-1-1.host = hostB + """) + .withFallback(ConfigFactory.load("application-postgres.conf")) + .resolve() + val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc")) + settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider" + settings.connectionFactorySettings(1).optionsProvider shouldBe "my.OptProvider" + } + + "support options-provider per db" in { + val config = ConfigFactory + .parseString(""" + akka.persistence.r2dbc.data-partition { + number-of-partitions = 2 + number-of-databases = 2 + } + akka.persistence.r2dbc.connection-factory-0-0 = ${akka.persistence.r2dbc.postgres} + akka.persistence.r2dbc.connection-factory-0-0.host = hostA + akka.persistence.r2dbc.connection-factory-0-0.options-provider=my.OptProvider0 + akka.persistence.r2dbc.connection-factory-1-1 = ${akka.persistence.r2dbc.postgres} + akka.persistence.r2dbc.connection-factory-1-1.host = hostB + akka.persistence.r2dbc.connection-factory-1-1.options-provider=my.OptProvider1 + """) + .withFallback(ConfigFactory.load("application-postgres.conf")) + .resolve() + val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc")) + settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider0" + settings.connectionFactorySettings(1023).optionsProvider shouldBe "my.OptProvider1" + } + } }