diff --git a/.gitignore b/.gitignore index f80e5c682..b6babdb0f 100644 --- a/.gitignore +++ b/.gitignore @@ -357,6 +357,7 @@ MigrationBackup/ cscope* build/ +build-*/ build_linux/ !.github/actions/build diff --git a/README.md b/README.md index a20a1d671..19de43011 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,129 @@ sudo sh l_BaseKit_p_2022.1.2.146.sh -a --components intel.oneapi.lin.mkl.devel - mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j ``` +### AVX-512 BF16 (optional acceleration) + +DiskANN includes an optional AVX-512 BF16-accelerated kernel for `bf16` distance computations. + +- Compile-time: the AVX-512 BF16 kernel is enabled only when the compiler supports the required flags; it is compiled for a single source file (`src/bf16_simd_kernels.cpp`) so the rest of the project is not forced to use AVX-512. +- Runtime: `bf16` distance code automatically dispatches to the AVX-512 BF16 kernel only when the running CPU/OS supports AVX-512 BF16; otherwise it falls back to the scalar implementation. + +You can control this with the following CMake options (non-MSVC builds): + +- Default (try to enable when supported): + ```bash + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DDISKANN_AVX512BF16=ON + cmake --build build -j + ``` +- Force disable: + ```bash + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DDISKANN_AVX512BF16=OFF + cmake --build build -j + ``` +- Force enable (fail configure if compiler does not support AVX-512 BF16 flags): + ```bash + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DDISKANN_FORCE_AVX512BF16=ON + cmake --build build -j + ``` + +### AMX BF16 (optional acceleration) + +DiskANN includes an optional AMX BF16-accelerated kernel for `bf16` inner-product computations. + +- Compile-time: the AMX BF16 kernel is enabled only when the compiler supports the required flags; it is compiled for a single source file (`src/bf16_amx_kernels.cpp`) so the rest of the project is not forced to use AMX. +- Runtime: `bf16` distance code automatically dispatches to the AMX kernel only when the running CPU/OS supports AMX and the current thread is permitted to use AMX tile state (Linux `arch_prctl` request). If unavailable, it falls back to AVX-512 BF16 (if enabled) and then scalar. + +You can control this with the following CMake options (non-MSVC builds): + +- Default (try to enable when supported): + ```bash + cmake -S . -B build-amx -DCMAKE_BUILD_TYPE=Release -DDISKANN_AMXBF16=ON + cmake --build build-amx -j + ``` +- Force disable: + ```bash + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DDISKANN_AMXBF16=OFF + cmake --build build -j + ``` +- Force enable (fail configure if compiler does not support AMX flags): + ```bash + cmake -S . -B build-amx -DCMAKE_BUILD_TYPE=Release -DDISKANN_FORCE_AMXBF16=ON + cmake --build build-amx -j + ``` + +### AVX-512 vs AMX (build one or the other) + +If you want to do a strict A/B build where only one ISA path is compiled/used, configure two separate build directories. + +- AVX-512 BF16 only (no AMX codegen): + ```bash + cmake -S . -B build-avx512 -DCMAKE_BUILD_TYPE=Release \ + -DDISKANN_FORCE_AVX512BF16=ON \ + -DDISKANN_AMXBF16=OFF + cmake --build build-avx512 -j + ``` + +- AMX BF16 only (no AVX-512 BF16 code path): + ```bash + cmake -S . -B build-amx -DCMAKE_BUILD_TYPE=Release \ + -DDISKANN_FORCE_AMXBF16=ON \ + -DDISKANN_AVX512BF16=OFF + cmake --build build-amx -j + ``` + +Note: some toolchains/build scripts add global `-march=native`. When AMX is disabled (`-DDISKANN_AMXBF16=OFF`), DiskANN explicitly compiles the AMX translation unit with `-mno-amx-tile`/`-mno-amx-bf16` (when supported) to avoid accidentally emitting AMX instructions. + +### RaBitQ main-search approximate scoring (optional, runtime-gated) + +DiskANN also supports using RaBitQ multi-bit codes as the *main traversal approximate scorer* in SSD search (inside `PQFlashIndex::cached_beam_search`). + +- Default behavior is unchanged: traversal uses the existing PQ distance lookup. +- When enabled, traversal scoring uses RaBitQ approximate inner product (converted to a “distance” as `-ip`) while keeping the rest of the search logic intact. + +#### Runtime enable + +Set: + +```bash +export DISKANN_USE_RABITQ_MAIN_APPROX=1 +``` + +If the environment variable is set but RaBitQ main codes are missing or incompatible, DiskANN prints a one-time message and automatically falls back to PQ. + +#### Main code file naming + +Preferred sidecar file name: + +```text +_rabitq_main.bin +``` + +For example, if your SSD index file is `foo_disk.index`, the RaBitQ main code file should be `foo_disk.index_rabitq_main.bin`. + +#### Generating main codes during disk index build + +You can generate the main-search sidecar automatically as part of disk index build: + +```bash +./build/apps/build_disk_index \ + ... \ + --dist_fn mips \ + --build_rabitq_main_codes \ + --rabitq_nb_bits 4 +``` + +This produces: + +```text +_disk.index_rabitq_main.bin +``` + +#### Constraints + +- Currently supported only for `dist_fn=mips` / `Metric::INNER_PRODUCT`. +- The RaBitQ code `dim` must match the index `_data_dim` (post any preprocessing/augmentation), otherwise main-search RaBitQ is disabled and the search falls back to PQ. +- Ensure you run the updated `search_disk_index`/`build_disk_index` binaries from the same build directory that contains this feature. + ## Windows build: The Windows version has been tested with Enterprise editions of Visual Studio 2022, 2019 and 2017. It should work with the Community and Professional editions as well without any changes. diff --git a/apps/CMakeLists.txt b/apps/CMakeLists.txt index e42c0b6cb..a848c4f10 100644 --- a/apps/CMakeLists.txt +++ b/apps/CMakeLists.txt @@ -22,6 +22,7 @@ target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${D add_executable(range_search_disk_index range_search_disk_index.cpp) target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + add_executable(test_streaming_scenario test_streaming_scenario.cpp) target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index f48b61726..0c58cf92c 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -10,9 +10,63 @@ #include "index.h" #include "partition.h" #include "program_options_utils.hpp" +#include "bfloat16.h" namespace po = boost::program_options; +static int convert_bf16_bin_to_f32_bin(const std::string &bf16_path, const std::string &f32_path) +{ + std::ifstream reader(bf16_path, std::ios::binary); + if (!reader) + { + diskann::cerr << "Error: could not open input file " << bf16_path << std::endl; + return -1; + } + std::ofstream writer(f32_path, std::ios::binary); + if (!writer) + { + diskann::cerr << "Error: could not open output file " << f32_path << std::endl; + return -1; + } + + uint32_t npts = 0, dim = 0; + reader.read(reinterpret_cast(&npts), sizeof(uint32_t)); + reader.read(reinterpret_cast(&dim), sizeof(uint32_t)); + if (!reader) + { + diskann::cerr << "Error: failed to read header from " << bf16_path << std::endl; + return -1; + } + writer.write(reinterpret_cast(&npts), sizeof(uint32_t)); + writer.write(reinterpret_cast(&dim), sizeof(uint32_t)); + + constexpr size_t kBlockElems = 1u << 20; // 1M elements (~2MB bf16, ~4MB float) + std::vector in_buf; + std::vector out_buf; + in_buf.resize(kBlockElems); + out_buf.resize(kBlockElems); + + const uint64_t total_elems = static_cast(npts) * static_cast(dim); + uint64_t done = 0; + while (done < total_elems) + { + const size_t this_block = static_cast(std::min(kBlockElems, total_elems - done)); + reader.read(reinterpret_cast(in_buf.data()), this_block * sizeof(diskann::bfloat16)); + if (!reader) + { + diskann::cerr << "Error: failed reading bf16 payload from " << bf16_path << std::endl; + return -1; + } + for (size_t i = 0; i < this_block; i++) + { + out_buf[i] = static_cast(in_buf[i]); + } + writer.write(reinterpret_cast(out_buf.data()), this_block * sizeof(float)); + done += this_block; + } + return 0; +} + int main(int argc, char **argv) { std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, @@ -21,6 +75,8 @@ int main(int argc, char **argv) float B, M; bool append_reorder_data = false; bool use_opq = false; + bool build_rabitq_main_codes = false; + uint32_t rabitq_nb_bits = 4; po::options_description desc{ program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")}; @@ -63,6 +119,14 @@ int main(int argc, char **argv) optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false), "Include full precision data in the index. Use only in " "conjuction with compressed data on SSD."); + + optional_configs.add_options()( + "build_rabitq_main_codes", po::bool_switch()->default_value(false), + "Generate RaBitQ main-search codes sidecar file (_disk.index_rabitq_main.bin). " + "Only meaningful for dist_fn=mips."); + + optional_configs.add_options()("rabitq_nb_bits", po::value(&rabitq_nb_bits)->default_value(4), + "Bits per dimension for RaBitQ codes (1..9)"); optional_configs.add_options()("build_PQ_bytes", po::value(&build_PQ)->default_value(0), program_options_utils::BUIlD_GRAPH_PQ_BYTES); optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false), @@ -94,6 +158,8 @@ int main(int argc, char **argv) append_reorder_data = true; if (vm["use_opq"].as()) use_opq = true; + if (vm["build_rabitq_main_codes"].as()) + build_rabitq_main_codes = true; } catch (const std::exception &ex) { @@ -124,10 +190,11 @@ int main(int argc, char **argv) << std::endl; return -1; } - if (data_type != std::string("float")) + if (data_type != std::string("float") && data_type != std::string("bf16") && + data_type != std::string("bfloat16")) { std::cout << "Error: Appending data for reordering currently only " - "supported for float data type." + "supported for float/bf16 data type." << std::endl; return -1; } @@ -137,7 +204,9 @@ int main(int argc, char **argv) std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " + std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " + std::string(std::to_string(append_reorder_data)) + " " + - std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)); + std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)) + " " + + std::string(std::to_string(build_rabitq_main_codes)) + " " + + std::string(std::to_string(rabitq_nb_bits)); try { @@ -155,6 +224,12 @@ int main(int argc, char **argv) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, universal_label, filter_threshold, Lf); + else if (data_type == std::string("bf16") || data_type == std::string("bfloat16")) + { + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, + use_filters, label_file, universal_label, filter_threshold, Lf); + } else { diskann::cerr << "Error. Unsupported data type" << std::endl; @@ -175,6 +250,13 @@ int main(int argc, char **argv) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, universal_label, filter_threshold, Lf); + else if (data_type == std::string("bf16") || data_type == std::string("bfloat16")) + { + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), + params.c_str(), metric, use_opq, codebook_prefix, + use_filters, label_file, universal_label, + filter_threshold, Lf); + } else { diskann::cerr << "Error. Unsupported data type" << std::endl; diff --git a/apps/build_rabitq_reorder_codes.cpp b/apps/build_rabitq_reorder_codes.cpp new file mode 100644 index 000000000..75d888957 --- /dev/null +++ b/apps/build_rabitq_reorder_codes.cpp @@ -0,0 +1,25 @@ +#include + +#include +#include +#include +#include +#include +#include + +#include "rabitq.h" +#include "utils.h" + +namespace po = boost::program_options; + +namespace +{ +#pragma pack(push, 1) +struct RaBitQReorderHeader +{ + char magic[8]; + uint32_t version; + uint32_t metric; + uint32_t nb_bits; + #error "build_rabitq_reorder_codes has been removed (RaBitQ reorder prefilter deprecated). Use build_disk_index with --build_rabitq_main_codes instead." + uint64_t num_points; diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 6b0793db7..25a25610f 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -13,6 +13,7 @@ #include "timer.h" #include "percentile_stats.h" #include "program_options_utils.hpp" +#include "bfloat16.h" #ifndef _WINDOWS #include @@ -31,6 +32,57 @@ namespace po = boost::program_options; +static int convert_bf16_bin_to_f32_bin(const std::string &bf16_path, const std::string &f32_path) +{ + std::ifstream reader(bf16_path, std::ios::binary); + if (!reader) + { + diskann::cerr << "Error: could not open input file " << bf16_path << std::endl; + return -1; + } + std::ofstream writer(f32_path, std::ios::binary); + if (!writer) + { + diskann::cerr << "Error: could not open output file " << f32_path << std::endl; + return -1; + } + + uint32_t npts = 0, dim = 0; + reader.read(reinterpret_cast(&npts), sizeof(uint32_t)); + reader.read(reinterpret_cast(&dim), sizeof(uint32_t)); + if (!reader) + { + diskann::cerr << "Error: failed to read header from " << bf16_path << std::endl; + return -1; + } + writer.write(reinterpret_cast(&npts), sizeof(uint32_t)); + writer.write(reinterpret_cast(&dim), sizeof(uint32_t)); + + constexpr size_t kBlockElems = 1u << 20; + std::vector in_buf(kBlockElems); + std::vector out_buf(kBlockElems); + + const uint64_t total_elems = static_cast(npts) * static_cast(dim); + uint64_t done = 0; + while (done < total_elems) + { + const size_t this_block = static_cast(std::min(kBlockElems, total_elems - done)); + reader.read(reinterpret_cast(in_buf.data()), this_block * sizeof(diskann::bfloat16)); + if (!reader) + { + diskann::cerr << "Error: failed reading bf16 payload from " << bf16_path << std::endl; + return -1; + } + for (size_t i = 0; i < this_block; i++) + { + out_buf[i] = static_cast(in_buf[i]); + } + writer.write(reinterpret_cast(out_buf.data()), this_block * sizeof(float)); + done += this_block; + } + return 0; +} + void print_stats(std::string category, std::vector percentiles, std::vector results) { diskann::cout << std::setw(20) << category << ": " << std::flush; @@ -155,7 +207,15 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre { for (uint32_t d = 0; d < warmup_dim; d++) { - warmup[i * warmup_aligned_dim + d] = (T)dis(gen); + const auto sample = dis(gen); + if constexpr (std::is_same_v) + { + warmup[i * warmup_aligned_dim + d] = T(static_cast(sample)); + } + else + { + warmup[i * warmup_aligned_dim + d] = static_cast(sample); + } } } } @@ -396,6 +456,8 @@ int main(int argc, char **argv) return -1; } + const bool is_bf16 = (data_type == std::string("bf16") || data_type == std::string("bfloat16")); + diskann::Metric metric; if (dist_fn == std::string("mips")) { @@ -417,16 +479,16 @@ int main(int argc, char **argv) return -1; } - if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT)) + if ((data_type != std::string("float")) && !is_bf16 && (metric == diskann::Metric::INNER_PRODUCT)) { std::cout << "Currently support only floating point data for Inner Product." << std::endl; return -1; } - if (use_reorder_data && data_type != std::string("float")) + if (use_reorder_data && data_type != std::string("float") && !is_bf16) { std::cout << "Error: Reorder data for reordering currently only " - "supported for float data type." + "supported for float/bf16 data type." << std::endl; return -1; } @@ -452,7 +514,12 @@ int main(int argc, char **argv) if (!query_filters.empty() && label_type == "ushort") { if (data_type == std::string("float")) - return search_disk_index( + return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, + gt_file, num_threads, K, W, num_nodes_to_cache, + search_io_limit, Lvec, fail_if_recall_below, query_filters, + use_reorder_data); + else if (is_bf16) + return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("int8")) @@ -465,7 +532,7 @@ int main(int argc, char **argv) num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else { - std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; + std::cerr << "Unsupported data type. Use float, bf16, int8 or uint8" << std::endl; return -1; } } @@ -475,6 +542,11 @@ int main(int argc, char **argv) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + else if (is_bf16) + return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, + gt_file, num_threads, K, W, num_nodes_to_cache, + search_io_limit, Lvec, fail_if_recall_below, query_filters, + use_reorder_data); else if (data_type == std::string("int8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, @@ -485,7 +557,7 @@ int main(int argc, char **argv) fail_if_recall_below, query_filters, use_reorder_data); else { - std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; + std::cerr << "Unsupported data type. Use float, bf16, int8 or uint8" << std::endl; return -1; } } diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1a9acc285..c3fabda68 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -22,6 +22,7 @@ #include "utils.h" #include "program_options_utils.hpp" #include "index_factory.h" +#include "bfloat16.h" namespace po = boost::program_options; @@ -435,9 +436,15 @@ int main(int argc, char **argv) num_threads, K, print_all_recalls, Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); } + else if (data_type == std::string("bf16") || data_type == std::string("bfloat16")) + { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); + } else { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + std::cout << "Unsupported type. Use float/bf16/int8/uint8" << std::endl; return -1; } } @@ -461,9 +468,16 @@ int main(int argc, char **argv) num_threads, K, print_all_recalls, Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); } + else if (data_type == std::string("bf16") || data_type == std::string("bfloat16")) + { + return search_memory_index(metric, index_path_prefix, result_path, query_file, + gt_file, num_threads, K, print_all_recalls, Lvec, + dynamic, tags, show_qps_per_thread, query_filters, + fail_if_recall_below); + } else { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + std::cout << "Unsupported type. Use float/bf16/int8/uint8" << std::endl; return -1; } } diff --git a/apps/utils/compute_groundtruth.cpp b/apps/utils/compute_groundtruth.cpp index da32fd7c6..1cffa9346 100644 --- a/apps/utils/compute_groundtruth.cpp +++ b/apps/utils/compute_groundtruth.cpp @@ -498,7 +498,8 @@ int main(int argc, char **argv) desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("data_type", po::value(&data_type)->required(), + "data type "); desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); desc.add_options()("base_file", po::value(&base_file)->required(), @@ -531,9 +532,10 @@ int main(int argc, char **argv) return -1; } - if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) + if (data_type != std::string("float") && data_type != std::string("bf16") && data_type != std::string("int8") && + data_type != std::string("uint8")) { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + std::cout << "Unsupported type. float, bf16, int8 and uint8 types are supported." << std::endl; return -1; } @@ -560,6 +562,8 @@ int main(int argc, char **argv) { if (data_type == std::string("float")) aux_main(base_file, query_file, gt_file, K, metric, tags_file); + if (data_type == std::string("bf16")) + aux_main(base_file, query_file, gt_file, K, metric, tags_file); if (data_type == std::string("int8")) aux_main(base_file, query_file, gt_file, K, metric, tags_file); if (data_type == std::string("uint8")) diff --git a/apps/utils/rand_data_gen.cpp b/apps/utils/rand_data_gen.cpp index e89ede800..b5972dcef 100644 --- a/apps/utils/rand_data_gen.cpp +++ b/apps/utils/rand_data_gen.cpp @@ -44,6 +44,43 @@ int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool nor return 0; } +int block_write_bf16(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm, float rand_scale) +{ + auto vec = new float[ndims]; + auto vec_bf16 = new diskann::bfloat16[ndims]; + + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution<> normal_rand{0, 1}; + std::uniform_real_distribution<> unif_dis(1.0, rand_scale); + + for (size_t i = 0; i < npts; i++) + { + float sum = 0; + float scale = 1.0f; + if (rand_scale > 1.0f) + scale = (float)unif_dis(gen); + for (size_t d = 0; d < ndims; ++d) + vec[d] = scale * (float)normal_rand(gen); + if (normalization) + { + for (size_t d = 0; d < ndims; ++d) + sum += vec[d] * vec[d]; + for (size_t d = 0; d < ndims; ++d) + vec[d] = vec[d] * norm / std::sqrt(sum); + } + + for (size_t d = 0; d < ndims; ++d) + vec_bf16[d] = diskann::bfloat16::from_float(vec[d]); + + writer.write((char *)vec_bf16, ndims * sizeof(diskann::bfloat16)); + } + + delete[] vec; + delete[] vec_bf16; + return 0; +} + int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm) { auto vec = new float[ndims]; @@ -120,7 +157,8 @@ int main(int argc, char **argv) desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("data_type", po::value(&data_type)->required(), + "data type "); desc.add_options()("output_file", po::value(&output_file)->required(), "File name for saving the random vectors"); desc.add_options()("ndims,D", po::value(&ndims)->required(), "Dimensoinality of the vector"); @@ -145,9 +183,10 @@ int main(int argc, char **argv) return -1; } - if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) + if (data_type != std::string("float") && data_type != std::string("bf16") && data_type != std::string("int8") && + data_type != std::string("uint8")) { - std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + std::cout << "Unsupported type. float, bf16, int8 and uint8 types are supported." << std::endl; return -1; } @@ -185,6 +224,12 @@ int main(int argc, char **argv) } } + if (data_type == std::string("bf16")) + { + // bf16 follows floating-point generation rules. + // (Scaling is allowed only when not normalized, same as float.) + } + try { std::ofstream writer; @@ -207,6 +252,10 @@ int main(int argc, char **argv) { ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling); } + else if (data_type == std::string("bf16")) + { + ret = block_write_bf16(writer, ndims, cblk_size, normalization, norm, rand_scaling); + } else if (data_type == std::string("int8")) { ret = block_write_int8(writer, ndims, cblk_size, norm); diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index 89856f1fa..beaebb5dc 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -109,6 +109,12 @@ template class AbstractDataStore // how to align the query vector in a consistent manner virtual size_t get_alignment_factor() const = 0; + // Optional: return a direct pointer to the underlying aligned, row-major data buffer. + // This is used for performance-sensitive paths (e.g., graph-build pruning) where + // callers can batch computations without repeatedly calling get_vector(). + // Implementations that do not keep full-precision vectors in memory should return nullptr. + virtual const data_t *get_raw_data() const { return nullptr; } + protected: // Expand the datastore to new_num_points. Returns the new capacity created, // which should be == new_num_points in the normal case. Implementers can also diff --git a/include/bf16_amx_kernels.h b/include/bf16_amx_kernels.h new file mode 100644 index 000000000..067d7a12f --- /dev/null +++ b/include/bf16_amx_kernels.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include "bfloat16.h" + +namespace diskann +{ +// Returns true if this build produced AMX BF16 kernels (i.e. the compiler supported +// AMX BF16 intrinsics for the relevant translation unit). +bool amxbf16_kernels_compiled(); + +// Returns true if the current CPU + OS context is capable of executing AMX BF16 instructions. +// This performs feature detection (CPUID + XCR0) and checks / requests Linux permissions when needed. +bool amxbf16_runtime_available(); + +// Dot product of bf16 vectors with f32 accumulation using AMX BF16. +// If AMX BF16 is not available at runtime, this falls back to a scalar implementation. +float bf16_dot_f32_accum_amx(const bfloat16 *a, const bfloat16 *b, uint32_t length); + +// Batch dot products: computes out[i] = dot(base[i], query) for i in [0, n_vecs). +// base is a row-major matrix of shape [n_vecs x dim]. +// If AMX BF16 is not available at runtime, this falls back to a scalar implementation. +void bf16_dot_f32_accum_amx_batch(const bfloat16 *base, const bfloat16 *query, uint32_t n_vecs, uint32_t dim, + float *out); + +// Matrix of dot products: out[i * n_queries + j] = dot(base[i], queries[j]). +// base is row-major [n_base x dim], queries is row-major [n_queries x dim]. +// If AMX BF16 is not available at runtime, this falls back to a scalar implementation. +void bf16_dot_f32_accum_amx_matmul(const bfloat16 *base, const bfloat16 *queries, uint32_t n_base, uint32_t n_queries, + uint32_t dim, float *out); + +// Matrix of dot products over gathered rows: +// out[i * n_queries + j] = dot(data[base_ids[i]], data[query_ids[j]]) +// where each vector is length `dim` and consecutive vectors are `data_stride` elements apart. +// If AMX BF16 is not available at runtime, this falls back to a scalar implementation. +void bf16_dot_f32_accum_amx_matmul_gather(const bfloat16 *data, uint32_t data_stride, const uint32_t *base_ids, + uint32_t n_base, const uint32_t *query_ids, uint32_t n_queries, uint32_t dim, + float *out); + +} // namespace diskann diff --git a/include/bf16_simd_kernels.h b/include/bf16_simd_kernels.h new file mode 100644 index 000000000..9c1f2cf8f --- /dev/null +++ b/include/bf16_simd_kernels.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +#include "bfloat16.h" + +namespace diskann +{ +// Returns true if this build produced AVX-512 BF16 kernels (i.e. the compiler supported +// AVX-512 BF16 intrinsics for the relevant translation unit). +bool avx512bf16_kernels_compiled(); + +// Dot product of bf16 vectors with f32 accumulation. +// If AVX-512 BF16 kernels are not compiled in, this falls back to a scalar implementation. +float bf16_dot_f32_accum(const bfloat16 *a, const bfloat16 *b, uint32_t length); + +} // namespace diskann diff --git a/include/bfloat16.h b/include/bfloat16.h new file mode 100644 index 000000000..22b398a7d --- /dev/null +++ b/include/bfloat16.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include + +namespace diskann +{ +// Minimal IEEE-754 bfloat16 (bf16) implementation. +// Stores the top 16 bits of a float32, with round-to-nearest-even on conversion. +struct bfloat16 +{ + uint16_t value = 0; + + constexpr bfloat16() = default; + explicit constexpr bfloat16(uint16_t v) : value(v) + { + } + + // Convenience constructor for generic code that does bfloat16(f). + explicit bfloat16(float f) : value(from_float(f).value) {} + + // Convenience constructor for generic code that does static_cast(double_expr). + explicit bfloat16(double f) : value(from_float(static_cast(f)).value) {} + + static inline bfloat16 from_float(float f) + { + uint32_t bits = 0; + std::memcpy(&bits, &f, sizeof(bits)); + + // Round-to-nearest-even: add 0x7FFF + LSB of the truncated part. + const uint32_t lsb = (bits >> 16) & 1u; + bits += 0x7FFFu + lsb; + return bfloat16(static_cast(bits >> 16)); + } + + inline float to_float() const + { + uint32_t bits = static_cast(value) << 16; + float f = 0.0f; + std::memcpy(&f, &bits, sizeof(f)); + return f; + } + + inline operator float() const + { + return to_float(); + } +}; + +// bfloat16 is not a built-in floating point type, but for most DiskANN code +// paths it should be treated as "floating-point-like". +template struct is_floating_point_like : std::is_floating_point +{ +}; + +template <> struct is_floating_point_like : std::true_type +{ +}; + +template inline constexpr bool is_floating_point_like_v = is_floating_point_like::value; + +} // namespace diskann diff --git a/include/disk_utils.h b/include/disk_utils.h index 08f046dcd..7ef0eb888 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -105,4 +105,11 @@ DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std const std::string output_file, const std::string reorder_data_file = std::string("")); +// Store the graph/base payload as T, and (optionally) append reorder vectors stored as ReorderT. +// This is primarily used to support bf16 reorder vectors on SSD. +template +DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file = std::string("")); + } // namespace diskann diff --git a/include/distance.h b/include/distance.h index f3b1de25a..e0d18314c 100644 --- a/include/distance.h +++ b/include/distance.h @@ -1,5 +1,6 @@ #pragma once #include "windows_customizations.h" +#include "bfloat16.h" #include namespace diskann @@ -232,4 +233,31 @@ class AVXNormalizedCosineDistanceFloat : public Distance template Distance *get_distance_function(Metric m); +class DistanceL2BFloat16 : public Distance +{ + public: + DistanceL2BFloat16() : Distance(diskann::Metric::L2) + { + } + DISKANN_DLLEXPORT virtual float compare(const bfloat16 *a, const bfloat16 *b, uint32_t size) const; +}; + +class DistanceCosineBFloat16 : public Distance +{ + public: + DistanceCosineBFloat16() : Distance(diskann::Metric::COSINE) + { + } + DISKANN_DLLEXPORT virtual float compare(const bfloat16 *a, const bfloat16 *b, uint32_t length) const; +}; + +class DistanceInnerProductBFloat16 : public Distance +{ + public: + DistanceInnerProductBFloat16() : Distance(diskann::Metric::INNER_PRODUCT) + { + } + DISKANN_DLLEXPORT virtual float compare(const bfloat16 *a, const bfloat16 *b, uint32_t length) const; +}; + } // namespace diskann diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 0a0a617da..99ba2a2bc 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -62,6 +62,8 @@ template class InMemDataStore : public AbstractDataStore class PQFlashIndex uint64_t _ndims_reorder_vecs = 0; uint64_t _reorder_data_start_sector = 0; uint64_t _nvecs_per_sector = 0; + uint64_t _reorder_bytes_per_element = sizeof(float); + + // Optional: RaBitQ multi-bit codes for *main-search* approximate scoring. + // When enabled (runtime-gated), neighbor expansion scoring uses RaBitQ + // instead of PQ distance lookup, while preserving the PQ path as default. + bool _rabitq_main_codes_exist = false; + uint8_t *_rabitq_main_codes = nullptr; + uint64_t _rabitq_main_code_size = 0; + uint64_t _rabitq_main_dim = 0; + uint32_t _rabitq_main_nb_bits = 0; + uint32_t _rabitq_main_metric = 0; diskann::Metric metric = diskann::Metric::L2; diff --git a/include/program_options_utils.hpp b/include/program_options_utils.hpp index 2be60595b..deeeb478d 100644 --- a/include/program_options_utils.hpp +++ b/include/program_options_utils.hpp @@ -18,7 +18,8 @@ const std::string make_program_description(const char *executable_name, const ch } // Required parameters -const char *DATA_TYPE_DESCRIPTION = "data type, one of {int8, uint8, float} - float is single precision (32 bit)"; +const char *DATA_TYPE_DESCRIPTION = + "data type, one of {int8, uint8, float, bf16} - float is single precision (32 bit), bf16 is bfloat16 (16 bit)"; const char *DISTANCE_FUNCTION_DESCRIPTION = "distance function {l2, mips, fast_l2, cosine}. 'fast l2' and 'mips' only support data_type float"; const char *INDEX_PATH_PREFIX_DESCRIPTION = "Path prefix to the index, e.g. '/mnt/data/my_ann_index'"; diff --git a/include/rabitq.h b/include/rabitq.h new file mode 100644 index 000000000..7e276afed --- /dev/null +++ b/include/rabitq.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include + +namespace diskann +{ +namespace rabitq +{ + +enum class Metric : uint32_t +{ + L2 = 0, + INNER_PRODUCT = 1, +}; + +#pragma pack(push, 1) +struct SignBitFactors +{ + float or_minus_c_l2sqr = 0; + float dp_multiplier = 0; +}; + +struct SignBitFactorsWithError : SignBitFactors +{ + float f_error = 0; +}; + +struct ExtraBitsFactors +{ + float f_add_ex = 0; + float f_rescale_ex = 0; +}; +#pragma pack(pop) + +static_assert(sizeof(SignBitFactors) == 8, "Unexpected padding in SignBitFactors"); +static_assert(sizeof(SignBitFactorsWithError) == 12, "Unexpected padding in SignBitFactorsWithError"); +static_assert(sizeof(ExtraBitsFactors) == 8, "Unexpected padding in ExtraBitsFactors"); + +size_t compute_code_size(size_t d, size_t nb_bits); + +// Encodes a single vector (assumes any rotation/centering is already applied externally). +// Layout matches Faiss standard RaBitQ format: +// - sign bits: (d+7)/8 bytes +// - base factors: SignBitFactors (nb_bits==1) or SignBitFactorsWithError (nb_bits>1) +// - ex_code: (d*(nb_bits-1)+7)/8 bytes (only when nb_bits>1) +// - ex_factors: ExtraBitsFactors (only when nb_bits>1) +void encode_vector(const float* x, size_t d, Metric metric, size_t nb_bits, uint8_t* out_code); + +// Approximate IP scorer. +// Returns an approximate inner product (higher is better) computed from a multi-bit RaBitQ code. +// For nb_bits==1, this returns a 1-bit estimate (still usable for prefilter). +float approx_inner_product_from_code(const uint8_t* code, const float* query, size_t d, size_t nb_bits); + +} // namespace rabitq +} // namespace diskann diff --git a/include/utils.h b/include/utils.h index c04a16515..3b91d204a 100644 --- a/include/utils.h +++ b/include/utils.h @@ -27,6 +27,7 @@ typedef int FileHandle; #include "windows_customizations.h" #include "tsl/robin_set.h" #include "types.h" +#include "bfloat16.h" #include "tag_uint128.h" #include @@ -868,10 +869,16 @@ template float prepare_base_for_inner_products(const std::string in size_t BLOCK_SIZE = 100000; size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE; + // IMPORTANT: output must preserve the same element type as the input. + // Disk index build/search reads the preprocessed file using the same T, + // so writing float32 here would corrupt bf16/int8 paths and can cause NaNs + // downstream (e.g., kmeans++ hanging in pivot selection). + using OutT = T; + std::unique_ptr in_block_data = std::make_unique(block_size * in_dims); - std::unique_ptr out_block_data = std::make_unique(block_size * out_dims); + std::unique_ptr out_block_data = std::make_unique(block_size * out_dims); - std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims); + std::memset(out_block_data.get(), 0, sizeof(OutT) * block_size * out_dims); uint64_t num_blocks = DIV_ROUND_UP(npts, block_size); std::vector norms(npts, 0); @@ -886,7 +893,8 @@ template float prepare_base_for_inner_products(const std::string in { for (uint64_t j = 0; j < in_dims; j++) { - norms[start_id + p] += in_block_data[p * in_dims + j] * in_block_data[p * in_dims + j]; + const float v = (float)in_block_data[p * in_dims + j]; + norms[start_id + p] += v * v; } max_norm = max_norm > norms[start_id + p] ? max_norm : norms[start_id + p]; } @@ -905,18 +913,84 @@ template float prepare_base_for_inner_products(const std::string in { for (uint64_t j = 0; j < in_dims; j++) { - out_block_data[p * out_dims + j] = in_block_data[p * in_dims + j] / max_norm; + const float v = (float)in_block_data[p * in_dims + j]; + out_block_data[p * out_dims + j] = (OutT)(v / max_norm); } float res = 1 - (norms[start_id + p] / (max_norm * max_norm)); res = res <= 0 ? 0 : std::sqrt(res); - out_block_data[p * out_dims + out_dims - 1] = res; + out_block_data[p * out_dims + out_dims - 1] = (OutT)res; } - out_writer.write((char *)out_block_data.get(), block_pts * out_dims * sizeof(float)); + out_writer.write((char *)out_block_data.get(), block_pts * out_dims * sizeof(OutT)); } out_writer.close(); return max_norm; } +// Normalize vectors (for cosine) while preserving element type. +// - For float: writes float payload +// - For bfloat16: writes bfloat16 payload +template void normalize_data_file_typed(const std::string &inFileName, const std::string &outFileName) +{ + std::ifstream readr(inFileName, std::ios::binary); + std::ofstream writr(outFileName, std::ios::binary); + + uint32_t npts_u32 = 0, ndims_u32 = 0; + readr.read((char *)&npts_u32, sizeof(uint32_t)); + readr.read((char *)&ndims_u32, sizeof(uint32_t)); + if (!readr) + { + throw diskann::ANNException("Failed to read header from " + inFileName, -1, __FUNCSIG__, __FILE__, __LINE__); + } + + writr.write((char *)&npts_u32, sizeof(uint32_t)); + writr.write((char *)&ndims_u32, sizeof(uint32_t)); + + const uint64_t npts = npts_u32; + const uint64_t ndims = ndims_u32; + + const uint64_t BLOCK_SIZE = 131072; + const uint64_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE; + const uint64_t num_blocks = DIV_ROUND_UP(npts, block_size); + + std::unique_ptr buf = std::make_unique(block_size * ndims); + + for (uint64_t b = 0; b < num_blocks; b++) + { + const uint64_t start_id = b * block_size; + const uint64_t end_id = ((b + 1) * block_size < npts) ? ((b + 1) * block_size) : npts; + const uint64_t block_pts = end_id - start_id; + + readr.read((char *)buf.get(), block_pts * ndims * sizeof(T)); + if (!readr) + { + throw diskann::ANNException("Failed to read payload from " + inFileName, -1, __FUNCSIG__, __FILE__, __LINE__); + } + +#pragma omp parallel for schedule(static, 1024) + for (int64_t i = 0; i < (int64_t)block_pts; i++) + { + float norm_pt = std::numeric_limits::epsilon(); + const uint64_t base = (uint64_t)i * ndims; + for (uint64_t d = 0; d < ndims; d++) + { + const float v = (float)buf[base + d]; + norm_pt += v * v; + } + norm_pt = std::sqrt(norm_pt); + for (uint64_t d = 0; d < ndims; d++) + { + buf[base + d] = (T)((float)buf[base + d] / norm_pt); + } + } + + writr.write((char *)buf.get(), block_pts * ndims * sizeof(T)); + if (!writr) + { + throw diskann::ANNException("Failed to write payload to " + outFileName, -1, __FUNCSIG__, __FILE__, __LINE__); + } + } +} + // plain saves data as npts X ndims array into filename template void save_Tvecs(const char *filename, T *data, size_t npts, size_t ndims) { @@ -1161,6 +1235,10 @@ template <> inline const char *diskann_type_to_name() { return "int8"; } +template <> inline const char *diskann_type_to_name() +{ + return "bf16"; +} template <> inline const char *diskann_type_to_name() { return "uint16"; @@ -1190,9 +1268,6 @@ template <> inline const char *diskann_type_to_name() #include #include -extern bool AvxSupportedCPU; -extern bool Avx2SupportedCPU; - inline size_t getMemoryUsage() { PROCESS_MEMORY_COUNTERS_EX pmc; @@ -1250,3 +1325,4 @@ inline size_t getMemoryUsage() extern bool AvxSupportedCPU; extern bool Avx2SupportedCPU; +extern bool Avx512Bf16SupportedCPU; diff --git a/scripts/consistency/.gitignore b/scripts/consistency/.gitignore new file mode 100644 index 000000000..8e68db421 --- /dev/null +++ b/scripts/consistency/.gitignore @@ -0,0 +1,14 @@ +# Consistency harness outputs (default OUT_DIR) +_out/ + +# Common temp outputs if OUT_DIR is set inside this folder +out/ +output/ +outputs/ + +# Python cache +__pycache__/ +*.py[cod] + +# Editor/OS junk +.DS_Store diff --git a/scripts/consistency/README.md b/scripts/consistency/README.md new file mode 100644 index 000000000..d7138b188 --- /dev/null +++ b/scripts/consistency/README.md @@ -0,0 +1,127 @@ +# Consistency / Precision Checks + +This folder contains a lightweight consistency harness to quantify **bf16 precision loss** relative to float32 for DiskANN. + +It runs the same dataset through: + +- float32 **memory** index vs bf16 **memory** index +- float32 **disk** index vs bf16 **disk** index +- float32 **disk-PQ + reorder** vs bf16 **disk-PQ + reorder** (true bf16 reorder vectors on SSD) + +and reports: + +- Recall@K against float32 ground truth +- Top-1 ID match rate (bf16 vs float) +- Distance absolute error for IDs common to both outputs + +## Quick start + +```bash +# From repo root +bash scripts/consistency/run_consistency.sh +``` + +Artifacts are written to `scripts/consistency/_out` by default. + +## What it runs + +The harness generates a small float32 dataset, converts it to bf16 (round-to-nearest-even), then runs: + +- **Memory** + - float: `build_memory_index` + `search_memory_index` + - bf16: `build_memory_index` + `search_memory_index` +- **Disk (full-precision)** + - float: `build_disk_index` with `--PQ_disk_bytes 0` + `search_disk_index` + - bf16: `build_disk_index` with `--PQ_disk_bytes 0` + `search_disk_index` +- **Disk (PQ + reorder)** + - float: `build_disk_index --PQ_disk_bytes $DISK_PQ_BYTES --append_reorder_data` + `search_disk_index --use_reorder_data` + - bf16: `build_disk_index --PQ_disk_bytes $DISK_PQ_BYTES --append_reorder_data` + `search_disk_index --use_reorder_data` + +Ground truth is computed from the **float32** dataset using exact L2. + +## How to run + +### Default (recommended) + +```bash +bash scripts/consistency/run_consistency.sh +``` + +### Custom size / speed knobs + +Smaller + faster: + +```bash +NPTS=2000 NQ=100 DIM=32 THREADS=4 bash scripts/consistency/run_consistency.sh +``` + +Bigger + more stable stats: + +```bash +OUT_DIR=/tmp/diskann_consistency NPTS=20000 NQ=1000 DIM=128 K=10 MEM_L=100 DISK_L=100 bash scripts/consistency/run_consistency.sh +``` + +### Change disk-PQ compression + +```bash +DISK_PQ_BYTES=16 bash scripts/consistency/run_consistency.sh +``` + +## Where outputs go + +Inside `$OUT_DIR` (default: `scripts/consistency/_out`): + +- `data/` + - `base_f32.bin`, `query_f32.bin` + - `base_bf16.bin`, `query_bf16.bin` + - `gt_l2` (or `gt_l2.bin` depending on build) — the truthset file +- `results/` + - Index prefixes (used by the apps): + - `index_mem_f32*`, `index_mem_bf16*` + - `index_disk_f32_full*`, `index_disk_bf16_full*` + - `index_disk_f32_pq*`, `index_disk_bf16_pq*` + - Search outputs (these are what the analyzer reads): + - Memory: `mem_f32_${MEM_L}_idx_uint32.bin`, `mem_f32_${MEM_L}_dists_float.bin` and bf16 equivalents + - Disk: `disk_*_${DISK_L}_idx_uint32.bin`, `disk_*_${DISK_L}_dists_float.bin` + +## How to read the results + +At the end, the runner calls `scripts/consistency/analyze_results.py` and prints three blocks: + +- `== Memory ==` +- `== Disk (full-precision) ==` +- `== Disk (PQ + reorder) ==` + +Each block reports: + +- **Recall@K (float vs bf16)** + - Computed against the float32 ground truth (exact L2). + - The `delta` shows how much recall changes when switching to bf16. + +- **Top1 ID match rate (bf16 vs float)** + - For each query, whether the #1 result ID matches between bf16 and float outputs. + - Useful when Recall@K is similar but top-1 stability differs. + +- **Distance abs error on common IDs (mean / p99 / max)** + - Per query, only considers IDs that appear in *both* top-K lists (to avoid comparing unrelated candidates). + - Measures absolute distance differences between bf16 and float results. + - For normalized vectors, these values are usually small; if they get large, it may indicate bf16 quantization loss + (or a bug in bf16 IO / distance kernels). + +## Tuning (env vars) + +- `OUT_DIR` (default: `scripts/consistency/_out`) +- `NPTS` (default: 5000), `NQ` (default: 200), `DIM` (default: 64) +- `K` (default: 10) +- `MEM_L` (default: 50), `DISK_L` (default: 50) +- `THREADS` (default: 4) +- Disk build/search knobs: + - `DISK_R` (default: 16) + - `DISK_LBUILD` (default: 50) + - `DISK_PQ_BYTES` (default: 8) + +## Notes / assumptions + +- This harness currently uses **L2** to isolate precision effects; cosine/mips can be added similarly if needed. +- The bf16 dataset is generated by converting the float32 dataset, so any differences should come from bf16 rounding + and bf16 compute paths (not from a different random dataset). diff --git a/scripts/consistency/analyze_results.py b/scripts/consistency/analyze_results.py new file mode 100755 index 000000000..08bf22ca4 --- /dev/null +++ b/scripts/consistency/analyze_results.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 + +import argparse +import math +import struct +from pathlib import Path +from typing import List, Tuple + + +def read_bin_u32_matrix(path: Path) -> Tuple[int, int, List[int]]: + with path.open("rb") as f: + header = f.read(8) + if len(header) != 8: + raise EOFError(f"{path}: missing header") + npts, ndims = struct.unpack(" Tuple[int, int, List[int]]: + """Reads DiskANN compute_groundtruth output. + + The truthset file format is: + u32 npts, u32 K, + u32 ids[npts*K], + f32 dists[npts*K] + + Some builds may omit the dists section; this reader handles both. + """ + with path.open("rb") as f: + header = f.read(8) + if len(header) != 8: + raise EOFError(f"{path}: missing header") + npts, k = struct.unpack(" Tuple[int, int, List[float]]: + with path.open("rb") as f: + header = f.read(8) + if len(header) != 8: + raise EOFError(f"{path}: missing header") + npts, ndims = struct.unpack(" float: + hit = 0 + for qi in range(nq): + gt = set(gt_ids[qi * k : (qi + 1) * k]) + pred = pred_ids[qi * k : (qi + 1) * k] + hit += sum(1 for x in pred if x in gt) + return hit / float(nq * k) + + +def top1_match_rate(a_ids: List[int], b_ids: List[int], nq: int, k: int) -> float: + same = 0 + for qi in range(nq): + if a_ids[qi * k] == b_ids[qi * k]: + same += 1 + return same / float(nq) + + +def distance_error_common_ids( + a_ids: List[int], a_dists: List[float], b_ids: List[int], b_dists: List[float], nq: int, k: int +) -> Tuple[float, float, float]: + # Compare distance values only for IDs that appear in both top-k lists per query. + # Returns (mean_abs, p99_abs, max_abs) over all common-id pairs. + abs_errors: List[float] = [] + + for qi in range(nq): + a_map = {} + base = qi * k + for j in range(k): + a_map[a_ids[base + j]] = a_dists[base + j] + + for j in range(k): + bid = b_ids[base + j] + if bid in a_map: + abs_errors.append(abs(float(b_dists[base + j]) - float(a_map[bid]))) + + if not abs_errors: + return math.nan, math.nan, math.nan + + abs_errors.sort() + mean_abs = sum(abs_errors) / float(len(abs_errors)) + p99_abs = abs_errors[int(0.99 * (len(abs_errors) - 1))] + max_abs = abs_errors[-1] + return mean_abs, p99_abs, max_abs + + +def _paths_for_prefix(prefix: Path, L: int) -> Tuple[Path, Path]: + idx = Path(str(prefix) + f"_{L}_idx_uint32.bin") + dist = Path(str(prefix) + f"_{L}_dists_float.bin") + return idx, dist + + +def compare_block(name: str, gt_path: Path, float_prefix: Path, bf16_prefix: Path, L: int, k: int) -> None: + gt_nq, gt_k, gt_ids = read_truthset_ids(gt_path) + if gt_k < k: + raise ValueError(f"GT K={gt_k} < requested K={k}") + + f_idx_path, f_dist_path = _paths_for_prefix(float_prefix, L) + b_idx_path, b_dist_path = _paths_for_prefix(bf16_prefix, L) + + fq, fk, f_ids = read_bin_u32_matrix(f_idx_path) + _, _, f_dists = read_bin_f32_matrix(f_dist_path) + + bq, bk, b_ids = read_bin_u32_matrix(b_idx_path) + _, _, b_dists = read_bin_f32_matrix(b_dist_path) + + if fq != gt_nq or bq != gt_nq: + raise ValueError(f"Query count mismatch: gt={gt_nq}, float={fq}, bf16={bq}") + if fk != k or bk != k: + raise ValueError(f"Result K mismatch: expected {k}, float={fk}, bf16={bk}") + + r_float = recall_at_k(gt_ids[: gt_nq * k], f_ids, gt_nq, k) + r_bf16 = recall_at_k(gt_ids[: gt_nq * k], b_ids, gt_nq, k) + top1 = top1_match_rate(f_ids, b_ids, gt_nq, k) + mean_abs, p99_abs, max_abs = distance_error_common_ids(f_ids, f_dists, b_ids, b_dists, gt_nq, k) + + print(f"== {name} ==") + print(f"L={L} K={k}") + print(f"Recall@{k}: float={r_float*100:.2f}% bf16={r_bf16*100:.2f}% (delta={(r_bf16-r_float)*100:.2f}%)") + print(f"Top1 ID match rate (bf16 vs float): {top1*100:.2f}%") + print( + "Distance abs error on common IDs: " + f"mean={mean_abs:.6g}, p99={p99_abs:.6g}, max={max_abs:.6g}" + ) + print() + + +def main() -> int: + ap = argparse.ArgumentParser(description="Compare float vs bf16 search outputs (memory + disk)") + ap.add_argument("--gt", required=True, help="Ground-truth .bin (uint32 ids), computed from float data") + ap.add_argument("--L", type=int, required=True) + ap.add_argument("--K", type=int, required=True) + + ap.add_argument("--mem_float", required=True, help="Memory float result prefix (no _L suffix)") + ap.add_argument("--mem_bf16", required=True, help="Memory bf16 result prefix (no _L suffix)") + + ap.add_argument("--disk_float_full", required=True) + ap.add_argument("--disk_bf16_full", required=True) + + ap.add_argument("--disk_float_pq", required=True) + ap.add_argument("--disk_bf16_pq", required=True) + + args = ap.parse_args() + + gt_path = Path(args.gt) + + compare_block("Memory", gt_path, Path(args.mem_float), Path(args.mem_bf16), args.L, args.K) + compare_block("Disk (full-precision)", gt_path, Path(args.disk_float_full), Path(args.disk_bf16_full), args.L, args.K) + compare_block("Disk (PQ + reorder)", gt_path, Path(args.disk_float_pq), Path(args.disk_bf16_pq), args.L, args.K) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/consistency/bin_convert.py b/scripts/consistency/bin_convert.py new file mode 100755 index 000000000..abae8c727 --- /dev/null +++ b/scripts/consistency/bin_convert.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +import argparse +import struct +from pathlib import Path + + +def _read_u32(f) -> int: + b = f.read(4) + if len(b) != 4: + raise EOFError("Unexpected EOF while reading u32") + return struct.unpack(" int: + bits = struct.unpack("> 16) & 1 + bits = (bits + 0x7FFF + lsb) & 0xFFFFFFFF + return (bits >> 16) & 0xFFFF + + +def float_bin_to_bf16_bin(in_path: Path, out_path: Path) -> None: + with in_path.open("rb") as r: + npts = _read_u32(r) + ndims = _read_u32(r) + total = npts * ndims + + payload = r.read(4 * total) + if len(payload) != 4 * total: + raise EOFError( + f"Unexpected EOF: expected {4*total} bytes of float32 payload, got {len(payload)}" + ) + + floats = struct.unpack(f"<{total}f", payload) + + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("wb") as w: + w.write(struct.pack(" None: + with in_path.open("rb") as r: + npts = _read_u32(r) + ndims = _read_u32(r) + total = npts * ndims + + payload = r.read(2 * total) + if len(payload) != 2 * total: + raise EOFError( + f"Unexpected EOF: expected {2*total} bytes of bf16 payload, got {len(payload)}" + ) + + bf16_vals = struct.unpack(f"<{total}H", payload) + + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("wb") as w: + w.write(struct.pack(" int: + ap = argparse.ArgumentParser(description="Convert DiskANN .bin between float32 and bf16") + ap.add_argument("--mode", choices=["float_to_bf16", "bf16_to_float"], required=True) + ap.add_argument("--input", required=True) + ap.add_argument("--output", required=True) + args = ap.parse_args() + + in_path = Path(args.input) + out_path = Path(args.output) + + if args.mode == "float_to_bf16": + float_bin_to_bf16_bin(in_path, out_path) + else: + bf16_bin_to_float_bin(in_path, out_path) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/consistency/run_consistency.sh b/scripts/consistency/run_consistency.sh new file mode 100755 index 000000000..16e6e169e --- /dev/null +++ b/scripts/consistency/run_consistency.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +REPO_ROOT=$(cd "$SCRIPT_DIR/../.." && pwd) +cd "$REPO_ROOT" + +# if [ -d "build/apps" ]; then +# BASE_PATH="build/apps" +# elif [ -d "build/tests" ]; then +# BASE_PATH="build/tests" +# else +# echo "Error: could not find build outputs under build/apps or build/tests" >&2 +# exit 2 +# fi + +if [ -d "build-amx/apps" ]; then + BASE_PATH="build-amx/apps" +elif [ -d "build/tests" ]; then + BASE_PATH="build/tests" +else + echo "Error: could not find build outputs under build/apps or build/tests" >&2 + exit 2 +fi + +OUT_DIR=${OUT_DIR:-"$SCRIPT_DIR/_out"} +NPTS=${NPTS:-5000} +NQ=${NQ:-200} +DIM=${DIM:-64} +K=${K:-10} +MEM_L=${MEM_L:-50} +DISK_L=${DISK_L:-50} +THREADS=${THREADS:-4} + +DISK_R=${DISK_R:-16} +DISK_LBUILD=${DISK_LBUILD:-50} +DISK_SEARCH_DRAM_BUDGET_GB=${DISK_SEARCH_DRAM_BUDGET_GB:-0.25} +DISK_BUILD_DRAM_BUDGET_GB=${DISK_BUILD_DRAM_BUDGET_GB:-2} +DISK_NUM_NODES_TO_CACHE=${DISK_NUM_NODES_TO_CACHE:-0} +DISK_BEAMWIDTH=${DISK_BEAMWIDTH:-2} +DISK_PQ_BYTES=${DISK_PQ_BYTES:-8} + +mkdir -p "$OUT_DIR/data" "$OUT_DIR/results" + +BASE_F32="$OUT_DIR/data/base_f32.bin" +QUERY_F32="$OUT_DIR/data/query_f32.bin" +BASE_BF16="$OUT_DIR/data/base_bf16.bin" +QUERY_BF16="$OUT_DIR/data/query_bf16.bin" +GT_PREFIX="$OUT_DIR/data/gt_l2" +GT_BIN="$GT_PREFIX" + +MEM_F32_PREFIX="$OUT_DIR/results/mem_f32" +MEM_BF16_PREFIX="$OUT_DIR/results/mem_bf16" + +DISK_F32_FULL_PREFIX="$OUT_DIR/results/disk_f32_full" +DISK_BF16_FULL_PREFIX="$OUT_DIR/results/disk_bf16_full" +DISK_F32_PQ_PREFIX="$OUT_DIR/results/disk_f32_pq" +DISK_BF16_PQ_PREFIX="$OUT_DIR/results/disk_bf16_pq" + +IDX_MEM_F32_PREFIX="$OUT_DIR/results/index_mem_f32" +IDX_MEM_BF16_PREFIX="$OUT_DIR/results/index_mem_bf16" +IDX_DISK_F32_FULL_PREFIX="$OUT_DIR/results/index_disk_f32_full" +IDX_DISK_BF16_FULL_PREFIX="$OUT_DIR/results/index_disk_bf16_full" +IDX_DISK_F32_PQ_PREFIX="$OUT_DIR/results/index_disk_f32_pq" +IDX_DISK_BF16_PQ_PREFIX="$OUT_DIR/results/index_disk_bf16_pq" + +echo "[1/6] Generate float32 base/query (NPTS=$NPTS NQ=$NQ DIM=$DIM)" +"$BASE_PATH/utils/rand_data_gen" --data_type float --output_file "$BASE_F32" -D "$DIM" -N "$NPTS" --norm 1.0 +"$BASE_PATH/utils/rand_data_gen" --data_type float --output_file "$QUERY_F32" -D "$DIM" -N "$NQ" --norm 1.0 + +echo "[2/6] Convert float32 -> bf16 (round-to-nearest-even)" +python3 "$SCRIPT_DIR/bin_convert.py" --mode float_to_bf16 --input "$BASE_F32" --output "$BASE_BF16" +python3 "$SCRIPT_DIR/bin_convert.py" --mode float_to_bf16 --input "$QUERY_F32" --output "$QUERY_BF16" + +echo "[3/6] Compute float32 ground truth (L2)" +"$BASE_PATH/utils/compute_groundtruth" --data_type float --dist_fn l2 --base_file "$BASE_F32" --query_file "$QUERY_F32" --gt_file "$GT_PREFIX" --K "$K" + +# compute_groundtruth historically may or may not append a ".bin" suffix depending on build. +if [[ -f "$GT_PREFIX" ]]; then + GT_BIN="$GT_PREFIX" +elif [[ -f "$GT_PREFIX.bin" ]]; then + GT_BIN="$GT_PREFIX.bin" +else + echo "Error: could not find ground truth output at '$GT_PREFIX' or '$GT_PREFIX.bin'" >&2 + exit 2 +fi + +echo "[4/6] Memory: build + search (float vs bf16)" +"$BASE_PATH/build_memory_index" --data_type float --dist_fn l2 --data_path "$BASE_F32" --index_path_prefix "$IDX_MEM_F32_PREFIX" +"$BASE_PATH/build_memory_index" --data_type bf16 --dist_fn l2 --data_path "$BASE_BF16" --index_path_prefix "$IDX_MEM_BF16_PREFIX" + +"$BASE_PATH/search_memory_index" --data_type float --dist_fn l2 --index_path_prefix "$IDX_MEM_F32_PREFIX" --query_file "$QUERY_F32" --recall_at "$K" --result_path "$MEM_F32_PREFIX" --gt_file "$GT_BIN" -L "$MEM_L" -T "$THREADS" +"$BASE_PATH/search_memory_index" --data_type bf16 --dist_fn l2 --index_path_prefix "$IDX_MEM_BF16_PREFIX" --query_file "$QUERY_BF16" --recall_at "$K" --result_path "$MEM_BF16_PREFIX" --gt_file "$GT_BIN" -L "$MEM_L" -T "$THREADS" + +echo "[5/6] Disk: build + search (full-precision and PQ+reorder; float vs bf16)" +# Full-precision disk +"$BASE_PATH/build_disk_index" --data_type float --dist_fn l2 --data_path "$BASE_F32" --index_path_prefix "$IDX_DISK_F32_FULL_PREFIX" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes 0 --build_PQ_bytes 0 -T "$THREADS" +"$BASE_PATH/build_disk_index" --data_type bf16 --dist_fn l2 --data_path "$BASE_BF16" --index_path_prefix "$IDX_DISK_BF16_FULL_PREFIX" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes 0 --build_PQ_bytes 0 -T "$THREADS" + +"$BASE_PATH/search_disk_index" --data_type float --dist_fn l2 --index_path_prefix "$IDX_DISK_F32_FULL_PREFIX" --query_file "$QUERY_F32" --result_path "$DISK_F32_FULL_PREFIX" --gt_file "$GT_BIN" -K "$K" -L "$DISK_L" -W "$DISK_BEAMWIDTH" -T "$THREADS" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" +"$BASE_PATH/search_disk_index" --data_type bf16 --dist_fn l2 --index_path_prefix "$IDX_DISK_BF16_FULL_PREFIX" --query_file "$QUERY_BF16" --result_path "$DISK_BF16_FULL_PREFIX" --gt_file "$GT_BIN" -K "$K" -L "$DISK_L" -W "$DISK_BEAMWIDTH" -T "$THREADS" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" + +# Disk PQ + reorder +"$BASE_PATH/build_disk_index" --data_type float --dist_fn l2 --data_path "$BASE_F32" --index_path_prefix "$IDX_DISK_F32_PQ_PREFIX" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes "$DISK_PQ_BYTES" --build_PQ_bytes 0 --append_reorder_data -T "$THREADS" +"$BASE_PATH/build_disk_index" --data_type bf16 --dist_fn l2 --data_path "$BASE_BF16" --index_path_prefix "$IDX_DISK_BF16_PQ_PREFIX" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes "$DISK_PQ_BYTES" --build_PQ_bytes 0 --append_reorder_data -T "$THREADS" + +"$BASE_PATH/search_disk_index" --data_type float --dist_fn l2 --index_path_prefix "$IDX_DISK_F32_PQ_PREFIX" --query_file "$QUERY_F32" --result_path "$DISK_F32_PQ_PREFIX" --gt_file "$GT_BIN" -K "$K" -L "$DISK_L" -W "$DISK_BEAMWIDTH" -T "$THREADS" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" --use_reorder_data +"$BASE_PATH/search_disk_index" --data_type bf16 --dist_fn l2 --index_path_prefix "$IDX_DISK_BF16_PQ_PREFIX" --query_file "$QUERY_BF16" --result_path "$DISK_BF16_PQ_PREFIX" --gt_file "$GT_BIN" -K "$K" -L "$DISK_L" -W "$DISK_BEAMWIDTH" -T "$THREADS" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" --use_reorder_data + +echo "[6/6] Analyze float vs bf16 deltas" +python3 "$SCRIPT_DIR/analyze_results.py" \ + --gt "$GT_BIN" \ + --L "$MEM_L" \ + --K "$K" \ + --mem_float "$MEM_F32_PREFIX" \ + --mem_bf16 "$MEM_BF16_PREFIX" \ + --disk_float_full "$DISK_F32_FULL_PREFIX" \ + --disk_bf16_full "$DISK_BF16_FULL_PREFIX" \ + --disk_float_pq "$DISK_F32_PQ_PREFIX" \ + --disk_bf16_pq "$DISK_BF16_PQ_PREFIX" + +echo "Done. Artifacts under: $OUT_DIR" diff --git a/scripts/perf/README.md b/scripts/perf/README.md index 692eedca7..3b14f317b 100644 --- a/scripts/perf/README.md +++ b/scripts/perf/README.md @@ -9,6 +9,67 @@ development continues. ## Usage +### Run the perf suite + +The main entrypoint is `scripts/perf/perf_test.sh`. + +Environment variables: + +- `DATA_TYPE` (default: `float`) + - Supported: `float`, `bf16` +- `PERF_MODE` (default: `memory`) + - `memory`: run **in-memory** index perf only + - `disk`: run **SSD/disk** index perf only + - `both`: run **both** memory + disk perf + +`PERF_MODE` details: + +- `PERF_MODE=memory` + - Builds and searches **in-memory** indexes (`build_memory_index`, `search_memory_index`). + - For `DATA_TYPE=float`, also runs `fast_l2` memory search. + +- `PERF_MODE=disk` + - Builds and searches **SSD/disk** indexes (`build_disk_index`, `search_disk_index`). + - For `DATA_TYPE=float`, runs disk perf for `l2`, `mips`, `cosine`. + - For `DATA_TYPE=bf16`, runs disk perf for `l2`, `cosine`: + - full-precision disk (`PQ_disk_bytes=0`) + - disk-PQ + reorder (`--append_reorder_data` / `--use_reorder_data`) + +- `PERF_MODE=both` + - Runs everything from both `memory` and `disk`. + +Legacy compatibility: + +- `RUN_DISK=1` maps to `PERF_MODE=both` +- `RUN_DISK=0` maps to `PERF_MODE=memory` + +Examples: + +```bash +# Memory index perf with bf16 +DATA_TYPE=bf16 PERF_MODE=memory ./scripts/perf/perf_test.sh + +# Memory + disk perf with bf16 +DATA_TYPE=bf16 PERF_MODE=both ./scripts/perf/perf_test.sh + +# Disk index perf (float) +DATA_TYPE=float PERF_MODE=disk ./scripts/perf/perf_test.sh + +# Disk index perf (bf16): runs both full-precision disk and disk-PQ(+reorder) +DATA_TYPE=bf16 PERF_MODE=disk ./scripts/perf/perf_test.sh + +# Disk index perf (bf16) with custom disk-PQ bytes (default: 8) +DATA_TYPE=bf16 PERF_MODE=disk DISK_BF16_PQ_DISK_BYTES=16 ./scripts/perf/perf_test.sh +``` + +Notes: + +- For `DATA_TYPE=bf16` and `PERF_MODE` includes `disk`, the script runs: + - full-precision disk (`PQ_disk_bytes=0`) and + - disk-PQ with reorder (`--append_reorder_data` / `--use_reorder_data`). +- `DISK_BF16_PQ_DISK_BYTES` controls the bf16 disk-PQ compression level for the disk-PQ runs (default: 8). +- For backward compatibility, `RUN_DISK` is still accepted and overrides `PERF_MODE`. + `docker build` must be run with the context directory set to `scripts`, but the Dockerfile set to `scripts/perf/Dockerfile` as in: ```bash docker build [--build-arg GIT_COMMIT_ISH=] -f scripts/perf/Dockerfile scripts diff --git a/scripts/perf/perf_test.sh b/scripts/perf/perf_test.sh index a8d537f01..669d1d566 100644 --- a/scripts/perf/perf_test.sh +++ b/scripts/perf/perf_test.sh @@ -1,40 +1,143 @@ #!/bin/bash +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +REPO_ROOT=$(cd "$SCRIPT_DIR/../.." && pwd) +cd "$REPO_ROOT" + +LOG_DIR=${LOG_DIR:-/app/logs} +mkdir -p "$LOG_DIR" +TIME_LOG="$LOG_DIR/time.log" + +DATA_TYPE=${DATA_TYPE:-float} +if [[ "$DATA_TYPE" != "float" && "$DATA_TYPE" != "bf16" ]]; then + echo "Unsupported DATA_TYPE='$DATA_TYPE'. Use DATA_TYPE=float or DATA_TYPE=bf16." + exit 2 +fi + +# Choose which index type(s) to run: memory, disk, or both. +# Default is memory. +PERF_MODE=${PERF_MODE:-memory} +if [[ -n "$RUN_DISK" ]]; then + # Backward compatibility: existing RUN_DISK=1 turns on disk tests. + if [[ "$RUN_DISK" == "1" ]]; then + PERF_MODE=both + elif [[ "$RUN_DISK" == "0" ]]; then + PERF_MODE=memory + else + echo "Unsupported RUN_DISK='$RUN_DISK'. Use RUN_DISK=0 or RUN_DISK=1." + exit 2 + fi +fi + +if [[ "$PERF_MODE" != "memory" && "$PERF_MODE" != "disk" && "$PERF_MODE" != "both" ]]; then + echo "Unsupported PERF_MODE='$PERF_MODE'. Use PERF_MODE=memory|disk|both." + exit 2 +fi + function json_time { command="$@" echo "Executing $command" - /usr/bin/time --quiet -o /app/logs/time.log -a --format '{"command":"%C", "wallclock": %e, "user": %U, "sys": %S}' $command + /usr/bin/time --quiet -o "$TIME_LOG" -a --format '{"command":"%C", "wallclock": %e, "user": %U, "sys": %S}' $command ret=$? if [ $ret -ne 0 ]; then - echo "{\"command\": \""$command"\", \"status_code\": $ret}" >> /app/logs/time.log + echo "{\"command\": \""$command"\", \"status_code\": $ret}" >> "$TIME_LOG" fi } mkdir data -rm /app/logs/time.log -touch /app/logs/time.log -chmod 666 /app/logs/time.log +rm -f "$TIME_LOG" +touch "$TIME_LOG" +chmod 666 "$TIME_LOG" if [ -d "build/apps" ]; then - export BASE_PATH="build/apps" + export BASE_PATH="build/apps" else - export BASE_PATH="build/tests" + export BASE_PATH="build/tests" fi -json_time $BASE_PATH/utils/rand_data_gen --data_type float --output_file data/rand_float_768D_1M_norm1.0.bin -D 768 -N 1000000 --norm 1.0 -json_time $BASE_PATH/utils/rand_data_gen --data_type float --output_file data/rand_float_768D_10K_norm1.0.bin -D 768 -N 10000 --norm 1.0 +BASE_FILE="data/rand_${DATA_TYPE}_768D_1M_norm1.0.bin" +QUERY_FILE="data/rand_${DATA_TYPE}_768D_10K_norm1.0.bin" -json_time $BASE_PATH/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_768D_1M_norm1.0.bin --query_file data/rand_float_768D_10K_norm1.0.bin --gt_file data/l2_rand_float_768D_1M_norm1.0_768D_10K_norm1.0_gt100 --K 100 -json_time $BASE_PATH/utils/compute_groundtruth --data_type float --dist_fn mips --base_file data/rand_float_768D_1M_norm1.0.bin --query_file data/rand_float_768D_10K_norm1.0.bin --gt_file data/mips_rand_float_768D_1M_norm1.0_768D_10K_norm1.0_gt100 --K 100 -json_time $BASE_PATH/utils/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_768D_1M_norm1.0.bin --query_file data/rand_float_768D_10K_norm1.0.bin --gt_file data/cosine_rand_float_768D_1M_norm1.0_768D_10K_norm1.0_gt100 --K 100 +GT_L2_FILE="data/l2_rand_${DATA_TYPE}_768D_1M_norm1.0_768D_10K_norm1.0_gt100" +GT_MIPS_FILE="data/mips_rand_${DATA_TYPE}_768D_1M_norm1.0_768D_10K_norm1.0_gt100" +GT_COSINE_FILE="data/cosine_rand_${DATA_TYPE}_768D_1M_norm1.0_768D_10K_norm1.0_gt100" -json_time $BASE_PATH/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_768D_1M_norm1.0.bin --index_path_prefix data/index_l2_rand_float_768D_1M_norm1.0 -json_time $BASE_PATH/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/index_l2_rand_float_768D_1M_norm1.0 --query_file data/rand_float_768D_10K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_768D_1M_norm1.0_768D_10K_norm1.0_gt100 -L 100 32 -json_time $BASE_PATH/search_memory_index --data_type float --dist_fn fast_l2 --index_path_prefix data/index_l2_rand_float_768D_1M_norm1.0 --query_file data/rand_float_768D_10K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_768D_1M_norm1.0_768D_10K_norm1.0_gt100 -L 100 32 +INDEX_L2_PREFIX="data/index_l2_rand_${DATA_TYPE}_768D_1M_norm1.0" +INDEX_MIPS_PREFIX="data/index_mips_rand_${DATA_TYPE}_768D_1M_norm1.0" +INDEX_COSINE_PREFIX="data/index_cosine_rand_${DATA_TYPE}_768D_1M_norm1.0" -json_time $BASE_PATH/build_memory_index --data_type float --dist_fn mips --data_path data/rand_float_768D_1M_norm1.0.bin --index_path_prefix data/index_mips_rand_float_768D_1M_norm1.0 -json_time $BASE_PATH/search_memory_index --data_type float --dist_fn mips --index_path_prefix data/index_l2_rand_float_768D_1M_norm1.0 --query_file data/rand_float_768D_10K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/mips_rand_float_768D_1M_norm1.0_768D_10K_norm1.0_gt100 -L 100 32 +json_time $BASE_PATH/utils/rand_data_gen --data_type "$DATA_TYPE" --output_file "$BASE_FILE" -D 768 -N 1000000 --norm 1.0 +json_time $BASE_PATH/utils/rand_data_gen --data_type "$DATA_TYPE" --output_file "$QUERY_FILE" -D 768 -N 10000 --norm 1.0 -json_time $BASE_PATH/build_memory_index --data_type float --dist_fn cosine --data_path data/rand_float_768D_1M_norm1.0.bin --index_path_prefix data/index_cosine_rand_float_768D_1M_norm1.0 -json_time $BASE_PATH/search_memory_index --data_type float --dist_fn cosine --index_path_prefix data/index_l2_rand_float_768D_1M_norm1.0 --query_file data/rand_float_768D_10K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_float_768D_1M_norm1.0_768D_10K_norm1.0_gt100 -L 100 32 +json_time $BASE_PATH/utils/compute_groundtruth --data_type "$DATA_TYPE" --dist_fn l2 --base_file "$BASE_FILE" --query_file "$QUERY_FILE" --gt_file "$GT_L2_FILE" --K 100 +json_time $BASE_PATH/utils/compute_groundtruth --data_type "$DATA_TYPE" --dist_fn mips --base_file "$BASE_FILE" --query_file "$QUERY_FILE" --gt_file "$GT_MIPS_FILE" --K 100 +json_time $BASE_PATH/utils/compute_groundtruth --data_type "$DATA_TYPE" --dist_fn cosine --base_file "$BASE_FILE" --query_file "$QUERY_FILE" --gt_file "$GT_COSINE_FILE" --K 100 + +if [[ "$PERF_MODE" == "memory" || "$PERF_MODE" == "both" ]]; then + json_time $BASE_PATH/build_memory_index --data_type "$DATA_TYPE" --dist_fn l2 --data_path "$BASE_FILE" --index_path_prefix "$INDEX_L2_PREFIX" + json_time $BASE_PATH/search_memory_index --data_type "$DATA_TYPE" --dist_fn l2 --index_path_prefix "$INDEX_L2_PREFIX" --query_file "$QUERY_FILE" --recall_at 10 --result_path temp --gt_file "$GT_L2_FILE" -L 100 32 + if [[ "$DATA_TYPE" == "float" ]]; then + json_time $BASE_PATH/search_memory_index --data_type "$DATA_TYPE" --dist_fn fast_l2 --index_path_prefix "$INDEX_L2_PREFIX" --query_file "$QUERY_FILE" --recall_at 10 --result_path temp --gt_file "$GT_L2_FILE" -L 100 32 + fi + + json_time $BASE_PATH/build_memory_index --data_type "$DATA_TYPE" --dist_fn mips --data_path "$BASE_FILE" --index_path_prefix "$INDEX_MIPS_PREFIX" + json_time $BASE_PATH/search_memory_index --data_type "$DATA_TYPE" --dist_fn mips --index_path_prefix "$INDEX_L2_PREFIX" --query_file "$QUERY_FILE" --recall_at 10 --result_path temp --gt_file "$GT_MIPS_FILE" -L 100 32 + + json_time $BASE_PATH/build_memory_index --data_type "$DATA_TYPE" --dist_fn cosine --data_path "$BASE_FILE" --index_path_prefix "$INDEX_COSINE_PREFIX" + json_time $BASE_PATH/search_memory_index --data_type "$DATA_TYPE" --dist_fn cosine --index_path_prefix "$INDEX_L2_PREFIX" --query_file "$QUERY_FILE" --recall_at 10 --result_path temp --gt_file "$GT_COSINE_FILE" -L 100 32 +fi + +# Optional SSD/disk index perf (mixed RAM+SSD). +if [[ "$PERF_MODE" == "disk" || "$PERF_MODE" == "both" ]]; then + DISK_R=${DISK_R:-32} + DISK_LBUILD=${DISK_LBUILD:-50} + DISK_SEARCH_DRAM_BUDGET_GB=${DISK_SEARCH_DRAM_BUDGET_GB:-0.5} + DISK_BUILD_DRAM_BUDGET_GB=${DISK_BUILD_DRAM_BUDGET_GB:-8} + DISK_PQ_DISK_BYTES=${DISK_PQ_DISK_BYTES:-0} + DISK_BUILD_PQ_BYTES=${DISK_BUILD_PQ_BYTES:-0} + DISK_NUM_NODES_TO_CACHE=${DISK_NUM_NODES_TO_CACHE:-10000} + DISK_BEAMWIDTH=${DISK_BEAMWIDTH:-2} + DISK_RECALL_AT=${DISK_RECALL_AT:-10} + DISK_SEARCH_LISTS=${DISK_SEARCH_LISTS:-"10 20 30 40 50 100"} + + mkdir -p temp + + if [[ "$DATA_TYPE" == "float" ]]; then + DISK_INDEX_L2_PREFIX="data/disk_index_l2_rand_${DATA_TYPE}_768D_1M_R${DISK_R}_L${DISK_LBUILD}_B${DISK_SEARCH_DRAM_BUDGET_GB}_M${DISK_BUILD_DRAM_BUDGET_GB}" + DISK_INDEX_MIPS_PREFIX="data/disk_index_mips_rand_${DATA_TYPE}_768D_1M_R${DISK_R}_L${DISK_LBUILD}_B${DISK_SEARCH_DRAM_BUDGET_GB}_M${DISK_BUILD_DRAM_BUDGET_GB}" + DISK_INDEX_COSINE_PREFIX="data/disk_index_cosine_rand_${DATA_TYPE}_768D_1M_R${DISK_R}_L${DISK_LBUILD}_B${DISK_SEARCH_DRAM_BUDGET_GB}_M${DISK_BUILD_DRAM_BUDGET_GB}" + + json_time $BASE_PATH/build_disk_index --data_type "$DATA_TYPE" --dist_fn l2 --data_path "$BASE_FILE" --index_path_prefix "$DISK_INDEX_L2_PREFIX" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes "$DISK_PQ_DISK_BYTES" --build_PQ_bytes "$DISK_BUILD_PQ_BYTES" + json_time $BASE_PATH/search_disk_index --data_type "$DATA_TYPE" --dist_fn l2 --index_path_prefix "$DISK_INDEX_L2_PREFIX" --query_file "$QUERY_FILE" --gt_file "$GT_L2_FILE" -K "$DISK_RECALL_AT" -L $DISK_SEARCH_LISTS --result_path "temp/disk_l2" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" -W "$DISK_BEAMWIDTH" + + json_time $BASE_PATH/build_disk_index --data_type "$DATA_TYPE" --dist_fn mips --data_path "$BASE_FILE" --index_path_prefix "$DISK_INDEX_MIPS_PREFIX" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes "$DISK_PQ_DISK_BYTES" --build_PQ_bytes "$DISK_BUILD_PQ_BYTES" + json_time $BASE_PATH/search_disk_index --data_type "$DATA_TYPE" --dist_fn mips --index_path_prefix "$DISK_INDEX_MIPS_PREFIX" --query_file "$QUERY_FILE" --gt_file "$GT_MIPS_FILE" -K "$DISK_RECALL_AT" -L $DISK_SEARCH_LISTS --result_path "temp/disk_mips" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" -W "$DISK_BEAMWIDTH" + + json_time $BASE_PATH/build_disk_index --data_type "$DATA_TYPE" --dist_fn cosine --data_path "$BASE_FILE" --index_path_prefix "$DISK_INDEX_COSINE_PREFIX" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes "$DISK_PQ_DISK_BYTES" --build_PQ_bytes "$DISK_BUILD_PQ_BYTES" + json_time $BASE_PATH/search_disk_index --data_type "$DATA_TYPE" --dist_fn cosine --index_path_prefix "$DISK_INDEX_COSINE_PREFIX" --query_file "$QUERY_FILE" --gt_file "$GT_COSINE_FILE" -K "$DISK_RECALL_AT" -L $DISK_SEARCH_LISTS --result_path "temp/disk_cosine" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" -W "$DISK_BEAMWIDTH" + elif [[ "$DATA_TYPE" == "bf16" ]]; then + # bf16 disk: run both full-precision and disk-PQ(+reorder) to cover true bf16-on-SSD workflows. + # Note: mips is not part of bf16 disk perf since the CLI advertises mips as float-only. + DISK_BF16_PQ_DISK_BYTES=${DISK_BF16_PQ_DISK_BYTES:-8} + + DISK_INDEX_L2_PREFIX_FULL="data/disk_index_l2_rand_${DATA_TYPE}_768D_1M_full_R${DISK_R}_L${DISK_LBUILD}_B${DISK_SEARCH_DRAM_BUDGET_GB}_M${DISK_BUILD_DRAM_BUDGET_GB}" + DISK_INDEX_COSINE_PREFIX_FULL="data/disk_index_cosine_rand_${DATA_TYPE}_768D_1M_full_R${DISK_R}_L${DISK_LBUILD}_B${DISK_SEARCH_DRAM_BUDGET_GB}_M${DISK_BUILD_DRAM_BUDGET_GB}" + DISK_INDEX_L2_PREFIX_PQ="data/disk_index_l2_rand_${DATA_TYPE}_768D_1M_pq${DISK_BF16_PQ_DISK_BYTES}_R${DISK_R}_L${DISK_LBUILD}_B${DISK_SEARCH_DRAM_BUDGET_GB}_M${DISK_BUILD_DRAM_BUDGET_GB}" + DISK_INDEX_COSINE_PREFIX_PQ="data/disk_index_cosine_rand_${DATA_TYPE}_768D_1M_pq${DISK_BF16_PQ_DISK_BYTES}_R${DISK_R}_L${DISK_LBUILD}_B${DISK_SEARCH_DRAM_BUDGET_GB}_M${DISK_BUILD_DRAM_BUDGET_GB}" + + # Full-precision bf16 on disk. + json_time $BASE_PATH/build_disk_index --data_type "$DATA_TYPE" --dist_fn l2 --data_path "$BASE_FILE" --index_path_prefix "$DISK_INDEX_L2_PREFIX_FULL" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes 0 --build_PQ_bytes "$DISK_BUILD_PQ_BYTES" + json_time $BASE_PATH/search_disk_index --data_type "$DATA_TYPE" --dist_fn l2 --index_path_prefix "$DISK_INDEX_L2_PREFIX_FULL" --query_file "$QUERY_FILE" --gt_file "$GT_L2_FILE" -K "$DISK_RECALL_AT" -L $DISK_SEARCH_LISTS --result_path "temp/disk_bf16_l2_full" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" -W "$DISK_BEAMWIDTH" + + json_time $BASE_PATH/build_disk_index --data_type "$DATA_TYPE" --dist_fn cosine --data_path "$BASE_FILE" --index_path_prefix "$DISK_INDEX_COSINE_PREFIX_FULL" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes 0 --build_PQ_bytes "$DISK_BUILD_PQ_BYTES" + json_time $BASE_PATH/search_disk_index --data_type "$DATA_TYPE" --dist_fn cosine --index_path_prefix "$DISK_INDEX_COSINE_PREFIX_FULL" --query_file "$QUERY_FILE" --gt_file "$GT_COSINE_FILE" -K "$DISK_RECALL_AT" -L $DISK_SEARCH_LISTS --result_path "temp/disk_bf16_cosine_full" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" -W "$DISK_BEAMWIDTH" + + # Disk-PQ + reorder (true bf16 reorder vectors on SSD). + json_time $BASE_PATH/build_disk_index --data_type "$DATA_TYPE" --dist_fn l2 --data_path "$BASE_FILE" --index_path_prefix "$DISK_INDEX_L2_PREFIX_PQ" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes "$DISK_BF16_PQ_DISK_BYTES" --build_PQ_bytes "$DISK_BUILD_PQ_BYTES" --append_reorder_data + json_time $BASE_PATH/search_disk_index --data_type "$DATA_TYPE" --dist_fn l2 --index_path_prefix "$DISK_INDEX_L2_PREFIX_PQ" --query_file "$QUERY_FILE" --gt_file "$GT_L2_FILE" -K "$DISK_RECALL_AT" -L $DISK_SEARCH_LISTS --result_path "temp/disk_bf16_l2_pq" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" -W "$DISK_BEAMWIDTH" --use_reorder_data + + json_time $BASE_PATH/build_disk_index --data_type "$DATA_TYPE" --dist_fn cosine --data_path "$BASE_FILE" --index_path_prefix "$DISK_INDEX_COSINE_PREFIX_PQ" -R "$DISK_R" -L "$DISK_LBUILD" -B "$DISK_SEARCH_DRAM_BUDGET_GB" -M "$DISK_BUILD_DRAM_BUDGET_GB" --PQ_disk_bytes "$DISK_BF16_PQ_DISK_BYTES" --build_PQ_bytes "$DISK_BUILD_PQ_BYTES" --append_reorder_data + json_time $BASE_PATH/search_disk_index --data_type "$DATA_TYPE" --dist_fn cosine --index_path_prefix "$DISK_INDEX_COSINE_PREFIX_PQ" --query_file "$QUERY_FILE" --gt_file "$GT_COSINE_FILE" -K "$DISK_RECALL_AT" -L $DISK_SEARCH_LISTS --result_path "temp/disk_bf16_cosine_pq" --num_nodes_to_cache "$DISK_NUM_NODES_TO_CACHE" -W "$DISK_BEAMWIDTH" --use_reorder_data + fi +fi diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cbca26440..4225d3462 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,6 +14,74 @@ else() in_mem_data_store.cpp in_mem_graph_store.cpp natural_number_set.cpp memory_mapper.cpp partition.cpp pq.cpp pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp index_factory.cpp abstract_index.cpp pq_l2_distance.cpp pq_data_store.cpp) + list(APPEND CPP_SOURCES rabitq.cpp) + list(APPEND CPP_SOURCES bf16_simd_kernels.cpp) + list(APPEND CPP_SOURCES bf16_amx_kernels.cpp) + + include(CheckCXXCompilerFlag) + + # AVX-512 BF16 kernels: compile-time optional, runtime dispatched. + # - DISKANN_AVX512BF16 (default ON): enable kernels when compiler supports required flags. + # - DISKANN_FORCE_AVX512BF16 (default OFF): require kernels; fail configure if unsupported. + option(DISKANN_AVX512BF16 "Enable AVX-512 BF16 kernels when supported by compiler" ON) + option(DISKANN_FORCE_AVX512BF16 "Force AVX-512 BF16 kernels; fail if unsupported" OFF) + + check_cxx_compiler_flag("-mavx512bf16" DISKANN_HAS_MAVX512BF16) + check_cxx_compiler_flag("-mavx512f" DISKANN_HAS_MAVX512F) + + set(DISKANN_ENABLE_AVX512BF16_KERNELS OFF) + if (DISKANN_FORCE_AVX512BF16) + if (NOT (DISKANN_HAS_MAVX512BF16 AND DISKANN_HAS_MAVX512F)) + message(FATAL_ERROR "DISKANN_FORCE_AVX512BF16=ON but compiler does not support -mavx512bf16 and -mavx512f") + endif() + set(DISKANN_ENABLE_AVX512BF16_KERNELS ON) + elseif (DISKANN_AVX512BF16) + if (DISKANN_HAS_MAVX512BF16 AND DISKANN_HAS_MAVX512F) + set(DISKANN_ENABLE_AVX512BF16_KERNELS ON) + endif() + endif() + + if (DISKANN_ENABLE_AVX512BF16_KERNELS) + set_source_files_properties(bf16_simd_kernels.cpp PROPERTIES + COMPILE_OPTIONS "-mavx512f;-mavx512bw;-mavx512vl;-mavx512dq;-mavx512bf16") + endif() + + # AMX BF16 kernels: compile-time optional, runtime dispatched. + # - DISKANN_AMXBF16 (default OFF): enable kernels when compiler supports required flags. + # - DISKANN_FORCE_AMXBF16 (default OFF): require kernels; fail configure if unsupported. + option(DISKANN_AMXBF16 "Enable AMX BF16 kernels when supported by compiler" OFF) + option(DISKANN_FORCE_AMXBF16 "Force AMX BF16 kernels; fail if unsupported" OFF) + + check_cxx_compiler_flag("-mamx-tile" DISKANN_HAS_MAMX_TILE) + check_cxx_compiler_flag("-mamx-bf16" DISKANN_HAS_MAMX_BF16) + # Some builds may set global -march=native, which can implicitly enable AMX. + # When DISKANN_AMXBF16=OFF, we still want to avoid emitting any AMX instructions. + check_cxx_compiler_flag("-mno-amx-tile" DISKANN_HAS_MNO_AMX_TILE) + check_cxx_compiler_flag("-mno-amx-bf16" DISKANN_HAS_MNO_AMX_BF16) + + set(DISKANN_ENABLE_AMXBF16_KERNELS OFF) + if (DISKANN_FORCE_AMXBF16) + if (NOT (DISKANN_HAS_MAMX_TILE AND DISKANN_HAS_MAMX_BF16)) + message(FATAL_ERROR "DISKANN_FORCE_AMXBF16=ON but compiler does not support -mamx-tile and -mamx-bf16") + endif() + set(DISKANN_ENABLE_AMXBF16_KERNELS ON) + elseif (DISKANN_AMXBF16) + if (DISKANN_HAS_MAMX_TILE AND DISKANN_HAS_MAMX_BF16) + set(DISKANN_ENABLE_AMXBF16_KERNELS ON) + endif() + endif() + + if (DISKANN_ENABLE_AMXBF16_KERNELS) + set_source_files_properties(bf16_amx_kernels.cpp PROPERTIES + COMPILE_OPTIONS "-mamx-tile;-mamx-bf16") + else() + # Explicitly disable AMX for this translation unit to prevent accidental + # enablement via global compiler flags such as -march=native. + if (DISKANN_HAS_MNO_AMX_TILE AND DISKANN_HAS_MNO_AMX_BF16) + set_source_files_properties(bf16_amx_kernels.cpp PROPERTIES + COMPILE_OPTIONS "-mno-amx-tile;-mno-amx-bf16") + endif() + endif() if (RESTAPI) list(APPEND CPP_SOURCES restapi/search_wrapper.cpp restapi/server.cpp) endif() diff --git a/src/abstract_data_store.cpp b/src/abstract_data_store.cpp index 0cff0152e..745ac7fca 100644 --- a/src/abstract_data_store.cpp +++ b/src/abstract_data_store.cpp @@ -3,6 +3,7 @@ #include #include "abstract_data_store.h" +#include "bfloat16.h" namespace diskann { @@ -42,4 +43,5 @@ template location_t AbstractDataStore::resize(const lo template DISKANN_DLLEXPORT class AbstractDataStore; template DISKANN_DLLEXPORT class AbstractDataStore; template DISKANN_DLLEXPORT class AbstractDataStore; +template DISKANN_DLLEXPORT class AbstractDataStore; } // namespace diskann diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 92665825f..498aa26b9 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -1,6 +1,7 @@ #include "common_includes.h" #include "windows_customizations.h" #include "abstract_index.h" +#include "bfloat16.h" namespace diskann { @@ -147,6 +148,8 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search AbstractIndex::search( const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const diskann::bfloat16 *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair AbstractIndex::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -154,6 +157,8 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search AbstractIndex::search( const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const diskann::bfloat16 *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, @@ -166,6 +171,9 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search_w template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, std::vector &res_vectors, bool use_filters, const std::string filter_label); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const diskann::bfloat16 *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, @@ -178,6 +186,9 @@ template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, std::vector &res_vectors, bool use_filters, const std::string filter_label); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const diskann::bfloat16 *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, @@ -190,6 +201,9 @@ template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, std::vector &res_vectors, bool use_filters, const std::string filter_label); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const diskann::bfloat16 *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, @@ -202,6 +216,9 @@ template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, std::vector &res_vectors, bool use_filters, const std::string filter_label); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const diskann::bfloat16 *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, @@ -217,6 +234,8 @@ template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const int8_t *query, size_t K, size_t L, uint32_t *indices); +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout( + const diskann::bfloat16 *query, size_t K, size_t L, uint32_t *indices); template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const int32_t tag); template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const int32_t tag); diff --git a/src/bf16_amx_kernels.cpp b/src/bf16_amx_kernels.cpp new file mode 100644 index 000000000..f46e8af55 --- /dev/null +++ b/src/bf16_amx_kernels.cpp @@ -0,0 +1,556 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "bf16_amx_kernels.h" + +#include +#include +#include +#include + +#if defined(__linux__) && (defined(__x86_64__) || defined(__i386__)) +#include +#include +#include +#endif + +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) +#include +#endif + +namespace diskann +{ + +static inline float bf16_dot_scalar(const bfloat16 *a, const bfloat16 *b, uint32_t length) +{ + float dot = 0.0f; +#ifndef _WINDOWS +#pragma omp simd reduction(+ : dot) aligned(a, b : 8) +#endif + for (int32_t i = 0; i < (int32_t)length; i++) + { + dot += a[i].to_float() * b[i].to_float(); + } + return dot; +} + +#if defined(__linux__) && (defined(__x86_64__) || defined(__i386__)) + +static inline uint64_t xgetbv_u32(uint32_t index) +{ + uint32_t eax = 0, edx = 0; + __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); + return (static_cast(edx) << 32) | eax; +} + +static inline bool cpu_has_osxsave() +{ + unsigned int eax = 0, ebx = 0, ecx = 0, edx = 0; + if (!__get_cpuid(1, &eax, &ebx, &ecx, &edx)) + return false; + return (ecx & (1u << 27)) != 0; // OSXSAVE +} + +static inline bool cpu_has_amx_bf16_hw() +{ + if (__get_cpuid_max(0, nullptr) < 7) + return false; + + unsigned int eax = 0, ebx = 0, ecx = 0, edx = 0; + __cpuid_count(7, 0, eax, ebx, ecx, edx); + + // Structured Extended Feature Flags Enumeration Leaf (CPUID.07H:EDX) + // - AMX_BF16: EDX[22] + // - AMX_TILE: EDX[24] + const bool has_amx_bf16 = (edx & (1u << 22)) != 0; + const bool has_amx_tile = (edx & (1u << 24)) != 0; + return has_amx_bf16 && has_amx_tile; +} + +static inline bool os_xcr0_allows_amx_state() +{ + if (!cpu_has_osxsave()) + return false; + + // XCR0 bits: + // - 17: XTILECFG + // - 18: XTILEDATA + const uint64_t xcr0 = xgetbv_u32(0); + const uint64_t kAmxMask = (1ULL << 17) | (1ULL << 18); + return (xcr0 & kAmxMask) == kAmxMask; +} + +// Linux xstate permission request. +// Keep local constants to avoid depending on kernel UAPI headers. +static constexpr unsigned long kXfeatureXtilecfg = 17; +static constexpr unsigned long kXfeatureXtiledData = 18; +static constexpr unsigned long kXfeatureMaskXtilecfg = (1UL << kXfeatureXtilecfg); +static constexpr unsigned long kXfeatureMaskXtiledData = (1UL << kXfeatureXtiledData); +static constexpr unsigned long kArchGetXcompPerm = 0x1022; +static constexpr unsigned long kArchReqXcompPerm = 0x1023; + +static inline bool request_linux_amx_perm_this_thread() +{ + unsigned long bitmask = 0; + long status = syscall(SYS_arch_prctl, kArchGetXcompPerm, &bitmask); + if (status != 0) + return false; + + if ((bitmask & kXfeatureMaskXtiledData) != 0) + return true; + + status = syscall(SYS_arch_prctl, kArchReqXcompPerm, kXfeatureXtiledData); + if (status != 0) + return false; + + bitmask = 0; + status = syscall(SYS_arch_prctl, kArchGetXcompPerm, &bitmask); + if (status != 0) + return false; + + return (bitmask & kXfeatureMaskXtiledData) != 0; +} + +static inline bool amx_bf16_runtime_available_impl() +{ + if (!cpu_has_amx_bf16_hw()) + return false; + if (!os_xcr0_allows_amx_state()) + return false; + + // Linux additionally requires per-thread permission before first AMX use. + return request_linux_amx_perm_this_thread(); +} + +#else + +static inline bool amx_bf16_runtime_available_impl() +{ + return false; +} + +#endif + +bool amxbf16_kernels_compiled() +{ +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + return true; +#else + return false; +#endif +} + +bool amxbf16_runtime_available() +{ +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + static thread_local int state = 0; // 0 unknown, 1 ok, -1 no + if (state == 0) + state = amx_bf16_runtime_available_impl() ? 1 : -1; + return state == 1; +#else + return false; +#endif +} + +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + +static inline void bf16_dot_amx_query_batch_impl(const bfloat16 *base, + const bfloat16 *query, + const uint32_t n_vecs, + const uint32_t dim, + float *out) +{ + // Process 32 bf16 elements per AMX dpbf16ps step. + constexpr uint32_t kStep = 32; + const uint32_t blockCount = dim / kStep; + const uint32_t tailCount = dim % kStep; + + // Tile config cache (per-thread) parameterized by A_rows. + // A: [A_rows x 64B], B: [16 x 4B], C: [A_rows x 4B] + alignas(64) static thread_local unsigned char cfg[64]; + static thread_local int prevA = -1; + + const int A_rows = static_cast(n_vecs); + const int N = 1; + const int A_colsb = static_cast(kStep * 2); // 64 + const int B_colsb = N * 4; // 4 bytes per row (2 bf16) + const int B_rows = static_cast(kStep / 2); // 16 rows + const int C_rows = A_rows; + const int C_colsb = N * 4; // 4 + + if (prevA != A_rows) + { + std::memset(cfg, 0, sizeof(cfg)); + cfg[0] = 1; + + // tile0: A + cfg[16] = (unsigned char)A_colsb; + cfg[48] = (unsigned char)A_rows; + + // tile1: B + cfg[18] = (unsigned char)B_colsb; + cfg[49] = (unsigned char)B_rows; + + // tile2: C + cfg[20] = (unsigned char)C_colsb; + cfg[50] = (unsigned char)C_rows; + + _tile_loadconfig((void *)cfg); + prevA = A_rows; + } + + _tile_zero(2); + + const int a_stride = static_cast(dim * sizeof(bfloat16)); + + for (uint32_t blk = 0; blk < blockCount; ++blk) + { + const uint32_t elem_off = blk * kStep; + _tile_loadd(0, (const void *)(base + elem_off), a_stride); + _tile_loadd(1, (const void *)(query + elem_off), 4); + _tile_dpbf16ps(2, 0, 1); + } + + // Store results: N=1, stride=4 bytes => out[0..A_rows-1] + _tile_stored(2, (void *)out, 4); + + // Tail correction (dim % 32) + if (tailCount != 0) + { + const uint32_t base_elem = blockCount * kStep; + for (uint32_t r = 0; r < n_vecs; ++r) + { + out[r] += bf16_dot_scalar(base + r * dim + base_elem, query + base_elem, tailCount); + } + } +} + +#endif + +float bf16_dot_f32_accum_amx(const bfloat16 *a, const bfloat16 *b, uint32_t length) +{ +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (!amxbf16_runtime_available()) + return bf16_dot_scalar(a, b, length); + + // Avoid AMX overhead for tiny vectors. + if (length < 256) + return bf16_dot_scalar(a, b, length); + + float out = 0.0f; + bf16_dot_amx_query_batch_impl(a, b, 1, length, &out); + return out; +#else + return bf16_dot_scalar(a, b, length); +#endif +} + +void bf16_dot_f32_accum_amx_batch(const bfloat16 *base, + const bfloat16 *query, + uint32_t n_vecs, + uint32_t dim, + float *out) +{ +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (!amxbf16_runtime_available()) + { + for (uint32_t i = 0; i < n_vecs; ++i) + out[i] = bf16_dot_scalar(base + i * dim, query, dim); + return; + } + + // Avoid AMX overhead for tiny batches or tiny dims. + if (n_vecs == 0) + return; + if (dim < 256) + { + for (uint32_t i = 0; i < n_vecs; ++i) + out[i] = bf16_dot_scalar(base + i * dim, query, dim); + return; + } + + // Kernel supports up to 16 rows per tile. + constexpr uint32_t kMaxRows = 16; + uint32_t offset = 0; + while (offset < n_vecs) + { + const uint32_t cur = std::min(kMaxRows, n_vecs - offset); + bf16_dot_amx_query_batch_impl(base + offset * dim, query, cur, dim, out + offset); + offset += cur; + } +#else + for (uint32_t i = 0; i < n_vecs; ++i) + out[i] = bf16_dot_scalar(base + i * dim, query, dim); +#endif +} + +void bf16_dot_f32_accum_amx_matmul(const bfloat16 *base, + const bfloat16 *queries, + uint32_t n_base, + uint32_t n_queries, + uint32_t dim, + float *out) +{ + if (n_base == 0 || n_queries == 0) + return; + +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (!amxbf16_runtime_available() || dim < 256) + { + for (uint32_t i = 0; i < n_base; ++i) + for (uint32_t j = 0; j < n_queries; ++j) + out[i * n_queries + j] = bf16_dot_scalar(base + i * dim, queries + j * dim, dim); + return; + } + + // We compute a (Mb x Nb) tile of C = A * B where: + // - A is [Mb x 32 bf16] (64 bytes per row) + // - B is [16 x (Nb*4) bytes] where each column is one query (packed as bf16), and 16 rows correspond to 32 bf16. + // - C is [Mb x (Nb*4) bytes] storing fp32 accumulators. + // This is a direct generalization of the existing N=1 kernel. + constexpr uint32_t kStep = 32; + const uint32_t blockCount = dim / kStep; + const uint32_t tailCount = dim % kStep; + + constexpr uint32_t kMaxMb = 16; // tile row capacity (bytes_per_row=64) + constexpr uint32_t kMaxNb = 16; // keep B_colsb = Nb*4 within 64 bytes + + alignas(64) static thread_local unsigned char cfg[64]; + static thread_local int prevMb = -1; + static thread_local int prevNb = -1; + + // Pack buffer for a B tile: [16 x (Nb*4) bytes]. + // We will pack 32 bf16 from each query into 16 rows (each row holds 2 bf16 per query => 4 bytes). + alignas(64) static thread_local uint8_t bpack[16 * 64]; + alignas(64) static thread_local float cbuf[16 * 16]; + + for (uint32_t i0 = 0; i0 < n_base; i0 += kMaxMb) + { + const uint32_t Mb = std::min(kMaxMb, n_base - i0); + const int A_rows = static_cast(Mb); + const int A_colsb = static_cast(kStep * sizeof(bfloat16)); // 64 + + for (uint32_t j0 = 0; j0 < n_queries; j0 += kMaxNb) + { + const uint32_t Nb = std::min(kMaxNb, n_queries - j0); + const int N = static_cast(Nb); + const int B_colsb = N * 4; + const int B_rows = static_cast(kStep / 2); // 16 + const int C_rows = A_rows; + const int C_colsb = N * 4; + + if (prevMb != A_rows || prevNb != N) + { + std::memset(cfg, 0, sizeof(cfg)); + cfg[0] = 1; + // tile0: A + cfg[16] = (unsigned char)A_colsb; + cfg[48] = (unsigned char)A_rows; + // tile1: B + cfg[18] = (unsigned char)B_colsb; + cfg[49] = (unsigned char)B_rows; + // tile2: C + cfg[20] = (unsigned char)C_colsb; + cfg[50] = (unsigned char)C_rows; + + _tile_loadconfig((void *)cfg); + prevMb = A_rows; + prevNb = N; + } + + _tile_zero(2); + + const int a_stride = static_cast(dim * sizeof(bfloat16)); + + for (uint32_t blk = 0; blk < blockCount; ++blk) + { + const uint32_t elem_off = blk * kStep; + + // Pack B for this block. + // Layout: bpack[row][col] where row in [0..15], col in bytes [0..B_colsb). + // For each query q, we take 32 bf16 starting at elem_off. For each pair (2 bf16) + // we write 4 bytes into bpack[row] at offset q*4. + const bfloat16 *qptr = queries + (j0 * dim) + elem_off; + for (uint32_t r = 0; r < 16; ++r) + { + uint8_t *dst_row = bpack + r * 64; + for (uint32_t q = 0; q < Nb; ++q) + { + const uint16_t *src16 = reinterpret_cast(qptr + q * dim + (r * 2)); + std::memcpy(dst_row + q * 4, src16, 4); + } + } + + _tile_loadd(0, (const void *)(base + (i0 * dim) + elem_off), a_stride); + _tile_loadd(1, (const void *)bpack, 64); + _tile_dpbf16ps(2, 0, 1); + } + + // Store C tile. C is [Mb x Nb] fp32 laid out with row stride (Nb*4 bytes). + _tile_stored(2, (void *)cbuf, (int)(Nb * sizeof(float))); + + // Write out with tail correction. + for (uint32_t ii = 0; ii < Mb; ++ii) + { + for (uint32_t jj = 0; jj < Nb; ++jj) + { + float v = cbuf[ii * Nb + jj]; + if (tailCount != 0) + { + const uint32_t base_elem = blockCount * kStep; + v += bf16_dot_scalar(base + (i0 + ii) * dim + base_elem, queries + (j0 + jj) * dim + base_elem, + tailCount); + } + out[(i0 + ii) * n_queries + (j0 + jj)] = v; + } + } + } + } +#else + for (uint32_t i = 0; i < n_base; ++i) + for (uint32_t j = 0; j < n_queries; ++j) + out[i * n_queries + j] = bf16_dot_scalar(base + i * dim, queries + j * dim, dim); +#endif +} + +void bf16_dot_f32_accum_amx_matmul_gather(const bfloat16 *data, + uint32_t data_stride, + const uint32_t *base_ids, + uint32_t n_base, + const uint32_t *query_ids, + uint32_t n_queries, + uint32_t dim, + float *out) +{ + if (n_base == 0 || n_queries == 0) + return; + +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (!amxbf16_runtime_available() || dim < 256) + { + for (uint32_t i = 0; i < n_base; ++i) + { + const bfloat16 *a = data + (size_t)base_ids[i] * (size_t)data_stride; + for (uint32_t j = 0; j < n_queries; ++j) + { + const bfloat16 *b = data + (size_t)query_ids[j] * (size_t)data_stride; + out[i * n_queries + j] = bf16_dot_scalar(a, b, dim); + } + } + return; + } + + constexpr uint32_t kStep = 32; + const uint32_t blockCount = dim / kStep; + const uint32_t tailCount = dim % kStep; + + constexpr uint32_t kMaxMb = 16; + constexpr uint32_t kMaxNb = 16; + + alignas(64) static thread_local unsigned char cfg[64]; + static thread_local int prevMb = -1; + static thread_local int prevNb = -1; + + alignas(64) static thread_local uint8_t apack[16 * 64]; + alignas(64) static thread_local uint8_t bpack[16 * 64]; + alignas(64) static thread_local float cbuf[16 * 16]; + + for (uint32_t i0 = 0; i0 < n_base; i0 += kMaxMb) + { + const uint32_t Mb = std::min(kMaxMb, n_base - i0); + const int A_rows = static_cast(Mb); + const int A_colsb = static_cast(kStep * sizeof(bfloat16)); // 64 + + for (uint32_t j0 = 0; j0 < n_queries; j0 += kMaxNb) + { + const uint32_t Nb = std::min(kMaxNb, n_queries - j0); + const int N = static_cast(Nb); + const int B_colsb = N * 4; + const int B_rows = static_cast(kStep / 2); // 16 + const int C_rows = A_rows; + const int C_colsb = N * 4; + + if (prevMb != A_rows || prevNb != N) + { + std::memset(cfg, 0, sizeof(cfg)); + cfg[0] = 1; + // tile0: A + cfg[16] = (unsigned char)A_colsb; + cfg[48] = (unsigned char)A_rows; + // tile1: B + cfg[18] = (unsigned char)B_colsb; + cfg[49] = (unsigned char)B_rows; + // tile2: C + cfg[20] = (unsigned char)C_colsb; + cfg[50] = (unsigned char)C_rows; + + _tile_loadconfig((void *)cfg); + prevMb = A_rows; + prevNb = N; + } + + _tile_zero(2); + + for (uint32_t blk = 0; blk < blockCount; ++blk) + { + const uint32_t elem_off = blk * kStep; + + // Pack A rows for this block: apack[row][0..63] = 32 bf16 + for (uint32_t r = 0; r < Mb; ++r) + { + const bfloat16 *src = data + (size_t)base_ids[i0 + r] * (size_t)data_stride + elem_off; + std::memcpy(apack + r * 64, src, 64); + } + + // Pack B for this block. + for (uint32_t r = 0; r < 16; ++r) + { + uint8_t *dst_row = bpack + r * 64; + for (uint32_t q = 0; q < Nb; ++q) + { + const bfloat16 *srcq = + data + (size_t)query_ids[j0 + q] * (size_t)data_stride + elem_off + (r * 2); + const uint16_t *src16 = reinterpret_cast(srcq); + std::memcpy(dst_row + q * 4, src16, 4); + } + } + + _tile_loadd(0, (const void *)apack, 64); + _tile_loadd(1, (const void *)bpack, 64); + _tile_dpbf16ps(2, 0, 1); + } + + _tile_stored(2, (void *)cbuf, (int)(Nb * sizeof(float))); + + for (uint32_t ii = 0; ii < Mb; ++ii) + { + const bfloat16 *a_full = data + (size_t)base_ids[i0 + ii] * (size_t)data_stride; + for (uint32_t jj = 0; jj < Nb; ++jj) + { + float v = cbuf[ii * Nb + jj]; + if (tailCount != 0) + { + const uint32_t base_elem = blockCount * kStep; + const bfloat16 *b_full = data + (size_t)query_ids[j0 + jj] * (size_t)data_stride; + v += bf16_dot_scalar(a_full + base_elem, b_full + base_elem, tailCount); + } + out[(i0 + ii) * n_queries + (j0 + jj)] = v; + } + } + } + } +#else + (void)data_stride; + for (uint32_t i = 0; i < n_base; ++i) + { + const bfloat16 *a = data + (size_t)base_ids[i] * (size_t)data_stride; + for (uint32_t j = 0; j < n_queries; ++j) + { + const bfloat16 *b = data + (size_t)query_ids[j] * (size_t)data_stride; + out[i * n_queries + j] = bf16_dot_scalar(a, b, dim); + } + } +#endif +} + +} // namespace diskann diff --git a/src/bf16_simd_kernels.cpp b/src/bf16_simd_kernels.cpp new file mode 100644 index 000000000..275f23139 --- /dev/null +++ b/src/bf16_simd_kernels.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "bf16_simd_kernels.h" + +#include +#include + +#if defined(__AVX512BF16__) && defined(__AVX512F__) +#include +#endif + +namespace diskann +{ + +static inline float bf16_dot_scalar(const bfloat16 *a, const bfloat16 *b, uint32_t length) +{ + float dot = 0.0f; +#ifndef _WINDOWS +#pragma omp simd reduction(+ : dot) aligned(a, b : 8) +#endif + for (int32_t i = 0; i < (int32_t)length; i++) + { + dot += a[i].to_float() * b[i].to_float(); + } + return dot; +} + +#if defined(__AVX512BF16__) && defined(__AVX512F__) + +bool avx512bf16_kernels_compiled() +{ + return true; +} + +// AVX-512 BF16 dot: each _mm512_dpbf16_ps consumes 32 bf16 elements and accumulates +// into 16 fp32 lanes (pairwise dot). We reduce the accumulator at the end. +float bf16_dot_f32_accum(const bfloat16 *a, const bfloat16 *b, uint32_t length) +{ + constexpr uint32_t kStep = 32; + + __m512 acc = _mm512_setzero_ps(); + uint32_t i = 0; + + for (; i + (kStep - 1) < length; i += kStep) + { + // Load 32 bf16 values (64 bytes) for each vector. + const __m512i va_i = _mm512_loadu_si512((const void *)(a + i)); + const __m512i vb_i = _mm512_loadu_si512((const void *)(b + i)); + + // Reinterpret as bf16 vectors. + const __m512bh va = (__m512bh)va_i; + const __m512bh vb = (__m512bh)vb_i; + + acc = _mm512_dpbf16_ps(acc, va, vb); + } + + alignas(64) float lanes[16]; + _mm512_store_ps(lanes, acc); + + float dot = 0.0f; + for (int lane = 0; lane < 16; ++lane) + dot += lanes[lane]; + + // Remainder. + if (i < length) + { + dot += bf16_dot_scalar(a + i, b + i, length - i); + } + + return dot; +} + +#else + +bool avx512bf16_kernels_compiled() +{ + return false; +} + +float bf16_dot_f32_accum(const bfloat16 *a, const bfloat16 *b, uint32_t length) +{ + return bf16_dot_scalar(a, b, length); +} + +#endif + +} // namespace diskann diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 22f1e98fd..c796b9c94 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -16,9 +16,129 @@ #include "percentile_stats.h" #include "partition.h" #include "pq_flash_index.h" +#include "rabitq.h" #include "timer.h" #include "tsl/robin_set.h" +namespace +{ +#pragma pack(push, 1) +struct RaBitQReorderHeader +{ + char magic[8]; + uint32_t version; + uint32_t metric; + uint32_t nb_bits; + uint32_t dim; + uint64_t num_points; + uint64_t code_size; +}; +#pragma pack(pop) + +static void write_rabitq_reorder_header(std::ofstream &out, uint32_t metric, uint32_t nb_bits, uint32_t dim, + uint64_t num_points, uint64_t code_size) +{ + RaBitQReorderHeader hdr; + std::memset(&hdr, 0, sizeof(hdr)); + hdr.magic[0] = 'D'; + hdr.magic[1] = 'A'; + hdr.magic[2] = 'R'; + hdr.magic[3] = 'B'; + hdr.magic[4] = 'Q'; + hdr.magic[5] = '1'; + hdr.magic[6] = '\0'; + hdr.magic[7] = '\0'; + hdr.version = 1; + hdr.metric = metric; + hdr.nb_bits = nb_bits; + hdr.dim = dim; + hdr.num_points = num_points; + hdr.code_size = code_size; + out.write(reinterpret_cast(&hdr), sizeof(hdr)); +} + +template +static void generate_rabitq_reorder_codes_from_bin(const std::string &data_file_to_use, const std::string &output_file, + diskann::rabitq::Metric metric, uint32_t nb_bits) +{ + std::ifstream in(data_file_to_use, std::ios::binary); + if (!in) + { + throw diskann::ANNException("Failed to open data file for RaBitQ code generation: " + data_file_to_use, -1); + } + + uint32_t npts_u32 = 0, dim_u32 = 0; + in.read(reinterpret_cast(&npts_u32), sizeof(uint32_t)); + in.read(reinterpret_cast(&dim_u32), sizeof(uint32_t)); + if (!in) + { + throw diskann::ANNException("Failed reading header from data file for RaBitQ code generation: " + + data_file_to_use, + -1); + } + + const uint64_t npts = npts_u32; + const uint64_t dim = dim_u32; + const uint64_t code_size = + diskann::rabitq::compute_code_size(static_cast(dim), static_cast(nb_bits)); + + std::ofstream out(output_file, std::ios::binary); + if (!out) + { + throw diskann::ANNException("Failed to open output file for RaBitQ code generation: " + output_file, -1); + } + + write_rabitq_reorder_header(out, static_cast(metric), nb_bits, dim_u32, npts, code_size); + + if (npts == 0) + return; + + const uint64_t kBlockPts = 100000; + const uint64_t block_pts = std::min(kBlockPts, npts); + std::vector in_block; + in_block.resize(static_cast(block_pts * dim)); + + std::vector out_codes; + out_codes.resize(static_cast(block_pts * code_size)); + + std::vector tmp; + tmp.resize(static_cast(dim)); + + const uint64_t num_blocks = DIV_ROUND_UP(npts, block_pts); + for (uint64_t b = 0; b < num_blocks; ++b) + { + const uint64_t start_id = b * block_pts; + const uint64_t end_id = std::min(npts, start_id + block_pts); + const uint64_t cur_pts = end_id - start_id; + + in.read(reinterpret_cast(in_block.data()), static_cast(cur_pts * dim * sizeof(T))); + if (!in) + { + throw diskann::ANNException("Failed reading data payload from: " + data_file_to_use, -1); + } + + for (uint64_t i = 0; i < cur_pts; ++i) + { + const T *row = in_block.data() + i * dim; + for (uint64_t j = 0; j < dim; ++j) + { + tmp[static_cast(j)] = static_cast(row[j]); + } + uint8_t *code = out_codes.data() + i * code_size; + diskann::rabitq::encode_vector(tmp.data(), static_cast(dim), metric, static_cast(nb_bits), + code); + } + + out.write(reinterpret_cast(out_codes.data()), static_cast(cur_pts * code_size)); + if (!out) + { + throw diskann::ANNException("Failed writing RaBitQ codes to: " + output_file, -1); + } + } +} + +} // namespace + namespace diskann { @@ -147,7 +267,14 @@ template T *generateRandomWarmup(uint64_t warmup_num, uint64_t warm { for (uint32_t d = 0; d < warmup_dim; d++) { - warmup[i * warmup_aligned_dim + d] = (T)dis(gen); + if constexpr (std::is_same::value) + { + warmup[i * warmup_aligned_dim + d] = (T)(float)dis(gen); + } + else + { + warmup[i * warmup_aligned_dim + d] = (T)dis(gen); + } } } diskann::cout << "..done" << std::endl; @@ -843,7 +970,7 @@ uint32_t optimize_beamwidth(std::unique_ptr> &p return best_bw; } -template +template void create_disk_layout(const std::string base_file, const std::string mem_index_file, const std::string output_file, const std::string reorder_data_file) { @@ -880,7 +1007,8 @@ void create_disk_layout(const std::string base_file, const std::string mem_index throw ANNException("Mismatch in num_points between reorder " "data file and base file", -1, __FUNCSIG__, __FILE__, __LINE__); - if (reorder_data_file_size != 8 + sizeof(float) * (size_t)npts_reorder_file * (size_t)ndims_reorder_file) + if (reorder_data_file_size != + 8 + sizeof(ReorderT) * (size_t)npts_reorder_file * (size_t)ndims_reorder_file) throw ANNException("Discrepancy in reorder data file size ", -1, __FUNCSIG__, __FILE__, __LINE__); } catch (std::system_error &e) @@ -942,7 +1070,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index if (append_reorder_data) { - n_data_nodes_per_sector = defaults::SECTOR_LEN / (ndims_reorder_file * sizeof(float)); + n_data_nodes_per_sector = defaults::SECTOR_LEN / (ndims_reorder_file * sizeof(ReorderT)); n_reorder_sectors = ROUND_UP(npts_64, n_data_nodes_per_sector) / n_data_nodes_per_sector; } uint64_t disk_index_file_size = (n_sectors + n_reorder_sectors + 1) * defaults::SECTOR_LEN; @@ -961,6 +1089,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index output_file_meta.push_back(n_sectors + 1); output_file_meta.push_back(ndims_reorder_file); output_file_meta.push_back(n_data_nodes_per_sector); + output_file_meta.push_back(sizeof(ReorderT)); } output_file_meta.push_back(disk_index_file_size); @@ -1067,7 +1196,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index { diskann::cout << "Index written. Appending reorder data..." << std::endl; - auto vec_len = ndims_reorder_file * sizeof(float); + auto vec_len = ndims_reorder_file * sizeof(ReorderT); std::unique_ptr vec_buf = std::make_unique(vec_len); for (uint64_t sector = 0; sector < n_reorder_sectors; sector++) @@ -1079,9 +1208,13 @@ void create_disk_layout(const std::string base_file, const std::string mem_index memset(sector_buf.get(), 0, defaults::SECTOR_LEN); - for (uint64_t sector_node_id = 0; sector_node_id < n_data_nodes_per_sector && sector_node_id < npts_64; - sector_node_id++) + for (uint64_t sector_node_id = 0; sector_node_id < n_data_nodes_per_sector; sector_node_id++) { + const uint64_t global_node_id = sector * n_data_nodes_per_sector + sector_node_id; + if (global_node_id >= npts_64) + { + break; + } memset(vec_buf.get(), 0, vec_len); reorder_data_reader.read(vec_buf.get(), vec_len); @@ -1097,6 +1230,14 @@ void create_disk_layout(const std::string base_file, const std::string mem_index diskann::cout << "Output disk index file written to " << output_file << std::endl; } +// Backwards-compatible entry point: reorder data is stored as float. +template +void create_disk_layout(const std::string base_file, const std::string mem_index_file, const std::string output_file, + const std::string reorder_data_file) +{ + create_disk_layout(base_file, mem_index_file, output_file, reorder_data_file); +} + template int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, @@ -1111,7 +1252,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const { param_list.push_back(cur_param); } - if (param_list.size() < 5 || param_list.size() > 9) + if (param_list.size() < 5 || param_list.size() > 11) { diskann::cout << "Correct usage of parameters is R (max degree)\n" "L (indexing list size, better if >= R)\n" @@ -1124,12 +1265,14 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const ": optional paramter, use only when using disk PQ\n" "build_PQ_byte (number of PQ bytes for inde build; set 0 to use " "full precision vectors)\n" - "QD Quantized Dimension to overwrite the derived dim from B " + "QD Quantized Dimension to overwrite the derived dim from B\n" + "build_rabitq_main_codes (0/1, optional; generates _disk.index_rabitq_main.bin)\n" + "rabitq_nb_bits (1..9, optional; default 4)" << std::endl; return -1; } - if (!std::is_same::value && + if (!diskann::is_floating_point_like_v && (compareMetric == diskann::Metric::INNER_PRODUCT || compareMetric == diskann::Metric::COSINE)) { std::stringstream stream; @@ -1168,6 +1311,20 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const build_pq_bytes = atoi(param_list[7].c_str()); } + bool build_rabitq_main_codes = false; + uint32_t rabitq_nb_bits = 4; + if (param_list.size() >= 10) + { + if (1 == atoi(param_list[9].c_str())) + { + build_rabitq_main_codes = true; + } + } + if (param_list.size() >= 11) + { + rabitq_nb_bits = static_cast(atoi(param_list[10].c_str())); + } + std::string base_file(dataFilePath); std::string data_file_to_use = base_file; std::string labels_file_original = label_file; @@ -1227,7 +1384,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const "apart from the interim indices created by DiskANN and the final index." << std::endl; data_file_to_use = prepped_base; - diskann::normalize_data_file(base_file, prepped_base); + diskann::normalize_data_file_typed(base_file, prepped_base); diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for cosine") << std::endl; created_temp_file_for_processed_data = true; } @@ -1338,11 +1495,35 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const if (!reorder_data) diskann::create_disk_layout(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path); else - diskann::create_disk_layout(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path, - data_file_to_use.c_str()); + diskann::create_disk_layout(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path, + data_file_to_use.c_str()); } diskann::cout << timer.elapsed_seconds_for_step("generating disk layout") << std::endl; + if (build_rabitq_main_codes) + { + if (rabitq_nb_bits < 1 || rabitq_nb_bits > 9) + { + throw diskann::ANNException("rabitq_nb_bits must be in [1,9]", -1); + } + if (compareMetric != diskann::Metric::INNER_PRODUCT) + { + throw diskann::ANNException("RaBitQ main code generation is currently supported only for MIPS/IP.", -1); + } + if (!diskann::is_floating_point_like_v) + { + throw diskann::ANNException("RaBitQ main code generation requires floating point data.", -1); + } + + const std::string rabitq_codes_path = disk_index_path + "_rabitq_main.bin"; + Timer rabitq_timer; + diskann::cout << "Generating RaBitQ main codes to " << rabitq_codes_path << " (nb_bits=" << rabitq_nb_bits + << ")" << std::endl; + generate_rabitq_reorder_codes_from_bin(data_file_to_use, rabitq_codes_path, + diskann::rabitq::Metric::INNER_PRODUCT, rabitq_nb_bits); + diskann::cout << rabitq_timer.elapsed_seconds_for_step("generating rabitq main codes") << std::endl; + } + double ten_percent_points = std::ceil(points_num * 0.1); double num_sample_points = ten_percent_points > MAX_SAMPLE_POINTS_FOR_WARMUP ? MAX_SAMPLE_POINTS_FOR_WARMUP : ten_percent_points; @@ -1387,6 +1568,23 @@ template DISKANN_DLLEXPORT void create_disk_layout(const std::string ba template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, const std::string output_file, const std::string reorder_data_file); +template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, + const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file); + +template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, + const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file); +template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, + const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file); +template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, + const std::string mem_index_file, + const std::string output_file, + const std::string reorder_data_file); template DISKANN_DLLEXPORT int8_t *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim); @@ -1394,6 +1592,9 @@ template DISKANN_DLLEXPORT uint8_t *load_warmup(const std::string &cach uint64_t warmup_dim, uint64_t warmup_aligned_dim); template DISKANN_DLLEXPORT float *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim); +template DISKANN_DLLEXPORT diskann::bfloat16 *load_warmup(const std::string &cache_warmup_file, + uint64_t &warmup_num, uint64_t warmup_dim, + uint64_t warmup_aligned_dim); #ifdef EXEC_ENV_OLS template DISKANN_DLLEXPORT int8_t *load_warmup(MemoryMappedFiles &files, const std::string &cache_warmup_file, @@ -1416,6 +1617,9 @@ template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( std::unique_ptr> &pFlashIndex, float *tuning_sample, uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); +template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, diskann::bfloat16 *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( std::unique_ptr> &pFlashIndex, int8_t *tuning_sample, @@ -1426,6 +1630,9 @@ template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( std::unique_ptr> &pFlashIndex, float *tuning_sample, uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); +template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, diskann::bfloat16 *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, @@ -1448,6 +1655,10 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *dat const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, + const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf); // LabelT = uint16 template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, @@ -1470,6 +1681,10 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *dat const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf); +template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, + const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, diff --git a/src/distance.cpp b/src/distance.cpp index c2f88c85b..968524856 100644 --- a/src/distance.cpp +++ b/src/distance.cpp @@ -15,6 +15,8 @@ #include #include "distance.h" +#include "bf16_simd_kernels.h" +#include "bf16_amx_kernels.h" #include "utils.h" #include "logger.h" #include "ann_exception.h" @@ -22,6 +24,79 @@ namespace diskann { +float DistanceL2BFloat16::compare(const bfloat16 *a, const bfloat16 *b, uint32_t size) const +{ + if (Avx512Bf16SupportedCPU && avx512bf16_kernels_compiled()) + { + const float aa = bf16_dot_f32_accum(a, a, size); + const float bb = bf16_dot_f32_accum(b, b, size); + const float ab = bf16_dot_f32_accum(a, b, size); + return aa + bb - 2.0f * ab; + } + + float result = 0.0f; +#ifndef _WINDOWS +#pragma omp simd reduction(+ : result) aligned(a, b : 8) +#endif + for (int32_t i = 0; i < (int32_t)size; i++) + { + const float da = a[i].to_float(); + const float db = b[i].to_float(); + const float diff = da - db; + result += diff * diff; + } + return result; +} + +float DistanceCosineBFloat16::compare(const bfloat16 *a, const bfloat16 *b, uint32_t length) const +{ + if (Avx512Bf16SupportedCPU && avx512bf16_kernels_compiled()) + { + const float magA = bf16_dot_f32_accum(a, a, length); + const float magB = bf16_dot_f32_accum(b, b, length); + const float scalarProduct = bf16_dot_f32_accum(a, b, length); + return 1.0f - (scalarProduct / (sqrt(magA) * sqrt(magB))); + } + + float magA = 0.0f, magB = 0.0f, scalarProduct = 0.0f; +#ifndef _WINDOWS +#pragma omp simd reduction(+ : magA, magB, scalarProduct) aligned(a, b : 8) +#endif + for (uint32_t i = 0; i < length; i++) + { + const float da = a[i].to_float(); + const float db = b[i].to_float(); + magA += da * da; + magB += db * db; + scalarProduct += da * db; + } + return 1.0f - (scalarProduct / (sqrt(magA) * sqrt(magB))); +} + +float DistanceInnerProductBFloat16::compare(const bfloat16 *a, const bfloat16 *b, uint32_t length) const +{ + if (amxbf16_kernels_compiled() && amxbf16_runtime_available()) + { + return -bf16_dot_f32_accum_amx(a, b, length); + } + + if (Avx512Bf16SupportedCPU && avx512bf16_kernels_compiled()) + { + return -bf16_dot_f32_accum(a, b, length); + } + + float dot = 0.0f; +#ifndef _WINDOWS +#pragma omp simd reduction(+ : dot) aligned(a, b : 8) +#endif + for (uint32_t i = 0; i < length; i++) + { + dot += a[i].to_float() * b[i].to_float(); + } + // Match DistanceInnerProduct semantics: return negative inner product as a distance. + return -dot; +} + // // Base Class Implementatons // @@ -714,6 +789,32 @@ template <> diskann::Distance *get_distance_function(diskann::Metric m) } } +template <> diskann::Distance *get_distance_function(diskann::Metric m) +{ + if (m == diskann::Metric::L2) + { + diskann::cout << "L2: Using bf16 distance computation DistanceL2BFloat16" << std::endl; + return new diskann::DistanceL2BFloat16(); + } + else if (m == diskann::Metric::COSINE) + { + diskann::cout << "Cosine: Using bf16 distance computation DistanceCosineBFloat16" << std::endl; + return new diskann::DistanceCosineBFloat16(); + } + else if (m == diskann::Metric::INNER_PRODUCT) + { + diskann::cout << "Inner product: Using bf16 distance computation DistanceInnerProductBFloat16" << std::endl; + return new diskann::DistanceInnerProductBFloat16(); + } + else + { + std::stringstream stream; + stream << "Only L2, cosine, and inner product supported for bf16 vectors." << std::endl; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } +} + template DISKANN_DLLEXPORT class DistanceInnerProduct; template DISKANN_DLLEXPORT class DistanceInnerProduct; template DISKANN_DLLEXPORT class DistanceInnerProduct; @@ -729,5 +830,6 @@ template DISKANN_DLLEXPORT class SlowDistanceL2; template DISKANN_DLLEXPORT Distance *get_distance_function(Metric m); template DISKANN_DLLEXPORT Distance *get_distance_function(Metric m); template DISKANN_DLLEXPORT Distance *get_distance_function(Metric m); +template DISKANN_DLLEXPORT Distance *get_distance_function(Metric m); } // namespace diskann diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index 85dea1af5..2fc0cea61 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -222,8 +222,10 @@ template void InMemDataStore::get_distance(const data_t *preprocessed_query, const std::vector &ids, std::vector &distances, AbstractScratch *scratch_space) const { + // printf("---->InMemDataStore::get_distance ids.size: %ld\n", ids.size()); for (int i = 0; i < ids.size(); i++) { + // printf("ids[i]:%d\n",ids[i]); distances[i] = _distance_fn->compare(preprocessed_query, _data + ids[i] * _aligned_dim, (uint32_t)this->_aligned_dim); } @@ -397,5 +399,6 @@ template Distance *InMemDataStore::get_dist_fn template DISKANN_DLLEXPORT class InMemDataStore; template DISKANN_DLLEXPORT class InMemDataStore; template DISKANN_DLLEXPORT class InMemDataStore; +template DISKANN_DLLEXPORT class InMemDataStore; } // namespace diskann \ No newline at end of file diff --git a/src/index.cpp b/src/index.cpp index 4b38027d7..102c8720b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3,6 +3,7 @@ #include +#include #include #include "boost/dynamic_bitset.hpp" @@ -22,6 +23,7 @@ #endif #include "index.h" +#include "bf16_amx_kernels.h" #define MAX_POINTS_FOR_USING_BITSET 10000000 @@ -145,14 +147,27 @@ Index::Index(Metric m, const size_t dim, const size_t max_point (size_t)((index_parameters == nullptr ? 0 : index_parameters->max_degree) * defaults::GRAPH_SLACK_FACTOR * 1.05))) { - if (_pq_dist) + if constexpr (std::is_same::value) { - _pq_data_store = IndexFactory::construct_pq_datastore(DataStoreStrategy::MEMORY, max_points + num_frozen_pts, - dim, m, num_pq_chunks, use_opq); + if (_pq_dist) + { + throw ANNException("ERROR: pq_dist_build is not supported for bf16 yet.", -1, __FUNCSIG__, __FILE__, + __LINE__); + } + _pq_data_store = _data_store; } else { - _pq_data_store = _data_store; + if (_pq_dist) + { + _pq_data_store = IndexFactory::construct_pq_datastore(DataStoreStrategy::MEMORY, + max_points + num_frozen_pts, dim, m, + num_pq_chunks, use_opq); + } + else + { + _pq_data_store = _data_store; + } } } @@ -1089,6 +1104,65 @@ void Index::occlude_list(const uint32_t location, std::vector) + { + return false; + } + if (!((_dist_metric == diskann::Metric::L2) || (_dist_metric == diskann::Metric::COSINE))) + { + return false; + } + if (!(amxbf16_kernels_compiled() && amxbf16_runtime_available())) + { + return false; + } + if (_data_store == nullptr) + { + return false; + } + const T *raw = _data_store->get_raw_data(); + return raw != nullptr; + }(); + + std::vector candidate_ids; + std::vector candidate_norms; + std::vector candidate_dots; + size_t candidate_count = 0; + size_t aligned_dim = 0; + + if (use_amx_mxn) + { + candidate_count = pool.size(); + aligned_dim = _data_store->get_aligned_dim(); + const diskann::bfloat16 *raw = reinterpret_cast(_data_store->get_raw_data()); + + candidate_ids.resize(candidate_count); + candidate_norms.resize(candidate_count); + + for (size_t i = 0; i < candidate_count; ++i) + { + candidate_ids[i] = pool[i].id; + const diskann::bfloat16 *vec = raw + (static_cast(candidate_ids[i]) * aligned_dim); + float norm = 0.0f; + for (size_t d = 0; d < aligned_dim; ++d) + { + const float v = vec[d].to_float(); + norm += v * v; + } + candidate_norms[i] = norm; + } + + candidate_dots.resize(candidate_count * candidate_count); + bf16_dot_f32_accum_amx_matmul_gather(raw, (uint32_t)aligned_dim, candidate_ids.data(), (uint32_t)candidate_count, + candidate_ids.data(), (uint32_t)candidate_count, (uint32_t)aligned_dim, + candidate_dots.data()); + } + float cur_alpha = 1; while (cur_alpha <= alpha && result.size() < degree) { @@ -1142,7 +1216,30 @@ void Index::occlude_list(const uint32_t location, std::vectorget_distance(iter2->id, iter->id); + float djk = 0.0f; + if (use_amx_mxn) + { + const size_t i = (size_t)(iter - pool.begin()); + const size_t j = (size_t)(iter2 - pool.begin()); + const float dot = candidate_dots[j * candidate_count + i]; + + if (_dist_metric == diskann::Metric::L2) + { + djk = candidate_norms[i] + candidate_norms[j] - 2.0f * dot; + if (djk < 0.0f) + djk = 0.0f; + } + else + { + const float denom = std::sqrt(candidate_norms[i]) * std::sqrt(candidate_norms[j]); + const float cos_sim = (denom > 0.0f) ? (dot / denom) : 0.0f; + djk = 1.0f - cos_sim; + } + } + else + { + djk = _data_store->get_distance(iter2->id, iter->id); + } if (_dist_metric == diskann::Metric::L2 || _dist_metric == diskann::Metric::COSINE) { occlude_factor[t] = (djk == 0) ? std::numeric_limits::max() @@ -3203,6 +3300,13 @@ template void Index::value) + { + throw diskann::ANNException( + "Optimized index layout is only supported for float (FAST_L2). For bf16, disable optimized layout.", -1, + __FUNCSIG__, __FILE__, __LINE__); + } + float *cur_vec = new float[_data_store->get_aligned_dim()]; std::memset(cur_vec, 0, _data_store->get_aligned_dim() * sizeof(float)); _data_len = (_data_store->get_aligned_dim() + 1) * sizeof(float); @@ -3253,6 +3357,13 @@ void Index::_search_with_optimized_layout(const DataType &query template void Index::search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices) { + if constexpr (std::is_same::value) + { + throw diskann::ANNException( + "search_with_optimized_layout is only supported for float (FAST_L2). For bf16, disable optimized layout.", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + DistanceFastL2 *dist_fast = (DistanceFastL2 *)(_data_store->get_dist_fn()); NeighborPriorityQueue retset(L); @@ -3340,15 +3451,19 @@ template const float Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; @@ -3356,15 +3471,19 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; diff --git a/src/index_factory.cpp b/src/index_factory.cpp index 35790f8d6..c5f2dca9e 100644 --- a/src/index_factory.cpp +++ b/src/index_factory.cpp @@ -1,6 +1,8 @@ #include "index_factory.h" #include "pq_l2_distance.h" +#include "bfloat16.h" + namespace diskann { @@ -34,13 +36,17 @@ void IndexFactory::check_config() -1, __FUNCSIG__, __FILE__, __LINE__); } - if (_config->data_type != "float" && _config->data_type != "uint8" && _config->data_type != "int8") + const bool is_bf16 = (_config->data_type == "bf16" || _config->data_type == "bfloat16"); + if (_config->data_type != "float" && _config->data_type != "uint8" && _config->data_type != "int8" && + !is_bf16) { throw ANNException("ERROR: invalid data type : + " + _config->data_type + - " is not supported. please select from [float, int8, uint8]", + " is not supported. please select from [float, int8, uint8, bf16]", -1); } + // bf16 now supports pq_dist_build via PQDataStore (internally converts queries to float). + if (_config->tag_type != "int32" && _config->tag_type != "uint32" && _config->tag_type != "int64" && _config->tag_type != "uint64") { @@ -126,9 +132,12 @@ std::unique_ptr IndexFactory::create_instance() if (_config->data_strategy == DataStoreStrategy::MEMORY && _config->pq_dist_build) { - pq_data_store = - construct_pq_datastore(_config->data_strategy, num_points + _config->num_frozen_pts, dim, - _config->metric, _config->num_pq_chunks, _config->use_opq); + pq_data_store = construct_pq_datastore(_config->data_strategy, + num_points + _config->num_frozen_pts, + dim, + _config->metric, + _config->num_pq_chunks, + _config->use_opq); } else { @@ -161,8 +170,12 @@ std::unique_ptr IndexFactory::create_instance(const std::string & { return create_instance(tag_type, label_type); } + else if (data_type == std::string("bf16") || data_type == std::string("bfloat16")) + { + return create_instance(tag_type, label_type); + } else - throw ANNException("Error: unsupported data_type please choose from [float/int8/uint8]", -1); + throw ANNException("Error: unsupported data_type please choose from [float/int8/uint8/bf16]", -1); } template diff --git a/src/partition.cpp b/src/partition.cpp index d0061708a..b198aabd5 100644 --- a/src/partition.cpp +++ b/src/partition.cpp @@ -21,6 +21,7 @@ #include "parameters.h" #include "memory_mapper.h" #include "partition.h" +#include "bfloat16.h" #ifdef _WINDOWS #include #endif @@ -611,6 +612,9 @@ template void DISKANN_DLLEXPORT gen_random_slice(const std::string base double sampling_rate); template void DISKANN_DLLEXPORT gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate); +template void DISKANN_DLLEXPORT gen_random_slice(const std::string base_file, + const std::string output_prefix, + double sampling_rate); template void DISKANN_DLLEXPORT gen_random_slice(const float *inputdata, size_t npts, size_t ndims, double p_val, float *&sampled_data, size_t &slice_size); @@ -625,6 +629,9 @@ template void DISKANN_DLLEXPORT gen_random_slice(const std::string data float *&sampled_data, size_t &slice_size, size_t &ndims); template void DISKANN_DLLEXPORT gen_random_slice(const std::string data_file, double p_val, float *&sampled_data, size_t &slice_size, size_t &ndims); +template void DISKANN_DLLEXPORT gen_random_slice(const std::string data_file, double p_val, + float *&sampled_data, size_t &slice_size, + size_t &ndims); template DISKANN_DLLEXPORT int partition(const std::string data_file, const float sampling_rate, size_t num_centers, size_t max_k_means_reps, @@ -647,6 +654,9 @@ template DISKANN_DLLEXPORT int partition_with_ram_budget(const std::str template DISKANN_DLLEXPORT int partition_with_ram_budget(const std::string data_file, const double sampling_rate, double ram_budget, size_t graph_degree, const std::string prefix_path, size_t k_base); +template DISKANN_DLLEXPORT int partition_with_ram_budget( + const std::string data_file, const double sampling_rate, double ram_budget, size_t graph_degree, + const std::string prefix_path, size_t k_base); template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids(const std::string data_file, std::string idmap_filename, @@ -656,4 +666,7 @@ template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids(const std:: std::string data_filename); template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids(const std::string data_file, std::string idmap_filename, - std::string data_filename); \ No newline at end of file + std::string data_filename); +template DISKANN_DLLEXPORT int retrieve_shard_data_from_ids(const std::string data_file, + std::string idmap_filename, + std::string data_filename); \ No newline at end of file diff --git a/src/pq.cpp b/src/pq.cpp index d2b545c79..837069210 100644 --- a/src/pq.cpp +++ b/src/pq.cpp @@ -9,6 +9,7 @@ #include "partition.h" #include "math_utils.h" #include "tsl/robin_map.h" +#include "bfloat16.h" // block size for reading/processing large files and matrices in blocks #define BLOCK_SIZE 5000000 @@ -258,6 +259,7 @@ void FixedChunkPQTable::populate_chunk_inner_products(const float *query_vec, fl { memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); // chunk wise distance computation + for (size_t chunk = 0; chunk < n_chunks; chunk++) { // sum (q-c)^2 for the dimensions associated with this chunk @@ -1081,12 +1083,10 @@ void generate_disk_quantized_data(const std::string &data_file_to_use, const std std::cout << "Compressing base for disk-PQ into " << disk_pq_dims << " chunks " << std::endl; generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, 256, (uint32_t)disk_pq_dims, NUM_KMEANS_REPS_PQ, disk_pq_pivots_path, false); - if (compareMetric == diskann::Metric::INNER_PRODUCT) - generate_pq_data_from_pivots(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, - disk_pq_compressed_vectors_path); - else - generate_pq_data_from_pivots(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, - disk_pq_compressed_vectors_path); + // For INNER_PRODUCT we may preprocess the base vectors into a temp file. That temp file must be + // read using the same element type that was written (T). Historically this was always float. + generate_pq_data_from_pivots(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, + disk_pq_compressed_vectors_path); delete[] train_data; } @@ -1148,6 +1148,9 @@ template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::st const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path, bool use_opq); +template DISKANN_DLLEXPORT int generate_pq_data_from_pivots( + const std::string &data_file, uint32_t num_centers, uint32_t num_pq_chunks, const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, bool use_opq); template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, @@ -1165,6 +1168,10 @@ template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::s const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims); +template DISKANN_DLLEXPORT void generate_disk_quantized_data( + const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, + size_t &disk_pq_dims); template DISKANN_DLLEXPORT void generate_quantized_data(const std::string &data_file_to_use, const std::string &pq_pivots_path, @@ -1186,4 +1193,8 @@ template DISKANN_DLLEXPORT void generate_quantized_data(const std::string diskann::Metric compareMetric, const double p_val, const size_t num_pq_chunks, const bool use_opq, const std::string &codebook_prefix); +template DISKANN_DLLEXPORT void generate_quantized_data( + const std::string &data_file_to_use, const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path, + diskann::Metric compareMetric, const double p_val, const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix); } // namespace diskann diff --git a/src/pq_data_store.cpp b/src/pq_data_store.cpp index c47c16705..e5058078f 100644 --- a/src/pq_data_store.cpp +++ b/src/pq_data_store.cpp @@ -5,6 +5,7 @@ #include "pq_scratch.h" #include "utils.h" #include "distance.h" +#include "bfloat16.h" namespace diskann { @@ -256,5 +257,6 @@ template location_t PQDataStore::load_impl(AlignedFile template DISKANN_DLLEXPORT class PQDataStore; template DISKANN_DLLEXPORT class PQDataStore; template DISKANN_DLLEXPORT class PQDataStore; +template DISKANN_DLLEXPORT class PQDataStore; } // namespace diskann \ No newline at end of file diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index d9ad50617..b52fb3b39 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -9,6 +9,18 @@ #include "pq_flash_index.h" #include "cosine_similarity.h" +#include "bf16_amx_kernels.h" +#include "rabitq.h" + +#include + +namespace +{ +std::once_flag g_reorder_amx_msg_once; +std::once_flag g_reorder_std_msg_once; +std::once_flag g_main_rabitq_msg_once; +} // namespace + #ifdef _WINDOWS #include "windows_aligned_file_reader.h" #else @@ -22,9 +34,6 @@ // sector # beyond the end of graph where data for id is present for reordering #define VECTOR_SECTOR_NO(id) (((uint64_t)(id)) / _nvecs_per_sector + _reorder_data_start_sector) -// sector # beyond the end of graph where data for id is present for reordering -#define VECTOR_SECTOR_OFFSET(id) ((((uint64_t)(id)) % _nvecs_per_sector) * _data_dim * sizeof(float)) - namespace diskann { @@ -35,7 +44,7 @@ PQFlashIndex::PQFlashIndex(std::shared_ptr &fileRe diskann::Metric metric_to_invoke = m; if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { - if (std::is_floating_point::value) + if (diskann::is_floating_point_like_v) { diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed " "(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we " @@ -97,6 +106,12 @@ template PQFlashIndex::~PQFlashIndex() { delete[] _medoids; } + + if (_rabitq_main_codes != nullptr) + { + aligned_free(_rabitq_main_codes); + _rabitq_main_codes = nullptr; + } } template inline uint64_t PQFlashIndex::get_node_sector(uint64_t node_id) @@ -791,6 +806,10 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; + // Optional: RaBitQ codes for main-search approximate scoring. Stored alongside the disk index file. + // File name convention: _rabitq_main.bin + std::string rabitq_main_file = std::string(_disk_index_file) + "_rabitq_main.bin"; + std::string labels_file = std ::string(_disk_index_file) + "_labels.txt"; std::string labels_to_medoids = std ::string(_disk_index_file) + "_labels_to_medoids.txt"; std::string dummy_map_file = std ::string(_disk_index_file) + "_dummy_map.txt"; @@ -818,200 +837,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons this->_disk_bytes_per_point = this->_data_dim * sizeof(T); this->_aligned_dim = ROUND_UP(pq_file_dim, 8); - size_t npts_u64, nchunks_u64; -#ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); -#else - diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); -#endif - - this->_num_points = npts_u64; - this->_n_chunks = nchunks_u64; -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_file)) - { - FileContent &content_labels = files.getContent(labels_file); - std::stringstream infile(std::string((const char *)content_labels._content, content_labels._size)); -#else - if (file_exists(labels_file)) - { - std::ifstream infile(labels_file, std::ios::binary); - if (infile.fail()) - { - throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1); - } -#endif - parse_label_file(infile, num_pts_in_label_file); - assert(num_pts_in_label_file == this->_num_points); - -#ifndef EXEC_ENV_OLS - infile.close(); -#endif - -#ifdef EXEC_ENV_OLS - FileContent &content_labels_map = files.getContent(labels_map_file); - std::stringstream map_reader(std::string((const char *)content_labels_map._content, content_labels_map._size)); -#else - std::ifstream map_reader(labels_map_file); -#endif - _label_map = load_label_map(map_reader); - -#ifndef EXEC_ENV_OLS - map_reader.close(); -#endif - -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_to_medoids)) - { - FileContent &content_labels_to_meoids = files.getContent(labels_to_medoids); - std::stringstream medoid_stream( - std::string((const char *)content_labels_to_meoids._content, content_labels_to_meoids._size)); -#else - if (file_exists(labels_to_medoids)) - { - std::ifstream medoid_stream(labels_to_medoids); - assert(medoid_stream.is_open()); -#endif - std::string line, token; - - _filter_to_medoid_ids.clear(); - try - { - while (std::getline(medoid_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - std::vector medoids; - LabelT label; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - label = (LabelT)std::stoul(token); - else - medoids.push_back((uint32_t)stoul(token)); - cnt++; - } - _filter_to_medoid_ids[label].swap(medoids); - } - } - catch (std::system_error &e) - { - throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); - } - } - std::string univ_label_file = std ::string(_disk_index_file) + "_universal_label.txt"; - -#ifdef EXEC_ENV_OLS - if (files.fileExists(univ_label_file)) - { - FileContent &content_univ_label = files.getContent(univ_label_file); - std::stringstream universal_label_reader( - std::string((const char *)content_univ_label._content, content_univ_label._size)); -#else - if (file_exists(univ_label_file)) - { - std::ifstream universal_label_reader(univ_label_file); - assert(universal_label_reader.is_open()); -#endif - std::string univ_label; - universal_label_reader >> univ_label; -#ifndef EXEC_ENV_OLS - universal_label_reader.close(); -#endif - LabelT label_as_num = (LabelT)std::stoul(univ_label); - set_universal_label(label_as_num); - } - #ifdef EXEC_ENV_OLS - if (files.fileExists(dummy_map_file)) - { - FileContent &content_dummy_map = files.getContent(dummy_map_file); - std::stringstream dummy_map_stream( - std::string((const char *)content_dummy_map._content, content_dummy_map._size)); -#else - if (file_exists(dummy_map_file)) - { - std::ifstream dummy_map_stream(dummy_map_file); - assert(dummy_map_stream.is_open()); -#endif - std::string line, token; - - while (std::getline(dummy_map_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t dummy_id; - uint32_t real_id; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - dummy_id = (uint32_t)stoul(token); - else - real_id = (uint32_t)stoul(token); - cnt++; - } - _dummy_pts.insert(dummy_id); - _has_dummy_pts.insert(real_id); - _dummy_to_real_map[dummy_id] = real_id; - - if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector(); - - _real_to_dummy_map[real_id].emplace_back(dummy_id); - } -#ifndef EXEC_ENV_OLS - dummy_map_stream.close(); -#endif - diskann::cout << "Loaded dummy map" << std::endl; - } - } - -#ifdef EXEC_ENV_OLS - _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); -#else - _pq_table.load_pq_centroid_bin(pq_table_bin.c_str(), nchunks_u64); -#endif - - diskann::cout << "Loaded PQ centroids and in-memory compressed vectors. #points: " << _num_points - << " #dim: " << _data_dim << " #aligned_dim: " << _aligned_dim << " #chunks: " << _n_chunks - << std::endl; - - if (_n_chunks > MAX_PQ_CHUNKS) - { - std::stringstream stream; - stream << "Error loading index. Ensure that max PQ bytes for in-memory " - "PQ data does not exceed " - << MAX_PQ_CHUNKS << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - - std::string disk_pq_pivots_path = this->_disk_index_file + "_pq_pivots.bin"; -#ifdef EXEC_ENV_OLS - if (files.fileExists(disk_pq_pivots_path)) - { - _use_disk_index_pq = true; - // giving 0 chunks to make the _pq_table infer from the - // chunk_offsets file the correct value - _disk_pq_table.load_pq_centroid_bin(files, disk_pq_pivots_path.c_str(), 0); -#else - if (file_exists(disk_pq_pivots_path)) - { - _use_disk_index_pq = true; - // giving 0 chunks to make the _pq_table infer from the - // chunk_offsets file the correct value - _disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0); -#endif - _disk_pq_n_chunks = _disk_pq_table.get_num_chunks(); - _disk_bytes_per_point = - _disk_pq_n_chunks * sizeof(uint8_t); // revising disk_bytes_per_point since DISK PQ is used. - diskann::cout << "Disk index uses PQ data compressed down to " << _disk_pq_n_chunks << " bytes per point." - << std::endl; - } - -// read index metadata -#ifdef EXEC_ENV_OLS - // This is a bit tricky. We have to read the header from the - // disk_index_file. But this is now exclusively a preserve of the // DiskPriorityIO class. So, we need to estimate how many // bytes are needed to store the header and read in that many using our // 'standard' aligned file reader approach. @@ -1031,6 +857,8 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons READ_U32(index_metadata, nr); READ_U32(index_metadata, nc); + const uint64_t metadata_u64_count = nr; + uint64_t disk_nnodes; uint64_t disk_ndims; // can be disk PQ dim if disk_PQ is set to true READ_U64(index_metadata, disk_nnodes); @@ -1083,6 +911,17 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons READ_U64(index_metadata, this->_reorder_data_start_sector); READ_U64(index_metadata, this->_ndims_reorder_vecs); READ_U64(index_metadata, this->_nvecs_per_sector); + + // Newer indexes may also store the element size of reorder vectors. + // Older indexes always used float reorder data. + if (metadata_u64_count >= 13) + { + READ_U64(index_metadata, this->_reorder_bytes_per_element); + } + else + { + this->_reorder_bytes_per_element = sizeof(float); + } } diskann::cout << "Disk-Index File Meta-data: "; @@ -1105,6 +944,93 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #endif + // Load RaBitQ main-search codes (optional) +#ifdef EXEC_ENV_OLS + if (files.fileExists(rabitq_main_file)) + { + FileContent &content = files.getContent(rabitq_main_file); + if (content._size >= sizeof(diskann::rabitq::RaBitQCodeHeader)) + { + diskann::rabitq::RaBitQCodeHeader hdr; + std::memcpy(&hdr, content._content, sizeof(hdr)); + if (diskann::rabitq::validate_header(hdr) && hdr.num_points == this->_num_points && hdr.dim == this->_data_dim) + { + const uint64_t expected_code_size = + static_cast(diskann::rabitq::compute_code_size(hdr.dim, hdr.nb_bits)); + if (hdr.code_size == expected_code_size) + { + const uint64_t total_bytes = static_cast(hdr.num_points) * static_cast(hdr.code_size); + const uint64_t need = sizeof(hdr) + total_bytes; + if (content._size >= need) + { + const uint64_t alloc_bytes = ROUND_UP(total_bytes, 64); + if (_rabitq_main_codes != nullptr) + aligned_free(_rabitq_main_codes); + diskann::alloc_aligned((void **)&_rabitq_main_codes, alloc_bytes, 64); + std::memset(_rabitq_main_codes, 0, alloc_bytes); + std::memcpy(_rabitq_main_codes, (const uint8_t *)content._content + sizeof(hdr), total_bytes); + _rabitq_main_codes_exist = true; + _rabitq_main_code_size = hdr.code_size; + _rabitq_main_dim = hdr.dim; + _rabitq_main_nb_bits = hdr.nb_bits; + _rabitq_main_metric = hdr.metric; + } + } + } + } + } +#else + if (file_exists(rabitq_main_file)) + { + try + { + std::ifstream in(rabitq_main_file, std::ios::binary); + if (!in) + throw diskann::ANNException(std::string("Failed to open RaBitQ main code file ") + rabitq_main_file, -1); + + diskann::rabitq::RaBitQCodeHeader hdr; + in.read(reinterpret_cast(&hdr), sizeof(hdr)); + if (!in) + throw diskann::ANNException(std::string("Failed to read RaBitQ main code header from ") + rabitq_main_file, + -1); + + if (!diskann::rabitq::validate_header(hdr) || hdr.num_points != this->_num_points || hdr.dim != this->_data_dim) + throw diskann::ANNException(std::string("Invalid RaBitQ main code header in ") + rabitq_main_file, -1); + + const uint64_t expected_code_size = + static_cast(diskann::rabitq::compute_code_size(hdr.dim, hdr.nb_bits)); + if (hdr.code_size != expected_code_size) + throw diskann::ANNException(std::string("RaBitQ main code_size mismatch in ") + rabitq_main_file, -1); + + const uint64_t total_bytes = static_cast(hdr.num_points) * static_cast(hdr.code_size); + const uint64_t alloc_bytes = ROUND_UP(total_bytes, 64); + + if (_rabitq_main_codes != nullptr) + aligned_free(_rabitq_main_codes); + diskann::alloc_aligned((void **)&_rabitq_main_codes, alloc_bytes, 64); + std::memset(_rabitq_main_codes, 0, alloc_bytes); + + in.read(reinterpret_cast(_rabitq_main_codes), total_bytes); + if (!in) + { + aligned_free(_rabitq_main_codes); + _rabitq_main_codes = nullptr; + throw diskann::ANNException(std::string("Failed to read RaBitQ main codes from ") + rabitq_main_file, -1); + } + + _rabitq_main_codes_exist = true; + _rabitq_main_code_size = hdr.code_size; + _rabitq_main_dim = hdr.dim; + _rabitq_main_nb_bits = hdr.nb_bits; + _rabitq_main_metric = hdr.metric; + } + catch (const std::exception &e) + { + diskann::cout << "Warning: failed to load RaBitQ main codes: " << e.what() << std::endl; + } + } +#endif + #ifdef EXEC_ENV_OLS if (files.fileExists(medoids_file)) { @@ -1305,7 +1231,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t query_norm += query1[i] * query1[i]; } if (metric == diskann::Metric::INNER_PRODUCT) - aligned_query_T[this->_data_dim - 1] = 0; + aligned_query_T[this->_data_dim - 1] = (T)0.0f; query_norm = std::sqrt(query_norm); @@ -1344,11 +1270,48 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t float *dist_scratch = pq_query_scratch->aligned_dist_scratch; uint8_t *pq_coord_scratch = pq_query_scratch->aligned_pq_coord_scratch; + const char *use_rabitq_main_env = std::getenv("DISKANN_USE_RABITQ_MAIN_APPROX"); + const bool use_rabitq_main_approx = + (use_rabitq_main_env != nullptr && std::atoi(use_rabitq_main_env) != 0 && _rabitq_main_codes_exist && + _rabitq_main_codes != nullptr && metric == diskann::Metric::INNER_PRODUCT && + _rabitq_main_metric == static_cast(diskann::rabitq::Metric::INNER_PRODUCT)); + + if (use_rabitq_main_env != nullptr && std::atoi(use_rabitq_main_env) != 0 && !use_rabitq_main_approx) + { + std::call_once(g_main_rabitq_msg_once, [&]() { + diskann::cout << "DISKANN_USE_RABITQ_MAIN_APPROX requested but main codes are unavailable or incompatible; " + "falling back to PQ." + << std::endl; + }); + } + else if (use_rabitq_main_approx) + { + std::call_once(g_main_rabitq_msg_once, [&]() { + diskann::cout << "Using RaBitQ main codes for traversal approximate scoring (nb_bits=" + << _rabitq_main_nb_bits << ")." << std::endl; + }); + } + // lambda to batch compute query<-> node distances in PQ space - auto compute_dists = [this, pq_coord_scratch, pq_dists](const uint32_t *ids, const uint64_t n_ids, - float *dists_out) { - diskann::aggregate_coords(ids, n_ids, this->data, this->_n_chunks, pq_coord_scratch); - diskann::pq_dist_lookup(pq_coord_scratch, n_ids, this->_n_chunks, pq_dists, dists_out); + auto compute_dists = [this, pq_coord_scratch, pq_dists, query_float, use_rabitq_main_approx](const uint32_t *ids, + const uint64_t n_ids, + float *dists_out) { + if (use_rabitq_main_approx) + { + for (uint64_t i = 0; i < n_ids; ++i) + { + const uint64_t id = ids[i]; + const uint8_t *code = _rabitq_main_codes + id * _rabitq_main_code_size; + const float approx_ip = diskann::rabitq::approx_inner_product_from_code( + code, query_float, static_cast(_rabitq_main_dim), static_cast(_rabitq_main_nb_bits)); + dists_out[i] = -approx_ip; + } + } + else + { + diskann::aggregate_coords(ids, n_ids, this->data, this->_n_chunks, pq_coord_scratch); + diskann::pq_dist_lookup(pq_coord_scratch, n_ids, this->_n_chunks, pq_dists, dists_out); + } }; Timer query_timer, io_timer, cpu_timer; @@ -1625,6 +1588,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::vector vec_read_reqs; + // Candidate pool before exact reorder. if (full_retset.size() > k_search * FULL_PRECISION_REORDER_MULTIPLIER) full_retset.erase(full_retset.begin() + k_search * FULL_PRECISION_REORDER_MULTIPLIER, full_retset.end()); @@ -1652,12 +1616,64 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t stats->io_us += io_timer.elapsed(); } - for (size_t i = 0; i < full_retset.size(); ++i) + // Fast-path for bf16 + INNER_PRODUCT: compute negative dot as distance in batch. + // This matches the rest of the MIPS code path which treats "distance" as -inner_product. + if (metric == diskann::Metric::INNER_PRODUCT && std::is_same::value && + _reorder_bytes_per_element == sizeof(diskann::bfloat16) && amxbf16_kernels_compiled() && + amxbf16_runtime_available()) { - auto id = full_retset[i].id; - // MULTISECTORFIX - auto location = (sector_scratch + i * defaults::SECTOR_LEN) + VECTOR_SECTOR_OFFSET(id); - full_retset[i].distance = _dist_cmp->compare(aligned_query_T, (T *)location, (uint32_t)this->_data_dim); + const uint32_t dim = static_cast(_ndims_reorder_vecs); + constexpr uint32_t kBatch = 16; + + std::call_once(g_reorder_amx_msg_once, [&]() { + printf("Using AMX bf16 batched dot product for reordering distance computation (%zu vectors).\n", + full_retset.size()); + fflush(stdout); + }); + + std::vector base_batch; + base_batch.resize(kBatch * dim); + alignas(64) float dots[kBatch]; + + for (size_t i = 0; i < full_retset.size(); i += kBatch) + { + const uint32_t cur = static_cast(std::min(kBatch, full_retset.size() - i)); + + for (uint32_t j = 0; j < cur; ++j) + { + auto id = full_retset[i + j].id; + const uint64_t elem_offset = + ((uint64_t)id % _nvecs_per_sector) * _ndims_reorder_vecs * _reorder_bytes_per_element; + auto location = (sector_scratch + (i + j) * defaults::SECTOR_LEN) + elem_offset; + std::memcpy(base_batch.data() + j * dim, location, static_cast(dim) * sizeof(diskann::bfloat16)); + } + + bf16_dot_f32_accum_amx_batch(base_batch.data(), (const diskann::bfloat16 *)aligned_query_T, cur, dim, + dots); + + for (uint32_t j = 0; j < cur; ++j) + { + full_retset[i + j].distance = -dots[j]; + } + } + } + else + { + std::call_once(g_reorder_std_msg_once, [&]() { + printf("Using standard distance computation for reordering distance(%zu vectors).\n", + full_retset.size()); + fflush(stdout); + }); + for (size_t i = 0; i < full_retset.size(); ++i) + { + auto id = full_retset[i].id; + // MULTISECTORFIX + const uint64_t elem_offset = + ((uint64_t)id % _nvecs_per_sector) * _ndims_reorder_vecs * _reorder_bytes_per_element; + auto location = (sector_scratch + i * defaults::SECTOR_LEN) + elem_offset; + full_retset[i].distance = + _dist_cmp->compare(aligned_query_T, (T *)location, (uint32_t)_ndims_reorder_vecs); + } } std::sort(full_retset.begin(), full_retset.end()); @@ -1786,8 +1802,10 @@ template std::uint64_t PQFlashIndex::ge template class PQFlashIndex; template class PQFlashIndex; template class PQFlashIndex; +template class PQFlashIndex; template class PQFlashIndex; template class PQFlashIndex; template class PQFlashIndex; +template class PQFlashIndex; } // namespace diskann diff --git a/src/pq_l2_distance.cpp b/src/pq_l2_distance.cpp index c08744c35..a352c3050 100644 --- a/src/pq_l2_distance.cpp +++ b/src/pq_l2_distance.cpp @@ -2,6 +2,7 @@ #include "pq.h" #include "pq_l2_distance.h" #include "pq_scratch.h" +#include "bfloat16.h" // block size for reading/processing large files and matrices in blocks #define BLOCK_SIZE 5000000 @@ -280,5 +281,6 @@ void PQL2Distance::prepopulate_chunkwise_distances(const float *query_ve template DISKANN_DLLEXPORT class PQL2Distance; template DISKANN_DLLEXPORT class PQL2Distance; template DISKANN_DLLEXPORT class PQL2Distance; +template DISKANN_DLLEXPORT class PQL2Distance; } // namespace diskann \ No newline at end of file diff --git a/src/rabitq.cpp b/src/rabitq.cpp new file mode 100644 index 000000000..804123f2a --- /dev/null +++ b/src/rabitq.cpp @@ -0,0 +1,444 @@ +#include "rabitq.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace diskann +{ +namespace rabitq +{ + +namespace +{ +constexpr float kTightStart[9] = {0.0f, 0.15f, 0.20f, 0.52f, 0.59f, 0.71f, 0.75f, 0.77f, 0.81f}; +constexpr double kEps = 1e-5; + +inline float l2_sqr(const float* x, size_t d) +{ + double sum = 0.0; + for (size_t i = 0; i < d; ++i) + sum += static_cast(x[i]) * static_cast(x[i]); + return static_cast(sum); +} + +inline float inner_product(const float* a, const float* b, size_t d) +{ + double sum = 0.0; + for (size_t i = 0; i < d; ++i) + sum += static_cast(a[i]) * static_cast(b[i]); + return static_cast(sum); +} + +inline void set_bit_standard(uint8_t* code, size_t bit_index) +{ + const size_t byte_idx = bit_index / 8; + const size_t bit_offset = bit_index % 8; + code[byte_idx] |= static_cast(1u << bit_offset); +} + +inline bool extract_bit_standard(const uint8_t* code, size_t bit_index) +{ + const size_t byte_idx = bit_index / 8; + const size_t bit_offset = bit_index % 8; + return (code[byte_idx] >> bit_offset) & 1u; +} + +inline int extract_code_inline(const uint8_t* ex_code, size_t index, size_t ex_bits) +{ + size_t bit_pos = index * ex_bits; + int code_value = 0; + + for (size_t bit = 0; bit < ex_bits; ++bit) + { + const size_t byte_idx = bit_pos / 8; + const size_t bit_idx = bit_pos % 8; + if (ex_code[byte_idx] & (1u << bit_idx)) + code_value |= (1u << bit); + bit_pos++; + } + + return code_value; +} + +SignBitFactorsWithError compute_vector_factors(const float* x, size_t d, Metric metric, bool compute_error) +{ + // Mirrors Faiss rabitq_utils::compute_vector_factors but with centroid == nullptr. + // or_minus_c == x + float norm_L2sqr = 0.0f; + float or_L2sqr = 0.0f; + float dp_oO = 0.0f; + + for (size_t j = 0; j < d; ++j) + { + const float x_val = x[j]; + const float or_minus_c = x_val; + const float sq = or_minus_c * or_minus_c; + norm_L2sqr += sq; + or_L2sqr += x_val * x_val; + + const bool xb = (or_minus_c > 0.0f); + dp_oO += xb ? or_minus_c : -or_minus_c; + } + + constexpr float epsilon = std::numeric_limits::epsilon(); + constexpr float kConstEpsilon = 1.9f; + + const float sqrt_norm_L2 = std::sqrt(norm_L2sqr); + const float inv_norm_L2 = (norm_L2sqr < epsilon) ? 1.0f : (1.0f / sqrt_norm_L2); + + const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast(d))); + const float normalized_dp = dp_oO * inv_norm_L2 * inv_d_sqrt; + const float inv_dp_oO = (std::abs(normalized_dp) < epsilon) ? 1.0f : (1.0f / normalized_dp); + + SignBitFactorsWithError factors; + if (metric == Metric::INNER_PRODUCT) + factors.or_minus_c_l2sqr = (norm_L2sqr - or_L2sqr); + else + factors.or_minus_c_l2sqr = norm_L2sqr; + + factors.dp_multiplier = inv_dp_oO * sqrt_norm_L2; + + if (compute_error) + { + const float xu_cb_norm_sqr = static_cast(d) * 0.25f; + const float ip_resi_xucb = 0.5f * dp_oO; + + float tmp_error = 0.0f; + if (std::abs(ip_resi_xucb) > epsilon) + { + const float ratio_sq = (norm_L2sqr * xu_cb_norm_sqr) / (ip_resi_xucb * ip_resi_xucb); + if (ratio_sq > 1.0f) + { + if (d == 1) + tmp_error = sqrt_norm_L2 * kConstEpsilon * std::sqrt(ratio_sq - 1.0f); + else + tmp_error = sqrt_norm_L2 * kConstEpsilon * + std::sqrt((ratio_sq - 1.0f) / static_cast(d - 1)); + } + } + + factors.f_error = (metric == Metric::L2) ? (2.0f * tmp_error) : (1.0f * tmp_error); + } + + return factors; +} + +float compute_optimal_scaling_factor(const float* o_abs, size_t d, size_t nb_bits) +{ + const size_t ex_bits = nb_bits - 1; + assert(ex_bits >= 1 && ex_bits <= 8); + + const int kNEnum = 10; + const int max_code = (1 << ex_bits) - 1; + + float max_o = *std::max_element(o_abs, o_abs + d); + if (!(max_o > 0.0f)) + return 1.0f; + + const float t_end = static_cast(max_code + kNEnum) / max_o; + const float t_start = t_end * kTightStart[ex_bits]; + + std::vector inv_o_abs(d); + for (size_t i = 0; i < d; ++i) + inv_o_abs[i] = 1.0f / o_abs[i]; + + std::vector cur_o_bar(d); + float sqr_denominator = static_cast(d) * 0.25f; + float numerator = 0.0f; + + for (size_t i = 0; i < d; ++i) + { + int cur = static_cast((t_start * o_abs[i]) + kEps); + cur_o_bar[i] = cur; + sqr_denominator += static_cast(cur * cur + cur); + numerator += (cur + 0.5f) * o_abs[i]; + } + + float inv_sqrt_denom = 1.0f / std::sqrt(sqr_denominator); + + std::vector> pq_storage; + pq_storage.reserve(d); + + std::priority_queue, std::vector>, std::greater<>> next_t( + std::greater<>(), std::move(pq_storage)); + + for (size_t i = 0; i < d; ++i) + { + float t_next = static_cast(cur_o_bar[i] + 1) * inv_o_abs[i]; + if (t_next < t_end) + next_t.emplace(t_next, i); + } + + float max_ip = 0.0f; + float t = 0.0f; + + while (!next_t.empty()) + { + const float cur_t = next_t.top().first; + const size_t update_id = next_t.top().second; + next_t.pop(); + + cur_o_bar[update_id]++; + const int update_o_bar = cur_o_bar[update_id]; + + const float delta = 2.0f * update_o_bar; + sqr_denominator += delta; + numerator += o_abs[update_id]; + + const float old_denom = sqr_denominator - delta; + inv_sqrt_denom = inv_sqrt_denom * (1.0f - 0.5f * delta / (old_denom + delta * 0.5f)); + + const float cur_ip = numerator * inv_sqrt_denom; + if (cur_ip > max_ip) + { + max_ip = cur_ip; + t = cur_t; + } + + if (update_o_bar < max_code) + { + float t_next = static_cast(update_o_bar + 1) * inv_o_abs[update_id]; + if (t_next < t_end) + next_t.emplace(t_next, update_id); + } + } + + return (t > 0.0f) ? t : 1.0f; +} + +void pack_multibit_codes(const int* tmp_code, uint8_t* ex_code, size_t d, size_t nb_bits) +{ + const size_t ex_bits = nb_bits - 1; + assert(ex_bits >= 1 && ex_bits <= 8); + + const size_t total_bits = d * ex_bits; + const size_t output_size = (total_bits + 7) / 8; + std::memset(ex_code, 0, output_size); + + size_t bit_pos = 0; + for (size_t i = 0; i < d; ++i) + { + const int code_value = tmp_code[i]; + for (size_t bit = 0; bit < ex_bits; ++bit) + { + const size_t byte_idx = bit_pos / 8; + const size_t bit_idx = bit_pos % 8; + if (code_value & (1 << bit)) + ex_code[byte_idx] |= static_cast(1u << bit_idx); + ++bit_pos; + } + } +} + +void compute_ex_factors(const float* residual, const int* tmp_code, size_t d, size_t ex_bits, float norm, + double ipnorm, ExtraBitsFactors& ex_factors, Metric metric) +{ + float ipnorm_inv = static_cast(1.0 / ipnorm); + if (!std::isnormal(ipnorm_inv)) + ipnorm_inv = 1.0f; + + const float cb = -(static_cast(1 << ex_bits) - 0.5f); + + std::vector xu_cb(d); + for (size_t i = 0; i < d; ++i) + xu_cb[i] = static_cast(tmp_code[i]) + cb; + + const float l2_sqr_val = norm * norm; + const float ip_resi_xucb = inner_product(residual, xu_cb.data(), d); + + if (metric == Metric::L2) + { + ex_factors.f_add_ex = l2_sqr_val; + ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm; + } + else + { + // centroid == nullptr in this DiskANN integration, so centroid correction term is 0. + ex_factors.f_add_ex = 1; + ex_factors.f_rescale_ex = ipnorm_inv * -norm; + (void)ip_resi_xucb; + } +} + +void quantize_ex_bits(const float* residual, size_t d, size_t nb_bits, uint8_t* ex_code, ExtraBitsFactors& ex_factors, + Metric metric) +{ + const size_t ex_bits = nb_bits - 1; + assert(ex_bits >= 1 && ex_bits <= 8); + + // Normalize residual + const float norm2 = l2_sqr(residual, d); + const float norm = std::sqrt(norm2); + const float inv_norm = (norm2 > 0.0f) ? (1.0f / norm) : 1.0f; + + std::vector o_abs(d); + std::vector tmp_code(d); + + // abs(normalized residual) + for (size_t i = 0; i < d; ++i) + o_abs[i] = std::abs(residual[i] * inv_norm); + + const float t = compute_optimal_scaling_factor(o_abs.data(), d, nb_bits); + const int max_code = (1 << ex_bits) - 1; + + // Quantize and apply sign bit flipping behavior (encode magnitude with sign stored separately) + // total_code = (sign << ex_bits) + ex_code + for (size_t i = 0; i < d; ++i) + { + int q = static_cast((t * o_abs[i]) + kEps); + q = std::clamp(q, 0, max_code); + + // If residual is negative, flip code via complement (matches Faiss logic via sign bit storage) + if (residual[i] <= 0.0f) + q = max_code - q; + + tmp_code[i] = q; + } + + pack_multibit_codes(tmp_code.data(), ex_code, d, nb_bits); + + // ipnorm from quantized normalized residual vs abs residual. Faiss uses a specific normalization; + // here we approximate with a stable fallback. + // We keep ipnorm > 0 to avoid NaNs. + double ipnorm = 0.0; + { + const float cb = -(static_cast(1 << ex_bits) - 0.5f); + std::vector xu_cb(d); + for (size_t i = 0; i < d; ++i) + xu_cb[i] = static_cast(tmp_code[i]) + cb; + ipnorm = inner_product(o_abs.data(), xu_cb.data(), d); + if (!(ipnorm > 0.0)) + ipnorm = 1.0; + } + + compute_ex_factors(residual, tmp_code.data(), d, ex_bits, norm, ipnorm, ex_factors, metric); +} + +} // namespace + +size_t compute_code_size(size_t d, size_t nb_bits) +{ + assert(nb_bits >= 1 && nb_bits <= 9); + + const size_t ex_bits = nb_bits - 1; + + const size_t base_size = (d + 7) / 8 + (ex_bits == 0 ? sizeof(SignBitFactors) : sizeof(SignBitFactorsWithError)); + + size_t ex_size = 0; + if (ex_bits > 0) + ex_size = (d * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors); + + return base_size + ex_size; +} + +void encode_vector(const float* x, size_t d, Metric metric, size_t nb_bits, uint8_t* out_code) +{ + assert(x != nullptr); + assert(out_code != nullptr); + assert(nb_bits >= 1 && nb_bits <= 9); + + const size_t ex_bits = nb_bits - 1; + const size_t code_size = compute_code_size(d, nb_bits); + std::memset(out_code, 0, code_size); + + uint8_t* sign_bits = out_code; + + const bool compute_error = (ex_bits > 0); + const SignBitFactorsWithError factors_data = compute_vector_factors(x, d, metric, compute_error); + + if (ex_bits == 0) + { + auto* base_factors = reinterpret_cast(out_code + (d + 7) / 8); + base_factors->or_minus_c_l2sqr = factors_data.or_minus_c_l2sqr; + base_factors->dp_multiplier = factors_data.dp_multiplier; + } + else + { + auto* full_factors = reinterpret_cast(out_code + (d + 7) / 8); + *full_factors = factors_data; + } + + // Sign bits + std::vector residual(d); + for (size_t j = 0; j < d; ++j) + { + residual[j] = x[j]; + if (x[j] > 0.0f) + set_bit_standard(sign_bits, j); + } + + if (ex_bits > 0) + { + uint8_t* ex_code = out_code + (d + 7) / 8 + sizeof(SignBitFactorsWithError); + auto* ex_factors = reinterpret_cast(ex_code + (d * ex_bits + 7) / 8); + quantize_ex_bits(residual.data(), d, nb_bits, ex_code, *ex_factors, metric); + } +} + +float approx_inner_product_from_code(const uint8_t* code, const float* query, size_t d, size_t nb_bits) +{ + assert(code != nullptr); + assert(query != nullptr); + assert(nb_bits >= 1 && nb_bits <= 9); + + const size_t ex_bits = nb_bits - 1; + + const uint8_t* sign_bits = code; + + // For this integration we assume query is already rotated/centered as needed. + const float qr_to_c_L2sqr = l2_sqr(query, d); + const float qr_norm_L2sqr = qr_to_c_L2sqr; + + float ex_ip = 0.0f; + + const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast(d))); + + if (ex_bits == 0) + { + // Very simple 1-bit estimator: reconstructed value per dim is (bit-0.5)*2*inv_d_sqrt*dp_multiplier + const auto* fac = reinterpret_cast(code + (d + 7) / 8); + const float scale = fac->dp_multiplier * 2.0f * inv_d_sqrt; + + for (size_t i = 0; i < d; ++i) + { + const float bit = extract_bit_standard(sign_bits, i) ? 1.0f : 0.0f; + const float reconstructed = (bit - 0.5f) * scale; + ex_ip += query[i] * reconstructed; + } + + return ex_ip; + } + + const auto* base_fac = reinterpret_cast(code + (d + 7) / 8); + const uint8_t* ex_code = code + (d + 7) / 8 + sizeof(SignBitFactorsWithError); + const auto* ex_fac = reinterpret_cast(ex_code + (d * ex_bits + 7) / 8); + + const float cb = -(static_cast(1 << ex_bits) - 0.5f); + for (size_t i = 0; i < d; ++i) + { + const bool sign_bit = extract_bit_standard(sign_bits, i); + const int ex_code_val = extract_code_inline(ex_code, i, ex_bits); + int total_code = (sign_bit ? 1 : 0) << ex_bits; + total_code += ex_code_val; + const float reconstructed = static_cast(total_code) + cb; + ex_ip += query[i] * reconstructed; + } + + // Faiss multi-bit distance formula (metric==IP) transformed yields an IP-like score. + // dist = qr_to_c_L2sqr + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip + // ip_score = -0.5 * (dist - qr_norm_L2sqr) + const float dist = qr_to_c_L2sqr + ex_fac->f_add_ex + ex_fac->f_rescale_ex * ex_ip; + const float ip_score = -0.5f * (dist - qr_norm_L2sqr); + + (void)base_fac; // base factors not needed for full multi-bit path here + + return ip_score; +} + +} // namespace rabitq +} // namespace diskann diff --git a/src/scratch.cpp b/src/scratch.cpp index 1f8a34bb1..4fb3a1f29 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -6,6 +6,7 @@ #include "scratch.h" #include "pq_scratch.h" +#include "bfloat16.h" namespace diskann { @@ -166,17 +167,21 @@ template void PQScratch::initialize(size_t dim, const T *query, template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class InMemQueryScratch; +template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class SSDQueryScratch; template DISKANN_DLLEXPORT class SSDQueryScratch; template DISKANN_DLLEXPORT class SSDQueryScratch; +template DISKANN_DLLEXPORT class SSDQueryScratch; template DISKANN_DLLEXPORT class PQScratch; template DISKANN_DLLEXPORT class PQScratch; template DISKANN_DLLEXPORT class PQScratch; +template DISKANN_DLLEXPORT class PQScratch; template DISKANN_DLLEXPORT class SSDThreadData; template DISKANN_DLLEXPORT class SSDThreadData; template DISKANN_DLLEXPORT class SSDThreadData; +template DISKANN_DLLEXPORT class SSDThreadData; } // namespace diskann diff --git a/src/utils.cpp b/src/utils.cpp index 3773cda22..595d51eb0 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -61,13 +61,124 @@ bool cpuHasAvx2Support() return false; } +bool cpuHasAvx512bf16Support() +{ + // Need OSXSAVE + XCR0 enabling ZMM state, plus AVX-512F and AVX-512 BF16. + int cpuInfo[4]; + __cpuid(cpuInfo, 1); + bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; + if (!osUsesXSAVE_XRSTORE) + return false; + + unsigned long long xcr0 = _xgetbv(_XCR_XFEATURE_ENABLED_MASK); + // Require XMM (bit1), YMM (bit2), opmask (bit5), ZMM_hi256 (bit6), hi16_zmm (bit7) + const unsigned long long kXcr0Avx512Mask = 0xE6; + if ((xcr0 & kXcr0Avx512Mask) != kXcr0Avx512Mask) + return false; + + __cpuid(cpuInfo, 0); + int n = cpuInfo[0]; + if (n < 7) + return false; + + // AVX-512F is CPUID.(EAX=7,ECX=0):EBX[16] + __cpuidex(cpuInfo, 7, 0); + const bool hasAvx512F = (cpuInfo[1] & (1 << 16)) != 0; + if (!hasAvx512F) + return false; + + // AVX512_BF16 is CPUID.(EAX=7,ECX=1):EAX[5] + __cpuidex(cpuInfo, 7, 1); + const bool hasAvx512Bf16 = (cpuInfo[0] & (1 << 5)) != 0; + return hasAvx512Bf16; +} + +bool AvxSupportedCPU = cpuHasAvxSupport(); +bool Avx2SupportedCPU = cpuHasAvx2Support(); +bool Avx512Bf16SupportedCPU = cpuHasAvx512bf16Support(); + +#else + +#if defined(__x86_64__) || defined(__i386__) +#include + +static inline uint64_t xgetbv_u32(uint32_t index) +{ + uint32_t eax = 0, edx = 0; + __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); + return (static_cast(edx) << 32) | eax; +} + +static inline bool cpuHasOsAvxSupport() +{ + unsigned int eax = 0, ebx = 0, ecx = 0, edx = 0; + if (!__get_cpuid(1, &eax, &ebx, &ecx, &edx)) + return false; + const bool osxsave = (ecx & (1u << 27)) != 0; + const bool avx = (ecx & (1u << 28)) != 0; + if (!(osxsave && avx)) + return false; + const uint64_t xcr0 = xgetbv_u32(0); + return (xcr0 & 0x6) == 0x6; +} + +bool cpuHasAvxSupport() +{ + return cpuHasOsAvxSupport(); +} + +bool cpuHasAvx2Support() +{ + if (!cpuHasOsAvxSupport()) + return false; + unsigned int eax = 0, ebx = 0, ecx = 0, edx = 0; + if (!__get_cpuid_max(0, nullptr)) + return false; + __cpuid_count(7, 0, eax, ebx, ecx, edx); + return (ebx & (1u << 5)) != 0; +} + +bool cpuHasAvx512bf16Support() +{ + // Require OSXSAVE + XCR0 enabling full AVX-512 ZMM state. + unsigned int eax = 0, ebx = 0, ecx = 0, edx = 0; + if (!__get_cpuid(1, &eax, &ebx, &ecx, &edx)) + return false; + const bool osxsave = (ecx & (1u << 27)) != 0; + if (!osxsave) + return false; + + const uint64_t xcr0 = xgetbv_u32(0); + const uint64_t kXcr0Avx512Mask = 0xE6; // XMM|YMM|opmask|ZMM_hi256|hi16_zmm + if ((xcr0 & kXcr0Avx512Mask) != kXcr0Avx512Mask) + return false; + + if (__get_cpuid_max(0, nullptr) < 7) + return false; + + // AVX-512F: CPUID.(EAX=7,ECX=0):EBX[16] + __cpuid_count(7, 0, eax, ebx, ecx, edx); + const bool hasAvx512F = (ebx & (1u << 16)) != 0; + if (!hasAvx512F) + return false; + + // AVX512_BF16: CPUID.(EAX=7,ECX=1):EAX[5] + __cpuid_count(7, 1, eax, ebx, ecx, edx); + const bool hasAvx512Bf16 = (eax & (1u << 5)) != 0; + return hasAvx512Bf16; +} + bool AvxSupportedCPU = cpuHasAvxSupport(); bool Avx2SupportedCPU = cpuHasAvx2Support(); +bool Avx512Bf16SupportedCPU = cpuHasAvx512bf16Support(); #else -bool Avx2SupportedCPU = true; bool AvxSupportedCPU = false; +bool Avx2SupportedCPU = false; +bool Avx512Bf16SupportedCPU = false; + +#endif #endif namespace diskann