diff --git a/pkg/networkservice/chains/nsmgr/vl3_test.go b/pkg/networkservice/chains/nsmgr/vl3_test.go index 0a48591aa..ddfb4ed9a 100644 --- a/pkg/networkservice/chains/nsmgr/vl3_test.go +++ b/pkg/networkservice/chains/nsmgr/vl3_test.go @@ -68,7 +68,10 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) { nseReg, sandbox.GenerateTestToken, vl3.NewServer(ctx, serverPrefixCh), - vl3dns.NewServer(ctx, vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."), vl3dns.WithDNSPort(40053)), + vl3dns.NewServer(ctx, + func() net.IP { return net.ParseIP("127.0.0.1") }, + vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."), + vl3dns.WithDNSPort(40053)), ) resolver := net.Resolver{ @@ -95,7 +98,6 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) { require.NoError(t, err) require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), 1) require.Len(t, resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1) - require.Equal(t, "10.0.0.0", resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps[0]) req.Connection = resp.Clone() @@ -152,14 +154,22 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) { serverPrefixCh <- &ipam.PrefixResponse{Prefix: "10.0.0.1/24"} + var dnsConfigs = new(vl3dns.Map) + _ = domain.Nodes[0].NewEndpoint( ctx, nseReg, sandbox.GenerateTestToken, vl3.NewServer(ctx, serverPrefixCh), - vl3dns.NewServer(ctx, vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."), vl3dns.WithDNSListenAndServeFunc(func(ctx context.Context, handler dnsutils.Handler, listenOn string) { - dnsutils.ListenAndServe(ctx, handler, ":50053") - }), vl3dns.WithDNSPort(40053)), + vl3dns.NewServer(ctx, + func() net.IP { return net.ParseIP("0.0.0.0") }, + vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."), + vl3dns.WithDNSListenAndServeFunc(func(ctx context.Context, handler dnsutils.Handler, listenOn string) { + dnsutils.ListenAndServe(ctx, handler, ":50053") + }), + vl3dns.WithConfigs(dnsConfigs), + vl3dns.WithDNSPort(40053), + ), ) resolver := net.Resolver{ @@ -174,7 +184,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) { defer close(clientPrefixCh) clientPrefixCh <- &ipam.PrefixResponse{Prefix: "127.0.0.1/32"} - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(vl3.NewClient(ctx, clientPrefixCh))) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(vl3dns.NewClient(net.ParseIP("127.0.0.1"), dnsConfigs), vl3.NewClient(ctx, clientPrefixCh))) req := defaultRequest(nsReg.Name) req.Connection.Id = uuid.New().String() @@ -183,9 +193,8 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) { resp, err := nsc.Request(ctx, req) require.NoError(t, err) - require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), 1) require.Len(t, resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1) - require.Equal(t, "10.0.0.0", resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps[0]) + require.Equal(t, "127.0.0.1", resp.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps[0]) require.Equal(t, "127.0.0.1/32", resp.GetContext().GetIpContext().GetSrcIpAddrs()[0]) req.Connection = resp.Clone() diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/client.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/client.go new file mode 100644 index 000000000..40ccd5b6a --- /dev/null +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/client.go @@ -0,0 +1,84 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vl3dns + +import ( + "context" + "net" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/networkservice" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" +) + +type vl3DNSClient struct { + dnsServerIP net.IP + dnsConfigs *Map +} + +// NewClient - returns a new null client that does nothing but call next.Client(ctx).{Request/Close} and return the result +// This is very useful in testing +func NewClient(dnsServerIP net.IP, dnsConfigs *Map) networkservice.NetworkServiceClient { + return &vl3DNSClient{ + dnsServerIP: dnsServerIP, + dnsConfigs: dnsConfigs, + } +} + +func (n *vl3DNSClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) { + if request.GetConnection() == nil { + request.Connection = new(networkservice.Connection) + } + if request.GetConnection().GetContext() == nil { + request.GetConnection().Context = new(networkservice.ConnectionContext) + } + if request.GetConnection().GetContext().GetDnsContext() == nil { + request.GetConnection().GetContext().DnsContext = new(networkservice.DNSContext) + } + + request.GetConnection().GetContext().GetDnsContext().Configs = []*networkservice.DNSConfig{ + { + DnsServerIps: []string{n.dnsServerIP.String()}, + }, + } + resp, err := next.Client(ctx).Request(ctx, request, opts...) + + if err == nil { + for _, config := range resp.GetContext().GetDnsContext().GetConfigs() { + var skip = false + for _, ip := range config.DnsServerIps { + if ip == n.dnsServerIP.String() { + skip = true + break + } + } + if skip { + continue + } + n.dnsConfigs.Store(resp.GetId(), config) + } + } + + return resp, err +} + +func (n *vl3DNSClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*empty.Empty, error) { + n.dnsConfigs.Delete(conn.GetId()) + return next.Client(ctx).Close(ctx, conn, opts...) +} diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/gen.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/gen.go new file mode 100644 index 000000000..7c26deb73 --- /dev/null +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/gen.go @@ -0,0 +1,26 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vl3dns + +import ( + "sync" +) + +//go:generate go-syncmap -output map.gen.go -type Map + +// Map - sync.Map with key == string and value == networkservice.DNSConfig +type Map sync.Map diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/map.gen.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/map.gen.go new file mode 100644 index 000000000..64e2b83c2 --- /dev/null +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/map.gen.go @@ -0,0 +1,75 @@ +// Code generated by "-output map.gen.go -type Map -output map.gen.go -type Map"; DO NOT EDIT. +package vl3dns + +import ( + "sync" // Used by sync.Map. + + "github.com/networkservicemesh/api/pkg/api/networkservice" +) + +// Generate code that will fail if the constants change value. +func _() { + // An "cannot convert Map literal (type Map) to type sync.Map" compiler error signifies that the base type have changed. + // Re-run the go-syncmap command to generate them again. + _ = (sync.Map)(Map{}) +} + +var _nil_Map_networkservice_DNSConfig_value = func() (val *networkservice.DNSConfig) { return }() + +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func (m *Map) Load(key string) (*networkservice.DNSConfig, bool) { + value, ok := (*sync.Map)(m).Load(key) + if value == nil { + return _nil_Map_networkservice_DNSConfig_value, ok + } + return value.(*networkservice.DNSConfig), ok +} + +// Store sets the value for a key. +func (m *Map) Store(key string, value *networkservice.DNSConfig) { + (*sync.Map)(m).Store(key, value) +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (m *Map) LoadOrStore(key string, value *networkservice.DNSConfig) (*networkservice.DNSConfig, bool) { + actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) + if actual == nil { + return _nil_Map_networkservice_DNSConfig_value, loaded + } + return actual.(*networkservice.DNSConfig), loaded +} + +// LoadAndDelete deletes the value for a key, returning the previous value if any. +// The loaded result reports whether the key was present. +func (m *Map) LoadAndDelete(key string) (value *networkservice.DNSConfig, loaded bool) { + actual, loaded := (*sync.Map)(m).LoadAndDelete(key) + if actual == nil { + return _nil_Map_networkservice_DNSConfig_value, loaded + } + return actual.(*networkservice.DNSConfig), loaded +} + +// Delete deletes the value for a key. +func (m *Map) Delete(key string) { + (*sync.Map)(m).Delete(key) +} + +// Range calls f sequentially for each key and value present in the map. +// If f returns false, range stops the iteration. +// +// Range does not necessarily correspond to any consistent snapshot of the Map's +// contents: no key will be visited more than once, but if the value for any key +// is stored or deleted concurrently, Range may reflect any mapping for that key +// from any point during the Range call. +// +// Range may be O(N) with the number of elements in the map even if f returns +// false after a constant number of calls. +func (m *Map) Range(f func(key string, value *networkservice.DNSConfig) bool) { + (*sync.Map)(m).Range(func(key, value interface{}) bool { + return f(key.(string), value.(*networkservice.DNSConfig)) + }) +} diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/options.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/options.go index c58c72d93..2fc6e818a 100644 --- a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/options.go +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/options.go @@ -19,7 +19,6 @@ package vl3dns import ( "context" "fmt" - "net/url" "text/template" "github.com/networkservicemesh/sdk/pkg/tools/dnsutils" @@ -28,10 +27,10 @@ import ( // Option configures vl3DNSServer type Option func(*vl3DNSServer) -// WithInitialFanoutList sets initial list to fanout queries -func WithInitialFanoutList(initialFanoutList []url.URL) Option { +// WithConfigs sets initial list to fanout queries +func WithConfigs(m *Map) Option { return func(vd *vl3DNSServer) { - vd.initialFanoutList = initialFanoutList + vd.configs = m } } diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go index a7548ec11..df36e2f57 100644 --- a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go @@ -23,7 +23,6 @@ import ( "net" "net/url" "strings" - "sync" "text/template" "github.com/golang/protobuf/ptypes/empty" @@ -42,12 +41,12 @@ import ( type vl3DNSServer struct { dnsServerRecords memory.Map - fanoutAddresses sync.Map domainSchemeTemplates []*template.Template - initialFanoutList []url.URL + configs *Map dnsPort int dnsServer dnsutils.Handler listenAndServeDNS func(ctx context.Context, handler dnsutils.Handler, listenOn string) + getDNSServerIP func() net.IP } type clientDNSNameKey struct{} @@ -57,10 +56,12 @@ type clientDNSNameKey struct{} // By default is using fanout dns handler to connect to other vl3 nses. // chanCtx is using for signal to stop dns server. // opts confugre vl3dns networkservice instance with specific behavior. -func NewServer(chanCtx context.Context, opts ...Option) networkservice.NetworkServiceServer { +func NewServer(chanCtx context.Context, getDNSServerIP func() net.IP, opts ...Option) networkservice.NetworkServiceServer { var result = &vl3DNSServer{ dnsPort: 53, listenAndServeDNS: dnsutils.ListenAndServe, + getDNSServerIP: getDNSServerIP, + configs: new(Map), } for _, opt := range opts { @@ -86,13 +87,12 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw request.Connection.Context.DnsContext = new(networkservice.DNSContext) } - var ipContext = request.GetConnection().GetContext().GetIpContext() + var dnsContext = request.GetConnection().GetContext().GetDnsContext() + var clientsConfigs = dnsContext.GetConfigs() - for _, dstIPNet := range ipContext.GetDstIPNets() { - request.GetConnection().GetContext().GetDnsContext().Configs = append(request.GetConnection().GetContext().GetDnsContext().Configs, &networkservice.DNSConfig{ - DnsServerIps: []string{dstIPNet.IP.String()}, - }) - } + dnsContext.Configs = append(dnsContext.Configs, &networkservice.DNSConfig{ + DnsServerIps: []string{n.getDNSServerIP().String()}, + }) var recordNames, err = n.buildSrcDNSRecords(request.GetConnection()) @@ -121,21 +121,29 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw metadata.Map(ctx, false).Store(clientDNSNameKey{}, recordNames) } - if n.shouldAddToFanoutList(ipContext) { - for _, srcIPNet := range ipContext.GetSrcIPNets() { - var u = url.URL{Scheme: "tcp", Host: fmt.Sprintf("%v:%v", srcIPNet.IP.String(), n.dnsPort)} - n.fanoutAddresses.Store(u, struct{}{}) + resp, err := next.Server(ctx).Request(ctx, request) + + if err == nil { + if srcRoutes := resp.GetContext().GetIpContext().GetSrcIPRoutes(); len(srcRoutes) > 0 { + var lastPrefix = srcRoutes[len(srcRoutes)-1].Prefix + for _, config := range clientsConfigs { + for _, serverIP := range config.DnsServerIps { + if serverIP == n.getDNSServerIP().String() { + continue + } + if withinPrefix(serverIP, lastPrefix) { + n.configs.Store(resp.GetId(), config) + } + } + } } } - return next.Server(ctx).Request(ctx, request) + return resp, err } func (n *vl3DNSServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { - for _, srcIPNet := range conn.Context.GetIpContext().GetSrcIPNets() { - var u = url.URL{Scheme: "tcp", Host: fmt.Sprintf("%v:%v", srcIPNet.IP.String(), n.dnsPort)} - n.fanoutAddresses.Delete(u) - } + n.configs.Delete(conn.GetId()) if v, ok := metadata.Map(ctx, false).LoadAndDelete(clientDNSNameKey{}); ok { var names = v.([]string) @@ -160,31 +168,16 @@ func (n *vl3DNSServer) buildSrcDNSRecords(c *networkservice.Connection) ([]strin } func (n *vl3DNSServer) getFanoutAddresses() []url.URL { - var result = n.initialFanoutList - n.fanoutAddresses.Range(func(key, _ interface{}) bool { - result = append(result, key.(url.URL)) + var result []url.URL + n.configs.Range(func(key string, value *networkservice.DNSConfig) bool { + for _, addr := range value.DnsServerIps { + result = append(result, url.URL{Scheme: "tcp", Host: fmt.Sprintf("%v:%v", addr, n.dnsPort)}) + } return true }) return result } -func (n *vl3DNSServer) shouldAddToFanoutList(ipContext *networkservice.IPContext) bool { - if len(ipContext.SrcRoutes) > 0 { - var lastSrcRoute = ipContext.SrcRoutes[len(ipContext.SrcRoutes)-1] - _, ipNet, err := net.ParseCIDR(lastSrcRoute.Prefix) - if err != nil { - return false - } - var pool = ippool.NewWithNet(ipNet) - for _, srcIP := range ipContext.GetSrcIpAddrs() { - if !pool.ContainsNetString(srcIP) { - return true - } - } - } - return false -} - func compareStringSlices(a, b []string) bool { if len(a) != len(b) { return false @@ -196,3 +189,12 @@ func compareStringSlices(a, b []string) bool { } return true } + +func withinPrefix(ipAddr, prefix string) bool { + _, ipNet, err := net.ParseCIDR(prefix) + if err != nil { + return false + } + var pool = ippool.NewWithNet(ipNet) + return pool.ContainsString(ipAddr) +}