diff --git a/etree.go b/etree.go index 07e1a6d..3cbb575 100644 --- a/etree.go +++ b/etree.go @@ -358,9 +358,14 @@ func (d *Document) SetRoot(e *Element) { // returns the number of bytes read and any error encountered. func (d *Document) ReadFrom(r io.Reader) (n int64, err error) { if d.ReadSettings.ValidateInput { - if err := validateXML(r, d.ReadSettings); err != nil { + b, err := io.ReadAll(r) + if err != nil { return 0, err } + if err := validateXML(bytes.NewReader(b), d.ReadSettings); err != nil { + return 0, err + } + r = bytes.NewReader(b) } return d.Element.readFrom(r, d.ReadSettings) } @@ -373,19 +378,30 @@ func (d *Document) ReadFromFile(filepath string) error { return err } defer f.Close() + _, err = d.ReadFrom(f) return err } // ReadFromBytes reads XML from the byte slice 'b' into the this document. func (d *Document) ReadFromBytes(b []byte) error { - _, err := d.ReadFrom(bytes.NewReader(b)) + if d.ReadSettings.ValidateInput { + if err := validateXML(bytes.NewReader(b), d.ReadSettings); err != nil { + return err + } + } + _, err := d.Element.readFrom(bytes.NewReader(b), d.ReadSettings) return err } // ReadFromString reads XML from the string 's' into this document. func (d *Document) ReadFromString(s string) error { - _, err := d.ReadFrom(strings.NewReader(s)) + if d.ReadSettings.ValidateInput { + if err := validateXML(strings.NewReader(s), d.ReadSettings); err != nil { + return err + } + } + _, err := d.Element.readFrom(strings.NewReader(s), d.ReadSettings) return err } diff --git a/etree_test.go b/etree_test.go index 57c991e..4c825ef 100644 --- a/etree_test.go +++ b/etree_test.go @@ -7,8 +7,12 @@ package etree import ( "bytes" "encoding/xml" + "errors" "io" + "io/fs" "math/rand" + "os" + "path" "strings" "testing" ) @@ -1540,19 +1544,46 @@ func TestValidateInput(t *testing.T) { {`x`, `XML syntax error on line 1: element closed by `}, } - for i, test := range tests { - doc := NewDocument() - doc.ReadSettings.ValidateInput = true - err := doc.ReadFromString(test.s) - if err == nil { - if test.err != "" { - t.Errorf("etree: test #%d:\nExpected error:\n %s\nReceived error:\n nil", i, test.err) - } - } else { - te := err.Error() - if te != test.err { - t.Errorf("etree: test #%d:\nExpected error;\n %s\nReceived error:\n %s", i, test.err, te) + type readFunc func(doc *Document, s string) error + runTests := func(t *testing.T, read readFunc) { + for i, test := range tests { + doc := NewDocument() + doc.ReadSettings.ValidateInput = true + err := read(doc, test.s) + if err == nil { + if test.err != "" { + t.Errorf("etree: test #%d:\nExpected error:\n %s\nReceived error:\n nil", i, test.err) + } + root := doc.Root() + if root == nil || root.Tag != "root" { + t.Errorf("etree: test #%d: failed to read document after input validation", i) + } + } else { + te := err.Error() + if te != test.err { + t.Errorf("etree: test #%d:\nExpected error;\n %s\nReceived error:\n %s", i, test.err, te) + } } } } + + readFromString := func(doc *Document, s string) error { + return doc.ReadFromString(s) + } + t.Run("ReadFromString", func(t *testing.T) { runTests(t, readFromString) }) + + readFromBytes := func(doc *Document, s string) error { + return doc.ReadFromBytes([]byte(s)) + } + t.Run("ReadFromBytes", func(t *testing.T) { runTests(t, readFromBytes) }) + + readFromFile := func(doc *Document, s string) error { + pathtmp := path.Join(t.TempDir(), "etree-test") + err := os.WriteFile(pathtmp, []byte(s), fs.ModePerm) + if err != nil { + return errors.New("unable to write tmp file for input validation") + } + return doc.ReadFromFile(pathtmp) + } + t.Run("ReadFromFile", func(t *testing.T) { runTests(t, readFromFile) }) }