Skip to content
This repository has been archived by the owner on Nov 24, 2023. It is now read-only.

Commit

Permalink
config: check and correct format of addr and URL (#937)
Browse files Browse the repository at this point in the history
* config: check and correct format of addr and URL

* fix hound

* fix typo

* fix test

* fix wrong use of join

Co-authored-by: GMHDBJD <35025882+GMHDBJD@users.noreply.github.com>
  • Loading branch information
lance6716 and GMHDBJD authored Aug 28, 2020
1 parent 925601c commit 92b3e58
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 8 deletions.
20 changes: 16 additions & 4 deletions dm/master/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,9 @@ func (c *Config) configFromFile(path string) error {

// adjust adjusts configs
func (c *Config) adjust() error {
// MasterAddr's format may be "scheme://host:port", "host:port" or ":port"
host, port, err := net.SplitHostPort(utils.UnwrapScheme(c.MasterAddr))
c.MasterAddr = utils.UnwrapScheme(c.MasterAddr)
// MasterAddr's format may be "host:port" or ":port"
host, port, err := net.SplitHostPort(c.MasterAddr)
if err != nil {
return terror.ErrMasterHostPortNotValid.Delegate(err, c.MasterAddr)
}
Expand All @@ -236,8 +237,9 @@ func (c *Config) adjust() error {
}
c.AdvertiseAddr = c.MasterAddr
} else {
// AdvertiseAddr's format may be "scheme://host:port" or "host:port"
host, port, err = net.SplitHostPort(utils.UnwrapScheme(c.AdvertiseAddr))
c.AdvertiseAddr = utils.UnwrapScheme(c.AdvertiseAddr)
// AdvertiseAddr's format should be "host:port"
host, port, err = net.SplitHostPort(c.AdvertiseAddr)
if err != nil {
return terror.ErrMasterAdvertiseAddrNotValid.Delegate(err, c.AdvertiseAddr)
}
Expand Down Expand Up @@ -294,10 +296,14 @@ func (c *Config) adjust() error {

if c.PeerUrls == "" {
c.PeerUrls = defaultPeerUrls
} else {
c.PeerUrls = utils.WrapSchemes(c.PeerUrls, c.SSLCA != "")
}

if c.AdvertisePeerUrls == "" {
c.AdvertisePeerUrls = c.PeerUrls
} else {
c.AdvertisePeerUrls = utils.WrapSchemes(c.AdvertisePeerUrls, c.SSLCA != "")
}

if c.InitialCluster == "" {
Expand All @@ -306,12 +312,18 @@ func (c *Config) adjust() error {
items[i] = fmt.Sprintf("%s=%s", c.Name, item)
}
c.InitialCluster = strings.Join(items, ",")
} else {
c.InitialCluster = utils.WrapSchemesForInitialCluster(c.InitialCluster, c.SSLCA != "")
}

if c.InitialClusterState == "" {
c.InitialClusterState = defaultInitialClusterState
}

if c.Join != "" {
c.Join = utils.WrapSchemes(c.Join, c.SSLCA != "")
}

return err
}

Expand Down
10 changes: 8 additions & 2 deletions dm/worker/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ func (c *Config) Parse(arguments []string) error {

// adjust adjusts the config.
func (c *Config) adjust() error {
host, port, err := net.SplitHostPort(utils.UnwrapScheme(c.WorkerAddr))
c.WorkerAddr = utils.UnwrapScheme(c.WorkerAddr)
host, port, err := net.SplitHostPort(c.WorkerAddr)
if err != nil {
return terror.ErrWorkerHostPortNotValid.Delegate(err, c.WorkerAddr)
}
Expand All @@ -180,7 +181,8 @@ func (c *Config) adjust() error {
}
c.AdvertiseAddr = c.WorkerAddr
} else {
host, port, err = net.SplitHostPort(utils.UnwrapScheme(c.AdvertiseAddr))
c.AdvertiseAddr = utils.UnwrapScheme(c.AdvertiseAddr)
host, port, err = net.SplitHostPort(c.AdvertiseAddr)
if err != nil {
return terror.ErrWorkerHostPortNotValid.Delegate(err, c.AdvertiseAddr)
}
Expand All @@ -194,6 +196,10 @@ func (c *Config) adjust() error {
c.Name = c.AdvertiseAddr
}

if c.Join != "" {
c.Join = utils.WrapSchemes(c.Join, c.SSLCA != "")
}

return nil
}

Expand Down
3 changes: 2 additions & 1 deletion dm/worker/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/dm/pkg/ha"
"github.com/pingcap/dm/pkg/log"
"github.com/pingcap/dm/pkg/terror"
"github.com/pingcap/dm/pkg/utils"
)

// GetJoinURLs gets the endpoints from the join address.
Expand All @@ -53,7 +54,7 @@ func (s *Server) JoinMaster(endpoints []string) error {

for _, endpoint := range endpoints {
ctx1, cancel1 := context.WithTimeout(ctx, 3*time.Second)
conn, err := grpc.DialContext(ctx1, endpoint, grpc.WithBlock(), tls.ToGRPCDialOption(), grpc.WithBackoffMaxDelay(3*time.Second))
conn, err := grpc.DialContext(ctx1, utils.UnwrapScheme(endpoint), grpc.WithBlock(), tls.ToGRPCDialOption(), grpc.WithBackoffMaxDelay(3*time.Second))
cancel1()
if err != nil {
if conn != nil {
Expand Down
40 changes: 40 additions & 0 deletions pkg/utils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,43 @@ func UnwrapScheme(s string) string {
}
return s
}

func wrapScheme(s string, https bool) string {
if s == "" {
return s
}
if strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://") {
return s
}
if https {
return "https://" + s
}
return "http://" + s
}

// WrapSchemes adds http or https scheme to input if missing. input could be a comma-separated list
// if input has wrong scheme, don't correct it (maybe user deliberately?)
func WrapSchemes(s string, https bool) string {
items := strings.Split(s, ",")
output := make([]string, 0, len(items))
for _, s := range items {
output = append(output, wrapScheme(s, https))
}
return strings.Join(output, ",")
}

// WrapSchemesForInitialCluster acts like WrapSchemes, except input is "name=URL,..."
func WrapSchemesForInitialCluster(s string, https bool) string {
items := strings.Split(s, ",")
output := make([]string, 0, len(items))
for _, item := range items {
kv := strings.Split(item, "=")
if len(kv) != 2 {
output = append(output, item)
continue
}

output = append(output, kv[0]+"="+wrapScheme(kv[1], https))
}
return strings.Join(output, ",")
}
38 changes: 38 additions & 0 deletions pkg/utils/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,46 @@ func (t *testUtilsSuite) TestUnwrapScheme(c *C) {
"httpsdfpoje.com",
"httpsdfpoje.com",
},
{
"",
"",
},
}
for _, ca := range cases {
c.Assert(UnwrapScheme(ca.old), Equals, ca.new)
}
}

func (t *testUtilsSuite) TestWrapSchemes(c *C) {
cases := []struct {
old string
http string
https string
}{
{
"0.0.0.0:123",
"http://0.0.0.0:123",
"https://0.0.0.0:123",
},
{
"abc.com:123",
"http://abc.com:123",
"https://abc.com:123",
},
{
// if input has wrong scheme, don't correct it (maybe user deliberately?)
"abc.com:123,http://abc.com:123,0.0.0.0:123,https://0.0.0.0:123",
"http://abc.com:123,http://abc.com:123,http://0.0.0.0:123,https://0.0.0.0:123",
"https://abc.com:123,http://abc.com:123,https://0.0.0.0:123,https://0.0.0.0:123",
},
{
"",
"",
"",
},
}
for _, ca := range cases {
c.Assert(WrapSchemes(ca.old, false), Equals, ca.http)
c.Assert(WrapSchemes(ca.old, true), Equals, ca.https)
}
}
2 changes: 1 addition & 1 deletion tests/ha_master/conf/dm-worker1.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
name = "worker1"
join = "localhost:8261,localhost:8361,localhost:8461,localhost:8561,localhost:8661"
join = "localhost:8261,http://localhost:8361,localhost:8461,localhost:8561,localhost:8661"

0 comments on commit 92b3e58

Please sign in to comment.