Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support customized DNS resolving for remote registry #696

Merged
merged 17 commits into from
Dec 23, 2022
Merged
73 changes: 68 additions & 5 deletions cmd/oras/internal/option/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net"
"net/http"
"os"
"strconv"
"strings"
"time"

Expand All @@ -35,15 +36,24 @@ import (
"oras.land/oras/internal/version"
)

type ResolveEntry struct {
qweeah marked this conversation as resolved.
Show resolved Hide resolved
from string
to net.IP
port int
}

// Remote options struct.
type Remote struct {
resolveFlag []string
qweeah marked this conversation as resolved.
Show resolved Hide resolved

CACertFilePath string
PlainHTTP bool
Insecure bool
Configs []string
Username string
PasswordFromStdin bool
Password string
Resolves []*ResolveEntry
qweeah marked this conversation as resolved.
Show resolved Hide resolved
qweeah marked this conversation as resolved.
Show resolved Hide resolved
}

// ApplyFlags applies flags to a command flag set.
Expand Down Expand Up @@ -76,6 +86,10 @@ func (opts *Remote) ApplyFlagsWithPrefix(fs *pflag.FlagSet, prefix, description
if fs.Lookup("registry-config") == nil {
fs.StringArrayVarP(&opts.Configs, "registry-config", "", nil, "`path` of the authentication file")
}

if fs.Lookup("resolve") == nil {
fs.StringArrayVarP(&opts.resolveFlag, "resolve", "", nil, "customized DNS formatted in `host:port:address`")
}
}

// ReadPassword tries to read password with optional cmd prompt.
Expand All @@ -94,6 +108,34 @@ func (opts *Remote) ReadPassword() (err error) {
return nil
}

// parseResolve parses resolve flag.
func (opts *Remote) parseResolve() (err error) {
qweeah marked this conversation as resolved.
Show resolved Hide resolved
errorMsg := "failed to parse resolve flag %q: %s"
qweeah marked this conversation as resolved.
Show resolved Hide resolved
for _, r := range opts.resolveFlag {
parts := strings.SplitN(r, ":", 3)
if len(parts) < 3 {
return fmt.Errorf(errorMsg, r, "expecting host:port:address")
}

port, err := strconv.Atoi(parts[1])
if err != nil {
return fmt.Errorf(errorMsg, r, "expecting uint64 port")
}

// ipv6 zone is not parsed
to := net.ParseIP(parts[2])
if to == nil {
return fmt.Errorf(errorMsg, r, "invalid IP address")
}
opts.Resolves = append(opts.Resolves, &ResolveEntry{
from: parts[0],
port: port,
to: to,
})
}
return nil
}

// tlsConfig assembles the tls config.
func (opts *Remote) tlsConfig() (*tls.Config, error) {
config := &tls.Config{
Expand All @@ -109,21 +151,42 @@ func (opts *Remote) tlsConfig() (*tls.Config, error) {
return config, nil
}

var defaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}

// DialContext connects to the addr on the named network using
// the provided context.
func (opts *Remote) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
var matched *ResolveEntry
for _, r := range opts.Resolves {
if addr == fmt.Sprintf("%s:%d", r.from, r.port) {
matched = r
break
}
}
if matched == nil {
return defaultDialer.DialContext(ctx, network, addr)
}
return net.DialTCP(network, nil, &net.TCPAddr{IP: matched.to, Port: matched.port})
}
qweeah marked this conversation as resolved.
Show resolved Hide resolved

// authClient assembles a oras auth client.
func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client, err error) {
config, err := opts.tlsConfig()
if err != nil {
return nil, err
}
if err := opts.parseResolve(); err != nil {
return nil, err
}
client = &auth.Client{
Client: &http.Client{
// default value are derived from http.DefaultTransport
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
Proxy: http.ProxyFromEnvironment,
DialContext: opts.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
Expand Down
106 changes: 102 additions & 4 deletions cmd/oras/internal/option/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
nhttp "net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"testing"

nhttp "net/http"
"net/http/httptest"
"net/url"

"github.com/spf13/pflag"
"oras.land/oras-go/v2/registry/remote/auth"
)
Expand Down Expand Up @@ -139,6 +138,31 @@ func TestRemote_authClient_CARoots(t *testing.T) {
}
}

func TestRemote_authClient_resolve(t *testing.T) {
URL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("invalid url in test server: %s", ts.URL)
}

testHost := "test.unit.oras"
opts := Remote{
resolveFlag: []string{fmt.Sprintf("%s:%s:%s", testHost, URL.Port(), URL.Hostname())},
Insecure: true,
}
client, err := opts.authClient(testHost, false)
if err != nil {
t.Fatalf("unexpected error when creating auth client: %v", err)
}
req, err := nhttp.NewRequestWithContext(context.Background(), nhttp.MethodGet, fmt.Sprintf("https://%s:%s", testHost, URL.Port()), nil)
if err != nil {
t.Fatalf("unexpected error when generating request: %v", err)
}
_, err = client.Do(req)
if err != nil {
t.Fatalf("unexpected error when sending request: %v", err)
}
}

func TestRemote_NewRegistry(t *testing.T) {
caPath := filepath.Join(t.TempDir(), "oras-test.pem")
if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil {
Expand Down Expand Up @@ -220,3 +244,77 @@ func TestRemote_isPlainHttp_localhost(t *testing.T) {

}
}

func TestRemote_parseResolve_err(t *testing.T) {
tests := []struct {
name string
opts *Remote
wantErr bool
}{
{
name: "invalid flag",
opts: &Remote{resolveFlag: []string{"this-shouldn't_work"}},
wantErr: true,
},
{
name: "no host",
opts: &Remote{resolveFlag: []string{":port:address"}},
wantErr: true,
},
{
name: "no address",
opts: &Remote{resolveFlag: []string{"host:port:"}},
wantErr: true,
},
{
name: "invalid address",
opts: &Remote{resolveFlag: []string{"host:port:invalid-ip"}},
wantErr: true,
},
{
name: "no port",
opts: &Remote{resolveFlag: []string{"host::address"}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.opts.parseResolve(); (err != nil) != tt.wantErr {
t.Errorf("Remote.parseResolve() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestRemote_parseResolve_defaultFlag(t *testing.T) {
opts := &Remote{resolveFlag: nil}
if err := opts.parseResolve(); err != nil {
t.Fatalf("should succeed parsing empty resolve flag but got %v", err)
}
if len(opts.Resolves) != 0 {
t.Fatalf("expect empty resolve entries but got %v", opts.Resolves)
}
}

func TestRemote_parseResolve_ipv4(t *testing.T) {
host := "mockedHost"
port := 12345
address := "192.168.1.1"
opts := &Remote{resolveFlag: []string{fmt.Sprintf("%s:%d:%s", host, port, address)}}
if err := opts.parseResolve(); err != nil {
t.Fatalf("should succeed parsing resolve flag but got %v", err)
}
if len(opts.Resolves) != 1 {
t.Fatalf("expect 1 resolve entries but got %v", opts.Resolves)
}

entry := opts.Resolves[0]
if entry.from != host {
t.Fatalf("expect resolved host %q but got %q", host, entry.from)
}
if entry.to.To4().String() != address {
t.Fatalf("expect resolved address %q but got %q", address, entry.to)
}
if entry.port != port {
t.Fatalf("expect resolved port %d but port %d", port, entry.port)
}
}