Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,9 @@ class Collection: std::enable_shared_from_this<Collection> {
std::vector<std::vector<std::string>>& q_exclude_tokens,
std::vector<std::vector<std::string>>& q_phrases, bool& exclude_operator_prior,
bool& phrase_search_op_prior, std::vector<std::string>& phrase, const std::string& stopwords_set,
const bool& already_segmented, const std::string& locale, std::shared_ptr<Stemmer> stemmer) const;
const bool& already_segmented, const std::string& locale, std::shared_ptr<Stemmer> stemmer,
const std::vector<char>& most_weighted_field_symbols_to_index,
const std::vector<char>& most_weighted_field_token_separators) const;

// PUBLIC OPERATIONS

Expand Down
15 changes: 10 additions & 5 deletions src/collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4583,10 +4583,15 @@ void Collection::process_tokens(std::vector<std::string>& tokens, std::vector<st
std::vector<std::vector<std::string>>& q_exclude_tokens,
std::vector<std::vector<std::string>>& q_phrases, bool& exclude_operator_prior,
bool& phrase_search_op_prior, std::vector<std::string>& phrase, const std::string& stopwords_set,
const bool& already_segmented, const std::string& locale, std::shared_ptr<Stemmer> stemmer) const{
const bool& already_segmented, const std::string& locale, std::shared_ptr<Stemmer> stemmer,
const std::vector<char>& most_weighted_field_symbols_to_index,
const std::vector<char>& most_weighted_field_token_separators) const{

const auto& custom_symbols = most_weighted_field_symbols_to_index.empty() ? symbols_to_index : most_weighted_field_symbols_to_index;
const auto& custom_separators = most_weighted_field_token_separators.empty() ? token_separators : most_weighted_field_token_separators;

auto symbols_to_index_has_minus =
std::find(symbols_to_index.begin(), symbols_to_index.end(), '-') != symbols_to_index.end();
std::find(custom_symbols.begin(), custom_symbols.end(), '-') != custom_symbols.end();

for(auto& token: tokens) {
bool end_of_phrase = false;
Expand Down Expand Up @@ -4624,7 +4629,7 @@ void Collection::process_tokens(std::vector<std::string>& tokens, std::vector<st
if(already_segmented) {
StringUtils::split(token, sub_tokens, " ");
} else {
Tokenizer(token, true, false, locale, symbols_to_index, token_separators).tokenize(sub_tokens);
Tokenizer(token, true, false, locale, custom_symbols, custom_separators).tokenize(sub_tokens);
}

for(auto& sub_token: sub_tokens) {
Expand Down Expand Up @@ -4737,7 +4742,7 @@ void Collection::parse_search_query(const std::string &query, std::vector<std::s
bool phrase_search_op_prior = false;
std::vector<std::string> phrase;

process_tokens(tokens, q_include_tokens, q_exclude_tokens, q_phrases, exclude_operator_prior, phrase_search_op_prior, phrase, stopwords_set, already_segmented, locale, stemmer);
process_tokens(tokens, q_include_tokens, q_exclude_tokens, q_phrases, exclude_operator_prior, phrase_search_op_prior, phrase, stopwords_set, already_segmented, locale, stemmer, most_weighted_field_symbols_to_index, most_weighted_field_token_separators);

if(stemmer) {
exclude_operator_prior = false;
Expand All @@ -4750,7 +4755,7 @@ void Collection::parse_search_query(const std::string &query, std::vector<std::s

process_tokens(tokens_non_stemmed, q_unstemmed_tokens, q_exclude_tokens_dummy, q_phrases_dummy,
exclude_operator_prior, phrase_search_op_prior, phrase, stopwords_set,
already_segmented, locale, nullptr);
already_segmented, locale, nullptr, most_weighted_field_symbols_to_index, most_weighted_field_token_separators);
}
}
}
Expand Down
88 changes: 88 additions & 0 deletions test/collection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4409,6 +4409,94 @@ TEST_F(CollectionTest, QueryParsingForPhraseSearch) {
collectionManager.drop_collection("coll1");
}

TEST_F(CollectionTest, QueryParsingWithFieldLevelSymbolsToIndex) {
nlohmann::json schema = R"({
"name": "coll_symbols",
"fields": [
{"name": "title", "type": "string", "symbols_to_index": ["-"]},
{"name": "points", "type": "int32"}
],
"default_sorting_field": "points"
})"_json;

auto coll_op = collectionManager.create_collection(schema);
ASSERT_TRUE(coll_op.ok());
Collection* coll1 = coll_op.get();

nlohmann::json doc1;
doc1["id"] = "0";
doc1["title"] = "test-driven development";
doc1["points"] = 100;

nlohmann::json doc2;
doc2["id"] = "1";
doc2["title"] = "test driven development";
doc2["points"] = 90;

ASSERT_TRUE(coll1->add(doc1.dump()).ok());
ASSERT_TRUE(coll1->add(doc2.dump()).ok());

std::vector<std::string> q_include_tokens, q_unstemmed_tokens;
std::vector<std::vector<std::string>> q_exclude_tokens;
std::vector<std::vector<std::string>> q_phrases;

const auto& search_schema = coll1->get_schema();
std::vector<char> field_symbols_to_index;
std::vector<char> field_token_separators;

for (const auto& field : search_schema) {
if (field.name == "title") {
field_symbols_to_index = field.symbols_to_index;
field_token_separators = field.token_separators;
break;
}
}

// query with hyphenated term
std::string q = "test-driven";
coll1->parse_search_query(q, q_include_tokens, q_unstemmed_tokens, q_exclude_tokens, q_phrases,
"en", false, "", nullptr, field_symbols_to_index, field_token_separators);

// field-level symbols_to_index including "-", "test-driven" should be kept as one token
ASSERT_EQ(1, q_include_tokens.size());
ASSERT_EQ("test-driven", q_include_tokens[0]);
ASSERT_EQ(0, q_exclude_tokens.size());
ASSERT_EQ(0, q_phrases.size());

// multiple hyphenated terms
q = "test-driven code-review";
q_include_tokens.clear();
q_unstemmed_tokens.clear();
q_exclude_tokens.clear();
q_phrases.clear();

coll1->parse_search_query(q, q_include_tokens, q_unstemmed_tokens, q_exclude_tokens, q_phrases,
"en", false, "", nullptr, field_symbols_to_index, field_token_separators);

ASSERT_EQ(2, q_include_tokens.size());
ASSERT_EQ("test-driven", q_include_tokens[0]);
ASSERT_EQ("code-review", q_include_tokens[1]);

// phrase search (should still respect field-level symbols in sub-tokenization)
q = "\"test-driven development\"";
q_include_tokens.clear();
q_unstemmed_tokens.clear();
q_exclude_tokens.clear();
q_phrases.clear();

coll1->parse_search_query(q, q_include_tokens, q_unstemmed_tokens, q_exclude_tokens, q_phrases,
"en", false, "", nullptr, field_symbols_to_index, field_token_separators);

ASSERT_EQ(1, q_include_tokens.size());
ASSERT_EQ("*", q_include_tokens[0]);
ASSERT_EQ(1, q_phrases.size());
ASSERT_EQ(2, q_phrases[0].size());
ASSERT_EQ("test-driven", q_phrases[0][0]);
ASSERT_EQ("development", q_phrases[0][1]);

collectionManager.drop_collection("coll_symbols");
}

TEST_F(CollectionTest, WildcardQueryBy) {
nlohmann::json schema = R"({
"name": "posts",
Expand Down