diff --git a/dm/master/config.go b/dm/master/config.go index e064ec4a28..a771658ff2 100644 --- a/dm/master/config.go +++ b/dm/master/config.go @@ -224,8 +224,9 @@ func (c *Config) configFromFile(path string) error { // adjust adjusts configs func (c *Config) adjust() error { - // MasterAddr's format may be "scheme://host:port", "host:port" or ":port" - host, port, err := net.SplitHostPort(utils.UnwrapScheme(c.MasterAddr)) + c.MasterAddr = utils.UnwrapScheme(c.MasterAddr) + // MasterAddr's format may be "host:port" or ":port" + host, port, err := net.SplitHostPort(c.MasterAddr) if err != nil { return terror.ErrMasterHostPortNotValid.Delegate(err, c.MasterAddr) } @@ -236,8 +237,9 @@ func (c *Config) adjust() error { } c.AdvertiseAddr = c.MasterAddr } else { - // AdvertiseAddr's format may be "scheme://host:port" or "host:port" - host, port, err = net.SplitHostPort(utils.UnwrapScheme(c.AdvertiseAddr)) + c.AdvertiseAddr = utils.UnwrapScheme(c.AdvertiseAddr) + // AdvertiseAddr's format should be "host:port" + host, port, err = net.SplitHostPort(c.AdvertiseAddr) if err != nil { return terror.ErrMasterAdvertiseAddrNotValid.Delegate(err, c.AdvertiseAddr) } @@ -294,10 +296,14 @@ func (c *Config) adjust() error { if c.PeerUrls == "" { c.PeerUrls = defaultPeerUrls + } else { + c.PeerUrls = utils.WrapSchemes(c.PeerUrls, c.SSLCA != "") } if c.AdvertisePeerUrls == "" { c.AdvertisePeerUrls = c.PeerUrls + } else { + c.AdvertisePeerUrls = utils.WrapSchemes(c.AdvertisePeerUrls, c.SSLCA != "") } if c.InitialCluster == "" { @@ -306,12 +312,18 @@ func (c *Config) adjust() error { items[i] = fmt.Sprintf("%s=%s", c.Name, item) } c.InitialCluster = strings.Join(items, ",") + } else { + c.InitialCluster = utils.WrapSchemesForInitialCluster(c.InitialCluster, c.SSLCA != "") } if c.InitialClusterState == "" { c.InitialClusterState = defaultInitialClusterState } + if c.Join != "" { + c.Join = utils.WrapSchemes(c.Join, c.SSLCA != "") + } + return err } diff --git a/dm/worker/config.go b/dm/worker/config.go index 85e7d314c7..9fae84d429 100644 --- a/dm/worker/config.go +++ b/dm/worker/config.go @@ -169,7 +169,8 @@ func (c *Config) Parse(arguments []string) error { // adjust adjusts the config. func (c *Config) adjust() error { - host, port, err := net.SplitHostPort(utils.UnwrapScheme(c.WorkerAddr)) + c.WorkerAddr = utils.UnwrapScheme(c.WorkerAddr) + host, port, err := net.SplitHostPort(c.WorkerAddr) if err != nil { return terror.ErrWorkerHostPortNotValid.Delegate(err, c.WorkerAddr) } @@ -180,7 +181,8 @@ func (c *Config) adjust() error { } c.AdvertiseAddr = c.WorkerAddr } else { - host, port, err = net.SplitHostPort(utils.UnwrapScheme(c.AdvertiseAddr)) + c.AdvertiseAddr = utils.UnwrapScheme(c.AdvertiseAddr) + host, port, err = net.SplitHostPort(c.AdvertiseAddr) if err != nil { return terror.ErrWorkerHostPortNotValid.Delegate(err, c.AdvertiseAddr) } @@ -194,6 +196,10 @@ func (c *Config) adjust() error { c.Name = c.AdvertiseAddr } + if c.Join != "" { + c.Join = utils.WrapSchemes(c.Join, c.SSLCA != "") + } + return nil } diff --git a/dm/worker/join.go b/dm/worker/join.go index 1e3eaf84fc..7676433140 100644 --- a/dm/worker/join.go +++ b/dm/worker/join.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/dm/pkg/ha" "github.com/pingcap/dm/pkg/log" "github.com/pingcap/dm/pkg/terror" + "github.com/pingcap/dm/pkg/utils" ) // GetJoinURLs gets the endpoints from the join address. @@ -53,7 +54,7 @@ func (s *Server) JoinMaster(endpoints []string) error { for _, endpoint := range endpoints { ctx1, cancel1 := context.WithTimeout(ctx, 3*time.Second) - conn, err := grpc.DialContext(ctx1, endpoint, grpc.WithBlock(), tls.ToGRPCDialOption(), grpc.WithBackoffMaxDelay(3*time.Second)) + conn, err := grpc.DialContext(ctx1, utils.UnwrapScheme(endpoint), grpc.WithBlock(), tls.ToGRPCDialOption(), grpc.WithBackoffMaxDelay(3*time.Second)) cancel1() if err != nil { if conn != nil { diff --git a/pkg/utils/util.go b/pkg/utils/util.go index a80d510b38..6852c5074a 100644 --- a/pkg/utils/util.go +++ b/pkg/utils/util.go @@ -180,3 +180,43 @@ func UnwrapScheme(s string) string { } return s } + +func wrapScheme(s string, https bool) string { + if s == "" { + return s + } + if strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://") { + return s + } + if https { + return "https://" + s + } + return "http://" + s +} + +// WrapSchemes adds http or https scheme to input if missing. input could be a comma-separated list +// if input has wrong scheme, don't correct it (maybe user deliberately?) +func WrapSchemes(s string, https bool) string { + items := strings.Split(s, ",") + output := make([]string, 0, len(items)) + for _, s := range items { + output = append(output, wrapScheme(s, https)) + } + return strings.Join(output, ",") +} + +// WrapSchemesForInitialCluster acts like WrapSchemes, except input is "name=URL,..." +func WrapSchemesForInitialCluster(s string, https bool) string { + items := strings.Split(s, ",") + output := make([]string, 0, len(items)) + for _, item := range items { + kv := strings.Split(item, "=") + if len(kv) != 2 { + output = append(output, item) + continue + } + + output = append(output, kv[0]+"="+wrapScheme(kv[1], https)) + } + return strings.Join(output, ",") +} diff --git a/pkg/utils/util_test.go b/pkg/utils/util_test.go index 0a259ec811..156e381b4c 100644 --- a/pkg/utils/util_test.go +++ b/pkg/utils/util_test.go @@ -216,8 +216,46 @@ func (t *testUtilsSuite) TestUnwrapScheme(c *C) { "httpsdfpoje.com", "httpsdfpoje.com", }, + { + "", + "", + }, } for _, ca := range cases { c.Assert(UnwrapScheme(ca.old), Equals, ca.new) } } + +func (t *testUtilsSuite) TestWrapSchemes(c *C) { + cases := []struct { + old string + http string + https string + }{ + { + "0.0.0.0:123", + "http://0.0.0.0:123", + "https://0.0.0.0:123", + }, + { + "abc.com:123", + "http://abc.com:123", + "https://abc.com:123", + }, + { + // if input has wrong scheme, don't correct it (maybe user deliberately?) + "abc.com:123,http://abc.com:123,0.0.0.0:123,https://0.0.0.0:123", + "http://abc.com:123,http://abc.com:123,http://0.0.0.0:123,https://0.0.0.0:123", + "https://abc.com:123,http://abc.com:123,https://0.0.0.0:123,https://0.0.0.0:123", + }, + { + "", + "", + "", + }, + } + for _, ca := range cases { + c.Assert(WrapSchemes(ca.old, false), Equals, ca.http) + c.Assert(WrapSchemes(ca.old, true), Equals, ca.https) + } +} diff --git a/tests/ha_master/conf/dm-worker1.toml b/tests/ha_master/conf/dm-worker1.toml index 75b251839c..cb4033224e 100644 --- a/tests/ha_master/conf/dm-worker1.toml +++ b/tests/ha_master/conf/dm-worker1.toml @@ -1,2 +1,2 @@ name = "worker1" -join = "localhost:8261,localhost:8361,localhost:8461,localhost:8561,localhost:8661" +join = "localhost:8261,http://localhost:8361,localhost:8461,localhost:8561,localhost:8661"