Skip to content
This repository has been archived by the owner on Jan 21, 2020. It is now read-only.

Commit

Permalink
support headers in template fetch (#444)
Browse files Browse the repository at this point in the history
Signed-off-by: David Chung <david.chung@docker.com>
  • Loading branch information
David Chung authored Mar 24, 2017
1 parent dd74ee6 commit a8b26d3
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 33 deletions.
6 changes: 3 additions & 3 deletions pkg/rpc/server/info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestFetchAPIInfoFromPlugin(t *testing.T) {
server, err := StartPluginAtPath(socketPath, rpc_instance.PluginServer(&testing_instance.Plugin{}))
require.NoError(t, err)

buff, err := template.Fetch(url, template.Options{SocketDir: dir})
buff, err := template.Fetch(url, template.Options{SocketDir: dir}, nil)
require.NoError(t, err)

decoded, err := template.FromJSON(buff)
Expand All @@ -45,7 +45,7 @@ func TestFetchAPIInfoFromPlugin(t *testing.T) {
require.Equal(t, "Instance", result)

url = "unix://" + host + "/info/functions.json"
buff, err = template.Fetch(url, template.Options{SocketDir: dir})
buff, err = template.Fetch(url, template.Options{SocketDir: dir}, nil)
require.NoError(t, err)

server.Stop()
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestFetchFunctionsFromPlugin(t *testing.T) {
server, err := StartPluginAtPath(socketPath, rpc_flavor.PluginServer(&exporter{&testing_flavor.Plugin{}}))
require.NoError(t, err)

buff, err := template.Fetch(url, template.Options{SocketDir: dir})
buff, err := template.Fetch(url, template.Options{SocketDir: dir}, nil)
require.NoError(t, err)

decoded, err := template.FromJSON(buff)
Expand Down
34 changes: 21 additions & 13 deletions pkg/template/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

// Fetch fetchs content from the given URL string. Supported schemes are http:// https:// file:// unix://
func Fetch(s string, opt Options) ([]byte, error) {
func Fetch(s string, opt Options, customize func(*http.Request)) ([]byte, error) {
u, err := url.Parse(s)
if err != nil {
return nil, err
Expand All @@ -21,12 +21,7 @@ func Fetch(s string, opt Options) ([]byte, error) {
return ioutil.ReadFile(u.Path)

case "http", "https":
resp, err := http.Get(u.String())
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
return doHTTPGet(u, customize, &http.Client{})

case "unix":
// unix: will look for a socket that matches the host name at a
Expand All @@ -36,17 +31,30 @@ func Fetch(s string, opt Options) ([]byte, error) {
return nil, err
}
u.Scheme = "http"
resp, err := c.Get(u.String())
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
return doHTTPGet(u, customize, c)
}

return nil, fmt.Errorf("unsupported url:%s", s)
}

func doHTTPGet(u *url.URL, customize func(*http.Request), client *http.Client) ([]byte, error) {
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, err
}

if customize != nil {
customize(req)
}

resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
}

func socketClient(u *url.URL, socketDir string) (*http.Client, error) {
socketPath := filepath.Join(socketDir, u.Host)
if f, err := os.Stat(socketPath); err != nil {
Expand Down
64 changes: 48 additions & 16 deletions pkg/template/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"time"
Expand Down Expand Up @@ -134,8 +135,45 @@ func IndexOf(srch interface{}, array interface{}, strictOptional ...bool) int {
return -1
}

// given optional args in a template function call, extra headers and the context
func headersAndContext(opt ...interface{}) (headers map[string][]string, context interface{}) {
if len(opt) == 0 {
return
}
// scan through all the args and if it's a string of the form x=y, then use as header
// the element that doesn't follow the form is the context
headers = map[string][]string{}
for _, v := range opt {
if vv, is := v.(string); is && strings.Index(vv, "=") > 0 {
kv := strings.Split(vv, "=")
key := kv[0]
value := ""
if len(kv) == 2 {
value = kv[1]
}
if _, has := headers[key]; !has {
headers[key] = []string{value}
} else {
headers[key] = append(headers[key], value)
}
} else {
context = v
}
}
return
}

func setHeaders(req *http.Request, headers map[string][]string) {
for k, vv := range headers {
for _, v := range vv {
req.Header.Add(k, v)
}
}
}

// DefaultFuncs returns a list of default functions for binding in the template
func (t *Template) DefaultFuncs() []Function {

return []Function{
{
Name: "source",
Expand All @@ -146,10 +184,7 @@ func (t *Template) DefaultFuncs() []Function {
"as the calling template. The context (e.g. variables) of the calling template as a result can be mutated.",
},
Func: func(p string, opt ...interface{}) (string, error) {
var o interface{}
if len(opt) > 0 {
o = opt[0]
}
headers, context := headersAndContext(opt...)
loc := p
if strings.Index(loc, "str://") == -1 {
buff, err := getURL(t.url, p)
Expand All @@ -158,7 +193,7 @@ func (t *Template) DefaultFuncs() []Function {
}
loc = buff
}
sourced, err := NewTemplate(loc, t.options)
sourced, err := NewTemplateCustom(loc, t.options, func(req *http.Request) { setHeaders(req, headers) })
if err != nil {
return "", err
}
Expand All @@ -167,11 +202,11 @@ func (t *Template) DefaultFuncs() []Function {
sourced.forkFrom(t)
sourced.context = t.context

if o == nil {
o = sourced.context
if context == nil {
context = sourced.context
}
// TODO(chungers) -- let the sourced template define new functions that can be called in the parent.
return sourced.Render(o)
return sourced.Render(context)
},
},
{
Expand All @@ -184,10 +219,7 @@ func (t *Template) DefaultFuncs() []Function {
"be visible in the calling template's context.",
},
Func: func(p string, opt ...interface{}) (string, error) {
var o interface{}
if len(opt) > 0 {
o = opt[0]
}
headers, context := headersAndContext(opt...)
loc := p
if strings.Index(loc, "str://") == -1 {
buff, err := getURL(t.url, p)
Expand All @@ -196,7 +228,7 @@ func (t *Template) DefaultFuncs() []Function {
}
loc = buff
}
included, err := NewTemplate(loc, t.options)
included, err := NewTemplateCustom(loc, t.options, func(req *http.Request) { setHeaders(req, headers) })
if err != nil {
return "", err
}
Expand All @@ -206,11 +238,11 @@ func (t *Template) DefaultFuncs() []Function {
}
included.context = dotCopy

if o == nil {
o = included.context
if context == nil {
context = included.context
}

return included.Render(o)
return included.Render(context)
},
},
{
Expand Down
53 changes: 53 additions & 0 deletions pkg/template/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ import (
"sync"
"testing"

"github.com/docker/infrakit/pkg/log"
"github.com/docker/infrakit/pkg/types"
"github.com/stretchr/testify/require"
)

var logger = log.New("module", "template")

func TestTemplateInclusionFromDifferentSources(t *testing.T) {
prefix := testSetupTemplates(t, testFiles)

Expand Down Expand Up @@ -256,3 +260,52 @@ func TestWithFunctions(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "hello=1", view)
}

func TestSourceWithHeaders(t *testing.T) {

h, context := headersAndContext("foo=bar")
logger.Info("result", "context", context, "headers", h)
require.Equal(t, interface{}(nil), context)
require.Equal(t, map[string][]string{"foo": {"bar"}}, h)

h, context = headersAndContext("foo=bar", "bar=baz", 224)
logger.Info("result", "context", context, "headers", h)
require.Equal(t, 224, context)
require.Equal(t, map[string][]string{"foo": {"bar"}, "bar": {"baz"}}, h)

h, context = headersAndContext("foo=bar", "bar=baz")
logger.Info("result", "context", context, "headers", h)
require.Equal(t, nil, context)
require.Equal(t, map[string][]string{"foo": {"bar"}, "bar": {"baz"}}, h)

h, context = headersAndContext("foo")
logger.Info("result", "context", context, "headers", h)
require.Equal(t, "foo", context)
require.Equal(t, map[string][]string{}, h)

h, context = headersAndContext("foo=bar", map[string]string{"hello": "world"})
logger.Info("result", "context", context, "headers", h)
require.Equal(t, map[string]string{"hello": "world"}, context)
require.Equal(t, map[string][]string{"foo": {"bar"}}, h)

// note we don't have to escape -- use the back quote and the string value is valid
r := "{{ include `https://httpbin.org/headers` `A=B` `Foo=Bar` `Foo=Bar` `X=1` 100 }}"
s := `{{ $resp := (source "str://` + r + `" | jsonDecode) }}{{ $resp.headers | jsonEncode}}`
tt, err := NewTemplate("str://"+s, Options{})
require.NoError(t, err)
view, err := tt.Render(nil)
require.NoError(t, err)

any := types.AnyString(view)
headers := map[string]interface{}{}
require.NoError(t, any.Decode(&headers))
require.Equal(t, map[string]interface{}{
"Foo": "Bar,Bar",
"Host": "httpbin.org",
"User-Agent": "Go-http-client/1.1",
"A": "B",
"X": "1",
"Accept-Encoding": "gzip",
"Connection": "close",
}, headers)
}
22 changes: 21 additions & 1 deletion pkg/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"io"
"net/http"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -88,6 +89,25 @@ type Void string

const voidValue Void = ""

// NewTemplateCustom fetches the content at the url and allows configuration of the request
// If the string begins with str:// as scheme, then the rest of the string is interpreted as the body of the template.
func NewTemplateCustom(s string, opt Options, custom func(*http.Request)) (*Template, error) {
var buff []byte
contextURL := s
// Special case of specifying the entire template as a string; otherwise treat as url
if strings.Index(s, "str://") == 0 {
buff = []byte(strings.Replace(s, "str://", "", 1))
contextURL = defaultContextURL()
} else {
b, err := Fetch(s, opt, custom)
if err != nil {
return nil, err
}
buff = b
}
return NewTemplateFromBytes(buff, contextURL, opt)
}

// NewTemplate fetches the content at the url and returns a template. If the string begins
// with str:// as scheme, then the rest of the string is interpreted as the body of the template.
func NewTemplate(s string, opt Options) (*Template, error) {
Expand All @@ -98,7 +118,7 @@ func NewTemplate(s string, opt Options) (*Template, error) {
buff = []byte(strings.Replace(s, "str://", "", 1))
contextURL = defaultContextURL()
} else {
b, err := Fetch(s, opt)
b, err := Fetch(s, opt, nil)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit a8b26d3

Please sign in to comment.