diff --git a/cache.go b/cache.go index 3abffcd..2982e3b 100644 --- a/cache.go +++ b/cache.go @@ -1,11 +1,11 @@ package xmlquery import ( + "fmt" "sync" - "github.com/golang/groupcache/lru" - "github.com/antchfx/xpath" + "github.com/golang/groupcache/lru" ) // DisableSelectorCache will disable caching for the query selector if value is true. @@ -21,23 +21,23 @@ var ( cacheMutex sync.Mutex ) -func getQuery(expr string) (*xpath.Expr, error) { +func getQuery(expr string, opts xpath.CompileOptions) (*xpath.Expr, error) { + key := expr + fmt.Sprintf("%#v", opts) if DisableSelectorCache || SelectorCacheMaxEntries <= 0 { - return xpath.Compile(expr) + return xpath.CompileWithOptions(expr, opts) } cacheOnce.Do(func() { cache = lru.New(SelectorCacheMaxEntries) }) cacheMutex.Lock() defer cacheMutex.Unlock() - if v, ok := cache.Get(expr); ok { + if v, ok := cache.Get(key); ok { return v.(*xpath.Expr), nil } - v, err := xpath.Compile(expr) + v, err := xpath.CompileWithOptions(expr, opts) if err != nil { return nil, err } - cache.Add(expr, v) + cache.Add(key, v) return v, nil - } diff --git a/parse.go b/parse.go index d359b50..d904115 100644 --- a/parse.go +++ b/parse.go @@ -372,7 +372,7 @@ type StreamParser struct { // streamElementFilter, if provided, cannot be successfully parsed and compiled // into a valid xpath query. func CreateStreamParser(r io.Reader, streamElementXPath string, streamElementFilter ...string) (*StreamParser, error) { - return CreateStreamParserWithOptions(r, ParserOptions{}, streamElementXPath, streamElementFilter...) + return CreateStreamParserWithCompileOptions(r, ParserOptions{}, xpath.CompileOptions{}, streamElementXPath, streamElementFilter...) } // CreateStreamParserWithOptions is like CreateStreamParser, but with custom options @@ -382,13 +382,24 @@ func CreateStreamParserWithOptions( streamElementXPath string, streamElementFilter ...string, ) (*StreamParser, error) { - elemXPath, err := getQuery(streamElementXPath) + return CreateStreamParserWithCompileOptions(r, options, xpath.CompileOptions{}, streamElementXPath, streamElementFilter...) +} + +// New function to allow passing CompileOptions +func CreateStreamParserWithCompileOptions( + r io.Reader, + options ParserOptions, + compileOpts xpath.CompileOptions, + streamElementXPath string, + streamElementFilter ...string, +) (*StreamParser, error) { + elemXPath, err := getQuery(streamElementXPath, compileOpts) if err != nil { return nil, fmt.Errorf("invalid streamElementXPath '%s', err: %s", streamElementXPath, err.Error()) } elemFilter := (*xpath.Expr)(nil) if len(streamElementFilter) > 0 { - elemFilter, err = getQuery(streamElementFilter[0]) + elemFilter, err = getQuery(streamElementFilter[0], compileOpts) if err != nil { return nil, fmt.Errorf("invalid streamElementFilter '%s', err: %s", streamElementFilter[0], err.Error()) } diff --git a/parse_test.go b/parse_test.go index d87780a..1bb1425 100644 --- a/parse_test.go +++ b/parse_test.go @@ -7,6 +7,8 @@ import ( "net/http/httptest" "strings" "testing" + + "github.com/antchfx/xpath" ) func TestLoadURLSuccess(t *testing.T) { @@ -331,25 +333,25 @@ func TestMissingNamespace(t *testing.T) { func TestTooNested(t *testing.T) { s := ` - - - - - - - - - - - - - - - - - - - ` + + + + + + + + + + + + + + + + + + + ` root, err := Parse(strings.NewReader(s)) if err != nil { t.Error(err) @@ -665,3 +667,46 @@ func TestDirective(t *testing.T) { t.Errorf("expected count is 4 but got %d", m) } } + +func TestStreamParser_XPathStrictEOFOption(t *testing.T) { + // This test checks that passing StrictEOF to the stream parser causes an error on trailing garbage in the XPath expression. + xml := `12` + + validXPath := "/root/a" + garbageXPath := "/root/a,foo" // trailing garbage after valid expr + + // Without StrictEOF: should NOT error for garbageXPath + sp, err := CreateStreamParserWithCompileOptions(strings.NewReader(xml), ParserOptions{}, xpath.CompileOptions{}, garbageXPath) + if err != nil { + t.Fatalf("unexpected error without StrictEOF: %v", err) + } + _, err = sp.Read() + if err != nil && err != io.EOF { + t.Fatalf("unexpected error on Read() without StrictEOF: %v", err) + } + + // With StrictEOF: should error for garbageXPath + sp, err = CreateStreamParserWithCompileOptions(strings.NewReader(xml), ParserOptions{}, xpath.CompileOptions{StrictEOF: true}, garbageXPath) + if err == nil { + _, err = sp.Read() // force evaluation if not already errored + } + if err == nil { + t.Fatal("expected error with StrictEOF and garbage XPath, but got nil") + } + if err != nil && !strings.Contains(err.Error(), "unexpected token after end of expression") { + t.Fatalf("expected strict EOF error, got: %v", err) + } + + // With StrictEOF: should NOT error for valid XPath + sp, err = CreateStreamParserWithCompileOptions(strings.NewReader(xml), ParserOptions{}, xpath.CompileOptions{StrictEOF: true}, validXPath) + if err != nil { + t.Fatalf("unexpected error with StrictEOF and valid XPath: %v", err) + } + n, err := sp.Read() + if err != nil && err != io.EOF { + t.Fatalf("unexpected error on Read() with StrictEOF and valid XPath: %v", err) + } + if n == nil || n.Data != "a" { + t.Fatalf("expected to find node, got: %#v", n) + } +} diff --git a/query.go b/query.go index d1353aa..d145939 100644 --- a/query.go +++ b/query.go @@ -85,24 +85,32 @@ func FindOne(top *Node, expr string) *Node { // QueryAll searches the XML Node that matches by the specified XPath expr. // Returns an error if the expression `expr` cannot be parsed. -func QueryAll(top *Node, expr string) ([]*Node, error) { - exp, err := getQuery(expr) +func QueryAllWithOptions(top *Node, expr string, opts xpath.CompileOptions) ([]*Node, error) { + exp, err := getQuery(expr, opts) if err != nil { return nil, err } return QuerySelectorAll(top, exp), nil } +func QueryAll(top *Node, expr string) ([]*Node, error) { + return QueryAllWithOptions(top, expr, xpath.CompileOptions{}) +} + // Query searches the XML Node that matches by the specified XPath expr, // and returns first matched element. -func Query(top *Node, expr string) (*Node, error) { - exp, err := getQuery(expr) +func QueryWithOptions(top *Node, expr string, opts xpath.CompileOptions) (*Node, error) { + exp, err := getQuery(expr, opts) if err != nil { return nil, err } return QuerySelector(top, exp), nil } +func Query(top *Node, expr string) (*Node, error) { + return QueryWithOptions(top, expr, xpath.CompileOptions{}) +} + // QuerySelectorAll searches all of the XML Node that matches the specified // XPath selectors. func QuerySelectorAll(top *Node, selector *xpath.Expr) []*Node { diff --git a/query_test.go b/query_test.go index b4158be..6e3f3d6 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,8 @@ import ( "fmt" "strings" "testing" + + "github.com/antchfx/xpath" ) // https://msdn.microsoft.com/en-us/library/ms762271(v=vs.85).aspx @@ -161,12 +163,76 @@ func loadXML(s string) *Node { } func TestMissingTextNodes(t *testing.T) { - doc := loadXML(` + doc := loadXML(`

Lorem ipsum dolor

`) - results := Find(doc, "//text()") - if len(results) != 3 { - t.Fatalf("Expected text nodes 3, got %d", len(results)) - } + results := Find(doc, "//text()") + if len(results) != 3 { + t.Fatalf("Expected text nodes 3, got %d", len(results)) + } +} + +func TestQueryWithOptions_XPathStrictEOFOption(t *testing.T) { + validXPath := "/catalog/book" + garbageXPath := "/catalog/book,foo" // trailing garbage after valid expr + + // Without StrictEOF: should NOT error for garbageXPath + node, err := QueryWithOptions(doc, garbageXPath, xpath.CompileOptions{}) + if err != nil { + t.Fatalf("unexpected error without StrictEOF: %v", err) + } + if node == nil || node.Data != "book" { + t.Fatalf("expected to find node, got: %#v", node) + } + + // With StrictEOF: should error for garbageXPath + _, err = QueryWithOptions(doc, garbageXPath, xpath.CompileOptions{StrictEOF: true}) + if err == nil { + t.Fatal("expected error with StrictEOF and garbage XPath, but got nil") + } + if err != nil && !strings.Contains(err.Error(), "unexpected token after end of expression") { + t.Fatalf("expected strict EOF error, got: %v", err) + } + + // With StrictEOF: should NOT error for valid XPath + node, err = QueryWithOptions(doc, validXPath, xpath.CompileOptions{StrictEOF: true}) + if err != nil { + t.Fatalf("unexpected error with StrictEOF and valid XPath: %v", err) + } + if node == nil || node.Data != "book" { + t.Fatalf("expected to find node, got: %#v", node) + } +} + +func TestQueryAllWithOptions_XPathStrictEOFOption(t *testing.T) { + validXPath := "/catalog/book" + garbageXPath := "/catalog/book,foo" // trailing garbage after valid expr + + // Without StrictEOF: should NOT error for garbageXPath + nodes, err := QueryAllWithOptions(doc, garbageXPath, xpath.CompileOptions{}) + if err != nil { + t.Fatalf("unexpected error without StrictEOF: %v", err) + } + if len(nodes) != 3 || nodes[0].Data != "book" || nodes[1].Data != "book" || nodes[2].Data != "book" { + t.Fatalf("expected to find three nodes, got: %#v", nodes) + } + + // With StrictEOF: should error for garbageXPath + _, err = QueryAllWithOptions(doc, garbageXPath, xpath.CompileOptions{StrictEOF: true}) + if err == nil { + t.Fatal("expected error with StrictEOF and garbage XPath, but got nil") + } + if err != nil && !strings.Contains(err.Error(), "unexpected token after end of expression") { + t.Fatalf("expected strict EOF error, got: %v", err) + } + + // With StrictEOF: should NOT error for valid XPath + nodes, err = QueryAllWithOptions(doc, validXPath, xpath.CompileOptions{StrictEOF: true}) + if err != nil { + t.Fatalf("unexpected error with StrictEOF and valid XPath: %v", err) + } + if len(nodes) != 3 || nodes[0].Data != "book" || nodes[1].Data != "book" || nodes[2].Data != "book" { + t.Fatalf("expected to find three nodes, got: %#v", nodes) + } }