Skip to content

Commit

Permalink
bootstrapDNS not used for upstream DNS resolution (#242) (#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xERR0R authored Aug 21, 2021
1 parent 7ea4b54 commit 850baf0
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 68 deletions.
9 changes: 4 additions & 5 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
//nolint:gochecknoglobals
var (
configPath string
cfg config.Config
apiHost string
apiPort uint16
)
Expand Down Expand Up @@ -58,11 +57,11 @@ func init() {
}

func initConfig() {
cfg = config.NewConfig(configPath, false)
log.ConfigureLogger(cfg.LogLevel, cfg.LogFormat, cfg.LogTimestamp)
config.LoadConfig(configPath, false)
log.ConfigureLogger(config.GetConfig().LogLevel, config.GetConfig().LogFormat, config.GetConfig().LogTimestamp)

if cfg.HTTPPort != "" {
split := strings.Split(cfg.HTTPPort, ":")
if config.GetConfig().HTTPPort != "" {
split := strings.Split(config.GetConfig().HTTPPort, ":")

var p uint64
p, err := strconv.ParseUint(strings.TrimSpace(split[len(split)-1]), 10, 16)
Expand Down
38 changes: 7 additions & 31 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ import (
"blocky/evt"
"blocky/server"
"blocky/util"
"context"
"fmt"
"net"
"net/http"
"os"
"os/signal"
Expand Down Expand Up @@ -36,17 +33,17 @@ func newServeCommand() *cobra.Command {
func startServer(_ *cobra.Command, _ []string) {
printBanner()

cfg = config.NewConfig(configPath, true)
log.ConfigureLogger(cfg.LogLevel, cfg.LogFormat, cfg.LogTimestamp)
config.LoadConfig(configPath, true)
log.ConfigureLogger(config.GetConfig().LogLevel, config.GetConfig().LogFormat, config.GetConfig().LogTimestamp)

configureHTTPClient(&cfg)
configureHTTPClient(config.GetConfig())

signals := make(chan os.Signal, 1)
done = make(chan bool, 1)

signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)

srv, err := server.NewServer(&cfg)
srv, err := server.NewServer(config.GetConfig())
util.FatalOnError("cant start server: ", err)

srv.Start()
Expand All @@ -63,30 +60,9 @@ func startServer(_ *cobra.Command, _ []string) {
}

func configureHTTPClient(cfg *config.Config) {
if cfg.BootstrapDNS != (config.Upstream{}) {
if cfg.BootstrapDNS.Net == config.NetTCPUDP {
dns := net.JoinHostPort(cfg.BootstrapDNS.Host, fmt.Sprint(cfg.BootstrapDNS.Port))
log.Log().Debugf("using %s as bootstrap dns server", dns)

r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Millisecond * time.Duration(2000),
}
return d.DialContext(ctx, "udp", dns)
}}

http.DefaultTransport = &http.Transport{
Dial: (&net.Dialer{
Timeout: 5 * time.Second,
Resolver: r,
}).Dial,
TLSHandshakeTimeout: 5 * time.Second,
}
} else {
log.Log().Fatal("bootstrap dns net should be tcp+udp")
}
http.DefaultTransport = &http.Transport{
Dial: (util.Dialer(cfg)).Dial,
TLSHandshakeTimeout: 5 * time.Second,
}
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
var _ = Describe("Serve command", func() {
When("Serve command is called", func() {
It("should start DNS server", func() {
cfg.BootstrapDNS = config.Upstream{
config.GetConfig().BootstrapDNS = config.Upstream{
Net: "tcp+udp",
Host: "1.1.1.1",
Port: 53,
Expand Down
17 changes: 13 additions & 4 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,11 @@ type QueryLogConfig struct {
LogRetentionDays uint64 `yaml:"logRetentionDays"`
}

// NewConfig creates new config from YAML file
func NewConfig(path string, mandatory bool) Config {
// nolint:gochecknoglobals
var config = &Config{}

// LoadConfig creates new config from YAML file
func LoadConfig(path string, mandatory bool) {
cfg := Config{}
setDefaultValues(&cfg)

Expand All @@ -296,7 +299,8 @@ func NewConfig(path string, mandatory bool) Config {
if errors.Is(err, os.ErrNotExist) && !mandatory {
// config file does not exist
// return config with default values
return cfg
config = &cfg
return
}

log.Log().Fatal("Can't read config file: ", err)
Expand All @@ -311,7 +315,12 @@ func NewConfig(path string, mandatory bool) Config {
log.Log().Fatal("LogFormat should be 'text' or 'json'")
}

return cfg
config = &cfg
}

// GetConfig returns the current config
func GetConfig() *Config {
return config
}

func setDefaultValues(cfg *Config) {
Expand Down
54 changes: 27 additions & 27 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@ var _ = Describe("Config", func() {
err := os.Chdir("../testdata")
Expect(err).Should(Succeed())

cfg := NewConfig("config.yml", true)

Expect(cfg.Port).Should(Equal("55555"))
Expect(cfg.Upstream.ExternalResolvers["default"]).Should(HaveLen(3))
Expect(cfg.Upstream.ExternalResolvers["default"][0].Host).Should(Equal("8.8.8.8"))
Expect(cfg.Upstream.ExternalResolvers["default"][1].Host).Should(Equal("8.8.4.4"))
Expect(cfg.Upstream.ExternalResolvers["default"][2].Host).Should(Equal("1.1.1.1"))
Expect(cfg.CustomDNS.Mapping.HostIPs).Should(HaveLen(2))
Expect(cfg.CustomDNS.Mapping.HostIPs["my.duckdns.org"][0]).Should(Equal(net.ParseIP("192.168.178.3")))
Expect(cfg.CustomDNS.Mapping.HostIPs["multiple.ips"][0]).Should(Equal(net.ParseIP("192.168.178.3")))
Expect(cfg.CustomDNS.Mapping.HostIPs["multiple.ips"][1]).Should(Equal(net.ParseIP("192.168.178.4")))
Expect(cfg.CustomDNS.Mapping.HostIPs["multiple.ips"][2]).Should(Equal(
LoadConfig("config.yml", true)

Expect(config.Port).Should(Equal("55555"))
Expect(config.Upstream.ExternalResolvers["default"]).Should(HaveLen(3))
Expect(config.Upstream.ExternalResolvers["default"][0].Host).Should(Equal("8.8.8.8"))
Expect(config.Upstream.ExternalResolvers["default"][1].Host).Should(Equal("8.8.4.4"))
Expect(config.Upstream.ExternalResolvers["default"][2].Host).Should(Equal("1.1.1.1"))
Expect(config.CustomDNS.Mapping.HostIPs).Should(HaveLen(2))
Expect(config.CustomDNS.Mapping.HostIPs["my.duckdns.org"][0]).Should(Equal(net.ParseIP("192.168.178.3")))
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][0]).Should(Equal(net.ParseIP("192.168.178.3")))
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][1]).Should(Equal(net.ParseIP("192.168.178.4")))
Expect(config.CustomDNS.Mapping.HostIPs["multiple.ips"][2]).Should(Equal(
net.ParseIP("2001:0db8:85a3:08d3:1319:8a2e:0370:7344")))
Expect(cfg.Conditional.Mapping.Upstreams).Should(HaveLen(2))
Expect(cfg.Conditional.Mapping.Upstreams["fritz.box"]).Should(HaveLen(1))
Expect(cfg.Conditional.Mapping.Upstreams["multiple.resolvers"]).Should(HaveLen(2))
Expect(cfg.ClientLookup.Upstream.Host).Should(Equal("192.168.178.1"))
Expect(cfg.ClientLookup.SingleNameOrder).Should(Equal([]uint{2, 1}))
Expect(cfg.Blocking.BlackLists).Should(HaveLen(2))
Expect(cfg.Blocking.WhiteLists).Should(HaveLen(1))
Expect(cfg.Blocking.ClientGroupsBlock).Should(HaveLen(2))

Expect(cfg.Caching.MaxCachingTime).Should(Equal(0))
Expect(cfg.Caching.MinCachingTime).Should(Equal(0))
Expect(config.Conditional.Mapping.Upstreams).Should(HaveLen(2))
Expect(config.Conditional.Mapping.Upstreams["fritz.box"]).Should(HaveLen(1))
Expect(config.Conditional.Mapping.Upstreams["multiple.resolvers"]).Should(HaveLen(2))
Expect(config.ClientLookup.Upstream.Host).Should(Equal("192.168.178.1"))
Expect(config.ClientLookup.SingleNameOrder).Should(Equal([]uint{2, 1}))
Expect(config.Blocking.BlackLists).Should(HaveLen(2))
Expect(config.Blocking.WhiteLists).Should(HaveLen(1))
Expect(config.Blocking.ClientGroupsBlock).Should(HaveLen(2))

Expect(config.Caching.MaxCachingTime).Should(Equal(0))
Expect(config.Caching.MinCachingTime).Should(Equal(0))
})
})
When("config file is malformed", func() {
Expand All @@ -61,7 +61,7 @@ var _ = Describe("Config", func() {

Log().ExitFunc = func(int) { fatal = true }

_ = NewConfig("config.yml", true)
LoadConfig("config.yml", true)
Expect(fatal).Should(BeTrue())
})
})
Expand All @@ -75,7 +75,7 @@ var _ = Describe("Config", func() {
var fatal bool

Log().ExitFunc = func(int) { fatal = true }
_ = NewConfig("config.yml", true)
LoadConfig("config.yml", true)

Expect(fatal).Should(BeTrue())
})
Expand All @@ -84,9 +84,9 @@ var _ = Describe("Config", func() {
err := os.Chdir("../..")
Expect(err).Should(Succeed())

cfg := NewConfig("config.yml", false)
LoadConfig("config.yml", false)

Expect(cfg.LogLevel).Should(Equal("info"))
Expect(config.LogLevel).Should(Equal("info"))
})
})
})
Expand Down
7 changes: 7 additions & 0 deletions resolver/upstream_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func createUpstreamClient(cfg config.Upstream) (client upstreamClient, upstreamU
if cfg.Net == config.NetHTTPS {
return &httpUpstreamClient{
client: &http.Client{
Transport: &http.Transport{
Dial: (util.Dialer(config.GetConfig())).Dial,
TLSHandshakeTimeout: 5 * time.Second,
},
Timeout: defaultTimeout,
},
}, fmt.Sprintf("%s://%s:%d%s", cfg.Net, cfg.Host, cfg.Port, cfg.Path)
Expand All @@ -56,6 +60,7 @@ func createUpstreamClient(cfg config.Upstream) (client upstreamClient, upstreamU
tcpClient: &dns.Client{
Net: cfg.Net,
Timeout: defaultTimeout,
Dialer: util.Dialer(config.GetConfig()),
},
}, net.JoinHostPort(cfg.Host, strconv.Itoa(int(cfg.Port)))
}
Expand All @@ -65,10 +70,12 @@ func createUpstreamClient(cfg config.Upstream) (client upstreamClient, upstreamU
tcpClient: &dns.Client{
Net: "tcp",
Timeout: defaultTimeout,
Dialer: util.Dialer(config.GetConfig()),
},
udpClient: &dns.Client{
Net: "udp",
Timeout: defaultTimeout,
Dialer: util.Dialer(config.GetConfig()),
},
}, net.JoinHostPort(cfg.Host, strconv.Itoa(int(cfg.Port)))
}
Expand Down
39 changes: 39 additions & 0 deletions util/bootstrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package util

import (
"blocky/config"
"blocky/log"

"context"
"fmt"
"net"
"time"
)

// Dialer creates a new dialer instance with bootstrap DNS as resolver
func Dialer(cfg *config.Config) *net.Dialer {
var resolver *net.Resolver

if cfg.BootstrapDNS != (config.Upstream{}) {
if cfg.BootstrapDNS.Net == config.NetTCPUDP {
dns := net.JoinHostPort(cfg.BootstrapDNS.Host, fmt.Sprint(cfg.BootstrapDNS.Port))
log.Log().Debugf("using %s as bootstrap dns server", dns)

resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Millisecond * time.Duration(2000),
}
return d.DialContext(ctx, "udp", dns)
}}
} else {
log.Log().Fatal("bootstrap dns net should be tcp+udp")
}
}

return &net.Dialer{
Timeout: 5 * time.Second,
Resolver: resolver,
}
}

0 comments on commit 850baf0

Please sign in to comment.