diff --git a/cmd/oras/internal/option/remote_test.go b/cmd/oras/internal/option/remote_test.go index 532669203..efca51cc2 100644 --- a/cmd/oras/internal/option/remote_test.go +++ b/cmd/oras/internal/option/remote_test.go @@ -18,13 +18,11 @@ package option import ( "context" "crypto/rand" - "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "net/http" - nhttp "net/http" "net/http/httptest" "net/url" "os" @@ -46,12 +44,12 @@ var testTagList = struct { func TestMain(m *testing.M) { // Test server - ts = httptest.NewTLSServer(nhttp.HandlerFunc(func(w nhttp.ResponseWriter, r *nhttp.Request) { + ts = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { p := r.URL.Path m := r.Method switch { case p == "/v2/" && m == "GET": - w.WriteHeader(nhttp.StatusOK) + w.WriteHeader(http.StatusOK) case p == fmt.Sprintf("/v2/%s/tags/list", testRepo) && m == "GET": json.NewEncoder(w).Encode(testTagList) } @@ -103,7 +101,7 @@ func TestRemote_authClient_skipTlsVerify(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - req, err := nhttp.NewRequestWithContext(context.Background(), nhttp.MethodGet, ts.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ts.URL, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -119,9 +117,6 @@ func TestRemote_authClient_CARoots(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - pool := x509.NewCertPool() - pool.AddCert(ts.Certificate()) - opts := Remote{ CACertFilePath: caPath, } @@ -129,7 +124,7 @@ func TestRemote_authClient_CARoots(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - req, err := nhttp.NewRequestWithContext(context.Background(), nhttp.MethodGet, ts.URL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ts.URL, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -154,7 +149,7 @@ func TestRemote_authClient_resolve(t *testing.T) { if err != nil { t.Fatalf("unexpected error when creating auth client: %v", err) } - req, err := nhttp.NewRequestWithContext(context.Background(), nhttp.MethodGet, fmt.Sprintf("https://%s:%s", testHost, URL.Port()), nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("https://%s:%s", testHost, URL.Port()), nil) if err != nil { t.Fatalf("unexpected error when generating request: %v", err) } @@ -170,9 +165,6 @@ func TestRemote_NewRegistry(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - pool := x509.NewCertPool() - pool.AddCert(ts.Certificate()) - opts := struct { Remote Common @@ -200,8 +192,6 @@ func TestRemote_NewRepository(t *testing.T) { if err := os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil { t.Fatalf("unexpected error: %v", err) } - pool := x509.NewCertPool() - pool.AddCert(ts.Certificate()) opts := struct { Remote Common @@ -369,7 +359,7 @@ func TestRemote_parseCustomHeaders(t *testing.T) { tests := []struct { name string headerFlags []string - want nhttp.Header + want http.Header wantErr bool }{ { diff --git a/cmd/oras/pull.go b/cmd/oras/pull.go index e66fc56dd..5e3208a4a 100644 --- a/cmd/oras/pull.go +++ b/cmd/oras/pull.go @@ -138,10 +138,7 @@ func runPull(opts pullOptions) error { rc.Close() } }() - if err := display.PrintStatus(target, "Processing ", opts.Verbose); err != nil { - return nil, err - } - return rc, nil + return rc, display.PrintStatus(target, "Processing ", opts.Verbose) }) nodes, subject, config, err := graph.Successors(ctx, statusFetcher, desc) diff --git a/internal/crypto/certificate_test.go b/internal/crypto/certificate_test.go new file mode 100644 index 000000000..38bfba11d --- /dev/null +++ b/internal/crypto/certificate_test.go @@ -0,0 +1,79 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package crypto + +import ( + "context" + "encoding/pem" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +var ts *httptest.Server + +func TestLoadCertPool(t *testing.T) { + // Test server + ts = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer ts.Close() + var err error + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + client := &http.Client{} + _, err = client.Do(req) + if err == nil { + t.Fatalf("expecting TLS check failure error but didn't get one") + } + + tp := http.DefaultTransport.(*http.Transport).Clone() + caPath := filepath.Join(t.TempDir(), "oras-test.pem") + if err = os.WriteFile(caPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ts.Certificate().Raw}), 0644); err != nil { + t.Fatalf("unexpected error: %v", err) + } + tp.TLSClientConfig.RootCAs, err = LoadCertPool(caPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + client = &http.Client{Transport: tp} + _, err = client.Do(req) + if err != nil { + t.Fatalf("failed to trust the self signed pem: %v", err) + } +} + +func TestLoadCertPool_invalidPem(t *testing.T) { + pemPath := filepath.Join(t.TempDir(), "invalid.pem") + if err := os.WriteFile(pemPath, []byte{}, 0644); err != nil { + t.Fatalf("unexpected error: %v", err) + } + got, err := LoadCertPool(pemPath) + if err == nil { + t.Errorf("Expecting LoadCertPool to return error for a non-existent pem file, got: %v", got) + return + } +} + +func TestLoadCertPool_pemNotExist(t *testing.T) { + got, err := LoadCertPool("/???") + if err == nil { + t.Errorf("Expecting LoadCertPool to return error for a non-existent pem file, got: %v", got) + return + } +} diff --git a/internal/graph/graph_test.go b/internal/graph/graph_test.go index b5f20c2cc..50eb0e8e0 100644 --- a/internal/graph/graph_test.go +++ b/internal/graph/graph_test.go @@ -113,8 +113,9 @@ func TestReferrers(t *testing.T) { appendBlob(ocispec.MediaTypeImageIndex, manifestJSON) } const ( - subject = iota + blob = iota imgConfig + subject image artifact index @@ -123,12 +124,13 @@ func TestReferrers(t *testing.T) { appendBlob(ocispec.MediaTypeArtifactManifest, []byte("subject content")) imageType := "test.image" appendBlob(imageType, []byte("config content")) - generateImage(&descs[subject], anno, descs[imgConfig]) + generateImage(nil, nil, descs[imgConfig], descs[blob]) + generateImage(&descs[subject], anno, descs[imgConfig], descs[blob]) imageDesc := descs[image] imageDesc.Annotations = anno imageDesc.ArtifactType = imageType artifactType := "test.artifact" - generateArtifact(artifactType, &descs[subject], anno) + generateArtifact(artifactType, &descs[subject], anno, descs[blob]) generateIndex(descs[subject]) artifactDesc := descs[artifact] artifactDesc.Annotations = anno @@ -157,6 +159,7 @@ func TestReferrers(t *testing.T) { {"should return referrers when target is a referrer lister", args{ctx, &refLister{referrers: referrers}, ocispec.Descriptor{}, ""}, referrers, false}, {"should return nil for index node", args{ctx, finder, descs[index], ""}, nil, false}, {"should return nil for config node", args{ctx, finder, descs[imgConfig], ""}, nil, false}, + {"should return nil for blob/layer node", args{ctx, finder, descs[blob], ""}, nil, false}, {"should find filtered image referrer", args{ctx, finder, descs[subject], imageType}, []ocispec.Descriptor{imageDesc}, false}, {"should find filtered artifact referrer", args{ctx, finder, descs[subject], artifactType}, []ocispec.Descriptor{artifactDesc}, false}, }