diff --git a/cache.go b/cache.go
index 3abffcd..2982e3b 100644
--- a/cache.go
+++ b/cache.go
@@ -1,11 +1,11 @@
package xmlquery
import (
+ "fmt"
"sync"
- "github.com/golang/groupcache/lru"
-
"github.com/antchfx/xpath"
+ "github.com/golang/groupcache/lru"
)
// DisableSelectorCache will disable caching for the query selector if value is true.
@@ -21,23 +21,23 @@ var (
cacheMutex sync.Mutex
)
-func getQuery(expr string) (*xpath.Expr, error) {
+func getQuery(expr string, opts xpath.CompileOptions) (*xpath.Expr, error) {
+ key := expr + fmt.Sprintf("%#v", opts)
if DisableSelectorCache || SelectorCacheMaxEntries <= 0 {
- return xpath.Compile(expr)
+ return xpath.CompileWithOptions(expr, opts)
}
cacheOnce.Do(func() {
cache = lru.New(SelectorCacheMaxEntries)
})
cacheMutex.Lock()
defer cacheMutex.Unlock()
- if v, ok := cache.Get(expr); ok {
+ if v, ok := cache.Get(key); ok {
return v.(*xpath.Expr), nil
}
- v, err := xpath.Compile(expr)
+ v, err := xpath.CompileWithOptions(expr, opts)
if err != nil {
return nil, err
}
- cache.Add(expr, v)
+ cache.Add(key, v)
return v, nil
-
}
diff --git a/parse.go b/parse.go
index d359b50..d904115 100644
--- a/parse.go
+++ b/parse.go
@@ -372,7 +372,7 @@ type StreamParser struct {
// streamElementFilter, if provided, cannot be successfully parsed and compiled
// into a valid xpath query.
func CreateStreamParser(r io.Reader, streamElementXPath string, streamElementFilter ...string) (*StreamParser, error) {
- return CreateStreamParserWithOptions(r, ParserOptions{}, streamElementXPath, streamElementFilter...)
+ return CreateStreamParserWithCompileOptions(r, ParserOptions{}, xpath.CompileOptions{}, streamElementXPath, streamElementFilter...)
}
// CreateStreamParserWithOptions is like CreateStreamParser, but with custom options
@@ -382,13 +382,24 @@ func CreateStreamParserWithOptions(
streamElementXPath string,
streamElementFilter ...string,
) (*StreamParser, error) {
- elemXPath, err := getQuery(streamElementXPath)
+ return CreateStreamParserWithCompileOptions(r, options, xpath.CompileOptions{}, streamElementXPath, streamElementFilter...)
+}
+
+// New function to allow passing CompileOptions
+func CreateStreamParserWithCompileOptions(
+ r io.Reader,
+ options ParserOptions,
+ compileOpts xpath.CompileOptions,
+ streamElementXPath string,
+ streamElementFilter ...string,
+) (*StreamParser, error) {
+ elemXPath, err := getQuery(streamElementXPath, compileOpts)
if err != nil {
return nil, fmt.Errorf("invalid streamElementXPath '%s', err: %s", streamElementXPath, err.Error())
}
elemFilter := (*xpath.Expr)(nil)
if len(streamElementFilter) > 0 {
- elemFilter, err = getQuery(streamElementFilter[0])
+ elemFilter, err = getQuery(streamElementFilter[0], compileOpts)
if err != nil {
return nil, fmt.Errorf("invalid streamElementFilter '%s', err: %s", streamElementFilter[0], err.Error())
}
diff --git a/parse_test.go b/parse_test.go
index d87780a..1bb1425 100644
--- a/parse_test.go
+++ b/parse_test.go
@@ -7,6 +7,8 @@ import (
"net/http/httptest"
"strings"
"testing"
+
+ "github.com/antchfx/xpath"
)
func TestLoadURLSuccess(t *testing.T) {
@@ -331,25 +333,25 @@ func TestMissingNamespace(t *testing.T) {
func TestTooNested(t *testing.T) {
s := `
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- `
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ `
root, err := Parse(strings.NewReader(s))
if err != nil {
t.Error(err)
@@ -665,3 +667,46 @@ func TestDirective(t *testing.T) {
t.Errorf("expected count is 4 but got %d", m)
}
}
+
+func TestStreamParser_XPathStrictEOFOption(t *testing.T) {
+ // This test checks that passing StrictEOF to the stream parser causes an error on trailing garbage in the XPath expression.
+ xml := `12`
+
+ validXPath := "/root/a"
+ garbageXPath := "/root/a,foo" // trailing garbage after valid expr
+
+ // Without StrictEOF: should NOT error for garbageXPath
+ sp, err := CreateStreamParserWithCompileOptions(strings.NewReader(xml), ParserOptions{}, xpath.CompileOptions{}, garbageXPath)
+ if err != nil {
+ t.Fatalf("unexpected error without StrictEOF: %v", err)
+ }
+ _, err = sp.Read()
+ if err != nil && err != io.EOF {
+ t.Fatalf("unexpected error on Read() without StrictEOF: %v", err)
+ }
+
+ // With StrictEOF: should error for garbageXPath
+ sp, err = CreateStreamParserWithCompileOptions(strings.NewReader(xml), ParserOptions{}, xpath.CompileOptions{StrictEOF: true}, garbageXPath)
+ if err == nil {
+ _, err = sp.Read() // force evaluation if not already errored
+ }
+ if err == nil {
+ t.Fatal("expected error with StrictEOF and garbage XPath, but got nil")
+ }
+ if err != nil && !strings.Contains(err.Error(), "unexpected token after end of expression") {
+ t.Fatalf("expected strict EOF error, got: %v", err)
+ }
+
+ // With StrictEOF: should NOT error for valid XPath
+ sp, err = CreateStreamParserWithCompileOptions(strings.NewReader(xml), ParserOptions{}, xpath.CompileOptions{StrictEOF: true}, validXPath)
+ if err != nil {
+ t.Fatalf("unexpected error with StrictEOF and valid XPath: %v", err)
+ }
+ n, err := sp.Read()
+ if err != nil && err != io.EOF {
+ t.Fatalf("unexpected error on Read() with StrictEOF and valid XPath: %v", err)
+ }
+ if n == nil || n.Data != "a" {
+ t.Fatalf("expected to find node, got: %#v", n)
+ }
+}
diff --git a/query.go b/query.go
index d1353aa..d145939 100644
--- a/query.go
+++ b/query.go
@@ -85,24 +85,32 @@ func FindOne(top *Node, expr string) *Node {
// QueryAll searches the XML Node that matches by the specified XPath expr.
// Returns an error if the expression `expr` cannot be parsed.
-func QueryAll(top *Node, expr string) ([]*Node, error) {
- exp, err := getQuery(expr)
+func QueryAllWithOptions(top *Node, expr string, opts xpath.CompileOptions) ([]*Node, error) {
+ exp, err := getQuery(expr, opts)
if err != nil {
return nil, err
}
return QuerySelectorAll(top, exp), nil
}
+func QueryAll(top *Node, expr string) ([]*Node, error) {
+ return QueryAllWithOptions(top, expr, xpath.CompileOptions{})
+}
+
// Query searches the XML Node that matches by the specified XPath expr,
// and returns first matched element.
-func Query(top *Node, expr string) (*Node, error) {
- exp, err := getQuery(expr)
+func QueryWithOptions(top *Node, expr string, opts xpath.CompileOptions) (*Node, error) {
+ exp, err := getQuery(expr, opts)
if err != nil {
return nil, err
}
return QuerySelector(top, exp), nil
}
+func Query(top *Node, expr string) (*Node, error) {
+ return QueryWithOptions(top, expr, xpath.CompileOptions{})
+}
+
// QuerySelectorAll searches all of the XML Node that matches the specified
// XPath selectors.
func QuerySelectorAll(top *Node, selector *xpath.Expr) []*Node {
diff --git a/query_test.go b/query_test.go
index b4158be..6e3f3d6 100644
--- a/query_test.go
+++ b/query_test.go
@@ -4,6 +4,8 @@ import (
"fmt"
"strings"
"testing"
+
+ "github.com/antchfx/xpath"
)
// https://msdn.microsoft.com/en-us/library/ms762271(v=vs.85).aspx
@@ -161,12 +163,76 @@ func loadXML(s string) *Node {
}
func TestMissingTextNodes(t *testing.T) {
- doc := loadXML(`
+ doc := loadXML(`
Lorem ipsum dolor
`)
- results := Find(doc, "//text()")
- if len(results) != 3 {
- t.Fatalf("Expected text nodes 3, got %d", len(results))
- }
+ results := Find(doc, "//text()")
+ if len(results) != 3 {
+ t.Fatalf("Expected text nodes 3, got %d", len(results))
+ }
+}
+
+func TestQueryWithOptions_XPathStrictEOFOption(t *testing.T) {
+ validXPath := "/catalog/book"
+ garbageXPath := "/catalog/book,foo" // trailing garbage after valid expr
+
+ // Without StrictEOF: should NOT error for garbageXPath
+ node, err := QueryWithOptions(doc, garbageXPath, xpath.CompileOptions{})
+ if err != nil {
+ t.Fatalf("unexpected error without StrictEOF: %v", err)
+ }
+ if node == nil || node.Data != "book" {
+ t.Fatalf("expected to find node, got: %#v", node)
+ }
+
+ // With StrictEOF: should error for garbageXPath
+ _, err = QueryWithOptions(doc, garbageXPath, xpath.CompileOptions{StrictEOF: true})
+ if err == nil {
+ t.Fatal("expected error with StrictEOF and garbage XPath, but got nil")
+ }
+ if err != nil && !strings.Contains(err.Error(), "unexpected token after end of expression") {
+ t.Fatalf("expected strict EOF error, got: %v", err)
+ }
+
+ // With StrictEOF: should NOT error for valid XPath
+ node, err = QueryWithOptions(doc, validXPath, xpath.CompileOptions{StrictEOF: true})
+ if err != nil {
+ t.Fatalf("unexpected error with StrictEOF and valid XPath: %v", err)
+ }
+ if node == nil || node.Data != "book" {
+ t.Fatalf("expected to find node, got: %#v", node)
+ }
+}
+
+func TestQueryAllWithOptions_XPathStrictEOFOption(t *testing.T) {
+ validXPath := "/catalog/book"
+ garbageXPath := "/catalog/book,foo" // trailing garbage after valid expr
+
+ // Without StrictEOF: should NOT error for garbageXPath
+ nodes, err := QueryAllWithOptions(doc, garbageXPath, xpath.CompileOptions{})
+ if err != nil {
+ t.Fatalf("unexpected error without StrictEOF: %v", err)
+ }
+ if len(nodes) != 3 || nodes[0].Data != "book" || nodes[1].Data != "book" || nodes[2].Data != "book" {
+ t.Fatalf("expected to find three nodes, got: %#v", nodes)
+ }
+
+ // With StrictEOF: should error for garbageXPath
+ _, err = QueryAllWithOptions(doc, garbageXPath, xpath.CompileOptions{StrictEOF: true})
+ if err == nil {
+ t.Fatal("expected error with StrictEOF and garbage XPath, but got nil")
+ }
+ if err != nil && !strings.Contains(err.Error(), "unexpected token after end of expression") {
+ t.Fatalf("expected strict EOF error, got: %v", err)
+ }
+
+ // With StrictEOF: should NOT error for valid XPath
+ nodes, err = QueryAllWithOptions(doc, validXPath, xpath.CompileOptions{StrictEOF: true})
+ if err != nil {
+ t.Fatalf("unexpected error with StrictEOF and valid XPath: %v", err)
+ }
+ if len(nodes) != 3 || nodes[0].Data != "book" || nodes[1].Data != "book" || nodes[2].Data != "book" {
+ t.Fatalf("expected to find three nodes, got: %#v", nodes)
+ }
}