Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Programatic connection factory options #587

Merged
merged 3 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,12 @@ akka.persistence.r2dbc {
# This timeout is handled by the database server.
# This timeout should be less than `close-calls-exceeding`.
statement-timeout = off

# Possibility to programatically amend the ConnectionFactoryOptions.
# Enable by specifying the fully qualified class name of a
# `akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider`.
# The class can optionally have a constructor with an ActorSystem parameter.
options-provider = ""
// #connection-settings-postgres
// #connection-settings-yugabyte
}
Expand Down Expand Up @@ -413,6 +419,12 @@ akka.persistence.r2dbc {
# Used to encode tags to and from db. Tags must not contain this separator.
tag-separator = ","

# Possibility to programatically amend the ConnectionFactoryOptions.
# Enable by specifying the fully qualified class name of a
# `akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider`.
# The class can optionally have a constructor with an ActorSystem parameter.
options-provider = ""

// #connection-settings-sqlserver
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.util.Failure
import scala.util.Success

import com.typesafe.config.Config
import io.r2dbc.spi.ConnectionFactoryOptions

import akka.annotation.InternalApi
import akka.annotation.InternalStableApi
Expand All @@ -29,9 +34,24 @@ object ConnectionFactoryProvider extends ExtensionId[ConnectionFactoryProvider]

// Java API
def get(system: ActorSystem[_]): ConnectionFactoryProvider = apply(system)

trait ConnectionFactoryOptionsProvider {
def buildOptions(
builder: ConnectionFactoryOptions.Builder,
connectionFactoryConfig: Config): ConnectionFactoryOptions
}

private object DefaultConnectionFactoryOptionsProvider extends ConnectionFactoryOptionsProvider {
override def buildOptions(
builder: ConnectionFactoryOptions.Builder,
connectionFactoryConfig: Config): ConnectionFactoryOptions =
builder.build()
}
}

class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension {
import ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import ConnectionFactoryProvider.DefaultConnectionFactoryOptionsProvider

import R2dbcExecutor.PublisherOps
private val sessions = new ConcurrentHashMap[String, ConnectionPool]
Expand All @@ -51,7 +71,8 @@ class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension {
configLocation,
configLocation => {
val settings = connectionFactorySettingsFor(configLocation)
val connectionFactory = settings.dialect.createConnectionFactory(settings.config)
val optionsProvider = connectionFactoryOptionsProvider(settings)
val connectionFactory = settings.dialect.createConnectionFactory(settings.config, optionsProvider)
createConnectionPoolFactory(settings.poolSettings, connectionFactory)
})
.asInstanceOf[ConnectionFactory]
Expand All @@ -72,6 +93,22 @@ class ConnectionFactoryProvider(system: ActorSystem[_]) extends Extension {
}
}

private def connectionFactoryOptionsProvider(
settings: ConnectionFactorySettings): ConnectionFactoryOptionsProvider = {
settings.optionsProvider match {
case "" => DefaultConnectionFactoryOptionsProvider
case fqcn =>
system.dynamicAccess.createInstanceFor[ConnectionFactoryOptionsProvider](fqcn, Nil) match {
case Success(provider) => provider
case Failure(_) =>
system.dynamicAccess
.createInstanceFor[ConnectionFactoryOptionsProvider](fqcn, List(classOf[ActorSystem[_]] -> system))
.get

}
}
}

/**
* INTERNAL API
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import com.typesafe.config.Config
import org.slf4j.Logger
import org.slf4j.LoggerFactory

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider

/**
* INTERNAL API
*/
Expand Down Expand Up @@ -44,7 +46,10 @@ private[r2dbc] object ConnectionFactorySettings {
// for backwards compatibility/convenience
val poolSettings = new ConnectionPoolSettings(config)

ConnectionFactorySettings(dialect, config, poolSettings)
// H2 dialect doesn't support options-provider
val optionsProvider = if (dialect == H2Dialect) "" else config.getString("options-provider")

ConnectionFactorySettings(dialect, config, poolSettings, optionsProvider)
}

}
Expand All @@ -56,4 +61,5 @@ private[r2dbc] object ConnectionFactorySettings {
private[r2dbc] case class ConnectionFactorySettings(
dialect: Dialect,
config: Config,
poolSettings: ConnectionPoolSettings)
poolSettings: ConnectionPoolSettings,
optionsProvider: String)
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import akka.annotation.InternalStableApi
import akka.persistence.r2dbc.R2dbcSettings
import com.typesafe.config.Config
import io.r2dbc.spi.ConnectionFactory

import scala.concurrent.ExecutionContext

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider

/**
* INTERNAL API
*/
Expand All @@ -31,7 +32,7 @@ private[r2dbc] trait Dialect {

def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext

def createConnectionFactory(config: Config): ConnectionFactory
def createConnectionFactory(config: Config, optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory

def createJournalDao(executorProvider: R2dbcExecutorProvider): JournalDao

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.Locale

import scala.concurrent.ExecutionContext

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider
import akka.persistence.r2dbc.internal.codec.IdentityAdapter
import akka.persistence.r2dbc.internal.codec.QueryAdapter
Expand All @@ -46,7 +47,9 @@ private[r2dbc] object H2Dialect extends Dialect {
res
}

override def createConnectionFactory(config: Config): ConnectionFactory = {
override def createConnectionFactory(
config: Config,
optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = {
// starting point for both url and regular configs,
// to allow url to override anything but provide sane defaults
val builder = H2ConnectionConfiguration.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import io.r2dbc.spi.ConnectionFactories
import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.ConnectionFactoryOptions

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider

/**
Expand Down Expand Up @@ -66,7 +67,9 @@ private[r2dbc] object PostgresDialect extends Dialect {
}
}

override def createConnectionFactory(config: Config): ConnectionFactory = {
override def createConnectionFactory(
config: Config,
optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = {
val settings = new PostgresConnectionFactorySettings(config)
val builder =
settings.urlOption match {
Expand Down Expand Up @@ -115,7 +118,8 @@ private[r2dbc] object PostgresDialect extends Dialect {
builder.option(PostgresqlConnectionFactoryProvider.SSL_PASSWORD, settings.sslPassword)
}

ConnectionFactories.get(builder.build())
val options = optionsProvider.buildOptions(builder, config)
ConnectionFactories.get(options)
}

override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import akka.persistence.r2dbc.internal.SnapshotDao
import com.typesafe.config.Config
import io.r2dbc.spi.ConnectionFactory

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider

/**
Expand All @@ -27,8 +28,10 @@ private[r2dbc] object YugabyteDialect extends Dialect {

override def name: String = "yugabyte"

override def createConnectionFactory(config: Config): ConnectionFactory =
PostgresDialect.createConnectionFactory(config)
override def createConnectionFactory(
config: Config,
optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory =
PostgresDialect.createConnectionFactory(config, optionsProvider)

override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext =
system.executionContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import io.r2dbc.spi.ConnectionFactories
import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.ConnectionFactoryOptions

import akka.persistence.r2dbc.ConnectionFactoryProvider.ConnectionFactoryOptionsProvider
import akka.persistence.r2dbc.internal.R2dbcExecutorProvider

/**
Expand Down Expand Up @@ -59,7 +60,9 @@ private[r2dbc] object SqlServerDialect extends Dialect {
res
}

override def createConnectionFactory(config: Config): ConnectionFactory = {
override def createConnectionFactory(
config: Config,
optionsProvider: ConnectionFactoryOptionsProvider): ConnectionFactory = {

val settings = new SqlServerConnectionFactorySettings(config)
val builder =
Expand All @@ -79,11 +82,13 @@ private[r2dbc] object SqlServerDialect extends Dialect {
.option(ConnectionFactoryOptions.DATABASE, settings.database)
.option(ConnectionFactoryOptions.CONNECT_TIMEOUT, JDuration.ofMillis(settings.connectTimeout.toMillis))
}
ConnectionFactories.get(
builder
//the option below is necessary to avoid https://github.com/r2dbc/r2dbc-mssql/issues/276
.option(MssqlConnectionFactoryProvider.PREFER_CURSORED_EXECUTION, false)
.build())

builder
//the option below is necessary to avoid https://github.com/r2dbc/r2dbc-mssql/issues/276
.option(MssqlConnectionFactoryProvider.PREFER_CURSORED_EXECUTION, false)

val options = optionsProvider.buildOptions(builder, config)
ConnectionFactories.get(options)
}

override def daoExecutionContext(settings: R2dbcSettings, system: ActorSystem[_]): ExecutionContext =
Expand Down
49 changes: 49 additions & 0 deletions core/src/test/scala/akka/persistence/r2dbc/R2dbcSettingsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ class R2dbcSettingsSpec extends AnyWordSpec with TestSuite with Matchers {
connectionFactorySettings.sslMode shouldBe "verify-full"
SSLMode.fromValue(connectionFactorySettings.sslMode) shouldBe SSLMode.VERIFY_FULL
}

"support options-provider" in {
val config = ConfigFactory
.parseString("akka.persistence.r2dbc.connection-factory.options-provider=my.OptProvider")
.withFallback(ConfigFactory.load("application-postgres.conf"))
val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc"))
settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider"
}
}

"data-partition settings" should {
Expand Down Expand Up @@ -287,5 +295,46 @@ class R2dbcSettingsSpec extends AnyWordSpec with TestSuite with Matchers {
settings.connectionFactorSliceRanges(0) should be(0 until 1024)
}

"support options-provider" in {
val config = ConfigFactory
.parseString("""
akka.persistence.r2dbc.postgres.options-provider=my.OptProvider
akka.persistence.r2dbc.data-partition {
number-of-partitions = 2
number-of-databases = 2
}
akka.persistence.r2dbc.connection-factory-0-0 = ${akka.persistence.r2dbc.postgres}
akka.persistence.r2dbc.connection-factory-0-0.host = hostA
akka.persistence.r2dbc.connection-factory-1-1 = ${akka.persistence.r2dbc.postgres}
akka.persistence.r2dbc.connection-factory-1-1.host = hostB
""")
.withFallback(ConfigFactory.load("application-postgres.conf"))
.resolve()
val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc"))
settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider"
settings.connectionFactorySettings(1).optionsProvider shouldBe "my.OptProvider"
}

"support options-provider per db" in {
val config = ConfigFactory
.parseString("""
akka.persistence.r2dbc.data-partition {
number-of-partitions = 2
number-of-databases = 2
}
akka.persistence.r2dbc.connection-factory-0-0 = ${akka.persistence.r2dbc.postgres}
akka.persistence.r2dbc.connection-factory-0-0.host = hostA
akka.persistence.r2dbc.connection-factory-0-0.options-provider=my.OptProvider0
akka.persistence.r2dbc.connection-factory-1-1 = ${akka.persistence.r2dbc.postgres}
akka.persistence.r2dbc.connection-factory-1-1.host = hostB
akka.persistence.r2dbc.connection-factory-1-1.options-provider=my.OptProvider1
""")
.withFallback(ConfigFactory.load("application-postgres.conf"))
.resolve()
val settings = R2dbcSettings(config.getConfig("akka.persistence.r2dbc"))
settings.connectionFactorySettings(0).optionsProvider shouldBe "my.OptProvider0"
settings.connectionFactorySettings(1023).optionsProvider shouldBe "my.OptProvider1"
}

}
}