diff --git a/dec_obj_iter_test.go b/dec_obj_iter_test.go index ce8c6ff..2bbbff0 100644 --- a/dec_obj_iter_test.go +++ b/dec_obj_iter_test.go @@ -59,4 +59,20 @@ func TestDecoder_ObjIter(t *testing.T) { d := DecodeStr(``) require.ErrorIs(t, testIter(d), io.ErrUnexpectedEOF) }) + t.Run("Key", testBufferReader(`{"foo":1,"bar":1,"baz":1}`, func(t *testing.T, d *Decoder) { + a := require.New(t) + + iter, err := d.ObjIter() + a.NoError(err) + + var r []string + for iter.Next() { + r = append(r, string(iter.Key())) + a.NoError(d.Skip()) + } + a.False(iter.Next()) + a.NoError(iter.Err()) + + a.Equal([]string{"foo", "bar", "baz"}, r) + })) } diff --git a/dec_raw.go b/dec_raw.go index 3458d5d..326bb02 100644 --- a/dec_raw.go +++ b/dec_raw.go @@ -1,16 +1,35 @@ package jx -import "github.com/go-faster/errors" +import ( + "bytes" + "io" + + "github.com/go-faster/errors" +) // Raw is like Skip(), but saves and returns skipped value as raw json. // // Do not retain returned value, it references underlying buffer. func (d *Decoder) Raw() (Raw, error) { - if d.reader != nil { - return nil, errors.New("not implemented for io.Reader") + start := d.head + if orig := d.reader; orig != nil { + buf := bytes.Buffer{} + buf.Write(d.buf[d.head:d.tail]) + d.reader = io.TeeReader(orig, &buf) + defer func() { + d.reader = orig + }() + + if err := d.Skip(); err != nil { + return nil, errors.Wrap(err, "skip") + } + + unread := d.tail - d.head + raw := buf.Bytes() + raw = raw[:len(raw)-unread] + return raw, nil } - start := d.head if err := d.Skip(); err != nil { return nil, errors.Wrap(err, "skip") } diff --git a/dec_raw_test.go b/dec_raw_test.go index 7b852da..578e3b3 100644 --- a/dec_raw_test.go +++ b/dec_raw_test.go @@ -1,99 +1,133 @@ package jx import ( + "bytes" + "fmt" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestDecoder_Raw(t *testing.T) { - t.Run("Positive", func(t *testing.T) { - v := `{"foo": [1, 2, 3, 4, 5] }}` - t.Run("RawStr", func(t *testing.T) { - d := DecodeStr(v) - require.NoError(t, d.Obj(func(d *Decoder, key string) error { - raw, err := d.Raw() - assert.NoError(t, err) - assert.Equal(t, Array, raw.Type()) - assert.Equal(t, `[1, 2, 3, 4, 5]`, raw.String()) - var rd Decoder - rd.ResetBytes(raw) - assert.NoError(t, rd.Arr(func(d *Decoder) error { - raw, err := d.Raw() - assert.NoError(t, err) - assert.Equal(t, Number, raw.Type()) - n, err := DecodeBytes(raw).Num() - assert.NoError(t, err) - assert.False(t, n.Str()) +func testDecoderRaw(t *testing.T, raw func(d *Decoder) (Raw, error)) { + tests := []struct { + input string + typ Type + expectErr bool + }{ + {`"foo"`, String, false}, + {`"foo\`, Invalid, true}, + + {`10`, Number, false}, + {`1asf0`, Invalid, true}, + + {`null`, Null, false}, + {`nul`, Invalid, true}, + + {`true`, Bool, false}, + {`tru`, Invalid, true}, + + {`[1, 2, 3, 4, 5]`, Array, false}, + {`[1, 2, 3, 4, 5}`, Invalid, true}, + + {`{"foo":"bar"}`, Object, false}, + {`{"foo":"bar", "baz":"foobar"}`, Object, false}, + {`{"foo":"bar}`, Invalid, true}, + } + for i, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("Test%d", i+1), testBufferReader(tt.input, func(t *testing.T, d *Decoder) { + a := require.New(t) + raw, err := raw(d) + if tt.expectErr { + a.Error(err) + return + } + a.NoError(err) + a.Equal(tt.input, raw.String()) + a.Equal(tt.typ, raw.Type()) + })) + } + + t.Run("InsideObject", func(t *testing.T) { + for i, tt := range tests { + tt := tt + e := GetEncoder() + e.Obj(func(e *Encoder) { + const length = 8 + for i := range [length]struct{}{} { + e.FieldStart(fmt.Sprintf("skip%d", i)) + e.Str("it") + } + + e.FieldStart("test") + e.RawStr(tt.input) + + for i := range [length]struct{}{} { + e.FieldStart(fmt.Sprintf("skip%d", i+length)) + e.Str("it") + } + }) + input := e.String() + t.Run(fmt.Sprintf("Test%d", i+1), testBufferReader(input, func(t *testing.T, d *Decoder) { + a := require.New(t) + + err := d.ObjBytes(func(d *Decoder, key []byte) error { + if string(key) != "test" { + return d.Skip() + } + raw, err := raw(d) + if err != nil { + return err + } + a.Equal(tt.input, raw.String()) + a.Equal(tt.typ, raw.Type()) return nil - })) - return err - })) - }) - t.Run("RawAppend", func(t *testing.T) { - d := DecodeStr(v) - require.NoError(t, d.Obj(func(d *Decoder, key string) error { - raw, err := d.RawAppend(nil) - require.NoError(t, err) - t.Logf("%q", raw) - return err + }) + + if tt.expectErr { + a.Error(err) + } else { + a.NoError(err) + } })) - }) + } }) - t.Run("Negative", func(t *testing.T) { - v := `{"foo": [1, 2, 3, 4, 5` - t.Run("RawStr", func(t *testing.T) { - d := DecodeStr(v) - var called bool - require.Error(t, d.Obj(func(d *Decoder, key string) error { - called = true - raw, err := d.Raw() - require.Error(t, err) - require.Nil(t, raw) - return err - })) - require.True(t, called, "should be called") - }) - t.Run("RawAppend", func(t *testing.T) { - d := DecodeStr(v) - var called bool - require.Error(t, d.Obj(func(d *Decoder, key string) error { - called = true - raw, err := d.RawAppend(make([]byte, 10)) - require.Error(t, err) - require.Nil(t, raw) - return err + + t.Run("InsideArray", func(t *testing.T) { + for i, tt := range tests { + tt := tt + input := fmt.Sprintf(`[%s]`, tt.input) + t.Run(fmt.Sprintf("Test%d", i+1), testBufferReader(input, func(t *testing.T, d *Decoder) { + a := require.New(t) + + err := d.Arr(func(d *Decoder) error { + raw, err := raw(d) + if err != nil { + return err + } + a.Equal(tt.input, raw.String()) + a.Equal(tt.typ, raw.Type()) + return nil + }) + + if tt.expectErr { + a.Error(err) + } else { + a.NoError(err) + } })) - require.True(t, called, "should be called") - }) - }) - t.Run("Reader", func(t *testing.T) { - d := Decode(errReader{}, 0) - if _, err := d.Raw(); err == nil { - t.Error("should fail") - } - if _, err := d.RawAppend(nil); err == nil { - t.Error("should fail") } }) } -func BenchmarkDecoder_Raw(b *testing.B) { - data := []byte(`{"foo": [1,2,3,4,5,6,7,8,9,10,11,12,13,14]}`) - b.ReportAllocs() +func TestDecoder_Raw(t *testing.T) { + testDecoderRaw(t, (*Decoder).Raw) +} - var d Decoder - for i := 0; i < b.N; i++ { - d.ResetBytes(data) - raw, err := d.Raw() - if err != nil { - b.Fatal(err) - } - if len(raw) == 0 { - b.Fatal("blank") - } - } +func TestDecoder_RawAppend(t *testing.T) { + testDecoderRaw(t, func(d *Decoder) (Raw, error) { + return d.RawAppend(nil) + }) } func BenchmarkRaw_Type(b *testing.B) { @@ -106,3 +140,40 @@ func BenchmarkRaw_Type(b *testing.B) { } } } + +func BenchmarkDecoder_Raw(b *testing.B) { + data := []byte(`{"foo": [1,2,3,4,5,6,7,8,9,10,11,12,13,14]}`) + b.ReportAllocs() + + b.Run("Bytes", func(b *testing.B) { + var d Decoder + for i := 0; i < b.N; i++ { + d.ResetBytes(data) + raw, err := d.Raw() + if err != nil { + b.Fatal(err) + } + if len(raw) == 0 { + b.Fatal("blank") + } + } + }) + b.Run("Reader", func(b *testing.B) { + var ( + d Decoder + r = new(bytes.Reader) + ) + for i := 0; i < b.N; i++ { + r.Reset(data) + d.Reset(r) + + raw, err := d.Raw() + if err != nil { + b.Fatal(err) + } + if len(raw) == 0 { + b.Fatal("blank") + } + } + }) +} diff --git a/dec_skip_cases_test.go b/dec_skip_cases_test.go index 4aeefda..53c2d95 100644 --- a/dec_skip_cases_test.go +++ b/dec_skip_cases_test.go @@ -7,7 +7,6 @@ import ( "reflect" "strings" "testing" - "testing/iotest" "github.com/stretchr/testify/require" ) @@ -477,49 +476,36 @@ func TestDecoder_Skip(t *testing.T) { inputs: testObjs, }) - testDecode := func(iter *Decoder, input string, stdErr error) func(t *testing.T) { - return func(t *testing.T) { - t.Cleanup(func() { - if t.Failed() { - t.Logf("Input: %q", input) - } - }) - - should := require.New(t) - if stdErr == nil { - should.NoError(iter.Skip()) - should.ErrorIs(iter.Null(), io.ErrUnexpectedEOF) - } else { - should.Error(func() error { - if err := iter.Skip(); err != nil { - return err - } - if err := iter.Skip(); err != io.EOF { - return err - } - return nil - }()) - } - } - } for _, testCase := range testCases { valType := reflect.TypeOf(testCase.ptr).Elem() t.Run(valType.Kind().String(), func(t *testing.T) { for inputIdx, input := range testCase.inputs { - t.Run(fmt.Sprintf("Test%d", inputIdx), func(t *testing.T) { - ptrVal := reflect.New(valType) - stdErr := json.Unmarshal([]byte(input), ptrVal.Interface()) - - t.Run("Buffer", testDecode(DecodeStr(input), input, stdErr)) - - r := strings.NewReader(input) - d := Decode(r, 512) - t.Run("Reader", testDecode(d, input, stdErr)) - - r.Reset(input) - obr := iotest.OneByteReader(r) - t.Run("OneByteReader", testDecode(Decode(obr, 512), input, stdErr)) - }) + input := input + stdErr := json.Unmarshal([]byte(input), reflect.New(valType).Interface()) + cb := func(t *testing.T, iter *Decoder) { + t.Cleanup(func() { + if t.Failed() { + t.Logf("Input: %q", input) + } + }) + + should := require.New(t) + if stdErr == nil { + should.NoError(iter.Skip()) + should.ErrorIs(iter.Null(), io.ErrUnexpectedEOF) + } else { + should.Error(func() error { + if err := iter.Skip(); err != nil { + return err + } + if err := iter.Skip(); err != io.EOF { + return err + } + return nil + }()) + } + } + t.Run(fmt.Sprintf("Test%d", inputIdx), testBufferReader(input, cb)) } }) } diff --git a/dec_test.go b/dec_test.go index 9190c70..d668dc4 100644 --- a/dec_test.go +++ b/dec_test.go @@ -12,40 +12,43 @@ import ( "github.com/stretchr/testify/require" ) -func createTestCase(input string, cb func(t *testing.T, d *Decoder) error) func(t *testing.T) { - run := func(d *Decoder, input string, valid bool) func(t *testing.T) { - return func(t *testing.T) { - t.Cleanup(func() { - if t.Failed() { - t.Logf("Input: %q", input) - } - }) - - err := cb(t, d) - if valid { - require.NoError(t, err) - } else { - require.Error(t, err) - } - } - } - +func testBufferReader(input string, cb func(t *testing.T, d *Decoder)) func(t *testing.T) { return func(t *testing.T) { - valid := json.Valid([]byte(input)) - - t.Run("Buffer", run(DecodeStr(input), input, valid)) + t.Run("Buffer", func(t *testing.T) { + cb(t, DecodeStr(input)) + }) - r := strings.NewReader(input) - d := Decode(r, 512) - t.Run("Reader", run(d, input, valid)) + t.Run("Reader", func(t *testing.T) { + r := strings.NewReader(input) + cb(t, Decode(r, 512)) + }) - r.Reset(input) - obr := iotest.OneByteReader(r) - d.Reset(obr) - t.Run("OneByteReader", run(d, input, valid)) + t.Run("OneByteReader", func(t *testing.T) { + r := strings.NewReader(input) + obr := iotest.OneByteReader(r) + cb(t, Decode(obr, 512)) + }) } } +func createTestCase(input string, cb func(t *testing.T, d *Decoder) error) func(t *testing.T) { + valid := json.Valid([]byte(input)) + return testBufferReader(input, func(t *testing.T, d *Decoder) { + t.Cleanup(func() { + if t.Failed() { + t.Logf("Input: %q", input) + } + }) + + err := cb(t, d) + if valid { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) +} + func runTestCases(t *testing.T, cases []string, cb func(t *testing.T, d *Decoder) error) { for i, input := range cases { t.Run(fmt.Sprintf("Test%d", i), createTestCase(input, cb))