diff --git a/cdc/sink/mq/mq.go b/cdc/sink/mq/mq.go index 15c5cc6fd2f..a093cc650d1 100644 --- a/cdc/sink/mq/mq.go +++ b/cdc/sink/mq/mq.go @@ -383,7 +383,7 @@ func NewKafkaSaramaSink(ctx context.Context, sinkURI *url.URL, } baseConfig := kafka.NewConfig() - if err := baseConfig.Apply(sinkURI); err != nil { + if err := baseConfig.Apply(sinkURI, replicaConfig); err != nil { return nil, cerror.WrapError(cerror.ErrKafkaInvalidConfig, err) } diff --git a/cdc/sink/mq/producer/kafka/config.go b/cdc/sink/mq/producer/kafka/config.go index 892896f57ad..4e34961f05a 100644 --- a/cdc/sink/mq/producer/kafka/config.go +++ b/cdc/sink/mq/producer/kafka/config.go @@ -16,6 +16,7 @@ package kafka import ( "context" "crypto/tls" + "encoding/base64" "net/url" "strconv" "strings" @@ -103,8 +104,8 @@ func (c *Config) setPartitionNum(realPartitionCount int32) error { return nil } -// Apply the sinkURI to update Config -func (c *Config) Apply(sinkURI *url.URL) error { +// Apply the configuration to the sarama producer. +func (c *Config) Apply(sinkURI *url.URL, replicaConfig *config.ReplicaConfig) error { c.BrokerEndpoints = strings.Split(sinkURI.Host, ",") params := sinkURI.Query() s := params.Get("partition-num") @@ -185,7 +186,7 @@ func (c *Config) Apply(sinkURI *url.URL) error { c.ReadTimeout = a } - err := c.applySASL(params) + err := c.applySASL(params, replicaConfig) if err != nil { return err } @@ -256,7 +257,7 @@ func (c *Config) applyTLS(params url.Values) error { return nil } -func (c *Config) applySASL(params url.Values) error { +func (c *Config) applySASL(params url.Values, replicaConfig *config.ReplicaConfig) error { s := params.Get("sasl-user") if s != "" { c.SASL.SASLUser = s @@ -274,6 +275,12 @@ func (c *Config) applySASL(params url.Values) error { return cerror.WrapError(cerror.ErrKafkaInvalidConfig, err) } c.SASL.SASLMechanism = mechanism + } else if replicaConfig != nil && replicaConfig.Sink != nil && replicaConfig.Sink.KafkaConfig != nil && replicaConfig.Sink.KafkaConfig.SASLMechanism != nil { + mechanism, err := security.SASLMechanismFromString(*replicaConfig.Sink.KafkaConfig.SASLMechanism) + if err != nil { + return cerror.WrapError(cerror.ErrKafkaInvalidConfig, err) + } + c.SASL.SASLMechanism = mechanism } s = params.Get("sasl-gssapi-auth-type") @@ -324,6 +331,67 @@ func (c *Config) applySASL(params url.Values) error { c.SASL.GSSAPI.DisablePAFXFAST = disablePAFXFAST } + if replicaConfig.Sink != nil && replicaConfig.Sink.KafkaConfig != nil { + if replicaConfig.Sink.KafkaConfig.SASLOAuthClientID != nil { + clientID := *replicaConfig.Sink.KafkaConfig.SASLOAuthClientID + if clientID == "" { + return cerror.ErrKafkaInvalidConfig.GenWithStack("OAuth2 client ID cannot be empty") + } + c.SASL.OAuth2.ClientID = clientID + } + + if replicaConfig.Sink.KafkaConfig.SASLOAuthClientSecret != nil { + clientSecret := *replicaConfig.Sink.KafkaConfig.SASLOAuthClientSecret + if clientSecret == "" { + return cerror.ErrKafkaInvalidConfig.GenWithStack( + "OAuth2 client secret cannot be empty") + } + + // BASE64 decode the client secret + decodedClientSecret, err := base64.StdEncoding.DecodeString(clientSecret) + if err != nil { + log.Error("OAuth2 client secret is not base64 encoded", zap.Error(err)) + return cerror.ErrKafkaInvalidConfig.GenWithStack( + "OAuth2 client secret is not base64 encoded") + } + c.SASL.OAuth2.ClientSecret = string(decodedClientSecret) + } + + if replicaConfig.Sink.KafkaConfig.SASLOAuthTokenURL != nil { + tokenURL := *replicaConfig.Sink.KafkaConfig.SASLOAuthTokenURL + if tokenURL == "" { + return cerror.ErrKafkaInvalidConfig.GenWithStack( + "OAuth2 token URL cannot be empty") + } + c.SASL.OAuth2.TokenURL = tokenURL + } + + if c.SASL.OAuth2.IsEnable() { + if c.SASL.SASLMechanism != security.OAuthMechanism { + return cerror.ErrKafkaInvalidConfig.GenWithStack( + "OAuth2 is only supported with SASL mechanism type OAUTHBEARER, but got %s", + c.SASL.SASLMechanism) + } + + if err := c.SASL.OAuth2.Validate(); err != nil { + return cerror.ErrKafkaInvalidConfig.Wrap(err) + } + c.SASL.OAuth2.SetDefault() + } + + if replicaConfig.Sink.KafkaConfig.SASLOAuthScopes != nil { + c.SASL.OAuth2.Scopes = replicaConfig.Sink.KafkaConfig.SASLOAuthScopes + } + + if replicaConfig.Sink.KafkaConfig.SASLOAuthGrantType != nil { + c.SASL.OAuth2.GrantType = *replicaConfig.Sink.KafkaConfig.SASLOAuthGrantType + } + + if replicaConfig.Sink.KafkaConfig.SASLOAuthAudience != nil { + c.SASL.OAuth2.Audience = *replicaConfig.Sink.KafkaConfig.SASLOAuthAudience + } + } + return nil } @@ -445,12 +513,14 @@ func NewSaramaConfig(ctx context.Context, c *Config) (*sarama.Config, error) { config.Net.TLS.Config.InsecureSkipVerify = c.InsecureSkipVerify } - completeSaramaSASLConfig(config, c) + if err := completeSaramaSASLConfig(ctx, config, c); err != nil { + return nil, errors.Trace(err) + } return config, err } -func completeSaramaSASLConfig(config *sarama.Config, c *Config) { +func completeSaramaSASLConfig(ctx context.Context, config *sarama.Config, c *Config) error { if c.SASL != nil && c.SASL.SASLMechanism != "" { config.Net.SASL.Enable = true config.Net.SASL.Mechanism = sarama.SASLMechanism(c.SASL.SASLMechanism) @@ -480,6 +550,14 @@ func completeSaramaSASLConfig(config *sarama.Config, c *Config) { case security.KeyTabAuth: config.Net.SASL.GSSAPI.KeyTabPath = c.SASL.GSSAPI.KeyTabPath } + case sarama.SASLTypeOAuth: + p, err := newTokenProvider(ctx, c) + if err != nil { + return errors.Trace(err) + } + config.Net.SASL.TokenProvider = p } + } + return nil } diff --git a/cdc/sink/mq/producer/kafka/config_test.go b/cdc/sink/mq/producer/kafka/config_test.go index 2c46b7d08de..748bbff9c6e 100644 --- a/cdc/sink/mq/producer/kafka/config_test.go +++ b/cdc/sink/mq/producer/kafka/config_test.go @@ -106,7 +106,7 @@ func TestConfigTimeouts(t *testing.T) { sinkURI, err := url.Parse(uri) require.Nil(t, err) - err = cfg.Apply(sinkURI) + err = cfg.Apply(sinkURI, config.GetDefaultReplicaConfig()) require.Nil(t, err) require.Equal(t, 5*time.Second, cfg.DialTimeout) @@ -121,6 +121,8 @@ func TestConfigTimeouts(t *testing.T) { } func TestCompleteConfigByOpts(t *testing.T) { + replicaCfg := config.GetDefaultReplicaConfig() + cfg := NewConfig() // Normal config. @@ -132,7 +134,7 @@ func TestCompleteConfigByOpts(t *testing.T) { sinkURI, err := url.Parse(uri) require.Nil(t, err) - err = cfg.Apply(sinkURI) + err = cfg.Apply(sinkURI, replicaCfg) require.Nil(t, err) require.Equal(t, int32(1), cfg.PartitionNum) require.Equal(t, int16(3), cfg.ReplicationFactor) @@ -144,7 +146,7 @@ func TestCompleteConfigByOpts(t *testing.T) { sinkURI, err = url.Parse(uri) require.Nil(t, err) cfg = NewConfig() - err = cfg.Apply(sinkURI) + err = cfg.Apply(sinkURI, replicaCfg) require.Nil(t, err) require.Len(t, cfg.BrokerEndpoints, 3) @@ -153,7 +155,7 @@ func TestCompleteConfigByOpts(t *testing.T) { sinkURI, err = url.Parse(uri) require.Nil(t, err) cfg = NewConfig() - err = cfg.Apply(sinkURI) + err = cfg.Apply(sinkURI, replicaCfg) require.Regexp(t, ".*invalid syntax.*", errors.Cause(err)) // Illegal max-message-bytes. @@ -161,7 +163,7 @@ func TestCompleteConfigByOpts(t *testing.T) { sinkURI, err = url.Parse(uri) require.Nil(t, err) cfg = NewConfig() - err = cfg.Apply(sinkURI) + err = cfg.Apply(sinkURI, replicaCfg) require.Regexp(t, ".*invalid syntax.*", errors.Cause(err)) // Illegal partition-num. @@ -169,7 +171,7 @@ func TestCompleteConfigByOpts(t *testing.T) { sinkURI, err = url.Parse(uri) require.Nil(t, err) cfg = NewConfig() - err = cfg.Apply(sinkURI) + err = cfg.Apply(sinkURI, replicaCfg) require.Regexp(t, ".*invalid syntax.*", errors.Cause(err)) // Out of range partition-num. @@ -177,7 +179,7 @@ func TestCompleteConfigByOpts(t *testing.T) { sinkURI, err = url.Parse(uri) require.Nil(t, err) cfg = NewConfig() - err = cfg.Apply(sinkURI) + err = cfg.Apply(sinkURI, replicaCfg) require.Regexp(t, ".*invalid partition num.*", errors.Cause(err)) } @@ -390,7 +392,7 @@ func TestConfigurationCombinations(t *testing.T) { require.Nil(t, err) baseConfig := NewConfig() - err = baseConfig.Apply(sinkURI) + err = baseConfig.Apply(sinkURI, config.GetDefaultReplicaConfig()) require.Nil(t, err) saramaConfig, err := NewSaramaConfig(context.Background(), baseConfig) @@ -428,26 +430,30 @@ func TestApplySASL(t *testing.T) { t.Parallel() tests := []struct { - name string - URI string - exceptErr string + name string + URI string + replicaConfig func() *config.ReplicaConfig + exceptErr string }{ { - name: "no params", - URI: "kafka://127.0.0.1:9092/abc", - exceptErr: "", + name: "no params", + URI: "kafka://127.0.0.1:9092/abc", + replicaConfig: config.GetDefaultReplicaConfig, + exceptErr: "", }, { name: "valid PLAIN SASL", URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0&partition-num=0" + "&sasl-user=user&sasl-password=password&sasl-mechanism=plain", - exceptErr: "", + replicaConfig: config.GetDefaultReplicaConfig, + exceptErr: "", }, { name: "valid SCRAM SASL", URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0&partition-num=0" + "&sasl-user=user&sasl-password=password&sasl-mechanism=SCRAM-SHA-512", - exceptErr: "", + replicaConfig: config.GetDefaultReplicaConfig, + exceptErr: "", }, { name: "valid GSSAPI user auth SASL", @@ -457,7 +463,8 @@ func TestApplySASL(t *testing.T) { "&sasl-gssapi-service-name=a&sasl-gssapi-user=user" + "&sasl-gssapi-password=pwd" + "&sasl-gssapi-realm=realm&sasl-gssapi-disable-pafxfast=false", - exceptErr: "", + replicaConfig: config.GetDefaultReplicaConfig, + exceptErr: "", }, { name: "valid GSSAPI keytab auth SASL", @@ -467,19 +474,136 @@ func TestApplySASL(t *testing.T) { "&sasl-gssapi-service-name=a&sasl-gssapi-user=user" + "&sasl-gssapi-keytab-path=/root/keytab" + "&sasl-gssapi-realm=realm&sasl-gssapi-disable-pafxfast=false", - exceptErr: "", + replicaConfig: config.GetDefaultReplicaConfig, + exceptErr: "", }, { name: "invalid mechanism", URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0&partition-num=0" + "&sasl-mechanism=a", - exceptErr: "unknown a SASL mechanism", + replicaConfig: config.GetDefaultReplicaConfig, + exceptErr: "unknown a SASL mechanism", }, { name: "invalid GSSAPI auth type", URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0&partition-num=0" + "&sasl-mechanism=gssapi&sasl-gssapi-auth-type=keyta1b", - exceptErr: "unknown keyta1b auth type", + replicaConfig: config.GetDefaultReplicaConfig, + exceptErr: "unknown keyta1b auth type", + }, + { + name: "valid OAUTHBEARER SASL", + URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0" + + "&partition-num=0&sasl-mechanism=OAUTHBEARER", + replicaConfig: func() *config.ReplicaConfig { + cfg := config.GetDefaultReplicaConfig() + oauthMechanism := string(security.OAuthMechanism) + clientID := "client_id" + clientSecret := "Y2xpZW50X3NlY3JldA==" // base64(client_secret) + tokenURL := "127.0.0.1:9093/token" + cfg.Sink.KafkaConfig = &config.KafkaConfig{ + SASLMechanism: &oauthMechanism, + SASLOAuthClientID: &clientID, + SASLOAuthClientSecret: &clientSecret, + SASLOAuthTokenURL: &tokenURL, + } + return cfg + }, + exceptErr: "", + }, + { + name: "invalid OAUTHBEARER SASL: missing client id", + URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0" + + "&partition-num=0&sasl-mechanism=OAUTHBEARER", + replicaConfig: func() *config.ReplicaConfig { + cfg := config.GetDefaultReplicaConfig() + oauthMechanism := string(security.OAuthMechanism) + clientSecret := "Y2xpZW50X3NlY3JldA==" // base64(client_secret) + tokenURL := "127.0.0.1:9093/token" + cfg.Sink.KafkaConfig = &config.KafkaConfig{ + SASLMechanism: &oauthMechanism, + SASLOAuthClientSecret: &clientSecret, + SASLOAuthTokenURL: &tokenURL, + } + return cfg + }, + exceptErr: "OAuth2 client id is empty", + }, + { + name: "invalid OAUTHBEARER SASL: missing client secret", + URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0" + + "&partition-num=0&sasl-mechanism=OAUTHBEARER", + replicaConfig: func() *config.ReplicaConfig { + cfg := config.GetDefaultReplicaConfig() + oauthMechanism := string(security.OAuthMechanism) + clientID := "client_id" + tokenURL := "127.0.0.1:9093/token" + cfg.Sink.KafkaConfig = &config.KafkaConfig{ + SASLMechanism: &oauthMechanism, + SASLOAuthClientID: &clientID, + SASLOAuthTokenURL: &tokenURL, + } + return cfg + }, + exceptErr: "OAuth2 client secret is empty", + }, + { + name: "invalid OAUTHBEARER SASL: missing token url", + URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0" + + "&partition-num=0&sasl-mechanism=OAUTHBEARER", + replicaConfig: func() *config.ReplicaConfig { + cfg := config.GetDefaultReplicaConfig() + oauthMechanism := string(security.OAuthMechanism) + clientID := "client_id" + clientSecret := "Y2xpZW50X3NlY3JldA==" // base64(client_secret) + cfg.Sink.KafkaConfig = &config.KafkaConfig{ + SASLMechanism: &oauthMechanism, + SASLOAuthClientID: &clientID, + SASLOAuthClientSecret: &clientSecret, + } + return cfg + }, + exceptErr: "OAuth2 token url is empty", + }, + { + name: "invalid OAUTHBEARER SASL: non base64 client secret", + URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0" + + "&partition-num=0&sasl-mechanism=OAUTHBEARER", + replicaConfig: func() *config.ReplicaConfig { + cfg := config.GetDefaultReplicaConfig() + oauthMechanism := string(security.OAuthMechanism) + clientID := "client_id" + clientSecret := "client_secret" + tokenURL := "127.0.0.1:9093/token" + cfg.Sink.KafkaConfig = &config.KafkaConfig{ + SASLMechanism: &oauthMechanism, + SASLOAuthClientID: &clientID, + SASLOAuthClientSecret: &clientSecret, + SASLOAuthTokenURL: &tokenURL, + } + return cfg + }, + exceptErr: "OAuth2 client secret is not base64 encoded", + }, + { + name: "invalid OAUTHBEARER SASL: wrong mechanism", + URI: "kafka://127.0.0.1:9092/abc?kafka-version=2.6.0" + + "&partition-num=0&sasl-mechanism=GSSAPI", + replicaConfig: func() *config.ReplicaConfig { + cfg := config.GetDefaultReplicaConfig() + oauthMechanism := string(security.OAuthMechanism) + clientID := "client_id" + clientSecret := "Y2xpZW50X3NlY3JldA==" // base64(client_secret) + tokenURL := "127.0.0.1:9093/token" + cfg.Sink.KafkaConfig = &config.KafkaConfig{ + SASLMechanism: &oauthMechanism, + SASLOAuthClientID: &clientID, + SASLOAuthClientSecret: &clientSecret, + SASLOAuthTokenURL: &tokenURL, + } + return cfg + }, + exceptErr: "OAuth2 is only supported with SASL mechanism type OAUTHBEARER", }, } @@ -490,9 +614,11 @@ func TestApplySASL(t *testing.T) { sinkURI, err := url.Parse(test.URI) require.Nil(t, err) if test.exceptErr == "" { - require.Nil(t, cfg.applySASL(sinkURI.Query())) + require.Nil(t, cfg.applySASL( + sinkURI.Query(), test.replicaConfig())) } else { - require.Regexp(t, test.exceptErr, cfg.applySASL(sinkURI.Query()).Error()) + require.Regexp(t, test.exceptErr, cfg.applySASL( + sinkURI.Query(), test.replicaConfig()).Error()) } }) } @@ -567,6 +693,7 @@ func TestApplyTLS(t *testing.T) { func TestCompleteSaramaSASLConfig(t *testing.T) { t.Parallel() + ctx := context.Background() // Test that SASL is turned on correctly. cfg := NewConfig() cfg.SASL = &security.SASL{ @@ -576,10 +703,10 @@ func TestCompleteSaramaSASLConfig(t *testing.T) { GSSAPI: security.GSSAPI{}, } saramaConfig := sarama.NewConfig() - completeSaramaSASLConfig(saramaConfig, cfg) + completeSaramaSASLConfig(ctx, saramaConfig, cfg) require.False(t, saramaConfig.Net.SASL.Enable) cfg.SASL.SASLMechanism = "plain" - completeSaramaSASLConfig(saramaConfig, cfg) + completeSaramaSASLConfig(ctx, saramaConfig, cfg) require.True(t, saramaConfig.Net.SASL.Enable) // Test that the SCRAMClientGeneratorFunc is set up correctly. cfg = NewConfig() @@ -590,9 +717,9 @@ func TestCompleteSaramaSASLConfig(t *testing.T) { GSSAPI: security.GSSAPI{}, } saramaConfig = sarama.NewConfig() - completeSaramaSASLConfig(saramaConfig, cfg) + completeSaramaSASLConfig(ctx, saramaConfig, cfg) require.Nil(t, saramaConfig.Net.SASL.SCRAMClientGeneratorFunc) cfg.SASL.SASLMechanism = "SCRAM-SHA-512" - completeSaramaSASLConfig(saramaConfig, cfg) + completeSaramaSASLConfig(ctx, saramaConfig, cfg) require.NotNil(t, saramaConfig.Net.SASL.SCRAMClientGeneratorFunc) } diff --git a/cdc/sink/mq/producer/kafka/oauth2_token_provider.go b/cdc/sink/mq/producer/kafka/oauth2_token_provider.go new file mode 100644 index 00000000000..268d33bc3ba --- /dev/null +++ b/cdc/sink/mq/producer/kafka/oauth2_token_provider.go @@ -0,0 +1,92 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package kafka + +import ( + "context" + "net/url" + + "github.com/Shopify/sarama" + "github.com/pingcap/errors" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" +) + +// tsokenProvider is a user-defined callback for generating +// access tokens for SASL/OAUTHBEARER auth. +type tokenProvider struct { + tokenSource oauth2.TokenSource +} + +var _ sarama.AccessTokenProvider = (*tokenProvider)(nil) + +// Token implements the sarama.AccessTokenProvider interface. +// Token returns an access token. The implementation should ensure token +// reuse so that multiple calls at connect time do not create multiple +// tokens. The implementation should also periodically refresh the token in +// order to guarantee that each call returns an unexpired token. This +// method should not block indefinitely--a timeout error should be returned +// after a short period of inactivity so that the broker connection logic +// can log debugging information and retry. +func (t *tokenProvider) Token() (*sarama.AccessToken, error) { + token, err := t.tokenSource.Token() + if err != nil { + // Errors will result in Sarama retrying the broker connection and logging + // the transient error, with a Broker connection error surfacing after retry + // attempts have been exhausted. + return nil, err + } + + return &sarama.AccessToken{Token: token.AccessToken}, nil +} + +func newTokenProvider(ctx context.Context, + kafkaConfig *Config, +) (sarama.AccessTokenProvider, error) { + // grant_type is by default going to be set to 'client_credentials' by the + // clientcredentials library as defined by the spec, however non-compliant + // auth server implementations may want a custom type + var endpointParams url.Values + if kafkaConfig.SASL.OAuth2.GrantType != "" { + if endpointParams == nil { + endpointParams = url.Values{} + } + endpointParams.Set("grant_type", kafkaConfig.SASL.OAuth2.GrantType) + } + + // audience is an optional parameter that can be used to specify the + // intended audience of the token. + if kafkaConfig.SASL.OAuth2.Audience != "" { + if endpointParams == nil { + endpointParams = url.Values{} + } + endpointParams.Set("audience", kafkaConfig.SASL.OAuth2.Audience) + } + + tokenURL, err := url.Parse(kafkaConfig.SASL.OAuth2.TokenURL) + if err != nil { + return nil, errors.Trace(err) + } + + cfg := clientcredentials.Config{ + ClientID: kafkaConfig.SASL.OAuth2.ClientID, + ClientSecret: kafkaConfig.SASL.OAuth2.ClientSecret, + TokenURL: tokenURL.String(), + EndpointParams: endpointParams, + Scopes: kafkaConfig.SASL.OAuth2.Scopes, + } + return &tokenProvider{ + tokenSource: cfg.TokenSource(ctx), + }, nil +} diff --git a/cdc/sink/mq/producer/kafka/oauth2_token_provider_test.go b/cdc/sink/mq/producer/kafka/oauth2_token_provider_test.go new file mode 100644 index 00000000000..f986c22ba07 --- /dev/null +++ b/cdc/sink/mq/producer/kafka/oauth2_token_provider_test.go @@ -0,0 +1,74 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package kafka + +import ( + "context" + "testing" + + "github.com/pingcap/tiflow/pkg/security" + "github.com/stretchr/testify/require" +) + +func TestNewTokenProvider(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + name string + config *Config + expectedErr string + }{ + { + name: "valid", + config: &Config{ + SASL: &security.SASL{ + OAuth2: security.OAuth2{ + ClientID: "client-id", + ClientSecret: "client-secret", + TokenURL: "http://localhost:8080/oauth2/token", + Scopes: []string{"scope1", "scope2"}, + GrantType: "client_credentials", + }, + }, + }, + }, + { + name: "invalid token URL", + config: &Config{ + SASL: &security.SASL{ + OAuth2: security.OAuth2{ + ClientID: "client-id", + ClientSecret: "client-secret", + TokenURL: "http://test.com/Segment%%2815197306101420000%29", + Scopes: []string{"scope1", "scope2"}, + GrantType: "client_credentials", + }, + }, + }, + expectedErr: "invalid URL escape", + }, + } { + ts := test + t.Run(ts.name, func(t *testing.T) { + t.Parallel() + _, err := newTokenProvider(context.TODO(), ts.config) + if ts.expectedErr == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), ts.expectedErr) + } + }) + } +} diff --git a/cdc/sinkv2/ddlsink/mq/kafka_ddl_sink.go b/cdc/sinkv2/ddlsink/mq/kafka_ddl_sink.go index f0b77168894..bd58836366e 100644 --- a/cdc/sinkv2/ddlsink/mq/kafka_ddl_sink.go +++ b/cdc/sinkv2/ddlsink/mq/kafka_ddl_sink.go @@ -45,7 +45,7 @@ func NewKafkaDDLSink( } baseConfig := kafka.NewConfig() - if err := baseConfig.Apply(sinkURI); err != nil { + if err := baseConfig.Apply(sinkURI, replicaConfig); err != nil { return nil, cerror.WrapError(cerror.ErrKafkaInvalidConfig, err) } saramaConfig, err := kafka.NewSaramaConfig(ctx, baseConfig) diff --git a/cdc/sinkv2/eventsink/mq/kafka_dml_sink.go b/cdc/sinkv2/eventsink/mq/kafka_dml_sink.go index c2ac32af15e..1d6b01c6634 100644 --- a/cdc/sinkv2/eventsink/mq/kafka_dml_sink.go +++ b/cdc/sinkv2/eventsink/mq/kafka_dml_sink.go @@ -45,7 +45,7 @@ func NewKafkaDMLSink( } baseConfig := kafka.NewConfig() - if err := baseConfig.Apply(sinkURI); err != nil { + if err := baseConfig.Apply(sinkURI, replicaConfig); err != nil { return nil, cerror.WrapError(cerror.ErrKafkaInvalidConfig, err) } saramaConfig, err := kafka.NewSaramaConfig(ctx, baseConfig) diff --git a/docs/swagger/docs.go b/docs/swagger/docs.go index 436ea53fe22..ca5cd482da6 100644 --- a/docs/swagger/docs.go +++ b/docs/swagger/docs.go @@ -1244,6 +1244,35 @@ var doc = `{ } } }, + "config.KafkaConfig": { + "type": "object", + "properties": { + "sasl-mechanism": { + "type": "string" + }, + "sasl-oauth-audience": { + "type": "string" + }, + "sasl-oauth-client-id": { + "type": "string" + }, + "sasl-oauth-client-secret": { + "type": "string" + }, + "sasl-oauth-grant-type": { + "type": "string" + }, + "sasl-oauth-scopes": { + "type": "array", + "items": { + "type": "string" + } + }, + "sasl-oauth-token-url": { + "type": "string" + } + } + }, "config.SinkConfig": { "type": "object", "properties": { @@ -1274,6 +1303,9 @@ var doc = `{ "file-index-digit": { "type": "integer" }, + "kafka-config": { + "$ref": "#/definitions/config.KafkaConfig" + }, "protocol": { "type": "string" }, diff --git a/docs/swagger/swagger.json b/docs/swagger/swagger.json index 1a21d578bb9..c0c12db0faa 100644 --- a/docs/swagger/swagger.json +++ b/docs/swagger/swagger.json @@ -1225,6 +1225,35 @@ } } }, + "config.KafkaConfig": { + "type": "object", + "properties": { + "sasl-mechanism": { + "type": "string" + }, + "sasl-oauth-audience": { + "type": "string" + }, + "sasl-oauth-client-id": { + "type": "string" + }, + "sasl-oauth-client-secret": { + "type": "string" + }, + "sasl-oauth-grant-type": { + "type": "string" + }, + "sasl-oauth-scopes": { + "type": "array", + "items": { + "type": "string" + } + }, + "sasl-oauth-token-url": { + "type": "string" + } + } + }, "config.SinkConfig": { "type": "object", "properties": { @@ -1255,6 +1284,9 @@ "file-index-digit": { "type": "integer" }, + "kafka-config": { + "$ref": "#/definitions/config.KafkaConfig" + }, "protocol": { "type": "string" }, diff --git a/docs/swagger/swagger.yaml b/docs/swagger/swagger.yaml index 591cd176c4a..ba542b73a74 100644 --- a/docs/swagger/swagger.yaml +++ b/docs/swagger/swagger.yaml @@ -42,6 +42,25 @@ definitions: topic: type: string type: object + config.KafkaConfig: + properties: + sasl-mechanism: + type: string + sasl-oauth-audience: + type: string + sasl-oauth-client-id: + type: string + sasl-oauth-client-secret: + type: string + sasl-oauth-grant-type: + type: string + sasl-oauth-scopes: + items: + type: string + type: array + sasl-oauth-token-url: + type: string + type: object config.SinkConfig: properties: column-selectors: @@ -62,6 +81,8 @@ definitions: type: integer file-index-digit: type: integer + kafka-config: + $ref: '#/definitions/config.KafkaConfig' protocol: type: string schema-registry: diff --git a/go.mod b/go.mod index fc7e3e0fe41..db005a5b473 100644 --- a/go.mod +++ b/go.mod @@ -99,6 +99,7 @@ require ( go.uber.org/zap v1.23.0 golang.org/x/exp v0.0.0-20221023144134-a1e5550cf13e golang.org/x/net v0.2.0 + golang.org/x/oauth2 v0.2.0 golang.org/x/sync v0.1.0 golang.org/x/sys v0.4.0 golang.org/x/text v0.4.0 @@ -274,7 +275,6 @@ require ( go.opentelemetry.io/otel/trace v0.20.0 // indirect go.opentelemetry.io/proto/otlp v0.7.0 // indirect golang.org/x/crypto v0.1.0 // indirect - golang.org/x/oauth2 v0.2.0 // indirect golang.org/x/term v0.2.0 // indirect golang.org/x/tools v0.2.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/pkg/config/sink.go b/pkg/config/sink.go index 9289859caa9..93a164026af 100644 --- a/pkg/config/sink.go +++ b/pkg/config/sink.go @@ -115,12 +115,25 @@ type SinkConfig struct { EnablePartitionSeparator bool `toml:"enable-partition-separator" json:"enable-partition-separator"` FileIndexWidth int `toml:"file-index-digit,omitempty" json:"file-index-digit,omitempty"` + KafkaConfig *KafkaConfig `toml:"kafka-config" json:"kafka-config,omitempty"` + // TiDBSourceID is the source ID of the upstream TiDB, // which is used to set the `tidb_cdc_write_source` session variable. // Note: This field is only used internally and only used in the MySQL sink. TiDBSourceID uint64 `toml:"-" json:"-"` } +// KafkaConfig represents a kafka sink configuration +type KafkaConfig struct { + SASLMechanism *string `toml:"sasl-mechanism" json:"sasl-mechanism,omitempty"` + SASLOAuthClientID *string `toml:"sasl-oauth-client-id" json:"sasl-oauth-client-id,omitempty"` + SASLOAuthClientSecret *string `toml:"sasl-oauth-client-secret" json:"sasl-oauth-client-secret,omitempty"` + SASLOAuthTokenURL *string `toml:"sasl-oauth-token-url" json:"sasl-oauth-token-url,omitempty"` + SASLOAuthScopes []string `toml:"sasl-oauth-scopes" json:"sasl-oauth-scopes,omitempty"` + SASLOAuthGrantType *string `toml:"sasl-oauth-grant-type" json:"sasl-oauth-grant-type,omitempty"` + SASLOAuthAudience *string `toml:"sasl-oauth-audience" json:"sasl-oauth-audience,omitempty"` +} + // CSVConfig defines a series of configuration items for csv codec. type CSVConfig struct { // delimiter between fields diff --git a/pkg/security/sasl.go b/pkg/security/sasl.go index 8077589d547..3be79701b54 100644 --- a/pkg/security/sasl.go +++ b/pkg/security/sasl.go @@ -35,6 +35,8 @@ const ( SCRAM512Mechanism SASLMechanism = sarama.SASLTypeSCRAMSHA512 // GSSAPIMechanism means the SASL mechanism is GSSAPI. GSSAPIMechanism SASLMechanism = sarama.SASLTypeGSSAPI + // OAuthMechanism means the SASL mechanism is OAuth2. + OAuthMechanism SASLMechanism = sarama.SASLTypeOAuth ) // SASLMechanismFromString converts the string to SASL mechanism. @@ -48,6 +50,8 @@ func SASLMechanismFromString(s string) (SASLMechanism, error) { return SCRAM512Mechanism, nil case "gssapi": return GSSAPIMechanism, nil + case "oauthbearer": + return OAuthMechanism, nil default: return UnknownMechanism, errors.Errorf("unknown %s SASL mechanism", s) } @@ -55,10 +59,47 @@ func SASLMechanismFromString(s string) (SASLMechanism, error) { // SASL holds necessary path parameter to support sasl-scram type SASL struct { - SASLUser string `toml:"sasl-user" json:"sasl-user"` - SASLPassword string `toml:"sasl-password" json:"sasl-password"` - SASLMechanism SASLMechanism `toml:"sasl-mechanism" json:"sasl-mechanism"` - GSSAPI GSSAPI `toml:"sasl-gssapi" json:"sasl-gssapi"` + SASLUser string + SASLPassword string + SASLMechanism SASLMechanism + GSSAPI GSSAPI + OAuth2 OAuth2 +} + +// OAuth2 holds necessary parameters to support sasl-oauth2. +type OAuth2 struct { + ClientID string + ClientSecret string + TokenURL string + Scopes []string + GrantType string + Audience string +} + +// Validate validates the parameters of OAuth2. +// Some parameters are required, some are optional. +func (o *OAuth2) Validate() error { + if len(o.ClientID) == 0 { + return errors.New("OAuth2 client id is empty") + } + if len(o.ClientSecret) == 0 { + return errors.New("OAuth2 client secret is empty") + } + if len(o.TokenURL) == 0 { + return errors.New("OAuth2 token url is empty") + } + return nil +} + +// SetDefault sets the default value of OAuth2. +func (o *OAuth2) SetDefault() { + o.GrantType = "client_credentials" +} + +// IsEnable checks whether the OAuth2 is enabled. +// One of values of ClientID, ClientSecret and TokenURL is not empty means enabled. +func (o *OAuth2) IsEnable() bool { + return len(o.ClientID) > 0 || len(o.ClientSecret) > 0 || len(o.TokenURL) > 0 } // GSSAPIAuthType defines the type of GSSAPI authentication.