diff --git a/cmd/ogen/main.go b/cmd/ogen/main.go index dd4d90a20..234ea8097 100644 --- a/cmd/ogen/main.go +++ b/cmd/ogen/main.go @@ -5,7 +5,10 @@ import ( "flag" "fmt" "io" + "net/http" + "net/url" "os" + "path" "path/filepath" "regexp" "runtime" @@ -127,6 +130,84 @@ Also, you can use --ct-alias to map content types to supported ones. return false } +type file struct { + data []byte + fileName string + source string + rootURL *url.URL +} + +func (f file) location() location.File { + return location.NewFile(f.fileName, f.source, f.data) +} + +func parseSpecPath( + p string, + client *http.Client, + readFile func(string) ([]byte, error), +) (f file, opts gen.RemoteOptions, _ error) { + // FIXME(tdakkota): pass context. + if u, _ := url.Parse(p); u != nil { + switch u.Scheme { + case "http", "https": + _, fileName := path.Split(u.Path) + + resp, err := client.Get(p) + if err != nil { + return f, opts, err + } + defer func() { + _ = resp.Body.Close() + }() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return f, opts, err + } + + f = file{ + data: data, + fileName: fileName, + source: p, + rootURL: u, + } + opts = gen.RemoteOptions{ + ReadFile: func(p string) ([]byte, error) { + return nil, errors.New("local files are not supported in remote mode") + }, + HTTPClient: client, + } + return f, opts, nil + case "": + default: + if runtime.GOOS == "windows" && filepath.VolumeName(p) != "" { + break + } + return f, opts, errors.Errorf("unsupported scheme %q", u.Scheme) + } + } + p = filepath.Clean(p) + _, fileName := filepath.Split(p) + + data, err := readFile(p) + if err != nil { + return f, opts, err + } + + f = file{ + data: data, + fileName: fileName, + source: p, + rootURL: &url.URL{Path: filepath.ToSlash(p)}, + } + opts = gen.RemoteOptions{ + HTTPClient: client, + ReadFile: readFile, + } + + return f, opts, nil +} + func run() error { set := flag.NewFlagSet(os.Args[0], flag.ExitOnError) set.Usage = func() { @@ -202,7 +283,6 @@ func run() error { set.Usage() return errors.New("no spec provided") } - specPath = filepath.Clean(specPath) logger, err := ogenzap.Create(logOptions) if err != nil { @@ -247,8 +327,11 @@ func run() error { }() } - specDir, fileName := filepath.Split(specPath) - data, err := os.ReadFile(specPath) + f, remoteOpts, err := parseSpecPath( + specPath, + &http.Client{Timeout: time.Minute}, + os.ReadFile, + ) if err != nil { return err } @@ -261,18 +344,15 @@ func run() error { SkipUnimplemented: *skipUnimplemented, InferSchemaType: *inferTypes, AllowRemote: *allowRemote, - Remote: gen.RemoteOptions{ - ReadFile: func(p string) ([]byte, error) { - return os.ReadFile(filepath.Join(specDir, p)) - }, - }, + RootURL: f.rootURL, + Remote: remoteOpts, Filters: gen.Filters{ PathRegex: filterPath, Methods: filterMethods, }, IgnoreNotImplemented: strings.Split(*debugIgnoreNotImplemented, ","), ContentTypeAliases: ctAliases, - File: location.NewFile(fileName, specPath, data), + File: f.location(), Logger: logger, } if expr := *skipTestsRegex; expr != "" { @@ -291,7 +371,7 @@ func run() error { } } - if err := generate(data, *packageName, *targetDir, *clean, opts); err != nil { + if err := generate(f.data, *packageName, *targetDir, *clean, opts); err != nil { if handleGenerateError(os.Stderr, logOptions.Color, err) { return errors.New("generation failed") } diff --git a/cmd/ogen/main_test.go b/cmd/ogen/main_test.go new file mode 100644 index 000000000..a8390cbe6 --- /dev/null +++ b/cmd/ogen/main_test.go @@ -0,0 +1,91 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" + "runtime" + "testing" + + "github.com/go-faster/errors" + "github.com/stretchr/testify/require" +) + +type roundTripFunc func(req *http.Request) (*http.Response, error) + +func (r roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return r(req) +} + +func Test_parseSpecPath(t *testing.T) { + testdata := []byte(`{}`) + urlPath := func(p string) *url.URL { + return &url.URL{Path: p} + } + urlParse := func(s string) *url.URL { + u, err := url.Parse(s) + require.NoError(t, err) + return u + } + + type testCase struct { + input string + httpData []byte + fileData []byte + wantFilename string + wantURL *url.URL + } + + tests := []testCase{ + {"spec.json", nil, testdata, "spec.json", urlPath("spec.json")}, + {"./spec.json", nil, testdata, "spec.json", urlPath("spec.json")}, + {"_testdata/spec.json", nil, testdata, "spec.json", urlPath("_testdata/spec.json")}, + + {"http://example.com/spec.json", testdata, nil, "spec.json", urlParse("http://example.com/spec.json")}, + } + if runtime.GOOS == "windows" { + tests = append(tests, []testCase{ + {`_testdata\spec.json`, nil, testdata, "spec.json", urlPath("_testdata/spec.json")}, + {`C:\_testdata\spec.json`, nil, testdata, "spec.json", urlPath("C:/_testdata/spec.json")}, + }...) + } + for i, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("Test%d", i+1), func(t *testing.T) { + a := require.New(t) + + var ( + data = tt.httpData + readFile = func(filename string) ([]byte, error) { + return nil, errors.Errorf("unexpected read file: %q", filename) + } + httpClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(data)), + }, nil + }), + } + ) + if data == nil { + data = tt.fileData + readFile = func(filename string) ([]byte, error) { + return data, nil + } + httpClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.Errorf("unexpected http request: %q", req.URL) + }), + } + } + + f, _, err := parseSpecPath(tt.input, httpClient, readFile) + a.NoError(err) + a.Equal(tt.wantFilename, f.fileName) + a.Equal(tt.wantURL, f.rootURL) + }) + } +} diff --git a/gen/generator.go b/gen/generator.go index 399c02871..3224865ce 100644 --- a/gen/generator.go +++ b/gen/generator.go @@ -42,6 +42,7 @@ func NewGenerator(spec *ogen.Spec, opts Options) (*Generator, error) { api, err := parser.Parse(spec, parser.Settings{ External: external, File: opts.File, + RootURL: opts.RootURL, InferTypes: opts.InferSchemaType, }) if err != nil {