diff --git a/dm/master/config.go b/dm/master/config.go index 937f3ccf64..eccff05922 100644 --- a/dm/master/config.go +++ b/dm/master/config.go @@ -58,6 +58,7 @@ func NewConfig() *Config { fs.BoolVar(&cfg.printSampleConfig, "print-sample-config", false, "print sample config file of dm-worker") fs.StringVar(&cfg.ConfigFile, "config", "", "path to config file") fs.StringVar(&cfg.MasterAddr, "master-addr", "", "master API server and status addr") + fs.StringVar(&cfg.AdvertiseAddr, "advertise-addr", "", `advertise address for client traffic (default "${master-addr}")`) fs.StringVar(&cfg.LogLevel, "L", "info", "log level: debug, info, warn, error, fatal") fs.StringVar(&cfg.LogFile, "log-file", "", "log file path") //fs.StringVar(&cfg.LogRotate, "log-rotate", "day", "log file rotate type, hour/day") @@ -202,18 +203,25 @@ func (c *Config) configFromFile(path string) error { // adjust adjusts configs func (c *Config) adjust() error { // MasterAddr's format may be "host:port" or ":port" - _, _, err := net.SplitHostPort(c.MasterAddr) + host, port, err := net.SplitHostPort(c.MasterAddr) if err != nil { return terror.ErrMasterHostPortNotValid.Delegate(err, c.MasterAddr) } - // AdvertiseAddr's format must be "host:port" - host, port, err := net.SplitHostPort(c.AdvertiseAddr) - if err != nil { - return terror.ErrMasterAdvertiseAddrNotValid.Delegate(err, c.AdvertiseAddr) - } - if len(host) == 0 || len(port) == 0 { - return terror.ErrMasterAdvertiseAddrNotValid.Generate(c.AdvertiseAddr) + if c.AdvertiseAddr == "" { + if host == "" || host == "0.0.0.0" || len(port) == 0 { + return terror.ErrMasterHostPortNotValid.Generatef("master-addr (%s) must include the 'host' part (should not be '0.0.0.0') when advertise-addr is not set", c.MasterAddr) + } + c.AdvertiseAddr = c.MasterAddr + } else { + // AdvertiseAddr's format must be "host:port" + host, port, err = net.SplitHostPort(c.AdvertiseAddr) + if err != nil { + return terror.ErrMasterAdvertiseAddrNotValid.Delegate(err, c.AdvertiseAddr) + } + if len(host) == 0 || host == "0.0.0.0" || len(port) == 0 { + return terror.ErrMasterAdvertiseAddrNotValid.Generate(c.AdvertiseAddr) + } } c.DeployMap = make(map[string]string) diff --git a/dm/master/config_test.go b/dm/master/config_test.go index 0699c51071..ed8a83efac 100644 --- a/dm/master/config_test.go +++ b/dm/master/config_test.go @@ -347,3 +347,23 @@ func (t *testConfigSuite) TestParseURLs(c *check.C) { } } } + +func (t *testConfigSuite) TestAdjustAddr(c *check.C) { + cfg := NewConfig() + c.Assert(cfg.configFromFile(defaultConfigFile), check.IsNil) + c.Assert(cfg.adjust(), check.IsNil) + + // invalid `advertise-addr` + cfg.AdvertiseAddr = "127.0.0.1" + c.Assert(terror.ErrMasterAdvertiseAddrNotValid.Equal(cfg.adjust()), check.IsTrue) + cfg.AdvertiseAddr = "0.0.0.0:8261" + c.Assert(terror.ErrMasterAdvertiseAddrNotValid.Equal(cfg.adjust()), check.IsTrue) + + // clear `advertise-addr`, still invalid because no `host` in `master-addr`. + cfg.AdvertiseAddr = "" + c.Assert(terror.ErrMasterHostPortNotValid.Equal(cfg.adjust()), check.IsTrue) + + cfg.MasterAddr = "127.0.0.1:8261" + c.Assert(cfg.adjust(), check.IsNil) + c.Assert(cfg.AdvertiseAddr, check.Equals, cfg.MasterAddr) +} diff --git a/dm/worker/config.go b/dm/worker/config.go index 42ad05719f..3cc7c3c3b2 100644 --- a/dm/worker/config.go +++ b/dm/worker/config.go @@ -157,7 +157,7 @@ func (c *Config) Parse(arguments []string) error { // adjust adjusts the config. func (c *Config) adjust() error { - host, _, err := net.SplitHostPort(c.WorkerAddr) + host, port, err := net.SplitHostPort(c.WorkerAddr) if err != nil { return terror.ErrWorkerHostPortNotValid.Delegate(err, c.WorkerAddr) } @@ -168,9 +168,12 @@ func (c *Config) adjust() error { } c.AdvertiseAddr = c.WorkerAddr } else { - host, _, err = net.SplitHostPort(c.AdvertiseAddr) - if err != nil || host == "" || host == "0.0.0.0" { - return terror.ErrWorkerHostPortNotValid.AnnotateDelegate(err, "advertise-addr (%s) must include the 'host' part and should not be '0.0.0.0'", c.AdvertiseAddr) + host, port, err = net.SplitHostPort(c.AdvertiseAddr) + if err != nil { + return terror.ErrWorkerHostPortNotValid.Delegate(err, c.AdvertiseAddr) + } + if host == "" || host == "0.0.0.0" || len(port) == 0 { + return terror.ErrWorkerHostPortNotValid.Generate("advertise-addr (%s) must include the 'host' part and should not be '0.0.0.0'", c.AdvertiseAddr) } } diff --git a/dm/worker/config_test.go b/dm/worker/config_test.go index bbed7ad569..8c3393fc0d 100644 --- a/dm/worker/config_test.go +++ b/dm/worker/config_test.go @@ -12,3 +12,37 @@ // limitations under the License. package worker + +import ( + "github.com/pingcap/check" + + "github.com/pingcap/dm/pkg/terror" +) + +var ( + defaultConfigFile = "./dm-worker.toml" + _ = check.Suite(&testConfigSuite{}) +) + +type testConfigSuite struct { +} + +func (t *testConfigSuite) TestAdjustAddr(c *check.C) { + cfg := NewConfig() + c.Assert(cfg.configFromFile(defaultConfigFile), check.IsNil) + c.Assert(cfg.adjust(), check.IsNil) + + // invalid `advertise-addr` + cfg.AdvertiseAddr = "127.0.0.1" + c.Assert(terror.ErrWorkerHostPortNotValid.Equal(cfg.adjust()), check.IsTrue) + cfg.AdvertiseAddr = "0.0.0.0:8262" + c.Assert(terror.ErrWorkerHostPortNotValid.Equal(cfg.adjust()), check.IsTrue) + + // clear `advertise-addr`, still invalid because no `host` in `worker-addr`. + cfg.AdvertiseAddr = "" + c.Assert(terror.ErrWorkerHostPortNotValid.Equal(cfg.adjust()), check.IsTrue) + + cfg.WorkerAddr = "127.0.0.1:8262" + c.Assert(cfg.adjust(), check.IsNil) + c.Assert(cfg.AdvertiseAddr, check.Equals, cfg.WorkerAddr) +}