Skip to content

Commit c49dc61

Browse files
BagritsevichStepanromange
authored andcommitted
fix(rax_tree): Fix crash caused by destructor in RaxTreeMap (#4228)
* fix(rax_tree): Fix double raxStop call in the SeekIterator fixes #4172 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor(rax_tree): Address comments Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> --------- Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
1 parent 976586c commit c49dc61

File tree

2 files changed

+63
-23
lines changed

2 files changed

+63
-23
lines changed

src/core/search/rax_tree.h

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,37 @@ template <typename V> struct RaxTreeMap {
2424

2525
// Simple seeking iterator
2626
struct SeekIterator {
27-
friend struct FindIterator;
28-
2927
SeekIterator() {
30-
raxStart(&it_, nullptr);
31-
it_.node = nullptr;
32-
}
33-
34-
~SeekIterator() {
35-
raxStop(&it_);
28+
it_.rt = nullptr;
3629
}
3730

38-
SeekIterator(SeekIterator&&) = delete; // self-referential
39-
SeekIterator(const SeekIterator&) = delete; // self-referential
40-
4131
SeekIterator(rax* tree, const char* op, std::string_view key) {
4232
raxStart(&it_, tree);
43-
raxSeek(&it_, op, to_key_ptr(key), key.size());
44-
operator++();
33+
if (raxSeek(&it_, op, to_key_ptr(key), key.size())) { // Successfuly seeked
34+
operator++();
35+
} else {
36+
InvalidateIterator();
37+
}
4538
}
4639

4740
explicit SeekIterator(rax* tree) : SeekIterator(tree, "^", std::string_view{nullptr, 0}) {
4841
}
4942

43+
/* Remove copy/move constructors to avoid double iterator invalidation */
44+
SeekIterator(SeekIterator&&) = delete;
45+
SeekIterator(const SeekIterator&) = delete;
46+
SeekIterator& operator=(SeekIterator&&) = delete;
47+
SeekIterator& operator=(const SeekIterator&) = delete;
48+
49+
~SeekIterator() {
50+
if (IsValid()) {
51+
InvalidateIterator();
52+
}
53+
}
54+
5055
bool operator==(const SeekIterator& rhs) const {
56+
if (!IsValid() || !rhs.IsValid())
57+
return !IsValid() && !rhs.IsValid();
5158
return it_.node == rhs.it_.node;
5259
}
5360

@@ -56,31 +63,40 @@ template <typename V> struct RaxTreeMap {
5663
}
5764

5865
SeekIterator& operator++() {
59-
if (!raxNext(&it_)) {
60-
raxStop(&it_);
61-
it_.node = nullptr;
66+
int next_result = raxNext(&it_);
67+
if (!next_result) { // OOM or we reached the end of the tree
68+
InvalidateIterator();
6269
}
6370
return *this;
6471
}
6572

73+
/* After operator++() the first value (string_view) is invalid. So make sure your copied it to
74+
* string */
6675
std::pair<std::string_view, V&> operator*() const {
76+
assert(IsValid() && it_.node && it_.node->iskey && it_.data);
6777
return {std::string_view{reinterpret_cast<const char*>(it_.key), it_.key_len},
6878
*reinterpret_cast<V*>(it_.data)};
6979
}
7080

81+
bool IsValid() const {
82+
return it_.rt;
83+
}
84+
7185
private:
86+
void InvalidateIterator() {
87+
raxStop(&it_);
88+
it_.rt = nullptr;
89+
}
90+
7291
raxIterator it_;
7392
};
7493

7594
// Result of find() call. Inherits from pair to mimic iterator interface, not incrementable.
7695
struct FindIterator : public std::optional<std::pair<std::string, V&>> {
7796
bool operator==(const SeekIterator& rhs) const {
78-
if (this->has_value() != !bool(rhs.it_.flags & RAX_ITER_EOF))
79-
return false;
80-
if (!this->has_value())
81-
return true;
82-
return (*this)->first ==
83-
std::string_view{reinterpret_cast<const char*>(rhs.it_.key), rhs.it_.key_len};
97+
if (!this->has_value() || !rhs.IsValid())
98+
return !this->has_value() && !rhs.IsValid();
99+
return (*this)->first == (*rhs).first;
84100
}
85101

86102
bool operator!=(const SeekIterator& rhs) const {
@@ -160,7 +176,7 @@ std::pair<typename RaxTreeMap<V>::FindIterator, bool> RaxTreeMap<V>::try_emplace
160176

161177
V* old = nullptr;
162178
raxInsert(tree_, to_key_ptr(key), key.size(), ptr, reinterpret_cast<void**>(&old));
163-
assert(old == nullptr);
179+
assert(!old);
164180

165181
auto it = std::make_optional(std::pair<std::string, V&>(std::string(key), *ptr));
166182
return std::make_pair(std::move(FindIterator{it}), true);

src/core/search/rax_tree_test.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,28 @@ TEST_F(RaxTreeTest, Find) {
104104
EXPECT_TRUE(map.find(string_view{}) == map.end());
105105
}
106106

107+
/* Run with mimalloc to make sure there is no double free */
108+
TEST_F(RaxTreeTest, Iterate) {
109+
const char* kKeys[] = {
110+
"aaaaaaaaaaaaaaaaaaaa",
111+
"bbbbbbbbbbbbbbbbbbbbbb"
112+
"cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
113+
"dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd"
114+
"eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
115+
};
116+
117+
RaxTreeMap<int> map(pmr::get_default_resource());
118+
for (const char* key : kKeys) {
119+
map.try_emplace(key, 2);
120+
}
121+
122+
for (auto it = map.begin(); it != map.end(); ++it) {
123+
EXPECT_EQ((*it).second, 2);
124+
}
125+
126+
for (auto it = map.begin(); it != map.end(); ++it) {
127+
EXPECT_EQ((*it).second, 2);
128+
}
129+
}
130+
107131
} // namespace dfly::search

0 commit comments

Comments
 (0)