Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make it possible to access http response headers #80

Merged
merged 4 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions extism.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ type Runtime struct {

// PluginConfig contains configuration options for the Extism plugin.
type PluginConfig struct {
ModuleConfig wazero.ModuleConfig
RuntimeConfig wazero.RuntimeConfig
EnableWasi bool
ObserveAdapter *observe.AdapterBase
ObserveOptions *observe.Options
ModuleConfig wazero.ModuleConfig
RuntimeConfig wazero.RuntimeConfig
EnableWasi bool
ObserveAdapter *observe.AdapterBase
ObserveOptions *observe.Options
EnableHttpResponseHeaders bool
}

// HttpRequest represents an HTTP request to be made by the plugin.
Expand Down Expand Up @@ -123,6 +124,7 @@ type Plugin struct {
AllowedHosts []string
AllowedPaths map[string]string
LastStatusCode int
LastResponseHeaders map[string]string
MaxHttpResponseBytes int64
MaxVarBytes int64
log func(LogLevel, string)
Expand Down Expand Up @@ -508,6 +510,11 @@ func NewPlugin(
if manifest.Memory != nil && manifest.Memory.MaxVarBytes >= 0 {
varMax = int64(manifest.Memory.MaxVarBytes)
}

var headers map[string]string = nil
if config.EnableHttpResponseHeaders {
headers = map[string]string{}
}
for _, m := range modules {
if m.inner.Name() == "main" {
p := &Plugin{
Expand All @@ -519,6 +526,7 @@ func NewPlugin(
AllowedHosts: manifest.AllowedHosts,
AllowedPaths: manifest.AllowedPaths,
LastStatusCode: 0,
LastResponseHeaders: headers,
Timeout: time.Duration(manifest.Timeout) * time.Millisecond,
MaxHttpResponseBytes: httpMax,
MaxVarBytes: varMax,
Expand Down
56 changes: 56 additions & 0 deletions extism_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,62 @@ func TestHTTP_denied(t *testing.T) {
}
}

func TestHTTPHeaders_allowed(t *testing.T) {
manifest := manifest("http_headers.wasm")
manifest.AllowedHosts = []string{"extism.org"}

ctx := context.Background()
config := wasiPluginConfig()
config.EnableHttpResponseHeaders = true

plugin, err := NewPlugin(ctx, manifest, config, []HostFunction{})
if err != nil {
t.Error(err)
}

defer plugin.Close()

req, _ := json.Marshal(map[string]string{
"url": "https://extism.org",
})

exit, output, err := plugin.Call("http_get", req)

if assertCall(t, err, exit) {
headers := map[string]string{}
err = json.Unmarshal(output, &headers)
if err != nil {
t.Error(err)
}
assert.Equal(t, "text/html; charset=utf-8", headers["content-type"])
}
}

func TestHTTPHeaders_denied(t *testing.T) {
manifest := manifest("http_headers.wasm")
manifest.AllowedHosts = []string{"extism.org"}

ctx := context.Background()
config := wasiPluginConfig()

plugin, err := NewPlugin(ctx, manifest, config, []HostFunction{})
if err != nil {
t.Error(err)
}

defer plugin.Close()

req, _ := json.Marshal(map[string]string{
"url": "https://extism.org",
})

exit, output, err := plugin.Call("http_get", req)

if assertCall(t, err, exit) {
assert.Equal(t, output, []byte("{}"))
}
}

func TestLog_default(t *testing.T) {
manifest := manifest("log.wasm")

Expand Down
35 changes: 35 additions & 0 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net/http"
"net/url"
"strings"
"unsafe"

// TODO: is there a better package for this?
Expand Down Expand Up @@ -303,6 +304,7 @@ func buildEnvModule(ctx context.Context, rt wazero.Runtime, extism api.Module) (
hostFunc("var_set", varSet)
hostFunc("http_request", httpRequest)
hostFunc("http_status_code", httpStatusCode)
hostFunc("http_headers", httpHeaders)
hostFunc("get_log_level", getLogLevel)

logFunc := func(name string, level LogLevel) {
Expand Down Expand Up @@ -486,6 +488,13 @@ func httpRequest(ctx context.Context, m api.Module, requestOffset uint64, bodyOf
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
cp := plugin.currentPlugin()

if plugin.LastResponseHeaders != nil {
for k := range plugin.LastResponseHeaders {
delete(plugin.LastResponseHeaders, k)
}
}
plugin.LastStatusCode = 0

requestJson, err := cp.ReadBytes(requestOffset)
if err != nil {
panic(fmt.Errorf("failed to read http request from memory: %v", err))
Expand Down Expand Up @@ -550,6 +559,12 @@ func httpRequest(ctx context.Context, m api.Module, requestOffset uint64, bodyOf
}
defer resp.Body.Close()

if plugin.LastResponseHeaders != nil {
for k, v := range resp.Header {
plugin.LastResponseHeaders[strings.ToLower(k)] = v[0]
}
}

plugin.LastStatusCode = resp.StatusCode

limiter := http.MaxBytesReader(nil, resp.Body, int64(plugin.MaxHttpResponseBytes))
Expand Down Expand Up @@ -581,6 +596,26 @@ func httpStatusCode(ctx context.Context, m api.Module) int32 {
panic("Invalid context, `plugin` key not found")
}

func httpHeaders(ctx context.Context, _ api.Module) uint64 {
if plugin, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
if plugin.LastResponseHeaders == nil {
return 0
}

data, err := json.Marshal(plugin.LastResponseHeaders)
if err != nil {
panic(err)
}
mem, err := plugin.currentPlugin().WriteBytes(data)
if err != nil {
panic(err)
}
return mem
}

panic("Invalid context, `plugin` key not found")
}

func getLogLevel(ctx context.Context, m api.Module) int32 {
// if _, ok := ctx.Value(PluginCtxKey("plugin")).(*Plugin); ok {
// panic("Invalid context, `plugin` key not found")
Expand Down
Binary file added wasm/http_headers.wasm
Binary file not shown.
Loading