diff --git a/pkg/ottl/ottlfuncs/func_remove_xml.go b/pkg/ottl/ottlfuncs/func_remove_xml.go index 1558a2bd3f42..b45ee74fcd1f 100644 --- a/pkg/ottl/ottlfuncs/func_remove_xml.go +++ b/pkg/ottl/ottlfuncs/func_remove_xml.go @@ -47,7 +47,13 @@ func removeXML[K any](target ottl.StringGetter[K], xPath string) ottl.ExprFunc[K } else if doc, err = parseNodesXML(targetVal); err != nil { return nil, err } - xmlquery.FindEach(doc, xPath, func(_ int, n *xmlquery.Node) { + + nodes, err := xmlquery.QueryAll(doc, xPath) + if err != nil { + return nil, err + } + + for _, n := range nodes { switch n.Type { case xmlquery.ElementNode: xmlquery.RemoveFromTree(n) @@ -60,7 +66,7 @@ func removeXML[K any](target ottl.StringGetter[K], xPath string) ottl.ExprFunc[K case xmlquery.CharDataNode: xmlquery.RemoveFromTree(n) } - }) + } return doc.OutputXML(false), nil } } @@ -82,7 +88,7 @@ func parseNodesXML(targetVal string) (*xmlquery.Node, error) { if err != nil { return nil, fmt.Errorf("parse xml: %w", err) } - if !preserveDeclearation { + if !preserveDeclearation && top.FirstChild != nil { xmlquery.RemoveFromTree(top.FirstChild) } return top, nil diff --git a/pkg/ottl/ottlfuncs/func_remove_xml_test.go b/pkg/ottl/ottlfuncs/func_remove_xml_test.go index c22f77c3199e..bf7e9d488bf3 100644 --- a/pkg/ottl/ottlfuncs/func_remove_xml_test.go +++ b/pkg/ottl/ottlfuncs/func_remove_xml_test.go @@ -91,15 +91,28 @@ func Test_RemoveXML(t *testing.T) { xPath: "//text()['*delete*']", want: ``, }, + { + name: "ignore empty", + document: ``, + xPath: "/", + want: ``, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - target := ottl.StandardStringGetter[any]{ - Getter: func(_ context.Context, _ any) (any, error) { - return tt.document, nil - }, - } - exprFunc := removeXML(target, tt.xPath) + factory := NewRemoveXMLFactory[any]() + exprFunc, err := factory.CreateFunction( + ottl.FunctionContext{}, + &RemoveXMLArguments[any]{ + Target: ottl.StandardStringGetter[any]{ + Getter: func(_ context.Context, _ any) (any, error) { + return tt.document, nil + }, + }, + XPath: tt.xPath, + }) + assert.NoError(t, err) + result, err := exprFunc(context.Background(), nil) assert.NoError(t, err) assert.Equal(t, tt.want, result)