diff --git a/workflow/source.go b/workflow/source.go index cb56aef..011cdd4 100644 --- a/workflow/source.go +++ b/workflow/source.go @@ -2,13 +2,17 @@ package workflow import ( "crypto/sha256" + "encoding/json" "fmt" + "io" "math/rand" + "net/http" "os" "path/filepath" "strings" "github.com/speakeasy-api/sdk-gen-config/workspace" + "gopkg.in/yaml.v3" ) // Ensure your update schema/workflow.schema.json on changes @@ -26,6 +30,14 @@ type Overlay struct { Document *Document `yaml:"document,omitempty"` } +type InputFileType string + +const ( + InputFileTypeUnknown InputFileType = "unknown" + InputFileTypeJSON InputFileType = "json" + InputFileTypeYAML InputFileType = "yaml" +) + func (o *Overlay) UnmarshalYAML(unmarshal func(interface{}) error) error { // Overlay is flat, so we need to unmarshal it into a map to determine if it's a document or fallbackCodeSamples var overlayMap map[string]interface{} @@ -129,59 +141,98 @@ func (s Source) Validate() error { } func (s Source) GetOutputLocation() (string, error) { - if s.Output != nil { - if len(s.Inputs) > 1 && !isYAMLFile(*s.Output) { - return "", fmt.Errorf("when merging multiple inputs, output must be a yaml file") - } - return *s.Output, nil - } - - if len(s.Inputs) == 1 && len(s.Overlays) == 0 { - return s.handleSingleInput() - } - - return s.generateOutputPath() -} - -func (s Source) handleSingleInput() (string, error) { - input := s.Inputs[0].Location - switch getFileStatus(input) { - case fileStatusLocal: - return input, nil - case fileStatusNotExists: - return "", fmt.Errorf("input file %s does not exist", input) - case fileStatusRemote, fileStatusRegistry: - return s.generateRegistryPath(input) - default: - return "", fmt.Errorf("unknown file status for %s", input) - } + if s.Output != nil { + if len(s.Inputs) > 1 && !isYAMLFile(*s.Output) { + return "", fmt.Errorf("when merging multiple inputs, output must be a yaml file") + } + return *s.Output, nil + } + + return s.generateOutputPath() } func (s Source) generateRegistryPath(input string) (string, error) { - ext := filepath.Ext(input) - if ext == "" { - ext = ".yaml" - } - hash := fmt.Sprintf("%x", sha256.Sum256([]byte(input))) - return filepath.Join(GetTempDir(), fmt.Sprintf("registry_%s%s", hash[:6], ext)), nil + ext := filepath.Ext(input) + hash := fmt.Sprintf("%x", sha256.Sum256([]byte(input))) + + if ext == "" { + resolvedExtension := s.getRemoteResolvedExtension(input) + + if resolvedExtension != InputFileTypeUnknown { + return filepath.Join(GetTempDir(), fmt.Sprintf("registry_%s.%s", hash[:6], resolvedExtension)), nil + } + } + + // Check if the extension is supported + if ext != ".yaml" && ext != ".yml" && ext != ".json" { + ext = ".yaml" + } + return filepath.Join(GetTempDir(), fmt.Sprintf("registry_%s%s", hash[:6], ext)), nil +} + +// Attempts to fetch the remote file and determine the format (yaml / json) +// based on the contents of the file +func (s Source) getRemoteResolvedExtension(input string) InputFileType { + res, err := http.Get(input) + if err != nil { + return InputFileTypeUnknown + } + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + return InputFileTypeUnknown + } + if json.Unmarshal(body, &struct{}{}) == nil { + return InputFileTypeJSON + } + if yaml.Unmarshal(body, &struct{}{}) == nil { + return InputFileTypeYAML + } + return InputFileTypeUnknown } func (s Source) generateOutputPath() (string, error) { - hashInputs := func() string { - var combined string - for _, input := range s.Inputs { - combined += input.Location - } - hash := sha256.Sum256([]byte(combined)) - return fmt.Sprintf("%x", hash)[:6] - } + generateOutputPath := func(extension string) string { + var combined string + for _, input := range s.Inputs { + combined += input.Location + } + hash := sha256.Sum256([]byte(combined)) + hashStr := fmt.Sprintf("%x", hash)[:6] + return filepath.Join(GetTempDir(), fmt.Sprintf("output_%s%s", hashStr, extension)) + } + + if len(s.Inputs) == 1 { + hasOverlays := len(s.Overlays) > 0 + + if hasOverlays { + ext := filepath.Ext(s.Inputs[0].Location) + if ext == "" { + ext = ".yaml" + } + return generateOutputPath(ext), nil + } + + input := s.Inputs[0].Location + + switch getFileStatus(input) { + case fileStatusLocal: + return input, nil + case fileStatusNotExists: + return "", fmt.Errorf("input file %s does not exist", input) + case fileStatusRemote, fileStatusRegistry: + return s.generateRegistryPath(input) + default: + return "", fmt.Errorf("unknown file status for %s", input) + } + } - return filepath.Join(GetTempDir(), fmt.Sprintf("output_%s.yaml", hashInputs())), nil + return generateOutputPath(".yaml"), nil } func isYAMLFile(path string) bool { - ext := filepath.Ext(path) - return ext == ".yaml" || ext == ".yml" + ext := filepath.Ext(path) + return ext == ".yaml" || ext == ".yml" } func GetTempDir() string { diff --git a/workflow/source_test.go b/workflow/source_test.go index 12f4c10..13678d3 100644 --- a/workflow/source_test.go +++ b/workflow/source_test.go @@ -1,7 +1,11 @@ package workflow_test import ( + "encoding/json" "fmt" + "net" + "net/http" + "net/http/httptest" "net/url" "os" "path/filepath" @@ -359,6 +363,57 @@ func TestSource_GetOutputLocation(t *testing.T) { type args struct { source workflow.Source } + + // The URL needs to be deterministic because the hash is based on the URL + path + testServer, err := newTestServerWithURL("127.0.0.1:1234", http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + // Determine the file extension from the URL path + fileExt := filepath.Ext(req.URL.Path) + + // Default to JSON or YAML if no file extension is present + if fileExt == "" { + switch { + case strings.Contains(req.URL.Path, "json"): + fileExt = ".json" + case strings.Contains(req.URL.Path, "yaml"): + fileExt = ".yaml" + } + } + + // Determine the content type and response body based on the file extension + var ( + contentType string + response interface{} + err error + responseBytes []byte + ) + response = map[string]interface{}{"openapi": "3.0.0"} + + switch fileExt { + case ".json": + contentType = "application/json" + case ".yaml": + contentType = "application/yaml" + default: + http.Error(res, "Unsupported file format", http.StatusBadRequest) + return + } + + // Set the content type header + res.Header().Set("Content-Type", contentType) + + // Marshal and write the response based on content type + if contentType == "application/json" { + responseBytes, err = json.Marshal(response) + } else { + responseBytes, err = yaml.Marshal(response) + } + assert.NoError(t, err) + res.Write(responseBytes) + })) + + require.NoError(t, err) + defer func() { testServer.Close() }() + tests := []struct { name string args args @@ -383,12 +438,12 @@ func TestSource_GetOutputLocation(t *testing.T) { source: workflow.Source{ Inputs: []workflow.Document{ { - Location: "http://example.com/openapi.json", + Location: fmt.Sprintf("%s/openapi.json", testServer.URL), }, }, }, }, - wantOutputLocation: ".speakeasy/temp/registry_e8ba45.json", + wantOutputLocation: ".speakeasy/temp/registry_4b5145.json", }, { name: "simple remote source without extension returns auto-generated output location assumed to be yaml", @@ -396,12 +451,12 @@ func TestSource_GetOutputLocation(t *testing.T) { source: workflow.Source{ Inputs: []workflow.Document{ { - Location: "http://example.com/openapi", + Location: fmt.Sprintf("%s/openapi", testServer.URL), }, }, }, }, - wantOutputLocation: ".speakeasy/temp/registry_94359d.yaml", + wantOutputLocation: ".speakeasy/temp/registry_61ea27.yaml", }, { name: "source with multiple inputs returns specified output location", @@ -469,6 +524,76 @@ func TestSource_GetOutputLocation(t *testing.T) { }, wantOutputLocation: ".speakeasy/temp/output_d910ba.yaml", }, + { + name: "single local source uses same extension as source", + args: args{ + source: workflow.Source{ + Inputs: []workflow.Document{ + { + Location: "openapi.json", + }, + }, + }, + }, + wantOutputLocation: "openapi.json", + }, + { + name: "single local source with overlays uses same extension as source", + args: args{ + source: workflow.Source{ + Inputs: []workflow.Document{ + { + Location: "openapi.json", + }, + }, + Overlays: []workflow.Overlay{ + {Document: &workflow.Document{Location: "overlay.yaml"}}, + }, + }, + }, + wantOutputLocation: ".speakeasy/temp/output_a98653.json", + }, + { + name: "single remote source with unknown format uses resolved json extension", + args: args{ + source: workflow.Source{ + Inputs: []workflow.Document{ + { + // The test server is setup so that it will return json if the path includes the word "json" + Location: fmt.Sprintf("%s/thepathincludesjson", testServer.URL), + }, + }, + }, + }, + wantOutputLocation: ".speakeasy/temp/registry_411616.json", + }, + { + name: "single remote source with unknown format uses resolved yaml extension", + args: args{ + source: workflow.Source{ + Inputs: []workflow.Document{ + { + // The test server is setup so that it will return yaml if the path includes the word "yaml" + Location: fmt.Sprintf("%s/thepathincludesyaml", testServer.URL), + }, + }, + }, + }, + wantOutputLocation: ".speakeasy/temp/registry_0254db.yaml", + }, + { + name: "single remote source with unsupported file extension returns auto-generated output location", + args: args{ + source: workflow.Source{ + Inputs: []workflow.Document{ + { + Location: fmt.Sprintf("%s/foo.txt", testServer.URL), + }, + }, + }, + }, + wantOutputLocation: ".speakeasy/temp/registry_69a6f2.yaml", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -611,3 +736,17 @@ func createEmptyFile(path string) error { return f.Close() } + +func newTestServerWithURL(URL string, handler http.Handler) (*httptest.Server, error) { + ts := httptest.NewUnstartedServer(handler) + if URL != "" { + l, err := net.Listen("tcp", URL) + if err != nil { + return nil, err + } + ts.Listener.Close() + ts.Listener = l + } + ts.Start() + return ts, nil +}