diff --git a/etree.go b/etree.go index c3bb746..07e1a6d 100644 --- a/etree.go +++ b/etree.go @@ -50,6 +50,13 @@ type ReadSettings struct { // preserve them instead of keeping only one. Default: false. PreserveDuplicateAttrs bool + // ValidateInput forces all ReadFrom* methods to validate that the + // provided input is composed of well-formed XML before processing it. If + // invalid XML is detected, the ReadFrom* methods return an error. Because + // this option requires the input to be processed twice, it incurs a + // significant performance penalty. Default: false. + ValidateInput bool + // Entity to be passed to standard xml.Decoder. Default: nil. Entity map[string]string @@ -66,9 +73,6 @@ func newReadSettings() ReadSettings { CharsetReader: func(label string, input io.Reader) (io.Reader, error) { return input, nil }, - Permissive: false, - PreserveCData: false, - Entity: nil, } } @@ -353,6 +357,11 @@ func (d *Document) SetRoot(e *Element) { // ReadFrom reads XML from the reader 'r' into this document. The function // returns the number of bytes read and any error encountered. func (d *Document) ReadFrom(r io.Reader) (n int64, err error) { + if d.ReadSettings.ValidateInput { + if err := validateXML(r, d.ReadSettings); err != nil { + return 0, err + } + } return d.Element.readFrom(r, d.ReadSettings) } @@ -380,6 +389,35 @@ func (d *Document) ReadFromString(s string) error { return err } +// validateXML determines if the data read from the reader 'r' contains +// well-formed XML according to the rules set by the go xml package. +func validateXML(r io.Reader, settings ReadSettings) error { + dec := newDecoder(r, settings) + err := dec.Decode(new(interface{})) + if err != nil { + return err + } + + // If there are any trailing tokens after unmarshalling with Decode(), + // then the XML input didn't terminate properly. + _, err = dec.Token() + if err == io.EOF { + return nil + } + return ErrXML +} + +// newDecoder creates an XML decoder for the reader 'r' configured using +// the provided read settings. +func newDecoder(r io.Reader, settings ReadSettings) *xml.Decoder { + d := xml.NewDecoder(r) + d.CharsetReader = settings.CharsetReader + d.Strict = !settings.Permissive + d.Entity = settings.Entity + d.AutoClose = settings.AutoClose + return d +} + // WriteTo serializes the document out to the writer 'w'. The function returns // the number of bytes written and any error encountered. func (d *Document) WriteTo(w io.Writer) (n int64, err error) { @@ -835,10 +873,7 @@ func (e *Element) readFrom(ri io.Reader, settings ReadSettings) (n int64, err er r = newXmlSimpleReader(ri) } - dec := xml.NewDecoder(r) - dec.CharsetReader = settings.CharsetReader - dec.Strict = !settings.Permissive - dec.Entity = settings.Entity + dec := newDecoder(r, settings) var stack stack stack.push(e) diff --git a/etree_test.go b/etree_test.go index 11c0eea..57c991e 100644 --- a/etree_test.go +++ b/etree_test.go @@ -1524,3 +1524,35 @@ func TestNotNil(t *testing.T) { t.Error("got:\n" + got) } } + +func TestValidateInput(t *testing.T) { + tests := []struct { + s string + err string + }{ + {`x`, ""}, + {``, ""}, + {`x`, `XML syntax error on line 1: unexpected EOF`}, + {``, `XML syntax error on line 1: unexpected end element `}, + {`<>`, `XML syntax error on line 1: expected element name after <`}, + {`xtrailing`, "etree: invalid XML format"}, + {`x<`, "etree: invalid XML format"}, + {`x`, `XML syntax error on line 1: element closed by `}, + } + + for i, test := range tests { + doc := NewDocument() + doc.ReadSettings.ValidateInput = true + err := doc.ReadFromString(test.s) + if err == nil { + if test.err != "" { + t.Errorf("etree: test #%d:\nExpected error:\n %s\nReceived error:\n nil", i, test.err) + } + } else { + te := err.Error() + if te != test.err { + t.Errorf("etree: test #%d:\nExpected error;\n %s\nReceived error:\n %s", i, test.err, te) + } + } + } +}