Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: respect input format when producing outfile #56

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 94 additions & 43 deletions workflow/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package workflow

import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"path/filepath"
"strings"

"github.com/speakeasy-api/sdk-gen-config/workspace"
"gopkg.in/yaml.v3"
)

// Ensure your update schema/workflow.schema.json on changes
Expand All @@ -26,6 +30,14 @@ type Overlay struct {
Document *Document `yaml:"document,omitempty"`
}

type InputFileType string

const (
InputFileTypeUnknown InputFileType = "unknown"
InputFileTypeJSON InputFileType = "json"
InputFileTypeYAML InputFileType = "yaml"
)

func (o *Overlay) UnmarshalYAML(unmarshal func(interface{}) error) error {
// Overlay is flat, so we need to unmarshal it into a map to determine if it's a document or fallbackCodeSamples
var overlayMap map[string]interface{}
Expand Down Expand Up @@ -129,59 +141,98 @@ func (s Source) Validate() error {
}

func (s Source) GetOutputLocation() (string, error) {
if s.Output != nil {
if len(s.Inputs) > 1 && !isYAMLFile(*s.Output) {
return "", fmt.Errorf("when merging multiple inputs, output must be a yaml file")
}
return *s.Output, nil
}

if len(s.Inputs) == 1 && len(s.Overlays) == 0 {
return s.handleSingleInput()
}

return s.generateOutputPath()
}

func (s Source) handleSingleInput() (string, error) {
input := s.Inputs[0].Location
switch getFileStatus(input) {
case fileStatusLocal:
return input, nil
case fileStatusNotExists:
return "", fmt.Errorf("input file %s does not exist", input)
case fileStatusRemote, fileStatusRegistry:
return s.generateRegistryPath(input)
default:
return "", fmt.Errorf("unknown file status for %s", input)
}
if s.Output != nil {
if len(s.Inputs) > 1 && !isYAMLFile(*s.Output) {
return "", fmt.Errorf("when merging multiple inputs, output must be a yaml file")
}
return *s.Output, nil
}

return s.generateOutputPath()
}

func (s Source) generateRegistryPath(input string) (string, error) {
ext := filepath.Ext(input)
if ext == "" {
ext = ".yaml"
}
hash := fmt.Sprintf("%x", sha256.Sum256([]byte(input)))
return filepath.Join(GetTempDir(), fmt.Sprintf("registry_%s%s", hash[:6], ext)), nil
ext := filepath.Ext(input)
hash := fmt.Sprintf("%x", sha256.Sum256([]byte(input)))

if ext == "" {
resolvedExtension := s.getRemoteResolvedExtension(input)

if resolvedExtension != InputFileTypeUnknown {
return filepath.Join(GetTempDir(), fmt.Sprintf("registry_%s.%s", hash[:6], resolvedExtension)), nil
}
}

// Check if the extension is supported
if ext != ".yaml" && ext != ".yml" && ext != ".json" {
ext = ".yaml"
}
return filepath.Join(GetTempDir(), fmt.Sprintf("registry_%s%s", hash[:6], ext)), nil
}

// Attempts to fetch the remote file and determine the format (yaml / json)
// based on the contents of the file
func (s Source) getRemoteResolvedExtension(input string) InputFileType {
res, err := http.Get(input)
Copy link
Member

@ThomasRooney ThomasRooney Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works for a lot of cases:

  1. Some remote files are blocked by authentication headers - these are provided alongside the input location.
  2. This getRemoteResolvedExtension(input) is also used with registry URIs.

I think we need to delegate to ResolveDocument which currently lives within the speakeasy repo. It could also probably cache the document to stop unnecessary duplication of calls.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sounds good. I will do some investigation into this 👍🏻

Aside from this fetching logic to be stripped out and replaced with ResolveDocument, the approach is generally sound?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep; definitely!

if err != nil {
return InputFileTypeUnknown
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return InputFileTypeUnknown
}
if json.Unmarshal(body, &struct{}{}) == nil {
return InputFileTypeJSON
}
if yaml.Unmarshal(body, &struct{}{}) == nil {
return InputFileTypeYAML
}
return InputFileTypeUnknown
}

func (s Source) generateOutputPath() (string, error) {
hashInputs := func() string {
var combined string
for _, input := range s.Inputs {
combined += input.Location
}
hash := sha256.Sum256([]byte(combined))
return fmt.Sprintf("%x", hash)[:6]
}
generateOutputPath := func(extension string) string {
var combined string
for _, input := range s.Inputs {
combined += input.Location
}
hash := sha256.Sum256([]byte(combined))
hashStr := fmt.Sprintf("%x", hash)[:6]
return filepath.Join(GetTempDir(), fmt.Sprintf("output_%s%s", hashStr, extension))
}

if len(s.Inputs) == 1 {
hasOverlays := len(s.Overlays) > 0

if hasOverlays {
ext := filepath.Ext(s.Inputs[0].Location)
if ext == "" {
ext = ".yaml"
}
return generateOutputPath(ext), nil
}

input := s.Inputs[0].Location

switch getFileStatus(input) {
case fileStatusLocal:
return input, nil
case fileStatusNotExists:
return "", fmt.Errorf("input file %s does not exist", input)
case fileStatusRemote, fileStatusRegistry:
return s.generateRegistryPath(input)
default:
return "", fmt.Errorf("unknown file status for %s", input)
}
}

return filepath.Join(GetTempDir(), fmt.Sprintf("output_%s.yaml", hashInputs())), nil
return generateOutputPath(".yaml"), nil
}

func isYAMLFile(path string) bool {
ext := filepath.Ext(path)
return ext == ".yaml" || ext == ".yml"
ext := filepath.Ext(path)
return ext == ".yaml" || ext == ".yml"
}

func GetTempDir() string {
Expand Down
147 changes: 143 additions & 4 deletions workflow/source_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package workflow_test

import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -359,6 +363,57 @@ func TestSource_GetOutputLocation(t *testing.T) {
type args struct {
source workflow.Source
}

// The URL needs to be deterministic because the hash is based on the URL + path
testServer, err := newTestServerWithURL("127.0.0.1:1234", http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
// Determine the file extension from the URL path
fileExt := filepath.Ext(req.URL.Path)

// Default to JSON or YAML if no file extension is present
if fileExt == "" {
switch {
case strings.Contains(req.URL.Path, "json"):
fileExt = ".json"
case strings.Contains(req.URL.Path, "yaml"):
fileExt = ".yaml"
}
}

// Determine the content type and response body based on the file extension
var (
contentType string
response interface{}
err error
responseBytes []byte
)
response = map[string]interface{}{"openapi": "3.0.0"}

switch fileExt {
case ".json":
contentType = "application/json"
case ".yaml":
contentType = "application/yaml"
default:
http.Error(res, "Unsupported file format", http.StatusBadRequest)
return
}

// Set the content type header
res.Header().Set("Content-Type", contentType)

// Marshal and write the response based on content type
if contentType == "application/json" {
responseBytes, err = json.Marshal(response)
} else {
responseBytes, err = yaml.Marshal(response)
}
assert.NoError(t, err)
res.Write(responseBytes)
}))

require.NoError(t, err)
defer func() { testServer.Close() }()

tests := []struct {
name string
args args
Expand All @@ -383,25 +438,25 @@ func TestSource_GetOutputLocation(t *testing.T) {
source: workflow.Source{
Inputs: []workflow.Document{
{
Location: "http://example.com/openapi.json",
Location: fmt.Sprintf("%s/openapi.json", testServer.URL),
},
},
},
},
wantOutputLocation: ".speakeasy/temp/registry_e8ba45.json",
wantOutputLocation: ".speakeasy/temp/registry_4b5145.json",
},
{
name: "simple remote source without extension returns auto-generated output location assumed to be yaml",
args: args{
source: workflow.Source{
Inputs: []workflow.Document{
{
Location: "http://example.com/openapi",
Location: fmt.Sprintf("%s/openapi", testServer.URL),
},
},
},
},
wantOutputLocation: ".speakeasy/temp/registry_94359d.yaml",
wantOutputLocation: ".speakeasy/temp/registry_61ea27.yaml",
},
{
name: "source with multiple inputs returns specified output location",
Expand Down Expand Up @@ -469,6 +524,76 @@ func TestSource_GetOutputLocation(t *testing.T) {
},
wantOutputLocation: ".speakeasy/temp/output_d910ba.yaml",
},
{
name: "single local source uses same extension as source",
args: args{
source: workflow.Source{
Inputs: []workflow.Document{
{
Location: "openapi.json",
},
},
},
},
wantOutputLocation: "openapi.json",
},
{
name: "single local source with overlays uses same extension as source",
args: args{
source: workflow.Source{
Inputs: []workflow.Document{
{
Location: "openapi.json",
},
},
Overlays: []workflow.Overlay{
{Document: &workflow.Document{Location: "overlay.yaml"}},
},
},
},
wantOutputLocation: ".speakeasy/temp/output_a98653.json",
},
{
name: "single remote source with unknown format uses resolved json extension",
args: args{
source: workflow.Source{
Inputs: []workflow.Document{
{
// The test server is setup so that it will return json if the path includes the word "json"
Location: fmt.Sprintf("%s/thepathincludesjson", testServer.URL),
},
},
},
},
wantOutputLocation: ".speakeasy/temp/registry_411616.json",
},
{
name: "single remote source with unknown format uses resolved yaml extension",
args: args{
source: workflow.Source{
Inputs: []workflow.Document{
{
// The test server is setup so that it will return yaml if the path includes the word "yaml"
Location: fmt.Sprintf("%s/thepathincludesyaml", testServer.URL),
},
},
},
},
wantOutputLocation: ".speakeasy/temp/registry_0254db.yaml",
},
{
name: "single remote source with unsupported file extension returns auto-generated output location",
args: args{
source: workflow.Source{
Inputs: []workflow.Document{
{
Location: fmt.Sprintf("%s/foo.txt", testServer.URL),
},
},
},
},
wantOutputLocation: ".speakeasy/temp/registry_69a6f2.yaml",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -611,3 +736,17 @@ func createEmptyFile(path string) error {

return f.Close()
}

func newTestServerWithURL(URL string, handler http.Handler) (*httptest.Server, error) {
ts := httptest.NewUnstartedServer(handler)
if URL != "" {
l, err := net.Listen("tcp", URL)
if err != nil {
return nil, err
}
ts.Listener.Close()
ts.Listener = l
}
ts.Start()
return ts, nil
}
Loading