Skip to content

Commit

Permalink
feat: Replace existing codebase with DNSTap implementation
Browse files Browse the repository at this point in the history
- New approach using DNSTap instead of DNS server specific plugin
- Remove existing implementation
  • Loading branch information
nikitawootten committed Dec 29, 2024
1 parent e99fb57 commit 03d6a56
Show file tree
Hide file tree
Showing 38 changed files with 737 additions and 1,844 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ jobs:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- uses: DeterminateSystems/magic-nix-cache-action@main
- name: Build firewall controller package
run: nix build .#firewall-controller
- name: Build modified CoreDNS package
run: nix build .#coredns
run: nix build .#
go-tests:
runs-on: ubuntu-latest
steps:
Expand Down
13 changes: 5 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@ test: test-unit
test-unit:
go test ./...

.PHONY: codegen
codegen:
go generate ./...

COREDNS_PORT := 5300
COREDNS_COREFILE := support/TestCorefile

.PHONY: run-coredns
run-coredns:
nix run .#coredns -- -conf support/TestCorefile -p $(COREDNS_PORT)
nix run nixpkgs#coredns -- -conf $(COREDNS_COREFILE) -p $(COREDNS_PORT)

.PHONY: run-firewall-controller
run-firewall-controller:
nix run .#firewall-controller -- server --address :8080
.PHONY: run-dns-firewall-controller
run-dns-firewall-controller:
nix run .#dns-firewall-controller

.PHONY: send-dns-request
send-dns-request:
Expand Down
99 changes: 99 additions & 0 deletions cmd/dns-firewall-controller/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package main

import (
"fmt"
"log"
"net"
"os"

"github.com/coreos/go-iptables/iptables"
"github.com/nikitawootten/dns-firewall-controller/controller"
"github.com/nikitawootten/dns-firewall-controller/firewall"
"github.com/urfave/cli/v2"
)

func createListener(c *cli.Context) (net.Listener, error) {
address := c.String("address")

log.Printf("Starting DNS Firewall Controller on address '%v'", address)
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, fmt.Errorf("failed to listen on address '%v': %w", address, err)
}

return listener, nil
}

func main() {
app := &cli.App{
Name: "dns-firewall-controller",
Description: "Open firewall rules based on DNS responses",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "address",
Usage: "Specify the address to listen on",
Value: ":6000",
},
},
Commands: []*cli.Command{
{
Name: "mock",
Category: "backend",
Usage: "Start the controller with a mock (no-op) backend",
Action: func(c *cli.Context) error {
listener, err := createListener(c)
if err != nil {
return err
}

backend := firewall.NewMockBackend()
return controller.Start(backend, listener)
},
},
{
Name: "iptables",
Category: "backend",
Usage: "Start the controller with an iptables backend",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "table",
Usage: "Specify the iptables table to use",
Value: "filter",
},
&cli.StringFlag{
Name: "chain",
Usage: "Specify the iptables chain to use",
Value: "INPUT",
},
},
Action: func(c *cli.Context) error {
table := c.String("table")
chain := c.String("chain")

listener, err := createListener(c)
if err != nil {
return err
}

iptables, err := iptables.New()
if err != nil {
return err
}

config := firewall.IPTablesFirewallBackend{
Table: table,
Chain: chain,
IPTables: iptables,
}

backend := firewall.NewIPTablesBackend(&config)
return controller.Start(backend, listener)
},
},
},
}

if err := app.Run(os.Args); err != nil {
log.Fatal(err)
}
}
36 changes: 36 additions & 0 deletions controller/wrapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package controller

import (
"net"
"os"
"os/signal"
"syscall"

"github.com/nikitawootten/dns-firewall-controller/dns"
"github.com/nikitawootten/dns-firewall-controller/firewall"
)

func Start(backend firewall.FirewallBackend, listener net.Listener) error {
receiver := dns.NewDNSTapReceiver(listener, func(response *dns.DNSResponse) {
rules := firewall.FirewallRulesFromDNSResponse(response)
for _, rule := range rules {
backend.AddRule(rule)
}
})

if err := backend.Start(); err != nil {
return err
}

signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt, syscall.SIGTERM)
go func() {
<-signals
backend.Stop()
listener.Close()
os.Exit(0)
}()

receiver.Start()
return nil
}
7 changes: 7 additions & 0 deletions default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{ pkgs, ... }:
pkgs.buildGoModule {
pname = "dns-firewall-controller";
version = "0.1.0";
src = ./.;
vendorHash = "sha256-moCBoEjkhGE1UgGb9Pk894RgxGMImZXJ4u9rMYNtWzY=";
}
130 changes: 130 additions & 0 deletions dns/dnstap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package dns

import (
"fmt"
"net"

dnstap "github.com/dnstap/golang-dnstap"
"github.com/miekg/dns"
)

type DNSRecordType string

const (
ARecord DNSRecordType = "A"
AAAARecord DNSRecordType = "AAAA"
)

type DNSRecord struct {
RecordType DNSRecordType
RecordIP net.IP
TTL uint32
}

type DNSResponse struct {
SourceAddress net.IP
Records []DNSRecord
}

func ParseDNSTapMessage(dnstapMessage *dnstap.Message) (*DNSResponse, error) {
sourceAddress, err := extractSourceAddress(dnstapMessage)
if err != nil {
return nil, err
}

records, err := extractRecords(dnstapMessage)
if err != nil {
return nil, err
}

record := DNSResponse{
SourceAddress: sourceAddress,
Records: records,
}

return &record, nil
}

func ParseDNSTap(dt *dnstap.Dnstap) (*DNSResponse, error) {
if dt == nil {
return nil, fmt.Errorf("nil dnstap message")
}

dnstapMessage := dt.GetMessage()
if dnstapMessage == nil {
return nil, fmt.Errorf("nil dnstap message")
}

return ParseDNSTapMessage(dnstapMessage)
}

func dnsTapCallbackAdapter(callback func(*DNSResponse)) func(*dnstap.Dnstap) {
return func(dt *dnstap.Dnstap) {
response, err := ParseDNSTap(dt)
if err != nil {
return
}

callback(response)
}
}

func extractSourceAddress(dnstapMessage *dnstap.Message) (net.IP, error) {
if dnstapMessage == nil {
return nil, fmt.Errorf("nil dnstap message")
}

rawAddress := dnstapMessage.GetQueryAddress()
if rawAddress == nil {
return nil, fmt.Errorf("nil query address")
}

return rawAddress, nil
}

func extractRecords(dnstapMessage *dnstap.Message) ([]DNSRecord, error) {
if dnstapMessage == nil {
return nil, fmt.Errorf("nil dnstap message")
}

records := []DNSRecord{}

if dnstapMessage.ResponseMessage == nil {
return nil, fmt.Errorf("nil response message")
}

dnsMsg := new(dns.Msg)
if err := dnsMsg.Unpack(dnstapMessage.ResponseMessage); err != nil {
return nil, fmt.Errorf("failed to unpack response message: %v", err)
}

for _, answer := range dnsMsg.Answer {
record, err := extractRecord(answer)
if err != nil {
continue
}

records = append(records, record)
}

return records, nil
}

func extractRecord(answer dns.RR) (DNSRecord, error) {
switch answer := answer.(type) {
case *dns.A:
return DNSRecord{
RecordType: ARecord,
RecordIP: answer.A,
TTL: answer.Hdr.Ttl,
}, nil
case *dns.AAAA:
return DNSRecord{
RecordType: AAAARecord,
RecordIP: answer.AAAA,
TTL: answer.Hdr.Ttl,
}, nil
default:
return DNSRecord{}, fmt.Errorf("unsupported record type: %v", answer)
}
}
65 changes: 65 additions & 0 deletions dns/output.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package dns

import (
"log"

dnstap "github.com/dnstap/golang-dnstap"

"google.golang.org/protobuf/proto"
)

const outputChannelSize = 32

type callbackDNSTapOutput struct {
callback func(*dnstap.Dnstap)
outputs []dnstap.Output
data chan []byte
done chan struct{}
}

func NewCallbackOutput(callback func(*dnstap.Dnstap)) dnstap.Output {
return &callbackDNSTapOutput{
callback: callback,
data: make(chan []byte, outputChannelSize),
done: make(chan struct{}),
}
}

func (o *callbackDNSTapOutput) Add(output dnstap.Output) {
o.outputs = append(o.outputs, output)
}

func (o *callbackDNSTapOutput) Close() {
close(o.data)
<-o.done
}

func (o *callbackDNSTapOutput) GetOutputChannel() chan []byte {
return o.data
}

func (o *callbackDNSTapOutput) RunOutputLoop() {
for payload := range o.data {
// Mirror the payload to all outputs
for _, output := range o.outputs {
output.GetOutputChannel() <- payload
}

o.processFrame(payload)
}

for _, output := range o.outputs {
output.Close()
}
close(o.done)
}

func (o *callbackDNSTapOutput) processFrame(frame []byte) {
dt := &dnstap.Dnstap{}
if err := proto.Unmarshal(frame, dt); err != nil {
log.Printf("Error unmarshaling Dnstap message: %v", err)
return
}

o.callback(dt)
}
Loading

0 comments on commit 03d6a56

Please sign in to comment.