Skip to content
This repository has been archived by the owner on Jan 21, 2020. It is now read-only.

support headers in template fetch #444

Merged
merged 7 commits into from
Mar 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/rpc/server/info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestFetchAPIInfoFromPlugin(t *testing.T) {
server, err := StartPluginAtPath(socketPath, rpc_instance.PluginServer(&testing_instance.Plugin{}))
require.NoError(t, err)

buff, err := template.Fetch(url, template.Options{SocketDir: dir})
buff, err := template.Fetch(url, template.Options{SocketDir: dir}, nil)
require.NoError(t, err)

decoded, err := template.FromJSON(buff)
Expand All @@ -45,7 +45,7 @@ func TestFetchAPIInfoFromPlugin(t *testing.T) {
require.Equal(t, "Instance", result)

url = "unix://" + host + "/info/functions.json"
buff, err = template.Fetch(url, template.Options{SocketDir: dir})
buff, err = template.Fetch(url, template.Options{SocketDir: dir}, nil)
require.NoError(t, err)

server.Stop()
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestFetchFunctionsFromPlugin(t *testing.T) {
server, err := StartPluginAtPath(socketPath, rpc_flavor.PluginServer(&exporter{&testing_flavor.Plugin{}}))
require.NoError(t, err)

buff, err := template.Fetch(url, template.Options{SocketDir: dir})
buff, err := template.Fetch(url, template.Options{SocketDir: dir}, nil)
require.NoError(t, err)

decoded, err := template.FromJSON(buff)
Expand Down
34 changes: 21 additions & 13 deletions pkg/template/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

// Fetch fetchs content from the given URL string. Supported schemes are http:// https:// file:// unix://
func Fetch(s string, opt Options) ([]byte, error) {
func Fetch(s string, opt Options, customize func(*http.Request)) ([]byte, error) {
u, err := url.Parse(s)
if err != nil {
return nil, err
Expand All @@ -21,12 +21,7 @@ func Fetch(s string, opt Options) ([]byte, error) {
return ioutil.ReadFile(u.Path)

case "http", "https":
resp, err := http.Get(u.String())
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
return doHTTPGet(u, customize, &http.Client{})

case "unix":
// unix: will look for a socket that matches the host name at a
Expand All @@ -36,17 +31,30 @@ func Fetch(s string, opt Options) ([]byte, error) {
return nil, err
}
u.Scheme = "http"
resp, err := c.Get(u.String())
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
return doHTTPGet(u, customize, c)
}

return nil, fmt.Errorf("unsupported url:%s", s)
}

func doHTTPGet(u *url.URL, customize func(*http.Request), client *http.Client) ([]byte, error) {
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, err
}

if customize != nil {
customize(req)
}

resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
}

func socketClient(u *url.URL, socketDir string) (*http.Client, error) {
socketPath := filepath.Join(socketDir, u.Host)
if f, err := os.Stat(socketPath); err != nil {
Expand Down
64 changes: 48 additions & 16 deletions pkg/template/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"time"
Expand Down Expand Up @@ -134,8 +135,45 @@ func IndexOf(srch interface{}, array interface{}, strictOptional ...bool) int {
return -1
}

// given optional args in a template function call, extra headers and the context
func headersAndContext(opt ...interface{}) (headers map[string][]string, context interface{}) {
if len(opt) == 0 {
return
}
// scan through all the args and if it's a string of the form x=y, then use as header
// the element that doesn't follow the form is the context
headers = map[string][]string{}
for _, v := range opt {
if vv, is := v.(string); is && strings.Index(vv, "=") > 0 {
kv := strings.Split(vv, "=")
key := kv[0]
value := ""
if len(kv) == 2 {
value = kv[1]
}
if _, has := headers[key]; !has {
headers[key] = []string{value}
} else {
headers[key] = append(headers[key], value)
}
} else {
context = v
}
}
return
}

func setHeaders(req *http.Request, headers map[string][]string) {
for k, vv := range headers {
for _, v := range vv {
req.Header.Add(k, v)
}
}
}

// DefaultFuncs returns a list of default functions for binding in the template
func (t *Template) DefaultFuncs() []Function {

return []Function{
{
Name: "source",
Expand All @@ -146,10 +184,7 @@ func (t *Template) DefaultFuncs() []Function {
"as the calling template. The context (e.g. variables) of the calling template as a result can be mutated.",
},
Func: func(p string, opt ...interface{}) (string, error) {
var o interface{}
if len(opt) > 0 {
o = opt[0]
}
headers, context := headersAndContext(opt...)
loc := p
if strings.Index(loc, "str://") == -1 {
buff, err := getURL(t.url, p)
Expand All @@ -158,7 +193,7 @@ func (t *Template) DefaultFuncs() []Function {
}
loc = buff
}
sourced, err := NewTemplate(loc, t.options)
sourced, err := NewTemplateCustom(loc, t.options, func(req *http.Request) { setHeaders(req, headers) })
if err != nil {
return "", err
}
Expand All @@ -167,11 +202,11 @@ func (t *Template) DefaultFuncs() []Function {
sourced.forkFrom(t)
sourced.context = t.context

if o == nil {
o = sourced.context
if context == nil {
context = sourced.context
}
// TODO(chungers) -- let the sourced template define new functions that can be called in the parent.
return sourced.Render(o)
return sourced.Render(context)
},
},
{
Expand All @@ -184,10 +219,7 @@ func (t *Template) DefaultFuncs() []Function {
"be visible in the calling template's context.",
},
Func: func(p string, opt ...interface{}) (string, error) {
var o interface{}
if len(opt) > 0 {
o = opt[0]
}
headers, context := headersAndContext(opt...)
loc := p
if strings.Index(loc, "str://") == -1 {
buff, err := getURL(t.url, p)
Expand All @@ -196,7 +228,7 @@ func (t *Template) DefaultFuncs() []Function {
}
loc = buff
}
included, err := NewTemplate(loc, t.options)
included, err := NewTemplateCustom(loc, t.options, func(req *http.Request) { setHeaders(req, headers) })
if err != nil {
return "", err
}
Expand All @@ -206,11 +238,11 @@ func (t *Template) DefaultFuncs() []Function {
}
included.context = dotCopy

if o == nil {
o = included.context
if context == nil {
context = included.context
}

return included.Render(o)
return included.Render(context)
},
},
{
Expand Down
53 changes: 53 additions & 0 deletions pkg/template/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ import (
"sync"
"testing"

"github.com/docker/infrakit/pkg/log"
"github.com/docker/infrakit/pkg/types"
"github.com/stretchr/testify/require"
)

var logger = log.New("module", "template")

func TestTemplateInclusionFromDifferentSources(t *testing.T) {
prefix := testSetupTemplates(t, testFiles)

Expand Down Expand Up @@ -256,3 +260,52 @@ func TestWithFunctions(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "hello=1", view)
}

func TestSourceWithHeaders(t *testing.T) {

h, context := headersAndContext("foo=bar")
logger.Info("result", "context", context, "headers", h)
require.Equal(t, interface{}(nil), context)
require.Equal(t, map[string][]string{"foo": {"bar"}}, h)

h, context = headersAndContext("foo=bar", "bar=baz", 224)
logger.Info("result", "context", context, "headers", h)
require.Equal(t, 224, context)
require.Equal(t, map[string][]string{"foo": {"bar"}, "bar": {"baz"}}, h)

h, context = headersAndContext("foo=bar", "bar=baz")
logger.Info("result", "context", context, "headers", h)
require.Equal(t, nil, context)
require.Equal(t, map[string][]string{"foo": {"bar"}, "bar": {"baz"}}, h)

h, context = headersAndContext("foo")
logger.Info("result", "context", context, "headers", h)
require.Equal(t, "foo", context)
require.Equal(t, map[string][]string{}, h)

h, context = headersAndContext("foo=bar", map[string]string{"hello": "world"})
logger.Info("result", "context", context, "headers", h)
require.Equal(t, map[string]string{"hello": "world"}, context)
require.Equal(t, map[string][]string{"foo": {"bar"}}, h)

// note we don't have to escape -- use the back quote and the string value is valid
r := "{{ include `https://httpbin.org/headers` `A=B` `Foo=Bar` `Foo=Bar` `X=1` 100 }}"
s := `{{ $resp := (source "str://` + r + `" | jsonDecode) }}{{ $resp.headers | jsonEncode}}`
tt, err := NewTemplate("str://"+s, Options{})
require.NoError(t, err)
view, err := tt.Render(nil)
require.NoError(t, err)

any := types.AnyString(view)
headers := map[string]interface{}{}
require.NoError(t, any.Decode(&headers))
require.Equal(t, map[string]interface{}{
"Foo": "Bar,Bar",
"Host": "httpbin.org",
"User-Agent": "Go-http-client/1.1",
"A": "B",
"X": "1",
"Accept-Encoding": "gzip",
"Connection": "close",
}, headers)
}
22 changes: 21 additions & 1 deletion pkg/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"io"
"net/http"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -88,6 +89,25 @@ type Void string

const voidValue Void = ""

// NewTemplateCustom fetches the content at the url and allows configuration of the request
// If the string begins with str:// as scheme, then the rest of the string is interpreted as the body of the template.
func NewTemplateCustom(s string, opt Options, custom func(*http.Request)) (*Template, error) {
var buff []byte
contextURL := s
// Special case of specifying the entire template as a string; otherwise treat as url
if strings.Index(s, "str://") == 0 {
buff = []byte(strings.Replace(s, "str://", "", 1))
contextURL = defaultContextURL()
} else {
b, err := Fetch(s, opt, custom)
if err != nil {
return nil, err
}
buff = b
}
return NewTemplateFromBytes(buff, contextURL, opt)
}

// NewTemplate fetches the content at the url and returns a template. If the string begins
// with str:// as scheme, then the rest of the string is interpreted as the body of the template.
func NewTemplate(s string, opt Options) (*Template, error) {
Expand All @@ -98,7 +118,7 @@ func NewTemplate(s string, opt Options) (*Template, error) {
buff = []byte(strings.Replace(s, "str://", "", 1))
contextURL = defaultContextURL()
} else {
b, err := Fetch(s, opt)
b, err := Fetch(s, opt, nil)
if err != nil {
return nil, err
}
Expand Down