Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class AbstractIndex
#ifdef EXEC_ENV_OLS
virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0;
#else
virtual void load(const IndexLoadParams& index_load_params) = 0;

virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String) = 0;
#endif

Expand Down
3 changes: 3 additions & 0 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l);
#else
DISKANN_DLLEXPORT void load(const IndexLoadParams& load_params);

DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String);

#endif

// get some private variables
Expand Down
39 changes: 39 additions & 0 deletions include/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,43 @@ class IndexWriteParametersBuilder
uint32_t _num_diverse_build{ defaults::NUM_DIVERSE_BUILD };
};

struct IndexLoadParams
{
std::string index_file_path;
uint32_t num_threads{defaults::NUM_THREADS};
uint32_t search_list_size{defaults::SEARCH_LIST_SIZE};
LabelFormatType label_format_type{LabelFormatType::String};

// Optional file paths - if empty, will be derived from index_file_path
std::string data_file_path;
std::string tags_file_path;
std::string delete_set_file_path;
std::string graph_file_path;
std::string labels_file_path;
std::string labels_to_medoids_file_path;
std::string labels_map_file_path;
std::string seller_file_path;
std::string bitmask_label_file_path;
std::string integer_label_file_path;
std::string universal_label_file_path;

IndexLoadParams() = default;

IndexLoadParams(const std::string &index_file_path, uint32_t num_threads = defaults::NUM_THREADS,
uint32_t search_list_size = defaults::SEARCH_LIST_SIZE,
LabelFormatType label_format_type = LabelFormatType::String)
: index_file_path(index_file_path), num_threads(num_threads), search_list_size(search_list_size),
label_format_type(label_format_type)
{
}

IndexLoadParams(const char *index_file_path, uint32_t num_threads = defaults::NUM_THREADS,
uint32_t search_list_size = defaults::SEARCH_LIST_SIZE,
LabelFormatType label_format_type = LabelFormatType::String)
: index_file_path(index_file_path), num_threads(num_threads), search_list_size(search_list_size),
label_format_type(label_format_type)
{
}
};

} // namespace diskann
53 changes: 38 additions & 15 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,26 @@ size_t Index<T, TagT, LabelT>::load_delete_set(const std::string &filename)

// load the index from file and update the max_degree, cur (navigating
// node loc), and _final_graph (adjacency list)
template <typename T, typename TagT, typename LabelT>

#ifdef EXEC_ENV_OLS
template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l)
{
#else
template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type)
{
IndexLoadParams load_params(filename, num_threads, search_l, label_format_type);
load(load_params);
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::load(const IndexLoadParams & load_params)
{
const std::string &filename = load_params.index_file_path;
const uint32_t num_threads = load_params.num_threads;
const uint32_t search_l = load_params.search_list_size;
const LabelFormatType label_format_type = load_params.label_format_type;
#endif
std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
std::unique_lock<std::shared_timed_mutex> cl(_consolidate_lock);
Expand All @@ -598,20 +611,27 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui

size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0, label_num_pts = 0;

std::string mem_index_file(filename);
std::string labels_file = mem_index_file + "_labels.txt";
std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt";
std::string labels_map_file = mem_index_file + "_labels_map.txt";
// Use file paths from load_params if provided, otherwise derive from index_file_path
std::string labels_file = load_params.labels_file_path.empty()
? filename + "_labels.txt" : load_params.labels_file_path;
std::string labels_to_medoids = load_params.labels_to_medoids_file_path.empty()
? filename + "_labels_to_medoids.txt" : load_params.labels_to_medoids_file_path;
std::string labels_map_file = load_params.labels_map_file_path.empty()
? filename + "_labels_map.txt" : load_params.labels_map_file_path;

if (!_save_as_one_file)
{
// For DLVS Store, we will not support saving the index in multiple
// files.
#ifndef EXEC_ENV_OLS
std::string data_file = std::string(filename) + ".data";
std::string tags_file = std::string(filename) + ".tags";
std::string delete_set_file = std::string(filename) + ".del";
std::string graph_file = std::string(filename);
std::string data_file = load_params.data_file_path.empty()
? filename + ".data" : load_params.data_file_path;
std::string tags_file = load_params.tags_file_path.empty()
? filename + ".tags" : load_params.tags_file_path;
std::string delete_set_file = load_params.delete_set_file_path.empty()
? filename + ".del" : load_params.delete_set_file_path;
std::string graph_file = load_params.graph_file_path.empty()
? filename : load_params.graph_file_path;
data_file_num_pts = load_data(data_file);
this->_table_stats.node_count = data_file_num_pts;
this->_table_stats.node_mem_usage = this->_data_store->get_data_size();
Expand Down Expand Up @@ -646,8 +666,9 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}

std::string index_seller_file = std::string(filename) + "_sellers.bin";
std::string old_index_seller_file = std::string(filename) + "_sellers.txt";
std::string index_seller_file = load_params.seller_file_path.empty()
? filename + "_sellers.bin" : load_params.seller_file_path;
std::string old_index_seller_file = filename + "_sellers.txt";
if (file_exists(index_seller_file))
{
//uint64_t nrows_seller_file;
Expand Down Expand Up @@ -678,8 +699,10 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui
_diverse_index = true;
}

std::string bitmask_label_file = std::string(filename) + "_bitmask_labels.bin";
std::string integer_label_file = std::string(filename) + "_integer_labels.bin";
std::string bitmask_label_file = load_params.bitmask_label_file_path.empty()
? filename + "_bitmask_labels.bin" : load_params.bitmask_label_file_path;
std::string integer_label_file = load_params.integer_label_file_path.empty()
? filename + "_integer_labels.bin" : load_params.integer_label_file_path;
if (file_exists(labels_file)
|| file_exists(bitmask_label_file)
|| file_exists(integer_label_file))
Expand Down Expand Up @@ -741,8 +764,8 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui
_label_to_start_id.clear();
label_helper().load_label_medoids(labels_to_medoids, _label_to_start_id);

std::string universal_label_file(filename);
universal_label_file += "_universal_label.txt";
std::string universal_label_file = load_params.universal_label_file_path.empty()
? filename + "_universal_label.txt" : load_params.universal_label_file_path;
if (file_exists(universal_label_file))
{
std::ifstream universal_label_reader(universal_label_file);
Expand Down
Loading