diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt index 4f6a436ceb6..e39388564a3 100644 --- a/.github/actions/spelling/expect.txt +++ b/.github/actions/spelling/expect.txt @@ -448,6 +448,7 @@ libxfixes libxkbcommon libxrandr limactl +limaiptables limaloc linds linebreak diff --git a/src/go/guestagent/main.go b/src/go/guestagent/main.go index 321824d5bf4..278ba15cf83 100644 --- a/src/go/guestagent/main.go +++ b/src/go/guestagent/main.go @@ -173,15 +173,15 @@ func main() { } if *enableKubernetes { - group.Go(func() error { - k8sServiceListenerIP := net.ParseIP(*k8sServiceListenerAddr) + k8sServiceListenerIP := net.ParseIP(*k8sServiceListenerAddr) - if k8sServiceListenerIP == nil || !(k8sServiceListenerIP.Equal(net.IPv4zero) || - k8sServiceListenerIP.Equal(net.IPv4(127, 0, 0, 1))) { - log.Fatalf("empty or none valid input for Kubernetes service listener IP address %s. "+ - "Valid options are 0.0.0.0 and 127.0.0.1.", *k8sServiceListenerAddr) - } + if k8sServiceListenerIP == nil || !(k8sServiceListenerIP.Equal(net.IPv4zero) || + k8sServiceListenerIP.Equal(net.IPv4(127, 0, 0, 1))) { + log.Fatalf("empty or invalid input for Kubernetes service listener IP address %s. "+ + "Valid options are 0.0.0.0 and 127.0.0.1.", *k8sServiceListenerAddr) + } + group.Go(func() error { // Watch for kube err := kube.WatchForServices(ctx, *configPath, @@ -190,19 +190,20 @@ func main() { if err != nil { return fmt.Errorf("kubernetes service watcher failed: %w", err) } + return nil + }) + group.Go(func() error { + iptablesScanner := iptables.NewIptablesScanner() + iptablesHandler := iptables.New(ctx, portTracker, iptablesScanner, k8sServiceListenerIP, iptablesUpdateInterval) + err := iptablesHandler.ForwardPorts() + if err != nil { + return fmt.Errorf("iptables port forwarding failed: %w", err) + } return nil }) } - group.Go(func() error { - err := iptables.ForwardPorts(ctx, portTracker, iptablesUpdateInterval) - if err != nil { - return fmt.Errorf("iptables port forwarding failed: %w", err) - } - return nil - }) - group.Go(func() error { procScanner, err := procnet.NewProcNetScanner(ctx, portTracker, procNetScanInterval) if err != nil { diff --git a/src/go/guestagent/pkg/iptables/iptables.go b/src/go/guestagent/pkg/iptables/iptables.go index 8143fe741f3..cca164abd2e 100644 --- a/src/go/guestagent/pkg/iptables/iptables.go +++ b/src/go/guestagent/pkg/iptables/iptables.go @@ -16,8 +16,6 @@ package iptables import ( "context" - "crypto/sha256" - "encoding/hex" "net" "strconv" "strings" @@ -25,30 +23,54 @@ import ( "github.com/Masterminds/log-go" "github.com/docker/go-connections/nat" - "github.com/lima-vm/lima/pkg/guestagent/iptables" + limaiptables "github.com/lima-vm/lima/pkg/guestagent/iptables" "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/tracker" + "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/utils" ) +// Iptables manages port forwarding for ports identified in iptables DNAT rules. +// It is primarily responsible for handling port mappings in Kubernetes environments that +// are not exposed via the Kubernetes API. The package scans iptables for these port and uses +// the k8sServiceListenerAddr setting for the hostIP property to create a port mapping and +// forwards them to both the API tracker and the WSL Proxy for proper routing and handling. +type Iptables struct { + context context.Context + apiTracker tracker.Tracker + scanner Scanner + listenerIP net.IP + // time, in seconds, to wait between updating. + updateInterval time.Duration +} + +func New(ctx context.Context, tracker tracker.Tracker, iptablesScanner Scanner, listenerIP net.IP, updateInterval time.Duration) *Iptables { + return &Iptables{ + context: ctx, + apiTracker: tracker, + scanner: iptablesScanner, + listenerIP: listenerIP, + updateInterval: updateInterval, + } +} + // ForwardPorts forwards ports found in iptables DNAT. In some environments, // like WSL, ports defined using the CNI portmap plugin happen through iptables. // These ports are not sent to places like /proc/net/tcp and are not picked up // as part of the normal forwarding system. This function detects those ports -// and binds them so that they are picked up. -// The argument is a time, in seconds, to wait between updating. -func ForwardPorts(ctx context.Context, tracker tracker.Tracker, updateInterval time.Duration) error { - var ports []iptables.Entry +// and binds them to k8sServiceListenerAddr so that they are picked up. +func (i *Iptables) ForwardPorts() error { + var ports []limaiptables.Entry - ticker := time.NewTicker(updateInterval) + ticker := time.NewTicker(i.updateInterval) defer ticker.Stop() for { select { - case <-ctx.Done(): + case <-i.context.Done(): return nil case <-ticker.C: } // Detect ports for forward - newPorts, err := iptables.GetPorts() + newPorts, err := i.scanner.GetPorts() if err != nil { // iptables exiting with an exit status of 4 means there // is a resource problem. For example, something else is @@ -69,7 +91,7 @@ func ForwardPorts(ctx context.Context, tracker tracker.Tracker, updateInterval t // Remove old forwards for _, p := range removed { name := entryToString(p) - if err := tracker.Remove(generateID(name)); err != nil { + if err := i.apiTracker.Remove(utils.GenerateID(name)); err != nil { log.Warnf("iptables scanner failed to remove portmap for %s: %w", name, err) continue } @@ -88,18 +110,14 @@ func ForwardPorts(ctx context.Context, tracker tracker.Tracker, updateInterval t continue } portBinding := nat.PortBinding{ - // We can set the hostIP to INADDR_ANY the API Tracker will determine - // the admin installation and can adjust this to localhost accordingly - HostIP: "0.0.0.0", + HostIP: i.listenerIP.String(), HostPort: port, } - if pb, ok := portMap[portMapKey]; ok { - portMap[portMapKey] = append(pb, portBinding) - } else { + if _, ok := portMap[portMapKey]; !ok { portMap[portMapKey] = []nat.PortBinding{portBinding} } name := entryToString(p) - if err := tracker.Add(generateID(name), portMap); err != nil { + if err := i.apiTracker.Add(utils.GenerateID(name), portMap); err != nil { log.Errorf("iptables scanner failed to forward portmap for %s: %s", name, err) continue } @@ -114,8 +132,8 @@ func ForwardPorts(ctx context.Context, tracker tracker.Tracker, updateInterval t // licensed under the Apache 2. // //nolint:nonamedreturns -func comparePorts(oldPorts, newPorts []iptables.Entry) (added, removed []iptables.Entry) { - oldPortMap := make(map[string]iptables.Entry, len(oldPorts)) +func comparePorts(oldPorts, newPorts []limaiptables.Entry) (added, removed []limaiptables.Entry) { + oldPortMap := make(map[string]limaiptables.Entry, len(oldPorts)) portExistMap := make(map[string]bool, len(oldPorts)) for _, oldPort := range oldPorts { key := entryToString(oldPort) @@ -139,12 +157,6 @@ func comparePorts(oldPorts, newPorts []iptables.Entry) (added, removed []iptable return } -func entryToString(ip iptables.Entry) string { +func entryToString(ip limaiptables.Entry) string { return net.JoinHostPort(ip.IP.String(), strconv.Itoa(ip.Port)) } - -func generateID(entry string) string { - hasher := sha256.New() - hasher.Write([]byte(entry)) - return hex.EncodeToString(hasher.Sum(nil)) -} diff --git a/src/go/guestagent/pkg/iptables/iptables_test.go b/src/go/guestagent/pkg/iptables/iptables_test.go new file mode 100644 index 00000000000..d5c0489a799 --- /dev/null +++ b/src/go/guestagent/pkg/iptables/iptables_test.go @@ -0,0 +1,275 @@ +/* +Copyright © 2024 SUSE LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptables_test + +import ( + "context" + "net" + "strconv" + "testing" + "time" + + "github.com/docker/go-connections/nat" + limaiptables "github.com/lima-vm/lima/pkg/guestagent/iptables" + "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/iptables" + "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/utils" + "github.com/stretchr/testify/require" +) + +func TestForwardPorts(t *testing.T) { + tests := []struct { + name string + remove bool + listenerIP net.IP + expectedEntries []limaiptables.Entry + removedEntries []limaiptables.Entry + updateEntries []limaiptables.Entry + expectedAddFuncErr error + }{ + { + name: "With localhost listener and valid port mappings", + listenerIP: net.IPv4(127, 0, 0, 1), + expectedEntries: []limaiptables.Entry{ + {TCP: true, IP: net.IPv4(192, 168, 20, 10), Port: 1080}, + {TCP: true, IP: net.IPv4(192, 168, 20, 11), Port: 1081}, + {TCP: true, IP: net.IPv4(192, 168, 20, 12), Port: 1082}, + }, + }, + { + name: "With wildcard listener and valid port mappings", + listenerIP: net.IPv4(0, 0, 0, 0), + expectedEntries: []limaiptables.Entry{ + {TCP: true, IP: net.IPv4(192, 168, 21, 10), Port: 1080}, + {TCP: true, IP: net.IPv4(192, 168, 21, 11), Port: 1081}, + {TCP: true, IP: net.IPv4(192, 168, 21, 12), Port: 1082}, + }, + }, + { + name: "With entries removed", + remove: true, + listenerIP: net.IPv4(0, 0, 0, 0), + expectedEntries: []limaiptables.Entry{ + {TCP: true, IP: net.IPv4(192, 168, 22, 10), Port: 1080}, + {TCP: true, IP: net.IPv4(192, 168, 22, 11), Port: 1081}, + {TCP: true, IP: net.IPv4(192, 168, 22, 12), Port: 1082}, + {TCP: true, IP: net.IPv4(192, 168, 22, 13), Port: 1083}, + {TCP: true, IP: net.IPv4(192, 168, 22, 14), Port: 1084}, + }, + removedEntries: []limaiptables.Entry{ + {TCP: true, IP: net.IPv4(192, 168, 22, 11), Port: 1081}, + {TCP: true, IP: net.IPv4(192, 168, 22, 12), Port: 1082}, + }, + updateEntries: []limaiptables.Entry{ + {TCP: true, IP: net.IPv4(192, 168, 22, 10), Port: 1080}, + {TCP: true, IP: net.IPv4(192, 168, 22, 13), Port: 1083}, + {TCP: true, IP: net.IPv4(192, 168, 22, 14), Port: 1084}, + {TCP: true, IP: net.IPv4(192, 168, 22, 15), Port: 1085}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iptablesScanner := fakeScanner{ + expectedEntries: tt.expectedEntries, + expectedErr: tt.expectedAddFuncErr, + } + + testTracker := fakeTracker{ + receivedID: make(chan string), + receivedRemoveID: make(chan string), + receivedPortMapping: make(chan nat.PortMap), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + interval := time.Second + iptablesHandler := iptables.New(ctx, &testTracker, &iptablesScanner, tt.listenerIP, interval) + + go func() { + require.NoError(t, iptablesHandler.ForwardPorts()) + cancel() + }() + + for i := 0; i < len(tt.expectedEntries); i++ { + id := <-testTracker.receivedID + expectedID := utils.GenerateID(entryToString(tt.expectedEntries[i])) + require.Equal(t, expectedID, id) + + pm := <-testTracker.receivedPortMapping + portProto, err := nat.NewPort("tcp", strconv.Itoa(tt.expectedEntries[i].Port)) + require.NoError(t, err) + + expectedPortBinding := nat.PortBinding{ + HostIP: tt.listenerIP.String(), + HostPort: strconv.Itoa(tt.expectedEntries[i].Port), + } + require.Contains(t, pm[portProto], expectedPortBinding) + } + + if tt.remove { + iptablesScanner.expectedEntries = tt.updateEntries + + // Collect all removed IDs. + var actualRemovedIDs []string + for i := 0; i < len(tt.removedEntries); i++ { + select { + case id := <-testTracker.receivedRemoveID: + actualRemovedIDs = append(actualRemovedIDs, id) + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for remove ID for entry %v", tt.removedEntries[i]) + } + } + + for _, removedEntry := range tt.removedEntries { + require.Contains(t, actualRemovedIDs, utils.GenerateID(entryToString(removedEntry))) + } + addedElement := tt.updateEntries[len(tt.updateEntries)-1] + id := <-testTracker.receivedID + expectedID := utils.GenerateID(entryToString(addedElement)) + require.Equal(t, expectedID, id) + + pm := <-testTracker.receivedPortMapping + portProto, err := nat.NewPort("tcp", strconv.Itoa(addedElement.Port)) + require.NoError(t, err) + + expectedPortMap := nat.PortMap{ + portProto: []nat.PortBinding{ + { + HostIP: tt.listenerIP.String(), + HostPort: strconv.Itoa(addedElement.Port), + }, + }, + } + require.ElementsMatch(t, pm[portProto], expectedPortMap[portProto]) + } + }) + } +} + +func TestForwardPortsSamePortDifferentIP(t *testing.T) { + duplicatedPort := 1084 + tests := []struct { + name string + listenerIP net.IP + expectedEntries []limaiptables.Entry + expectedAddFuncErr error + }{ + { + name: "Same Port with different IP", + listenerIP: net.IPv4(0, 0, 0, 0), + expectedEntries: []limaiptables.Entry{ + {TCP: true, IP: net.IPv4(192, 168, 22, 10), Port: 1080}, + {TCP: true, IP: net.IPv4(192, 168, 22, 11), Port: 1081}, + {TCP: true, IP: net.IPv4(192, 168, 22, 12), Port: 1082}, + {TCP: true, IP: net.IPv4(192, 168, 22, 13), Port: 1083}, + {TCP: true, IP: net.IPv4(192, 168, 22, 14), Port: duplicatedPort}, + {TCP: true, IP: net.IPv4(192, 168, 22, 15), Port: duplicatedPort}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iptablesScanner := fakeScanner{ + expectedEntries: tt.expectedEntries, + expectedErr: tt.expectedAddFuncErr, + } + + testTracker := fakeTracker{ + receivedID: make(chan string), + receivedRemoveID: make(chan string), + receivedPortMapping: make(chan nat.PortMap), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + interval := time.Second + iptablesHandler := iptables.New(ctx, &testTracker, &iptablesScanner, tt.listenerIP, interval) + + go func() { + require.NoError(t, iptablesHandler.ForwardPorts()) + cancel() + }() + + for i := 0; i < len(tt.expectedEntries); i++ { + id := <-testTracker.receivedID + expectedID := utils.GenerateID(entryToString(tt.expectedEntries[i])) + require.Equal(t, expectedID, id) + + pm := <-testTracker.receivedPortMapping + portProto, err := nat.NewPort("tcp", strconv.Itoa(tt.expectedEntries[i].Port)) + require.NoError(t, err) + + // Port bindings for the same port on different IP addresses should appear only once + // in the port mapping. This is because the HostIP is always controlled by the + // k8sServiceListenerAddr, which means that duplicate entries with the same port + // but different IPs are unnecessary and should not be handled. + if tt.expectedEntries[i].Port == duplicatedPort { + require.Len(t, pm[portProto], 1) + } + + expectedPortBinding := nat.PortBinding{ + HostIP: tt.listenerIP.String(), + HostPort: strconv.Itoa(tt.expectedEntries[i].Port), + } + require.Contains(t, pm[portProto], expectedPortBinding) + } + }) + } +} + +// Fake Tracker implementation for mocking behavior +type fakeTracker struct { + receivedID chan string + receivedRemoveID chan string + receivedPortMapping chan nat.PortMap + expectedAddFuncErr error +} + +func (f *fakeTracker) Get(containerID string) nat.PortMap { + return nil +} + +func (f *fakeTracker) Add(containerID string, portMapping nat.PortMap) error { + f.receivedID <- containerID + f.receivedPortMapping <- portMapping + return f.expectedAddFuncErr +} + +func (f *fakeTracker) Remove(containerID string) error { + f.receivedRemoveID <- containerID + return nil +} + +func (f *fakeTracker) RemoveAll() error { + return nil +} + +// Fake Scanner to simulate iptables entries +type fakeScanner struct { + expectedEntries []limaiptables.Entry + expectedErr error +} + +func (f *fakeScanner) GetPorts() ([]limaiptables.Entry, error) { + return f.expectedEntries, f.expectedErr +} + +// Utility function to convert iptables entry to string +func entryToString(ip limaiptables.Entry) string { + return net.JoinHostPort(ip.IP.String(), strconv.Itoa(ip.Port)) +} diff --git a/src/go/guestagent/pkg/iptables/scanner.go b/src/go/guestagent/pkg/iptables/scanner.go new file mode 100644 index 00000000000..92bd943b413 --- /dev/null +++ b/src/go/guestagent/pkg/iptables/scanner.go @@ -0,0 +1,33 @@ +/* +Copyright © 2024 SUSE LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package iptables handles forwarding ports found in iptables DNAT +package iptables + +import "github.com/lima-vm/lima/pkg/guestagent/iptables" + +// Scanner is the interface that wraps the GetPorts method which +// is used to scan the iptables. +type Scanner interface { + GetPorts() ([]iptables.Entry, error) +} + +type IptablesScanner struct{} + +func NewIptablesScanner() *IptablesScanner { + return &IptablesScanner{} +} + +func (i *IptablesScanner) GetPorts() ([]iptables.Entry, error) { + return iptables.GetPorts() +} diff --git a/src/go/guestagent/pkg/procnet/scanner_linux.go b/src/go/guestagent/pkg/procnet/scanner_linux.go index b006f8348da..b32f18cda38 100644 --- a/src/go/guestagent/pkg/procnet/scanner_linux.go +++ b/src/go/guestagent/pkg/procnet/scanner_linux.go @@ -22,8 +22,6 @@ package procnet import ( "context" - "crypto/sha256" - "encoding/hex" "fmt" "io" "net" @@ -37,6 +35,7 @@ import ( "github.com/docker/go-connections/nat" "github.com/lima-vm/lima/pkg/guestagent/procnettcp" "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/tracker" + "github.com/rancher-sandbox/rancher-desktop/src/go/guestagent/pkg/utils" ) type action string @@ -91,7 +90,7 @@ func (p *ProcNetScanner) ForwardPorts() error { for port, bindings := range newPortMap { if _, exists := previousPortMap[port]; !exists { log.Infof("/proc/net scanner added port: %s -> %+v", port, bindings) - err := p.tracker.Add(generateID(fmt.Sprintf("%s/%s", port.Proto(), port.Port())), nat.PortMap{ + err := p.tracker.Add(utils.GenerateID(fmt.Sprintf("%s/%s", port.Proto(), port.Port())), nat.PortMap{ port: bindings, }) if err != nil { @@ -108,7 +107,7 @@ func (p *ProcNetScanner) ForwardPorts() error { for port, previousBindings := range previousPortMap { if _, exists := newPortMap[port]; !exists { log.Infof("/proc/net scanner removed port: %s -> %+v", port, previousBindings) - err := p.tracker.Remove(generateID(fmt.Sprintf("%s/%s", port.Proto(), port.Port()))) + err := p.tracker.Remove(utils.GenerateID(fmt.Sprintf("%s/%s", port.Proto(), port.Port()))) if err != nil { log.Errorf("/proc/net scanner failed to remove port: %s", err) continue @@ -234,9 +233,3 @@ func writeSysctl(path string, value string) error { log.Infof("/proc/net scanner enabled %s", routeLocalnet) return nil } - -func generateID(entry string) string { - hasher := sha256.New() - hasher.Write([]byte(entry)) - return hex.EncodeToString(hasher.Sum(nil)) -} diff --git a/src/go/guestagent/pkg/utils/util.go b/src/go/guestagent/pkg/utils/util.go index 3c4fedeccb6..e5a6d18c790 100644 --- a/src/go/guestagent/pkg/utils/util.go +++ b/src/go/guestagent/pkg/utils/util.go @@ -13,7 +13,11 @@ limitations under the License. package utils -import "errors" +import ( + "crypto/sha256" + "encoding/hex" + "errors" +) var ( ErrExecIptablesRule = errors.New("failed updating iptables rules") @@ -29,3 +33,9 @@ func NormalizeHostIP(ip string) string { } return "0.0.0.0" } + +func GenerateID(entry string) string { + hasher := sha256.New() + hasher.Write([]byte(entry)) + return hex.EncodeToString(hasher.Sum(nil)) +}