diff --git a/agent/config.go b/agent/config.go index 9052e94..b691a22 100644 --- a/agent/config.go +++ b/agent/config.go @@ -36,7 +36,7 @@ type CoreConfig struct { SuffixFile string SuffixDir string NoDefaults bool - Forward Forward + Forward *Forward } // Generate a CoreDNS (Caddy) style configuration block as a string. @@ -63,7 +63,8 @@ func ConfigFromEnv(e env.Environment) *CoreConfig { upstream2 string ) - var cc CoreConfig + cc := new(CoreConfig) + cc.Forward = new(Forward) if err := env.Parse(e, env.Schema{ "DONUT_DNS_PORT": env.Int(&cc.Port, false), "DONUT_DNS_NO_DEBUG": env.Bool(&cc.NoDebug, false), @@ -85,20 +86,18 @@ func ConfigFromEnv(e env.Environment) *CoreConfig { panic(err) } - var upstreams []string if upstream1 != "" { - upstreams = append(upstreams, upstream1) + cc.Forward.Addresses = append(cc.Forward.Addresses, upstream1) } if upstream2 != "" { - upstreams = append(upstreams, upstream2) + cc.Forward.Addresses = append(cc.Forward.Addresses, upstream2) } - cc.Forward.Addresses = upstreams cc.Allows = split(allow) cc.Blocks = split(block) cc.Suffix = split(suffix) - return &cc + return cc } // Log cc to plog. diff --git a/agent/config.tmpl b/agent/config.tmpl index fe4d25b..69971e5 100644 --- a/agent/config.tmpl +++ b/agent/config.tmpl @@ -12,6 +12,13 @@ {{end}} {{range .Suffix}}suffix {{.}} {{end}} + {{if eq (len .Forward.Addresses) 2 }} + upstream_1 {{index .Forward.Addresses 0}} + upstream_2 {{index .Forward.Addresses 1}} + {{else}} + upstream_1 {{index .Forward.Addresses 0}} + {{end}} + {{if .Forward.ServerName}}forward_server_name {{.Forward.ServerName}}{{end}} } forward . {{range .Forward.Addresses}}{{.}} {{end}}{ tls_servername {{.Forward.ServerName}} diff --git a/agent/config_test.go b/agent/config_test.go index f60a8be..dafaaea 100644 --- a/agent/config_test.go +++ b/agent/config_test.go @@ -19,7 +19,7 @@ func TestCoreConfig_Generate(t *testing.T) { BlockFile: "/etc/block.list", Suffix: []string{"fb.com", "twitter.com"}, SuffixFile: "/etc/suffix.list", - Forward: Forward{ + Forward: &Forward{ Addresses: []string{"1.1.1.1", "1.0.0.1"}, ServerName: "cloudflare-dns.com", }, @@ -59,7 +59,7 @@ func TestCoreConfig_Generate_less(t *testing.T) { Allows: nil, Blocks: nil, NoDefaults: true, - Forward: Forward{ + Forward: &Forward{ Addresses: []string{"8.8.8.8"}, ServerName: "google.dns", }, @@ -122,7 +122,7 @@ func TestConfigFromEnv(t *testing.T) { SuffixFile: "/etc/suffix.list", SuffixDir: "/etc/suffixes", NoDefaults: false, - Forward: Forward{ + Forward: &Forward{ Addresses: []string{"8.8.8.8", "8.8.4.4"}, ServerName: "dns.google", }, @@ -159,7 +159,7 @@ func TestConfigFromEnv_2(t *testing.T) { Blocks: []string{"facebook.com"}, BlockFile: "", NoDefaults: true, - Forward: Forward{ + Forward: &Forward{ Addresses: []string{"8.8.8.8"}, ServerName: "dns.google", }, diff --git a/plugins/donutdns/setup.go b/plugins/donutdns/setup.go index cc063b0..274c908 100644 --- a/plugins/donutdns/setup.go +++ b/plugins/donutdns/setup.go @@ -19,6 +19,7 @@ func Setup(c *caddy.Controller) error { // reconstruct the parts of CoreConfig for initializing the allow/block lists cc := new(agent.CoreConfig) + cc.Forward = new(agent.Forward) for c.Next() { _ = c.RemainingArgs() @@ -65,6 +66,24 @@ func Setup(c *caddy.Controller) error { return c.ArgErr() } cc.Suffix = append(cc.Suffix, c.Val()) + + case "upstream_1": + if !c.NextArg() { + return c.ArgErr() + } + cc.Forward.Addresses = append(cc.Forward.Addresses, c.Val()) + + case "upstream_2": + if !c.NextArg() { + return c.ArgErr() + } + cc.Forward.Addresses = append(cc.Forward.Addresses, c.Val()) + + case "forward_server_name": + if !c.NextArg() { + return c.ArgErr() + } + cc.Forward.ServerName = c.Val() } } } @@ -74,6 +93,8 @@ func Setup(c *caddy.Controller) error { pluginLogger.Infof("domains on explicit allow-list: %d", allow) pluginLogger.Infof("domains on explicit block-list: %d", block) pluginLogger.Infof("domains on suffixes block-list: %d", suffix) + pluginLogger.Infof("forward upstreams: %v", cc.Forward.Addresses) + pluginLogger.Infof("forward name: %s", cc.Forward.ServerName) // Add the Plugin to CoreDNS, so Servers can use it in their plugin chain. dd := DonutDNS{sets: sets} diff --git a/sources/client.go b/sources/client.go index c6fd397..8c5338e 100644 --- a/sources/client.go +++ b/sources/client.go @@ -4,9 +4,11 @@ import ( "context" "net" "net/http" + "strings" "time" "github.com/hashicorp/go-cleanhttp" + "github.com/shoenig/donutdns/agent" ) // client creates an http.Client with an explicit DNS server. This is necessary @@ -19,13 +21,21 @@ import ( // until it's actually ready. A project for a rainy day. // // Totally ripped from https://koraygocmen.medium.com/custom-dns-resolver-for-the-default-http-client-in-go-a1420db38a5d -func client() *http.Client { +func client(fwd *agent.Forward) *http.Client { var ( dnsResolverIP = "1.1.1.1:53" // Cloudflare DNS resolver. dnsResolverProto = "udp" // Protocol to use for the DNS resolver dnsResolverTimeoutMs = 5000 // Timeout (ms) for the DNS resolver (optional) ) + // use the DONUT_DNS_UPSTREAM_1 value if set, setting default port if necessary + if len(fwd.Addresses) > 0 { + dnsResolverIP = fwd.Addresses[0] + if !strings.Contains(dnsResolverIP, ":") { + dnsResolverIP += ":53" + } + } + dialer := &net.Dialer{ Resolver: &net.Resolver{ PreferGo: true, diff --git a/sources/fetch.go b/sources/fetch.go index 2107b8a..6cfdeb0 100644 --- a/sources/fetch.go +++ b/sources/fetch.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/hashicorp/go-set" + "github.com/shoenig/donutdns/agent" "github.com/shoenig/donutdns/output" "github.com/shoenig/donutdns/sources/extract" "github.com/shoenig/ignore" @@ -17,18 +18,20 @@ type Downloader interface { } type downloader struct { - logger output.Logger + logger output.Logger + forward *agent.Forward } // NewDownloader creates a new Downloader for downloading source lists. -func NewDownloader(logger output.Logger) Downloader { +func NewDownloader(fwd *agent.Forward, logger output.Logger) Downloader { return &downloader{ - logger: logger, + forward: fwd, + logger: logger, } } func (d *downloader) Download(lists *Lists) (*set.Set[string], error) { - g := NewGetter(d.logger, extract.New(extract.Generic)) + g := NewGetter(d.logger, d.forward, extract.New(extract.Generic)) combo := set.New[string](100) for _, source := range lists.All() { single, err := g.Get(source) @@ -54,12 +57,9 @@ type getter struct { } // NewGetter creates a new Getter, using Extractor ex to extract domains. -func NewGetter(logger output.Logger, ex extract.Extractor) Getter { +func NewGetter(logger output.Logger, fwd *agent.Forward, ex extract.Extractor) Getter { return &getter{ - client: client( - // todo: pass in one of the upstreams - // currently hard-code cloudflare for bootstrapping the sources - ), + client: client(fwd), ex: ex, logger: logger, } diff --git a/sources/fetch_test.go b/sources/fetch_test.go index 2521c99..bc596f2 100644 --- a/sources/fetch_test.go +++ b/sources/fetch_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/coredns/coredns/plugin/pkg/log" + "github.com/shoenig/donutdns/agent" "github.com/shoenig/donutdns/sources/extract" "github.com/shoenig/test/must" ) @@ -26,12 +27,23 @@ func Test_Get(t *testing.T) { defer ts.Close() ex := extract.New(extract.Generic) - g := NewGetter(pLog, ex) + fwd := new(agent.Forward) + + g := NewGetter(pLog, fwd, ex) s, err := g.Get(ts.URL) must.NoError(t, err) must.EqOp(t, 3, s.Size()) } +func Test_Get_bad_upstream(t *testing.T) { + ex := extract.New(extract.Generic) + fwd := &agent.Forward{Addresses: []string{"0.0.0.0"}} + + g := NewGetter(pLog, fwd, ex) + _, err := g.Get("http://example.com") + must.ErrorContains(t, err, "dial tcp: lookup example.com") +} + func Test_Download(t *testing.T) { hit := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -48,7 +60,9 @@ func Test_Download(t *testing.T) { Miners: []string{ts.URL}, } - d := NewDownloader(pLog) + fwd := new(agent.Forward) + d := NewDownloader(fwd, pLog) + s, err := d.Download(lists) must.NoError(t, err) must.EqOp(t, 3, s.Size()) diff --git a/sources/sets.go b/sources/sets.go index 0ec47df..66bc4fa 100644 --- a/sources/sets.go +++ b/sources/sets.go @@ -28,7 +28,7 @@ func New(logger output.Logger, cc *agent.CoreConfig) *Sets { // initialize defaults if enabled if !cc.NoDefaults { - defaults(block, logger) + defaults(cc.Forward, block, logger) } // insert individual custom allowable domains @@ -106,8 +106,8 @@ func (s *Sets) BlockBySuffix(domain string) bool { return s.BlockBySuffix(domain[idx+1:]) } -func defaults(set *set.Set[string], logger output.Logger) { - d := NewDownloader(logger) +func defaults(fwd *agent.Forward, set *set.Set[string], logger output.Logger) { + d := NewDownloader(fwd, logger) s, err := d.Download(Defaults()) if err != nil { panic(err)