Skip to content

Commit

Permalink
feat: use dns_hijack flag instead of tun_dns_server
Browse files Browse the repository at this point in the history
  • Loading branch information
cxz66666 committed Nov 12, 2023
1 parent 4887d49 commit fb2be8c
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 27 deletions.
2 changes: 1 addition & 1 deletion config.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ disable_keep_alive = false
zju_dns_server = "10.10.0.21"
secondary_dns_server = "114.114.114.114"
dns_server_bind = ""
tun_dns_server = ""
dns_hijack = false
debug_dump = false

# Port forwarding
Expand Down
8 changes: 4 additions & 4 deletions init.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type (
ZJUDNSServer string
SecondaryDNSServer string
DNSServerBind string
TUNDNSServer string
DNSHijack bool
DebugDump bool
PortForwardingList []SinglePortForwarding
CustomDNSList []SingleCustomDNS
Expand Down Expand Up @@ -72,7 +72,7 @@ type (
ZJUDNSServer *string `toml:"zju_dns_server"`
SecondaryDNSServer *string `toml:"secondary_dns_server"`
DNSServerBind *string `toml:"dns_server_bind"`
TUNDNSServer *string `toml:"tun_dns_server"`
DNSHijack *bool `toml:"dns_hijack"`
DebugDump *bool `toml:"debug_dump"`
PortForwarding []SinglePortForwardingTOML `toml:"port_forwarding"`
CustomDNS []SingleCustomDNSTOML `toml:"custom_dns"`
Expand Down Expand Up @@ -127,7 +127,7 @@ func parseTOMLConfig(configFile string, conf *Config) error {
conf.ZJUDNSServer = getTOMLVal(confTOML.ZJUDNSServer, "10.10.0.21")
conf.SecondaryDNSServer = getTOMLVal(confTOML.SecondaryDNSServer, "114.114.114.114")
conf.DNSServerBind = getTOMLVal(confTOML.DNSServerBind, "")
conf.TUNDNSServer = getTOMLVal(confTOML.TUNDNSServer, "")
conf.DNSHijack = getTOMLVal(confTOML.DNSHijack, false)

for _, singlePortForwarding := range confTOML.PortForwarding {
if singlePortForwarding.NetworkType == nil {
Expand Down Expand Up @@ -193,7 +193,7 @@ func init() {
flag.StringVar(&conf.ZJUDNSServer, "zju-dns-server", "10.10.0.21", "ZJU DNS server address")
flag.StringVar(&conf.SecondaryDNSServer, "secondary-dns-server", "114.114.114.114", "Secondary DNS server address. Leave empty to use system default DNS server")
flag.StringVar(&conf.DNSServerBind, "dns-server-bind", "", "The address DNS server listens on (e.g. 127.0.0.1:53)")
flag.StringVar(&conf.TUNDNSServer, "tun-dns-server", "", "DNS Server address for TUN interface (e.g. 127.0.0.1). You should not specify the port")
flag.BoolVar(&conf.DNSHijack, "dns-hijack", false, "Hijack all dns query to ZJU Connect")
flag.StringVar(&conf.TwfID, "twf-id", "", "Login using twfID captured (mostly for debug usage)")
flag.StringVar(&tcpPortForwarding, "tcp-port-forwarding", "", "TCP port forwarding (e.g. 0.0.0.0:9898-10.10.98.98:80,127.0.0.1:9899-10.10.98.98:80)")
flag.StringVar(&udpPortForwarding, "udp-port-forwarding", "", "UDP port forwarding (e.g. 127.0.0.1:53-10.10.0.21:53)")
Expand Down
7 changes: 7 additions & 0 deletions internal/terminal_func/terminal_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ type TerminalItem struct {

var terminalFuncList []TerminalItem

var terminalBegin = false

func RegisterTerminalFunc(execName string, fun TerminalFunc) {
terminalFuncList = append(terminalFuncList, TerminalItem{
f: fun,
Expand All @@ -23,6 +25,7 @@ func RegisterTerminalFunc(execName string, fun TerminalFunc) {

func ExecTerminalFunc(ctx context.Context) []error {
var errList []error
terminalBegin = true
for _, item := range terminalFuncList {
log.Println("Exec func on terminal:", item.name)
if err := item.f(ctx); err != nil {
Expand All @@ -34,3 +37,7 @@ func ExecTerminalFunc(ctx context.Context) []error {
}
return errList
}

func IsTermianl() bool {
return terminalBegin
}
6 changes: 5 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func main() {

var vpnStack stack.Stack
if conf.TUNMode {
vpnTUNStack, err := tun.NewStack(vpnClient, conf.TUNDNSServer)
vpnTUNStack, err := tun.NewStack(vpnClient, conf.DNSHijack)
if err != nil {
log.Fatalf("Tun stack setup error: %s", err)
}
Expand Down Expand Up @@ -126,6 +126,10 @@ func main() {
if conf.DNSServerBind != "" {
go service.ServeDNS(conf.DNSServerBind, localResolver)
}
if conf.TUNMode {
clientIP, _ := vpnClient.IP()
go service.ServeDNS(clientIP.String()+":53", localResolver)
}

if conf.SocksBind != "" {
go service.ServeSocks5(conf.SocksBind, vpnDialer, vpnResolver, conf.SocksUser, conf.SocksPasswd)
Expand Down
13 changes: 11 additions & 2 deletions stack/gvisor/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gvisor
import (
"errors"
"github.com/mythologyli/zju-connect/client"
"github.com/mythologyli/zju-connect/internal/terminal_func"
"github.com/mythologyli/zju-connect/internal/zcdns"
"github.com/mythologyli/zju-connect/log"
"gvisor.dev/gvisor/pkg/buffer"
Expand Down Expand Up @@ -80,7 +81,11 @@ func (ep *Endpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error)
if ep.rvpnConn != nil {
n, err := ep.rvpnConn.Write(buf)
if err != nil {
panic(err)
if terminal_func.IsTermianl() {
return list.Len(), nil
} else {
panic(err)
}
}
log.DebugPrintf("Send: wrote %d bytes", n)
log.DebugDumpHex(buf[:n])
Expand Down Expand Up @@ -151,7 +156,11 @@ func (s *Stack) Run() {
buf := make([]byte, MTU)
n, err := s.endpoint.rvpnConn.Read(buf)
if err != nil {
panic(err)
if terminal_func.IsTermianl() {
return
} else {
panic(err)
}
}
log.DebugPrintf("Recv: read %d bytes", n)
log.DebugDumpHex(buf[:n])
Expand Down
19 changes: 14 additions & 5 deletions stack/tun/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tun
import (
"context"
"fmt"
"github.com/mythologyli/zju-connect/internal/terminal_func"
"io"

tun "github.com/cxz66666/sing-tun"
Expand Down Expand Up @@ -44,8 +45,12 @@ func (s *Stack) Run() {

err = s.endpoint.Write(buf[:n])
if err != nil {
log.Printf("Error occurred while writing to TUN stack: %v", err)
panic(err)
if terminal_func.IsTermianl() {
return
} else {
log.Printf("Error occurred while writing to TUN stack: %v", err)
panic(err)
}
}
}
}()
Expand All @@ -55,9 +60,13 @@ func (s *Stack) Run() {
buf := make([]byte, MTU+tun.PacketOffset)
n, err := s.endpoint.Read(buf)
if err != nil {
log.Printf("Error occurred while reading from TUN stack: %v", err)
// TODO graceful shutdown
panic(err)
if terminal_func.IsTermianl() {
return
} else {
log.Printf("Error occurred while reading from TUN stack: %v", err)
// TODO graceful shutdown
panic(err)
}
}

if n < zctcpip.IPv4PacketMinLength {
Expand Down
18 changes: 11 additions & 7 deletions stack/tun/stack_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (s *Stack) AddDnsServer(dnsServer string, targetHost string) error {
return nil
}

func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*Stack, error) {
func NewStack(easyConnectClient *client.EasyConnectClient, dnsHijack bool) (*Stack, error) {
var err error
s := &Stack{}
s.endpoint = &Endpoint{
Expand Down Expand Up @@ -106,6 +106,9 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S
if err != nil {
return nil, err
}
terminal_func.RegisterTerminalFunc("Close Tun Device", func(ctx context.Context) error {
return ifce.Close()
})
s.endpoint.ifce = ifce
s.endpoint.ifceName = tunName
netIfce, err := net.InterfaceByName(tunName)
Expand Down Expand Up @@ -146,12 +149,13 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S
})
},
}

if err = s.AddDnsServer(s.endpoint.ip.String(), "zju.edu.cn"); err != nil {
log.Printf("AddDnsServer failed: %v", err)
}
if err = s.AddDnsServer(s.endpoint.ip.String(), "cc98.org"); err != nil {
log.Printf("AddDnsServer failed: %v", err)
if dnsHijack {
if err = s.AddDnsServer(s.endpoint.ip.String(), "zju.edu.cn"); err != nil {
log.Printf("AddDnsServer failed: %v", err)
}
if err = s.AddDnsServer(s.endpoint.ip.String(), "cc98.org"); err != nil {
log.Printf("AddDnsServer failed: %v", err)
}
}
return s, nil
}
15 changes: 11 additions & 4 deletions stack/tun/stack_linux.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package tun

import (
"context"
tun "github.com/cxz66666/sing-tun"
"github.com/mythologyli/zju-connect/client"
"github.com/mythologyli/zju-connect/internal/terminal_func"
"github.com/mythologyli/zju-connect/log"
"net"
"net/netip"
Expand Down Expand Up @@ -47,7 +49,7 @@ func (s *Stack) AddRoute(target string) error {
return nil
}

func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*Stack, error) {
func NewStack(easyConnectClient *client.EasyConnectClient, dnsHijack bool) (*Stack, error) {
var err error
s := &Stack{}
s.endpoint = &Endpoint{
Expand All @@ -59,7 +61,7 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S
return nil, err
}
ipPrefix, _ := netip.ParsePrefix(s.endpoint.ip.String() + "/8")
tunName := "zjuconnect"
tunName := "ZJU-Connect"
tunName = tun.CalculateInterfaceName(tunName)

tunOptions := tun.Options{
Expand All @@ -68,13 +70,18 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S
Inet4Address: []netip.Prefix{
ipPrefix,
},
AutoRoute: true,
TableIndex: 1897,
}
if dnsHijack {
tunOptions.AutoRoute = true
tunOptions.TableIndex = 1897
}
ifce, err := tun.New(tunOptions)
if err != nil {
return nil, err
}
terminal_func.RegisterTerminalFunc("Close Tun Device", func(ctx context.Context) error {
return ifce.Close()
})
s.endpoint.ifce = ifce
s.endpoint.ifceName = tunName
log.Printf("Interface Name: %s\n", tunName)
Expand Down
13 changes: 10 additions & 3 deletions stack/tun/stack_windows.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package tun

import (
"context"
"fmt"
"github.com/mythologyli/zju-connect/client"
"github.com/mythologyli/zju-connect/internal/terminal_func"
"github.com/mythologyli/zju-connect/log"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun"
Expand Down Expand Up @@ -72,7 +74,7 @@ func (s *Stack) AddRoute(target string) error {
return nil
}

func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*Stack, error) {
func NewStack(easyConnectClient *client.EasyConnectClient, dnsHijack bool) (*Stack, error) {
s := &Stack{}

guid, err := windows.GUIDFromString(guid)
Expand Down Expand Up @@ -125,8 +127,8 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S
log.Printf("Run %s failed: %v", command.String(), err)
}

if dnsServer != "" {
command = exec.Command("netsh", "interface", "ipv4", "add", "dnsservers", "ZJU Connect", dnsServer)
if dnsHijack {
command = exec.Command("netsh", "interface", "ipv4", "add", "dnsservers", "ZJU Connect", s.endpoint.ip.String())
} else {
command = exec.Command("netsh", "interface", "ipv4", "delete", "dnsservers", "ZJU Connect", "all")
}
Expand All @@ -135,5 +137,10 @@ func NewStack(easyConnectClient *client.EasyConnectClient, dnsServer string) (*S
log.Printf("Run %s failed: %v", command.String(), err)
}

terminal_func.RegisterTerminalFunc("Close Tun Device", func(ctx context.Context) error {
dev.Close()
closeCommand := exec.Command("netsh", "interface", "ipv4", "delete", "dnsservers", "ZJU Connect", "all")
return closeCommand.Run()
})
return s, nil
}

0 comments on commit fb2be8c

Please sign in to comment.