diff --git a/src/include/detail/linalg/tdb_matrix.h b/src/include/detail/linalg/tdb_matrix.h index ad2eeaca8..37fcf9ae6 100644 --- a/src/include/detail/linalg/tdb_matrix.h +++ b/src/include/detail/linalg/tdb_matrix.h @@ -351,20 +351,44 @@ class tdbBlockedMatrix : public MatrixBase { auto layout_order = schema_.cell_order(); - // Create a query + // Read TileDB data + size_t read_batch_size_cells = get_read_batch_size_cells(ctx_); + size_t total_size = elements_to_load * dimension; + size_t offset = 0; tiledb::Query query(ctx_, *array_); - query.set_subarray(subarray) - .set_layout(layout_order) - .set_data_buffer(attr_name, this->data(), elements_to_load * dimension); - tiledb_helpers::submit_query(tdb_func__, uri_, query); + query.set_subarray(subarray).set_layout(layout_order); + tiledb::Query::Status status; + do { + // Submit query and get status + size_t request_size = read_batch_size_cells; + if (offset + read_batch_size_cells > total_size) { + request_size = total_size - offset; + } + query.set_data_buffer(attr_name, this->data() + offset, request_size); + query.submit(); + status = query.query_status(); + + auto num_results = query.result_buffer_elements()[attr_name].second; + if (num_results == 0) { + throw std::runtime_error( + "Read error: Got empty results while expecting to retrieve more " + "values."); + } + offset += num_results; + } while (status == tiledb::Query::Status::INCOMPLETE && + offset < total_size); + // Handle errors + if (status == tiledb::Query::Status::COMPLETE && offset != total_size) { + throw std::runtime_error( + "Read error: Read status COMPLETE but result size was different " + "than expected: " + + std::to_string(offset) + " != " + std::to_string(total_size)); + } + if (status != tiledb::Query::Status::COMPLETE) { + throw std::runtime_error("Read error: Query status not COMPLETE"); + } _memory_data.insert_entry( tdb_func__, elements_to_load * dimension * sizeof(T)); - - // @todo Handle incomplete queries. - if (tiledb::Query::Status::COMPLETE != query.query_status()) { - throw std::runtime_error("Query status is not complete"); - } - num_loads_++; return true; } diff --git a/src/include/detail/linalg/tdb_matrix_with_ids.h b/src/include/detail/linalg/tdb_matrix_with_ids.h index 4dc0a3956..9a9947b84 100644 --- a/src/include/detail/linalg/tdb_matrix_with_ids.h +++ b/src/include/detail/linalg/tdb_matrix_with_ids.h @@ -224,19 +224,46 @@ class tdbBlockedMatrixWithIds auto layout_order = ids_schema_.cell_order(); this->ids().resize(elements_to_load * dimension); - // Create a query + + // Read TileDB data + size_t read_batch_size_cells = get_read_batch_size_cells(this->ctx_); + size_t total_size = elements_to_load * dimension; + size_t offset = 0; + auto ptr = this->ids().data(); tiledb::Query query(this->ctx_, *ids_array_); - query.set_subarray(subarray) - .set_layout(layout_order) - .set_data_buffer(attr_name, this->ids()); - tiledb_helpers::submit_query(tdb_func__, ids_uri_, query); + query.set_subarray(subarray).set_layout(layout_order); + tiledb::Query::Status status; + do { + // Submit query and get status + size_t request_size = read_batch_size_cells; + if (offset + read_batch_size_cells > total_size) { + request_size = total_size - offset; + } + query.set_data_buffer(attr_name, ptr + offset, request_size); + tiledb_helpers::submit_query(tdb_func__, ids_uri_, query); + status = query.query_status(); + + auto num_results = query.result_buffer_elements()[attr_name].second; + if (num_results == 0) { + throw std::runtime_error( + "Read error: Got empty results while expecting to retrieve more " + "values."); + } + offset += num_results; + } while (status == tiledb::Query::Status::INCOMPLETE && + offset < total_size); + // Handle errors + if (status == tiledb::Query::Status::COMPLETE && offset != total_size) { + throw std::runtime_error( + "Read error: Read status COMPLETE but result size was different " + "than expected: " + + std::to_string(offset) + " != " + std::to_string(total_size)); + } + if (status != tiledb::Query::Status::COMPLETE) { + throw std::runtime_error("Read error: Query status not COMPLETE"); + } _memory_data.insert_entry( tdb_func__, elements_to_load * dimension * sizeof(T)); - // @todo Handle incomplete queries. - if (tiledb::Query::Status::COMPLETE != query.query_status()) { - throw std::runtime_error("Query status for IDs is not complete"); - } - return true; } }; // tdbBlockedMatrixWithIds diff --git a/src/include/detail/linalg/tdb_partitioned_matrix.h b/src/include/detail/linalg/tdb_partitioned_matrix.h index 8d62a1aa0..8f609b583 100644 --- a/src/include/detail/linalg/tdb_partitioned_matrix.h +++ b/src/include/detail/linalg/tdb_partitioned_matrix.h @@ -559,24 +559,45 @@ class tdbPartitionedMatrix auto cell_order = partitioned_vectors_schema_.cell_order(); auto layout_order = cell_order; - - tiledb::Query query(ctx_, *(this->partitioned_vectors_array_)); - auto ptr = this->data(); - query.set_subarray(subarray) - .set_layout(layout_order) - .set_data_buffer(attr_name, ptr, col_count * dimension); - // tiledb_helpers::submit_query(tdb_func__, partitioned_vectors_uri_, - // query); - query.submit(); - _memory_data.insert_entry(tdb_func__, col_count * dimension * sizeof(T)); - // assert(tiledb::Query::Status::COMPLETE == query.query_dstatus()); - auto qs = query.query_status(); - // @todo Handle incomplete queries. - if (tiledb::Query::Status::COMPLETE != query.query_status()) { - throw std::runtime_error("Query status is not complete -- fix me"); + // Read TileDB data + size_t read_batch_size_cells = get_read_batch_size_cells(ctx_); + size_t total_size = col_count * dimension; + size_t offset = 0; + tiledb::Query query(ctx_, *(this->partitioned_vectors_array_)); + query.set_subarray(subarray).set_layout(layout_order); + tiledb::Query::Status status; + do { + // Submit query and get status + size_t request_size = read_batch_size_cells; + if (offset + read_batch_size_cells > total_size) { + request_size = total_size - offset; + } + query.set_data_buffer(attr_name, ptr + offset, request_size); + query.submit(); + status = query.query_status(); + + auto num_results = query.result_buffer_elements()[attr_name].second; + if (num_results == 0) { + throw std::runtime_error( + "Read error: Got empty results while expecting to retrieve more " + "values."); + } + offset += num_results; + } while (status == tiledb::Query::Status::INCOMPLETE && + offset < total_size); + // Handle errors + if (status == tiledb::Query::Status::COMPLETE && offset != total_size) { + throw std::runtime_error( + "Read error: Read status COMPLETE but result size was different " + "than expected: " + + std::to_string(offset) + " != " + std::to_string(total_size)); } + if (status != tiledb::Query::Status::COMPLETE) { + throw std::runtime_error("Read error: Query status not COMPLETE"); + } + _memory_data.insert_entry(tdb_func__, col_count * dimension * sizeof(T)); } /** diff --git a/src/include/tdb_defs.h b/src/include/tdb_defs.h index 04af65151..a6f9d3fc1 100644 --- a/src/include/tdb_defs.h +++ b/src/include/tdb_defs.h @@ -36,6 +36,28 @@ #include #include +// Default batch size for all TileDB read operations. +// This is expressed in number of array cells read per request. +constexpr size_t DEFAULT_READ_BATCH_SIZE_CELLS = 100000000; +constexpr char READ_BATCH_SIZE_CELLS_CONFIG_KEY[] = + "vectorsearch.read_batch_size_cells"; +static size_t get_read_batch_size_cells(const tiledb::Context& ctx) { + auto config = ctx.config(); + if (config.contains(READ_BATCH_SIZE_CELLS_CONFIG_KEY)) { + auto tmp_str = config.get(READ_BATCH_SIZE_CELLS_CONFIG_KEY); + try { + size_t read_batch_size_cells = std::stoull(tmp_str); + return read_batch_size_cells; + } catch (const std::invalid_argument& e) { + throw std::invalid_argument( + "Failed to convert 'vectorsearch.read_batch_size_cells' to size_t " + "('" + + tmp_str + "')"); + } + } + return DEFAULT_READ_BATCH_SIZE_CELLS; +} + template constexpr bool always_false = false;