Skip to content

Commit

Permalink
use json.Decoder.UseNumber() when unmarshalling vars
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Scarr committed Aug 2, 2018
1 parent c555f54 commit 95fe07f
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 16 deletions.
4 changes: 2 additions & 2 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 11 additions & 8 deletions example/todo/todo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@ func TestTodo(t *testing.T) {
srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(New())))
c := client.New(srv.URL)

t.Run("create a new todo", func(t *testing.T) {
var resp struct {
CreateTodo struct{ ID int }
}
c.MustPost(`mutation { createTodo(todo:{text:"Fery important"}) { id } }`, &resp)
var resp struct {
CreateTodo struct{ ID int }
}
c.MustPost(`mutation { createTodo(todo:{text:"Fery important"}) { id } }`, &resp)

require.Equal(t, 4, resp.CreateTodo.ID)
})
require.Equal(t, 4, resp.CreateTodo.ID)

t.Run("update the todo text", func(t *testing.T) {
var resp struct {
UpdateTodo struct{ Text string }
}
c.MustPost(`mutation { updateTodo(id: 4, changes:{text:"Very important"}) { text } }`, &resp)
c.MustPost(
`mutation($id: Int!, $text: String!) { updateTodo(id: $id, changes:{text:$text}) { text } }`,
&resp,
client.Var("id", 4),
client.Var("text", "Very important"),
)

require.Equal(t, "Very important", resp.UpdateTodo.Text)
})
Expand Down
5 changes: 5 additions & 0 deletions graphql/float.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graphql

import (
"encoding/json"
"fmt"
"io"
"strconv"
Expand All @@ -18,8 +19,12 @@ func UnmarshalFloat(v interface{}) (float64, error) {
return strconv.ParseFloat(v, 64)
case int:
return float64(v), nil
case int64:
return float64(v), nil
case float64:
return v, nil
case json.Number:
return strconv.ParseFloat(string(v), 64)
default:
return 0, fmt.Errorf("%T is not an float", v)
}
Expand Down
3 changes: 3 additions & 0 deletions graphql/id.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graphql

import (
"encoding/json"
"fmt"
"io"
"strconv"
Expand All @@ -15,6 +16,8 @@ func UnmarshalID(v interface{}) (string, error) {
switch v := v.(type) {
case string:
return v, nil
case json.Number:
return string(v), nil
case int:
return strconv.Itoa(v), nil
case float64:
Expand Down
5 changes: 3 additions & 2 deletions graphql/int.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graphql

import (
"encoding/json"
"fmt"
"io"
"strconv"
Expand All @@ -20,8 +21,8 @@ func UnmarshalInt(v interface{}) (int, error) {
return v, nil
case int64:
return int(v), nil
case float64:
return int(v), nil
case json.Number:
return strconv.Atoi(string(v))
default:
return 0, fmt.Errorf("%T is not an int", v)
}
Expand Down
11 changes: 9 additions & 2 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

Expand Down Expand Up @@ -140,13 +141,13 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
reqParams.OperationName = r.URL.Query().Get("operationName")

if variables := r.URL.Query().Get("variables"); variables != "" {
if err := json.Unmarshal([]byte(variables), &reqParams.Variables); err != nil {
if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
return
}
}
case http.MethodPost:
if err := json.NewDecoder(r.Body).Decode(&reqParams); err != nil {
if err := jsonDecode(r.Body, &reqParams); err != nil {
sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
return
}
Expand Down Expand Up @@ -201,6 +202,12 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
})
}

func jsonDecode(r io.Reader, val interface{}) error {
dec := json.NewDecoder(r)
dec.UseNumber()
return dec.Decode(val)
}

func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
w.WriteHeader(code)
b, err := json.Marshal(&graphql.Response{Errors: errors})
Expand Down
11 changes: 9 additions & 2 deletions handler/websocket.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handler

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -132,7 +133,7 @@ func (c *wsConnection) run() {

func (c *wsConnection) subscribe(message *operationMessage) bool {
var reqParams params
if err := json.Unmarshal(message.Payload, &reqParams); err != nil {
if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
c.sendConnectionError("invalid json")
return false
}
Expand Down Expand Up @@ -228,11 +229,17 @@ func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
}

func (c *wsConnection) readOp() *operationMessage {
_, r, err := c.conn.NextReader()
if err != nil {
c.sendConnectionError("invalid json")
return nil
}
message := operationMessage{}
if err := c.conn.ReadJSON(&message); err != nil {
if err := jsonDecode(r, &message); err != nil {
c.sendConnectionError("invalid json")
return nil
}

return &message
}

Expand Down

0 comments on commit 95fe07f

Please sign in to comment.