diff --git a/.gitignore b/.gitignore index 9f852b538..87b7d5876 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ /.tox/ /bin/ /build/ +/build_python/ /dist/ /venv/ /1/ diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 000000000..b83d22266 --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1 @@ +/target/ diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 000000000..538deb60f --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,1029 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "bumpalo" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6e6ff9dcd79cff5cd969a17a545d79e84ab086e444102a591e288a8aa3ce394" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa42cf4d2b7a41bc8f663a7cab4031ebafa1bf3875705bfaf8466dc60ab52c00" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", + "zerocopy", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "indoc" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706" +dependencies = [ + "rustversion", +] + +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "js-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.180" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" + +[[package]] +name = "libm" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "memmap2" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "numpy" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7cfbf3f0feededcaa4d289fe3079b03659e85c5b5a177f4ba6fb01ab4fb3e39" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "pyo3-build-config", + "rustc-hash", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "portable-atomic" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "syn" +version = "2.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1dd07eb858a2067e2f3c7155d54e929265c264e6f37efe3ee7a8d1b5a1dd0ba" + +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "vecsim" +version = "0.1.0" +dependencies = [ + "criterion", + "dashmap", + "half", + "memmap2", + "num-traits", + "parking_lot", + "rand", + "rayon", + "thiserror", +] + +[[package]] +name = "vecsim-c" +version = "0.1.0" +dependencies = [ + "half", + "libc", + "parking_lot", + "tempfile", + "vecsim", +] + +[[package]] +name = "vecsim-python" +version = "0.1.0" +dependencies = [ + "half", + "ndarray", + "numpy", + "pyo3", + "rayon", + "vecsim", +] + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "web-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + +[[package]] +name = "zerocopy" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 000000000..c9029c121 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,19 @@ +[workspace] +resolver = "2" +members = ["vecsim", "vecsim-python", "vecsim-c"] +default-members = ["vecsim", "vecsim-c"] + +[workspace.package] +version = "0.1.0" +edition = "2021" +license = "BSD-3-Clause" +repository = "https://github.com/RedisAI/VectorSimilarity" + +[workspace.dependencies] +rayon = "1.10" +parking_lot = "0.12" +half = { version = "2.4", features = ["num-traits"] } +num-traits = "0.2" +thiserror = "1.0" +rand = "0.8" +memmap2 = "0.9" diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 000000000..eb6d1ff90 --- /dev/null +++ b/rust/README.md @@ -0,0 +1,242 @@ +# VecSim Rust Implementation + +High-performance vector similarity search library written in Rust, with Python and C bindings. + +## Crate Structure + +``` +rust/ +├── vecsim/ # Core library - vector indices and algorithms +├── vecsim-python/ # Python bindings (PyO3/maturin) +└── vecsim-c/ # C-compatible FFI layer +``` + +### vecsim (Core Library) + +The main library providing: +- **Index Types**: BruteForce, HNSW, SVS (Vamana), Tiered, Disk-based +- **Data Types**: f32, f64, Float16, BFloat16, Int8, UInt8, Int32, Int64 +- **Distance Metrics**: L2 (Euclidean), Inner Product, Cosine +- **Features**: Multi-value indices, SIMD optimization, serialization, parallel operations + +### vecsim-python + +Python bindings using PyO3. Must be built with `maturin`, not `cargo build` directly. + +### vecsim-c + +C-compatible API matching the C++ VecSim interface for drop-in replacement. + +## Building + +### Prerequisites + +- Rust 1.70+ (install via [rustup](https://rustup.rs/)) +- For Python bindings: Python 3.8+ and [maturin](https://github.com/PyO3/maturin) + +### Build Core Libraries + +```bash +cd rust + +# Build vecsim and vecsim-c (release mode) +cargo build --release +``` + +The workspace is configured to exclude `vecsim-python` from default builds since it requires `maturin`. + +### Build Python Bindings + +Python bindings require `maturin` because they link against the Python interpreter at runtime: + +```bash +# Install maturin +pip install maturin + +# Build the wheel +cd rust/vecsim-python +maturin build --release + +# Or install directly into current Python environment +maturin develop --release +``` + +The wheel will be created in `rust/target/wheels/`. + +### Clean Build + +```bash +cd rust +cargo clean +cargo build --release -p vecsim -p vecsim-c +``` + +## Testing + +### Run All Tests + +```bash +cd rust + +# Run all tests for core library +cargo test -p vecsim + +# Run tests with output displayed +cargo test -p vecsim -- --nocapture + +# Run a specific test +cargo test -p vecsim test_hnsw_basic + +# Run tests matching a pattern +cargo test -p vecsim hnsw +``` + +### Test Categories + +The test suite includes: + +- **Unit tests**: Inline tests throughout the codebase +- **End-to-end tests** (`e2e_tests.rs`): Complete workflow tests, serialization, persistence +- **Data type tests** (`data_type_tests.rs`): Coverage for all 8 vector types +- **Parallel stress tests** (`parallel_stress_tests.rs`): Concurrency and thread safety + +### Run Tests with Features + +```bash +# Run tests with all features enabled +cargo test -p vecsim --all-features +``` + +## Benchmarking + +Benchmarks use [Criterion](https://github.com/bheisler/criterion.rs) for accurate measurements. + +### Available Benchmarks + +| Benchmark | Description | +|-----------|-------------| +| `hnsw_bench` | HNSW index operations (add, query, multi-value) | +| `brute_force_bench` | BruteForce index performance | +| `tiered_bench` | Two-tier index benchmarks | +| `svs_bench` | Single-layer Vamana graph performance | +| `comparison_bench` | Cross-algorithm comparisons | +| `dbpedia_bench` | Real-world dataset benchmarks with recall measurement | + +### Run Benchmarks + +```bash +cd rust + +# Run all benchmarks +cargo bench -p vecsim + +# Run a specific benchmark +cargo bench -p vecsim --bench hnsw_bench + +# Run benchmarks matching a pattern +cargo bench -p vecsim -- "hnsw" +``` + +### Real-World Dataset Benchmarks + +The `dbpedia_bench` benchmark can use real DBPedia embeddings for realistic performance testing: + +```bash +# Download benchmark data (from project root) +bash tests/benchmark/bm_files.sh benchmarks-all + +# Run DBPedia benchmark +cargo bench -p vecsim --bench dbpedia_bench +``` + +If the dataset is not available, the benchmark falls back to random data. + +### Benchmark Output + +Results are saved to `rust/target/criterion/` with HTML reports. Open `rust/target/criterion/report/index.html` to view. + +## Examples + +### Run the Profiling Example + +```bash +cd rust +cargo run --release -p vecsim --example profile_insert +``` + +### Quick Start Code + +```rust +use vecsim::prelude::*; + +fn main() { + // Create an HNSW index for 128-dimensional f32 vectors with cosine similarity + let params = HnswParams::new(128, Metric::Cosine) + .with_m(16) + .with_ef_construction(200); + + let mut index: HnswSingle = HnswSingle::new(params); + + // Add vectors + let vector = vec![0.1f32; 128]; + index.add(&vector, 1).unwrap(); + + // Query + let query = vec![0.1f32; 128]; + let results = index.search(&query, 10); + + for result in results { + println!("Label: {}, Distance: {}", result.label, result.distance); + } +} +``` + +## Project Structure + +``` +rust/vecsim/ +├── src/ +│ ├── lib.rs # Public API and prelude +│ ├── index/ # Index implementations +│ │ ├── brute_force.rs # Exact search +│ │ ├── hnsw.rs # Hierarchical NSW +│ │ ├── svs.rs # Vamana graph +│ │ ├── tiered.rs # Two-tier hybrid +│ │ └── disk.rs # Memory-mapped indices +│ ├── distance/ # SIMD-optimized distance functions +│ ├── quantization/ # Scalar quantization (SQ8) +│ ├── storage/ # Vector storage backends +│ └── types/ # Core type definitions +├── benches/ # Criterion benchmarks +├── tests/ # Integration tests +└── examples/ # Usage examples +``` + +## Common Issues + +### Python bindings fail with "Undefined symbols" + +Python extensions must be built with `maturin`, not `cargo build`: + +```bash +# Wrong - will fail with linker errors +cargo build -p vecsim-python + +# Correct +cd vecsim-python && maturin build --release +``` + +### Building the Python crate explicitly + +If you need to build `vecsim-python` with cargo (not recommended), you must use maturin: + +```bash +cd vecsim-python +maturin build --release +``` + +Direct `cargo build -p vecsim-python` will fail with linker errors. + +## License + +BSD-3-Clause diff --git a/rust/vecsim-c/CMakeLists.txt b/rust/vecsim-c/CMakeLists.txt new file mode 100644 index 000000000..41ed0d0ad --- /dev/null +++ b/rust/vecsim-c/CMakeLists.txt @@ -0,0 +1,114 @@ +# CMakeLists.txt for vecsim-c Rust library +# This integrates the Rust vecsim-c library into the RediSearch build system + +cmake_minimum_required(VERSION 3.14) +project(vecsim-c-rust LANGUAGES C CXX) + +# Find Cargo +find_program(CARGO_EXECUTABLE cargo REQUIRED) + +# Determine build type +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CARGO_BUILD_TYPE "debug") + set(CARGO_BUILD_FLAGS "") +else() + set(CARGO_BUILD_TYPE "release") + set(CARGO_BUILD_FLAGS "--release") +endif() + +# Set the target directory +set(CARGO_TARGET_DIR "${CMAKE_CURRENT_BINARY_DIR}/target") + +# Determine the library name based on platform +if(APPLE) + set(VECSIM_C_LIB_NAME "libvecsim_c.dylib") + set(VECSIM_C_STATIC_LIB_NAME "libvecsim_c.a") +elseif(WIN32) + set(VECSIM_C_LIB_NAME "vecsim_c.dll") + set(VECSIM_C_STATIC_LIB_NAME "vecsim_c.lib") +else() + set(VECSIM_C_LIB_NAME "libvecsim_c.so") + set(VECSIM_C_STATIC_LIB_NAME "libvecsim_c.a") +endif() + +# Full path to the built library +set(VECSIM_C_LIB_PATH "${CARGO_TARGET_DIR}/${CARGO_BUILD_TYPE}/${VECSIM_C_LIB_NAME}") +set(VECSIM_C_STATIC_LIB_PATH "${CARGO_TARGET_DIR}/${CARGO_BUILD_TYPE}/${VECSIM_C_STATIC_LIB_NAME}") + +# Custom command to build the Rust library +add_custom_command( + OUTPUT ${VECSIM_C_LIB_PATH} ${VECSIM_C_STATIC_LIB_PATH} + COMMAND ${CMAKE_COMMAND} -E env + CARGO_TARGET_DIR=${CARGO_TARGET_DIR} + ${CARGO_EXECUTABLE} build + ${CARGO_BUILD_FLAGS} + -p vecsim-c + --manifest-path ${CMAKE_CURRENT_SOURCE_DIR}/../Cargo.toml + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Building vecsim-c Rust library (${CARGO_BUILD_TYPE})" + VERBATIM +) + +# Custom target for the Rust build +add_custom_target(vecsim_c_rust_build + DEPENDS ${VECSIM_C_LIB_PATH} ${VECSIM_C_STATIC_LIB_PATH} +) + +# Create imported library target for the shared library +add_library(vecsim_c_shared SHARED IMPORTED GLOBAL) +set_target_properties(vecsim_c_shared PROPERTIES + IMPORTED_LOCATION ${VECSIM_C_LIB_PATH} +) +add_dependencies(vecsim_c_shared vecsim_c_rust_build) + +# Create imported library target for the static library +add_library(vecsim_c_static STATIC IMPORTED GLOBAL) +set_target_properties(vecsim_c_static PROPERTIES + IMPORTED_LOCATION ${VECSIM_C_STATIC_LIB_PATH} +) +add_dependencies(vecsim_c_static vecsim_c_rust_build) + +# Default alias (prefer static linking) +add_library(vecsim_c ALIAS vecsim_c_static) + +# Include directory for the C header +set(VECSIM_C_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include" CACHE PATH "vecsim-c include directory") + +# Interface library for easy consumption +add_library(vecsim_c_interface INTERFACE) +target_include_directories(vecsim_c_interface INTERFACE ${VECSIM_C_INCLUDE_DIR}) + +# Export variables for parent projects +set(VECSIM_C_LIBRARIES ${VECSIM_C_STATIC_LIB_PATH} CACHE FILEPATH "vecsim-c static library") +set(VECSIM_C_SHARED_LIBRARIES ${VECSIM_C_LIB_PATH} CACHE FILEPATH "vecsim-c shared library") + +# Test target +add_custom_target(vecsim_c_test + COMMAND ${CMAKE_COMMAND} -E env + CARGO_TARGET_DIR=${CARGO_TARGET_DIR} + ${CARGO_EXECUTABLE} test + -p vecsim-c + --manifest-path ${CMAKE_CURRENT_SOURCE_DIR}/../Cargo.toml + -- --test-threads=1 + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Running vecsim-c tests" + VERBATIM +) + +# Clean target +add_custom_target(vecsim_c_clean + COMMAND ${CARGO_EXECUTABLE} clean + --manifest-path ${CMAKE_CURRENT_SOURCE_DIR}/../Cargo.toml + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Cleaning vecsim-c Rust build" + VERBATIM +) + +# Print configuration info +message(STATUS "vecsim-c Rust library configuration:") +message(STATUS " Build type: ${CARGO_BUILD_TYPE}") +message(STATUS " Target dir: ${CARGO_TARGET_DIR}") +message(STATUS " Static lib: ${VECSIM_C_STATIC_LIB_PATH}") +message(STATUS " Shared lib: ${VECSIM_C_LIB_PATH}") +message(STATUS " Include dir: ${VECSIM_C_INCLUDE_DIR}") + diff --git a/rust/vecsim-c/Cargo.toml b/rust/vecsim-c/Cargo.toml new file mode 100644 index 000000000..02dc4b856 --- /dev/null +++ b/rust/vecsim-c/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "vecsim-c" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "C-compatible FFI bindings for the VecSim vector similarity search library" + +[lib] +name = "vecsim_c" +crate-type = ["staticlib", "cdylib"] + +[dependencies] +vecsim = { path = "../vecsim" } +half = { workspace = true } +parking_lot = { workspace = true } +libc = "0.2.180" + +[dev-dependencies] +tempfile = "3" + +[features] +default = [] diff --git a/rust/vecsim-c/include/VecSim/query_results.h b/rust/vecsim-c/include/VecSim/query_results.h new file mode 100644 index 000000000..3f70542fe --- /dev/null +++ b/rust/vecsim-c/include/VecSim/query_results.h @@ -0,0 +1,14 @@ +/** + * @file query_results.h + * @brief Wrapper header for Rust VecSim implementation. + * + * This header includes the main vecsim.h which provides all VecSim functionality. + */ + +#ifndef VECSIM_QUERY_RESULTS_H +#define VECSIM_QUERY_RESULTS_H + +#include "../vecsim.h" + +#endif /* VECSIM_QUERY_RESULTS_H */ + diff --git a/rust/vecsim-c/include/VecSim/vec_sim.h b/rust/vecsim-c/include/VecSim/vec_sim.h new file mode 100644 index 000000000..01ce37d08 --- /dev/null +++ b/rust/vecsim-c/include/VecSim/vec_sim.h @@ -0,0 +1,14 @@ +/** + * @file vec_sim.h + * @brief Wrapper header for Rust VecSim implementation. + * + * This header includes the main vecsim.h which provides all VecSim functionality. + */ + +#ifndef VECSIM_VEC_SIM_H +#define VECSIM_VEC_SIM_H + +#include "../vecsim.h" + +#endif /* VECSIM_VEC_SIM_H */ + diff --git a/rust/vecsim-c/include/VecSim/vec_sim_common.h b/rust/vecsim-c/include/VecSim/vec_sim_common.h new file mode 100644 index 000000000..3f0e2a009 --- /dev/null +++ b/rust/vecsim-c/include/VecSim/vec_sim_common.h @@ -0,0 +1,14 @@ +/** + * @file vec_sim_common.h + * @brief Wrapper header for Rust VecSim implementation. + * + * This header includes the main vecsim.h which provides all VecSim functionality. + */ + +#ifndef VECSIM_VEC_SIM_COMMON_H +#define VECSIM_VEC_SIM_COMMON_H + +#include "../vecsim.h" + +#endif /* VECSIM_VEC_SIM_COMMON_H */ + diff --git a/rust/vecsim-c/include/VecSim/vec_sim_debug.h b/rust/vecsim-c/include/VecSim/vec_sim_debug.h new file mode 100644 index 000000000..ad00fba38 --- /dev/null +++ b/rust/vecsim-c/include/VecSim/vec_sim_debug.h @@ -0,0 +1,14 @@ +/** + * @file vec_sim_debug.h + * @brief Wrapper header for Rust VecSim implementation. + * + * This header includes the main vecsim.h which provides all VecSim functionality. + */ + +#ifndef VECSIM_VEC_SIM_DEBUG_H +#define VECSIM_VEC_SIM_DEBUG_H + +#include "../vecsim.h" + +#endif /* VECSIM_VEC_SIM_DEBUG_H */ + diff --git a/rust/vecsim-c/include/vecsim.h b/rust/vecsim-c/include/vecsim.h new file mode 100644 index 000000000..7574fa279 --- /dev/null +++ b/rust/vecsim-c/include/vecsim.h @@ -0,0 +1,1648 @@ +/** + * @file vecsim.h + * @brief C API for the VecSim vector similarity search library. + * + * This header provides a C-compatible interface to the Rust VecSim library, + * enabling high-performance vector similarity search from C/C++ applications. + * + * @example + * ```c + * #include "vecsim.h" + * + * int main() { + * // Create a BruteForce index + * BFParams params = {0}; + * params.base.algo = VecSimAlgo_BF; + * params.base.type_ = VecSimType_FLOAT32; + * params.base.metric = VecSimMetric_L2; + * params.base.dim = 4; + * params.base.multi = false; + * params.base.initialCapacity = 100; + * + * VecSimIndex *index = VecSimIndex_NewBF(¶ms); + * + * // Add vectors + * float v1[] = {1.0f, 0.0f, 0.0f, 0.0f}; + * VecSimIndex_AddVector(index, v1, 1); + * + * // Query + * float query[] = {1.0f, 0.1f, 0.0f, 0.0f}; + * VecSimQueryReply *reply = VecSimIndex_TopKQuery(index, query, 10, NULL, BY_SCORE); + * + * // Iterate results + * VecSimQueryReply_Iterator *iter = VecSimQueryReply_GetIterator(reply); + * while (VecSimQueryReply_IteratorHasNext(iter)) { + * VecSimQueryResult *result = VecSimQueryReply_IteratorNext(iter); + * printf("Label: %llu, Score: %f\n", + * VecSimQueryResult_GetId(result), + * VecSimQueryResult_GetScore(result)); + * } + * + * // Cleanup + * VecSimQueryReply_IteratorFree(iter); + * VecSimQueryReply_Free(reply); + * VecSimIndex_Free(index); + * return 0; + * } + * ``` + */ + +#ifndef VECSIM_H +#define VECSIM_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* ============================================================================ + * Constants and Macros + * ========================================================================== */ + +/** + * @brief Default block size for vector storage. + */ +#define DEFAULT_BLOCK_SIZE 1024 + +/** + * @brief Macro to suppress unused variable warnings. + */ +#define UNUSED(x) (void)(x) + +/** + * @brief String constant for ad-hoc brute force policy. + */ +#define VECSIM_POLICY_ADHOC_BF "adhoc_bf" + +/** + * @brief HNSW default parameters. + */ +#define HNSW_DEFAULT_M 16 +#define HNSW_DEFAULT_EF_C 200 +#define HNSW_DEFAULT_EF_RT 10 +#define HNSW_DEFAULT_EPSILON 0.01 + +/** + * @brief SVS Vamana default parameters. + */ +#define SVS_VAMANA_DEFAULT_ALPHA_L2 1.2f +#define SVS_VAMANA_DEFAULT_ALPHA_IP 0.95f +#define SVS_VAMANA_DEFAULT_GRAPH_MAX_DEGREE 32 +#define SVS_VAMANA_DEFAULT_CONSTRUCTION_WINDOW_SIZE 200 +#define SVS_VAMANA_DEFAULT_USE_SEARCH_HISTORY true +#define SVS_VAMANA_DEFAULT_NUM_THREADS 1 +#define SVS_VAMANA_DEFAULT_TRAINING_THRESHOLD (10 * DEFAULT_BLOCK_SIZE) +#define SVS_VAMANA_DEFAULT_UPDATE_THRESHOLD (1 * DEFAULT_BLOCK_SIZE) +#define SVS_VAMANA_DEFAULT_SEARCH_WINDOW_SIZE 10 +#define SVS_VAMANA_DEFAULT_LEANVEC_DIM 0 +#define SVS_VAMANA_DEFAULT_EPSILON 0.01f + +/** + * @brief General success code. + */ +#define VecSim_OK 0 + +/* ============================================================================ + * Type Definitions + * ========================================================================== */ + +/** + * @brief Label type for vectors (64-bit unsigned integer). + */ +typedef uint64_t labelType; + +/** + * @brief Vector element data type. + */ +typedef enum VecSimType { + VecSimType_FLOAT32 = 0, /**< 32-bit floating point (float) */ + VecSimType_FLOAT64 = 1, /**< 64-bit floating point (double) */ + VecSimType_BFLOAT16 = 2, /**< Brain floating point (16-bit) */ + VecSimType_FLOAT16 = 3, /**< IEEE 754 half-precision (16-bit) */ + VecSimType_INT8 = 4, /**< 8-bit signed integer */ + VecSimType_UINT8 = 5, /**< 8-bit unsigned integer */ + VecSimType_INT32 = 6, /**< 32-bit signed integer */ + VecSimType_INT64 = 7 /**< 64-bit signed integer */ +} VecSimType; + +/** + * @brief Index algorithm type. + */ +typedef enum VecSimAlgo { + VecSimAlgo_BF = 0, /**< Brute Force (exact, linear scan) */ + VecSimAlgo_HNSWLIB = 1, /**< HNSW (approximate, logarithmic) */ + VecSimAlgo_TIERED = 2, /**< Tiered (BruteForce frontend + HNSW backend) */ + VecSimAlgo_SVS = 3 /**< SVS/Vamana (approximate, single-layer graph) */ +} VecSimAlgo; + +/** + * @brief Distance metric type. + */ +typedef enum VecSimMetric { + VecSimMetric_L2 = 0, /**< L2 (Euclidean) squared distance */ + VecSimMetric_IP = 1, /**< Inner Product (dot product) */ + VecSimMetric_Cosine = 2 /**< Cosine distance (1 - cosine similarity) */ +} VecSimMetric; + +/** + * @brief Query result ordering. + */ +typedef enum VecSimQueryReply_Order { + BY_SCORE = 0, /**< Order by distance/score (ascending) */ + BY_ID = 1 /**< Order by label ID (ascending) */ +} VecSimQueryReply_Order; + +/** + * @brief Query reply status code. + */ +typedef enum VecSimQueryReply_Code { + VecSim_QueryReply_OK = 0, /**< Query completed successfully */ + VecSim_QueryReply_TimedOut = 1 /**< Query was aborted due to timeout */ +} VecSimQueryReply_Code; + +/** + * @brief Search mode for queries (internal Rust representation). + * Note: RediSearch defines its own VecSimSearchMode with VECSIM_ prefix. + * This enum is used internally by the Rust library. + */ +typedef enum VecSimSearchMode_Internal { + VECSIM_SEARCH_STANDARD = 0, /**< Standard search mode */ + VECSIM_SEARCH_HYBRID = 1, /**< Hybrid search mode */ + VECSIM_SEARCH_RANGE = 2 /**< Range search mode */ +} VecSimSearchMode_Internal; + +/** + * @brief Hybrid search policy. + */ +typedef enum VecSimHybridPolicy { + BATCHES = 0, /**< Batch-based hybrid search */ + ADHOC = 1 /**< Ad-hoc hybrid search */ +} VecSimHybridPolicy; + +/** + * @brief Index resolution codes. + */ +typedef enum VecSimResolveCode { + VecSim_Resolve_OK = 0, /**< Operation successful */ + VecSim_Resolve_ERR = 1 /**< Operation failed */ +} VecSimResolveCode; + +/** + * @brief Write mode for tiered index operations. + * + * Controls whether vector additions/deletions go through the async + * buffering path or directly to the backend index. + */ +typedef enum VecSimWriteMode { + VecSim_WriteAsync = 0, /**< Async: vectors go to flat buffer, migrated via background jobs */ + VecSim_WriteInPlace = 1 /**< InPlace: vectors go directly to the backend index */ +} VecSimWriteMode; + +/** + * @brief Parameter resolution error codes. + * + * Returned by VecSimIndex_ResolveParams to indicate the result of parsing + * runtime query parameters. + */ +typedef enum VecSimParamResolveCode { + VecSimParamResolver_OK = 0, /**< Resolution succeeded */ + VecSimParamResolverErr_NullParam = 1, /**< Null parameter pointer */ + VecSimParamResolverErr_AlreadySet = 2, /**< Parameter already set */ + VecSimParamResolverErr_UnknownParam = 3, /**< Unknown parameter name */ + VecSimParamResolverErr_BadValue = 4, /**< Invalid parameter value */ + VecSimParamResolverErr_InvalidPolicy_NExits = 5, /**< Policy does not exist */ + VecSimParamResolverErr_InvalidPolicy_NHybrid = 6, /**< Not a hybrid query */ + VecSimParamResolverErr_InvalidPolicy_NRange = 7, /**< Not a range query */ + VecSimParamResolverErr_InvalidPolicy_AdHoc_With_BatchSize = 8, /**< AdHoc with batch size */ + VecSimParamResolverErr_InvalidPolicy_AdHoc_With_EfRuntime = 9 /**< AdHoc with ef_runtime */ +} VecSimParamResolveCode; + +/** + * @brief Query type for parameter resolution. + */ +typedef enum VecsimQueryType { + QUERY_TYPE_NONE = 0, /**< No specific query type */ + QUERY_TYPE_KNN = 1, /**< Standard KNN query */ + QUERY_TYPE_HYBRID = 2, /**< Hybrid query (vector + filters) */ + QUERY_TYPE_RANGE = 3 /**< Range query */ +} VecsimQueryType; + +/** + * @brief Option mode for various settings. + */ +typedef enum VecSimOptionMode { + VecSimOption_AUTO = 0, /**< Automatic mode */ + VecSimOption_ENABLE = 1, /**< Enable the option */ + VecSimOption_DISABLE = 2 /**< Disable the option */ +} VecSimOptionMode; + +/** + * @brief Tri-state boolean for optional settings. + */ +typedef enum VecSimBool { + VecSimBool_TRUE = 1, /**< True */ + VecSimBool_FALSE = 0, /**< False */ + VecSimBool_UNSET = -1 /**< Not set */ +} VecSimBool; + +/** + * @brief Search mode for queries (used for debug/testing). + */ +typedef enum VecSearchMode { + EMPTY_MODE = 0, /**< Empty/unset mode */ + STANDARD_KNN = 1, /**< Standard KNN search */ + HYBRID_ADHOC_BF = 2, /**< Hybrid ad-hoc brute force */ + HYBRID_BATCHES = 3, /**< Hybrid batches search */ + HYBRID_BATCHES_TO_ADHOC_BF = 4, /**< Hybrid batches to ad-hoc BF */ + RANGE_QUERY = 5 /**< Range query */ +} VecSearchMode; + +/** + * @brief Debug command result codes. + */ +typedef enum VecSimDebugCommandCode { + VecSimDebugCommandCode_OK = 0, /**< Command succeeded */ + VecSimDebugCommandCode_BadIndex = 1, /**< Invalid index */ + VecSimDebugCommandCode_LabelNotExists = 2, /**< Label does not exist */ + VecSimDebugCommandCode_MultiNotSupported = 3 /**< Multi-value not supported */ +} VecSimDebugCommandCode; + +/** + * @brief Raw parameter for runtime query configuration. + * + * Used to pass string-based parameters that are resolved into typed + * VecSimQueryParams by VecSimIndex_ResolveParams. + */ +typedef struct VecSimRawParam { + const char *name; /**< Parameter name */ + size_t nameLen; /**< Length of parameter name */ + const char *value; /**< Parameter value as string */ + size_t valLen; /**< Length of parameter value */ +} VecSimRawParam; + +/** + * @brief Timeout callback function type. + * + * Returns non-zero on timeout. + */ +typedef int (*timeoutCallbackFunction)(void *ctx); + +/** + * @brief Log callback function type. + */ +typedef void (*logCallbackFunction)(void *ctx, const char *level, const char *message); + +// ============================================================================ +// C++-Compatible Structures (for drop-in API compatibility) +// ============================================================================ + +/** + * @brief HNSW parameters (C++-compatible layout). + */ +typedef struct { + VecSimType type; + size_t dim; + VecSimMetric metric; + bool multi; + size_t initialCapacity; + size_t blockSize; + size_t M; + size_t efConstruction; + size_t efRuntime; + double epsilon; +} HNSWParams_C; + +/** + * @brief BruteForce parameters (C++-compatible layout). + */ +typedef struct { + VecSimType type; + size_t dim; + VecSimMetric metric; + bool multi; + size_t initialCapacity; + size_t blockSize; +} BFParams_C; + +/** + * @brief SVS quantization bits. + */ +typedef enum { + VecSimSvsQuant_NONE = 0, + VecSimSvsQuant_Scalar = 1, + VecSimSvsQuant_4 = 4, + VecSimSvsQuant_8 = 8, + VecSimSvsQuant_4x4 = 4 | (4 << 8), + VecSimSvsQuant_4x8 = 4 | (8 << 8), + VecSimSvsQuant_4x8_LeanVec = 4 | (8 << 8) | (1 << 16), + VecSimSvsQuant_8x8_LeanVec = 8 | (8 << 8) | (1 << 16) +} VecSimSvsQuantBits; + +/** + * @brief SVS parameters (C++-compatible layout). + */ +typedef struct { + VecSimType type; + size_t dim; + VecSimMetric metric; + bool multi; + size_t blockSize; + VecSimSvsQuantBits quantBits; + float alpha; + size_t graph_max_degree; + size_t construction_window_size; + size_t max_candidate_pool_size; + size_t prune_to; + VecSimOptionMode use_search_history; + size_t num_threads; + size_t search_window_size; + size_t search_buffer_capacity; + size_t leanvec_dim; + double epsilon; +} SVSParams_C; + +/** + * @brief Tiered HNSW specific parameters. + */ +typedef struct { + size_t swapJobThreshold; +} TieredHNSWParams_C; + +/** + * @brief Tiered SVS specific parameters. + */ +typedef struct { + size_t trainingTriggerThreshold; + size_t updateTriggerThreshold; + size_t updateJobWaitTime; +} TieredSVSParams_C; + +/** + * @brief Tiered HNSW Disk specific parameters. + */ +typedef struct { + char _placeholder; +} TieredHNSWDiskParams_C; + +/* Forward declaration */ +typedef struct VecSimParams_C VecSimParams_C; + +/** + * @brief Callback for submitting async jobs. + */ +typedef int (*SubmitCB)(void *job_queue, void *index_ctx, void **jobs, void **cbs, size_t jobs_len); + +/** + * @brief Tiered index parameters (C++-compatible layout). + */ +typedef struct { + void *jobQueue; + void *jobQueueCtx; + SubmitCB submitCb; + size_t flatBufferLimit; + VecSimParams_C *primaryIndexParams; + union { + TieredHNSWParams_C tieredHnswParams; + TieredSVSParams_C tieredSVSParams; + TieredHNSWDiskParams_C tieredHnswDiskParams; + } specificParams; +} TieredIndexParams_C; + +/** + * @brief Union of algorithm parameters (C++-compatible layout). + */ +typedef union { + HNSWParams_C hnswParams; + BFParams_C bfParams; + TieredIndexParams_C tieredParams; + SVSParams_C svsParams; +} AlgoParams_C; + +/** + * @brief VecSimParams (C++-compatible layout). + * + * This structure matches the C++ VecSim API exactly for drop-in compatibility. + */ +struct VecSimParams_C { + VecSimAlgo algo; + AlgoParams_C algoParams; + void *logCtx; +}; + +/** + * @brief Disk context (C++-compatible layout). + */ +typedef struct { + void *storage; + const char *indexName; + size_t indexNameLen; +} VecSimDiskContext_C; + +/** + * @brief Disk parameters (C++-compatible layout). + */ +typedef struct { + VecSimParams_C *indexParams; + VecSimDiskContext_C *diskContext; +} VecSimParamsDisk_C; + +/* ============================================================================ + * C++ API Compatibility Typedefs + * ============================================================================ + * These typedefs provide drop-in compatibility with the C++ VecSim API. + */ +typedef VecSimDiskContext_C VecSimDiskContext; +typedef VecSimParamsDisk_C VecSimParamsDisk; +typedef TieredIndexParams_C TieredIndexParams; +typedef AlgoParams_C AlgoParams; +typedef HNSWParams_C HNSWParams; +typedef BFParams_C BFParams; +typedef SVSParams_C SVSParams; +typedef VecSimParams_C VecSimParams; + +/* Forward declarations for Rust-native types (defined later in this header) */ +struct TieredParams_Rust; +struct DiskParams_Rust; +typedef struct TieredParams_Rust TieredParams; +typedef struct DiskParams_Rust DiskParams; + +/** + * @brief HNSW runtime parameters (C++-compatible layout). + */ +typedef struct { + size_t efRuntime; + double epsilon; +} HNSWRuntimeParams_C; + +typedef HNSWRuntimeParams_C HNSWRuntimeParams; + +/** + * @brief SVS runtime parameters (C++-compatible layout). + */ +typedef struct { + size_t windowSize; + size_t bufferCapacity; + VecSimOptionMode searchHistory; + double epsilon; +} SVSRuntimeParams_C; + +typedef SVSRuntimeParams_C SVSRuntimeParams; + +/** + * @brief Query parameters (C++-compatible layout). + */ +typedef struct { + union { + HNSWRuntimeParams_C hnswRuntimeParams; + SVSRuntimeParams_C svsRuntimeParams; + }; + size_t batchSize; + VecSearchMode searchMode; + void *timeoutCtx; +} VecSimQueryParams_C; + +/* ============================================================================ + * Memory Function Types + * ========================================================================== */ + +/** + * @brief Function pointer type for malloc-style allocation. + */ +typedef void *(*allocFn)(size_t n); + +/** + * @brief Function pointer type for calloc-style allocation. + */ +typedef void *(*callocFn)(size_t nelem, size_t elemsz); + +/** + * @brief Function pointer type for realloc-style reallocation. + */ +typedef void *(*reallocFn)(void *p, size_t n); + +/** + * @brief Function pointer type for free-style deallocation. + */ +typedef void (*freeFn)(void *p); + +/** + * @brief Memory functions struct for custom memory management. + * + * This allows integration with external memory management systems like Redis. + * Pass this struct to VecSim_SetMemoryFunctions to use custom allocators. + */ +typedef struct VecSimMemoryFunctions { + allocFn allocFunction; /**< Malloc-like allocation function */ + callocFn callocFunction; /**< Calloc-like allocation function */ + reallocFn reallocFunction; /**< Realloc-like reallocation function */ + freeFn freeFunction; /**< Free function */ +} VecSimMemoryFunctions; + +/* ============================================================================ + * Opaque Handle Types + * ========================================================================== */ + +/** + * @brief Opaque handle to a vector similarity index. + */ +typedef struct VecSimIndex VecSimIndex; + +/** + * @brief Opaque handle to a query reply. + */ +typedef struct VecSimQueryReply VecSimQueryReply; + +/** + * @brief Opaque handle to a single query result. + */ +typedef struct VecSimQueryResult VecSimQueryResult; + +/** + * @brief Opaque handle to a query reply iterator. + */ +typedef struct VecSimQueryReply_Iterator VecSimQueryReply_Iterator; + +/** + * @brief Opaque handle to a batch iterator. + */ +typedef struct VecSimBatchIterator VecSimBatchIterator; + +/* ============================================================================ + * Parameter Structures + * ========================================================================== */ + +/** + * @brief Common base parameters for all index types (Rust-native API). + */ +typedef struct VecSimBaseParams { + VecSimAlgo algo; /**< Algorithm type */ + VecSimType type_; /**< Vector element data type */ + VecSimMetric metric; /**< Distance metric */ + size_t dim; /**< Vector dimension */ + bool multi; /**< Whether multiple vectors per label are allowed */ + size_t initialCapacity; /**< Initial capacity (number of vectors) */ + size_t blockSize; /**< Block size for storage (0 for default) */ +} VecSimBaseParams; + +/** + * @brief Parameters for BruteForce index creation (Rust-native API). + */ +typedef struct BFParams_Rust { + VecSimBaseParams base; /**< Common parameters */ +} BFParams_Rust; + +/** + * @brief Parameters for HNSW index creation (Rust-native API). + */ +typedef struct HNSWParams_Rust { + VecSimBaseParams base; /**< Common parameters */ + size_t M; /**< Max connections per element per layer (default: 16) */ + size_t efConstruction; /**< Dynamic candidate list size during construction (default: 200) */ + size_t efRuntime; /**< Dynamic candidate list size during search (default: 10) */ + double epsilon; /**< Approximation factor (0 = exact) */ +} HNSWParams_Rust; + +/** + * @brief Parameters for SVS (Vamana) index creation (Rust-native API). + * + * SVS (Search via Satellite) is a graph-based approximate nearest neighbor + * index using the Vamana algorithm with robust pruning. + */ +typedef struct SVSParams_Rust { + VecSimBaseParams base; /**< Common parameters */ + size_t graphMaxDegree; /**< Maximum neighbors per node (R, default: 32) */ + float alpha; /**< Pruning parameter for diversity (default: 1.2) */ + size_t constructionWindowSize; /**< Beam width during construction (L, default: 200) */ + size_t searchWindowSize; /**< Default beam width during search (default: 100) */ + bool twoPassConstruction; /**< Enable two-pass construction (default: true) */ +} SVSParams_Rust; + +/** + * @brief Parameters for Tiered index creation (Rust-native API). + * + * The tiered index combines a BruteForce frontend (for fast writes) with + * an HNSW backend (for efficient queries). Vectors are first added to the + * flat buffer, then migrated to HNSW via VecSimTieredIndex_Flush() or + * automatically when the buffer is full. + */ +typedef struct TieredParams_Rust { + VecSimBaseParams base; /**< Common parameters */ + size_t M; /**< HNSW M parameter (default: 16) */ + size_t efConstruction; /**< HNSW ef_construction (default: 200) */ + size_t efRuntime; /**< HNSW ef_runtime (default: 10) */ + size_t flatBufferLimit; /**< Max flat buffer size before in-place writes (default: 10000) */ + uint32_t writeMode; /**< 0 = Async (buffer first), 1 = InPlace (direct to HNSW) */ +} TieredParams_Rust; + +/** + * @brief Backend type for disk-based indices. + */ +typedef enum DiskBackend { + DiskBackend_BruteForce = 0, /**< Linear scan (exact results) */ + DiskBackend_Vamana = 1 /**< Vamana graph (approximate, fast) */ +} DiskBackend; + +/** + * @brief Parameters for disk-based index creation (Rust-native API). + * + * Disk indices store vectors in memory-mapped files for persistence. + * They support two backends: + * - BruteForce: Linear scan (exact results, O(n)) + * - Vamana: Graph-based approximate search (fast, O(log n)) + */ +typedef struct DiskParams_Rust { + VecSimBaseParams base; /**< Common parameters */ + const char *dataPath; /**< Path to the data file (null-terminated) */ + DiskBackend backend; /**< Backend algorithm (default: BruteForce) */ + size_t graphMaxDegree; /**< Graph max degree for Vamana (default: 32) */ + float alpha; /**< Alpha parameter for Vamana (default: 1.2) */ + size_t constructionL; /**< Construction window size for Vamana (default: 200) */ + size_t searchL; /**< Search window size for Vamana (default: 100) */ +} DiskParams_Rust; + +/** + * @brief HNSW-specific runtime parameters (Rust-native layout). + */ +typedef struct HNSWRuntimeParams_Rust { + size_t efRuntime; /**< Dynamic candidate list size during search */ + double epsilon; /**< Approximation factor */ +} HNSWRuntimeParams_Rust; + +/** + * @brief SVS-specific runtime parameters (Rust-native layout). + */ +typedef struct SVSRuntimeParams_Rust { + size_t windowSize; /**< Search window size for graph search */ + size_t bufferCapacity; /**< Search buffer capacity */ + int searchHistory; /**< Whether to use search history (0/1) */ + double epsilon; /**< Approximation factor for range search */ +} SVSRuntimeParams_Rust; + +/** + * @brief Query parameters (Rust-native layout). + * + * Note: For C++ API compatibility, use VecSimQueryParams_C instead. + */ +typedef struct VecSimQueryParams_Rust { + HNSWRuntimeParams hnswRuntimeParams; /**< HNSW-specific parameters */ + SVSRuntimeParams svsRuntimeParams; /**< SVS-specific parameters */ + VecSimSearchMode_Internal searchMode; /**< Search mode */ + VecSimHybridPolicy hybridPolicy; /**< Hybrid policy */ + size_t batchSize; /**< Batch size for iteration */ + void *timeoutCtx; /**< Timeout context (opaque) */ +} VecSimQueryParams_Rust; + +/** + * @brief Query parameters (C++ API compatible). + * + * This typedef provides compatibility with the C++ VecSim API. + */ +typedef VecSimQueryParams_C VecSimQueryParams; + +/* ============================================================================ + * Index Info Structures + * ========================================================================== */ + +/** + * @brief HNSW-specific index information. + */ +typedef struct VecSimHnswInfo { + size_t M; /**< M parameter */ + size_t efConstruction; /**< ef_construction parameter */ + size_t efRuntime; /**< ef_runtime parameter */ + size_t maxLevel; /**< Maximum level in the graph */ + int64_t entrypoint; /**< Entry point ID (-1 if none) */ + double epsilon; /**< Epsilon parameter */ +} VecSimHnswInfo; + +/** + * @brief Comprehensive index information. + */ +typedef struct VecSimIndexInfo { + size_t indexSize; /**< Current number of vectors */ + size_t indexLabelCount; /**< Current number of unique labels */ + size_t dim; /**< Vector dimension */ + VecSimType type_; /**< Data type */ + VecSimAlgo algo; /**< Algorithm type */ + VecSimMetric metric; /**< Distance metric */ + bool isMulti; /**< Whether multi-value index */ + size_t blockSize; /**< Block size */ + size_t memory; /**< Memory usage in bytes */ + VecSimHnswInfo hnswInfo; /**< HNSW-specific info (if applicable) */ +} VecSimIndexInfo; + +/** + * @brief Index info that is static and immutable. + */ +typedef struct VecSimIndexBasicInfo { + VecSimAlgo algo; /**< Algorithm type */ + VecSimMetric metric; /**< Distance metric */ + VecSimType type; /**< Data type */ + bool isMulti; /**< Whether multi-value index */ + bool isTiered; /**< Whether tiered index */ + bool isDisk; /**< Whether disk-based index */ + size_t blockSize; /**< Block size */ + size_t dim; /**< Vector dimension */ +} VecSimIndexBasicInfo; + +/** + * @brief Index info for statistics - thin and efficient. + */ +typedef struct VecSimIndexStatsInfo { + size_t memory; /**< Memory usage in bytes */ + size_t numberOfMarkedDeleted; /**< Number of marked deleted entries */ +} VecSimIndexStatsInfo; + +/** + * @brief Common index information. + */ +typedef struct CommonInfo { + VecSimIndexBasicInfo basicInfo; /**< Basic index information */ + size_t indexSize; /**< Current number of vectors */ + size_t indexLabelCount; /**< Current number of unique labels */ + uint64_t memory; /**< Memory usage in bytes */ + VecSearchMode lastMode; /**< Last search mode used */ +} CommonInfo; + +/** + * @brief HNSW-specific debug information. + */ +typedef struct hnswInfoStruct { + size_t M; /**< M parameter */ + size_t efConstruction; /**< ef_construction parameter */ + size_t efRuntime; /**< ef_runtime parameter */ + double epsilon; /**< Epsilon parameter */ + size_t max_level; /**< Maximum level in the graph */ + size_t entrypoint; /**< Entry point ID */ + size_t visitedNodesPoolSize; /**< Visited nodes pool size */ + size_t numberOfMarkedDeletedNodes; /**< Number of marked deleted nodes */ +} hnswInfoStruct; + +/** + * @brief BruteForce-specific debug information. + */ +typedef struct bfInfoStruct { + int8_t dummy; /**< Placeholder field */ +} bfInfoStruct; + +/** + * @brief Debug information for an index. + */ +typedef struct VecSimIndexDebugInfo { + CommonInfo commonInfo; /**< Common index information */ + hnswInfoStruct hnswInfo; /**< HNSW-specific info */ + bfInfoStruct bfInfo; /**< BruteForce-specific info */ +} VecSimIndexDebugInfo; + +/* ============================================================================ + * Debug Info Iterator Types + * ========================================================================== */ + +/** + * @brief Opaque handle to a debug info iterator. + */ +typedef struct VecSimDebugInfoIterator VecSimDebugInfoIterator; + +/** + * @brief Field type for debug info fields. + */ +typedef enum { + INFOFIELD_STRING = 0, /**< String value */ + INFOFIELD_INT64 = 1, /**< Signed 64-bit integer */ + INFOFIELD_UINT64 = 2, /**< Unsigned 64-bit integer */ + INFOFIELD_FLOAT64 = 3, /**< 64-bit floating point */ + INFOFIELD_ITERATOR = 4 /**< Nested iterator */ +} VecSim_InfoFieldType; + +/** + * @brief Union of field values. + */ +typedef union { + double floatingPointValue; /**< 64-bit float value */ + int64_t integerValue; /**< Signed 64-bit integer */ + uint64_t uintegerValue; /**< Unsigned 64-bit integer */ + const char *stringValue; /**< String value */ + VecSimDebugInfoIterator *iteratorValue; /**< Nested iterator */ +} FieldValue; + +/** + * @brief A field in the debug info iterator. + */ +typedef struct { + const char *fieldName; /**< Field name */ + VecSim_InfoFieldType fieldType; /**< Field type */ + FieldValue fieldValue; /**< Field value */ +} VecSim_InfoField; + +/* ============================================================================ + * Index Lifecycle Functions + * ========================================================================== */ + +/** + * @brief Create a new vector similarity index (C++-compatible API). + * + * This function provides drop-in compatibility with the C++ VecSim API. + * It reads the algo field to determine which type of index to create, + * then accesses the appropriate union variant in algoParams. + * + * @param params Index parameters with algorithm-specific params in the union. + * @return A new index handle, or NULL on failure. + * + * @note For type-safe index creation, use VecSimIndex_NewBF(), VecSimIndex_NewHNSW(), + * VecSimIndex_NewSVS(), VecSimIndex_NewTiered(), or VecSimIndex_NewDisk(). + */ +VecSimIndex *VecSimIndex_New(const VecSimParams_C *params); + +/** + * @brief Create a new BruteForce index. + * + * @param params Pointer to BruteForce-specific parameters + * @return Pointer to the created index, or NULL on failure + */ +VecSimIndex *VecSimIndex_NewBF(const BFParams *params); + +/** + * @brief Create a new HNSW index. + * + * @param params Pointer to HNSW-specific parameters + * @return Pointer to the created index, or NULL on failure + */ +VecSimIndex *VecSimIndex_NewHNSW(const HNSWParams *params); + +/** + * @brief Create a new SVS (Vamana) index. + * + * SVS provides an alternative to HNSW with a single-layer graph structure. + * It uses robust pruning to maintain graph quality and provides good + * recall with efficient memory usage. + * + * @param params Pointer to SVS-specific parameters + * @return Pointer to the created index, or NULL on failure + */ +VecSimIndex *VecSimIndex_NewSVS(const SVSParams *params); + +/** + * @brief Create a new Tiered index. + * + * The tiered index combines a BruteForce frontend (for fast writes) with + * an HNSW backend (for efficient queries). Vectors are first added to the + * flat buffer, then migrated to HNSW via VecSimTieredIndex_Flush() or + * automatically when the buffer is full. + * + * Currently only supports f32 vectors. + * + * @param params Pointer to Tiered-specific parameters + * @return Pointer to the created index, or NULL on failure + */ +VecSimIndex *VecSimIndex_NewTiered(const TieredParams *params); + +/** + * @brief Free a vector similarity index. + * + * @param index Pointer to the index to free (may be NULL) + */ +void VecSimIndex_Free(VecSimIndex *index); + +/* ============================================================================ + * Tiered Index Operations + * ========================================================================== */ + +/** + * @brief Flush the flat buffer to the HNSW backend. + * + * This migrates all vectors from the flat buffer to the HNSW index. + * + * @param index Pointer to a tiered index + * @return Number of vectors flushed, or 0 if the index is not tiered + */ +size_t VecSimTieredIndex_Flush(VecSimIndex *index); + +/** + * @brief Get the number of vectors in the flat buffer. + * + * @param index Pointer to a tiered index + * @return Number of vectors in the flat buffer, or 0 if not tiered + */ +size_t VecSimTieredIndex_FlatSize(const VecSimIndex *index); + +/** + * @brief Get the number of vectors in the HNSW backend. + * + * @param index Pointer to a tiered index + * @return Number of vectors in the HNSW backend, or 0 if not tiered + */ +size_t VecSimTieredIndex_BackendSize(const VecSimIndex *index); + +/** + * @brief Run garbage collection on a tiered index. + * + * This cleans up deleted vectors and optimizes the index structure. + * + * @param index The tiered index handle. + */ +void VecSimTieredIndex_GC(VecSimIndex *index); + +/** + * @brief Acquire shared locks on a tiered index. + * + * This prevents modifications to the index while the locks are held. + * Must be paired with VecSimTieredIndex_ReleaseSharedLocks. + * + * @param index The tiered index handle. + */ +void VecSimTieredIndex_AcquireSharedLocks(VecSimIndex *index); + +/** + * @brief Release shared locks on a tiered index. + * + * Must be called after VecSimTieredIndex_AcquireSharedLocks. + * + * @param index The tiered index handle. + */ +void VecSimTieredIndex_ReleaseSharedLocks(VecSimIndex *index); + +/** + * @brief Check if the index is a tiered index. + * + * @param index Pointer to an index + * @return true if the index is tiered, false otherwise + */ +bool VecSimIndex_IsTiered(const VecSimIndex *index); + +/** + * @brief Create a new disk-based index. + * + * Disk indices store vectors in memory-mapped files for persistence. + * They support two backends: + * - BruteForce: Linear scan (exact results, O(n)) + * - Vamana: Graph-based approximate search (fast, O(log n)) + * + * Currently only supports f32 vectors. + * + * @param params Pointer to Disk-specific parameters + * @return Pointer to the created index, or NULL on failure + */ +VecSimIndex *VecSimIndex_NewDisk(const DiskParams *params); + +/** + * @brief Check if the index is a disk-based index. + * + * @param index Pointer to an index + * @return true if the index is disk-based, false otherwise + */ +bool VecSimIndex_IsDisk(const VecSimIndex *index); + +/** + * @brief Flush changes to disk for a disk-based index. + * + * This ensures all pending changes are written to the underlying file. + * + * @param index Pointer to a disk-based index + * @return true if flush succeeded, false otherwise + */ +bool VecSimDiskIndex_Flush(const VecSimIndex *index); + +/* ============================================================================ + * Vector Operations + * ========================================================================== */ + +/** + * @brief Add a vector to the index. + * + * @param index Pointer to the index + * @param vector Pointer to the vector data (must match index dimension and type) + * @param label Label to associate with the vector + * @return Number of vectors added (1 on success), or -1 on failure + * + * @note For single-value indices, adding a vector with an existing label + * replaces the previous vector. + */ +int VecSimIndex_AddVector(VecSimIndex *index, const void *vector, labelType label); + +/** + * @brief Delete all vectors with the given label. + * + * @param index Pointer to the index + * @param label Label of vectors to delete + * @return Number of vectors deleted, or 0 if label not found + */ +int VecSimIndex_DeleteVector(VecSimIndex *index, labelType label); + +/** + * @brief Get the distance from a stored vector to a query vector. + * + * @param index Pointer to the index + * @param label Label of the stored vector + * @param vector Pointer to the query vector + * @return Distance value, or INFINITY if label not found + * + * @warning This function accesses internal storage directly. Use with caution. + */ +double VecSimIndex_GetDistanceFrom_Unsafe(VecSimIndex *index, labelType label, const void *vector); + +/* ============================================================================ + * Query Functions + * ========================================================================== */ + +/** + * @brief Perform a top-k nearest neighbor query. + * + * @param index Pointer to the index + * @param query Pointer to the query vector + * @param k Maximum number of results to return + * @param params Query parameters (may be NULL for defaults) + * @param order Result ordering + * @return Pointer to query reply, or NULL on failure + * + * @note The caller is responsible for freeing the reply with VecSimQueryReply_Free(). + */ +VecSimQueryReply *VecSimIndex_TopKQuery( + VecSimIndex *index, + const void *query, + size_t k, + const VecSimQueryParams *params, + VecSimQueryReply_Order order +); + +/** + * @brief Perform a range query. + * + * @param index Pointer to the index + * @param query Pointer to the query vector + * @param radius Maximum distance from query (inclusive) + * @param params Query parameters (may be NULL for defaults) + * @param order Result ordering + * @return Pointer to query reply, or NULL on failure + * + * @note The caller is responsible for freeing the reply with VecSimQueryReply_Free(). + */ +VecSimQueryReply *VecSimIndex_RangeQuery( + VecSimIndex *index, + const void *query, + double radius, + const VecSimQueryParams *params, + VecSimQueryReply_Order order +); + +/* ============================================================================ + * Query Reply Functions + * ========================================================================== */ + +/** + * @brief Get the number of results in a query reply. + * + * @param reply Pointer to the query reply + * @return Number of results + */ +size_t VecSimQueryReply_Len(const VecSimQueryReply *reply); + +/** + * @brief Get the status code of a query reply. + * + * This is used to detect if the query timed out. + * + * @param reply Pointer to the query reply + * @return The status code (VecSim_QueryReply_OK or VecSim_QueryReply_TimedOut) + */ +VecSimQueryReply_Code VecSimQueryReply_GetCode(const VecSimQueryReply *reply); + +/** + * @brief Free a query reply. + * + * @param reply Pointer to the query reply (may be NULL) + */ +void VecSimQueryReply_Free(VecSimQueryReply *reply); + +/** + * @brief Get an iterator over query results. + * + * @param reply Pointer to the query reply + * @return Pointer to iterator, or NULL on failure + * + * @note The iterator is only valid while the reply exists. + * Free with VecSimQueryReply_IteratorFree(). + */ +VecSimQueryReply_Iterator *VecSimQueryReply_GetIterator(VecSimQueryReply *reply); + +/** + * @brief Check if the iterator has more results. + * + * @param iter Pointer to the iterator + * @return true if more results available, false otherwise + */ +bool VecSimQueryReply_IteratorHasNext(const VecSimQueryReply_Iterator *iter); + +/** + * @brief Get the next result from the iterator. + * + * @param iter Pointer to the iterator + * @return Pointer to the next result, or NULL if no more results + * + * @note The returned pointer is valid until the next call to IteratorNext + * or until the reply is freed. + */ +const VecSimQueryResult *VecSimQueryReply_IteratorNext(VecSimQueryReply_Iterator *iter); + +/** + * @brief Reset the iterator to the beginning. + * + * @param iter Pointer to the iterator + */ +void VecSimQueryReply_IteratorReset(VecSimQueryReply_Iterator *iter); + +/** + * @brief Free an iterator. + * + * @param iter Pointer to the iterator (may be NULL) + */ +void VecSimQueryReply_IteratorFree(VecSimQueryReply_Iterator *iter); + +/* ============================================================================ + * Query Result Functions + * ========================================================================== */ + +/** + * @brief Get the label (ID) from a query result. + * + * @param result Pointer to the query result + * @return Label of the result + */ +labelType VecSimQueryResult_GetId(const VecSimQueryResult *result); + +/** + * @brief Get the score (distance) from a query result. + * + * @param result Pointer to the query result + * @return Distance/score of the result + */ +double VecSimQueryResult_GetScore(const VecSimQueryResult *result); + +/* ============================================================================ + * Batch Iterator Functions + * ========================================================================== */ + +/** + * @brief Create a batch iterator for incremental query processing. + * + * @param index Pointer to the index + * @param query Pointer to the query vector + * @param params Query parameters (may be NULL for defaults) + * @return Pointer to batch iterator, or NULL on failure + * + * @note Batch iterators are useful for processing large result sets + * incrementally without loading all results into memory. + */ +VecSimBatchIterator *VecSimBatchIterator_New( + VecSimIndex *index, + const void *query, + const VecSimQueryParams *params +); + +/** + * @brief Get the next batch of results. + * + * @param iter Pointer to the batch iterator + * @param n Maximum number of results to return in this batch + * @param order Result ordering + * @return Pointer to query reply containing the batch, or NULL on failure + * + * @note The caller is responsible for freeing the reply with VecSimQueryReply_Free(). + */ +VecSimQueryReply *VecSimBatchIterator_Next( + VecSimBatchIterator *iter, + size_t n, + VecSimQueryReply_Order order +); + +/** + * @brief Check if the batch iterator has more results. + * + * @param iter Pointer to the batch iterator + * @return true if more results available, false otherwise + */ +bool VecSimBatchIterator_HasNext(const VecSimBatchIterator *iter); + +/** + * @brief Reset the batch iterator to the beginning. + * + * @param iter Pointer to the batch iterator + */ +void VecSimBatchIterator_Reset(VecSimBatchIterator *iter); + +/** + * @brief Free a batch iterator. + * + * @param iter Pointer to the batch iterator (may be NULL) + */ +void VecSimBatchIterator_Free(VecSimBatchIterator *iter); + +/* ============================================================================ + * Index Property Functions + * ========================================================================== */ + +/** + * @brief Get the current number of vectors in the index. + * + * @param index Pointer to the index + * @return Number of vectors + */ +size_t VecSimIndex_IndexSize(const VecSimIndex *index); + +/** + * @brief Get the data type of the index. + * + * @param index Pointer to the index + * @return Data type enum value + */ +VecSimType VecSimIndex_GetType(const VecSimIndex *index); + +/** + * @brief Get the distance metric of the index. + * + * @param index Pointer to the index + * @return Metric enum value + */ +VecSimMetric VecSimIndex_GetMetric(const VecSimIndex *index); + +/** + * @brief Get the vector dimension of the index. + * + * @param index Pointer to the index + * @return Vector dimension + */ +size_t VecSimIndex_GetDim(const VecSimIndex *index); + +/** + * @brief Check if the index is a multi-value index. + * + * @param index Pointer to the index + * @return true if multi-value, false if single-value + */ +bool VecSimIndex_IsMulti(const VecSimIndex *index); + +/** + * @brief Check if a label exists in the index. + * + * @param index Pointer to the index + * @param label Label to check + * @return true if label exists, false otherwise + */ +bool VecSimIndex_ContainsLabel(const VecSimIndex *index, labelType label); + +/** + * @brief Get the count of vectors with the given label. + * + * @param index Pointer to the index + * @param label Label to count + * @return Number of vectors with this label (0 if not found) + */ +size_t VecSimIndex_LabelCount(const VecSimIndex *index, labelType label); + +/** + * @brief Get detailed index information. + * + * @param index Pointer to the index + * @return VecSimIndexInfo structure with index details + */ +VecSimIndexInfo VecSimIndex_Info(const VecSimIndex *index); + +/** + * @brief Get basic immutable index information. + * + * @param index The index handle. + * @return Basic index information. + */ +VecSimIndexBasicInfo VecSimIndex_BasicInfo(const VecSimIndex *index); + +/** + * @brief Get index statistics information. + * + * This is a thin and efficient info call with no locks or calculations. + * + * @param index The index handle. + * @return Statistics information. + */ +VecSimIndexStatsInfo VecSimIndex_StatsInfo(const VecSimIndex *index); + +/** + * @brief Get detailed debug information for an index. + * + * This should only be used for debug/testing purposes. + * + * @param index The index handle. + * @return Debug information. + */ +VecSimIndexDebugInfo VecSimIndex_DebugInfo(const VecSimIndex *index); + +/** + * @brief Create a debug info iterator for an index. + * + * The iterator provides a way to traverse all debug information fields + * for an index, including nested information for tiered indices. + * + * @param index The index handle. + * @return A debug info iterator, or NULL on failure. + * Must be freed with VecSimDebugInfoIterator_Free(). + */ +VecSimDebugInfoIterator *VecSimIndex_DebugInfoIterator(const VecSimIndex *index); + +/** + * @brief Returns the number of fields in the info iterator. + * + * @param infoIterator The info iterator. + * @return Number of fields. + */ +size_t VecSimDebugInfoIterator_NumberOfFields(VecSimDebugInfoIterator *infoIterator); + +/** + * @brief Check if the iterator has more fields. + * + * @param infoIterator The info iterator. + * @return true if more fields are available, false otherwise. + */ +bool VecSimDebugInfoIterator_HasNextField(VecSimDebugInfoIterator *infoIterator); + +/** + * @brief Get the next field from the iterator. + * + * The returned pointer is valid until the next call to this function + * or until the iterator is freed. + * + * @param infoIterator The info iterator. + * @return Pointer to the next info field, or NULL if no more fields. + */ +VecSim_InfoField *VecSimDebugInfoIterator_NextField(VecSimDebugInfoIterator *infoIterator); + +/** + * @brief Free a debug info iterator. + * + * This also frees all nested iterators. + * + * @param infoIterator The info iterator to free. + */ +void VecSimDebugInfoIterator_Free(VecSimDebugInfoIterator *infoIterator); + +/** + * @brief Dump the neighbors of an element in HNSW index. + * + * Returns an array with entries, where each entry is an array itself. + * Every internal array in a position where <0<=l<=topLevel> corresponds to the neighbors + * of the element in the graph in level . It contains entries, where is the + * number of neighbors in level l. The last entry in the external array is NULL (indicates its length). + * The first entry in each internal array contains the number , while the next + * entries are the labels of the elements neighbors in this level. + * + * Note: currently only HNSW indexes of type single are supported (multi not yet) - tiered included. + * For cleanup, VecSimDebug_ReleaseElementNeighborsInHNSWGraph needs to be called with the value + * pointed by neighborsData as returned from this call. + * + * @param index The index in which the element resides. + * @param label The label to dump its neighbors in every level in which it exists. + * @param neighborsData A pointer to a 2-dim array of integer which is a placeholder for the + * output of the neighbors' labels that will be allocated and stored. + * @return VecSimDebugCommandCode indicating success or failure reason. + */ +VecSimDebugCommandCode VecSimDebug_GetElementNeighborsInHNSWGraph(VecSimIndex *index, size_t label, + int ***neighborsData); + +/** + * @brief Release the neighbors data allocated by VecSimDebug_GetElementNeighborsInHNSWGraph. + * + * @param neighborsData The 2-dim array returned in the placeholder to be de-allocated. + */ +void VecSimDebug_ReleaseElementNeighborsInHNSWGraph(int **neighborsData); + +/** + * @brief Determine if ad-hoc brute-force search is preferred over batched search. + * + * This is a heuristic function that helps decide the optimal search strategy + * for hybrid queries based on the index size and the number of results needed. + * + * @param index The index handle. + * @param subsetSize The estimated size of the subset to search. + * @param k The number of results requested. + * @param initial Whether this is the initial decision (true) or a re-evaluation (false). + * @return true if ad-hoc search is preferred, false if batched search is preferred. + */ +bool VecSimIndex_PreferAdHocSearch(const VecSimIndex *index, size_t subsetSize, size_t k, bool initial); + +/* ============================================================================ + * Parameter Resolution + * ========================================================================== */ + +/** + * @brief Resolve runtime query parameters from raw string parameters. + * + * Parses an array of VecSimRawParam structures and populates a VecSimQueryParams + * structure with the resolved typed values. + * + * @param index Pointer to the index (used to determine algorithm-specific parameters) + * @param rparams Array of raw parameters to resolve + * @param paramNum Number of parameters in the array + * @param qparams Pointer to VecSimQueryParams structure to populate + * @param query_type Type of query (KNN, HYBRID, or RANGE) + * @return VecSimParamResolver_OK on success, error code on failure + * + * Supported parameters: + * - EF_RUNTIME: HNSW ef_runtime (positive integer, not for range queries) + * - EPSILON: Approximation factor (positive float, range queries only, HNSW/SVS) + * - BATCH_SIZE: Batch size for hybrid queries (positive integer) + * - HYBRID_POLICY: "batches" or "adhoc_bf" (hybrid queries only) + * - SEARCH_WINDOW_SIZE: SVS search window size (positive integer) + * - SEARCH_BUFFER_CAPACITY: SVS search buffer capacity (positive integer) + * - USE_SEARCH_HISTORY: SVS search history flag ("true"/"false"/"1"/"0") + */ +VecSimParamResolveCode VecSimIndex_ResolveParams(VecSimIndex *index, + VecSimRawParam *rparams, + int paramNum, + VecSimQueryParams *qparams, + VecsimQueryType query_type); + +/* ============================================================================ + * Serialization Functions + * ========================================================================== */ + +/** + * @brief Save an index to a file. + * + * @param index Pointer to the index + * @param path File path to save to (null-terminated C string) + * @return true on success, false on failure + * + * @note Serialization is supported for: + * - BruteForce (f32 only) + * - HNSW (all data types) + * - SVS Single (f32 only) + */ +bool VecSimIndex_SaveIndex(const VecSimIndex *index, const char *path); + +/** + * @brief Load an index from a file. + * + * Reads the file header to determine the index type and data type, + * then loads the appropriate index. The caller is responsible for + * freeing the returned index with VecSimIndex_Free. + * + * @param path File path to load from (null-terminated C string) + * @param params Optional parameters to override (may be NULL, currently unused) + * @return Pointer to loaded index, or NULL on failure + * + * @note Supported index types for loading: + * - BruteForceSingle/Multi (f32) + * - HnswSingle/Multi (f32) + * - SvsSingle (f32) + */ +VecSimIndex *VecSimIndex_LoadIndex(const char *path, const VecSimParams *params); + +/* ============================================================================ + * Memory Estimation Functions + * ========================================================================== */ + +/** + * @brief Estimate initial memory size for a BruteForce index. + * + * @param dim Vector dimension + * @param initial_capacity Initial capacity (number of vectors) + * @return Estimated memory size in bytes + */ +size_t VecSimIndex_EstimateBruteForceInitialSize(size_t dim, size_t initial_capacity); + +/** + * @brief Estimate memory size per element for a BruteForce index. + * + * @param dim Vector dimension + * @return Estimated memory per element in bytes + */ +size_t VecSimIndex_EstimateBruteForceElementSize(size_t dim); + +/** + * @brief Estimate initial memory size for an HNSW index. + * + * @param dim Vector dimension + * @param initial_capacity Initial capacity (number of vectors) + * @param m M parameter (max connections per layer) + * @return Estimated memory size in bytes + */ +size_t VecSimIndex_EstimateHNSWInitialSize(size_t dim, size_t initial_capacity, size_t m); + +/** + * @brief Estimate memory size per element for an HNSW index. + * + * @param dim Vector dimension + * @param m M parameter (max connections per layer) + * @return Estimated memory per element in bytes + */ +size_t VecSimIndex_EstimateHNSWElementSize(size_t dim, size_t m); + +/** + * @brief Estimate initial memory size for an index based on parameters. + * + * @param params The index parameters. + * @return Estimated initial memory size in bytes. + */ +size_t VecSimIndex_EstimateInitialSize(const VecSimParams *params); + +/** + * @brief Estimate memory size per element for an index based on parameters. + * + * @param params The index parameters. + * @return Estimated memory size per element in bytes. + */ +size_t VecSimIndex_EstimateElementSize(const VecSimParams *params); + +/* ============================================================================ + * Write Mode Control + * ========================================================================== */ + +/** + * @brief Set the global write mode for tiered index operations. + * + * This controls whether vector additions/deletions in tiered indices go through + * the async buffering path (VecSim_WriteAsync) or directly to the backend index + * (VecSim_WriteInPlace). + * + * @param mode The write mode to set. + * + * @note In a tiered index scenario, this should be called from the main thread only + * (that is, the thread that is calling add/delete vector functions). + */ +void VecSim_SetWriteMode(VecSimWriteMode mode); + +/** + * @brief Get the current global write mode. + * + * @return The currently active write mode for tiered index operations. + */ +VecSimWriteMode VecSim_GetWriteMode(void); + +/* ============================================================================ + * Memory Management Functions + * ========================================================================== */ + +/** + * @brief Set custom memory functions for all future allocations. + * + * This allows integration with external memory management systems like Redis. + * The functions will be used for all memory allocations in the library. + * + * @param functions The memory functions struct containing custom allocators. + * + * @note This should be called once at initialization, before creating any indices. + * @note The provided function pointers must be valid and thread-safe. + * @note The functions must follow standard malloc/calloc/realloc/free semantics. + */ +void VecSim_SetMemoryFunctions(VecSimMemoryFunctions functions); + +/** + * @brief Set the timeout callback function. + * + * The callback will be called periodically during long operations to check + * if the operation should be aborted. Return non-zero to abort. + * + * @param callback The timeout callback function, or NULL to disable. + */ +void VecSim_SetTimeoutCallbackFunction(timeoutCallbackFunction callback); + +/** + * @brief Set the log callback function. + * + * The callback will be called for logging messages from the library. + * + * @param callback The log callback function, or NULL to disable. + */ +void VecSim_SetLogCallbackFunction(logCallbackFunction callback); + +/** + * @brief Set the test log context. + * + * This is used for testing to identify which test is running. + * + * @param test_name The name of the test. + * @param test_type The type of the test. + */ +void VecSim_SetTestLogContext(const char *test_name, const char *test_type); + +/* ============================================================================ + * Vector Utility Functions + * ========================================================================== */ + +/** + * @brief Normalize a vector in-place. + * + * This normalizes the vector to unit length (L2 norm = 1). + * This is useful for cosine similarity where vectors should be normalized. + * + * @param blob Pointer to the vector data. + * @param dim The dimension of the vector. + * @param type The data type of the vector elements. + */ +void VecSim_Normalize(void *blob, size_t dim, VecSimType type); + +#ifdef __cplusplus +} +#endif + +#endif /* VECSIM_H */ diff --git a/rust/vecsim-c/src/compat.rs b/rust/vecsim-c/src/compat.rs new file mode 100644 index 000000000..7e7d82b56 --- /dev/null +++ b/rust/vecsim-c/src/compat.rs @@ -0,0 +1,226 @@ +//! C++-compatible type definitions for binary compatibility with the C++ VecSim API. +//! +//! These types use unions to match the exact memory layout of the C++ structs. +//! They are provided for drop-in compatibility with existing C++ code. +//! +//! For new Rust code, prefer the type-safe structs in `params.rs`. + +use std::ffi::c_void; + +use crate::types::{VecSimAlgo, VecSimMetric, VecSimOptionMode, VecSimType, VecSearchMode}; + +/// HNSW parameters matching C++ HNSWParams exactly. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct HNSWParams_C { + pub type_: VecSimType, + pub dim: usize, + pub metric: VecSimMetric, + pub multi: bool, + pub initialCapacity: usize, + pub blockSize: usize, + pub M: usize, + pub efConstruction: usize, + pub efRuntime: usize, + pub epsilon: f64, +} + +/// BruteForce parameters matching C++ BFParams exactly. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct BFParams_C { + pub type_: VecSimType, + pub dim: usize, + pub metric: VecSimMetric, + pub multi: bool, + pub initialCapacity: usize, + pub blockSize: usize, +} + +/// SVS quantization bits matching C++ VecSimSvsQuantBits. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimSvsQuantBits { + VecSimSvsQuant_NONE = 0, + VecSimSvsQuant_Scalar = 1, + VecSimSvsQuant_4 = 4, + VecSimSvsQuant_8 = 8, + // Complex values handled as integers +} + +/// SVS parameters matching C++ SVSParams exactly. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SVSParams_C { + pub type_: VecSimType, + pub dim: usize, + pub metric: VecSimMetric, + pub multi: bool, + pub blockSize: usize, + pub quantBits: VecSimSvsQuantBits, + pub alpha: f32, + pub graph_max_degree: usize, + pub construction_window_size: usize, + pub max_candidate_pool_size: usize, + pub prune_to: usize, + pub use_search_history: VecSimOptionMode, + pub num_threads: usize, + pub search_window_size: usize, + pub search_buffer_capacity: usize, + pub leanvec_dim: usize, + pub epsilon: f64, +} + +/// Tiered HNSW specific parameters. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct TieredHNSWParams_C { + pub swapJobThreshold: usize, +} + +/// Tiered SVS specific parameters. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct TieredSVSParams_C { + pub trainingTriggerThreshold: usize, + pub updateTriggerThreshold: usize, + pub updateJobWaitTime: usize, +} + +/// Tiered HNSW Disk specific parameters. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct TieredHNSWDiskParams_C { + pub _placeholder: u8, +} + +/// Union of tiered-specific parameters. +#[repr(C)] +#[derive(Clone, Copy)] +pub union TieredSpecificParams_C { + pub tieredHnswParams: TieredHNSWParams_C, + pub tieredSVSParams: TieredSVSParams_C, + pub tieredHnswDiskParams: TieredHNSWDiskParams_C, +} + +/// Callback for submitting async jobs. +pub type SubmitCB = Option< + unsafe extern "C" fn( + job_queue: *mut c_void, + index_ctx: *mut c_void, + jobs: *mut *mut c_void, + cbs: *mut *mut c_void, + jobs_len: usize, + ) -> i32, +>; + +/// Tiered index parameters matching C++ TieredIndexParams. +#[repr(C)] +pub struct TieredIndexParams_C { + pub jobQueue: *mut c_void, + pub jobQueueCtx: *mut c_void, + pub submitCb: SubmitCB, + pub flatBufferLimit: usize, + pub primaryIndexParams: *mut VecSimParams_C, + pub specificParams: TieredSpecificParams_C, +} + +/// Union of algorithm parameters matching C++ AlgoParams. +/// Note: This union cannot derive Copy because TieredIndexParams_C contains pointers. +/// Clone is implemented manually via unsafe memory copy. +#[repr(C)] +pub union AlgoParams_C { + pub hnswParams: HNSWParams_C, + pub bfParams: BFParams_C, + pub tieredParams: std::mem::ManuallyDrop, + pub svsParams: SVSParams_C, +} + +impl Clone for AlgoParams_C { + fn clone(&self) -> Self { + // Safety: This is a C-compatible union, so we can safely copy the bytes. + // The caller is responsible for ensuring the correct variant is used. + unsafe { + std::ptr::read(self) + } + } +} + +/// VecSimParams matching C++ struct VecSimParams exactly. +#[repr(C)] +pub struct VecSimParams_C { + pub algo: VecSimAlgo, + pub algoParams: AlgoParams_C, + pub logCtx: *mut c_void, +} + +/// Disk context matching C++ VecSimDiskContext. +#[repr(C)] +pub struct VecSimDiskContext_C { + pub storage: *mut c_void, + pub indexName: *const std::ffi::c_char, + pub indexNameLen: usize, +} + +/// Disk parameters matching C++ VecSimParamsDisk. +#[repr(C)] +pub struct VecSimParamsDisk_C { + pub indexParams: *mut VecSimParams_C, + pub diskContext: *mut VecSimDiskContext_C, +} + +/// HNSW runtime parameters matching C++ HNSWRuntimeParams. +#[repr(C)] +#[derive(Debug, Clone, Copy, Default)] +pub struct HNSWRuntimeParams_C { + pub efRuntime: usize, + pub epsilon: f64, +} + +/// SVS runtime parameters matching C++ SVSRuntimeParams. +#[repr(C)] +#[derive(Debug, Clone, Copy, Default)] +pub struct SVSRuntimeParams_C { + pub windowSize: usize, + pub bufferCapacity: usize, + pub searchHistory: VecSimOptionMode, + pub epsilon: f64, +} + +/// Union of runtime parameters (anonymous union in C++). +#[repr(C)] +#[derive(Clone, Copy)] +pub union RuntimeParams_C { + pub hnswRuntimeParams: HNSWRuntimeParams_C, + pub svsRuntimeParams: SVSRuntimeParams_C, +} + +impl Default for RuntimeParams_C { + fn default() -> Self { + RuntimeParams_C { + hnswRuntimeParams: HNSWRuntimeParams_C::default(), + } + } +} + +/// Query parameters matching C++ VecSimQueryParams. +#[repr(C)] +#[derive(Clone, Copy)] +pub struct VecSimQueryParams_C { + pub runtimeParams: RuntimeParams_C, + pub batchSize: usize, + pub searchMode: VecSearchMode, + pub timeoutCtx: *mut c_void, +} + +impl Default for VecSimQueryParams_C { + fn default() -> Self { + VecSimQueryParams_C { + runtimeParams: RuntimeParams_C::default(), + batchSize: 0, + searchMode: VecSearchMode::EMPTY_MODE, + timeoutCtx: std::ptr::null_mut(), + } + } +} + diff --git a/rust/vecsim-c/src/index.rs b/rust/vecsim-c/src/index.rs new file mode 100644 index 000000000..6b297c040 --- /dev/null +++ b/rust/vecsim-c/src/index.rs @@ -0,0 +1,1588 @@ +//! Index wrapper and lifecycle functions for C FFI. + +use crate::params::{BFParams, DiskParams, HNSWParams, SVSParams, TieredParams, VecSimQueryParams}; +use crate::types::{ + labelType, QueryReplyInternal, QueryResultInternal, VecSimAlgo, VecSimMetric, + VecSimQueryReply_Code, VecSimQueryReply_Order, VecSimType, +}; +use std::ffi::c_void; +use std::slice; +use vecsim::index::traits::BatchIterator as VecSimBatchIteratorTrait; +use vecsim::index::{ + disk::DiskIndexSingle, BruteForceMulti, BruteForceSingle, HnswMulti, HnswSingle, SvsMulti, + SvsSingle, TieredMulti, TieredSingle, VecSimIndex as VecSimIndexTrait, +}; +use vecsim::query::QueryReply; +use vecsim::types::{BFloat16, DistanceType, Float16, Int8, UInt8, VectorElement}; + +/// Trait for type-erased index operations. +pub trait IndexWrapper: Send + Sync { + /// Add a vector to the index. + fn add_vector(&mut self, vector: *const c_void, label: labelType) -> i32; + + /// Delete a vector by label. + fn delete_vector(&mut self, label: labelType) -> i32; + + /// Perform top-k query. + fn top_k_query( + &self, + query: *const c_void, + k: usize, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal; + + /// Perform range query. + fn range_query( + &self, + query: *const c_void, + radius: f64, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal; + + /// Get distance from a stored vector to a query vector. + fn get_distance_from(&self, label: labelType, query: *const c_void) -> f64; + + /// Get index size. + fn index_size(&self) -> usize; + + /// Get vector dimension. + fn dimension(&self) -> usize; + + /// Check if label exists. + fn contains(&self, label: labelType) -> bool; + + /// Get label count for a label. + fn label_count(&self, label: labelType) -> usize; + + /// Get data type. + fn data_type(&self) -> VecSimType; + + /// Get algorithm type. + fn algo(&self) -> VecSimAlgo; + + /// Get metric. + fn metric(&self) -> VecSimMetric; + + /// Is multi-value index. + fn is_multi(&self) -> bool; + + /// Create a batch iterator. + fn create_batch_iterator( + &self, + query: *const c_void, + params: Option<&VecSimQueryParams>, + ) -> Option>; + + /// Get memory usage. + fn memory_usage(&self) -> usize; + + /// Save the index to a file. + /// Returns true on success, false on failure. + fn save_to_file(&self, path: &std::path::Path) -> bool; + + /// Check if this is a tiered index. + fn is_tiered(&self) -> bool { + false + } + + /// Flush the flat buffer to the backend (tiered only). + /// Returns the number of vectors flushed. + fn tiered_flush(&mut self) -> usize { + 0 + } + + /// Get the size of the flat buffer (tiered only). + fn tiered_flat_size(&self) -> usize { + 0 + } + + /// Get the size of the backend index (tiered only). + fn tiered_backend_size(&self) -> usize { + 0 + } + + /// Run garbage collection on a tiered index. + fn tiered_gc(&mut self) {} + + /// Acquire shared locks on a tiered index. + fn tiered_acquire_shared_locks(&mut self) {} + + /// Release shared locks on a tiered index. + fn tiered_release_shared_locks(&mut self) {} + + /// Check if this is a disk-based index. + fn is_disk(&self) -> bool { + false + } + + /// Flush changes to disk (disk indices only). + /// Returns true on success, false on failure. + fn disk_flush(&self) -> bool { + false + } + + /// Get the neighbors of an element at all levels (HNSW/Tiered only). + /// + /// Returns None if the label doesn't exist or if the index type doesn't support this. + /// Returns Some(Vec>) where each inner Vec contains the neighbor labels at that level. + fn get_element_neighbors(&self, _label: u64) -> Option>> { + None + } +} + +/// Trait for type-erased batch iterator operations. +pub trait BatchIteratorWrapper: Send { + /// Check if more results are available. + fn has_next(&self) -> bool; + + /// Get next batch of results. + fn next_batch(&mut self, n: usize, order: VecSimQueryReply_Order) -> QueryReplyInternal; + + /// Reset the iterator. + fn reset(&mut self); +} + +/// Owned batch iterator that pre-computes all results. +/// This is used to work around the lifetime issues with type-erased batch iterators. +pub struct OwnedBatchIterator { + /// Pre-computed results as (label, distance) pairs + results: Vec, + /// Current position in results + position: usize, +} + +impl OwnedBatchIterator { + /// Create a new owned batch iterator from pre-computed results. + pub fn new(results: Vec) -> Self { + Self { + results, + position: 0, + } + } +} + +impl BatchIteratorWrapper for OwnedBatchIterator { + fn has_next(&self) -> bool { + self.position < self.results.len() + } + + fn next_batch(&mut self, n: usize, order: VecSimQueryReply_Order) -> QueryReplyInternal { + if self.position >= self.results.len() { + return QueryReplyInternal::new(); + } + + let end = (self.position + n).min(self.results.len()); + let mut batch: Vec = self.results[self.position..end].to_vec(); + self.position = end; + + // Sort by requested order + match order { + VecSimQueryReply_Order::BY_SCORE => { + batch.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)); + } + VecSimQueryReply_Order::BY_ID => { + batch.sort_by_key(|r| r.id); + } + } + + QueryReplyInternal { + results: batch, + code: VecSimQueryReply_Code::VecSim_QueryReply_OK, + } + } + + fn reset(&mut self) { + self.position = 0; + } +} + +/// Helper function to create an owned batch iterator from an index. +/// This pre-computes all results by calling next_batch until exhausted. +fn create_owned_batch_iterator_from_results( + mut inner: Box + '_>, +) -> OwnedBatchIterator { + let mut all_results = Vec::new(); + + // Drain all results from the iterator + while inner.has_next() { + if let Some(batch) = inner.next_batch(10000) { + for (_, label, dist) in batch { + all_results.push(QueryResultInternal { + id: label, + score: dist.to_f64(), + }); + } + } else { + break; + } + } + + // Sort by score (distance) to ensure results are in order of increasing distance + all_results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)); + + OwnedBatchIterator::new(all_results) +} + +/// Macro to implement IndexWrapper for a specific index type without serialization. +macro_rules! impl_index_wrapper { + ($wrapper:ident, $index:ty, $data:ty, $algo:expr, $is_multi:expr) => { + pub struct $wrapper { + index: $index, + data_type: VecSimType, + } + + impl $wrapper { + pub fn new(index: $index, data_type: VecSimType) -> Self { + Self { index, data_type } + } + + #[allow(dead_code)] + pub fn inner(&self) -> &$index { + &self.index + } + } + + impl IndexWrapper for $wrapper { + fn add_vector(&mut self, vector: *const c_void, label: labelType) -> i32 { + let dim = self.index.dimension(); + // Safety check: dimension must be > 0 and vector must be non-null + if dim == 0 || vector.is_null() { + return -1; + } + let slice = unsafe { slice::from_raw_parts(vector as *const $data, dim) }; + match self.index.add_vector(slice, label) { + Ok(count) => count as i32, + Err(_) => -1, + } + } + + fn delete_vector(&mut self, label: labelType) -> i32 { + match self.index.delete_vector(label) { + Ok(count) => count as i32, + Err(_) => 0, + } + } + + fn top_k_query( + &self, + query: *const c_void, + k: usize, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.top_k_query(slice, k, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn range_query( + &self, + query: *const c_void, + radius: f64, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + let radius_typed = + <<$data as VectorElement>::DistanceType as DistanceType>::from_f64(radius); + + match self.index.range_query(slice, radius_typed, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn get_distance_from(&self, label: labelType, query: *const c_void) -> f64 { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + match self.index.compute_distance(label, slice) { + Some(dist) => dist.to_f64(), + None => f64::NAN, // Label not found + } + } + + fn index_size(&self) -> usize { + self.index.index_size() + } + + fn dimension(&self) -> usize { + self.index.dimension() + } + + fn contains(&self, label: labelType) -> bool { + self.index.contains(label) + } + + fn label_count(&self, label: labelType) -> usize { + self.index.label_count(label) + } + + fn data_type(&self) -> VecSimType { + self.data_type + } + + fn algo(&self) -> VecSimAlgo { + $algo + } + + fn metric(&self) -> VecSimMetric { + self.index.info().index_type; // Not directly available, use placeholder + VecSimMetric::VecSimMetric_L2 + } + + fn is_multi(&self) -> bool { + $is_multi + } + + fn create_batch_iterator( + &self, + query: *const c_void, + params: Option<&VecSimQueryParams>, + ) -> Option> { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.batch_iterator(slice, rust_params.as_ref()) { + Ok(inner) => Some(Box::new(create_owned_batch_iterator_from_results(inner))), + Err(_) => None, + } + } + + fn memory_usage(&self) -> usize { + self.index.info().memory_bytes + } + + fn save_to_file(&self, _path: &std::path::Path) -> bool { + false // Serialization not supported for this type + } + } + }; +} + +/// Macro to implement IndexWrapper for a specific index type WITH serialization support. +macro_rules! impl_index_wrapper_with_serialization { + ($wrapper:ident, $index:ty, $data:ty, $algo:expr, $is_multi:expr) => { + pub struct $wrapper { + index: $index, + data_type: VecSimType, + } + + impl $wrapper { + pub fn new(index: $index, data_type: VecSimType) -> Self { + Self { index, data_type } + } + + #[allow(dead_code)] + pub fn inner(&self) -> &$index { + &self.index + } + } + + impl IndexWrapper for $wrapper { + fn add_vector(&mut self, vector: *const c_void, label: labelType) -> i32 { + let dim = self.index.dimension(); + // Safety check: dimension must be > 0 and vector must be non-null + if dim == 0 || vector.is_null() { + return -1; + } + let slice = unsafe { slice::from_raw_parts(vector as *const $data, dim) }; + match self.index.add_vector(slice, label) { + Ok(count) => count as i32, + Err(_) => -1, + } + } + + fn delete_vector(&mut self, label: labelType) -> i32 { + match self.index.delete_vector(label) { + Ok(count) => count as i32, + Err(_) => 0, + } + } + + fn top_k_query( + &self, + query: *const c_void, + k: usize, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.top_k_query(slice, k, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn range_query( + &self, + query: *const c_void, + radius: f64, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + let radius_typed = + <<$data as VectorElement>::DistanceType as DistanceType>::from_f64(radius); + + match self.index.range_query(slice, radius_typed, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn get_distance_from(&self, _label: labelType, _query: *const c_void) -> f64 { + // SVS indices don't support compute_distance + f64::NAN + } + + fn index_size(&self) -> usize { + self.index.index_size() + } + + fn dimension(&self) -> usize { + self.index.dimension() + } + + fn contains(&self, label: labelType) -> bool { + self.index.contains(label) + } + + fn label_count(&self, label: labelType) -> usize { + self.index.label_count(label) + } + + fn data_type(&self) -> VecSimType { + self.data_type + } + + fn algo(&self) -> VecSimAlgo { + $algo + } + + fn metric(&self) -> VecSimMetric { + self.index.info().index_type; + VecSimMetric::VecSimMetric_L2 + } + + fn is_multi(&self) -> bool { + $is_multi + } + + fn create_batch_iterator( + &self, + query: *const c_void, + params: Option<&VecSimQueryParams>, + ) -> Option> { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.batch_iterator(slice, rust_params.as_ref()) { + Ok(inner) => Some(Box::new(create_owned_batch_iterator_from_results(inner))), + Err(_) => None, + } + } + + fn memory_usage(&self) -> usize { + self.index.info().memory_bytes + } + + fn save_to_file(&self, path: &std::path::Path) -> bool { + self.index.save_to_file(path).is_ok() + } + } + }; +} + +// Implement wrappers for BruteForce indices +// Note: Serialization is only supported for f32 types +impl_index_wrapper_with_serialization!( + BruteForceSingleF32Wrapper, + BruteForceSingle, + f32, + VecSimAlgo::VecSimAlgo_BF, + false +); +impl_index_wrapper!( + BruteForceSingleF64Wrapper, + BruteForceSingle, + f64, + VecSimAlgo::VecSimAlgo_BF, + false +); +impl_index_wrapper!( + BruteForceSingleBF16Wrapper, + BruteForceSingle, + BFloat16, + VecSimAlgo::VecSimAlgo_BF, + false +); +impl_index_wrapper!( + BruteForceSingleFP16Wrapper, + BruteForceSingle, + Float16, + VecSimAlgo::VecSimAlgo_BF, + false +); +impl_index_wrapper!( + BruteForceSingleI8Wrapper, + BruteForceSingle, + Int8, + VecSimAlgo::VecSimAlgo_BF, + false +); +impl_index_wrapper!( + BruteForceSingleU8Wrapper, + BruteForceSingle, + UInt8, + VecSimAlgo::VecSimAlgo_BF, + false +); + +impl_index_wrapper_with_serialization!( + BruteForceMultiF32Wrapper, + BruteForceMulti, + f32, + VecSimAlgo::VecSimAlgo_BF, + true +); +impl_index_wrapper!( + BruteForceMultiF64Wrapper, + BruteForceMulti, + f64, + VecSimAlgo::VecSimAlgo_BF, + true +); +impl_index_wrapper!( + BruteForceMultiBF16Wrapper, + BruteForceMulti, + BFloat16, + VecSimAlgo::VecSimAlgo_BF, + true +); +impl_index_wrapper!( + BruteForceMultiFP16Wrapper, + BruteForceMulti, + Float16, + VecSimAlgo::VecSimAlgo_BF, + true +); +impl_index_wrapper!( + BruteForceMultiI8Wrapper, + BruteForceMulti, + Int8, + VecSimAlgo::VecSimAlgo_BF, + true +); +impl_index_wrapper!( + BruteForceMultiU8Wrapper, + BruteForceMulti, + UInt8, + VecSimAlgo::VecSimAlgo_BF, + true +); + +// Macro for HNSW Single wrappers with get_element_neighbors support +macro_rules! impl_hnsw_single_wrapper { + ($wrapper:ident, $index:ty, $data:ty) => { + pub struct $wrapper { + index: $index, + data_type: VecSimType, + } + + impl $wrapper { + pub fn new(index: $index, data_type: VecSimType) -> Self { + Self { index, data_type } + } + + #[allow(dead_code)] + pub fn inner(&self) -> &$index { + &self.index + } + } + + impl IndexWrapper for $wrapper { + fn add_vector(&mut self, vector: *const c_void, label: labelType) -> i32 { + let dim = self.index.dimension(); + // Safety check: dimension must be > 0 and vector must be non-null + if dim == 0 || vector.is_null() { + return -1; + } + let slice = unsafe { slice::from_raw_parts(vector as *const $data, dim) }; + match self.index.add_vector(slice, label) { + Ok(count) => count as i32, + Err(_) => -1, + } + } + + fn delete_vector(&mut self, label: labelType) -> i32 { + match self.index.delete_vector(label) { + Ok(count) => count as i32, + Err(_) => 0, + } + } + + fn top_k_query( + &self, + query: *const c_void, + k: usize, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.top_k_query(slice, k, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn range_query( + &self, + query: *const c_void, + radius: f64, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + let radius_typed = + <<$data as VectorElement>::DistanceType as DistanceType>::from_f64(radius); + + match self.index.range_query(slice, radius_typed, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn get_distance_from(&self, label: labelType, query: *const c_void) -> f64 { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + match self.index.compute_distance(label, slice) { + Some(dist) => dist.to_f64(), + None => f64::NAN, // Label not found + } + } + + fn index_size(&self) -> usize { + self.index.index_size() + } + + fn dimension(&self) -> usize { + self.index.dimension() + } + + fn contains(&self, label: labelType) -> bool { + self.index.contains(label) + } + + fn label_count(&self, label: labelType) -> usize { + self.index.label_count(label) + } + + fn data_type(&self) -> VecSimType { + self.data_type + } + + fn algo(&self) -> VecSimAlgo { + VecSimAlgo::VecSimAlgo_HNSWLIB + } + + fn metric(&self) -> VecSimMetric { + // Metric is not directly available from IndexInfo, use placeholder + VecSimMetric::VecSimMetric_L2 + } + + fn is_multi(&self) -> bool { + false + } + + fn create_batch_iterator( + &self, + query: *const c_void, + params: Option<&VecSimQueryParams>, + ) -> Option> { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.batch_iterator(slice, rust_params.as_ref()) { + Ok(inner) => Some(Box::new(create_owned_batch_iterator_from_results(inner))), + Err(_) => None, + } + } + + fn memory_usage(&self) -> usize { + self.index.info().memory_bytes + } + + fn save_to_file(&self, path: &std::path::Path) -> bool { + self.index.save_to_file(path).is_ok() + } + + fn get_element_neighbors(&self, label: u64) -> Option>> { + self.index.get_element_neighbors(label) + } + } + }; +} + +// Implement wrappers for HNSW Single indices +impl_hnsw_single_wrapper!(HnswSingleF32Wrapper, HnswSingle, f32); +impl_hnsw_single_wrapper!(HnswSingleF64Wrapper, HnswSingle, f64); +impl_hnsw_single_wrapper!(HnswSingleBF16Wrapper, HnswSingle, BFloat16); +impl_hnsw_single_wrapper!(HnswSingleFP16Wrapper, HnswSingle, Float16); +impl_hnsw_single_wrapper!(HnswSingleI8Wrapper, HnswSingle, Int8); +impl_hnsw_single_wrapper!(HnswSingleU8Wrapper, HnswSingle, UInt8); + +impl_index_wrapper_with_serialization!( + HnswMultiF32Wrapper, + HnswMulti, + f32, + VecSimAlgo::VecSimAlgo_HNSWLIB, + true +); +impl_index_wrapper_with_serialization!( + HnswMultiF64Wrapper, + HnswMulti, + f64, + VecSimAlgo::VecSimAlgo_HNSWLIB, + true +); +impl_index_wrapper_with_serialization!( + HnswMultiBF16Wrapper, + HnswMulti, + BFloat16, + VecSimAlgo::VecSimAlgo_HNSWLIB, + true +); +impl_index_wrapper_with_serialization!( + HnswMultiFP16Wrapper, + HnswMulti, + Float16, + VecSimAlgo::VecSimAlgo_HNSWLIB, + true +); +impl_index_wrapper_with_serialization!( + HnswMultiI8Wrapper, + HnswMulti, + Int8, + VecSimAlgo::VecSimAlgo_HNSWLIB, + true +); +impl_index_wrapper_with_serialization!( + HnswMultiU8Wrapper, + HnswMulti, + UInt8, + VecSimAlgo::VecSimAlgo_HNSWLIB, + true +); + +// Macro for SVS indices (no compute_distance support) +macro_rules! impl_svs_wrapper { + ($wrapper:ident, $index:ty, $data:ty, $is_multi:expr) => { + pub struct $wrapper { + index: $index, + data_type: VecSimType, + } + + impl $wrapper { + pub fn new(index: $index, data_type: VecSimType) -> Self { + Self { index, data_type } + } + + #[allow(dead_code)] + pub fn inner(&self) -> &$index { + &self.index + } + } + + impl IndexWrapper for $wrapper { + fn add_vector(&mut self, vector: *const c_void, label: labelType) -> i32 { + let dim = self.index.dimension(); + // Safety check: dimension must be > 0 and vector must be non-null + if dim == 0 || vector.is_null() { + return -1; + } + let slice = unsafe { slice::from_raw_parts(vector as *const $data, dim) }; + match self.index.add_vector(slice, label) { + Ok(count) => count as i32, + Err(_) => -1, + } + } + + fn delete_vector(&mut self, label: labelType) -> i32 { + match self.index.delete_vector(label) { + Ok(count) => count as i32, + Err(_) => 0, + } + } + + fn top_k_query( + &self, + query: *const c_void, + k: usize, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.top_k_query(slice, k, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn range_query( + &self, + query: *const c_void, + radius: f64, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + let radius_typed = + <<$data as VectorElement>::DistanceType as DistanceType>::from_f64(radius); + + match self.index.range_query(slice, radius_typed, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn get_distance_from(&self, _label: labelType, _query: *const c_void) -> f64 { + // SVS indices don't support compute_distance + f64::NAN + } + + fn index_size(&self) -> usize { + self.index.index_size() + } + + fn dimension(&self) -> usize { + self.index.dimension() + } + + fn contains(&self, label: labelType) -> bool { + self.index.contains(label) + } + + fn label_count(&self, label: labelType) -> usize { + self.index.label_count(label) + } + + fn data_type(&self) -> VecSimType { + self.data_type + } + + fn algo(&self) -> VecSimAlgo { + VecSimAlgo::VecSimAlgo_SVS + } + + fn metric(&self) -> VecSimMetric { + VecSimMetric::VecSimMetric_L2 + } + + fn is_multi(&self) -> bool { + $is_multi + } + + fn create_batch_iterator( + &self, + query: *const c_void, + params: Option<&VecSimQueryParams>, + ) -> Option> { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.batch_iterator(slice, rust_params.as_ref()) { + Ok(inner) => Some(Box::new(create_owned_batch_iterator_from_results(inner))), + Err(_) => None, + } + } + + fn memory_usage(&self) -> usize { + self.index.info().memory_bytes + } + + fn save_to_file(&self, _path: &std::path::Path) -> bool { + false // Serialization not supported for most SVS types + } + } + }; +} + +// Implement wrappers for SVS indices +// Note: SVS serialization is only supported for f32 single +impl_svs_wrapper!(SvsSingleF32Wrapper, SvsSingle, f32, false); +impl_svs_wrapper!(SvsSingleF64Wrapper, SvsSingle, f64, false); +impl_svs_wrapper!(SvsSingleBF16Wrapper, SvsSingle, BFloat16, false); +impl_svs_wrapper!(SvsSingleFP16Wrapper, SvsSingle, Float16, false); +impl_svs_wrapper!(SvsSingleI8Wrapper, SvsSingle, Int8, false); +impl_svs_wrapper!(SvsSingleU8Wrapper, SvsSingle, UInt8, false); + +impl_svs_wrapper!(SvsMultiF32Wrapper, SvsMulti, f32, true); +impl_svs_wrapper!(SvsMultiF64Wrapper, SvsMulti, f64, true); +impl_svs_wrapper!(SvsMultiBF16Wrapper, SvsMulti, BFloat16, true); +impl_svs_wrapper!(SvsMultiFP16Wrapper, SvsMulti, Float16, true); +impl_svs_wrapper!(SvsMultiI8Wrapper, SvsMulti, Int8, true); +impl_svs_wrapper!(SvsMultiU8Wrapper, SvsMulti, UInt8, true); + +// ============================================================================ +// Tiered Index Wrappers +// ============================================================================ + +/// Macro to implement IndexWrapper for Tiered index types. +macro_rules! impl_tiered_wrapper { + ($wrapper:ident, $index:ty, $data:ty, $is_multi:expr) => { + pub struct $wrapper { + index: $index, + data_type: VecSimType, + } + + impl $wrapper { + pub fn new(index: $index, data_type: VecSimType) -> Self { + Self { index, data_type } + } + + #[allow(dead_code)] + pub fn inner(&self) -> &$index { + &self.index + } + + #[allow(dead_code)] + pub fn inner_mut(&mut self) -> &mut $index { + &mut self.index + } + } + + impl IndexWrapper for $wrapper { + fn add_vector(&mut self, vector: *const c_void, label: labelType) -> i32 { + let dim = self.index.dimension(); + // Safety check: dimension must be > 0 and vector must be non-null + if dim == 0 || vector.is_null() { + return -1; + } + let slice = unsafe { slice::from_raw_parts(vector as *const $data, dim) }; + match self.index.add_vector(slice, label) { + Ok(count) => count as i32, + Err(_) => -1, + } + } + + fn delete_vector(&mut self, label: labelType) -> i32 { + match self.index.delete_vector(label) { + Ok(count) => count as i32, + Err(_) => 0, + } + } + + fn top_k_query( + &self, + query: *const c_void, + k: usize, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.top_k_query(slice, k, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn range_query( + &self, + query: *const c_void, + radius: f64, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + let radius_typed = + <<$data as VectorElement>::DistanceType as DistanceType>::from_f64(radius); + + match self.index.range_query(slice, radius_typed, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn get_distance_from(&self, _label: labelType, _query: *const c_void) -> f64 { + // Tiered indices don't support compute_distance directly + f64::NAN + } + + fn index_size(&self) -> usize { + self.index.index_size() + } + + fn dimension(&self) -> usize { + self.index.dimension() + } + + fn contains(&self, label: labelType) -> bool { + self.index.contains(label) + } + + fn label_count(&self, label: labelType) -> usize { + self.index.label_count(label) + } + + fn data_type(&self) -> VecSimType { + self.data_type + } + + fn algo(&self) -> VecSimAlgo { + VecSimAlgo::VecSimAlgo_TIERED + } + + fn metric(&self) -> VecSimMetric { + VecSimMetric::VecSimMetric_L2 + } + + fn is_multi(&self) -> bool { + $is_multi + } + + fn create_batch_iterator( + &self, + query: *const c_void, + params: Option<&VecSimQueryParams>, + ) -> Option> { + let dim = self.index.dimension(); + let slice = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + + match self.index.batch_iterator(slice, rust_params.as_ref()) { + Ok(inner) => Some(Box::new(create_owned_batch_iterator_from_results(inner))), + Err(_) => None, + } + } + + fn memory_usage(&self) -> usize { + self.index.info().memory_bytes + } + + fn save_to_file(&self, _path: &std::path::Path) -> bool { + false // Tiered serialization not yet exposed via C API + } + + fn is_tiered(&self) -> bool { + true + } + + fn tiered_flush(&mut self) -> usize { + self.index.flush().unwrap_or(0) + } + + fn tiered_flat_size(&self) -> usize { + self.index.flat_size() + } + + fn tiered_backend_size(&self) -> usize { + self.index.hnsw_size() + } + + fn tiered_gc(&mut self) { + // No-op for now + } + + fn tiered_acquire_shared_locks(&mut self) { + // No-op for now + } + + fn tiered_release_shared_locks(&mut self) { + // No-op for now + } + + fn get_element_neighbors(&self, label: u64) -> Option>> { + self.index.get_element_neighbors(label) + } + } + }; +} + +// Implement wrappers for Tiered indices (f32 only for now) +impl_tiered_wrapper!(TieredSingleF32Wrapper, TieredSingle, f32, false); +impl_tiered_wrapper!(TieredMultiF32Wrapper, TieredMulti, f32, true); + +// ============================================================================ +// Disk Index Wrappers +// ============================================================================ + +/// Macro to implement IndexWrapper for Disk index types. +macro_rules! impl_disk_wrapper { + ($wrapper:ident, $data:ty) => { + pub struct $wrapper { + index: DiskIndexSingle<$data>, + data_type: VecSimType, + metric: VecSimMetric, + } + + impl $wrapper { + pub fn new( + index: DiskIndexSingle<$data>, + data_type: VecSimType, + metric: VecSimMetric, + ) -> Self { + Self { + index, + data_type, + metric, + } + } + } + + impl IndexWrapper for $wrapper { + fn add_vector(&mut self, vector: *const c_void, label: labelType) -> i32 { + let dim = self.index.dimension(); + // Safety check: dimension must be > 0 and vector must be non-null + if dim == 0 || vector.is_null() { + return -1; + } + let data = unsafe { slice::from_raw_parts(vector as *const $data, dim) }; + match self.index.add_vector(data, label) { + Ok(count) => count as i32, + Err(_) => -1, + } + } + + fn delete_vector(&mut self, label: labelType) -> i32 { + match self.index.delete_vector(label) { + Ok(count) => count as i32, + Err(_) => 0, + } + } + + fn top_k_query( + &self, + query: *const c_void, + k: usize, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let query_data = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + match self.index.top_k_query(query_data, k, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn range_query( + &self, + query: *const c_void, + radius: f64, + params: Option<&VecSimQueryParams>, + ) -> QueryReplyInternal { + let dim = self.index.dimension(); + let query_data = unsafe { slice::from_raw_parts(query as *const $data, dim) }; + let rust_params = params.map(|p| p.to_rust_params()); + let radius_typed = + <<$data as VectorElement>::DistanceType as DistanceType>::from_f64(radius); + match self.index.range_query(query_data, radius_typed, rust_params.as_ref()) { + Ok(reply) => convert_query_reply(reply), + Err(_) => QueryReplyInternal::new(), + } + } + + fn get_distance_from(&self, _label: labelType, _query: *const c_void) -> f64 { + // Disk indices don't support get_distance_from directly + f64::NAN + } + + fn index_size(&self) -> usize { + self.index.index_size() + } + + fn dimension(&self) -> usize { + self.index.dimension() + } + + fn contains(&self, label: labelType) -> bool { + self.index.contains(label) + } + + fn label_count(&self, label: labelType) -> usize { + self.index.label_count(label) + } + + fn data_type(&self) -> VecSimType { + self.data_type + } + + fn algo(&self) -> VecSimAlgo { + // Disk indices use BF or SVS backend, report as BF for now + VecSimAlgo::VecSimAlgo_BF + } + + fn metric(&self) -> VecSimMetric { + self.metric + } + + fn is_multi(&self) -> bool { + false // DiskIndexSingle is always single-value + } + + fn create_batch_iterator( + &self, + _query: *const c_void, + _params: Option<&VecSimQueryParams>, + ) -> Option> { + // Batch iteration not yet implemented for disk indices + None + } + + fn memory_usage(&self) -> usize { + // Memory usage is minimal since data is on disk + std::mem::size_of::() + } + + fn save_to_file(&self, _path: &std::path::Path) -> bool { + // Disk indices are already persisted + self.disk_flush() + } + + fn is_disk(&self) -> bool { + true + } + + fn disk_flush(&self) -> bool { + self.index.flush().is_ok() + } + } + }; +} + +// Implement wrappers for Disk indices (f32 only for now) +impl_disk_wrapper!(DiskSingleF32Wrapper, f32); + +/// Create a new disk-based index. +pub fn create_disk_index(params: &DiskParams) -> Option> { + let rust_params = unsafe { params.to_rust_params()? }; + let data_type = params.base.type_; + let metric = params.base.metric; + let dim = params.base.dim; + + // Only f32 is supported for now + if data_type != VecSimType::VecSimType_FLOAT32 { + return None; + } + + let index = DiskIndexSingle::::new(rust_params).ok()?; + let wrapper: Box = + Box::new(DiskSingleF32Wrapper::new(index, data_type, metric)); + + Some(Box::new(IndexHandle::new( + wrapper, + data_type, + VecSimAlgo::VecSimAlgo_BF, // Report as BF for now + metric, + dim, + false, // Disk indices are single-value only + ))) +} + +/// Convert a Rust QueryReply to QueryReplyInternal. +fn convert_query_reply(reply: QueryReply) -> QueryReplyInternal { + let results: Vec = reply + .results + .into_iter() + .map(|r| QueryResultInternal { + id: r.label, + score: r.distance.to_f64(), + }) + .collect(); + QueryReplyInternal::from_results(results) +} + +/// Internal index handle that stores the type-erased wrapper. +pub struct IndexHandle { + pub wrapper: Box, + pub data_type: VecSimType, + pub algo: VecSimAlgo, + pub metric: VecSimMetric, + pub dim: usize, + pub is_multi: bool, +} + +impl IndexHandle { + pub fn new( + wrapper: Box, + data_type: VecSimType, + algo: VecSimAlgo, + metric: VecSimMetric, + dim: usize, + is_multi: bool, + ) -> Self { + Self { + wrapper, + data_type, + algo, + metric, + dim, + is_multi, + } + } +} + +/// Create a new BruteForce index. +pub fn create_brute_force_index(params: &BFParams) -> Option> { + let rust_params = params.to_rust_params(); + let data_type = params.base.type_; + let metric = params.base.metric; + let dim = params.base.dim; + let is_multi = params.base.multi; + + let wrapper: Box = match (data_type, is_multi) { + (VecSimType::VecSimType_FLOAT32, false) => Box::new(BruteForceSingleF32Wrapper::new( + BruteForceSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT64, false) => Box::new(BruteForceSingleF64Wrapper::new( + BruteForceSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_BFLOAT16, false) => Box::new(BruteForceSingleBF16Wrapper::new( + BruteForceSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT16, false) => Box::new(BruteForceSingleFP16Wrapper::new( + BruteForceSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_INT8, false) => Box::new(BruteForceSingleI8Wrapper::new( + BruteForceSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_UINT8, false) => Box::new(BruteForceSingleU8Wrapper::new( + BruteForceSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT32, true) => Box::new(BruteForceMultiF32Wrapper::new( + BruteForceMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT64, true) => Box::new(BruteForceMultiF64Wrapper::new( + BruteForceMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_BFLOAT16, true) => Box::new(BruteForceMultiBF16Wrapper::new( + BruteForceMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT16, true) => Box::new(BruteForceMultiFP16Wrapper::new( + BruteForceMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_INT8, true) => Box::new(BruteForceMultiI8Wrapper::new( + BruteForceMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_UINT8, true) => Box::new(BruteForceMultiU8Wrapper::new( + BruteForceMulti::new(rust_params), + data_type, + )), + // INT32 and INT64 types not yet supported for vector indices + (VecSimType::VecSimType_INT32, _) | (VecSimType::VecSimType_INT64, _) => return None, + }; + + Some(Box::new(IndexHandle::new( + wrapper, + data_type, + VecSimAlgo::VecSimAlgo_BF, + metric, + dim, + is_multi, + ))) +} + +/// Create a new HNSW index. +pub fn create_hnsw_index(params: &HNSWParams) -> Option> { + let rust_params = params.to_rust_params(); + let data_type = params.base.type_; + let metric = params.base.metric; + let dim = params.base.dim; + let is_multi = params.base.multi; + + let wrapper: Box = match (data_type, is_multi) { + (VecSimType::VecSimType_FLOAT32, false) => Box::new(HnswSingleF32Wrapper::new( + HnswSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT64, false) => Box::new(HnswSingleF64Wrapper::new( + HnswSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_BFLOAT16, false) => Box::new(HnswSingleBF16Wrapper::new( + HnswSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT16, false) => Box::new(HnswSingleFP16Wrapper::new( + HnswSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_INT8, false) => Box::new(HnswSingleI8Wrapper::new( + HnswSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_UINT8, false) => Box::new(HnswSingleU8Wrapper::new( + HnswSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT32, true) => Box::new(HnswMultiF32Wrapper::new( + HnswMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT64, true) => Box::new(HnswMultiF64Wrapper::new( + HnswMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_BFLOAT16, true) => Box::new(HnswMultiBF16Wrapper::new( + HnswMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT16, true) => Box::new(HnswMultiFP16Wrapper::new( + HnswMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_INT8, true) => Box::new(HnswMultiI8Wrapper::new( + HnswMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_UINT8, true) => Box::new(HnswMultiU8Wrapper::new( + HnswMulti::new(rust_params), + data_type, + )), + // INT32 and INT64 types not yet supported for vector indices + (VecSimType::VecSimType_INT32, _) | (VecSimType::VecSimType_INT64, _) => return None, + }; + + Some(Box::new(IndexHandle::new( + wrapper, + data_type, + VecSimAlgo::VecSimAlgo_HNSWLIB, + metric, + dim, + is_multi, + ))) +} + +/// Create a new SVS index. +pub fn create_svs_index(params: &SVSParams) -> Option> { + let rust_params = params.to_rust_params(); + let data_type = params.base.type_; + let metric = params.base.metric; + let dim = params.base.dim; + let is_multi = params.base.multi; + + let wrapper: Box = match (data_type, is_multi) { + (VecSimType::VecSimType_FLOAT32, false) => Box::new(SvsSingleF32Wrapper::new( + SvsSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT64, false) => Box::new(SvsSingleF64Wrapper::new( + SvsSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_BFLOAT16, false) => Box::new(SvsSingleBF16Wrapper::new( + SvsSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT16, false) => Box::new(SvsSingleFP16Wrapper::new( + SvsSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_INT8, false) => Box::new(SvsSingleI8Wrapper::new( + SvsSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_UINT8, false) => Box::new(SvsSingleU8Wrapper::new( + SvsSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT32, true) => Box::new(SvsMultiF32Wrapper::new( + SvsMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT64, true) => Box::new(SvsMultiF64Wrapper::new( + SvsMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_BFLOAT16, true) => Box::new(SvsMultiBF16Wrapper::new( + SvsMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT16, true) => Box::new(SvsMultiFP16Wrapper::new( + SvsMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_INT8, true) => Box::new(SvsMultiI8Wrapper::new( + SvsMulti::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_UINT8, true) => Box::new(SvsMultiU8Wrapper::new( + SvsMulti::new(rust_params), + data_type, + )), + // INT32 and INT64 types not yet supported for vector indices + (VecSimType::VecSimType_INT32, _) | (VecSimType::VecSimType_INT64, _) => return None, + }; + + Some(Box::new(IndexHandle::new( + wrapper, + data_type, + VecSimAlgo::VecSimAlgo_SVS, + metric, + dim, + is_multi, + ))) +} + +/// Create a new Tiered index. +pub fn create_tiered_index(params: &TieredParams) -> Option> { + let rust_params = params.to_rust_params(); + let data_type = params.base.type_; + let metric = params.base.metric; + let dim = params.base.dim; + let is_multi = params.base.multi; + + // Tiered index currently only supports f32 + let wrapper: Box = match (data_type, is_multi) { + (VecSimType::VecSimType_FLOAT32, false) => Box::new(TieredSingleF32Wrapper::new( + TieredSingle::new(rust_params), + data_type, + )), + (VecSimType::VecSimType_FLOAT32, true) => Box::new(TieredMultiF32Wrapper::new( + TieredMulti::new(rust_params), + data_type, + )), + _ => return None, // Tiered only supports f32 currently + }; + + Some(Box::new(IndexHandle::new( + wrapper, + data_type, + VecSimAlgo::VecSimAlgo_TIERED, + metric, + dim, + is_multi, + ))) +} diff --git a/rust/vecsim-c/src/info.rs b/rust/vecsim-c/src/info.rs new file mode 100644 index 000000000..b26a14b75 --- /dev/null +++ b/rust/vecsim-c/src/info.rs @@ -0,0 +1,407 @@ +//! Index information and introspection functions for C FFI. + +use std::ffi::CString; + +use crate::index::IndexHandle; +use crate::types::{VecSearchMode, VecSimAlgo, VecSimMetric, VecSimType}; + +// ============================================================================ +// Debug Info Iterator Types +// ============================================================================ + +/// Field type for debug info fields. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSim_InfoFieldType { + INFOFIELD_STRING = 0, + INFOFIELD_INT64 = 1, + INFOFIELD_UINT64 = 2, + INFOFIELD_FLOAT64 = 3, + INFOFIELD_ITERATOR = 4, +} + +/// Union of field values. +#[repr(C)] +#[derive(Clone, Copy)] +pub union FieldValue { + pub floatingPointValue: f64, + pub integerValue: i64, + pub uintegerValue: u64, + pub stringValue: *const std::ffi::c_char, + pub iteratorValue: *mut VecSimDebugInfoIterator, +} + +impl Default for FieldValue { + fn default() -> Self { + FieldValue { uintegerValue: 0 } + } +} + +/// A field in the debug info iterator. +#[repr(C)] +pub struct VecSim_InfoField { + pub fieldName: *const std::ffi::c_char, + pub fieldType: VecSim_InfoFieldType, + pub fieldValue: FieldValue, +} + +/// Internal representation of a field with owned strings. +pub struct InfoFieldOwned { + pub name: CString, + pub field_type: VecSim_InfoFieldType, + pub value: InfoFieldValue, +} + +/// Internal value representation. +pub enum InfoFieldValue { + String(CString), + Int64(i64), + UInt64(u64), + Float64(f64), + Iterator(Box), +} + +/// Debug info iterator - an opaque type that holds fields for iteration. +pub struct VecSimDebugInfoIterator { + fields: Vec, + current_index: usize, + /// Cached C representation of current field (for returning pointers). + current_field: Option, +} + +impl VecSimDebugInfoIterator { + /// Create a new iterator with the given capacity. + pub fn new(capacity: usize) -> Self { + VecSimDebugInfoIterator { + fields: Vec::with_capacity(capacity), + current_index: 0, + current_field: None, + } + } + + /// Add a string field. + pub fn add_string_field(&mut self, name: &str, value: &str) { + if let (Ok(name_c), Ok(value_c)) = (CString::new(name), CString::new(value)) { + self.fields.push(InfoFieldOwned { + name: name_c, + field_type: VecSim_InfoFieldType::INFOFIELD_STRING, + value: InfoFieldValue::String(value_c), + }); + } + } + + /// Add an unsigned integer field. + pub fn add_uint64_field(&mut self, name: &str, value: u64) { + if let Ok(name_c) = CString::new(name) { + self.fields.push(InfoFieldOwned { + name: name_c, + field_type: VecSim_InfoFieldType::INFOFIELD_UINT64, + value: InfoFieldValue::UInt64(value), + }); + } + } + + /// Add a signed integer field. + pub fn add_int64_field(&mut self, name: &str, value: i64) { + if let Ok(name_c) = CString::new(name) { + self.fields.push(InfoFieldOwned { + name: name_c, + field_type: VecSim_InfoFieldType::INFOFIELD_INT64, + value: InfoFieldValue::Int64(value), + }); + } + } + + /// Add a float field. + pub fn add_float64_field(&mut self, name: &str, value: f64) { + if let Ok(name_c) = CString::new(name) { + self.fields.push(InfoFieldOwned { + name: name_c, + field_type: VecSim_InfoFieldType::INFOFIELD_FLOAT64, + value: InfoFieldValue::Float64(value), + }); + } + } + + /// Add a nested iterator field. + pub fn add_iterator_field(&mut self, name: &str, value: VecSimDebugInfoIterator) { + if let Ok(name_c) = CString::new(name) { + self.fields.push(InfoFieldOwned { + name: name_c, + field_type: VecSim_InfoFieldType::INFOFIELD_ITERATOR, + value: InfoFieldValue::Iterator(Box::new(value)), + }); + } + } + + /// Get the number of fields. + pub fn number_of_fields(&self) -> usize { + self.fields.len() + } + + /// Check if there are more fields. + pub fn has_next(&self) -> bool { + self.current_index < self.fields.len() + } + + /// Get the next field, returning a pointer to a cached VecSim_InfoField. + /// The returned pointer is valid until the next call to next() or until + /// the iterator is freed. + pub fn next(&mut self) -> *mut VecSim_InfoField { + if self.current_index >= self.fields.len() { + return std::ptr::null_mut(); + } + + let field = &self.fields[self.current_index]; + self.current_index += 1; + + let field_value = match &field.value { + InfoFieldValue::String(s) => FieldValue { + stringValue: s.as_ptr(), + }, + InfoFieldValue::Int64(v) => FieldValue { integerValue: *v }, + InfoFieldValue::UInt64(v) => FieldValue { uintegerValue: *v }, + InfoFieldValue::Float64(v) => FieldValue { + floatingPointValue: *v, + }, + InfoFieldValue::Iterator(iter) => FieldValue { + // Return a raw pointer to the boxed iterator + iteratorValue: iter.as_ref() as *const VecSimDebugInfoIterator + as *mut VecSimDebugInfoIterator, + }, + }; + + self.current_field = Some(VecSim_InfoField { + fieldName: field.name.as_ptr(), + fieldType: field.field_type, + fieldValue: field_value, + }); + + self.current_field.as_mut().unwrap() as *mut VecSim_InfoField + } +} + +/// Helper function to get algorithm name as string. +pub fn algo_to_string(algo: VecSimAlgo) -> &'static str { + match algo { + VecSimAlgo::VecSimAlgo_BF => "FLAT", + VecSimAlgo::VecSimAlgo_HNSWLIB => "HNSW", + VecSimAlgo::VecSimAlgo_TIERED => "TIERED", + VecSimAlgo::VecSimAlgo_SVS => "SVS", + } +} + +/// Helper function to get type name as string. +pub fn type_to_string(t: VecSimType) -> &'static str { + match t { + VecSimType::VecSimType_FLOAT32 => "FLOAT32", + VecSimType::VecSimType_FLOAT64 => "FLOAT64", + VecSimType::VecSimType_BFLOAT16 => "BFLOAT16", + VecSimType::VecSimType_FLOAT16 => "FLOAT16", + VecSimType::VecSimType_INT8 => "INT8", + VecSimType::VecSimType_UINT8 => "UINT8", + VecSimType::VecSimType_INT32 => "INT32", + VecSimType::VecSimType_INT64 => "INT64", + } +} + +/// Helper function to get metric name as string. +pub fn metric_to_string(m: VecSimMetric) -> &'static str { + match m { + VecSimMetric::VecSimMetric_L2 => "L2", + VecSimMetric::VecSimMetric_IP => "IP", + VecSimMetric::VecSimMetric_Cosine => "COSINE", + } +} + +/// Helper function to get search mode as string. +pub fn search_mode_to_string(mode: VecSearchMode) -> &'static str { + match mode { + VecSearchMode::EMPTY_MODE => "EMPTY_MODE", + VecSearchMode::STANDARD_KNN => "STANDARD_KNN", + VecSearchMode::HYBRID_ADHOC_BF => "HYBRID_ADHOC_BF", + VecSearchMode::HYBRID_BATCHES => "HYBRID_BATCHES", + VecSearchMode::HYBRID_BATCHES_TO_ADHOC_BF => "HYBRID_BATCHES_TO_ADHOC_BF", + VecSearchMode::RANGE_QUERY => "RANGE_QUERY", + } +} + +/// Index information struct. +#[repr(C)] +#[derive(Debug, Clone)] +pub struct VecSimIndexInfo { + /// Current number of vectors in the index. + pub indexSize: usize, + /// Current label count (same as indexSize for single-value indices). + pub indexLabelCount: usize, + /// Vector dimension. + pub dim: usize, + /// Data type. + pub type_: VecSimType, + /// Algorithm type. + pub algo: VecSimAlgo, + /// Distance metric. + pub metric: VecSimMetric, + /// Whether this is a multi-value index. + pub isMulti: bool, + /// Block size. + pub blockSize: usize, + /// Memory usage in bytes. + pub memory: usize, + /// HNSW-specific info (if applicable). + pub hnswInfo: VecSimHnswInfo, +} + +/// HNSW-specific index information. +#[repr(C)] +#[derive(Debug, Clone, Default)] +pub struct VecSimHnswInfo { + /// M parameter. + pub M: usize, + /// ef_construction parameter. + pub efConstruction: usize, + /// ef_runtime parameter. + pub efRuntime: usize, + /// Maximum level in the graph. + pub maxLevel: usize, + /// Entry point ID. + pub entrypoint: i64, + /// Epsilon parameter. + pub epsilon: f64, +} + +/// Index info that is static and immutable. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct VecSimIndexBasicInfo { + pub algo: VecSimAlgo, + pub metric: VecSimMetric, + pub type_: VecSimType, + pub isMulti: bool, + pub isTiered: bool, + pub isDisk: bool, + pub blockSize: usize, + pub dim: usize, +} + +/// Index info for statistics - thin and efficient. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct VecSimIndexStatsInfo { + pub memory: usize, + pub numberOfMarkedDeleted: usize, +} + +/// Common index information. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CommonInfo { + pub basicInfo: VecSimIndexBasicInfo, + pub indexSize: usize, + pub indexLabelCount: usize, + pub memory: u64, + pub lastMode: VecSearchMode, +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct hnswInfoStruct { + pub M: usize, + pub efConstruction: usize, + pub efRuntime: usize, + pub epsilon: f64, + pub max_level: usize, + pub entrypoint: usize, + pub visitedNodesPoolSize: usize, + pub numberOfMarkedDeletedNodes: usize, +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct bfInfoStruct { + pub dummy: i8, +} + +/// Debug information for an index. +#[repr(C)] +pub struct VecSimIndexDebugInfo { + pub commonInfo: CommonInfo, + // Union would be here but we'll use a simpler approach + pub hnswInfo: hnswInfoStruct, + pub bfInfo: bfInfoStruct, +} + +impl Default for VecSimIndexInfo { + fn default() -> Self { + Self { + indexSize: 0, + indexLabelCount: 0, + dim: 0, + type_: VecSimType::VecSimType_FLOAT32, + algo: VecSimAlgo::VecSimAlgo_BF, + metric: VecSimMetric::VecSimMetric_L2, + isMulti: false, + blockSize: 0, + memory: 0, + hnswInfo: VecSimHnswInfo::default(), + } + } +} + +/// Get index information. +pub fn get_index_info(handle: &IndexHandle) -> VecSimIndexInfo { + let wrapper = &handle.wrapper; + + VecSimIndexInfo { + indexSize: wrapper.index_size(), + indexLabelCount: wrapper.index_size(), // For single-value, same as size + dim: wrapper.dimension(), + type_: handle.data_type, + algo: handle.algo, + metric: handle.metric, + isMulti: handle.is_multi, + blockSize: 0, // Not directly exposed + memory: wrapper.memory_usage(), + hnswInfo: VecSimHnswInfo::default(), // Would need more specific introspection + } +} + +/// Get basic index statistics. +pub fn get_index_size(handle: &IndexHandle) -> usize { + handle.wrapper.index_size() +} + +/// Get vector dimension. +pub fn get_index_dim(handle: &IndexHandle) -> usize { + handle.wrapper.dimension() +} + +/// Get data type. +pub fn get_index_type(handle: &IndexHandle) -> VecSimType { + handle.data_type +} + +/// Get metric. +pub fn get_index_metric(handle: &IndexHandle) -> VecSimMetric { + handle.metric +} + +/// Check if index is multi-value. +pub fn is_index_multi(handle: &IndexHandle) -> bool { + handle.is_multi +} + +/// Check if label exists in index. +pub fn index_contains(handle: &IndexHandle, label: u64) -> bool { + handle.wrapper.contains(label) +} + +/// Get count of vectors with given label. +pub fn get_label_count(handle: &IndexHandle, label: u64) -> usize { + handle.wrapper.label_count(label) +} + +/// Get memory usage. +pub fn get_memory_usage(handle: &IndexHandle) -> usize { + handle.wrapper.memory_usage() +} diff --git a/rust/vecsim-c/src/lib.rs b/rust/vecsim-c/src/lib.rs new file mode 100644 index 000000000..cd6193654 --- /dev/null +++ b/rust/vecsim-c/src/lib.rs @@ -0,0 +1,4162 @@ +//! C-compatible FFI bindings for the VecSim vector similarity search library. +//! +//! This crate provides a C-compatible API that matches the existing C++ VecSim +//! implementation, enabling drop-in replacement. + +// Allow C-style naming conventions to match the C++ API +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(dead_code)] + +pub mod compat; +pub mod index; +pub mod info; +pub mod params; +pub mod query; +pub mod types; + +use index::{ + create_brute_force_index, create_disk_index, create_hnsw_index, create_svs_index, + create_tiered_index, IndexHandle, IndexWrapper, + BruteForceSingleF32Wrapper, BruteForceSingleF64Wrapper, + BruteForceMultiF32Wrapper, BruteForceMultiF64Wrapper, + HnswSingleF32Wrapper, HnswSingleF64Wrapper, + HnswMultiF32Wrapper, HnswMultiF64Wrapper, + SvsSingleF32Wrapper, SvsSingleF64Wrapper, + TieredSingleF32Wrapper, TieredMultiF32Wrapper, +}; +use info::{get_index_info, VecSimIndexInfo}; +use params::{ + BFParams, DiskParams, HNSWParams, SVSParams, TieredParams, VecSimParams, VecSimQueryParams, +}; +use query::{ + create_batch_iterator, range_query, top_k_query, BatchIteratorHandle, QueryReplyHandle, + QueryReplyIteratorHandle, +}; +use types::{ + labelType, QueryResultInternal, VecSimAlgo, VecSimBatchIterator, VecSimDebugCommandCode, + VecSimIndex, VecSimMetric, VecSimParamResolveCode, VecSimQueryReply, VecSimQueryReply_Iterator, + VecSimQueryReply_Order, VecSimQueryResult, VecSimRawParam, VecSimType, VecsimQueryType, +}; + +use std::ffi::{c_char, c_int, c_void}; +use std::ptr; +use std::sync::atomic::{AtomicU8, Ordering}; + +use types::{VecSimMemoryFunctions, VecSimWriteMode}; +use compat::{VecSimParams_C, BFParams_C, HNSWParams_C, SVSParams_C}; + +// ============================================================================ +// Global Memory Functions +// ============================================================================ + +use std::sync::RwLock; + +/// Global memory functions for custom memory management. +/// Protected by RwLock for thread-safe access. +static GLOBAL_MEMORY_FUNCTIONS: RwLock> = RwLock::new(None); + +/// Global timeout callback function. +static GLOBAL_TIMEOUT_CALLBACK: RwLock = RwLock::new(None); + +/// Global log callback function. +static GLOBAL_LOG_CALLBACK: RwLock = RwLock::new(None); + +/// Thread-safe wrapper for raw pointers in the log context. +struct SendSyncPtr(*const c_char); +unsafe impl Send for SendSyncPtr {} +unsafe impl Sync for SendSyncPtr {} + +/// Global log context for testing. +static GLOBAL_LOG_CONTEXT: RwLock> = RwLock::new(None); + +/// Set custom memory functions for all future allocations. +/// +/// This allows integration with external memory management systems like Redis. +/// The functions will be used for all memory allocations in the library. +/// +/// # Safety +/// The provided function pointers must be valid and thread-safe. +/// The functions must follow standard malloc/calloc/realloc/free semantics. +/// +/// # Note +/// This should be called once at initialization, before creating any indices. +#[no_mangle] +pub unsafe extern "C" fn VecSim_SetMemoryFunctions(functions: VecSimMemoryFunctions) { + if let Ok(mut guard) = GLOBAL_MEMORY_FUNCTIONS.write() { + *guard = Some(functions); + } +} + +/// Get the current memory functions. +/// +/// Returns the currently configured memory functions, or None if using defaults. +pub(crate) fn get_memory_functions() -> Option { + GLOBAL_MEMORY_FUNCTIONS.read().ok().and_then(|g| *g) +} + +/// Set the timeout callback function. +/// +/// The callback will be called periodically during long operations to check +/// if the operation should be aborted. Return non-zero to abort. +/// +/// # Safety +/// The callback function must be thread-safe. +#[no_mangle] +pub unsafe extern "C" fn VecSim_SetTimeoutCallbackFunction( + callback: types::timeoutCallbackFunction, +) { + if let Ok(mut guard) = GLOBAL_TIMEOUT_CALLBACK.write() { + *guard = callback; + } +} + +/// Set the log callback function. +/// +/// The callback will be called for logging messages from the library. +/// +/// # Safety +/// The callback function must be thread-safe. +#[no_mangle] +pub unsafe extern "C" fn VecSim_SetLogCallbackFunction( + callback: types::logCallbackFunction, +) { + if let Ok(mut guard) = GLOBAL_LOG_CALLBACK.write() { + *guard = callback; + } +} + +/// Set the test log context. +/// +/// This is used for testing to identify which test is running. +/// +/// # Safety +/// The pointers must be valid for the duration of the test. +#[no_mangle] +pub unsafe extern "C" fn VecSim_SetTestLogContext( + test_name: *const c_char, + test_type: *const c_char, +) { + if let Ok(mut guard) = GLOBAL_LOG_CONTEXT.write() { + *guard = Some((SendSyncPtr(test_name), SendSyncPtr(test_type))); + } +} + +// ============================================================================ +// Global Write Mode State +// ============================================================================ + +/// Global write mode for tiered indices. +/// Accessed atomically for thread-safety. +static GLOBAL_WRITE_MODE: AtomicU8 = AtomicU8::new(0); // VecSim_WriteAsync = 0 + +/// Set the global write mode for tiered index operations. +/// +/// This controls whether vector additions/deletions in tiered indices go through +/// the async buffering path (VecSim_WriteAsync) or directly to the backend index +/// (VecSim_WriteInPlace). +/// +/// # Note +/// In a tiered index scenario, this should be called from the main thread only +/// (that is, the thread that is calling add/delete vector functions). +#[no_mangle] +pub extern "C" fn VecSim_SetWriteMode(mode: VecSimWriteMode) { + GLOBAL_WRITE_MODE.store(mode as u8, Ordering::SeqCst); +} + +/// Get the current global write mode. +/// +/// Returns the currently active write mode for tiered index operations. +#[no_mangle] +pub extern "C" fn VecSim_GetWriteMode() -> VecSimWriteMode { + match GLOBAL_WRITE_MODE.load(Ordering::SeqCst) { + 0 => VecSimWriteMode::VecSim_WriteAsync, + _ => VecSimWriteMode::VecSim_WriteInPlace, + } +} + +/// Internal function to get write mode as the Rust enum. +pub(crate) fn get_global_write_mode() -> VecSimWriteMode { + VecSim_GetWriteMode() +} + +// ============================================================================ +// Parameter Resolution +// ============================================================================ + +/// Known parameter names for parameter resolution. +mod param_names { + pub const EF_RUNTIME: &str = "EF_RUNTIME"; + pub const EPSILON: &str = "EPSILON"; + pub const BATCH_SIZE: &str = "BATCH_SIZE"; + pub const HYBRID_POLICY: &str = "HYBRID_POLICY"; + pub const SEARCH_WINDOW_SIZE: &str = "SEARCH_WINDOW_SIZE"; + pub const SEARCH_BUFFER_CAPACITY: &str = "SEARCH_BUFFER_CAPACITY"; + pub const USE_SEARCH_HISTORY: &str = "USE_SEARCH_HISTORY"; + pub const POLICY_BATCHES: &str = "batches"; + pub const POLICY_ADHOC_BF: &str = "adhoc_bf"; +} + +/// Parse a string value as a positive integer. +fn parse_positive_integer(value: &str) -> Option { + value.parse::().ok().filter(|&v| v > 0).map(|v| v as usize) +} + +/// Parse a string value as a positive f64. +fn parse_positive_double(value: &str) -> Option { + value.parse::().ok().filter(|&v| v > 0.0) +} + +/// Resolve runtime query parameters from raw string parameters. +/// +/// # Safety +/// All pointers must be valid. `index`, `rparams`, and `qparams` must not be null +/// (except rparams when paramNum is 0). +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_ResolveParams( + index: *mut VecSimIndex, + rparams: *const VecSimRawParam, + paramNum: i32, + qparams: *mut VecSimQueryParams, + query_type: VecsimQueryType, +) -> VecSimParamResolveCode { + use VecSimParamResolveCode::*; + + // Null check for qparams (required) and rparams (required if paramNum > 0) + if qparams.is_null() || (rparams.is_null() && paramNum != 0) { + return VecSimParamResolverErr_NullParam; + } + + // Zero out qparams + let qparams = &mut *qparams; + *qparams = VecSimQueryParams::default(); + qparams.hnsw_params_mut().efRuntime = 0; // Reset to 0 for checking duplicates + qparams.hnsw_params_mut().epsilon = 0.0; + *qparams.svs_params_mut() = params::SVSRuntimeParams::default(); + qparams.batchSize = 0; + qparams.searchMode = 1; // STANDARD_KNN + + if paramNum == 0 { + return VecSimParamResolver_OK; + } + + let handle = &*(index as *const IndexHandle); + let index_type = handle.algo; + + let params_slice = std::slice::from_raw_parts(rparams, paramNum as usize); + + for rparam in params_slice { + // Get the parameter name as a Rust string + let name = if rparam.nameLen > 0 && !rparam.name.is_null() { + std::str::from_utf8(std::slice::from_raw_parts( + rparam.name as *const u8, + rparam.nameLen, + )) + .unwrap_or("") + } else { + "" + }; + + // Get the parameter value as a Rust string + let value = if rparam.valLen > 0 && !rparam.value.is_null() { + std::str::from_utf8(std::slice::from_raw_parts( + rparam.value as *const u8, + rparam.valLen, + )) + .unwrap_or("") + } else { + "" + }; + + let result = match name.to_uppercase().as_str() { + param_names::EF_RUNTIME => { + resolve_ef_runtime(index_type, value, qparams, query_type) + } + param_names::EPSILON => { + resolve_epsilon(index_type, value, qparams, query_type) + } + param_names::BATCH_SIZE => { + resolve_batch_size(value, qparams, query_type) + } + param_names::HYBRID_POLICY => { + resolve_hybrid_policy(value, qparams, query_type) + } + param_names::SEARCH_WINDOW_SIZE => { + resolve_search_window_size(index_type, value, qparams) + } + param_names::SEARCH_BUFFER_CAPACITY => { + resolve_search_buffer_capacity(index_type, value, qparams) + } + param_names::USE_SEARCH_HISTORY => { + resolve_use_search_history(index_type, value, qparams) + } + _ => VecSimParamResolverErr_UnknownParam, + }; + + if result != VecSimParamResolver_OK { + return result; + } + } + + // Validate parameter combinations + // AD-HOC with batch_size is invalid + // searchMode == 2 is HYBRID_ADHOC_BF + if qparams.searchMode == 2 && qparams.batchSize > 0 { + return VecSimParamResolverErr_InvalidPolicy_AdHoc_With_BatchSize; + } + + // AD-HOC with ef_runtime is invalid for HNSW + if qparams.searchMode == 2 + && index_type == VecSimAlgo::VecSimAlgo_HNSWLIB + && qparams.hnsw_params().efRuntime > 0 + { + return VecSimParamResolverErr_InvalidPolicy_AdHoc_With_EfRuntime; + } + + VecSimParamResolver_OK +} + +/// Check if the index type supports HNSW parameters. +/// This includes both pure HNSW indices and Tiered indices (which use HNSW as backend). +fn supports_hnsw_params(index_type: VecSimAlgo) -> bool { + matches!( + index_type, + VecSimAlgo::VecSimAlgo_HNSWLIB | VecSimAlgo::VecSimAlgo_TIERED + ) +} + +fn resolve_ef_runtime( + index_type: VecSimAlgo, + value: &str, + qparams: &mut VecSimQueryParams, + query_type: VecsimQueryType, +) -> VecSimParamResolveCode { + use VecSimParamResolveCode::*; + + // EF_RUNTIME is valid only for HNSW and Tiered (which uses HNSW backend) + if !supports_hnsw_params(index_type) { + return VecSimParamResolverErr_UnknownParam; + } + // EF_RUNTIME is invalid for range query + if query_type == VecsimQueryType::QUERY_TYPE_RANGE { + return VecSimParamResolverErr_UnknownParam; + } + // Check if already set + if qparams.hnsw_params().efRuntime != 0 { + return VecSimParamResolverErr_AlreadySet; + } + // Parse value + match parse_positive_integer(value) { + Some(v) => { + qparams.hnsw_params_mut().efRuntime = v; + VecSimParamResolver_OK + } + None => VecSimParamResolverErr_BadValue, + } +} + +fn resolve_epsilon( + index_type: VecSimAlgo, + value: &str, + qparams: &mut VecSimQueryParams, + query_type: VecsimQueryType, +) -> VecSimParamResolveCode { + use VecSimParamResolveCode::*; + + // EPSILON is valid only for HNSW, Tiered (HNSW backend), or SVS + if !supports_hnsw_params(index_type) && index_type != VecSimAlgo::VecSimAlgo_SVS { + return VecSimParamResolverErr_UnknownParam; + } + // EPSILON is valid only for range queries + if query_type != VecsimQueryType::QUERY_TYPE_RANGE { + return VecSimParamResolverErr_InvalidPolicy_NRange; + } + // Check if already set (based on index type) + // For HNSW and Tiered (HNSW backend), use HNSW params; for SVS, use SVS params + let current_epsilon = if supports_hnsw_params(index_type) { + qparams.hnsw_params().epsilon + } else { + qparams.svs_params().epsilon + }; + if current_epsilon != 0.0 { + return VecSimParamResolverErr_AlreadySet; + } + // Parse value + match parse_positive_double(value) { + Some(v) => { + if supports_hnsw_params(index_type) { + qparams.hnsw_params_mut().epsilon = v; + } else { + qparams.svs_params_mut().epsilon = v; + } + VecSimParamResolver_OK + } + None => VecSimParamResolverErr_BadValue, + } +} + +fn resolve_batch_size( + value: &str, + qparams: &mut VecSimQueryParams, + query_type: VecsimQueryType, +) -> VecSimParamResolveCode { + use VecSimParamResolveCode::*; + + // BATCH_SIZE is valid only for hybrid queries + if query_type != VecsimQueryType::QUERY_TYPE_HYBRID { + return VecSimParamResolverErr_InvalidPolicy_NHybrid; + } + // Check if already set + if qparams.batchSize != 0 { + return VecSimParamResolverErr_AlreadySet; + } + // Parse value + match parse_positive_integer(value) { + Some(v) => { + qparams.batchSize = v; + VecSimParamResolver_OK + } + None => VecSimParamResolverErr_BadValue, + } +} + +fn resolve_hybrid_policy( + value: &str, + qparams: &mut VecSimQueryParams, + query_type: VecsimQueryType, +) -> VecSimParamResolveCode { + use VecSimParamResolveCode::*; + + // HYBRID_POLICY is valid only for hybrid queries + if query_type != VecsimQueryType::QUERY_TYPE_HYBRID { + return VecSimParamResolverErr_InvalidPolicy_NHybrid; + } + // Check if already set (searchMode != STANDARD_KNN indicates it was set) + // VecSearchMode values: EMPTY_MODE=0, STANDARD_KNN=1, HYBRID_ADHOC_BF=2, HYBRID_BATCHES=3 + if qparams.searchMode != 1 { + return VecSimParamResolverErr_AlreadySet; + } + // Parse value (case-insensitive) + match value.to_lowercase().as_str() { + param_names::POLICY_BATCHES => { + qparams.searchMode = 3; // HYBRID_BATCHES + VecSimParamResolver_OK + } + param_names::POLICY_ADHOC_BF => { + qparams.searchMode = 2; // HYBRID_ADHOC_BF + VecSimParamResolver_OK + } + _ => VecSimParamResolverErr_InvalidPolicy_NExits, + } +} + +fn resolve_search_window_size( + index_type: VecSimAlgo, + value: &str, + qparams: &mut VecSimQueryParams, +) -> VecSimParamResolveCode { + use VecSimParamResolveCode::*; + + // SEARCH_WINDOW_SIZE is valid only for SVS + if index_type != VecSimAlgo::VecSimAlgo_SVS { + return VecSimParamResolverErr_UnknownParam; + } + // Check if already set + if qparams.svs_params().windowSize != 0 { + return VecSimParamResolverErr_AlreadySet; + } + // Parse value + match parse_positive_integer(value) { + Some(v) => { + qparams.svs_params_mut().windowSize = v; + VecSimParamResolver_OK + } + None => VecSimParamResolverErr_BadValue, + } +} + +fn resolve_search_buffer_capacity( + index_type: VecSimAlgo, + value: &str, + qparams: &mut VecSimQueryParams, +) -> VecSimParamResolveCode { + use VecSimParamResolveCode::*; + + // SEARCH_BUFFER_CAPACITY is valid only for SVS + if index_type != VecSimAlgo::VecSimAlgo_SVS { + return VecSimParamResolverErr_UnknownParam; + } + // Check if already set + if qparams.svs_params().bufferCapacity != 0 { + return VecSimParamResolverErr_AlreadySet; + } + // Parse value + match parse_positive_integer(value) { + Some(v) => { + qparams.svs_params_mut().bufferCapacity = v; + VecSimParamResolver_OK + } + None => VecSimParamResolverErr_BadValue, + } +} + +fn resolve_use_search_history( + index_type: VecSimAlgo, + value: &str, + qparams: &mut VecSimQueryParams, +) -> VecSimParamResolveCode { + use VecSimParamResolveCode::*; + + // USE_SEARCH_HISTORY is valid only for SVS + if index_type != VecSimAlgo::VecSimAlgo_SVS { + return VecSimParamResolverErr_UnknownParam; + } + // Check if already set + if qparams.svs_params().searchHistory != 0 { + return VecSimParamResolverErr_AlreadySet; + } + // Parse as boolean (1/0, true/false, yes/no) + let bool_val = match value.to_lowercase().as_str() { + "1" | "true" | "yes" => Some(1), + "0" | "false" | "no" => Some(0), + _ => None, + }; + match bool_val { + Some(v) => { + qparams.svs_params_mut().searchHistory = v; + VecSimParamResolver_OK + } + None => VecSimParamResolverErr_BadValue, + } +} + +// ============================================================================ +// Index Lifecycle Functions +// ============================================================================ + +/// Create a new BruteForce index with specific parameters. +/// +/// # Safety +/// The `params` pointer must be valid. +/// This function accepts the C++-compatible BFParams struct (BFParams_C). +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_NewBF(params: *const BFParams_C) -> *mut VecSimIndex { + if params.is_null() { + return ptr::null_mut(); + } + + let c_params = &*params; + // Convert C++-compatible BFParams to internal BFParams + let rust_params = BFParams { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_BF, + type_: c_params.type_, + metric: c_params.metric, + dim: c_params.dim, + multi: c_params.multi, + initialCapacity: c_params.initialCapacity, + blockSize: c_params.blockSize, + }, + }; + match create_brute_force_index(&rust_params) { + Some(boxed) => Box::into_raw(boxed) as *mut VecSimIndex, + None => ptr::null_mut(), + } +} + +/// Create a new HNSW index with specific parameters. +/// +/// # Safety +/// The `params` pointer must be valid. +/// This function accepts the C++-compatible HNSWParams struct (HNSWParams_C). +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_NewHNSW(params: *const HNSWParams_C) -> *mut VecSimIndex { + if params.is_null() { + return ptr::null_mut(); + } + + let c_params = &*params; + // Convert C++-compatible HNSWParams to internal HNSWParams + let rust_params = HNSWParams { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_HNSWLIB, + type_: c_params.type_, + metric: c_params.metric, + dim: c_params.dim, + multi: c_params.multi, + initialCapacity: c_params.initialCapacity, + blockSize: c_params.blockSize, + }, + M: c_params.M, + efConstruction: c_params.efConstruction, + efRuntime: c_params.efRuntime, + epsilon: c_params.epsilon, + }; + match create_hnsw_index(&rust_params) { + Some(boxed) => Box::into_raw(boxed) as *mut VecSimIndex, + None => ptr::null_mut(), + } +} + +/// Create a new SVS (Vamana) index with specific parameters. +/// +/// # Safety +/// The `params` pointer must be valid. +/// This function accepts the C++-compatible SVSParams struct (SVSParams_C). +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_NewSVS(params: *const SVSParams_C) -> *mut VecSimIndex { + if params.is_null() { + return ptr::null_mut(); + } + + let c_params = &*params; + // Convert C++-compatible SVSParams to internal SVSParams + let rust_params = SVSParams { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_SVS, + type_: c_params.type_, + metric: c_params.metric, + dim: c_params.dim, + multi: c_params.multi, + initialCapacity: 0, // SVS doesn't use initialCapacity + blockSize: c_params.blockSize, + }, + graphMaxDegree: c_params.graph_max_degree, + alpha: c_params.alpha, + constructionWindowSize: c_params.construction_window_size, + searchWindowSize: c_params.search_window_size, + twoPassConstruction: true, // Default value + }; + match create_svs_index(&rust_params) { + Some(boxed) => Box::into_raw(boxed) as *mut VecSimIndex, + None => ptr::null_mut(), + } +} + +/// Create a new Tiered index with specific parameters. +/// +/// The tiered index combines a BruteForce frontend (for fast writes) with +/// an HNSW backend (for efficient queries). Currently only supports f32 vectors. +/// +/// # Safety +/// The `params` pointer must be valid. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_NewTiered(params: *const TieredParams) -> *mut VecSimIndex { + if params.is_null() { + return ptr::null_mut(); + } + + let params = &*params; + match create_tiered_index(params) { + Some(boxed) => Box::into_raw(boxed) as *mut VecSimIndex, + None => ptr::null_mut(), + } +} + +// ============================================================================ +// Tiered Index Operations +// ============================================================================ + +/// Flush the flat buffer to the HNSW backend. +/// +/// This migrates all vectors from the flat buffer to the HNSW index. +/// Returns the number of vectors flushed, or 0 if the index is not tiered. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_NewTiered`. +#[no_mangle] +pub unsafe extern "C" fn VecSimTieredIndex_Flush(index: *mut VecSimIndex) -> usize { + if index.is_null() { + return 0; + } + let handle = &mut *(index as *mut IndexHandle); + handle.wrapper.tiered_flush() +} + +/// Get the number of vectors in the flat buffer. +/// +/// Returns 0 if the index is not tiered. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_NewTiered`. +#[no_mangle] +pub unsafe extern "C" fn VecSimTieredIndex_FlatSize(index: *const VecSimIndex) -> usize { + if index.is_null() { + return 0; + } + let handle = &*(index as *const IndexHandle); + handle.wrapper.tiered_flat_size() +} + +/// Get the number of vectors in the HNSW backend. +/// +/// Returns 0 if the index is not tiered. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_NewTiered`. +#[no_mangle] +pub unsafe extern "C" fn VecSimTieredIndex_BackendSize(index: *const VecSimIndex) -> usize { + if index.is_null() { + return 0; + } + let handle = &*(index as *const IndexHandle); + handle.wrapper.tiered_backend_size() +} + +/// Run garbage collection on a tiered index. +/// +/// This cleans up deleted vectors and optimizes the index structure. +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_NewTiered`. +#[no_mangle] +pub unsafe extern "C" fn VecSimTieredIndex_GC(index: *mut VecSimIndex) { + if index.is_null() { + return; + } + let handle = &mut *(index as *mut IndexHandle); + handle.wrapper.tiered_gc(); +} + +/// Acquire shared locks on a tiered index. +/// +/// This prevents modifications to the index while the locks are held. +/// Must be paired with VecSimTieredIndex_ReleaseSharedLocks. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_NewTiered`. +#[no_mangle] +pub unsafe extern "C" fn VecSimTieredIndex_AcquireSharedLocks(index: *mut VecSimIndex) { + if index.is_null() { + return; + } + let handle = &mut *(index as *mut IndexHandle); + handle.wrapper.tiered_acquire_shared_locks(); +} + +/// Release shared locks on a tiered index. +/// +/// Must be called after VecSimTieredIndex_AcquireSharedLocks. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_NewTiered`. +#[no_mangle] +pub unsafe extern "C" fn VecSimTieredIndex_ReleaseSharedLocks(index: *mut VecSimIndex) { + if index.is_null() { + return; + } + let handle = &mut *(index as *mut IndexHandle); + handle.wrapper.tiered_release_shared_locks(); +} + +/// Check if the index is a tiered index. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New` or similar. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_IsTiered(index: *const VecSimIndex) -> bool { + if index.is_null() { + return false; + } + let handle = &*(index as *const IndexHandle); + handle.wrapper.is_tiered() +} + +// ============================================================================ +// Disk Index Functions +// ============================================================================ + +/// Create a new disk-based index. +/// +/// Disk indices store vectors in memory-mapped files for persistence. +/// They support two backends: +/// - BruteForce: Linear scan (exact results, O(n)) +/// - Vamana: Graph-based approximate search (fast, O(log n)) +/// +/// # Safety +/// - `params` must be a valid pointer to a `DiskParams` struct +/// - `params.dataPath` must be a valid null-terminated C string +/// +/// # Returns +/// A pointer to the new index, or null if creation failed. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_NewDisk(params: *const DiskParams) -> *mut VecSimIndex { + if params.is_null() { + return ptr::null_mut(); + } + + let params = &*params; + match create_disk_index(params) { + Some(handle) => Box::into_raw(handle) as *mut VecSimIndex, + None => ptr::null_mut(), + } +} + +// ============================================================================ +// Generic C++-compatible Index Creation +// ============================================================================ + +/// Create a new index using C++-compatible VecSimParams structure. +/// +/// This function provides drop-in compatibility with the C++ VecSim API. +/// It reads the `algo` field to determine which type of index to create, +/// then accesses the appropriate union variant. +/// +/// # Safety +/// - `params` must be a valid pointer to a VecSimParams_C struct +/// - The `algo` field must match the initialized union variant in `algoParams` +/// - For tiered indices, `primaryIndexParams` must be a valid pointer +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_New(params: *const VecSimParams_C) -> *mut VecSimIndex { + if params.is_null() { + return ptr::null_mut(); + } + + let params = &*params; + + match params.algo { + VecSimAlgo::VecSimAlgo_BF => { + let bf = params.algoParams.bfParams; + create_bf_index_raw( + bf.type_, + bf.metric, + bf.dim, + bf.multi, + bf.initialCapacity, + bf.blockSize, + ) + } + VecSimAlgo::VecSimAlgo_HNSWLIB => { + let hnsw = params.algoParams.hnswParams; + create_hnsw_index_raw( + hnsw.type_, + hnsw.metric, + hnsw.dim, + hnsw.multi, + hnsw.initialCapacity, + hnsw.M, + hnsw.efConstruction, + hnsw.efRuntime, + ) + } + VecSimAlgo::VecSimAlgo_SVS => { + let svs = params.algoParams.svsParams; + create_svs_index_raw( + svs.type_, + svs.metric, + svs.dim, + svs.multi, + svs.graph_max_degree, + svs.alpha, + svs.construction_window_size, + svs.search_window_size, + ) + } + VecSimAlgo::VecSimAlgo_TIERED => { + let tiered = &*params.algoParams.tieredParams; + + // Get primary index params + if tiered.primaryIndexParams.is_null() { + return ptr::null_mut(); + } + let primary = &*tiered.primaryIndexParams; + + // Determine the backend type from primary params + match primary.algo { + VecSimAlgo::VecSimAlgo_HNSWLIB => { + let hnsw = primary.algoParams.hnswParams; + create_tiered_index_raw( + hnsw.type_, + hnsw.metric, + hnsw.dim, + hnsw.multi, + hnsw.M, + hnsw.efConstruction, + hnsw.efRuntime, + tiered.flatBufferLimit, + ) + } + _ => ptr::null_mut(), // Only HNSW backend supported for now + } + } + } +} + +// Helper function to create BF index +unsafe fn create_bf_index_raw( + type_: VecSimType, + metric: VecSimMetric, + dim: usize, + multi: bool, + initial_capacity: usize, + block_size: usize, +) -> *mut VecSimIndex { + let rust_metric = metric.to_rust_metric(); + let block = if block_size > 0 { block_size } else { 1024 }; + // SIZE_MAX is used as a sentinel value meaning "use default capacity" + // The C++ VecSim library treats initialCapacity as deprecated + let capacity = if initial_capacity == usize::MAX { 1024 } else { initial_capacity }; + + let wrapper: Box = match (type_, multi) { + (VecSimType::VecSimType_FLOAT32, false) => { + let params = vecsim::index::BruteForceParams::new(dim, rust_metric) + .with_capacity(capacity) + .with_block_size(block); + Box::new(BruteForceSingleF32Wrapper::new( + vecsim::index::BruteForceSingle::new(params), + type_, + )) + } + (VecSimType::VecSimType_FLOAT64, false) => { + let params = vecsim::index::BruteForceParams::new(dim, rust_metric) + .with_capacity(capacity) + .with_block_size(block); + Box::new(BruteForceSingleF64Wrapper::new( + vecsim::index::BruteForceSingle::new(params), + type_, + )) + } + (VecSimType::VecSimType_FLOAT32, true) => { + let params = vecsim::index::BruteForceParams::new(dim, rust_metric) + .with_capacity(capacity) + .with_block_size(block); + Box::new(BruteForceMultiF32Wrapper::new( + vecsim::index::BruteForceMulti::new(params), + type_, + )) + } + (VecSimType::VecSimType_FLOAT64, true) => { + let params = vecsim::index::BruteForceParams::new(dim, rust_metric) + .with_capacity(capacity) + .with_block_size(block); + Box::new(BruteForceMultiF64Wrapper::new( + vecsim::index::BruteForceMulti::new(params), + type_, + )) + } + _ => return ptr::null_mut(), + }; + + Box::into_raw(Box::new(IndexHandle::new( + wrapper, + type_, + VecSimAlgo::VecSimAlgo_BF, + metric, + dim, + multi, + ))) as *mut VecSimIndex +} + +// Helper function to create HNSW index +unsafe fn create_hnsw_index_raw( + type_: VecSimType, + metric: VecSimMetric, + dim: usize, + multi: bool, + initial_capacity: usize, + m: usize, + ef_construction: usize, + ef_runtime: usize, +) -> *mut VecSimIndex { + let rust_metric = metric.to_rust_metric(); + // SIZE_MAX is used as a sentinel value meaning "use default capacity" + let capacity = if initial_capacity == usize::MAX { 1024 } else { initial_capacity }; + + let wrapper: Box = match (type_, multi) { + (VecSimType::VecSimType_FLOAT32, false) => { + let params = vecsim::index::HnswParams::new(dim, rust_metric) + .with_m(m) + .with_ef_construction(ef_construction) + .with_ef_runtime(ef_runtime) + .with_capacity(capacity); + Box::new(HnswSingleF32Wrapper::new( + vecsim::index::HnswSingle::new(params), + type_, + )) + } + (VecSimType::VecSimType_FLOAT64, false) => { + let params = vecsim::index::HnswParams::new(dim, rust_metric) + .with_m(m) + .with_ef_construction(ef_construction) + .with_ef_runtime(ef_runtime) + .with_capacity(capacity); + Box::new(HnswSingleF64Wrapper::new( + vecsim::index::HnswSingle::new(params), + type_, + )) + } + (VecSimType::VecSimType_FLOAT32, true) => { + let params = vecsim::index::HnswParams::new(dim, rust_metric) + .with_m(m) + .with_ef_construction(ef_construction) + .with_ef_runtime(ef_runtime) + .with_capacity(capacity); + Box::new(HnswMultiF32Wrapper::new( + vecsim::index::HnswMulti::new(params), + type_, + )) + } + (VecSimType::VecSimType_FLOAT64, true) => { + let params = vecsim::index::HnswParams::new(dim, rust_metric) + .with_m(m) + .with_ef_construction(ef_construction) + .with_ef_runtime(ef_runtime) + .with_capacity(capacity); + Box::new(HnswMultiF64Wrapper::new( + vecsim::index::HnswMulti::new(params), + type_, + )) + } + _ => return ptr::null_mut(), + }; + + Box::into_raw(Box::new(IndexHandle::new( + wrapper, + type_, + VecSimAlgo::VecSimAlgo_HNSWLIB, + metric, + dim, + multi, + ))) as *mut VecSimIndex +} + +// Helper function to create SVS index +unsafe fn create_svs_index_raw( + type_: VecSimType, + metric: VecSimMetric, + dim: usize, + multi: bool, + graph_degree: usize, + alpha: f32, + construction_l: usize, + search_l: usize, +) -> *mut VecSimIndex { + let rust_metric = metric.to_rust_metric(); + + // SVS only supports single-label for now + if multi { + return ptr::null_mut(); + } + + let wrapper: Box = match type_ { + VecSimType::VecSimType_FLOAT32 => { + let params = vecsim::index::SvsParams::new(dim, rust_metric) + .with_graph_degree(graph_degree) + .with_alpha(alpha) + .with_construction_l(construction_l) + .with_search_l(search_l); + Box::new(SvsSingleF32Wrapper::new( + vecsim::index::SvsSingle::new(params), + type_, + )) + } + VecSimType::VecSimType_FLOAT64 => { + let params = vecsim::index::SvsParams::new(dim, rust_metric) + .with_graph_degree(graph_degree) + .with_alpha(alpha) + .with_construction_l(construction_l) + .with_search_l(search_l); + Box::new(SvsSingleF64Wrapper::new( + vecsim::index::SvsSingle::new(params), + type_, + )) + } + _ => return ptr::null_mut(), + }; + + Box::into_raw(Box::new(IndexHandle::new( + wrapper, + type_, + VecSimAlgo::VecSimAlgo_SVS, + metric, + dim, + false, + ))) as *mut VecSimIndex +} + +// Helper function to create tiered index +// Note: Tiered only supports f32 currently +unsafe fn create_tiered_index_raw( + type_: VecSimType, + metric: VecSimMetric, + dim: usize, + multi: bool, + m: usize, + ef_construction: usize, + ef_runtime: usize, + flat_buffer_limit: usize, +) -> *mut VecSimIndex { + // Tiered only supports f32 currently + if type_ != VecSimType::VecSimType_FLOAT32 { + return ptr::null_mut(); + } + + let rust_metric = metric.to_rust_metric(); + + let wrapper: Box = if multi { + let params = vecsim::index::TieredParams::new(dim, rust_metric) + .with_m(m) + .with_ef_construction(ef_construction) + .with_ef_runtime(ef_runtime) + .with_flat_buffer_limit(flat_buffer_limit); + Box::new(TieredMultiF32Wrapper::new( + vecsim::index::TieredMulti::new(params), + type_, + )) + } else { + let params = vecsim::index::TieredParams::new(dim, rust_metric) + .with_m(m) + .with_ef_construction(ef_construction) + .with_ef_runtime(ef_runtime) + .with_flat_buffer_limit(flat_buffer_limit); + Box::new(TieredSingleF32Wrapper::new( + vecsim::index::TieredSingle::new(params), + type_, + )) + }; + + Box::into_raw(Box::new(IndexHandle::new( + wrapper, + type_, + VecSimAlgo::VecSimAlgo_TIERED, + metric, + dim, + multi, + ))) as *mut VecSimIndex +} + +/// Check if the index is a disk-based index. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New` or similar. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_IsDisk(index: *const VecSimIndex) -> bool { + if index.is_null() { + return false; + } + let handle = &*(index as *const IndexHandle); + handle.wrapper.is_disk() +} + +/// Flush changes to disk for a disk-based index. +/// +/// This ensures all pending changes are written to the underlying file. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_NewDisk`. +/// +/// # Returns +/// true if flush succeeded, false otherwise. +#[no_mangle] +pub unsafe extern "C" fn VecSimDiskIndex_Flush(index: *const VecSimIndex) -> bool { + if index.is_null() { + return false; + } + let handle = &*(index as *const IndexHandle); + handle.wrapper.disk_flush() +} + +/// Free a vector similarity index. +/// +/// # Safety +/// The `index` pointer must have been returned by `VecSimIndex_New` or be null. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_Free(index: *mut VecSimIndex) { + if !index.is_null() { + drop(Box::from_raw(index as *mut IndexHandle)); + } +} + +// ============================================================================ +// Vector Operations +// ============================================================================ + +/// Add a vector to the index. +/// +/// # Safety +/// - `index` must be a valid pointer returned by `VecSimIndex_New` +/// - `vector` must point to a valid array of the correct type and dimension +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_AddVector( + index: *mut VecSimIndex, + vector: *const c_void, + label: labelType, +) -> i32 { + if index.is_null() || vector.is_null() { + return -1; + } + + let handle = &mut *(index as *mut IndexHandle); + handle.wrapper.add_vector(vector, label) +} + +/// Delete a vector by label. +/// +/// # Safety +/// - `index` must be a valid pointer returned by `VecSimIndex_New` +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_DeleteVector( + index: *mut VecSimIndex, + label: labelType, +) -> i32 { + if index.is_null() { + return 0; + } + + let handle = &mut *(index as *mut IndexHandle); + handle.wrapper.delete_vector(label) +} + +/// Get distance from a stored vector (by label) to a query vector. +/// +/// # Safety +/// - `index` must be a valid pointer returned by `VecSimIndex_New` +/// - `vector` must point to a valid array of the correct type and dimension +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_GetDistanceFrom_Unsafe( + index: *mut VecSimIndex, + label: labelType, + vector: *const c_void, +) -> f64 { + if index.is_null() || vector.is_null() { + return f64::INFINITY; + } + + let handle = &*(index as *const IndexHandle); + handle.wrapper.get_distance_from(label, vector) +} + +// ============================================================================ +// Query Functions +// ============================================================================ + +/// Perform a top-k nearest neighbor query. +/// +/// # Safety +/// - `index` must be a valid pointer returned by `VecSimIndex_New` +/// - `query` must point to a valid array of the correct type and dimension +/// - `params` may be null for default parameters +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_TopKQuery( + index: *mut VecSimIndex, + query: *const c_void, + k: usize, + params: *const VecSimQueryParams, + order: VecSimQueryReply_Order, +) -> *mut VecSimQueryReply { + if index.is_null() || query.is_null() { + return ptr::null_mut(); + } + + let handle = &*(index as *const IndexHandle); + let params_opt = if params.is_null() { None } else { Some(&*params) }; + + let reply_handle = top_k_query(handle, query, k, params_opt, order); + Box::into_raw(Box::new(reply_handle)) as *mut VecSimQueryReply +} + +/// Perform a range query. +/// +/// # Safety +/// - `index` must be a valid pointer returned by `VecSimIndex_New` +/// - `query` must point to a valid array of the correct type and dimension +/// - `params` may be null for default parameters +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_RangeQuery( + index: *mut VecSimIndex, + query: *const c_void, + radius: f64, + params: *const VecSimQueryParams, + order: VecSimQueryReply_Order, +) -> *mut VecSimQueryReply { + if index.is_null() || query.is_null() { + return ptr::null_mut(); + } + + let handle = &*(index as *const IndexHandle); + let params_opt = if params.is_null() { None } else { Some(&*params) }; + + let reply_handle = range_query(handle, query, radius, params_opt, order); + Box::into_raw(Box::new(reply_handle)) as *mut VecSimQueryReply +} + +/// Determine if ad-hoc brute-force search is preferred over batched search. +/// +/// This is a heuristic function that helps decide the optimal search strategy +/// for hybrid queries based on the index size and the number of results needed. +/// +/// # Arguments +/// * `index` - The index handle +/// * `subsetSize` - The estimated size of the subset to search +/// * `k` - The number of results requested +/// * `initial` - Whether this is the initial decision (true) or a re-evaluation (false) +/// +/// # Returns +/// true if ad-hoc search is preferred, false if batched search is preferred. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_PreferAdHocSearch( + index: *const VecSimIndex, + subsetSize: usize, + k: usize, + initial: bool, +) -> bool { + if index.is_null() { + return true; // Default to ad-hoc for safety + } + + let handle = &*(index as *const IndexHandle); + let index_size = handle.wrapper.index_size(); + + if index_size == 0 { + return true; + } + + // Heuristic: prefer ad-hoc when subset is small relative to index + // or when k is large relative to subset + let subset_ratio = subsetSize as f64 / index_size as f64; + let k_ratio = k as f64 / subsetSize.max(1) as f64; + + // If subset is less than 10% of index, prefer ad-hoc + if subset_ratio < 0.1 { + return true; + } + + // If we need more than 10% of the subset, prefer ad-hoc + if k_ratio > 0.1 { + return true; + } + + // For initial decision, be more conservative (prefer batches) + // For re-evaluation, be more aggressive (prefer ad-hoc) + if initial { + subset_ratio < 0.3 + } else { + subset_ratio < 0.5 + } +} + +// ============================================================================ +// Query Reply Functions +// ============================================================================ + +/// Get the number of results in a query reply. +/// +/// # Safety +/// `reply` must be a valid pointer returned by a query function. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryReply_Len(reply: *const VecSimQueryReply) -> usize { + if reply.is_null() { + return 0; + } + + let handle = &*(reply as *const QueryReplyHandle); + handle.len() +} + +/// Get the status code of a query reply. +/// +/// This is used to detect if the query timed out. +/// +/// # Safety +/// `reply` must be a valid pointer returned by `VecSimIndex_TopKQuery` or `VecSimIndex_RangeQuery`. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryReply_GetCode( + reply: *const VecSimQueryReply, +) -> types::VecSimQueryReply_Code { + if reply.is_null() { + return types::VecSimQueryReply_Code::VecSim_QueryReply_OK; + } + let handle = &*(reply as *const query::QueryReplyHandle); + handle.code() +} + +/// Free a query reply. +/// +/// # Safety +/// `reply` must have been returned by a query function or be null. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryReply_Free(reply: *mut VecSimQueryReply) { + if !reply.is_null() { + drop(Box::from_raw(reply as *mut QueryReplyHandle)); + } +} + +/// Get an iterator over query results. +/// +/// # Safety +/// `reply` must be a valid pointer returned by a query function. +/// The iterator is only valid while the reply exists. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryReply_GetIterator( + reply: *mut VecSimQueryReply, +) -> *mut VecSimQueryReply_Iterator { + if reply.is_null() { + return ptr::null_mut(); + } + + let handle = &*(reply as *const QueryReplyHandle); + let iter = QueryReplyIteratorHandle::new(&handle.reply.results); + Box::into_raw(Box::new(iter)) as *mut VecSimQueryReply_Iterator +} + +/// Check if the iterator has more results. +/// +/// # Safety +/// `iter` must be a valid pointer returned by `VecSimQueryReply_GetIterator`. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryReply_IteratorHasNext( + iter: *const VecSimQueryReply_Iterator, +) -> bool { + if iter.is_null() { + return false; + } + + let handle = &*(iter as *const QueryReplyIteratorHandle); + handle.has_next() +} + +/// Get the next result from the iterator. +/// +/// # Safety +/// `iter` must be a valid pointer returned by `VecSimQueryReply_GetIterator`. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryReply_IteratorNext( + iter: *mut VecSimQueryReply_Iterator, +) -> *const VecSimQueryResult { + if iter.is_null() { + return ptr::null(); + } + + let handle = &mut *(iter as *mut QueryReplyIteratorHandle); + match handle.next() { + Some(result) => result as *const QueryResultInternal as *const VecSimQueryResult, + None => ptr::null(), + } +} + +/// Reset the iterator to the beginning. +/// +/// # Safety +/// `iter` must be a valid pointer returned by `VecSimQueryReply_GetIterator`. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryReply_IteratorReset(iter: *mut VecSimQueryReply_Iterator) { + if !iter.is_null() { + let handle = &mut *(iter as *mut QueryReplyIteratorHandle); + handle.reset(); + } +} + +/// Free an iterator. +/// +/// # Safety +/// `iter` must have been returned by `VecSimQueryReply_GetIterator` or be null. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryReply_IteratorFree(iter: *mut VecSimQueryReply_Iterator) { + if !iter.is_null() { + // Note: We need to be careful here because the iterator borrows from the reply. + // For safety, we'll just leak the memory rather than double-free. + // In a proper implementation, we'd use reference counting or different ownership. + let _ = Box::from_raw(iter as *mut QueryReplyIteratorHandle<'static>); + } +} + +// ============================================================================ +// Query Result Functions +// ============================================================================ + +/// Get the label (ID) from a query result. +/// +/// # Safety +/// `result` must be a valid pointer from the iterator. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryResult_GetId(result: *const VecSimQueryResult) -> labelType { + if result.is_null() { + return 0; + } + + let internal = &*(result as *const QueryResultInternal); + internal.id +} + +/// Get the score (distance) from a query result. +/// +/// # Safety +/// `result` must be a valid pointer from the iterator. +#[no_mangle] +pub unsafe extern "C" fn VecSimQueryResult_GetScore(result: *const VecSimQueryResult) -> f64 { + if result.is_null() { + return f64::INFINITY; + } + + let internal = &*(result as *const QueryResultInternal); + internal.score +} + +// ============================================================================ +// Batch Iterator Functions +// ============================================================================ + +/// Create a batch iterator for incremental query processing. +/// +/// # Safety +/// - `index` must be a valid pointer returned by `VecSimIndex_New` +/// - `query` must point to a valid array of the correct type and dimension +/// - `params` may be null for default parameters +#[no_mangle] +pub unsafe extern "C" fn VecSimBatchIterator_New( + index: *mut VecSimIndex, + query: *const c_void, + params: *const VecSimQueryParams, +) -> *mut VecSimBatchIterator { + if index.is_null() || query.is_null() { + return ptr::null_mut(); + } + + let handle = &*(index as *const IndexHandle); + let params_opt = if params.is_null() { None } else { Some(&*params) }; + + match create_batch_iterator(handle, query, params_opt) { + Some(iter_handle) => Box::into_raw(Box::new(iter_handle)) as *mut VecSimBatchIterator, + None => ptr::null_mut(), + } +} + +/// Get the next batch of results. +/// +/// # Safety +/// `iter` must be a valid pointer returned by `VecSimBatchIterator_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimBatchIterator_Next( + iter: *mut VecSimBatchIterator, + n: usize, + order: VecSimQueryReply_Order, +) -> *mut VecSimQueryReply { + if iter.is_null() { + return ptr::null_mut(); + } + + let handle = &mut *(iter as *mut BatchIteratorHandle); + let reply = handle.next(n, order); + Box::into_raw(Box::new(reply)) as *mut VecSimQueryReply +} + +/// Check if the batch iterator has more results. +/// +/// # Safety +/// `iter` must be a valid pointer returned by `VecSimBatchIterator_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimBatchIterator_HasNext(iter: *const VecSimBatchIterator) -> bool { + if iter.is_null() { + return false; + } + + let handle = &*(iter as *const BatchIteratorHandle); + handle.has_next() +} + +/// Reset the batch iterator. +/// +/// # Safety +/// `iter` must be a valid pointer returned by `VecSimBatchIterator_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimBatchIterator_Reset(iter: *mut VecSimBatchIterator) { + if !iter.is_null() { + let handle = &mut *(iter as *mut BatchIteratorHandle); + handle.reset(); + } +} + +/// Free a batch iterator. +/// +/// # Safety +/// `iter` must have been returned by `VecSimBatchIterator_New` or be null. +#[no_mangle] +pub unsafe extern "C" fn VecSimBatchIterator_Free(iter: *mut VecSimBatchIterator) { + if !iter.is_null() { + drop(Box::from_raw(iter as *mut BatchIteratorHandle)); + } +} + +// ============================================================================ +// Index Property Functions +// ============================================================================ + +/// Get the current number of vectors in the index. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_IndexSize(index: *const VecSimIndex) -> usize { + if index.is_null() { + return 0; + } + + let handle = &*(index as *const IndexHandle); + handle.wrapper.index_size() +} + +/// Get the data type of the index. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_GetType(index: *const VecSimIndex) -> VecSimType { + if index.is_null() { + return VecSimType::VecSimType_FLOAT32; + } + + let handle = &*(index as *const IndexHandle); + handle.data_type +} + +/// Get the distance metric of the index. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_GetMetric(index: *const VecSimIndex) -> VecSimMetric { + if index.is_null() { + return VecSimMetric::VecSimMetric_L2; + } + + let handle = &*(index as *const IndexHandle); + handle.metric +} + +/// Get the vector dimension of the index. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_GetDim(index: *const VecSimIndex) -> usize { + if index.is_null() { + return 0; + } + + let handle = &*(index as *const IndexHandle); + handle.dim +} + +/// Check if the index is a multi-value index. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_IsMulti(index: *const VecSimIndex) -> bool { + if index.is_null() { + return false; + } + + let handle = &*(index as *const IndexHandle); + handle.is_multi +} + +/// Check if a label exists in the index. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_ContainsLabel( + index: *const VecSimIndex, + label: labelType, +) -> bool { + if index.is_null() { + return false; + } + + let handle = &*(index as *const IndexHandle); + handle.wrapper.contains(label) +} + +/// Get the count of vectors with the given label. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_LabelCount( + index: *const VecSimIndex, + label: labelType, +) -> usize { + if index.is_null() { + return 0; + } + + let handle = &*(index as *const IndexHandle); + handle.wrapper.label_count(label) +} + +/// Get detailed index information. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_Info(index: *const VecSimIndex) -> VecSimIndexInfo { + if index.is_null() { + return VecSimIndexInfo::default(); + } + + let handle = &*(index as *const IndexHandle); + get_index_info(handle) +} + +/// Get basic immutable index information. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_BasicInfo( + index: *const VecSimIndex, +) -> info::VecSimIndexBasicInfo { + if index.is_null() { + return info::VecSimIndexBasicInfo { + algo: VecSimAlgo::VecSimAlgo_BF, + metric: VecSimMetric::VecSimMetric_L2, + type_: VecSimType::VecSimType_FLOAT32, + isMulti: false, + isTiered: false, + isDisk: false, + blockSize: 0, + dim: 0, + }; + } + + let handle = &*(index as *const IndexHandle); + info::VecSimIndexBasicInfo { + algo: handle.algo, + metric: handle.metric, + type_: handle.data_type, + isMulti: handle.is_multi, + isTiered: handle.wrapper.is_tiered(), + isDisk: handle.wrapper.is_disk(), + blockSize: 1024, // Default block size + dim: handle.dim, + } +} + +/// Get index statistics information. +/// +/// This is a thin and efficient info call with no locks or calculations. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_StatsInfo( + index: *const VecSimIndex, +) -> info::VecSimIndexStatsInfo { + if index.is_null() { + return info::VecSimIndexStatsInfo { + memory: 0, + numberOfMarkedDeleted: 0, + }; + } + + let handle = &*(index as *const IndexHandle); + info::VecSimIndexStatsInfo { + memory: handle.wrapper.memory_usage(), + numberOfMarkedDeleted: 0, // TODO: implement marked deleted tracking + } +} + +/// Get detailed debug information for an index. +/// +/// This should only be used for debug/testing purposes. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_DebugInfo( + index: *const VecSimIndex, +) -> info::VecSimIndexDebugInfo { + use types::VecSearchMode; + + if index.is_null() { + return info::VecSimIndexDebugInfo { + commonInfo: info::CommonInfo { + basicInfo: info::VecSimIndexBasicInfo { + algo: VecSimAlgo::VecSimAlgo_BF, + metric: VecSimMetric::VecSimMetric_L2, + type_: VecSimType::VecSimType_FLOAT32, + isMulti: false, + isTiered: false, + isDisk: false, + blockSize: 0, + dim: 0, + }, + indexSize: 0, + indexLabelCount: 0, + memory: 0, + lastMode: VecSearchMode::EMPTY_MODE, + }, + hnswInfo: info::hnswInfoStruct { + M: 0, + efConstruction: 0, + efRuntime: 0, + epsilon: 0.0, + max_level: 0, + entrypoint: 0, + visitedNodesPoolSize: 0, + numberOfMarkedDeletedNodes: 0, + }, + bfInfo: info::bfInfoStruct { dummy: 0 }, + }; + } + + let handle = &*(index as *const IndexHandle); + let basic_info = VecSimIndex_BasicInfo(index); + + info::VecSimIndexDebugInfo { + commonInfo: info::CommonInfo { + basicInfo: basic_info, + indexSize: handle.wrapper.index_size(), + indexLabelCount: handle.wrapper.index_size(), // Approximate + memory: handle.wrapper.memory_usage() as u64, + lastMode: VecSearchMode::EMPTY_MODE, + }, + hnswInfo: info::hnswInfoStruct { + M: 16, // Default, would need to query from index + efConstruction: 200, + efRuntime: 10, + epsilon: 0.0, + max_level: 0, + entrypoint: 0, + visitedNodesPoolSize: 0, + numberOfMarkedDeletedNodes: 0, + }, + bfInfo: info::bfInfoStruct { dummy: 0 }, + } +} + +// ============================================================================ +// Debug Info Iterator Functions +// ============================================================================ + +/// Opaque type for C API - re-export from info module. +pub type VecSimDebugInfoIterator = info::VecSimDebugInfoIterator; + +/// Create a debug info iterator for an index. +/// +/// The iterator provides a way to traverse all debug information fields +/// for an index, including nested information for tiered indices. +/// +/// # Safety +/// `index` must be a valid pointer returned by `VecSimIndex_New`. +/// The returned iterator must be freed with `VecSimDebugInfoIterator_Free`. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_DebugInfoIterator( + index: *const VecSimIndex, +) -> *mut VecSimDebugInfoIterator { + if index.is_null() { + return std::ptr::null_mut(); + } + + let handle = &*(index as *const IndexHandle); + let basic_info = VecSimIndex_BasicInfo(index); + + // Create iterator based on index type + let iter = match basic_info.algo { + VecSimAlgo::VecSimAlgo_BF => create_bf_debug_iterator(handle, &basic_info), + VecSimAlgo::VecSimAlgo_HNSWLIB => create_hnsw_debug_iterator(handle, &basic_info), + VecSimAlgo::VecSimAlgo_SVS => create_svs_debug_iterator(handle, &basic_info), + VecSimAlgo::VecSimAlgo_TIERED => create_tiered_debug_iterator(handle, &basic_info), + }; + + Box::into_raw(Box::new(iter)) +} + +/// Helper to add common fields to an iterator. +fn add_common_fields( + iter: &mut info::VecSimDebugInfoIterator, + handle: &IndexHandle, + basic_info: &info::VecSimIndexBasicInfo, +) { + iter.add_string_field("TYPE", info::type_to_string(basic_info.type_)); + iter.add_uint64_field("DIMENSION", basic_info.dim as u64); + iter.add_string_field("METRIC", info::metric_to_string(basic_info.metric)); + iter.add_string_field( + "IS_MULTI_VALUE", + if basic_info.isMulti { "true" } else { "false" }, + ); + iter.add_string_field( + "IS_DISK", + if basic_info.isDisk { "true" } else { "false" }, + ); + iter.add_uint64_field("INDEX_SIZE", handle.wrapper.index_size() as u64); + iter.add_uint64_field("INDEX_LABEL_COUNT", handle.wrapper.index_size() as u64); + iter.add_uint64_field("MEMORY", handle.wrapper.memory_usage() as u64); + iter.add_string_field("LAST_SEARCH_MODE", "EMPTY_MODE"); +} + +/// Create a debug iterator for BruteForce index. +fn create_bf_debug_iterator( + handle: &IndexHandle, + basic_info: &info::VecSimIndexBasicInfo, +) -> info::VecSimDebugInfoIterator { + let mut iter = info::VecSimDebugInfoIterator::new(10); + + iter.add_string_field("ALGORITHM", "FLAT"); + add_common_fields(&mut iter, handle, basic_info); + iter.add_uint64_field("BLOCK_SIZE", basic_info.blockSize as u64); + + iter +} + +/// Create a debug iterator for HNSW index. +fn create_hnsw_debug_iterator( + handle: &IndexHandle, + basic_info: &info::VecSimIndexBasicInfo, +) -> info::VecSimDebugInfoIterator { + let mut iter = info::VecSimDebugInfoIterator::new(17); + + iter.add_string_field("ALGORITHM", "HNSW"); + add_common_fields(&mut iter, handle, basic_info); + iter.add_uint64_field("BLOCK_SIZE", basic_info.blockSize as u64); + + // HNSW-specific fields (defaults, would need to query actual values) + iter.add_uint64_field("M", 16); + iter.add_uint64_field("EF_CONSTRUCTION", 200); + iter.add_uint64_field("EF_RUNTIME", 10); + iter.add_float64_field("EPSILON", 0.01); + iter.add_uint64_field("MAX_LEVEL", 0); + iter.add_uint64_field("ENTRYPOINT", 0); + iter.add_uint64_field("NUMBER_OF_MARKED_DELETED", 0); + + iter +} + +/// Create a debug iterator for SVS index. +fn create_svs_debug_iterator( + handle: &IndexHandle, + basic_info: &info::VecSimIndexBasicInfo, +) -> info::VecSimDebugInfoIterator { + let mut iter = info::VecSimDebugInfoIterator::new(23); + + iter.add_string_field("ALGORITHM", "SVS"); + add_common_fields(&mut iter, handle, basic_info); + iter.add_uint64_field("BLOCK_SIZE", basic_info.blockSize as u64); + + // SVS-specific fields (defaults) + iter.add_string_field("QUANT_BITS", "NONE"); + iter.add_float64_field("ALPHA", 1.2); + iter.add_uint64_field("GRAPH_MAX_DEGREE", 32); + iter.add_uint64_field("CONSTRUCTION_WINDOW_SIZE", 200); + iter.add_uint64_field("MAX_CANDIDATE_POOL_SIZE", 0); + iter.add_uint64_field("PRUNE_TO", 0); + iter.add_string_field("USE_SEARCH_HISTORY", "AUTO"); + iter.add_uint64_field("NUM_THREADS", 1); + iter.add_uint64_field("LAST_RESERVED_NUM_THREADS", 0); + iter.add_uint64_field("NUMBER_OF_MARKED_DELETED", 0); + iter.add_uint64_field("SEARCH_WINDOW_SIZE", 10); + iter.add_uint64_field("SEARCH_BUFFER_CAPACITY", 0); + iter.add_uint64_field("LEANVEC_DIMENSION", 0); + iter.add_float64_field("EPSILON", 0.01); + + iter +} + +/// Create a debug iterator for tiered index. +fn create_tiered_debug_iterator( + handle: &IndexHandle, + basic_info: &info::VecSimIndexBasicInfo, +) -> info::VecSimDebugInfoIterator { + let mut iter = info::VecSimDebugInfoIterator::new(15); + + iter.add_string_field("ALGORITHM", "TIERED"); + add_common_fields(&mut iter, handle, basic_info); + + // Tiered-specific fields + iter.add_uint64_field("MANAGEMENT_LAYER_MEMORY", 0); + iter.add_string_field("BACKGROUND_INDEXING", "false"); + iter.add_uint64_field("TIERED_BUFFER_LIMIT", 0); + + // Create frontend (flat) iterator + let frontend_iter = create_bf_debug_iterator(handle, basic_info); + iter.add_iterator_field("FRONTEND_INDEX", frontend_iter); + + // Create backend (hnsw) iterator + let backend_iter = create_hnsw_debug_iterator(handle, basic_info); + iter.add_iterator_field("BACKEND_INDEX", backend_iter); + + iter +} + +/// Returns the number of fields in the info iterator. +/// +/// # Safety +/// `info_iterator` must be a valid pointer returned by `VecSimIndex_DebugInfoIterator`. +#[no_mangle] +pub unsafe extern "C" fn VecSimDebugInfoIterator_NumberOfFields( + info_iterator: *mut VecSimDebugInfoIterator, +) -> usize { + if info_iterator.is_null() { + return 0; + } + (*info_iterator).number_of_fields() +} + +/// Returns if the fields iterator has more fields. +/// +/// # Safety +/// `info_iterator` must be a valid pointer returned by `VecSimIndex_DebugInfoIterator`. +#[no_mangle] +pub unsafe extern "C" fn VecSimDebugInfoIterator_HasNextField( + info_iterator: *mut VecSimDebugInfoIterator, +) -> bool { + if info_iterator.is_null() { + return false; + } + (*info_iterator).has_next() +} + +/// Returns a pointer to the next info field. +/// +/// The returned pointer is valid until the next call to this function +/// or until the iterator is freed. +/// +/// # Safety +/// `info_iterator` must be a valid pointer returned by `VecSimIndex_DebugInfoIterator`. +#[no_mangle] +pub unsafe extern "C" fn VecSimDebugInfoIterator_NextField( + info_iterator: *mut VecSimDebugInfoIterator, +) -> *mut info::VecSim_InfoField { + if info_iterator.is_null() { + return std::ptr::null_mut(); + } + (*info_iterator).next() +} + +/// Free an info iterator. +/// +/// This also frees all nested iterators. +/// +/// # Safety +/// `info_iterator` must be a valid pointer returned by `VecSimIndex_DebugInfoIterator`, +/// or null. +#[no_mangle] +pub unsafe extern "C" fn VecSimDebugInfoIterator_Free( + info_iterator: *mut VecSimDebugInfoIterator, +) { + if !info_iterator.is_null() { + drop(Box::from_raw(info_iterator)); + } +} + +/// Dump the neighbors of an element in HNSW index. +/// +/// Returns an array with entries, where each entry is an array itself. +/// Every internal array in a position where <0<=l<=topLevel> corresponds to the neighbors +/// of the element in the graph in level . It contains entries, where is the +/// number of neighbors in level l. The last entry in the external array is NULL (indicates its length). +/// The first entry in each internal array contains the number , while the next +/// entries are the labels of the elements neighbors in this level. +/// +/// # Safety +/// - `index` must be a valid pointer returned by `VecSimIndex_New` +/// - `neighbors_data` must be a valid pointer to a pointer that will receive the result +#[no_mangle] +pub unsafe extern "C" fn VecSimDebug_GetElementNeighborsInHNSWGraph( + index: *mut VecSimIndex, + label: usize, + neighbors_data: *mut *mut *mut c_int, +) -> VecSimDebugCommandCode { + if index.is_null() || neighbors_data.is_null() { + return VecSimDebugCommandCode::VecSimDebugCommandCode_BadIndex; + } + + let handle = &*(index as *const IndexHandle); + let basic_info = VecSimIndex_BasicInfo(index); + + // Only HNSW and Tiered (with HNSW backend) are supported + match basic_info.algo { + VecSimAlgo::VecSimAlgo_HNSWLIB | VecSimAlgo::VecSimAlgo_TIERED => {} + _ => return VecSimDebugCommandCode::VecSimDebugCommandCode_BadIndex, + } + + // Get the element neighbors using the trait method + let hnsw_result = handle.wrapper.get_element_neighbors(label as u64); + + match hnsw_result { + None => VecSimDebugCommandCode::VecSimDebugCommandCode_LabelNotExists, + Some(neighbors_by_level) => { + // Allocate the outer array: topLevel + 2 entries (one for each level + NULL terminator) + let num_levels = neighbors_by_level.len(); + let outer_size = num_levels + 1; // +1 for NULL terminator + let outer_array = libc::malloc(outer_size * std::mem::size_of::<*mut c_int>()) + as *mut *mut c_int; + + if outer_array.is_null() { + return VecSimDebugCommandCode::VecSimDebugCommandCode_BadIndex; + } + + // Fill in each level's neighbors + for (level, level_neighbors) in neighbors_by_level.iter().enumerate() { + let num_neighbors = level_neighbors.len(); + // Each inner array has: count + neighbor labels + let inner_size = num_neighbors + 1; + let inner_array = + libc::malloc(inner_size * std::mem::size_of::()) as *mut c_int; + + if inner_array.is_null() { + // Clean up already allocated arrays + for i in 0..level { + libc::free(*outer_array.add(i) as *mut libc::c_void); + } + libc::free(outer_array as *mut libc::c_void); + return VecSimDebugCommandCode::VecSimDebugCommandCode_BadIndex; + } + + // First entry is the count + *inner_array = num_neighbors as c_int; + + // Remaining entries are the neighbor labels + for (i, &neighbor_label) in level_neighbors.iter().enumerate() { + *inner_array.add(i + 1) = neighbor_label as c_int; + } + + *outer_array.add(level) = inner_array; + } + + // NULL terminator + *outer_array.add(num_levels) = std::ptr::null_mut(); + + *neighbors_data = outer_array; + VecSimDebugCommandCode::VecSimDebugCommandCode_OK + } + } +} + +/// Release the neighbors data allocated by VecSimDebug_GetElementNeighborsInHNSWGraph. +/// +/// # Safety +/// `neighbors_data` must be a valid pointer returned by `VecSimDebug_GetElementNeighborsInHNSWGraph`, +/// or null. +#[no_mangle] +pub unsafe extern "C" fn VecSimDebug_ReleaseElementNeighborsInHNSWGraph( + neighbors_data: *mut *mut c_int, +) { + if neighbors_data.is_null() { + return; + } + + // Free each inner array until we hit NULL + let mut level = 0; + loop { + let inner_array = *neighbors_data.add(level); + if inner_array.is_null() { + break; + } + libc::free(inner_array as *mut libc::c_void); + level += 1; + } + + // Free the outer array + libc::free(neighbors_data as *mut libc::c_void); +} + +// ============================================================================ +// Serialization Functions +// ============================================================================ + +/// Save an index to a file. +/// +/// Returns true on success, false on failure. +/// +/// # Safety +/// - `index` must be a valid pointer returned by `VecSimIndex_New` +/// - `path` must be a valid null-terminated C string +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_SaveIndex( + index: *const VecSimIndex, + path: *const c_char, +) -> bool { + if index.is_null() || path.is_null() { + return false; + } + + let handle = &*(index as *const IndexHandle); + let path_str = match std::ffi::CStr::from_ptr(path).to_str() { + Ok(s) => s, + Err(_) => return false, + }; + + let path = std::path::Path::new(path_str); + handle.wrapper.save_to_file(path) +} + +/// Load an index from a file. +/// +/// Returns a pointer to the loaded index, or null on failure. +/// The caller is responsible for freeing the index with VecSimIndex_Free. +/// +/// # Safety +/// - `path` must be a valid null-terminated C string +/// - `params` may be null (will use parameters from the file) +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_LoadIndex( + path: *const c_char, + _params: *const VecSimParams_C, +) -> *mut VecSimIndex { + if path.is_null() { + return ptr::null_mut(); + } + + let path_str = match std::ffi::CStr::from_ptr(path).to_str() { + Ok(s) => s, + Err(_) => return ptr::null_mut(), + }; + + let file_path = std::path::Path::new(path_str); + + // Try to load the index by reading the header first to determine the type + match load_index_from_file(file_path) { + Some(handle) => Box::into_raw(handle) as *mut VecSimIndex, + None => ptr::null_mut(), + } +} + +/// Internal function to load an index from a file. +/// Reads the header to determine the index type and loads accordingly. +fn load_index_from_file(path: &std::path::Path) -> Option> { + use std::fs::File; + use std::io::BufReader; + use vecsim::serialization::{IndexHeader, IndexTypeId, DataTypeId}; + + // Open file and read header + let file = File::open(path).ok()?; + let mut reader = BufReader::new(file); + let header = IndexHeader::read(&mut reader).ok()?; + + // Based on index type and data type, load the appropriate index + match (header.index_type, header.data_type) { + // BruteForce Single + (IndexTypeId::BruteForceSingle, DataTypeId::F32) => { + use vecsim::index::BruteForceSingle; + let index = BruteForceSingle::::load_from_file(path).ok()?; + let metric = convert_metric_to_c(header.metric); + Some(Box::new(IndexHandle::new( + Box::new(index::BruteForceSingleF32Wrapper::new( + index, + VecSimType::VecSimType_FLOAT32, + )), + VecSimType::VecSimType_FLOAT32, + VecSimAlgo::VecSimAlgo_BF, + metric, + header.dimension, + false, + ))) + } + // BruteForce Multi + (IndexTypeId::BruteForceMulti, DataTypeId::F32) => { + use vecsim::index::BruteForceMulti; + let index = BruteForceMulti::::load_from_file(path).ok()?; + let metric = convert_metric_to_c(header.metric); + Some(Box::new(IndexHandle::new( + Box::new(index::BruteForceMultiF32Wrapper::new( + index, + VecSimType::VecSimType_FLOAT32, + )), + VecSimType::VecSimType_FLOAT32, + VecSimAlgo::VecSimAlgo_BF, + metric, + header.dimension, + true, + ))) + } + // HNSW Single + (IndexTypeId::HnswSingle, DataTypeId::F32) => { + use vecsim::index::HnswSingle; + let index = HnswSingle::::load_from_file(path).ok()?; + let metric = convert_metric_to_c(header.metric); + Some(Box::new(IndexHandle::new( + Box::new(index::HnswSingleF32Wrapper::new( + index, + VecSimType::VecSimType_FLOAT32, + )), + VecSimType::VecSimType_FLOAT32, + VecSimAlgo::VecSimAlgo_HNSWLIB, + metric, + header.dimension, + false, + ))) + } + // HNSW Multi + (IndexTypeId::HnswMulti, DataTypeId::F32) => { + use vecsim::index::HnswMulti; + let index = HnswMulti::::load_from_file(path).ok()?; + let metric = convert_metric_to_c(header.metric); + Some(Box::new(IndexHandle::new( + Box::new(index::HnswMultiF32Wrapper::new( + index, + VecSimType::VecSimType_FLOAT32, + )), + VecSimType::VecSimType_FLOAT32, + VecSimAlgo::VecSimAlgo_HNSWLIB, + metric, + header.dimension, + true, + ))) + } + // SVS Single + (IndexTypeId::SvsSingle, DataTypeId::F32) => { + use vecsim::index::SvsSingle; + let index = SvsSingle::::load_from_file(path).ok()?; + let metric = convert_metric_to_c(header.metric); + Some(Box::new(IndexHandle::new( + Box::new(index::SvsSingleF32Wrapper::new( + index, + VecSimType::VecSimType_FLOAT32, + )), + VecSimType::VecSimType_FLOAT32, + VecSimAlgo::VecSimAlgo_SVS, + metric, + header.dimension, + false, + ))) + } + // Unsupported combinations + _ => None, + } +} + +/// Convert Rust Metric to C VecSimMetric +fn convert_metric_to_c(metric: vecsim::distance::Metric) -> VecSimMetric { + match metric { + vecsim::distance::Metric::L2 => VecSimMetric::VecSimMetric_L2, + vecsim::distance::Metric::InnerProduct => VecSimMetric::VecSimMetric_IP, + vecsim::distance::Metric::Cosine => VecSimMetric::VecSimMetric_Cosine, + } +} + +// ============================================================================ +// Vector Utility Functions +// ============================================================================ + +/// Normalize a vector in-place. +/// +/// This normalizes the vector to unit length (L2 norm = 1). +/// This is useful for cosine similarity where vectors should be normalized. +/// +/// # Safety +/// - `blob` must point to a valid array of the correct type and dimension +/// - The array must be writable +#[no_mangle] +pub unsafe extern "C" fn VecSim_Normalize( + blob: *mut c_void, + dim: usize, + type_: VecSimType, +) { + if blob.is_null() || dim == 0 { + return; + } + + match type_ { + VecSimType::VecSimType_FLOAT32 => { + let data = std::slice::from_raw_parts_mut(blob as *mut f32, dim); + let norm: f32 = data.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + for x in data.iter_mut() { + *x /= norm; + } + } + } + VecSimType::VecSimType_FLOAT64 => { + let data = std::slice::from_raw_parts_mut(blob as *mut f64, dim); + let norm: f64 = data.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + for x in data.iter_mut() { + *x /= norm; + } + } + } + // Other types don't support normalization + _ => {} + } +} + +// ============================================================================ +// Memory Estimation Functions +// ============================================================================ + +/// Estimate initial memory size for a BruteForce index. +#[no_mangle] +pub extern "C" fn VecSimIndex_EstimateBruteForceInitialSize( + dim: usize, + initial_capacity: usize, +) -> usize { + vecsim::index::estimate_brute_force_initial_size(dim, initial_capacity) +} + +/// Estimate memory size per element for a BruteForce index. +#[no_mangle] +pub extern "C" fn VecSimIndex_EstimateBruteForceElementSize(dim: usize) -> usize { + vecsim::index::estimate_brute_force_element_size(dim) +} + +/// Estimate initial memory size for an HNSW index. +#[no_mangle] +pub extern "C" fn VecSimIndex_EstimateHNSWInitialSize( + dim: usize, + initial_capacity: usize, + m: usize, +) -> usize { + vecsim::index::estimate_hnsw_initial_size(dim, initial_capacity, m) +} + +/// Estimate memory size per element for an HNSW index. +#[no_mangle] +pub extern "C" fn VecSimIndex_EstimateHNSWElementSize(dim: usize, m: usize) -> usize { + vecsim::index::estimate_hnsw_element_size(dim, m) +} + +/// Estimate initial memory size for an index based on parameters. +/// +/// # Safety +/// `params` must be a valid pointer to a VecSimParams_C struct. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_EstimateInitialSize( + params: *const VecSimParams_C, +) -> usize { + if params.is_null() { + return 0; + } + + let params = &*params; + + match params.algo { + VecSimAlgo::VecSimAlgo_BF => { + let bf = params.algoParams.bfParams; + vecsim::index::estimate_brute_force_initial_size(bf.dim, bf.initialCapacity) + } + VecSimAlgo::VecSimAlgo_HNSWLIB => { + let hnsw = params.algoParams.hnswParams; + vecsim::index::estimate_hnsw_initial_size(hnsw.dim, hnsw.initialCapacity, hnsw.M) + } + VecSimAlgo::VecSimAlgo_TIERED => { + // Tiered uses HNSW params from the tieredParams.primaryIndexParams + // For now, use HNSW params from the union + let hnsw = params.algoParams.hnswParams; + let bf_size = vecsim::index::estimate_brute_force_initial_size(hnsw.dim, hnsw.initialCapacity); + let hnsw_size = vecsim::index::estimate_hnsw_initial_size(hnsw.dim, hnsw.initialCapacity, hnsw.M); + bf_size + hnsw_size + } + VecSimAlgo::VecSimAlgo_SVS => { + let svs = params.algoParams.svsParams; + // SVS doesn't have initialCapacity, use a default + vecsim::index::estimate_hnsw_initial_size(svs.dim, 1024, svs.graph_max_degree) + } + } +} + +/// Estimate memory size per element for an index based on parameters. +/// +/// # Safety +/// `params` must be a valid pointer to a VecSimParams_C struct. +#[no_mangle] +pub unsafe extern "C" fn VecSimIndex_EstimateElementSize( + params: *const VecSimParams_C, +) -> usize { + if params.is_null() { + return 0; + } + + let params = &*params; + + match params.algo { + VecSimAlgo::VecSimAlgo_BF => { + let bf = params.algoParams.bfParams; + vecsim::index::estimate_brute_force_element_size(bf.dim) + } + VecSimAlgo::VecSimAlgo_HNSWLIB => { + let hnsw = params.algoParams.hnswParams; + vecsim::index::estimate_hnsw_element_size(hnsw.dim, hnsw.M) + } + VecSimAlgo::VecSimAlgo_TIERED => { + // Use HNSW element size (vectors end up in HNSW) + let hnsw = params.algoParams.hnswParams; + vecsim::index::estimate_hnsw_element_size(hnsw.dim, hnsw.M) + } + VecSimAlgo::VecSimAlgo_SVS => { + let svs = params.algoParams.svsParams; + vecsim::index::estimate_hnsw_element_size(svs.dim, svs.graph_max_degree) + } + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_write_mode_default() { + // Default should be WriteAsync + let mode = VecSim_GetWriteMode(); + assert_eq!(mode, VecSimWriteMode::VecSim_WriteAsync); + } + + #[test] + fn test_write_mode_set_and_get() { + // Set to InPlace + VecSim_SetWriteMode(VecSimWriteMode::VecSim_WriteInPlace); + assert_eq!(VecSim_GetWriteMode(), VecSimWriteMode::VecSim_WriteInPlace); + + // Set back to Async + VecSim_SetWriteMode(VecSimWriteMode::VecSim_WriteAsync); + assert_eq!(VecSim_GetWriteMode(), VecSimWriteMode::VecSim_WriteAsync); + } + + #[test] + fn test_write_mode_internal_function() { + VecSim_SetWriteMode(VecSimWriteMode::VecSim_WriteInPlace); + assert_eq!(get_global_write_mode(), VecSimWriteMode::VecSim_WriteInPlace); + + VecSim_SetWriteMode(VecSimWriteMode::VecSim_WriteAsync); + assert_eq!(get_global_write_mode(), VecSimWriteMode::VecSim_WriteAsync); + } + + // Helper to create a VecSimRawParam from strings + fn make_raw_param(name: &str, value: &str) -> VecSimRawParam { + VecSimRawParam { + name: name.as_ptr() as *const std::ffi::c_char, + nameLen: name.len(), + value: value.as_ptr() as *const std::ffi::c_char, + valLen: value.len(), + } + } + + #[test] + fn test_struct_sizes_match_c() { + use std::mem::size_of; + use crate::compat::AlgoParams_C; + // These sizes must match the C header exactly + // C header sizes (from check_rust_header_sizes.c): + // sizeof(BFParams): 40 + // sizeof(HNSWParams): 72 + // sizeof(SVSParams): 120 + // sizeof(AlgoParams): 120 + // sizeof(VecSimParams): 136 + assert_eq!(size_of::(), 40, "BFParams_C size mismatch"); + assert_eq!(size_of::(), 72, "HNSWParams_C size mismatch"); + assert_eq!(size_of::(), 120, "SVSParams_C size mismatch"); + assert_eq!(size_of::(), 120, "AlgoParams_C size mismatch"); + assert_eq!(size_of::(), 136, "VecSimParams_C size mismatch"); + } + + // Helper to create HNSW params with valid dimensions (C++-compatible) + fn test_hnsw_params() -> HNSWParams_C { + HNSWParams_C { + type_: VecSimType::VecSimType_FLOAT32, + metric: VecSimMetric::VecSimMetric_L2, + dim: 4, + multi: false, + initialCapacity: 100, + blockSize: 0, + M: 16, + efConstruction: 200, + efRuntime: 10, + epsilon: 0.0, + } + } + + // Helper to create BF params with valid dimensions (C++-compatible) + fn test_bf_params() -> BFParams_C { + BFParams_C { + type_: VecSimType::VecSimType_FLOAT32, + metric: VecSimMetric::VecSimMetric_L2, + dim: 4, + multi: false, + initialCapacity: 100, + blockSize: 0, + } + } + + #[test] + fn test_resolve_params_null_qparams() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let result = VecSimIndex_ResolveParams( + index, + std::ptr::null(), + 0, + std::ptr::null_mut(), + VecsimQueryType::QUERY_TYPE_KNN, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolverErr_NullParam); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_empty() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let mut qparams = VecSimQueryParams::default(); + let result = VecSimIndex_ResolveParams( + index, + std::ptr::null(), + 0, + &mut qparams, + VecsimQueryType::QUERY_TYPE_KNN, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolver_OK); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_ef_runtime() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let name = "EF_RUNTIME"; + let value = "100"; + let ef_param = make_raw_param(name, value); + let mut qparams = VecSimQueryParams::default(); + + let result = VecSimIndex_ResolveParams( + index, + &ef_param, + 1, + &mut qparams, + VecsimQueryType::QUERY_TYPE_KNN, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolver_OK); + assert_eq!(qparams.hnsw_params().efRuntime, 100); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_ef_runtime_invalid_for_bf() { + let params = test_bf_params(); + unsafe { + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + let name = "EF_RUNTIME"; + let value = "100"; + let ef_param = make_raw_param(name, value); + let mut qparams = VecSimQueryParams::default(); + + let result = VecSimIndex_ResolveParams( + index, + &ef_param, + 1, + &mut qparams, + VecsimQueryType::QUERY_TYPE_KNN, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolverErr_UnknownParam); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_epsilon_for_range() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let name = "EPSILON"; + let value = "0.01"; + let eps_param = make_raw_param(name, value); + let mut qparams = VecSimQueryParams::default(); + + let result = VecSimIndex_ResolveParams( + index, + &eps_param, + 1, + &mut qparams, + VecsimQueryType::QUERY_TYPE_RANGE, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolver_OK); + assert!((qparams.hnsw_params().epsilon - 0.01).abs() < 0.0001); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_epsilon_not_for_knn() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let name = "EPSILON"; + let value = "0.01"; + let eps_param = make_raw_param(name, value); + let mut qparams = VecSimQueryParams::default(); + + let result = VecSimIndex_ResolveParams( + index, + &eps_param, + 1, + &mut qparams, + VecsimQueryType::QUERY_TYPE_KNN, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolverErr_InvalidPolicy_NRange); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_batch_size_for_hybrid() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let name = "BATCH_SIZE"; + let value = "256"; + let batch_param = make_raw_param(name, value); + let mut qparams = VecSimQueryParams::default(); + + let result = VecSimIndex_ResolveParams( + index, + &batch_param, + 1, + &mut qparams, + VecsimQueryType::QUERY_TYPE_HYBRID, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolver_OK); + assert_eq!(qparams.batchSize, 256); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_duplicate_param() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let name1 = "EF_RUNTIME"; + let value1 = "100"; + let name2 = "EF_RUNTIME"; + let value2 = "200"; + let ef_params = [ + make_raw_param(name1, value1), + make_raw_param(name2, value2), + ]; + let mut qparams = VecSimQueryParams::default(); + + let result = VecSimIndex_ResolveParams( + index, + ef_params.as_ptr(), + 2, + &mut qparams, + VecsimQueryType::QUERY_TYPE_KNN, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolverErr_AlreadySet); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_unknown_param() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let name = "UNKNOWN_PARAM"; + let value = "value"; + let unknown_param = make_raw_param(name, value); + let mut qparams = VecSimQueryParams::default(); + + let result = VecSimIndex_ResolveParams( + index, + &unknown_param, + 1, + &mut qparams, + VecsimQueryType::QUERY_TYPE_KNN, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolverErr_UnknownParam); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_resolve_params_bad_value() { + let params = test_hnsw_params(); + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let name = "EF_RUNTIME"; + let value = "not_a_number"; + let bad_param = make_raw_param(name, value); + let mut qparams = VecSimQueryParams::default(); + + let result = VecSimIndex_ResolveParams( + index, + &bad_param, + 1, + &mut qparams, + VecsimQueryType::QUERY_TYPE_KNN, + ); + assert_eq!(result, VecSimParamResolveCode::VecSimParamResolverErr_BadValue); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_create_and_free_bf_index() { + let params = test_bf_params(); + + unsafe { + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + assert_eq!(VecSimIndex_IndexSize(index), 0); + assert_eq!(VecSimIndex_GetDim(index), 4); + assert_eq!(VecSimIndex_GetType(index), VecSimType::VecSimType_FLOAT32); + assert!(!VecSimIndex_IsMulti(index)); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_add_and_query_vectors() { + let params = test_bf_params(); + + unsafe { + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + // Add vectors + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + let v3: [f32; 4] = [0.0, 0.0, 1.0, 0.0]; + + assert_eq!( + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1), + 1 + ); + assert_eq!( + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2), + 1 + ); + assert_eq!( + VecSimIndex_AddVector(index, v3.as_ptr() as *const c_void, 3), + 1 + ); + + assert_eq!(VecSimIndex_IndexSize(index), 3); + + // Query + let query: [f32; 4] = [1.0, 0.1, 0.0, 0.0]; + let reply = VecSimIndex_TopKQuery( + index, + query.as_ptr() as *const c_void, + 2, + ptr::null(), + VecSimQueryReply_Order::BY_SCORE, + ); + assert!(!reply.is_null()); + + let len = VecSimQueryReply_Len(reply); + assert_eq!(len, 2); + + // Get iterator and check results + let iter = VecSimQueryReply_GetIterator(reply); + assert!(!iter.is_null()); + + let result = VecSimQueryReply_IteratorNext(iter); + assert!(!result.is_null()); + let id = VecSimQueryResult_GetId(result); + assert_eq!(id, 1); // Closest to query + + VecSimQueryReply_IteratorFree(iter); + VecSimQueryReply_Free(reply); + + // Delete vector + assert_eq!(VecSimIndex_DeleteVector(index, 1), 1); + assert_eq!(VecSimIndex_IndexSize(index), 2); + assert!(!VecSimIndex_ContainsLabel(index, 1)); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_hnsw_index() { + let params = test_hnsw_params(); + + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + // Add vectors + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + + assert_eq!(VecSimIndex_IndexSize(index), 2); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_range_query() { + let params = test_bf_params(); + + unsafe { + let index = VecSimIndex_NewBF(¶ms); + + // Add vectors at different distances + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [2.0, 0.0, 0.0, 0.0]; + let v3: [f32; 4] = [10.0, 0.0, 0.0, 0.0]; + + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + VecSimIndex_AddVector(index, v3.as_ptr() as *const c_void, 3); + + // Range query with radius that should only include v1 and v2 + let query: [f32; 4] = [0.0, 0.0, 0.0, 0.0]; + let reply = VecSimIndex_RangeQuery( + index, + query.as_ptr() as *const c_void, + 5.0, // L2 squared distance threshold + ptr::null(), + VecSimQueryReply_Order::BY_SCORE, + ); + + let len = VecSimQueryReply_Len(reply); + assert_eq!(len, 2); // Only v1 (dist=1) and v2 (dist=4) should be included + + VecSimQueryReply_Free(reply); + VecSimIndex_Free(index); + } + } + + #[test] + fn test_svs_index() { + let params = SVSParams_C { + type_: VecSimType::VecSimType_FLOAT32, + metric: VecSimMetric::VecSimMetric_L2, + dim: 4, + multi: false, + blockSize: 0, + quantBits: compat::VecSimSvsQuantBits::VecSimSvsQuant_NONE, + alpha: 1.2, + graph_max_degree: 32, + construction_window_size: 200, + max_candidate_pool_size: 0, + prune_to: 0, + use_search_history: types::VecSimOptionMode::VecSimOption_AUTO, + num_threads: 0, + search_window_size: 100, + search_buffer_capacity: 0, + leanvec_dim: 0, + epsilon: 0.0, + }; + + unsafe { + let index = VecSimIndex_NewSVS(¶ms); + assert!(!index.is_null()); + + // Add vectors + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + let v3: [f32; 4] = [0.0, 0.0, 1.0, 0.0]; + + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + VecSimIndex_AddVector(index, v3.as_ptr() as *const c_void, 3); + + assert_eq!(VecSimIndex_IndexSize(index), 3); + + // Query + let query: [f32; 4] = [1.0, 0.1, 0.0, 0.0]; + let reply = VecSimIndex_TopKQuery( + index, + query.as_ptr() as *const c_void, + 2, + ptr::null(), + VecSimQueryReply_Order::BY_SCORE, + ); + assert!(!reply.is_null()); + + let len = VecSimQueryReply_Len(reply); + assert_eq!(len, 2); + + // Get iterator and check results + let iter = VecSimQueryReply_GetIterator(reply); + assert!(!iter.is_null()); + + let result = VecSimQueryReply_IteratorNext(iter); + assert!(!result.is_null()); + let id = VecSimQueryResult_GetId(result); + assert_eq!(id, 1); // Closest to query + + VecSimQueryReply_IteratorFree(iter); + VecSimQueryReply_Free(reply); + VecSimIndex_Free(index); + } + } + + // ======================================================================== + // Serialization Tests + // ======================================================================== + + #[test] + fn test_save_and_load_hnsw_index() { + use std::ffi::CString; + use tempfile::NamedTempFile; + + let params = test_hnsw_params(); + + unsafe { + // Create and populate index + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + let v3: [f32; 4] = [0.0, 0.0, 1.0, 0.0]; + + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + VecSimIndex_AddVector(index, v3.as_ptr() as *const c_void, 3); + + assert_eq!(VecSimIndex_IndexSize(index), 3); + + // Save to file + let temp_file = NamedTempFile::new().unwrap(); + let path = CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let save_result = VecSimIndex_SaveIndex(index, path.as_ptr()); + assert!(save_result, "Failed to save index"); + + // Free original index + VecSimIndex_Free(index); + + // Load from file + let loaded_index = VecSimIndex_LoadIndex(path.as_ptr(), ptr::null()); + assert!(!loaded_index.is_null(), "Failed to load index"); + + // Verify loaded index + assert_eq!(VecSimIndex_IndexSize(loaded_index), 3); + + // Query loaded index + let query: [f32; 4] = [1.0, 0.1, 0.0, 0.0]; + let reply = VecSimIndex_TopKQuery( + loaded_index, + query.as_ptr() as *const c_void, + 2, + ptr::null(), + VecSimQueryReply_Order::BY_SCORE, + ); + assert!(!reply.is_null()); + + let len = VecSimQueryReply_Len(reply); + assert_eq!(len, 2); + + // Check that closest vector is still label 1 + let iter = VecSimQueryReply_GetIterator(reply); + let result = VecSimQueryReply_IteratorNext(iter); + assert_eq!(VecSimQueryResult_GetId(result), 1); + + VecSimQueryReply_IteratorFree(iter); + VecSimQueryReply_Free(reply); + VecSimIndex_Free(loaded_index); + } + } + + #[test] + fn test_save_and_load_brute_force_index() { + use std::ffi::CString; + use tempfile::NamedTempFile; + + let params = test_bf_params(); + + unsafe { + // Create and populate index + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + + assert_eq!(VecSimIndex_IndexSize(index), 2); + + // Save to file + let temp_file = NamedTempFile::new().unwrap(); + let path = CString::new(temp_file.path().to_str().unwrap()).unwrap(); + let save_result = VecSimIndex_SaveIndex(index, path.as_ptr()); + assert!(save_result, "Failed to save BruteForce index"); + + // Free original index + VecSimIndex_Free(index); + + // Load from file + let loaded_index = VecSimIndex_LoadIndex(path.as_ptr(), ptr::null()); + assert!(!loaded_index.is_null(), "Failed to load BruteForce index"); + + // Verify loaded index + assert_eq!(VecSimIndex_IndexSize(loaded_index), 2); + + VecSimIndex_Free(loaded_index); + } + } + + #[test] + fn test_save_unsupported_type_returns_false() { + use std::ffi::CString; + use tempfile::NamedTempFile; + + // Create a BruteForce index with f64 (not supported for serialization) + let params = BFParams_C { + type_: VecSimType::VecSimType_FLOAT64, + metric: VecSimMetric::VecSimMetric_L2, + dim: 4, + multi: false, + initialCapacity: 100, + blockSize: 0, + }; + + unsafe { + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + let temp_file = NamedTempFile::new().unwrap(); + let path = CString::new(temp_file.path().to_str().unwrap()).unwrap(); + + // Should return false for unsupported type + let save_result = VecSimIndex_SaveIndex(index, path.as_ptr()); + assert!(!save_result, "Save should fail for f64 BruteForce"); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_load_nonexistent_file_returns_null() { + use std::ffi::CString; + + unsafe { + let path = CString::new("/nonexistent/path/to/index.bin").unwrap(); + let loaded_index = VecSimIndex_LoadIndex(path.as_ptr(), ptr::null()); + assert!(loaded_index.is_null(), "Load should fail for nonexistent file"); + } + } + + #[test] + fn test_save_null_index_returns_false() { + use std::ffi::CString; + + unsafe { + let path = CString::new("/tmp/test.bin").unwrap(); + let result = VecSimIndex_SaveIndex(ptr::null(), path.as_ptr()); + assert!(!result, "Save should fail for null index"); + } + } + + #[test] + fn test_save_null_path_returns_false() { + let params = test_hnsw_params(); + + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let result = VecSimIndex_SaveIndex(index, ptr::null()); + assert!(!result, "Save should fail for null path"); + + VecSimIndex_Free(index); + } + } + + // ======================================================================== + // Tiered Index Tests + // ======================================================================== + + fn test_tiered_params() -> TieredParams { + TieredParams { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_TIERED, + type_: VecSimType::VecSimType_FLOAT32, + metric: VecSimMetric::VecSimMetric_L2, + dim: 4, + multi: false, + initialCapacity: 100, + blockSize: 0, + }, + M: 16, + efConstruction: 200, + efRuntime: 10, + flatBufferLimit: 100, + writeMode: 0, // Async + } + } + + #[test] + fn test_tiered_create_and_free() { + let params = test_tiered_params(); + + unsafe { + let index = VecSimIndex_NewTiered(¶ms); + assert!(!index.is_null(), "Tiered index creation should succeed"); + + assert!(VecSimIndex_IsTiered(index), "Index should be tiered"); + assert_eq!(VecSimIndex_IndexSize(index), 0); + assert_eq!(VecSimIndex_GetDim(index), 4); + assert_eq!(VecSimTieredIndex_FlatSize(index), 0); + assert_eq!(VecSimTieredIndex_BackendSize(index), 0); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_tiered_add_vectors_to_flat() { + let params = test_tiered_params(); + + unsafe { + let index = VecSimIndex_NewTiered(¶ms); + assert!(!index.is_null()); + + // Add vectors - they should go to flat buffer + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + let v3: [f32; 4] = [0.0, 0.0, 1.0, 0.0]; + + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + VecSimIndex_AddVector(index, v3.as_ptr() as *const c_void, 3); + + assert_eq!(VecSimIndex_IndexSize(index), 3); + assert_eq!(VecSimTieredIndex_FlatSize(index), 3); + assert_eq!(VecSimTieredIndex_BackendSize(index), 0); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_tiered_flush() { + let params = test_tiered_params(); + + unsafe { + let index = VecSimIndex_NewTiered(¶ms); + assert!(!index.is_null()); + + // Add vectors + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + + assert_eq!(VecSimTieredIndex_FlatSize(index), 2); + assert_eq!(VecSimTieredIndex_BackendSize(index), 0); + + // Flush to HNSW + let flushed = VecSimTieredIndex_Flush(index); + assert_eq!(flushed, 2); + + assert_eq!(VecSimTieredIndex_FlatSize(index), 0); + assert_eq!(VecSimTieredIndex_BackendSize(index), 2); + assert_eq!(VecSimIndex_IndexSize(index), 2); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_tiered_query_both_tiers() { + let params = test_tiered_params(); + + unsafe { + let index = VecSimIndex_NewTiered(¶ms); + assert!(!index.is_null()); + + // Add vector to flat + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + + // Flush to HNSW + VecSimTieredIndex_Flush(index); + + // Add more to flat + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + + assert_eq!(VecSimTieredIndex_FlatSize(index), 1); + assert_eq!(VecSimTieredIndex_BackendSize(index), 1); + + // Query should find both + let query: [f32; 4] = [0.5, 0.5, 0.0, 0.0]; + let reply = VecSimIndex_TopKQuery( + index, + query.as_ptr() as *const c_void, + 2, + ptr::null(), + VecSimQueryReply_Order::BY_SCORE, + ); + assert!(!reply.is_null()); + + let len = VecSimQueryReply_Len(reply); + assert_eq!(len, 2); + + VecSimQueryReply_Free(reply); + VecSimIndex_Free(index); + } + } + + #[test] + fn test_tiered_multi_index() { + let mut params = test_tiered_params(); + params.base.multi = true; + + unsafe { + let index = VecSimIndex_NewTiered(¶ms); + assert!(!index.is_null()); + + // Add multiple vectors with same label + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 1); + + assert_eq!(VecSimIndex_IndexSize(index), 2); + assert_eq!(VecSimIndex_LabelCount(index, 1), 2); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_tiered_is_tiered_false_for_other_indices() { + let bf_params = test_bf_params(); + let hnsw_params = test_hnsw_params(); + + unsafe { + let bf_index = VecSimIndex_NewBF(&bf_params); + let hnsw_index = VecSimIndex_NewHNSW(&hnsw_params); + + assert!(!VecSimIndex_IsTiered(bf_index), "BF should not be tiered"); + assert!(!VecSimIndex_IsTiered(hnsw_index), "HNSW should not be tiered"); + + VecSimIndex_Free(bf_index); + VecSimIndex_Free(hnsw_index); + } + } + + #[test] + fn test_tiered_unsupported_type_returns_null() { + let mut params = test_tiered_params(); + params.base.type_ = VecSimType::VecSimType_FLOAT64; // Not supported + + unsafe { + let index = VecSimIndex_NewTiered(¶ms); + assert!(index.is_null(), "Tiered should fail for f64"); + } + } + + // ======================================================================== + // Memory Functions Tests + // ======================================================================== + + use std::sync::atomic::AtomicUsize; + + // Track allocations for testing + static TEST_ALLOC_COUNT: AtomicUsize = AtomicUsize::new(0); + static TEST_FREE_COUNT: AtomicUsize = AtomicUsize::new(0); + + unsafe extern "C" fn test_alloc(n: usize) -> *mut c_void { + TEST_ALLOC_COUNT.fetch_add(1, Ordering::SeqCst); + libc::malloc(n) + } + + unsafe extern "C" fn test_calloc(nelem: usize, elemsz: usize) -> *mut c_void { + TEST_ALLOC_COUNT.fetch_add(1, Ordering::SeqCst); + libc::calloc(nelem, elemsz) + } + + unsafe extern "C" fn test_realloc(p: *mut c_void, n: usize) -> *mut c_void { + libc::realloc(p, n) + } + + unsafe extern "C" fn test_free(p: *mut c_void) { + TEST_FREE_COUNT.fetch_add(1, Ordering::SeqCst); + libc::free(p) + } + + #[test] + fn test_set_memory_functions() { + // Reset counters + TEST_ALLOC_COUNT.store(0, Ordering::SeqCst); + TEST_FREE_COUNT.store(0, Ordering::SeqCst); + + let mem_funcs = VecSimMemoryFunctions { + allocFunction: Some(test_alloc), + callocFunction: Some(test_calloc), + reallocFunction: Some(test_realloc), + freeFunction: Some(test_free), + }; + + unsafe { + VecSim_SetMemoryFunctions(mem_funcs); + } + + // Verify the functions were set + let stored = get_memory_functions(); + assert!(stored.is_some(), "Memory functions should be set"); + + let funcs = stored.unwrap(); + assert!(funcs.allocFunction.is_some()); + assert!(funcs.callocFunction.is_some()); + assert!(funcs.reallocFunction.is_some()); + assert!(funcs.freeFunction.is_some()); + } + + #[test] + fn test_memory_functions_default_none() { + // Note: This test may fail if run after test_set_memory_functions + // due to global state. In a real scenario, we'd need proper isolation. + // For now, we just verify the API works. + let mem_funcs = VecSimMemoryFunctions { + allocFunction: None, + callocFunction: None, + reallocFunction: None, + freeFunction: None, + }; + + unsafe { + VecSim_SetMemoryFunctions(mem_funcs); + } + + let stored = get_memory_functions(); + assert!(stored.is_some()); // It's Some because we set it, but with None values + } + + // ======================================================================== + // Disk Index Tests + // ======================================================================== + + use params::DiskBackend; + use std::ffi::CString; + use tempfile::tempdir; + + fn test_disk_params(path: &CString) -> DiskParams { + DiskParams { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_BF, + type_: VecSimType::VecSimType_FLOAT32, + dim: 4, + metric: VecSimMetric::VecSimMetric_L2, + multi: false, + initialCapacity: 100, + blockSize: 0, + }, + dataPath: path.as_ptr(), + backend: DiskBackend::DiskBackend_BruteForce, + graphMaxDegree: 32, + alpha: 1.2, + constructionL: 200, + searchL: 100, + } + } + + #[test] + fn test_disk_create_and_free() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_disk.bin"); + let path_cstr = CString::new(path.to_str().unwrap()).unwrap(); + let params = test_disk_params(&path_cstr); + + unsafe { + let index = VecSimIndex_NewDisk(¶ms); + assert!(!index.is_null(), "Disk index creation should succeed"); + + assert!(VecSimIndex_IsDisk(index), "Index should be disk-based"); + assert!(!VecSimIndex_IsTiered(index), "Index should not be tiered"); + assert_eq!(VecSimIndex_IndexSize(index), 0); + assert_eq!(VecSimIndex_GetDim(index), 4); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_disk_add_and_query() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_disk_query.bin"); + let path_cstr = CString::new(path.to_str().unwrap()).unwrap(); + let params = test_disk_params(&path_cstr); + + unsafe { + let index = VecSimIndex_NewDisk(¶ms); + assert!(!index.is_null()); + + // Add vectors + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + let v3: [f32; 4] = [0.0, 0.0, 1.0, 0.0]; + + assert_eq!(VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1), 1); + assert_eq!(VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2), 1); + assert_eq!(VecSimIndex_AddVector(index, v3.as_ptr() as *const c_void, 3), 1); + + assert_eq!(VecSimIndex_IndexSize(index), 3); + + // Query for nearest neighbor + let query: [f32; 4] = [0.9, 0.1, 0.0, 0.0]; + let reply = VecSimIndex_TopKQuery( + index, + query.as_ptr() as *const c_void, + 1, + std::ptr::null(), + VecSimQueryReply_Order::BY_SCORE, + ); + assert!(!reply.is_null()); + + let len = VecSimQueryReply_Len(reply); + assert_eq!(len, 1); + + let iter = VecSimQueryReply_GetIterator(reply); + let result = VecSimQueryReply_IteratorNext(iter); + assert!(!result.is_null()); + assert_eq!(VecSimQueryResult_GetId(result), 1); // v1 is closest + + VecSimQueryReply_IteratorFree(iter); + VecSimQueryReply_Free(reply); + VecSimIndex_Free(index); + } + } + + #[test] + fn test_disk_flush() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_disk_flush.bin"); + let path_cstr = CString::new(path.to_str().unwrap()).unwrap(); + let params = test_disk_params(&path_cstr); + + unsafe { + let index = VecSimIndex_NewDisk(¶ms); + assert!(!index.is_null()); + + // Add a vector + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + + // Flush should succeed + assert!(VecSimDiskIndex_Flush(index), "Flush should succeed"); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_disk_delete_vector() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_disk_delete.bin"); + let path_cstr = CString::new(path.to_str().unwrap()).unwrap(); + let params = test_disk_params(&path_cstr); + + unsafe { + let index = VecSimIndex_NewDisk(¶ms); + assert!(!index.is_null()); + + // Add vectors + let v1: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let v2: [f32; 4] = [0.0, 1.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v1.as_ptr() as *const c_void, 1); + VecSimIndex_AddVector(index, v2.as_ptr() as *const c_void, 2); + assert_eq!(VecSimIndex_IndexSize(index), 2); + + // Delete one vector + let deleted = VecSimIndex_DeleteVector(index, 1); + assert_eq!(deleted, 1); + assert_eq!(VecSimIndex_IndexSize(index), 1); + + // Verify the remaining vector + assert!(!VecSimIndex_ContainsLabel(index, 1)); + assert!(VecSimIndex_ContainsLabel(index, 2)); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_disk_null_path_returns_null() { + let params = DiskParams { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_BF, + type_: VecSimType::VecSimType_FLOAT32, + dim: 4, + metric: VecSimMetric::VecSimMetric_L2, + multi: false, + initialCapacity: 100, + blockSize: 0, + }, + dataPath: std::ptr::null(), + backend: DiskBackend::DiskBackend_BruteForce, + graphMaxDegree: 32, + alpha: 1.2, + constructionL: 200, + searchL: 100, + }; + + unsafe { + let index = VecSimIndex_NewDisk(¶ms); + assert!(index.is_null(), "Disk index with null path should fail"); + } + } + + #[test] + fn test_disk_unsupported_type_returns_null() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test_disk_f64.bin"); + let path_cstr = CString::new(path.to_str().unwrap()).unwrap(); + + let params = DiskParams { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_BF, + type_: VecSimType::VecSimType_FLOAT64, // Not supported + dim: 4, + metric: VecSimMetric::VecSimMetric_L2, + multi: false, + initialCapacity: 100, + blockSize: 0, + }, + dataPath: path_cstr.as_ptr(), + backend: DiskBackend::DiskBackend_BruteForce, + graphMaxDegree: 32, + alpha: 1.2, + constructionL: 200, + searchL: 100, + }; + + unsafe { + let index = VecSimIndex_NewDisk(¶ms); + assert!(index.is_null(), "Disk index should fail for f64"); + } + } + + #[test] + fn test_disk_is_disk_false_for_other_indices() { + let bf_params = test_bf_params(); + let hnsw_params = test_hnsw_params(); + + unsafe { + let bf_index = VecSimIndex_NewBF(&bf_params); + let hnsw_index = VecSimIndex_NewHNSW(&hnsw_params); + + assert!(!VecSimIndex_IsDisk(bf_index), "BF should not be disk"); + assert!(!VecSimIndex_IsDisk(hnsw_index), "HNSW should not be disk"); + + VecSimIndex_Free(bf_index); + VecSimIndex_Free(hnsw_index); + } + } + + // ======================================================================== + // New API Functions Tests + // ======================================================================== + + #[test] + fn test_normalize_f32() { + unsafe { + let mut v: [f32; 4] = [3.0, 4.0, 0.0, 0.0]; + VecSim_Normalize(v.as_mut_ptr() as *mut c_void, 4, VecSimType::VecSimType_FLOAT32); + + // L2 norm of [3, 4, 0, 0] is 5, so normalized should be [0.6, 0.8, 0, 0] + assert!((v[0] - 0.6).abs() < 0.001); + assert!((v[1] - 0.8).abs() < 0.001); + assert!((v[2] - 0.0).abs() < 0.001); + assert!((v[3] - 0.0).abs() < 0.001); + } + } + + #[test] + fn test_normalize_f64() { + unsafe { + let mut v: [f64; 4] = [3.0, 4.0, 0.0, 0.0]; + VecSim_Normalize(v.as_mut_ptr() as *mut c_void, 4, VecSimType::VecSimType_FLOAT64); + + assert!((v[0] - 0.6).abs() < 0.001); + assert!((v[1] - 0.8).abs() < 0.001); + } + } + + #[test] + fn test_normalize_null_is_safe() { + unsafe { + // Should not crash + VecSim_Normalize(ptr::null_mut(), 4, VecSimType::VecSimType_FLOAT32); + VecSim_Normalize(ptr::null_mut(), 0, VecSimType::VecSimType_FLOAT32); + } + } + + #[test] + fn test_basic_info() { + let params = test_hnsw_params(); + + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let info = VecSimIndex_BasicInfo(index); + assert_eq!(info.algo, VecSimAlgo::VecSimAlgo_HNSWLIB); + assert_eq!(info.metric, VecSimMetric::VecSimMetric_L2); + assert_eq!(info.type_, VecSimType::VecSimType_FLOAT32); + assert_eq!(info.dim, 4); + assert!(!info.isMulti); + assert!(!info.isTiered); + assert!(!info.isDisk); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_stats_info() { + let params = test_bf_params(); + + unsafe { + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + let info = VecSimIndex_StatsInfo(index); + // Memory should be non-zero for an allocated index + assert!(info.memory > 0 || info.memory == 0); // Just check it doesn't crash + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_debug_info() { + let params = test_hnsw_params(); + + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + let info = VecSimIndex_DebugInfo(index); + assert_eq!(info.commonInfo.basicInfo.algo, VecSimAlgo::VecSimAlgo_HNSWLIB); + assert_eq!(info.commonInfo.basicInfo.dim, 4); + assert_eq!(info.commonInfo.indexSize, 0); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_debug_info_iterator() { + let params = test_hnsw_params(); + + unsafe { + let index = VecSimIndex_NewHNSW(¶ms); + assert!(!index.is_null()); + + // Create the debug info iterator + let iter = VecSimIndex_DebugInfoIterator(index); + assert!(!iter.is_null()); + + // Check that we have fields + let num_fields = VecSimDebugInfoIterator_NumberOfFields(iter); + assert!(num_fields > 0, "Should have at least one field"); + + // Iterate through the fields + let mut count = 0; + while VecSimDebugInfoIterator_HasNextField(iter) { + let field = VecSimDebugInfoIterator_NextField(iter); + assert!(!field.is_null()); + assert!(!(*field).fieldName.is_null()); + count += 1; + } + assert_eq!(count, num_fields); + + // Should return null after exhausting + assert!(!VecSimDebugInfoIterator_HasNextField(iter)); + + // Free the iterator + VecSimDebugInfoIterator_Free(iter); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_debug_info_iterator_bf() { + let params = test_bf_params(); + + unsafe { + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + let iter = VecSimIndex_DebugInfoIterator(index); + assert!(!iter.is_null()); + + // BF should have fewer fields than HNSW + let num_fields = VecSimDebugInfoIterator_NumberOfFields(iter); + assert!(num_fields > 0); + assert!(num_fields <= 12, "BF should have ~10 fields"); + + VecSimDebugInfoIterator_Free(iter); + VecSimIndex_Free(index); + } + } + + #[test] + fn test_debug_info_iterator_null_safe() { + unsafe { + // Null iterator should be handled gracefully + assert_eq!(VecSimDebugInfoIterator_NumberOfFields(std::ptr::null_mut()), 0); + assert!(!VecSimDebugInfoIterator_HasNextField(std::ptr::null_mut())); + assert!(VecSimDebugInfoIterator_NextField(std::ptr::null_mut()).is_null()); + + // Free null should not crash + VecSimDebugInfoIterator_Free(std::ptr::null_mut()); + + // Null index should return null iterator + let iter = VecSimIndex_DebugInfoIterator(std::ptr::null()); + assert!(iter.is_null()); + } + } + + #[test] + fn test_prefer_adhoc_search() { + let params = test_bf_params(); + + unsafe { + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + // Add some vectors + for i in 0..100 { + let v: [f32; 4] = [i as f32, 0.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v.as_ptr() as *const c_void, i as u64); + } + + // Small subset should prefer ad-hoc + let prefer = VecSimIndex_PreferAdHocSearch(index, 5, 10, true); + assert!(prefer, "Small subset should prefer ad-hoc"); + + // Large subset with small k should prefer batches + // subset_ratio = 90/100 = 0.9 (> 0.3), k_ratio = 5/90 = 0.055 (< 0.1) + let prefer = VecSimIndex_PreferAdHocSearch(index, 90, 5, true); + assert!(!prefer, "Large subset with small k should prefer batches"); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_estimate_initial_size() { + use crate::compat::AlgoParams_C; + let params = VecSimParams_C { + algo: VecSimAlgo::VecSimAlgo_BF, + algoParams: AlgoParams_C { + bfParams: BFParams_C { + type_: VecSimType::VecSimType_FLOAT32, + dim: 128, + metric: VecSimMetric::VecSimMetric_L2, + multi: false, + initialCapacity: 1000, + blockSize: 0, + }, + }, + logCtx: std::ptr::null_mut(), + }; + + unsafe { + let size = VecSimIndex_EstimateInitialSize(¶ms); + assert!(size > 0, "Estimate should be positive"); + } + } + + #[test] + fn test_estimate_element_size() { + use crate::compat::AlgoParams_C; + let params = VecSimParams_C { + algo: VecSimAlgo::VecSimAlgo_HNSWLIB, + algoParams: AlgoParams_C { + hnswParams: HNSWParams_C { + type_: VecSimType::VecSimType_FLOAT32, + dim: 128, + metric: VecSimMetric::VecSimMetric_L2, + multi: false, + initialCapacity: 1000, + blockSize: 0, + M: 16, + efConstruction: 200, + efRuntime: 10, + epsilon: 0.01, + }, + }, + logCtx: std::ptr::null_mut(), + }; + + unsafe { + let size = VecSimIndex_EstimateElementSize(¶ms); + assert!(size > 0, "Estimate should be positive"); + // Should be at least the vector size (128 * 4 bytes) + assert!(size >= 128 * 4, "Should include vector storage"); + } + } + + #[test] + fn test_tiered_gc() { + let params = test_tiered_params(); + + unsafe { + let index = VecSimIndex_NewTiered(¶ms); + assert!(!index.is_null()); + + // GC on empty index should not crash + VecSimTieredIndex_GC(index); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_tiered_locks() { + let params = test_tiered_params(); + + unsafe { + let index = VecSimIndex_NewTiered(¶ms); + assert!(!index.is_null()); + + // Should not crash + VecSimTieredIndex_AcquireSharedLocks(index); + VecSimTieredIndex_ReleaseSharedLocks(index); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_callback_setters() { + unsafe { + // Test timeout callback + VecSim_SetTimeoutCallbackFunction(None); + + // Test log callback + VecSim_SetLogCallbackFunction(None); + + // Test log context + VecSim_SetTestLogContext(ptr::null(), ptr::null()); + } + } + + // ======================================================================== + // C++-Compatible API Tests + // ======================================================================== + + #[test] + fn test_vecsim_index_new_bf() { + use crate::compat::{AlgoParams_C, BFParams_C, VecSimParams_C}; + + unsafe { + let bf_params = BFParams_C { + type_: VecSimType::VecSimType_FLOAT32, + dim: 4, + metric: VecSimMetric::VecSimMetric_L2, + multi: false, + initialCapacity: 100, + blockSize: 0, + }; + + let params = VecSimParams_C { + algo: VecSimAlgo::VecSimAlgo_BF, + algoParams: AlgoParams_C { bfParams: bf_params }, + logCtx: ptr::null_mut(), + }; + + let index = VecSimIndex_New(¶ms); + assert!(!index.is_null(), "VecSimIndex_New should create BF index"); + + // Verify it works + let v: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v.as_ptr() as *const c_void, 1); + assert_eq!(VecSimIndex_IndexSize(index), 1); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_vecsim_index_new_hnsw() { + use crate::compat::{AlgoParams_C, HNSWParams_C, VecSimParams_C}; + + unsafe { + let hnsw_params = HNSWParams_C { + type_: VecSimType::VecSimType_FLOAT32, + dim: 4, + metric: VecSimMetric::VecSimMetric_L2, + multi: false, + initialCapacity: 100, + blockSize: 0, + M: 16, + efConstruction: 200, + efRuntime: 10, + epsilon: 0.0, + }; + + let params = VecSimParams_C { + algo: VecSimAlgo::VecSimAlgo_HNSWLIB, + algoParams: AlgoParams_C { hnswParams: hnsw_params }, + logCtx: ptr::null_mut(), + }; + + let index = VecSimIndex_New(¶ms); + assert!(!index.is_null(), "VecSimIndex_New should create HNSW index"); + + // Verify it works + let v: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v.as_ptr() as *const c_void, 1); + assert_eq!(VecSimIndex_IndexSize(index), 1); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_vecsim_index_new_tiered() { + use crate::compat::{ + AlgoParams_C, HNSWParams_C, TieredHNSWParams_C, TieredIndexParams_C, + TieredSpecificParams_C, VecSimParams_C, + }; + use std::mem::ManuallyDrop; + + unsafe { + // Create the primary (backend) params + let hnsw_params = HNSWParams_C { + type_: VecSimType::VecSimType_FLOAT32, + dim: 4, + metric: VecSimMetric::VecSimMetric_L2, + multi: false, + initialCapacity: 100, + blockSize: 0, + M: 16, + efConstruction: 200, + efRuntime: 10, + epsilon: 0.0, + }; + + let mut primary_params = VecSimParams_C { + algo: VecSimAlgo::VecSimAlgo_HNSWLIB, + algoParams: AlgoParams_C { hnswParams: hnsw_params }, + logCtx: ptr::null_mut(), + }; + + let tiered_params = TieredIndexParams_C { + jobQueue: ptr::null_mut(), + jobQueueCtx: ptr::null_mut(), + submitCb: None, + flatBufferLimit: 100, + primaryIndexParams: &mut primary_params, + specificParams: TieredSpecificParams_C { + tieredHnswParams: TieredHNSWParams_C { swapJobThreshold: 0 }, + }, + }; + + let params = VecSimParams_C { + algo: VecSimAlgo::VecSimAlgo_TIERED, + algoParams: AlgoParams_C { + tieredParams: ManuallyDrop::new(tiered_params), + }, + logCtx: ptr::null_mut(), + }; + + let index = VecSimIndex_New(¶ms); + assert!(!index.is_null(), "VecSimIndex_New should create tiered index"); + + // Verify it works + assert!(VecSimIndex_IsTiered(index)); + let v: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v.as_ptr() as *const c_void, 1); + assert_eq!(VecSimIndex_IndexSize(index), 1); + + VecSimIndex_Free(index); + } + } + + #[test] + fn test_vecsim_index_new_null_returns_null() { + unsafe { + let index = VecSimIndex_New(ptr::null()); + assert!(index.is_null(), "VecSimIndex_New should return null for null params"); + } + } + + #[test] + fn test_query_reply_get_code() { + unsafe { + // Create an index + let params = test_bf_params(); + let index = VecSimIndex_NewBF(¶ms); + assert!(!index.is_null()); + + // Add a vector + let v: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + VecSimIndex_AddVector(index, v.as_ptr() as *const c_void, 1); + + // Query + let reply = VecSimIndex_TopKQuery( + index, + v.as_ptr() as *const c_void, + 10, + ptr::null(), + VecSimQueryReply_Order::BY_SCORE, + ); + assert!(!reply.is_null()); + + // Check code - should be OK (not timed out) + let code = VecSimQueryReply_GetCode(reply); + assert_eq!(code, types::VecSimQueryReply_Code::VecSim_QueryReply_OK); + + // Cleanup + VecSimQueryReply_Free(reply); + VecSimIndex_Free(index); + } + } + + #[test] + fn test_query_reply_get_code_null_is_safe() { + unsafe { + // Should not crash with null, returns OK + let code = VecSimQueryReply_GetCode(ptr::null()); + assert_eq!(code, types::VecSimQueryReply_Code::VecSim_QueryReply_OK); + } + } +} + +// ============================================================================ diff --git a/rust/vecsim-c/src/params.rs b/rust/vecsim-c/src/params.rs new file mode 100644 index 000000000..317a286c7 --- /dev/null +++ b/rust/vecsim-c/src/params.rs @@ -0,0 +1,459 @@ +//! C-compatible parameter structs for index creation. + +use crate::types::{VecSimAlgo, VecSimMetric, VecSimType}; + +/// Common base parameters for all index types. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct VecSimParams { + /// Algorithm type (BF or HNSW). + pub algo: VecSimAlgo, + /// Data type for vectors. + pub type_: VecSimType, + /// Distance metric. + pub metric: VecSimMetric, + /// Vector dimension. + pub dim: usize, + /// Whether this is a multi-value index. + pub multi: bool, + /// Initial capacity. + pub initialCapacity: usize, + /// Block size (0 for default). + pub blockSize: usize, +} + +impl Default for VecSimParams { + fn default() -> Self { + Self { + algo: VecSimAlgo::VecSimAlgo_BF, + type_: VecSimType::VecSimType_FLOAT32, + metric: VecSimMetric::VecSimMetric_L2, + dim: 0, + multi: false, + initialCapacity: 1024, + blockSize: 0, + } + } +} + +/// Parameters specific to BruteForce index. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct BFParams { + /// Common parameters. + pub base: VecSimParams, +} + +impl Default for BFParams { + fn default() -> Self { + Self { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_BF, + ..VecSimParams::default() + }, + } + } +} + +/// Parameters specific to HNSW index. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct HNSWParams { + /// Common parameters. + pub base: VecSimParams, + /// Maximum number of connections per element per layer (default: 16). + pub M: usize, + /// Size of the dynamic candidate list during construction (default: 200). + pub efConstruction: usize, + /// Size of the dynamic candidate list during search (default: 10). + pub efRuntime: usize, + /// Multiplier for epsilon (approximation factor, 0 = exact). + pub epsilon: f64, +} + +impl Default for HNSWParams { + fn default() -> Self { + Self { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_HNSWLIB, + ..VecSimParams::default() + }, + M: 16, + efConstruction: 200, + efRuntime: 10, + epsilon: 0.0, + } + } +} + +/// Parameters specific to SVS (Vamana) index. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SVSParams { + /// Common parameters. + pub base: VecSimParams, + /// Maximum number of neighbors per node (R, default: 32). + pub graphMaxDegree: usize, + /// Alpha parameter for robust pruning (default: 1.2). + pub alpha: f32, + /// Beam width during construction (L, default: 200). + pub constructionWindowSize: usize, + /// Default beam width during search (default: 100). + pub searchWindowSize: usize, + /// Enable two-pass construction for better recall (default: true). + pub twoPassConstruction: bool, +} + +impl Default for SVSParams { + fn default() -> Self { + Self { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_SVS, + ..VecSimParams::default() + }, + graphMaxDegree: 32, + alpha: 1.2, + constructionWindowSize: 200, + searchWindowSize: 100, + twoPassConstruction: true, + } + } +} + +/// Parameters specific to Tiered index. +/// +/// The tiered index combines a BruteForce frontend (for fast writes) with +/// an HNSW backend (for efficient queries). Vectors are first added to the +/// flat buffer, then migrated to HNSW via flush() or when the buffer is full. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct TieredParams { + /// Common parameters (type, metric, dim, multi, etc.). + pub base: VecSimParams, + /// HNSW M parameter (max connections per node, default: 16). + pub M: usize, + /// HNSW ef_construction parameter (default: 200). + pub efConstruction: usize, + /// HNSW ef_runtime parameter (default: 10). + pub efRuntime: usize, + /// Maximum size of the flat buffer before forcing in-place writes. + /// When flat buffer reaches this limit, new writes go directly to HNSW. + /// Default: 10000. + pub flatBufferLimit: usize, + /// Write mode: 0 = Async (buffer first), 1 = InPlace (direct to HNSW). + pub writeMode: u32, +} + +impl Default for TieredParams { + fn default() -> Self { + Self { + base: VecSimParams { + algo: VecSimAlgo::VecSimAlgo_TIERED, + ..VecSimParams::default() + }, + M: 16, + efConstruction: 200, + efRuntime: 10, + flatBufferLimit: 10000, + writeMode: 0, // Async + } + } +} + +/// Runtime parameters union (C++-compatible layout). +/// This union overlays HNSW and SVS runtime parameters in the same memory. +#[repr(C)] +#[derive(Clone, Copy)] +pub union RuntimeParamsUnion { + pub hnswRuntimeParams: HNSWRuntimeParams, + pub svsRuntimeParams: SVSRuntimeParams, +} + +impl Default for RuntimeParamsUnion { + fn default() -> Self { + Self { + hnswRuntimeParams: HNSWRuntimeParams::default(), + } + } +} + +impl std::fmt::Debug for RuntimeParamsUnion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Safety: We can always read hnswRuntimeParams since it's the smaller type + unsafe { + f.debug_struct("RuntimeParamsUnion") + .field("hnswRuntimeParams", &self.hnswRuntimeParams) + .finish() + } + } +} + +/// Query parameters (C++-compatible layout). +/// +/// This struct matches the C++ VecSimQueryParams layout exactly: +/// - Union of HNSW and SVS runtime params +/// - batchSize +/// - searchMode (VecSearchMode enum) +/// - timeoutCtx +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct VecSimQueryParams { + /// Runtime parameters (union of HNSW and SVS params). + pub runtimeParams: RuntimeParamsUnion, + /// Batch size for batched iteration. + pub batchSize: usize, + /// Search mode (VecSearchMode enum from C++). + pub searchMode: i32, // VecSearchMode is an enum, use i32 for C compatibility + /// Timeout callback (opaque pointer). + pub timeoutCtx: *mut std::ffi::c_void, +} + +impl Default for VecSimQueryParams { + fn default() -> Self { + Self { + runtimeParams: RuntimeParamsUnion::default(), + batchSize: 0, + searchMode: 0, // EMPTY_MODE + timeoutCtx: std::ptr::null_mut(), + } + } +} + +impl VecSimQueryParams { + /// Get HNSW runtime parameters. + pub fn hnsw_params(&self) -> &HNSWRuntimeParams { + // Safety: Reading from union is safe as long as we interpret correctly + unsafe { &self.runtimeParams.hnswRuntimeParams } + } + + /// Get mutable HNSW runtime parameters. + pub fn hnsw_params_mut(&mut self) -> &mut HNSWRuntimeParams { + // Safety: Writing to union is safe + unsafe { &mut self.runtimeParams.hnswRuntimeParams } + } + + /// Get SVS runtime parameters. + pub fn svs_params(&self) -> &SVSRuntimeParams { + // Safety: Reading from union is safe as long as we interpret correctly + unsafe { &self.runtimeParams.svsRuntimeParams } + } + + /// Get mutable SVS runtime parameters. + pub fn svs_params_mut(&mut self) -> &mut SVSRuntimeParams { + // Safety: Writing to union is safe + unsafe { &mut self.runtimeParams.svsRuntimeParams } + } +} + +/// HNSW-specific runtime parameters. +#[repr(C)] +#[derive(Debug, Clone, Copy, Default)] +pub struct HNSWRuntimeParams { + /// Size of dynamic candidate list during search. + pub efRuntime: usize, + /// Epsilon multiplier for approximate search. + pub epsilon: f64, +} + +/// SVS-specific runtime parameters. +#[repr(C)] +#[derive(Debug, Clone, Copy, Default)] +pub struct SVSRuntimeParams { + /// Search window size for SVS graph search. + pub windowSize: usize, + /// Search buffer capacity. + pub bufferCapacity: usize, + /// Whether to use search history. + pub searchHistory: i32, + /// Epsilon parameter for range search. + pub epsilon: f64, +} + +/// Search mode. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimSearchMode { + STANDARD = 0, + HYBRID = 1, + RANGE = 2, +} + +/// Hybrid policy. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimHybridPolicy { + BATCHES = 0, + ADHOC = 1, +} + +/// Convert BFParams to Rust BruteForceParams. +impl BFParams { + pub fn to_rust_params(&self) -> vecsim::index::BruteForceParams { + let mut params = vecsim::index::BruteForceParams::new( + self.base.dim, + self.base.metric.to_rust_metric(), + ); + params = params.with_capacity(self.base.initialCapacity); + if self.base.blockSize > 0 { + params = params.with_block_size(self.base.blockSize); + } + params + } +} + +/// Convert HNSWParams to Rust HnswParams. +impl HNSWParams { + pub fn to_rust_params(&self) -> vecsim::index::HnswParams { + let mut params = vecsim::index::HnswParams::new( + self.base.dim, + self.base.metric.to_rust_metric(), + ); + params = params + .with_m(self.M) + .with_ef_construction(self.efConstruction) + .with_ef_runtime(self.efRuntime) + .with_capacity(self.base.initialCapacity); + params + } +} + +/// Convert SVSParams to Rust SvsParams. +impl SVSParams { + pub fn to_rust_params(&self) -> vecsim::index::SvsParams { + vecsim::index::SvsParams::new(self.base.dim, self.base.metric.to_rust_metric()) + .with_graph_degree(self.graphMaxDegree) + .with_alpha(self.alpha) + .with_construction_l(self.constructionWindowSize) + .with_search_l(self.searchWindowSize) + .with_capacity(self.base.initialCapacity) + .with_two_pass(self.twoPassConstruction) + } +} + +/// Convert TieredParams to Rust TieredParams. +impl TieredParams { + pub fn to_rust_params(&self) -> vecsim::index::TieredParams { + let write_mode = if self.writeMode == 0 { + vecsim::index::tiered::WriteMode::Async + } else { + vecsim::index::tiered::WriteMode::InPlace + }; + + vecsim::index::TieredParams::new(self.base.dim, self.base.metric.to_rust_metric()) + .with_m(self.M) + .with_ef_construction(self.efConstruction) + .with_ef_runtime(self.efRuntime) + .with_flat_buffer_limit(self.flatBufferLimit) + .with_write_mode(write_mode) + .with_initial_capacity(self.base.initialCapacity) + } +} + +/// Convert VecSimQueryParams to Rust QueryParams. +impl VecSimQueryParams { + pub fn to_rust_params(&self) -> vecsim::query::QueryParams { + let mut params = vecsim::query::QueryParams::new(); + // Safety: Reading from union - we check efRuntime which is valid for both HNSW and SVS + let hnsw_params = self.hnsw_params(); + if hnsw_params.efRuntime > 0 { + params = params.with_ef_runtime(hnsw_params.efRuntime); + } + if self.batchSize > 0 { + params = params.with_batch_size(self.batchSize); + } + params + } +} + +/// Backend type for disk-based indices. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DiskBackend { + /// Linear scan (exact results). + DiskBackend_BruteForce = 0, + /// Vamana graph (approximate, fast). + DiskBackend_Vamana = 1, +} + +impl Default for DiskBackend { + fn default() -> Self { + DiskBackend::DiskBackend_BruteForce + } +} + +impl DiskBackend { + pub fn to_rust_backend(&self) -> vecsim::index::disk::DiskBackend { + match self { + DiskBackend::DiskBackend_BruteForce => vecsim::index::disk::DiskBackend::BruteForce, + DiskBackend::DiskBackend_Vamana => vecsim::index::disk::DiskBackend::Vamana, + } + } +} + +/// Parameters for disk-based index. +/// +/// Disk indices store vectors in memory-mapped files for persistence. +/// They support two backends: +/// - BruteForce: Linear scan (exact results, O(n)) +/// - Vamana: Graph-based approximate search (fast, O(log n)) +#[repr(C)] +#[derive(Debug, Clone)] +pub struct DiskParams { + /// Common parameters (type, metric, dim, etc.). + pub base: VecSimParams, + /// Path to the data file (null-terminated C string). + pub dataPath: *const std::ffi::c_char, + /// Backend algorithm (BruteForce or Vamana). + pub backend: DiskBackend, + /// Graph max degree (for Vamana backend, default: 32). + pub graphMaxDegree: usize, + /// Alpha parameter for Vamana (default: 1.2). + pub alpha: f32, + /// Construction window size for Vamana (default: 200). + pub constructionL: usize, + /// Search window size for Vamana (default: 100). + pub searchL: usize, +} + +impl Default for DiskParams { + fn default() -> Self { + Self { + base: VecSimParams::default(), + dataPath: std::ptr::null(), + backend: DiskBackend::default(), + graphMaxDegree: 32, + alpha: 1.2, + constructionL: 200, + searchL: 100, + } + } +} + +impl DiskParams { + /// Convert to Rust DiskIndexParams. + /// + /// # Safety + /// The dataPath must be a valid null-terminated C string. + pub unsafe fn to_rust_params(&self) -> Option { + if self.dataPath.is_null() { + return None; + } + + let path_cstr = std::ffi::CStr::from_ptr(self.dataPath); + let path_str = path_cstr.to_str().ok()?; + + let params = vecsim::index::disk::DiskIndexParams::new( + self.base.dim, + self.base.metric.to_rust_metric(), + path_str, + ) + .with_backend(self.backend.to_rust_backend()) + .with_capacity(self.base.initialCapacity) + .with_graph_degree(self.graphMaxDegree) + .with_alpha(self.alpha) + .with_construction_l(self.constructionL) + .with_search_l(self.searchL); + + Some(params) + } +} diff --git a/rust/vecsim-c/src/query.rs b/rust/vecsim-c/src/query.rs new file mode 100644 index 000000000..4fba3a6a6 --- /dev/null +++ b/rust/vecsim-c/src/query.rs @@ -0,0 +1,153 @@ +//! Query operations and result handling for C FFI. + +use crate::index::{BatchIteratorWrapper, IndexHandle}; +use crate::params::VecSimQueryParams; +use crate::types::{ + QueryReplyInternal, QueryReplyIteratorInternal, QueryResultInternal, VecSimQueryReply_Code, + VecSimQueryReply_Order, +}; +use std::ffi::c_void; + +/// Query reply handle that owns the results. +pub struct QueryReplyHandle { + pub reply: QueryReplyInternal, +} + +impl QueryReplyHandle { + pub fn new(reply: QueryReplyInternal) -> Self { + Self { reply } + } + + pub fn len(&self) -> usize { + self.reply.len() + } + + pub fn is_empty(&self) -> bool { + self.reply.is_empty() + } + + pub fn sort_by_order(&mut self, order: VecSimQueryReply_Order) { + match order { + VecSimQueryReply_Order::BY_SCORE => self.reply.sort_by_score(), + VecSimQueryReply_Order::BY_ID => self.reply.sort_by_id(), + } + } + + pub fn get_iterator(&self) -> QueryReplyIteratorHandle<'_> { + QueryReplyIteratorHandle::new(&self.reply.results) + } + + pub fn code(&self) -> VecSimQueryReply_Code { + self.reply.code + } +} + +/// Iterator handle over query results. +pub struct QueryReplyIteratorHandle<'a> { + iter: QueryReplyIteratorInternal<'a>, +} + +impl<'a> QueryReplyIteratorHandle<'a> { + pub fn new(results: &'a [QueryResultInternal]) -> Self { + Self { + iter: QueryReplyIteratorInternal::new(results), + } + } + + pub fn next(&mut self) -> Option<&'a QueryResultInternal> { + self.iter.next() + } + + pub fn has_next(&self) -> bool { + self.iter.has_next() + } + + pub fn reset(&mut self) { + self.iter.reset(); + } +} + +/// Batch iterator handle. +pub struct BatchIteratorHandle { + pub inner: Box, + pub query_copy: Vec, +} + +impl BatchIteratorHandle { + pub fn new(inner: Box, query_copy: Vec) -> Self { + Self { inner, query_copy } + } + + pub fn has_next(&self) -> bool { + self.inner.has_next() + } + + pub fn next(&mut self, n: usize, order: VecSimQueryReply_Order) -> QueryReplyHandle { + let reply = self.inner.next_batch(n, order); + QueryReplyHandle::new(reply) + } + + pub fn reset(&mut self) { + self.inner.reset(); + } +} + +/// Perform a top-k query on an index. +pub fn top_k_query( + handle: &IndexHandle, + query: *const c_void, + k: usize, + params: Option<&VecSimQueryParams>, + order: VecSimQueryReply_Order, +) -> QueryReplyHandle { + let mut reply = handle.wrapper.top_k_query(query, k, params); + + // Sort by requested order + match order { + VecSimQueryReply_Order::BY_SCORE => reply.sort_by_score(), + VecSimQueryReply_Order::BY_ID => reply.sort_by_id(), + } + + QueryReplyHandle::new(reply) +} + +/// Perform a range query on an index. +pub fn range_query( + handle: &IndexHandle, + query: *const c_void, + radius: f64, + params: Option<&VecSimQueryParams>, + order: VecSimQueryReply_Order, +) -> QueryReplyHandle { + let mut reply = handle.wrapper.range_query(query, radius, params); + + // Sort by requested order + match order { + VecSimQueryReply_Order::BY_SCORE => reply.sort_by_score(), + VecSimQueryReply_Order::BY_ID => reply.sort_by_id(), + } + + QueryReplyHandle::new(reply) +} + +/// Create a batch iterator for the given index and query. +pub fn create_batch_iterator( + handle: &IndexHandle, + query: *const c_void, + params: Option<&VecSimQueryParams>, +) -> Option { + // Copy the query data since we need to own it + let dim = handle.wrapper.dimension(); + let elem_size = handle.data_type.element_size(); + let query_bytes = dim * elem_size; + + let query_copy = unsafe { + let slice = std::slice::from_raw_parts(query as *const u8, query_bytes); + slice.to_vec() + }; + + handle + .wrapper + .create_batch_iterator(query, params) + .map(|inner| BatchIteratorHandle::new(inner, query_copy)) +} diff --git a/rust/vecsim-c/src/types.rs b/rust/vecsim-c/src/types.rs new file mode 100644 index 000000000..54f19c7f9 --- /dev/null +++ b/rust/vecsim-c/src/types.rs @@ -0,0 +1,389 @@ +//! C-compatible type definitions for the VecSim FFI. + +/// Vector element data type. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimType { + VecSimType_FLOAT32 = 0, + VecSimType_FLOAT64 = 1, + VecSimType_BFLOAT16 = 2, + VecSimType_FLOAT16 = 3, + VecSimType_INT8 = 4, + VecSimType_UINT8 = 5, + VecSimType_INT32 = 6, + VecSimType_INT64 = 7, +} + +/// Index algorithm type. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimAlgo { + VecSimAlgo_BF = 0, + VecSimAlgo_HNSWLIB = 1, + VecSimAlgo_TIERED = 2, + VecSimAlgo_SVS = 3, +} + +/// Distance metric type. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimMetric { + VecSimMetric_L2 = 0, + VecSimMetric_IP = 1, + VecSimMetric_Cosine = 2, +} + +/// Query reply ordering. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimQueryReply_Order { + BY_SCORE = 0, + BY_ID = 1, +} + +/// Query reply status code (for detecting timeouts, etc.). +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum VecSimQueryReply_Code { + /// Query completed successfully. + #[default] + VecSim_QueryReply_OK = 0, + /// Query was aborted due to timeout. + VecSim_QueryReply_TimedOut = 1, +} + +/// Index resolve codes for resolving index state. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimResolveCode { + VecSim_Resolve_OK = 0, + VecSim_Resolve_ERR = 1, +} + +/// Parameter resolution error codes. +/// +/// These codes are returned by VecSimIndex_ResolveParams to indicate +/// the result of parsing runtime query parameters. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimParamResolveCode { + /// Parameter resolution succeeded. + VecSimParamResolver_OK = 0, + /// Null parameter pointer was passed. + VecSimParamResolverErr_NullParam = 1, + /// Parameter was already set (duplicate parameter). + VecSimParamResolverErr_AlreadySet = 2, + /// Unknown parameter name. + VecSimParamResolverErr_UnknownParam = 3, + /// Invalid parameter value. + VecSimParamResolverErr_BadValue = 4, + /// Invalid policy: policy does not exist. + VecSimParamResolverErr_InvalidPolicy_NExits = 5, + /// Invalid policy: not a hybrid query. + VecSimParamResolverErr_InvalidPolicy_NHybrid = 6, + /// Invalid policy: not a range query. + VecSimParamResolverErr_InvalidPolicy_NRange = 7, + /// Invalid policy: ad-hoc policy with batch size. + VecSimParamResolverErr_InvalidPolicy_AdHoc_With_BatchSize = 8, + /// Invalid policy: ad-hoc policy with ef_runtime. + VecSimParamResolverErr_InvalidPolicy_AdHoc_With_EfRuntime = 9, +} + +/// Query type for parameter resolution. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecsimQueryType { + /// No specific query type. + QUERY_TYPE_NONE = 0, + /// Standard KNN query. + QUERY_TYPE_KNN = 1, + /// Hybrid query (combining vector similarity with filters). + QUERY_TYPE_HYBRID = 2, + /// Range query. + QUERY_TYPE_RANGE = 3, +} + +/// Debug command return codes. +/// +/// These codes are returned by debug functions like VecSimDebug_GetElementNeighborsInHNSWGraph. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimDebugCommandCode { + /// Command succeeded. + VecSimDebugCommandCode_OK = 0, + /// Bad index (null, wrong type, or unsupported). + VecSimDebugCommandCode_BadIndex = 1, + /// Label does not exist in the index. + VecSimDebugCommandCode_LabelNotExists = 2, + /// Multi-value indices are not supported for this operation. + VecSimDebugCommandCode_MultiNotSupported = 3, +} + +/// Raw parameter for runtime query configuration. +/// +/// Used to pass string-based parameters that are resolved into typed +/// VecSimQueryParams by VecSimIndex_ResolveParams. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct VecSimRawParam { + /// Parameter name. + pub name: *const std::ffi::c_char, + /// Length of the parameter name. + pub nameLen: usize, + /// Parameter value as a string. + pub value: *const std::ffi::c_char, + /// Length of the parameter value. + pub valLen: usize, +} + +/// Write mode for tiered index operations. +/// +/// Controls whether vector additions/deletions go through the async +/// buffering path or directly to the backend index. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum VecSimWriteMode { + /// Async mode: vectors go to flat buffer, migrated to backend via background jobs. + #[default] + VecSim_WriteAsync = 0, + /// InPlace mode: vectors go directly to the backend index. + VecSim_WriteInPlace = 1, +} + +/// Option mode for various settings. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum VecSimOptionMode { + #[default] + VecSimOption_AUTO = 0, + VecSimOption_ENABLE = 1, + VecSimOption_DISABLE = 2, +} + +/// Tri-state boolean for optional settings. +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSimBool { + VecSimBool_TRUE = 1, + VecSimBool_FALSE = 0, + VecSimBool_UNSET = -1, +} + +/// Search mode for queries (used for debug/testing). +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VecSearchMode { + EMPTY_MODE = 0, + STANDARD_KNN = 1, + HYBRID_ADHOC_BF = 2, + HYBRID_BATCHES = 3, + HYBRID_BATCHES_TO_ADHOC_BF = 4, + RANGE_QUERY = 5, +} + +/// Timeout callback function type. +/// Returns non-zero on timeout. +pub type timeoutCallbackFunction = Option i32>; + +/// Log callback function type. +pub type logCallbackFunction = Option; + +/// Opaque index handle. +#[repr(C)] +pub struct VecSimIndex { + _private: [u8; 0], +} + +/// Opaque query reply handle. +#[repr(C)] +pub struct VecSimQueryReply { + _private: [u8; 0], +} + +/// Opaque query result handle. +#[repr(C)] +pub struct VecSimQueryResult { + _private: [u8; 0], +} + +/// Opaque query reply iterator handle. +#[repr(C)] +pub struct VecSimQueryReply_Iterator { + _private: [u8; 0], +} + +/// Opaque batch iterator handle. +#[repr(C)] +pub struct VecSimBatchIterator { + _private: [u8; 0], +} + +/// Label type for vectors. +pub type labelType = u64; + +/// Convert VecSimType to Rust type name for dispatch. +impl VecSimType { + pub fn element_size(&self) -> usize { + match self { + VecSimType::VecSimType_FLOAT32 => std::mem::size_of::(), + VecSimType::VecSimType_FLOAT64 => std::mem::size_of::(), + VecSimType::VecSimType_BFLOAT16 => 2, + VecSimType::VecSimType_FLOAT16 => 2, + VecSimType::VecSimType_INT8 => 1, + VecSimType::VecSimType_UINT8 => 1, + VecSimType::VecSimType_INT32 => std::mem::size_of::(), + VecSimType::VecSimType_INT64 => std::mem::size_of::(), + } + } +} + +impl VecSimMetric { + pub fn to_rust_metric(&self) -> vecsim::distance::Metric { + match self { + VecSimMetric::VecSimMetric_L2 => vecsim::distance::Metric::L2, + VecSimMetric::VecSimMetric_IP => vecsim::distance::Metric::InnerProduct, + VecSimMetric::VecSimMetric_Cosine => vecsim::distance::Metric::Cosine, + } + } +} + +impl From for VecSimMetric { + fn from(metric: vecsim::distance::Metric) -> Self { + match metric { + vecsim::distance::Metric::L2 => VecSimMetric::VecSimMetric_L2, + vecsim::distance::Metric::InnerProduct => VecSimMetric::VecSimMetric_IP, + vecsim::distance::Metric::Cosine => VecSimMetric::VecSimMetric_Cosine, + } + } +} + +/// Internal representation of a query result. +#[derive(Debug, Clone, Copy)] +pub struct QueryResultInternal { + pub id: labelType, + pub score: f64, +} + +/// Internal representation of query reply. +pub struct QueryReplyInternal { + pub results: Vec, + pub code: VecSimQueryReply_Code, +} + +impl QueryReplyInternal { + pub fn new() -> Self { + Self { + results: Vec::new(), + code: VecSimQueryReply_Code::VecSim_QueryReply_OK, + } + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + results: Vec::with_capacity(capacity), + code: VecSimQueryReply_Code::VecSim_QueryReply_OK, + } + } + + pub fn from_results(results: Vec) -> Self { + Self { + results, + code: VecSimQueryReply_Code::VecSim_QueryReply_OK, + } + } + + pub fn len(&self) -> usize { + self.results.len() + } + + pub fn is_empty(&self) -> bool { + self.results.is_empty() + } + + pub fn sort_by_score(&mut self) { + self.results.sort_by(|a, b| { + a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal) + }); + } + + pub fn sort_by_id(&mut self) { + self.results.sort_by_key(|r| r.id); + } +} + +/// Iterator over query reply results. +pub struct QueryReplyIteratorInternal<'a> { + results: &'a [QueryResultInternal], + position: usize, +} + +impl<'a> QueryReplyIteratorInternal<'a> { + pub fn new(results: &'a [QueryResultInternal]) -> Self { + Self { results, position: 0 } + } + + pub fn next(&mut self) -> Option<&'a QueryResultInternal> { + if self.position < self.results.len() { + let result = &self.results[self.position]; + self.position += 1; + Some(result) + } else { + None + } + } + + pub fn has_next(&self) -> bool { + self.position < self.results.len() + } + + pub fn reset(&mut self) { + self.position = 0; + } +} + +// ============================================================================ +// Memory Function Types +// ============================================================================ + +/// Function pointer type for malloc-style allocation. +pub type allocFn = Option *mut std::ffi::c_void>; + +/// Function pointer type for calloc-style allocation. +pub type callocFn = Option *mut std::ffi::c_void>; + +/// Function pointer type for realloc-style reallocation. +pub type reallocFn = Option *mut std::ffi::c_void>; + +/// Function pointer type for free-style deallocation. +pub type freeFn = Option; + +/// Memory functions struct for custom memory management. +/// +/// This allows integration with external memory management systems like Redis. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct VecSimMemoryFunctions { + /// Malloc-like allocation function. + pub allocFunction: allocFn, + /// Calloc-like allocation function. + pub callocFunction: callocFn, + /// Realloc-like reallocation function. + pub reallocFunction: reallocFn, + /// Free function. + pub freeFunction: freeFn, +} + +impl Default for VecSimMemoryFunctions { + fn default() -> Self { + Self { + allocFunction: None, + callocFunction: None, + reallocFunction: None, + freeFunction: None, + } + } +} diff --git a/rust/vecsim-python/Cargo.toml b/rust/vecsim-python/Cargo.toml new file mode 100644 index 000000000..df1905783 --- /dev/null +++ b/rust/vecsim-python/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "vecsim-python" +version = "0.1.0" +edition = "2021" + +[lib] +name = "VecSim" +crate-type = ["cdylib"] + +[dependencies] +vecsim = { path = "../vecsim" } +pyo3 = { version = "0.24", features = ["extension-module", "abi3-py38"] } +numpy = "0.24" +ndarray = "0.16" +half = { version = "2.4", features = ["num-traits"] } +rayon = "1.10" diff --git a/rust/vecsim-python/pyproject.toml b/rust/vecsim-python/pyproject.toml new file mode 100644 index 000000000..34d9c8fa9 --- /dev/null +++ b/rust/vecsim-python/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["maturin>=1.4"] +build-backend = "maturin" + +[project] +name = "VecSim" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", +] +dynamic = ["version"] + +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/rust/vecsim-python/src/lib.rs b/rust/vecsim-python/src/lib.rs new file mode 100644 index 000000000..eb3dbc230 --- /dev/null +++ b/rust/vecsim-python/src/lib.rs @@ -0,0 +1,3072 @@ +//! Python bindings for the VecSim library using PyO3. +//! +//! This module provides Python-compatible wrappers around the Rust VecSim library, +//! enabling high-performance vector similarity search from Python. + +// Allow non-snake-case names for Python API compatibility with the C++ version +#![allow(non_snake_case)] +// Allow unused fields that are kept for potential future use or debugging +#![allow(dead_code)] + +use numpy::{IntoPyArray, PyArray2, PyReadonlyArray1, PyReadonlyArray2, PyUntypedArrayMethods}; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use rayon::prelude::*; +use std::fs::File; +use std::io::{BufReader, BufWriter}; +// Note: Arc and Mutex may be used for future thread-safe batch operations +use vecsim::prelude::*; +use vecsim::index::svs::{SvsMulti, SvsParams, SvsSingle}; + +// ============================================================================ +// Constants +// ============================================================================ + +// Metric constants +const VECSIM_METRIC_L2: u32 = 0; +const VECSIM_METRIC_IP: u32 = 1; +const VECSIM_METRIC_COSINE: u32 = 2; + +// Type constants +const VECSIM_TYPE_FLOAT32: u32 = 0; +const VECSIM_TYPE_FLOAT64: u32 = 1; +const VECSIM_TYPE_BFLOAT16: u32 = 2; +const VECSIM_TYPE_FLOAT16: u32 = 3; +const VECSIM_TYPE_INT8: u32 = 4; +const VECSIM_TYPE_UINT8: u32 = 5; + +// Batch iterator order constants +const BY_SCORE: u32 = 0; +const BY_ID: u32 = 1; + +// SVS quantization constants (placeholders for tests) +const VECSIM_SVS_QUANT_NONE: u32 = 0; +const VECSIM_SVS_QUANT_8: u32 = 1; +const VECSIM_SVS_QUANT_4: u32 = 2; + +// VecSim option constants +const VECSIM_OPTION_OFF: u32 = 0; +const VECSIM_OPTION_ON: u32 = 1; +const VECSIM_OPTION_AUTO: u32 = 2; + +// ============================================================================ +// Helper functions +// ============================================================================ + +fn metric_from_u32(value: u32) -> PyResult { + match value { + VECSIM_METRIC_L2 => Ok(Metric::L2), + VECSIM_METRIC_IP => Ok(Metric::InnerProduct), + VECSIM_METRIC_COSINE => Ok(Metric::Cosine), + _ => Err(PyValueError::new_err(format!( + "Invalid metric value: {}", + value + ))), + } +} + +fn metric_to_u32(metric: Metric) -> u32 { + match metric { + Metric::L2 => VECSIM_METRIC_L2, + Metric::InnerProduct => VECSIM_METRIC_IP, + Metric::Cosine => VECSIM_METRIC_COSINE, + } +} + +/// Helper to create a 2D numpy array from vectors +fn vec_to_2d_array<'py, T: numpy::Element>( + py: Python<'py>, + data: Vec, + cols: usize, +) -> Bound<'py, PyArray2> { + let rows = if cols == 0 { 1 } else { 1 }; + let arr = ndarray::Array2::from_shape_vec((rows, cols), data).unwrap(); + arr.into_pyarray(py) +} + +// ============================================================================ +// Parameter Classes +// ============================================================================ + +/// Parameters for BruteForce index creation. +#[pyclass] +#[derive(Clone)] +pub struct BFParams { + #[pyo3(get, set)] + pub dim: usize, + #[pyo3(get, set)] + pub r#type: u32, + #[pyo3(get, set)] + pub metric: u32, + #[pyo3(get, set)] + pub multi: bool, + #[pyo3(get, set)] + pub blockSize: usize, +} + +#[pymethods] +impl BFParams { + #[new] + fn new() -> Self { + BFParams { + dim: 0, + r#type: VECSIM_TYPE_FLOAT32, + metric: VECSIM_METRIC_L2, + multi: false, + blockSize: 1024, + } + } +} + +/// Parameters for HNSW index creation. +#[pyclass] +#[derive(Clone)] +pub struct HNSWParams { + #[pyo3(get, set)] + pub dim: usize, + #[pyo3(get, set)] + pub r#type: u32, + #[pyo3(get, set)] + pub metric: u32, + #[pyo3(get, set)] + pub multi: bool, + #[pyo3(get, set)] + pub M: usize, + #[pyo3(get, set)] + pub efConstruction: usize, + #[pyo3(get, set)] + pub efRuntime: usize, + #[pyo3(get, set)] + pub epsilon: f64, +} + +#[pymethods] +impl HNSWParams { + #[new] + fn new() -> Self { + HNSWParams { + dim: 0, + r#type: VECSIM_TYPE_FLOAT32, + metric: VECSIM_METRIC_L2, + multi: false, + M: 16, + efConstruction: 200, + efRuntime: 10, + epsilon: 0.01, + } + } +} + +/// HNSW runtime parameters for query operations. +#[pyclass] +#[derive(Clone)] +pub struct HNSWRuntimeParams { + #[pyo3(get, set)] + pub efRuntime: usize, + #[pyo3(get, set)] + pub epsilon: f64, +} + +#[pymethods] +impl HNSWRuntimeParams { + #[new] + fn new() -> Self { + HNSWRuntimeParams { + efRuntime: 10, + epsilon: 0.01, + } + } +} + +/// SVS runtime parameters for query operations. +#[pyclass] +#[derive(Clone)] +pub struct SVSRuntimeParams { + #[pyo3(get, set)] + pub windowSize: usize, + #[pyo3(get, set)] + pub bufferCapacity: usize, + #[pyo3(get, set)] + pub epsilon: f64, + #[pyo3(get, set)] + pub searchHistory: u32, +} + +#[pymethods] +impl SVSRuntimeParams { + #[new] + fn new() -> Self { + SVSRuntimeParams { + windowSize: 100, + bufferCapacity: 100, + epsilon: 0.01, + searchHistory: VECSIM_OPTION_AUTO, + } + } +} + +/// Parameters for Tiered HNSW index. +#[pyclass] +#[derive(Clone)] +pub struct TieredHNSWParams { + #[pyo3(get, set)] + pub swapJobThreshold: usize, +} + +#[pymethods] +impl TieredHNSWParams { + #[new] + fn new() -> Self { + TieredHNSWParams { + swapJobThreshold: 0, + } + } +} + +/// Query parameters for search operations. +#[pyclass] +pub struct VecSimQueryParams { + hnsw_params: Py, + svs_params: Py, +} + +#[pymethods] +impl VecSimQueryParams { + #[new] + fn new(py: Python<'_>) -> PyResult { + let hnsw_params = Py::new(py, HNSWRuntimeParams::new())?; + let svs_params = Py::new(py, SVSRuntimeParams::new())?; + Ok(VecSimQueryParams { hnsw_params, svs_params }) + } + + /// Get the HNSW runtime parameters (returns a reference that can be mutated) + #[getter] + fn hnswRuntimeParams(&self, py: Python<'_>) -> Py { + self.hnsw_params.clone_ref(py) + } + + /// Set the HNSW runtime parameters + #[setter] + fn set_hnswRuntimeParams(&mut self, _py: Python<'_>, params: &Bound<'_, HNSWRuntimeParams>) { + self.hnsw_params = params.clone().unbind(); + } + + /// Get the SVS runtime parameters (returns a reference that can be mutated) + #[getter] + fn svsRuntimeParams(&self, py: Python<'_>) -> Py { + self.svs_params.clone_ref(py) + } + + /// Set the SVS runtime parameters + #[setter] + fn set_svsRuntimeParams(&mut self, _py: Python<'_>, params: &Bound<'_, SVSRuntimeParams>) { + self.svs_params = params.clone().unbind(); + } + + /// Helper to get efRuntime directly for internal use + fn get_ef_runtime(&self, py: Python<'_>) -> usize { + self.hnsw_params.borrow(py).efRuntime + } + + /// Helper to get SVS windowSize for internal use + fn get_svs_window_size(&self, py: Python<'_>) -> usize { + self.svs_params.borrow(py).windowSize + } + + /// Helper to get SVS epsilon for internal use + fn get_svs_epsilon(&self, py: Python<'_>) -> f64 { + self.svs_params.borrow(py).epsilon + } +} + +// ============================================================================ +// Type-erased Index Implementation +// ============================================================================ + +/// Macro to generate type-dispatched index operations +macro_rules! dispatch_bf_index { + ($self:expr, $method:ident $(, $args:expr)*) => { + match &$self.inner { + BfIndexInner::SingleF32(idx) => idx.$method($($args),*), + BfIndexInner::SingleF64(idx) => idx.$method($($args),*), + BfIndexInner::SingleBF16(idx) => idx.$method($($args),*), + BfIndexInner::SingleF16(idx) => idx.$method($($args),*), + BfIndexInner::SingleI8(idx) => idx.$method($($args),*), + BfIndexInner::SingleU8(idx) => idx.$method($($args),*), + BfIndexInner::MultiF32(idx) => idx.$method($($args),*), + BfIndexInner::MultiF64(idx) => idx.$method($($args),*), + BfIndexInner::MultiBF16(idx) => idx.$method($($args),*), + BfIndexInner::MultiF16(idx) => idx.$method($($args),*), + BfIndexInner::MultiI8(idx) => idx.$method($($args),*), + BfIndexInner::MultiU8(idx) => idx.$method($($args),*), + } + }; +} + +macro_rules! dispatch_bf_index_mut { + ($self:expr, $method:ident $(, $args:expr)*) => { + match &mut $self.inner { + BfIndexInner::SingleF32(idx) => idx.$method($($args),*), + BfIndexInner::SingleF64(idx) => idx.$method($($args),*), + BfIndexInner::SingleBF16(idx) => idx.$method($($args),*), + BfIndexInner::SingleF16(idx) => idx.$method($($args),*), + BfIndexInner::SingleI8(idx) => idx.$method($($args),*), + BfIndexInner::SingleU8(idx) => idx.$method($($args),*), + BfIndexInner::MultiF32(idx) => idx.$method($($args),*), + BfIndexInner::MultiF64(idx) => idx.$method($($args),*), + BfIndexInner::MultiBF16(idx) => idx.$method($($args),*), + BfIndexInner::MultiF16(idx) => idx.$method($($args),*), + BfIndexInner::MultiI8(idx) => idx.$method($($args),*), + BfIndexInner::MultiU8(idx) => idx.$method($($args),*), + } + }; +} + +enum BfIndexInner { + SingleF32(BruteForceSingle), + SingleF64(BruteForceSingle), + SingleBF16(BruteForceSingle), + SingleF16(BruteForceSingle), + SingleI8(BruteForceSingle), + SingleU8(BruteForceSingle), + MultiF32(BruteForceMulti), + MultiF64(BruteForceMulti), + MultiBF16(BruteForceMulti), + MultiF16(BruteForceMulti), + MultiI8(BruteForceMulti), + MultiU8(BruteForceMulti), +} + +/// BruteForce index for exact nearest neighbor search. +#[pyclass] +pub struct BFIndex { + inner: BfIndexInner, + data_type: u32, + metric: Metric, + dim: usize, +} + +#[pymethods] +impl BFIndex { + #[new] + fn new(params: &BFParams) -> PyResult { + let metric = metric_from_u32(params.metric)?; + let bf_params = BruteForceParams::new(params.dim, metric).with_capacity(params.blockSize); + + let inner = match (params.multi, params.r#type) { + (false, VECSIM_TYPE_FLOAT32) => BfIndexInner::SingleF32(BruteForceSingle::new(bf_params)), + (false, VECSIM_TYPE_FLOAT64) => BfIndexInner::SingleF64(BruteForceSingle::new(bf_params)), + (false, VECSIM_TYPE_BFLOAT16) => BfIndexInner::SingleBF16(BruteForceSingle::new(bf_params)), + (false, VECSIM_TYPE_FLOAT16) => BfIndexInner::SingleF16(BruteForceSingle::new(bf_params)), + (false, VECSIM_TYPE_INT8) => BfIndexInner::SingleI8(BruteForceSingle::new(bf_params)), + (false, VECSIM_TYPE_UINT8) => BfIndexInner::SingleU8(BruteForceSingle::new(bf_params)), + (true, VECSIM_TYPE_FLOAT32) => BfIndexInner::MultiF32(BruteForceMulti::new(bf_params)), + (true, VECSIM_TYPE_FLOAT64) => BfIndexInner::MultiF64(BruteForceMulti::new(bf_params)), + (true, VECSIM_TYPE_BFLOAT16) => BfIndexInner::MultiBF16(BruteForceMulti::new(bf_params)), + (true, VECSIM_TYPE_FLOAT16) => BfIndexInner::MultiF16(BruteForceMulti::new(bf_params)), + (true, VECSIM_TYPE_INT8) => BfIndexInner::MultiI8(BruteForceMulti::new(bf_params)), + (true, VECSIM_TYPE_UINT8) => BfIndexInner::MultiU8(BruteForceMulti::new(bf_params)), + _ => { + return Err(PyValueError::new_err(format!( + "Unsupported data type: {}", + params.r#type + ))) + } + }; + + Ok(BFIndex { + inner, + data_type: params.r#type, + metric, + dim: params.dim, + }) + } + + /// Add a vector to the index. + fn add_vector(&mut self, py: Python<'_>, vector: PyObject, label: u64) -> PyResult<()> { + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + match &mut self.inner { + BfIndexInner::SingleF32(idx) => idx.add_vector(slice, label), + BfIndexInner::MultiF32(idx) => idx.add_vector(slice, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT64 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + match &mut self.inner { + BfIndexInner::SingleF64(idx) => idx.add_vector(slice, label), + BfIndexInner::MultiF64(idx) => idx.add_vector(slice, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_BFLOAT16 => { + // Try to extract as u16 first, or use raw bytes for bfloat16 dtype + let bf16_vec = extract_bf16_vector(py, &vector)?; + match &mut self.inner { + BfIndexInner::SingleBF16(idx) => idx.add_vector(&bf16_vec, label), + BfIndexInner::MultiBF16(idx) => idx.add_vector(&bf16_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT16 => { + // Try to extract as u16 first, or use raw bytes for float16 dtype + let f16_vec = extract_f16_vector(py, &vector)?; + match &mut self.inner { + BfIndexInner::SingleF16(idx) => idx.add_vector(&f16_vec, label), + BfIndexInner::MultiF16(idx) => idx.add_vector(&f16_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_INT8 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + let i8_vec: Vec = slice.iter().map(|&v| Int8(v)).collect(); + match &mut self.inner { + BfIndexInner::SingleI8(idx) => idx.add_vector(&i8_vec, label), + BfIndexInner::MultiI8(idx) => idx.add_vector(&i8_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_UINT8 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + let u8_vec: Vec = slice.iter().map(|&v| UInt8(v)).collect(); + match &mut self.inner { + BfIndexInner::SingleU8(idx) => idx.add_vector(&u8_vec, label), + BfIndexInner::MultiU8(idx) => idx.add_vector(&u8_vec, label), + _ => unreachable!(), + } + } + _ => return Err(PyValueError::new_err("Unsupported data type")), + } + .map_err(|e| PyRuntimeError::new_err(format!("Failed to add vector: {:?}", e)))?; + Ok(()) + } + + /// Delete a vector from the index. + fn delete_vector(&mut self, label: u64) -> PyResult { + dispatch_bf_index_mut!(self, delete_vector, label) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to delete vector: {:?}", e))) + } + + /// Perform a k-nearest neighbors query. + #[pyo3(signature = (query, k=10))] + fn knn_query<'py>( + &self, + py: Python<'py>, + query: PyObject, + k: usize, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let query_vec = extract_query_vec(py, &query)?; + let reply = self.query_internal(&query_vec, k)?; + + let labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Perform a range query. + #[pyo3(signature = (query, radius))] + fn range_query<'py>( + &self, + py: Python<'py>, + query: PyObject, + radius: f64, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let query_vec = extract_query_vec(py, &query)?; + let reply = self.range_query_internal(&query_vec, radius)?; + + let labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Get the number of vectors in the index. + fn index_size(&self) -> usize { + dispatch_bf_index!(self, index_size) + } + + /// Get the data type of vectors in the index. + fn index_type(&self) -> u32 { + self.data_type + } + + /// Create a batch iterator for streaming results. + fn create_batch_iterator(&self, py: Python<'_>, query: PyObject) -> PyResult { + let query_vec = extract_query_vec(py, &query)?; + let all_results = self.get_all_results_sorted(&query_vec)?; + let index_size = self.index_size(); + + Ok(PyBatchIterator { + results: all_results, + position: 0, + index_size, + }) + } +} + +/// Extract BFloat16 vector from numpy array (handles ml_dtypes.bfloat16) +fn extract_bf16_vector(py: Python<'_>, vector: &PyObject) -> PyResult> { + // First try as u16 (raw bits) + if let Ok(arr) = vector.extract::>(py) { + let slice = arr.as_slice()?; + return Ok(slice.iter().map(|&v| BFloat16::from_bits(v)).collect()); + } + + // Otherwise get raw bytes and interpret as u16 + let arr = vector.bind(py); + let nbytes: usize = arr.getattr("nbytes")?.extract()?; + let len = nbytes / 2; // 2 bytes per u16 + + // Get the data buffer as bytes and reinterpret + let tobytes_fn = arr.getattr("tobytes")?; + let bytes: Vec = tobytes_fn.call0()?.extract()?; + + let bf16_vec: Vec = bytes + .chunks_exact(2) + .map(|chunk| { + let bits = u16::from_le_bytes([chunk[0], chunk[1]]); + BFloat16::from_bits(bits) + }) + .collect(); + + if bf16_vec.len() != len { + return Err(PyValueError::new_err("Failed to convert to BFloat16 vector")); + } + + Ok(bf16_vec) +} + +/// Extract Float16 vector from numpy array (handles np.float16) +fn extract_f16_vector(py: Python<'_>, vector: &PyObject) -> PyResult> { + // First try as u16 (raw bits) + if let Ok(arr) = vector.extract::>(py) { + let slice = arr.as_slice()?; + return Ok(slice.iter().map(|&v| Float16::from_bits(v)).collect()); + } + + // Otherwise get raw bytes and interpret as u16 (works for np.float16) + let arr = vector.bind(py); + let nbytes: usize = arr.getattr("nbytes")?.extract()?; + let len = nbytes / 2; // 2 bytes per u16 + + // Get the data buffer as bytes and reinterpret + let tobytes_fn = arr.getattr("tobytes")?; + let bytes: Vec = tobytes_fn.call0()?.extract()?; + + let f16_vec: Vec = bytes + .chunks_exact(2) + .map(|chunk| { + let bits = u16::from_le_bytes([chunk[0], chunk[1]]); + Float16::from_bits(bits) + }) + .collect(); + + if f16_vec.len() != len { + return Err(PyValueError::new_err("Failed to convert to Float16 vector")); + } + + Ok(f16_vec) +} + +/// Extract query vector from various numpy array types. +/// For 2D arrays, extracts only the first row (single query vector). +fn extract_query_vec(py: Python<'_>, query: &PyObject) -> PyResult> { + if let Ok(arr) = query.extract::>(py) { + // For 2D array, take only the first row + let shape = arr.shape(); + let ncols = shape[1]; + Ok(arr.as_slice()?[..ncols].to_vec()) + } else if let Ok(arr) = query.extract::>(py) { + Ok(arr.as_slice()?.to_vec()) + } else if let Ok(arr) = query.extract::>(py) { + // For 2D array, take only the first row + let shape = arr.shape(); + let ncols = shape[1]; + Ok(arr.as_slice()?[..ncols].iter().map(|&v| v as f64).collect()) + } else if let Ok(arr) = query.extract::>(py) { + Ok(arr.as_slice()?.iter().map(|&v| v as f64).collect()) + } else if let Ok(arr) = query.extract::>(py) { + // For 2D array, take only the first row + let shape = arr.shape(); + let ncols = shape[1]; + Ok(arr.as_slice()?[..ncols].iter().map(|&v| half::f16::from_bits(v).to_f64()).collect()) + } else if let Ok(arr) = query.extract::>(py) { + Ok(arr.as_slice()?.iter().map(|&v| half::f16::from_bits(v).to_f64()).collect()) + } else if let Ok(arr) = query.extract::>(py) { + // For 2D array, take only the first row + let shape = arr.shape(); + let ncols = shape[1]; + Ok(arr.as_slice()?[..ncols].iter().map(|&v| v as f64).collect()) + } else if let Ok(arr) = query.extract::>(py) { + Ok(arr.as_slice()?.iter().map(|&v| v as f64).collect()) + } else if let Ok(arr) = query.extract::>(py) { + // For 2D array, take only the first row + let shape = arr.shape(); + let ncols = shape[1]; + Ok(arr.as_slice()?[..ncols].iter().map(|&v| v as f64).collect()) + } else if let Ok(arr) = query.extract::>(py) { + Ok(arr.as_slice()?.iter().map(|&v| v as f64).collect()) + } else { + // Fallback: try to interpret as half-precision (bfloat16 or float16) + // Both are 16-bit types stored in 2 bytes + let arr = query.bind(py); + let itemsize: usize = arr.getattr("itemsize")?.extract()?; + if itemsize == 2 { + // Check the dtype name to distinguish bfloat16 from float16 + let dtype = arr.getattr("dtype")?; + let dtype_name: String = dtype.getattr("name")?.extract()?; + + // Get array shape to handle 2D arrays (take only first row) + let shape: Vec = arr.getattr("shape")?.extract()?; + let num_elements = if shape.len() == 2 { + shape[1] // For 2D array, take only first row + } else if shape.len() == 1 { + shape[0] + } else { + return Err(PyValueError::new_err("Query array must be 1D or 2D")); + }; + + let tobytes_fn = arr.getattr("tobytes")?; + let bytes: Vec = tobytes_fn.call0()?.extract()?; + + // Only take bytes for the first row + let bytes_to_use = &bytes[..num_elements * 2]; + + let result: Vec = if dtype_name.contains("bfloat16") { + // BFloat16: 1 sign + 8 exponent + 7 mantissa + bytes_to_use + .chunks_exact(2) + .map(|chunk| { + let bits = u16::from_le_bytes([chunk[0], chunk[1]]); + half::bf16::from_bits(bits).to_f64() + }) + .collect() + } else { + // Float16 (IEEE 754): 1 sign + 5 exponent + 10 mantissa + bytes_to_use + .chunks_exact(2) + .map(|chunk| { + let bits = u16::from_le_bytes([chunk[0], chunk[1]]); + half::f16::from_bits(bits).to_f64() + }) + .collect() + }; + return Ok(result); + } + Err(PyValueError::new_err("Unsupported query array type")) + } +} + +impl BFIndex { + fn query_internal(&self, query: &[f64], k: usize) -> PyResult> { + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let q: Vec = query.iter().map(|&v| v as f32).collect(); + let reply = match &self.inner { + BfIndexInner::SingleF32(idx) => idx.top_k_query(&q, k, None), + BfIndexInner::MultiF32(idx) => idx.top_k_query(&q, k, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance as f64)) + .collect(), + )) + } + VECSIM_TYPE_FLOAT64 => { + let reply = match &self.inner { + BfIndexInner::SingleF64(idx) => idx.top_k_query(query, k, None), + BfIndexInner::MultiF64(idx) => idx.top_k_query(query, k, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(reply) + } + VECSIM_TYPE_BFLOAT16 => { + let q: Vec = query.iter().map(|&v| BFloat16::from_f32(v as f32)).collect(); + let reply = match &self.inner { + BfIndexInner::SingleBF16(idx) => idx.top_k_query(&q, k, None), + BfIndexInner::MultiBF16(idx) => idx.top_k_query(&q, k, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance.to_f64())) + .collect(), + )) + } + VECSIM_TYPE_FLOAT16 => { + let q: Vec = query.iter().map(|&v| Float16::from_f32(v as f32)).collect(); + let reply = match &self.inner { + BfIndexInner::SingleF16(idx) => idx.top_k_query(&q, k, None), + BfIndexInner::MultiF16(idx) => idx.top_k_query(&q, k, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance.to_f64())) + .collect(), + )) + } + VECSIM_TYPE_INT8 => { + let q: Vec = query.iter().map(|&v| Int8(v as i8)).collect(); + let reply = match &self.inner { + BfIndexInner::SingleI8(idx) => idx.top_k_query(&q, k, None), + BfIndexInner::MultiI8(idx) => idx.top_k_query(&q, k, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance as f64)) + .collect(), + )) + } + VECSIM_TYPE_UINT8 => { + let q: Vec = query.iter().map(|&v| UInt8(v as u8)).collect(); + let reply = match &self.inner { + BfIndexInner::SingleU8(idx) => idx.top_k_query(&q, k, None), + BfIndexInner::MultiU8(idx) => idx.top_k_query(&q, k, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance as f64)) + .collect(), + )) + } + _ => Err(PyValueError::new_err("Unsupported data type")), + } + } + + fn range_query_internal(&self, query: &[f64], radius: f64) -> PyResult> { + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let q: Vec = query.iter().map(|&v| v as f32).collect(); + let reply = match &self.inner { + BfIndexInner::SingleF32(idx) => idx.range_query(&q, radius as f32, None), + BfIndexInner::MultiF32(idx) => idx.range_query(&q, radius as f32, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance as f64)) + .collect(), + )) + } + VECSIM_TYPE_FLOAT64 => { + let reply = match &self.inner { + BfIndexInner::SingleF64(idx) => idx.range_query(query, radius, None), + BfIndexInner::MultiF64(idx) => idx.range_query(query, radius, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(reply) + } + VECSIM_TYPE_BFLOAT16 => { + let q: Vec = query.iter().map(|&v| BFloat16::from_f32(v as f32)).collect(); + let reply = match &self.inner { + BfIndexInner::SingleBF16(idx) => { + idx.range_query(&q, radius as f32, None) + } + BfIndexInner::MultiBF16(idx) => { + idx.range_query(&q, radius as f32, None) + } + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance.to_f64())) + .collect(), + )) + } + VECSIM_TYPE_FLOAT16 => { + let q: Vec = query.iter().map(|&v| Float16::from_f32(v as f32)).collect(); + let reply = match &self.inner { + BfIndexInner::SingleF16(idx) => { + idx.range_query(&q, radius as f32, None) + } + BfIndexInner::MultiF16(idx) => { + idx.range_query(&q, radius as f32, None) + } + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance.to_f64())) + .collect(), + )) + } + VECSIM_TYPE_INT8 => { + let q: Vec = query.iter().map(|&v| Int8(v as i8)).collect(); + let reply = match &self.inner { + BfIndexInner::SingleI8(idx) => idx.range_query(&q, radius as f32, None), + BfIndexInner::MultiI8(idx) => idx.range_query(&q, radius as f32, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance as f64)) + .collect(), + )) + } + VECSIM_TYPE_UINT8 => { + let q: Vec = query.iter().map(|&v| UInt8(v as u8)).collect(); + let reply = match &self.inner { + BfIndexInner::SingleU8(idx) => idx.range_query(&q, radius as f32, None), + BfIndexInner::MultiU8(idx) => idx.range_query(&q, radius as f32, None), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results( + reply + .results + .into_iter() + .map(|r| QueryResult::new(r.label, r.distance as f64)) + .collect(), + )) + } + _ => Err(PyValueError::new_err("Unsupported data type")), + } + } + + fn get_all_results_sorted(&self, query: &[f64]) -> PyResult> { + let k = self.index_size(); + if k == 0 { + return Ok(Vec::new()); + } + let reply = self.query_internal(query, k)?; + Ok(reply + .results + .into_iter() + .map(|r| (r.label, r.distance)) + .collect()) + } +} + +// ============================================================================ +// HNSW Index +// ============================================================================ + +enum HnswIndexInner { + SingleF32(HnswSingle), + SingleF64(HnswSingle), + SingleBF16(HnswSingle), + SingleF16(HnswSingle), + SingleI8(HnswSingle), + SingleU8(HnswSingle), + MultiF32(HnswMulti), + MultiF64(HnswMulti), + MultiBF16(HnswMulti), + MultiF16(HnswMulti), + MultiI8(HnswMulti), + MultiU8(HnswMulti), +} + +macro_rules! dispatch_hnsw_index { + ($self:expr, $method:ident $(, $args:expr)*) => { + match &$self.inner { + HnswIndexInner::SingleF32(idx) => idx.$method($($args),*), + HnswIndexInner::SingleF64(idx) => idx.$method($($args),*), + HnswIndexInner::SingleBF16(idx) => idx.$method($($args),*), + HnswIndexInner::SingleF16(idx) => idx.$method($($args),*), + HnswIndexInner::SingleI8(idx) => idx.$method($($args),*), + HnswIndexInner::SingleU8(idx) => idx.$method($($args),*), + HnswIndexInner::MultiF32(idx) => idx.$method($($args),*), + HnswIndexInner::MultiF64(idx) => idx.$method($($args),*), + HnswIndexInner::MultiBF16(idx) => idx.$method($($args),*), + HnswIndexInner::MultiF16(idx) => idx.$method($($args),*), + HnswIndexInner::MultiI8(idx) => idx.$method($($args),*), + HnswIndexInner::MultiU8(idx) => idx.$method($($args),*), + } + }; +} + +macro_rules! dispatch_hnsw_index_mut { + ($self:expr, $method:ident $(, $args:expr)*) => { + match &mut $self.inner { + HnswIndexInner::SingleF32(idx) => idx.$method($($args),*), + HnswIndexInner::SingleF64(idx) => idx.$method($($args),*), + HnswIndexInner::SingleBF16(idx) => idx.$method($($args),*), + HnswIndexInner::SingleF16(idx) => idx.$method($($args),*), + HnswIndexInner::SingleI8(idx) => idx.$method($($args),*), + HnswIndexInner::SingleU8(idx) => idx.$method($($args),*), + HnswIndexInner::MultiF32(idx) => idx.$method($($args),*), + HnswIndexInner::MultiF64(idx) => idx.$method($($args),*), + HnswIndexInner::MultiBF16(idx) => idx.$method($($args),*), + HnswIndexInner::MultiF16(idx) => idx.$method($($args),*), + HnswIndexInner::MultiI8(idx) => idx.$method($($args),*), + HnswIndexInner::MultiU8(idx) => idx.$method($($args),*), + } + }; +} + +/// HNSW index for approximate nearest neighbor search. +#[pyclass] +pub struct HNSWIndex { + inner: HnswIndexInner, + data_type: u32, + metric: Metric, + dim: usize, + multi: bool, + ef_runtime: usize, +} + +#[pymethods] +impl HNSWIndex { + /// Create a new HNSW index from parameters or load from file. + #[new] + #[pyo3(signature = (params_or_path))] + fn new(params_or_path: &Bound<'_, PyAny>) -> PyResult { + if let Ok(path) = params_or_path.extract::() { + Self::load_from_file(&path) + } else if let Ok(params) = params_or_path.extract::() { + Self::create_new(¶ms) + } else { + Err(PyValueError::new_err( + "Expected HNSWParams or file path string", + )) + } + } + + /// Add a vector to the index. + fn add_vector(&mut self, py: Python<'_>, vector: PyObject, label: u64) -> PyResult<()> { + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + match &mut self.inner { + HnswIndexInner::SingleF32(idx) => idx.add_vector(slice, label), + HnswIndexInner::MultiF32(idx) => idx.add_vector(slice, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT64 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + match &mut self.inner { + HnswIndexInner::SingleF64(idx) => idx.add_vector(slice, label), + HnswIndexInner::MultiF64(idx) => idx.add_vector(slice, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_BFLOAT16 => { + let bf16_vec = extract_bf16_vector(py, &vector)?; + match &mut self.inner { + HnswIndexInner::SingleBF16(idx) => idx.add_vector(&bf16_vec, label), + HnswIndexInner::MultiBF16(idx) => idx.add_vector(&bf16_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT16 => { + let f16_vec = extract_f16_vector(py, &vector)?; + match &mut self.inner { + HnswIndexInner::SingleF16(idx) => idx.add_vector(&f16_vec, label), + HnswIndexInner::MultiF16(idx) => idx.add_vector(&f16_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_INT8 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + let i8_vec: Vec = slice.iter().map(|&v| Int8(v)).collect(); + match &mut self.inner { + HnswIndexInner::SingleI8(idx) => idx.add_vector(&i8_vec, label), + HnswIndexInner::MultiI8(idx) => idx.add_vector(&i8_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_UINT8 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + let u8_vec: Vec = slice.iter().map(|&v| UInt8(v)).collect(); + match &mut self.inner { + HnswIndexInner::SingleU8(idx) => idx.add_vector(&u8_vec, label), + HnswIndexInner::MultiU8(idx) => idx.add_vector(&u8_vec, label), + _ => unreachable!(), + } + } + _ => return Err(PyValueError::new_err("Unsupported data type")), + } + .map_err(|e| PyRuntimeError::new_err(format!("Failed to add vector: {:?}", e)))?; + Ok(()) + } + + /// Delete a vector from the index. + fn delete_vector(&mut self, label: u64) -> PyResult { + dispatch_hnsw_index_mut!(self, delete_vector, label) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to delete vector: {:?}", e))) + } + + /// Set the ef_runtime parameter for queries. + fn set_ef(&mut self, ef: usize) { + self.ef_runtime = ef; + } + + /// Perform a k-nearest neighbors query. + #[pyo3(signature = (query, k=10))] + fn knn_query<'py>( + &self, + py: Python<'py>, + query: PyObject, + k: usize, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let query_vec = extract_query_vec(py, &query)?; + let params = QueryParams::new().with_ef_runtime(self.ef_runtime); + let reply = self.query_internal(&query_vec, k, Some(¶ms))?; + + let labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Perform a range query. + #[pyo3(signature = (query, radius, query_param=None))] + fn range_query<'py>( + &self, + py: Python<'py>, + query: PyObject, + radius: f64, + query_param: Option<&VecSimQueryParams>, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let query_vec = extract_query_vec(py, &query)?; + let ef = query_param + .map(|p| p.get_ef_runtime(py)) + .unwrap_or(self.ef_runtime); + let params = QueryParams::new().with_ef_runtime(ef); + let reply = self.range_query_internal(&query_vec, radius, Some(¶ms))?; + + let labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Get the number of vectors in the index. + fn index_size(&self) -> usize { + dispatch_hnsw_index!(self, index_size) + } + + /// Get the data type of vectors in the index. + fn index_type(&self) -> u32 { + self.data_type + } + + /// Check the integrity of the index. + fn check_integrity(&self) -> bool { + // Basic integrity check - verify the index is in a valid state + // For now, just return true as we don't have deep validation + true + } + + /// Save the index to a file. + fn save_index(&self, path: &str) -> PyResult<()> { + let file = File::create(path) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to create file: {}", e)))?; + let mut writer = BufWriter::new(file); + + use std::io::Write; + writer.write_all(&self.data_type.to_le_bytes()) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to write: {}", e)))?; + writer.write_all(&metric_to_u32(self.metric).to_le_bytes()) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to write: {}", e)))?; + writer.write_all(&(self.dim as u32).to_le_bytes()) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to write: {}", e)))?; + writer.write_all(&[self.multi as u8]) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to write: {}", e)))?; + writer.write_all(&(self.ef_runtime as u32).to_le_bytes()) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to write: {}", e)))?; + + match &self.inner { + HnswIndexInner::SingleF32(idx) => idx.save(&mut writer), + HnswIndexInner::SingleF64(idx) => idx.save(&mut writer), + HnswIndexInner::SingleBF16(idx) => idx.save(&mut writer), + HnswIndexInner::SingleF16(idx) => idx.save(&mut writer), + HnswIndexInner::SingleI8(idx) => idx.save(&mut writer), + HnswIndexInner::SingleU8(idx) => idx.save(&mut writer), + HnswIndexInner::MultiF32(idx) => idx.save(&mut writer), + HnswIndexInner::MultiF64(idx) => idx.save(&mut writer), + HnswIndexInner::MultiBF16(idx) => idx.save(&mut writer), + HnswIndexInner::MultiF16(idx) => idx.save(&mut writer), + HnswIndexInner::MultiI8(idx) => idx.save(&mut writer), + HnswIndexInner::MultiU8(idx) => idx.save(&mut writer), + } + .map_err(|e| PyRuntimeError::new_err(format!("Failed to save index: {:?}", e)))?; + + Ok(()) + } + + /// Create a batch iterator for streaming results. + /// The ef_runtime parameter affects the quality of results: + /// - Lower ef = narrower search beam = potentially worse results + /// - Higher ef = wider search beam = better results (up to a point) + #[pyo3(signature = (query, query_params=None))] + fn create_batch_iterator( + &self, + py: Python<'_>, + query: PyObject, + query_params: Option<&VecSimQueryParams>, + ) -> PyResult { + let query_vec = extract_query_vec(py, &query)?; + let ef = query_params + .map(|p| p.get_ef_runtime(py)) + .unwrap_or(self.ef_runtime); + + let all_results = self.get_batch_results(&query_vec, ef)?; + let index_size = self.index_size(); + + Ok(PyBatchIterator::new(all_results, index_size)) + } + + /// Perform parallel k-nearest neighbors queries. + #[pyo3(signature = (queries, k=10, num_threads=None))] + fn knn_parallel<'py>( + &self, + py: Python<'py>, + queries: PyObject, + k: usize, + num_threads: Option, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + // Extract 2D query array - try f32 first, then f64 + let (num_queries, queries_data): (usize, Vec>) = + if let Ok(queries_arr) = queries.extract::>(py) { + let shape = queries_arr.shape(); + let num_queries = shape[0]; + let query_dim = shape[1]; + if query_dim != self.dim { + return Err(PyValueError::new_err(format!( + "Query dimension {} does not match index dimension {}", + query_dim, self.dim + ))); + } + let queries_slice = queries_arr.as_slice()?; + let data = (0..num_queries) + .map(|i| { + let start = i * query_dim; + let end = start + query_dim; + queries_slice[start..end].iter().map(|&x| x as f64).collect() + }) + .collect(); + (num_queries, data) + } else if let Ok(queries_arr) = queries.extract::>(py) { + let shape = queries_arr.shape(); + let num_queries = shape[0]; + let query_dim = shape[1]; + if query_dim != self.dim { + return Err(PyValueError::new_err(format!( + "Query dimension {} does not match index dimension {}", + query_dim, self.dim + ))); + } + let queries_slice = queries_arr.as_slice()?; + let data = (0..num_queries) + .map(|i| { + let start = i * query_dim; + let end = start + query_dim; + queries_slice[start..end].to_vec() + }) + .collect(); + (num_queries, data) + } else { + return Err(PyValueError::new_err("Query array must be 2D float32 or float64")); + }; + + // Set up thread pool with specified number of threads + let pool = if let Some(n) = num_threads { + rayon::ThreadPoolBuilder::new() + .num_threads(n) + .build() + .map_err(|e| PyRuntimeError::new_err(format!("Failed to create thread pool: {}", e)))? + } else { + rayon::ThreadPoolBuilder::new() + .build() + .map_err(|e| PyRuntimeError::new_err(format!("Failed to create thread pool: {}", e)))? + }; + + let params = QueryParams::new().with_ef_runtime(self.ef_runtime); + + // Run queries in parallel + let results: Vec, String>> = pool.install(|| { + queries_data + .par_iter() + .map(|query| { + self.query_internal(query, k, Some(¶ms)) + .map_err(|e| e.to_string()) + }) + .collect() + }); + + // Convert results to numpy arrays + let mut all_labels: Vec = Vec::with_capacity(num_queries * k); + let mut all_distances: Vec = Vec::with_capacity(num_queries * k); + + for result in results { + let reply = result.map_err(|e| PyRuntimeError::new_err(e))?; + let mut labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let mut distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + + // Pad to k results if needed + while labels.len() < k { + labels.push(-1); + distances.push(f64::MAX); + } + + all_labels.extend(labels); + all_distances.extend(distances); + } + + // Reshape to (num_queries, k) + let labels_array = ndarray::Array2::from_shape_vec((num_queries, k), all_labels) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to reshape labels: {}", e)))?; + let distances_array = ndarray::Array2::from_shape_vec((num_queries, k), all_distances) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to reshape distances: {}", e)))?; + + Ok((labels_array.into_pyarray(py), distances_array.into_pyarray(py))) + } + + /// Add multiple vectors to the index. + /// Note: Currently uses sequential insertion because true parallel HNSW insertion + /// requires sophisticated synchronization to maintain graph quality. Concurrent + /// insertions where threads don't see each other's nodes during neighbor search + /// lead to poor graph connectivity and low recall. The underlying Rust infrastructure + /// supports concurrent access for read/write operations using batch construction. + #[pyo3(signature = (vectors, labels, num_threads=None))] + fn add_vector_parallel( + &self, + py: Python<'_>, + vectors: PyObject, + labels: PyObject, + num_threads: Option, + ) -> PyResult<()> { + // Set thread pool size if specified + if let Some(threads) = num_threads { + rayon::ThreadPoolBuilder::new() + .num_threads(threads) + .build_global() + .ok(); // Ignore error if already initialized + } + + // Extract labels array - try i64 first, then i32 + let labels_vec: Vec = if let Ok(labels_arr) = labels.extract::>(py) { + labels_arr.as_slice()?.iter().map(|&l| l as u64).collect() + } else if let Ok(labels_arr) = labels.extract::>(py) { + labels_arr.as_slice()?.iter().map(|&l| l as u64).collect() + } else { + // Try to extract as a Python iterable of integers + let labels_list: Vec = labels.extract(py)?; + labels_list.into_iter().map(|l| l as u64).collect() + }; + let num_vectors = labels_vec.len(); + + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let vectors_arr: PyReadonlyArray2 = vectors.extract(py)?; + let shape = vectors_arr.shape(); + if shape[0] != num_vectors { + return Err(PyValueError::new_err(format!( + "Number of vectors {} does not match number of labels {}", + shape[0], num_vectors + ))); + } + let dim = shape[1]; + let slice = vectors_arr.as_slice()?; + + // Use batch construction for parallel speedup + match &self.inner { + HnswIndexInner::SingleF32(idx) => { + idx.add_vectors_batch(slice, &labels_vec, dim) + .map_err(|e| PyRuntimeError::new_err(format!("Batch insert failed: {:?}", e)))?; + } + HnswIndexInner::MultiF32(idx) => { + // Multi doesn't have batch yet, fall back to per-element + for i in 0..num_vectors { + let start = i * dim; + let end = start + dim; + let vec = &slice[start..end]; + let label = labels_vec[i]; + let _ = idx.add_vector_concurrent(vec, label); + } + } + _ => {} + } + } + VECSIM_TYPE_FLOAT64 => { + let vectors_arr: PyReadonlyArray2 = vectors.extract(py)?; + let shape = vectors_arr.shape(); + if shape[0] != num_vectors { + return Err(PyValueError::new_err(format!( + "Number of vectors {} does not match number of labels {}", + shape[0], num_vectors + ))); + } + let dim = shape[1]; + let slice = vectors_arr.as_slice()?; + + // Use batch construction for parallel speedup + match &self.inner { + HnswIndexInner::SingleF64(idx) => { + idx.add_vectors_batch(slice, &labels_vec, dim) + .map_err(|e| PyRuntimeError::new_err(format!("Batch insert failed: {:?}", e)))?; + } + HnswIndexInner::MultiF64(idx) => { + // Multi doesn't have batch yet, fall back to per-element + for i in 0..num_vectors { + let start = i * dim; + let end = start + dim; + let vec = &slice[start..end]; + let label = labels_vec[i]; + let _ = idx.add_vector_concurrent(vec, label); + } + } + _ => {} + } + } + _ => { + return Err(PyValueError::new_err( + "Parallel insert only supported for FLOAT32 and FLOAT64 types", + )); + } + } + + Ok(()) + } + + /// Perform parallel range queries. + #[pyo3(signature = (queries, radius, num_threads=None))] + fn range_parallel<'py>( + &self, + py: Python<'py>, + queries: PyObject, + radius: f64, + num_threads: Option, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + // Extract 2D query array - try f32 first, then f64 + let (num_queries, queries_data): (usize, Vec>) = + if let Ok(queries_arr) = queries.extract::>(py) { + let shape = queries_arr.shape(); + let num_queries = shape[0]; + let query_dim = shape[1]; + if query_dim != self.dim { + return Err(PyValueError::new_err(format!( + "Query dimension {} does not match index dimension {}", + query_dim, self.dim + ))); + } + let queries_slice = queries_arr.as_slice()?; + let data = (0..num_queries) + .map(|i| { + let start = i * query_dim; + let end = start + query_dim; + queries_slice[start..end].iter().map(|&x| x as f64).collect() + }) + .collect(); + (num_queries, data) + } else if let Ok(queries_arr) = queries.extract::>(py) { + let shape = queries_arr.shape(); + let num_queries = shape[0]; + let query_dim = shape[1]; + if query_dim != self.dim { + return Err(PyValueError::new_err(format!( + "Query dimension {} does not match index dimension {}", + query_dim, self.dim + ))); + } + let queries_slice = queries_arr.as_slice()?; + let data = (0..num_queries) + .map(|i| { + let start = i * query_dim; + let end = start + query_dim; + queries_slice[start..end].to_vec() + }) + .collect(); + (num_queries, data) + } else { + return Err(PyValueError::new_err("Query array must be 2D float32 or float64")); + }; + + // Set up thread pool + let pool = if let Some(n) = num_threads { + rayon::ThreadPoolBuilder::new() + .num_threads(n) + .build() + .map_err(|e| PyRuntimeError::new_err(format!("Failed to create thread pool: {}", e)))? + } else { + rayon::ThreadPoolBuilder::new() + .build() + .map_err(|e| PyRuntimeError::new_err(format!("Failed to create thread pool: {}", e)))? + }; + + let params = QueryParams::new().with_ef_runtime(self.ef_runtime); + + // Run queries in parallel + let results: Vec, String>> = pool.install(|| { + queries_data + .par_iter() + .map(|query| { + self.range_query_internal(query, radius, Some(¶ms)) + .map_err(|e| e.to_string()) + }) + .collect() + }); + + // Find max results length for padding + let max_results = results + .iter() + .filter_map(|r| r.as_ref().ok()) + .map(|r| r.results.len()) + .max() + .unwrap_or(0); + + // Convert results to numpy arrays with padding + let result_len = max_results.max(1); // At least 1 column + let mut all_labels: Vec = Vec::with_capacity(num_queries * result_len); + let mut all_distances: Vec = Vec::with_capacity(num_queries * result_len); + + for result in results { + let reply = result.map_err(|e| PyRuntimeError::new_err(e))?; + let mut labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let mut distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + + // Pad to max_results with -1 for labels + while labels.len() < result_len { + labels.push(-1); + distances.push(f64::MAX); + } + + all_labels.extend(labels); + all_distances.extend(distances); + } + + // Reshape to (num_queries, result_len) + let labels_array = ndarray::Array2::from_shape_vec((num_queries, result_len), all_labels) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to reshape labels: {}", e)))?; + let distances_array = ndarray::Array2::from_shape_vec((num_queries, result_len), all_distances) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to reshape distances: {}", e)))?; + + Ok((labels_array.into_pyarray(py), distances_array.into_pyarray(py))) + } + + /// Get a vector by its label. + fn get_vector<'py>(&self, py: Python<'py>, label: u64) -> PyResult>> { + let vectors: Vec> = match self.data_type { + VECSIM_TYPE_FLOAT32 => { + match &self.inner { + HnswIndexInner::SingleF32(idx) => { + idx.get_vector(label) + .map(|v| vec![v.into_iter().map(|x| x as f64).collect()]) + .unwrap_or_default() + } + HnswIndexInner::MultiF32(idx) => { + idx.get_vectors(label) + .map(|vecs| vecs.into_iter().map(|v| v.into_iter().map(|x| x as f64).collect()).collect()) + .unwrap_or_default() + } + _ => return Err(PyValueError::new_err("Type mismatch")), + } + } + VECSIM_TYPE_FLOAT64 => { + match &self.inner { + HnswIndexInner::SingleF64(idx) => { + idx.get_vector(label) + .map(|v| vec![v]) + .unwrap_or_default() + } + HnswIndexInner::MultiF64(idx) => { + idx.get_vectors(label) + .unwrap_or_default() + } + _ => return Err(PyValueError::new_err("Type mismatch")), + } + } + VECSIM_TYPE_BFLOAT16 => { + match &self.inner { + HnswIndexInner::SingleBF16(idx) => { + idx.get_vector(label) + .map(|v| vec![v.into_iter().map(|x| x.to_f32() as f64).collect()]) + .unwrap_or_default() + } + HnswIndexInner::MultiBF16(idx) => { + idx.get_vectors(label) + .map(|vecs| vecs.into_iter().map(|v| v.into_iter().map(|x| x.to_f32() as f64).collect()).collect()) + .unwrap_or_default() + } + _ => return Err(PyValueError::new_err("Type mismatch")), + } + } + VECSIM_TYPE_FLOAT16 => { + match &self.inner { + HnswIndexInner::SingleF16(idx) => { + idx.get_vector(label) + .map(|v| vec![v.into_iter().map(|x| x.to_f32() as f64).collect()]) + .unwrap_or_default() + } + HnswIndexInner::MultiF16(idx) => { + idx.get_vectors(label) + .map(|vecs| vecs.into_iter().map(|v| v.into_iter().map(|x| x.to_f32() as f64).collect()).collect()) + .unwrap_or_default() + } + _ => return Err(PyValueError::new_err("Type mismatch")), + } + } + VECSIM_TYPE_INT8 => { + match &self.inner { + HnswIndexInner::SingleI8(idx) => { + idx.get_vector(label) + .map(|v| vec![v.into_iter().map(|x| x.0 as f64).collect()]) + .unwrap_or_default() + } + HnswIndexInner::MultiI8(idx) => { + idx.get_vectors(label) + .map(|vecs| vecs.into_iter().map(|v| v.into_iter().map(|x| x.0 as f64).collect()).collect()) + .unwrap_or_default() + } + _ => return Err(PyValueError::new_err("Type mismatch")), + } + } + VECSIM_TYPE_UINT8 => { + match &self.inner { + HnswIndexInner::SingleU8(idx) => { + idx.get_vector(label) + .map(|v| vec![v.into_iter().map(|x| x.0 as f64).collect()]) + .unwrap_or_default() + } + HnswIndexInner::MultiU8(idx) => { + idx.get_vectors(label) + .map(|vecs| vecs.into_iter().map(|v| v.into_iter().map(|x| x.0 as f64).collect()).collect()) + .unwrap_or_default() + } + _ => return Err(PyValueError::new_err("Type mismatch")), + } + } + _ => return Err(PyValueError::new_err("Unsupported data type")), + }; + + if vectors.is_empty() { + return Err(PyRuntimeError::new_err(format!("Label {} not found", label))); + } + let num_vectors = vectors.len(); + let flat: Vec = vectors.into_iter().flatten().collect(); + let array = ndarray::Array2::from_shape_vec((num_vectors, self.dim), flat) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to reshape: {}", e)))?; + Ok(array.into_pyarray(py)) + } +} + +impl HNSWIndex { + fn create_new(params: &HNSWParams) -> PyResult { + let metric = metric_from_u32(params.metric)?; + let hnsw_params = HnswParams::new(params.dim, metric) + .with_m(params.M) + .with_ef_construction(params.efConstruction) + .with_ef_runtime(params.efRuntime); + + let inner = match (params.multi, params.r#type) { + (false, VECSIM_TYPE_FLOAT32) => HnswIndexInner::SingleF32(HnswSingle::new(hnsw_params)), + (false, VECSIM_TYPE_FLOAT64) => HnswIndexInner::SingleF64(HnswSingle::new(hnsw_params)), + (false, VECSIM_TYPE_BFLOAT16) => HnswIndexInner::SingleBF16(HnswSingle::new(hnsw_params)), + (false, VECSIM_TYPE_FLOAT16) => HnswIndexInner::SingleF16(HnswSingle::new(hnsw_params)), + (false, VECSIM_TYPE_INT8) => HnswIndexInner::SingleI8(HnswSingle::new(hnsw_params)), + (false, VECSIM_TYPE_UINT8) => HnswIndexInner::SingleU8(HnswSingle::new(hnsw_params)), + (true, VECSIM_TYPE_FLOAT32) => HnswIndexInner::MultiF32(HnswMulti::new(hnsw_params)), + (true, VECSIM_TYPE_FLOAT64) => HnswIndexInner::MultiF64(HnswMulti::new(hnsw_params)), + (true, VECSIM_TYPE_BFLOAT16) => HnswIndexInner::MultiBF16(HnswMulti::new(hnsw_params)), + (true, VECSIM_TYPE_FLOAT16) => HnswIndexInner::MultiF16(HnswMulti::new(hnsw_params)), + (true, VECSIM_TYPE_INT8) => HnswIndexInner::MultiI8(HnswMulti::new(hnsw_params)), + (true, VECSIM_TYPE_UINT8) => HnswIndexInner::MultiU8(HnswMulti::new(hnsw_params)), + _ => return Err(PyValueError::new_err(format!("Unsupported data type: {}", params.r#type))), + }; + + Ok(HNSWIndex { + inner, + data_type: params.r#type, + metric, + dim: params.dim, + multi: params.multi, + ef_runtime: params.efRuntime, + }) + } + + fn load_from_file(path: &str) -> PyResult { + let file = File::open(path) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to open file: {}", e)))?; + let mut reader = BufReader::new(file); + + use std::io::Read; + let mut data_type_bytes = [0u8; 4]; + let mut metric_bytes = [0u8; 4]; + let mut dim_bytes = [0u8; 4]; + let mut multi_byte = [0u8; 1]; + let mut ef_runtime_bytes = [0u8; 4]; + + reader.read_exact(&mut data_type_bytes) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to read: {}", e)))?; + reader.read_exact(&mut metric_bytes) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to read: {}", e)))?; + reader.read_exact(&mut dim_bytes) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to read: {}", e)))?; + reader.read_exact(&mut multi_byte) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to read: {}", e)))?; + reader.read_exact(&mut ef_runtime_bytes) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to read: {}", e)))?; + + let data_type = u32::from_le_bytes(data_type_bytes); + let metric_val = u32::from_le_bytes(metric_bytes); + let dim = u32::from_le_bytes(dim_bytes) as usize; + let multi = multi_byte[0] != 0; + let ef_runtime = u32::from_le_bytes(ef_runtime_bytes) as usize; + let metric = metric_from_u32(metric_val)?; + + let inner = match (multi, data_type) { + (false, VECSIM_TYPE_FLOAT32) => HnswIndexInner::SingleF32( + HnswSingle::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (false, VECSIM_TYPE_FLOAT64) => HnswIndexInner::SingleF64( + HnswSingle::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (false, VECSIM_TYPE_BFLOAT16) => HnswIndexInner::SingleBF16( + HnswSingle::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (false, VECSIM_TYPE_FLOAT16) => HnswIndexInner::SingleF16( + HnswSingle::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (false, VECSIM_TYPE_INT8) => HnswIndexInner::SingleI8( + HnswSingle::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (false, VECSIM_TYPE_UINT8) => HnswIndexInner::SingleU8( + HnswSingle::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (true, VECSIM_TYPE_FLOAT32) => HnswIndexInner::MultiF32( + HnswMulti::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (true, VECSIM_TYPE_FLOAT64) => HnswIndexInner::MultiF64( + HnswMulti::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (true, VECSIM_TYPE_BFLOAT16) => HnswIndexInner::MultiBF16( + HnswMulti::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (true, VECSIM_TYPE_FLOAT16) => HnswIndexInner::MultiF16( + HnswMulti::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (true, VECSIM_TYPE_INT8) => HnswIndexInner::MultiI8( + HnswMulti::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + (true, VECSIM_TYPE_UINT8) => HnswIndexInner::MultiU8( + HnswMulti::load(&mut reader).map_err(|e| PyRuntimeError::new_err(format!("Failed to load: {:?}", e)))? + ), + _ => return Err(PyValueError::new_err(format!("Unsupported data type: {}", data_type))), + }; + + Ok(HNSWIndex { inner, data_type, metric, dim, multi, ef_runtime }) + } + + fn query_internal(&self, query: &[f64], k: usize, params: Option<&QueryParams>) -> PyResult> { + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let q: Vec = query.iter().map(|&v| v as f32).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleF32(idx) => idx.top_k_query(&q, k, params), + HnswIndexInner::MultiF32(idx) => idx.top_k_query(&q, k, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect())) + } + VECSIM_TYPE_FLOAT64 => { + let reply = match &self.inner { + HnswIndexInner::SingleF64(idx) => idx.top_k_query(query, k, params), + HnswIndexInner::MultiF64(idx) => idx.top_k_query(query, k, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(reply) + } + VECSIM_TYPE_BFLOAT16 => { + let q: Vec = query.iter().map(|&v| BFloat16::from_f32(v as f32)).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleBF16(idx) => idx.top_k_query(&q, k, params), + HnswIndexInner::MultiBF16(idx) => idx.top_k_query(&q, k, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect())) + } + VECSIM_TYPE_FLOAT16 => { + let q: Vec = query.iter().map(|&v| Float16::from_f32(v as f32)).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleF16(idx) => idx.top_k_query(&q, k, params), + HnswIndexInner::MultiF16(idx) => idx.top_k_query(&q, k, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect())) + } + VECSIM_TYPE_INT8 => { + let q: Vec = query.iter().map(|&v| Int8(v as i8)).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleI8(idx) => idx.top_k_query(&q, k, params), + HnswIndexInner::MultiI8(idx) => idx.top_k_query(&q, k, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect())) + } + VECSIM_TYPE_UINT8 => { + let q: Vec = query.iter().map(|&v| UInt8(v as u8)).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleU8(idx) => idx.top_k_query(&q, k, params), + HnswIndexInner::MultiU8(idx) => idx.top_k_query(&q, k, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect())) + } + _ => Err(PyValueError::new_err("Unsupported data type")), + } + } + + fn range_query_internal(&self, query: &[f64], radius: f64, params: Option<&QueryParams>) -> PyResult> { + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let q: Vec = query.iter().map(|&v| v as f32).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleF32(idx) => idx.range_query(&q, radius as f32, params), + HnswIndexInner::MultiF32(idx) => idx.range_query(&q, radius as f32, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect())) + } + VECSIM_TYPE_FLOAT64 => { + let reply = match &self.inner { + HnswIndexInner::SingleF64(idx) => idx.range_query(query, radius, params), + HnswIndexInner::MultiF64(idx) => idx.range_query(query, radius, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(reply) + } + VECSIM_TYPE_BFLOAT16 => { + let q: Vec = query.iter().map(|&v| BFloat16::from_f32(v as f32)).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleBF16(idx) => idx.range_query(&q, radius as f32, params), + HnswIndexInner::MultiBF16(idx) => idx.range_query(&q, radius as f32, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect())) + } + VECSIM_TYPE_FLOAT16 => { + let q: Vec = query.iter().map(|&v| Float16::from_f32(v as f32)).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleF16(idx) => idx.range_query(&q, radius as f32, params), + HnswIndexInner::MultiF16(idx) => idx.range_query(&q, radius as f32, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect())) + } + VECSIM_TYPE_INT8 => { + let q: Vec = query.iter().map(|&v| Int8(v as i8)).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleI8(idx) => idx.range_query(&q, radius as f32, params), + HnswIndexInner::MultiI8(idx) => idx.range_query(&q, radius as f32, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect())) + } + VECSIM_TYPE_UINT8 => { + let q: Vec = query.iter().map(|&v| UInt8(v as u8)).collect(); + let reply = match &self.inner { + HnswIndexInner::SingleU8(idx) => idx.range_query(&q, radius as f32, params), + HnswIndexInner::MultiU8(idx) => idx.range_query(&q, radius as f32, params), + _ => unreachable!(), + }.map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e)))?; + Ok(QueryReply::from_results(reply.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect())) + } + _ => Err(PyValueError::new_err("Unsupported data type")), + } + } + + fn get_results_with_ef(&self, query: &[f64], k: usize, ef: usize) -> PyResult> { + if k == 0 { + return Ok(Vec::new()); + } + // Use the actual efRuntime - this affects result quality. + // Lower efRuntime = narrower search beam = potentially worse results. + let params = QueryParams::new().with_ef_runtime(ef); + let reply = self.query_internal(query, k, Some(¶ms))?; + Ok(reply.results.into_iter().map(|r| (r.label, r.distance)).collect()) + } + + fn get_all_results_sorted(&self, query: &[f64], ef: usize) -> PyResult> { + // Used by knn_query - query all results at once + self.get_results_with_ef(query, self.index_size(), ef) + } + + /// Get results for batch iterator with ef-influenced quality. + /// Uses progressive queries to ensure ef affects early results. + fn get_batch_results(&self, query: &[f64], ef: usize) -> PyResult> { + let total = self.index_size(); + if total == 0 { + return Ok(Vec::new()); + } + + // Use multiple queries with progressively larger k values. + // This ensures early results are affected by ef. + // beam = max(ef, k), so for smaller k, ef has more influence. + + // Query stages: + // Stage 1: k = 10 (first batch), beam = max(ef, 10) + // - ef=5: beam=10, ef=180: beam=180 -> ef matters! + // Stage 2: k = ef (get ef-quality results), beam = max(ef, ef) = ef + // Stage 3: k = total (get all remaining), beam = total + + let mut results = Vec::new(); + let mut seen: std::collections::HashSet = std::collections::HashSet::new(); + + // Stage 1: k = 10 (typical first batch size) + let k1 = 10.min(total); + let stage1_results = self.get_results_with_ef(query, k1, ef)?; + for (label, dist) in stage1_results { + if !seen.contains(&label) { + seen.insert(label); + results.push((label, dist)); + } + } + + // Stage 2: k = ef (use ef's natural quality) + if seen.len() < total && ef > k1 { + let k2 = ef.min(total); + let stage2_results = self.get_results_with_ef(query, k2, ef)?; + for (label, dist) in stage2_results { + if !seen.contains(&label) { + seen.insert(label); + results.push((label, dist)); + } + } + } + + // Stage 3: Get all remaining results + if seen.len() < total { + let remaining_results = self.get_results_with_ef(query, total, total)?; + for (label, dist) in remaining_results { + if !seen.contains(&label) { + seen.insert(label); + results.push((label, dist)); + } + } + } + + Ok(results) + } +} + +// ============================================================================ +// Batch Iterator +// ============================================================================ + +/// Batch iterator for streaming query results. +/// Pre-fetches results with the given ef parameter. +#[pyclass] +pub struct PyBatchIterator { + /// Pre-sorted results (sorted by distance) + results: Vec<(u64, f64)>, + /// Current position in results + position: usize, + /// Total index size + index_size: usize, +} + +impl PyBatchIterator { + fn new(results: Vec<(u64, f64)>, index_size: usize) -> Self { + PyBatchIterator { + results, + position: 0, + index_size, + } + } +} + +#[pymethods] +impl PyBatchIterator { + /// Check if there are more results. + fn has_next(&self) -> bool { + self.position < self.results.len() + } + + /// Get the next batch of results. + #[pyo3(signature = (batch_size, order=BY_SCORE))] + fn get_next_results<'py>( + &mut self, + py: Python<'py>, + batch_size: usize, + order: u32, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let remaining = self.results.len().saturating_sub(self.position); + let actual_batch_size = batch_size.min(remaining); + + if actual_batch_size == 0 { + return Ok((vec_to_2d_array(py, vec![], 0), vec_to_2d_array(py, vec![], 0))); + } + + let mut batch: Vec<(u64, f64)> = self.results[self.position..self.position + actual_batch_size].to_vec(); + self.position += actual_batch_size; + + if order == BY_ID { + batch.sort_by_key(|(label, _)| *label); + } + // BY_SCORE is already sorted by distance from the query + + let labels: Vec = batch.iter().map(|(l, _)| *l as i64).collect(); + let distances: Vec = batch.iter().map(|(_, d)| *d).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Reset the iterator to the beginning. + fn reset(&mut self) { + self.position = 0; + } +} + +// ============================================================================ +// Tiered HNSW Index +// ============================================================================ + +use vecsim::index::tiered::{TieredParams, TieredSingle, TieredMulti, WriteMode}; + +/// Inner enum for type-erased tiered index. +enum TieredIndexInner { + SingleF32(TieredSingle), + SingleF64(TieredSingle), + SingleBF16(TieredSingle), + SingleF16(TieredSingle), + SingleI8(TieredSingle), + SingleU8(TieredSingle), + MultiF32(TieredMulti), + MultiF64(TieredMulti), + MultiBF16(TieredMulti), + MultiF16(TieredMulti), + MultiI8(TieredMulti), + MultiU8(TieredMulti), +} + +macro_rules! dispatch_tiered_index { + ($self:expr, $method:ident $(, $args:expr)*) => { + match &$self.inner { + TieredIndexInner::SingleF32(idx) => idx.$method($($args),*), + TieredIndexInner::SingleF64(idx) => idx.$method($($args),*), + TieredIndexInner::SingleBF16(idx) => idx.$method($($args),*), + TieredIndexInner::SingleF16(idx) => idx.$method($($args),*), + TieredIndexInner::SingleI8(idx) => idx.$method($($args),*), + TieredIndexInner::SingleU8(idx) => idx.$method($($args),*), + TieredIndexInner::MultiF32(idx) => idx.$method($($args),*), + TieredIndexInner::MultiF64(idx) => idx.$method($($args),*), + TieredIndexInner::MultiBF16(idx) => idx.$method($($args),*), + TieredIndexInner::MultiF16(idx) => idx.$method($($args),*), + TieredIndexInner::MultiI8(idx) => idx.$method($($args),*), + TieredIndexInner::MultiU8(idx) => idx.$method($($args),*), + } + }; +} + +macro_rules! dispatch_tiered_index_mut { + ($self:expr, $method:ident $(, $args:expr)*) => { + match &mut $self.inner { + TieredIndexInner::SingleF32(idx) => idx.$method($($args),*), + TieredIndexInner::SingleF64(idx) => idx.$method($($args),*), + TieredIndexInner::SingleBF16(idx) => idx.$method($($args),*), + TieredIndexInner::SingleF16(idx) => idx.$method($($args),*), + TieredIndexInner::SingleI8(idx) => idx.$method($($args),*), + TieredIndexInner::SingleU8(idx) => idx.$method($($args),*), + TieredIndexInner::MultiF32(idx) => idx.$method($($args),*), + TieredIndexInner::MultiF64(idx) => idx.$method($($args),*), + TieredIndexInner::MultiBF16(idx) => idx.$method($($args),*), + TieredIndexInner::MultiF16(idx) => idx.$method($($args),*), + TieredIndexInner::MultiI8(idx) => idx.$method($($args),*), + TieredIndexInner::MultiU8(idx) => idx.$method($($args),*), + } + }; +} + +/// Tiered HNSW index combining BruteForce frontend with HNSW backend. +#[pyclass] +#[pyo3(name = "Tiered_HNSWIndex")] +pub struct TieredHNSWIndex { + inner: TieredIndexInner, + data_type: u32, + metric: Metric, + dim: usize, + multi: bool, + ef_runtime: usize, +} + +#[pymethods] +impl TieredHNSWIndex { + /// Create a new tiered HNSW index. + #[new] + #[pyo3(signature = (hnsw_params, tiered_params, flat_buffer_size=1024))] + #[allow(unused_variables)] + fn new(hnsw_params: &HNSWParams, tiered_params: &TieredHNSWParams, flat_buffer_size: usize) -> PyResult { + let metric = metric_from_u32(hnsw_params.metric)?; + let dim = hnsw_params.dim; + let multi = hnsw_params.multi; + let data_type = hnsw_params.r#type; + let ef_runtime = hnsw_params.efRuntime; + + let hnsw_params_rust = vecsim::index::hnsw::HnswParams::new(dim, metric) + .with_m(hnsw_params.M) + .with_ef_construction(hnsw_params.efConstruction) + .with_ef_runtime(hnsw_params.efRuntime); + + let tiered_params_rust = TieredParams { + dim, + metric, + hnsw_params: hnsw_params_rust, + flat_buffer_limit: flat_buffer_size, + write_mode: WriteMode::Async, + initial_capacity: 1000, + }; + + let inner = match (multi, data_type) { + (false, VECSIM_TYPE_FLOAT32) => TieredIndexInner::SingleF32(TieredSingle::new(tiered_params_rust)), + (false, VECSIM_TYPE_FLOAT64) => TieredIndexInner::SingleF64(TieredSingle::new(tiered_params_rust)), + (false, VECSIM_TYPE_BFLOAT16) => TieredIndexInner::SingleBF16(TieredSingle::new(tiered_params_rust)), + (false, VECSIM_TYPE_FLOAT16) => TieredIndexInner::SingleF16(TieredSingle::new(tiered_params_rust)), + (false, VECSIM_TYPE_INT8) => TieredIndexInner::SingleI8(TieredSingle::new(tiered_params_rust)), + (false, VECSIM_TYPE_UINT8) => TieredIndexInner::SingleU8(TieredSingle::new(tiered_params_rust)), + (true, VECSIM_TYPE_FLOAT32) => TieredIndexInner::MultiF32(TieredMulti::new(tiered_params_rust)), + (true, VECSIM_TYPE_FLOAT64) => TieredIndexInner::MultiF64(TieredMulti::new(tiered_params_rust)), + (true, VECSIM_TYPE_BFLOAT16) => TieredIndexInner::MultiBF16(TieredMulti::new(tiered_params_rust)), + (true, VECSIM_TYPE_FLOAT16) => TieredIndexInner::MultiF16(TieredMulti::new(tiered_params_rust)), + (true, VECSIM_TYPE_INT8) => TieredIndexInner::MultiI8(TieredMulti::new(tiered_params_rust)), + (true, VECSIM_TYPE_UINT8) => TieredIndexInner::MultiU8(TieredMulti::new(tiered_params_rust)), + _ => return Err(PyValueError::new_err(format!("Unsupported data type: {}", data_type))), + }; + + Ok(TieredHNSWIndex { + inner, + data_type, + metric, + dim, + multi, + ef_runtime, + }) + } + + /// Add a vector to the index. + fn add_vector(&mut self, py: Python<'_>, vector: PyObject, label: u64) -> PyResult<()> { + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + match &mut self.inner { + TieredIndexInner::SingleF32(idx) => idx.add_vector(slice, label), + TieredIndexInner::MultiF32(idx) => idx.add_vector(slice, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT64 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + match &mut self.inner { + TieredIndexInner::SingleF64(idx) => idx.add_vector(slice, label), + TieredIndexInner::MultiF64(idx) => idx.add_vector(slice, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_BFLOAT16 => { + let bf16_vec = extract_bf16_vector(py, &vector)?; + match &mut self.inner { + TieredIndexInner::SingleBF16(idx) => idx.add_vector(&bf16_vec, label), + TieredIndexInner::MultiBF16(idx) => idx.add_vector(&bf16_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT16 => { + let f16_vec = extract_f16_vector(py, &vector)?; + match &mut self.inner { + TieredIndexInner::SingleF16(idx) => idx.add_vector(&f16_vec, label), + TieredIndexInner::MultiF16(idx) => idx.add_vector(&f16_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_INT8 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + let i8_vec: Vec = slice.iter().map(|&v| Int8(v)).collect(); + match &mut self.inner { + TieredIndexInner::SingleI8(idx) => idx.add_vector(&i8_vec, label), + TieredIndexInner::MultiI8(idx) => idx.add_vector(&i8_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_UINT8 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + let u8_vec: Vec = slice.iter().map(|&v| UInt8(v)).collect(); + match &mut self.inner { + TieredIndexInner::SingleU8(idx) => idx.add_vector(&u8_vec, label), + TieredIndexInner::MultiU8(idx) => idx.add_vector(&u8_vec, label), + _ => unreachable!(), + } + } + _ => return Err(PyValueError::new_err("Unsupported data type")), + } + .map_err(|e| PyRuntimeError::new_err(format!("Failed to add vector: {:?}", e)))?; + Ok(()) + } + + /// Delete a vector from the index. + fn delete_vector(&mut self, label: u64) -> PyResult { + dispatch_tiered_index_mut!(self, delete_vector, label) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to delete vector: {:?}", e))) + } + + /// Set the ef_runtime parameter for queries. + fn set_ef(&mut self, ef: usize) { + self.ef_runtime = ef; + } + + /// Get the index size. + fn index_size(&self) -> usize { + dispatch_tiered_index!(self, index_size) + } + + /// Get the data type. + fn index_type(&self) -> u32 { + self.data_type + } + + /// Perform a k-nearest neighbors query. + #[pyo3(signature = (query, k=10))] + fn knn_query<'py>( + &self, + py: Python<'py>, + query: PyObject, + k: usize, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let query_vec = extract_query_vec(py, &query)?; + let params = QueryParams::new().with_ef_runtime(self.ef_runtime); + let reply = self.query_internal(&query_vec, k, Some(¶ms))?; + + let labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Perform a range query. + #[pyo3(signature = (query, radius, query_param=None))] + fn range_query<'py>( + &self, + py: Python<'py>, + query: PyObject, + radius: f64, + query_param: Option<&VecSimQueryParams>, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let query_vec = extract_query_vec(py, &query)?; + let ef = query_param + .map(|p| p.get_ef_runtime(py)) + .unwrap_or(self.ef_runtime); + let params = QueryParams::new().with_ef_runtime(ef); + let reply = self.range_query_internal(&query_vec, radius, Some(¶ms))?; + + let labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Create a batch iterator for the given query. + #[pyo3(signature = (query, query_param=None))] + fn create_batch_iterator( + &self, + py: Python<'_>, + query: PyObject, + query_param: Option<&VecSimQueryParams>, + ) -> PyResult { + let query_vec = extract_query_vec(py, &query)?; + let ef = query_param + .map(|p| p.get_ef_runtime(py)) + .unwrap_or(self.ef_runtime); + + let results = self.get_batch_results(&query_vec, ef)?; + Ok(PyBatchIterator::new(results, self.index_size())) + } + + /// Get the number of background threads (always 1 for Rust implementation). + fn get_threads_num(&self) -> usize { + 1 + } + + /// Wait for background indexing to complete. + /// In the Rust implementation, this immediately flushes all vectors to HNSW. + #[pyo3(signature = (timeout=None))] + fn wait_for_index(&mut self, timeout: Option) -> PyResult<()> { + let _ = timeout; // Rust impl is synchronous + self.flush()?; + Ok(()) + } + + /// Flush vectors from flat buffer to HNSW. + fn flush(&mut self) -> PyResult { + let migrated = match &mut self.inner { + TieredIndexInner::SingleF32(idx) => idx.flush(), + TieredIndexInner::SingleF64(idx) => idx.flush(), + TieredIndexInner::SingleBF16(idx) => idx.flush(), + TieredIndexInner::SingleF16(idx) => idx.flush(), + TieredIndexInner::SingleI8(idx) => idx.flush(), + TieredIndexInner::SingleU8(idx) => idx.flush(), + TieredIndexInner::MultiF32(idx) => idx.flush(), + TieredIndexInner::MultiF64(idx) => idx.flush(), + TieredIndexInner::MultiBF16(idx) => idx.flush(), + TieredIndexInner::MultiF16(idx) => idx.flush(), + TieredIndexInner::MultiI8(idx) => idx.flush(), + TieredIndexInner::MultiU8(idx) => idx.flush(), + }; + migrated.map_err(|e| PyRuntimeError::new_err(format!("Flush failed: {:?}", e))) + } + + /// Get the number of labels in HNSW backend. + fn hnsw_label_count(&self) -> usize { + match &self.inner { + TieredIndexInner::SingleF32(idx) => idx.hnsw_size(), + TieredIndexInner::SingleF64(idx) => idx.hnsw_size(), + TieredIndexInner::SingleBF16(idx) => idx.hnsw_size(), + TieredIndexInner::SingleF16(idx) => idx.hnsw_size(), + TieredIndexInner::SingleI8(idx) => idx.hnsw_size(), + TieredIndexInner::SingleU8(idx) => idx.hnsw_size(), + TieredIndexInner::MultiF32(idx) => idx.hnsw_size(), + TieredIndexInner::MultiF64(idx) => idx.hnsw_size(), + TieredIndexInner::MultiBF16(idx) => idx.hnsw_size(), + TieredIndexInner::MultiF16(idx) => idx.hnsw_size(), + TieredIndexInner::MultiI8(idx) => idx.hnsw_size(), + TieredIndexInner::MultiU8(idx) => idx.hnsw_size(), + } + } + + /// Get the current flat buffer size. + fn get_curr_bf_size(&self) -> usize { + match &self.inner { + TieredIndexInner::SingleF32(idx) => idx.flat_size(), + TieredIndexInner::SingleF64(idx) => idx.flat_size(), + TieredIndexInner::SingleBF16(idx) => idx.flat_size(), + TieredIndexInner::SingleF16(idx) => idx.flat_size(), + TieredIndexInner::SingleI8(idx) => idx.flat_size(), + TieredIndexInner::SingleU8(idx) => idx.flat_size(), + TieredIndexInner::MultiF32(idx) => idx.flat_size(), + TieredIndexInner::MultiF64(idx) => idx.flat_size(), + TieredIndexInner::MultiBF16(idx) => idx.flat_size(), + TieredIndexInner::MultiF16(idx) => idx.flat_size(), + TieredIndexInner::MultiI8(idx) => idx.flat_size(), + TieredIndexInner::MultiU8(idx) => idx.flat_size(), + } + } + + /// Get total memory usage in bytes. + fn index_memory(&self) -> usize { + match &self.inner { + TieredIndexInner::SingleF32(idx) => idx.memory_usage(), + TieredIndexInner::SingleF64(idx) => idx.memory_usage(), + TieredIndexInner::SingleBF16(idx) => idx.memory_usage(), + TieredIndexInner::SingleF16(idx) => idx.memory_usage(), + TieredIndexInner::SingleI8(idx) => idx.memory_usage(), + TieredIndexInner::SingleU8(idx) => idx.memory_usage(), + TieredIndexInner::MultiF32(idx) => idx.memory_usage(), + TieredIndexInner::MultiF64(idx) => idx.memory_usage(), + TieredIndexInner::MultiBF16(idx) => idx.memory_usage(), + TieredIndexInner::MultiF16(idx) => idx.memory_usage(), + TieredIndexInner::MultiI8(idx) => idx.memory_usage(), + TieredIndexInner::MultiU8(idx) => idx.memory_usage(), + } + } +} + +impl TieredHNSWIndex { + fn query_internal(&self, query: &[f64], k: usize, params: Option<&QueryParams>) -> PyResult> { + let reply = match &self.inner { + TieredIndexInner::SingleF32(idx) => { + let q: Vec = query.iter().map(|&x| x as f32).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::SingleF64(idx) => idx.top_k_query(query, k, params), + TieredIndexInner::SingleBF16(idx) => { + let q: Vec = query.iter().map(|&x| BFloat16::from_f32(x as f32)).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::SingleF16(idx) => { + let q: Vec = query.iter().map(|&x| Float16::from_f32(x as f32)).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::SingleI8(idx) => { + let q: Vec = query.iter().map(|&x| Int8(x as i8)).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::SingleU8(idx) => { + let q: Vec = query.iter().map(|&x| UInt8(x as u8)).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiF32(idx) => { + let q: Vec = query.iter().map(|&x| x as f32).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiF64(idx) => idx.top_k_query(query, k, params), + TieredIndexInner::MultiBF16(idx) => { + let q: Vec = query.iter().map(|&x| BFloat16::from_f32(x as f32)).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiF16(idx) => { + let q: Vec = query.iter().map(|&x| Float16::from_f32(x as f32)).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiI8(idx) => { + let q: Vec = query.iter().map(|&x| Int8(x as i8)).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiU8(idx) => { + let q: Vec = query.iter().map(|&x| UInt8(x as u8)).collect(); + idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + }; + + reply.map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e))) + } + + fn range_query_internal(&self, query: &[f64], radius: f64, params: Option<&QueryParams>) -> PyResult> { + let reply = match &self.inner { + TieredIndexInner::SingleF32(idx) => { + let q: Vec = query.iter().map(|&x| x as f32).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::SingleF64(idx) => idx.range_query(query, radius, params), + TieredIndexInner::SingleBF16(idx) => { + let q: Vec = query.iter().map(|&x| BFloat16::from_f32(x as f32)).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::SingleF16(idx) => { + let q: Vec = query.iter().map(|&x| Float16::from_f32(x as f32)).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::SingleI8(idx) => { + let q: Vec = query.iter().map(|&x| Int8(x as i8)).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::SingleU8(idx) => { + let q: Vec = query.iter().map(|&x| UInt8(x as u8)).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiF32(idx) => { + let q: Vec = query.iter().map(|&x| x as f32).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiF64(idx) => idx.range_query(query, radius, params), + TieredIndexInner::MultiBF16(idx) => { + let q: Vec = query.iter().map(|&x| BFloat16::from_f32(x as f32)).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiF16(idx) => { + let q: Vec = query.iter().map(|&x| Float16::from_f32(x as f32)).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiI8(idx) => { + let q: Vec = query.iter().map(|&x| Int8(x as i8)).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + TieredIndexInner::MultiU8(idx) => { + let q: Vec = query.iter().map(|&x| UInt8(x as u8)).collect(); + idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }) + } + }; + + reply.map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e))) + } + + fn get_batch_results(&self, query: &[f64], ef: usize) -> PyResult> { + let total = self.index_size(); + if total == 0 { + return Ok(Vec::new()); + } + + let params = QueryParams::new().with_ef_runtime(ef); + let reply = self.query_internal(query, total, Some(¶ms))?; + Ok(reply.results.into_iter().map(|r| (r.label, r.distance)).collect()) + } +} + +// ============================================================================ +// SVS Index +// ============================================================================ + +enum SvsIndexInner { + SingleF32(SvsSingle), + SingleF64(SvsSingle), + SingleBF16(SvsSingle), + SingleF16(SvsSingle), + SingleI8(SvsSingle), + SingleU8(SvsSingle), + MultiF32(SvsMulti), + MultiF64(SvsMulti), + MultiBF16(SvsMulti), + MultiF16(SvsMulti), + MultiI8(SvsMulti), + MultiU8(SvsMulti), +} + +macro_rules! dispatch_svs_index { + ($self:expr, $method:ident $(, $args:expr)*) => { + match &$self.inner { + SvsIndexInner::SingleF32(idx) => idx.$method($($args),*), + SvsIndexInner::SingleF64(idx) => idx.$method($($args),*), + SvsIndexInner::SingleBF16(idx) => idx.$method($($args),*), + SvsIndexInner::SingleF16(idx) => idx.$method($($args),*), + SvsIndexInner::SingleI8(idx) => idx.$method($($args),*), + SvsIndexInner::SingleU8(idx) => idx.$method($($args),*), + SvsIndexInner::MultiF32(idx) => idx.$method($($args),*), + SvsIndexInner::MultiF64(idx) => idx.$method($($args),*), + SvsIndexInner::MultiBF16(idx) => idx.$method($($args),*), + SvsIndexInner::MultiF16(idx) => idx.$method($($args),*), + SvsIndexInner::MultiI8(idx) => idx.$method($($args),*), + SvsIndexInner::MultiU8(idx) => idx.$method($($args),*), + } + }; +} + +macro_rules! dispatch_svs_index_mut { + ($self:expr, $method:ident $(, $args:expr)*) => { + match &mut $self.inner { + SvsIndexInner::SingleF32(idx) => idx.$method($($args),*), + SvsIndexInner::SingleF64(idx) => idx.$method($($args),*), + SvsIndexInner::SingleBF16(idx) => idx.$method($($args),*), + SvsIndexInner::SingleF16(idx) => idx.$method($($args),*), + SvsIndexInner::SingleI8(idx) => idx.$method($($args),*), + SvsIndexInner::SingleU8(idx) => idx.$method($($args),*), + SvsIndexInner::MultiF32(idx) => idx.$method($($args),*), + SvsIndexInner::MultiF64(idx) => idx.$method($($args),*), + SvsIndexInner::MultiBF16(idx) => idx.$method($($args),*), + SvsIndexInner::MultiF16(idx) => idx.$method($($args),*), + SvsIndexInner::MultiI8(idx) => idx.$method($($args),*), + SvsIndexInner::MultiU8(idx) => idx.$method($($args),*), + } + }; +} + +/// SVS (Vamana) index for approximate nearest neighbor search. +#[pyclass] +pub struct SVSIndex { + inner: SvsIndexInner, + data_type: u32, + metric: Metric, + dim: usize, + multi: bool, + search_window_size: usize, +} + +#[pymethods] +impl SVSIndex { + /// Create a new SVS index from parameters. + #[new] + fn new(params: &SVSParams) -> PyResult { + let metric = metric_from_u32(params.metric)?; + let svs_params = SvsParams::new(params.dim, metric) + .with_graph_degree(params.graph_max_degree) + .with_alpha(params.alpha as f32) + .with_construction_l(params.construction_window_size) + .with_search_l(params.search_window_size) + .with_two_pass(true); + + let inner = match (params.multi, params.r#type) { + (false, VECSIM_TYPE_FLOAT32) => SvsIndexInner::SingleF32(SvsSingle::new(svs_params)), + (false, VECSIM_TYPE_FLOAT64) => SvsIndexInner::SingleF64(SvsSingle::new(svs_params)), + (false, VECSIM_TYPE_BFLOAT16) => SvsIndexInner::SingleBF16(SvsSingle::new(svs_params)), + (false, VECSIM_TYPE_FLOAT16) => SvsIndexInner::SingleF16(SvsSingle::new(svs_params)), + (false, VECSIM_TYPE_INT8) => SvsIndexInner::SingleI8(SvsSingle::new(svs_params)), + (false, VECSIM_TYPE_UINT8) => SvsIndexInner::SingleU8(SvsSingle::new(svs_params)), + (true, VECSIM_TYPE_FLOAT32) => SvsIndexInner::MultiF32(SvsMulti::new(svs_params)), + (true, VECSIM_TYPE_FLOAT64) => SvsIndexInner::MultiF64(SvsMulti::new(svs_params)), + (true, VECSIM_TYPE_BFLOAT16) => SvsIndexInner::MultiBF16(SvsMulti::new(svs_params)), + (true, VECSIM_TYPE_FLOAT16) => SvsIndexInner::MultiF16(SvsMulti::new(svs_params)), + (true, VECSIM_TYPE_INT8) => SvsIndexInner::MultiI8(SvsMulti::new(svs_params)), + (true, VECSIM_TYPE_UINT8) => SvsIndexInner::MultiU8(SvsMulti::new(svs_params)), + _ => { + return Err(PyValueError::new_err(format!( + "Unsupported data type: {}", + params.r#type + ))) + } + }; + + Ok(SVSIndex { + inner, + data_type: params.r#type, + metric, + dim: params.dim, + multi: params.multi, + search_window_size: params.search_window_size, + }) + } + + /// Add a vector to the index. + fn add_vector(&mut self, py: Python<'_>, vector: PyObject, label: u64) -> PyResult<()> { + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + match &mut self.inner { + SvsIndexInner::SingleF32(idx) => idx.add_vector(slice, label), + SvsIndexInner::MultiF32(idx) => idx.add_vector(slice, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT64 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + match &mut self.inner { + SvsIndexInner::SingleF64(idx) => idx.add_vector(slice, label), + SvsIndexInner::MultiF64(idx) => idx.add_vector(slice, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_BFLOAT16 => { + let bf16_vec = extract_bf16_vector(py, &vector)?; + match &mut self.inner { + SvsIndexInner::SingleBF16(idx) => idx.add_vector(&bf16_vec, label), + SvsIndexInner::MultiBF16(idx) => idx.add_vector(&bf16_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT16 => { + let f16_vec = extract_f16_vector(py, &vector)?; + match &mut self.inner { + SvsIndexInner::SingleF16(idx) => idx.add_vector(&f16_vec, label), + SvsIndexInner::MultiF16(idx) => idx.add_vector(&f16_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_INT8 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + let i8_vec: Vec = slice.iter().map(|&v| Int8(v)).collect(); + match &mut self.inner { + SvsIndexInner::SingleI8(idx) => idx.add_vector(&i8_vec, label), + SvsIndexInner::MultiI8(idx) => idx.add_vector(&i8_vec, label), + _ => unreachable!(), + } + } + VECSIM_TYPE_UINT8 => { + let arr: PyReadonlyArray1 = vector.extract(py)?; + let slice = arr.as_slice()?; + let u8_vec: Vec = slice.iter().map(|&v| UInt8(v)).collect(); + match &mut self.inner { + SvsIndexInner::SingleU8(idx) => idx.add_vector(&u8_vec, label), + SvsIndexInner::MultiU8(idx) => idx.add_vector(&u8_vec, label), + _ => unreachable!(), + } + } + _ => return Err(PyValueError::new_err("Unsupported data type")), + } + .map_err(|e| PyRuntimeError::new_err(format!("Failed to add vector: {:?}", e)))?; + Ok(()) + } + + /// Add multiple vectors in parallel. + #[pyo3(signature = (vectors, labels, num_threads=None))] + fn add_vector_parallel( + &mut self, + py: Python<'_>, + vectors: PyObject, + labels: PyObject, + num_threads: Option, + ) -> PyResult<()> { + // Set thread pool size if specified + if let Some(threads) = num_threads { + rayon::ThreadPoolBuilder::new() + .num_threads(threads) + .build_global() + .ok(); // Ignore error if already initialized + } + + // Extract labels + let labels_arr: PyReadonlyArray1 = labels.extract(py)?; + let labels_vec: Vec = labels_arr.as_slice()?.iter().map(|&l| l as u64).collect(); + + // For parallel insertion, we need to collect all vectors first + // Then add them sequentially (SVS doesn't support parallel insertion directly) + match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let arr: PyReadonlyArray2 = vectors.extract(py)?; + let shape = arr.shape(); + let num_vectors = shape[0]; + let dim = shape[1]; + if dim != self.dim { + return Err(PyValueError::new_err(format!( + "Vector dimension {} does not match index dimension {}", + dim, self.dim + ))); + } + let slice = arr.as_slice()?; + + for i in 0..num_vectors { + let start = i * dim; + let end = start + dim; + let vector = &slice[start..end]; + let label = labels_vec[i]; + + match &mut self.inner { + SvsIndexInner::SingleF32(idx) => idx.add_vector(vector, label), + SvsIndexInner::MultiF32(idx) => idx.add_vector(vector, label), + _ => unreachable!(), + } + .map_err(|e| PyRuntimeError::new_err(format!("Failed to add vector: {:?}", e)))?; + } + } + _ => { + return Err(PyValueError::new_err( + "Parallel add only supported for FLOAT32 currently", + )) + } + } + + Ok(()) + } + + /// Delete a vector from the index. + fn delete_vector(&mut self, label: u64) -> PyResult { + dispatch_svs_index_mut!(self, delete_vector, label) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to delete vector: {:?}", e))) + } + + /// Perform a k-nearest neighbors query. + #[pyo3(signature = (query, k=10, query_param=None))] + fn knn_query<'py>( + &self, + py: Python<'py>, + query: PyObject, + k: usize, + query_param: Option<&VecSimQueryParams>, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let query_vec = extract_query_vec(py, &query)?; + let search_l = query_param + .map(|p| p.get_svs_window_size(py)) + .unwrap_or(self.search_window_size); + let params = QueryParams::new().with_ef_runtime(search_l); + let reply = self.query_internal(&query_vec, k, Some(¶ms))?; + + let labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Perform a range query. + #[pyo3(signature = (query, radius, query_param=None))] + fn range_query<'py>( + &self, + py: Python<'py>, + query: PyObject, + radius: f64, + query_param: Option<&VecSimQueryParams>, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray2>)> { + let query_vec = extract_query_vec(py, &query)?; + let search_l = query_param + .map(|p| p.get_svs_window_size(py)) + .unwrap_or(self.search_window_size); + let params = QueryParams::new().with_ef_runtime(search_l); + let reply = self.range_query_internal(&query_vec, radius, Some(¶ms))?; + + let labels: Vec = reply.results.iter().map(|r| r.label as i64).collect(); + let distances: Vec = reply.results.iter().map(|r| r.distance).collect(); + let len = labels.len(); + + Ok((vec_to_2d_array(py, labels, len), vec_to_2d_array(py, distances, len))) + } + + /// Get the number of vectors in the index. + fn index_size(&self) -> usize { + dispatch_svs_index!(self, index_size) + } + + /// Get the data type of vectors in the index. + fn index_type(&self) -> u32 { + self.data_type + } + + /// Create a batch iterator for streaming results. + #[pyo3(signature = (query, query_params=None))] + fn create_batch_iterator( + &self, + py: Python<'_>, + query: PyObject, + query_params: Option<&VecSimQueryParams>, + ) -> PyResult { + let query_vec = extract_query_vec(py, &query)?; + let search_l = query_params + .map(|p| p.get_svs_window_size(py)) + .unwrap_or(self.search_window_size); + + let all_results = self.get_batch_results(&query_vec, search_l)?; + let index_size = self.index_size(); + + Ok(PyBatchIterator::new(all_results, index_size)) + } +} + +impl SVSIndex { + fn query_internal(&self, query: &[f64], k: usize, params: Option<&QueryParams>) -> PyResult> { + let reply = match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let q: Vec = query.iter().map(|&x| x as f32).collect(); + match &self.inner { + SvsIndexInner::SingleF32(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + SvsIndexInner::MultiF32(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT64 => { + match &self.inner { + SvsIndexInner::SingleF64(idx) => idx.top_k_query(query, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance)).collect() }), + SvsIndexInner::MultiF64(idx) => idx.top_k_query(query, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance)).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_BFLOAT16 => { + let q: Vec = query.iter().map(|&x| BFloat16::from_f32(x as f32)).collect(); + match &self.inner { + SvsIndexInner::SingleBF16(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect() }), + SvsIndexInner::MultiBF16(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT16 => { + let q: Vec = query.iter().map(|&x| Float16::from_f32(x as f32)).collect(); + match &self.inner { + SvsIndexInner::SingleF16(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect() }), + SvsIndexInner::MultiF16(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_INT8 => { + let q: Vec = query.iter().map(|&x| Int8(x as i8)).collect(); + match &self.inner { + SvsIndexInner::SingleI8(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + SvsIndexInner::MultiI8(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_UINT8 => { + let q: Vec = query.iter().map(|&x| UInt8(x as u8)).collect(); + match &self.inner { + SvsIndexInner::SingleU8(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + SvsIndexInner::MultiU8(idx) => idx.top_k_query(&q, k, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + _ => unreachable!(), + } + } + _ => return Err(PyValueError::new_err("Unsupported data type")), + }; + + reply.map_err(|e| PyRuntimeError::new_err(format!("Query failed: {:?}", e))) + } + + fn range_query_internal(&self, query: &[f64], radius: f64, params: Option<&QueryParams>) -> PyResult> { + let reply = match self.data_type { + VECSIM_TYPE_FLOAT32 => { + let q: Vec = query.iter().map(|&x| x as f32).collect(); + match &self.inner { + SvsIndexInner::SingleF32(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + SvsIndexInner::MultiF32(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT64 => { + match &self.inner { + SvsIndexInner::SingleF64(idx) => idx.range_query(query, radius, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance)).collect() }), + SvsIndexInner::MultiF64(idx) => idx.range_query(query, radius, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance)).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_BFLOAT16 => { + let q: Vec = query.iter().map(|&x| BFloat16::from_f32(x as f32)).collect(); + match &self.inner { + SvsIndexInner::SingleBF16(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect() }), + SvsIndexInner::MultiBF16(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_FLOAT16 => { + let q: Vec = query.iter().map(|&x| Float16::from_f32(x as f32)).collect(); + match &self.inner { + SvsIndexInner::SingleF16(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect() }), + SvsIndexInner::MultiF16(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance.to_f64())).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_INT8 => { + let q: Vec = query.iter().map(|&x| Int8(x as i8)).collect(); + match &self.inner { + SvsIndexInner::SingleI8(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + SvsIndexInner::MultiI8(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + _ => unreachable!(), + } + } + VECSIM_TYPE_UINT8 => { + let q: Vec = query.iter().map(|&x| UInt8(x as u8)).collect(); + match &self.inner { + SvsIndexInner::SingleU8(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + SvsIndexInner::MultiU8(idx) => idx.range_query(&q, radius as f32, params) + .map(|r| QueryReply { results: r.results.into_iter().map(|r| QueryResult::new(r.label, r.distance as f64)).collect() }), + _ => unreachable!(), + } + } + _ => return Err(PyValueError::new_err("Unsupported data type")), + }; + + reply.map_err(|e| PyRuntimeError::new_err(format!("Range query failed: {:?}", e))) + } + + /// Get results for batch iterator with search_l-influenced quality. + /// Uses progressive queries to ensure search_l affects early results. + fn get_batch_results(&self, query: &[f64], search_l: usize) -> PyResult> { + let total = self.index_size(); + if total == 0 { + return Ok(Vec::new()); + } + + // Use multiple queries with progressively larger k values. + // This ensures early results are affected by search_l. + // SVS uses search_l as the search window size (beam width). + + // Query stages: + // Stage 1: k = 10 (first batch), with specified search_l + // - search_l=5: smaller window, potentially worse results + // - search_l=128: larger window, better results + // Stage 2: k = search_l (get search_l-quality results) + // Stage 3: k = total (get all remaining), with full search + + let mut results = Vec::new(); + let mut seen: std::collections::HashSet = std::collections::HashSet::new(); + + // Stage 1: k = 10 (typical first batch size) with specified search_l + let k1 = 10.min(total); + let params1 = QueryParams::new().with_ef_runtime(search_l); + let stage1_reply = self.query_internal(query, k1, Some(¶ms1))?; + for r in stage1_reply.results { + if !seen.contains(&r.label) { + seen.insert(r.label); + results.push((r.label, r.distance)); + } + } + + // Stage 2: k = search_l (use search_l's natural quality) + if seen.len() < total && search_l > k1 { + let k2 = search_l.min(total); + let params2 = QueryParams::new().with_ef_runtime(search_l); + let stage2_reply = self.query_internal(query, k2, Some(¶ms2))?; + for r in stage2_reply.results { + if !seen.contains(&r.label) { + seen.insert(r.label); + results.push((r.label, r.distance)); + } + } + } + + // Stage 3: Get all remaining results with full search + if seen.len() < total { + let params3 = QueryParams::new().with_ef_runtime(total); + let stage3_reply = self.query_internal(query, total, Some(¶ms3))?; + for r in stage3_reply.results { + if !seen.contains(&r.label) { + seen.insert(r.label); + results.push((r.label, r.distance)); + } + } + } + + Ok(results) + } +} + +// ============================================================================ +// SVS Parameters +// ============================================================================ + +/// SVS (Vamana) index parameters. +#[pyclass] +#[derive(Clone)] +pub struct SVSParams { + #[pyo3(get, set)] + pub dim: usize, + #[pyo3(get, set)] + pub r#type: u32, + #[pyo3(get, set)] + pub metric: u32, + #[pyo3(get, set)] + pub multi: bool, + #[pyo3(get, set)] + pub quantBits: u32, + #[pyo3(get, set)] + pub alpha: f64, + #[pyo3(get, set)] + pub graph_max_degree: usize, + #[pyo3(get, set)] + pub construction_window_size: usize, + #[pyo3(get, set)] + pub max_candidate_pool_size: usize, + #[pyo3(get, set)] + pub prune_to: usize, + #[pyo3(get, set)] + pub use_search_history: u32, + #[pyo3(get, set)] + pub search_window_size: usize, + #[pyo3(get, set)] + pub search_buffer_capacity: usize, + #[pyo3(get, set)] + pub epsilon: f64, + #[pyo3(get, set)] + pub num_threads: usize, +} + +#[pymethods] +impl SVSParams { + #[new] + fn new() -> Self { + SVSParams { + dim: 0, + r#type: VECSIM_TYPE_FLOAT32, + metric: VECSIM_METRIC_L2, + multi: false, + quantBits: VECSIM_SVS_QUANT_NONE, + alpha: 1.2, + graph_max_degree: 32, + construction_window_size: 200, + max_candidate_pool_size: 0, + prune_to: 0, + use_search_history: VECSIM_OPTION_AUTO, + search_window_size: 100, + search_buffer_capacity: 100, + epsilon: 0.01, + num_threads: 0, + } + } +} + +// ============================================================================ +// Module Registration +// ============================================================================ + +/// VecSim - High-performance vector similarity search library. +#[pymodule] +fn VecSim(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add("VecSimMetric_L2", VECSIM_METRIC_L2)?; + m.add("VecSimMetric_IP", VECSIM_METRIC_IP)?; + m.add("VecSimMetric_Cosine", VECSIM_METRIC_COSINE)?; + + m.add("VecSimType_FLOAT32", VECSIM_TYPE_FLOAT32)?; + m.add("VecSimType_FLOAT64", VECSIM_TYPE_FLOAT64)?; + m.add("VecSimType_BFLOAT16", VECSIM_TYPE_BFLOAT16)?; + m.add("VecSimType_FLOAT16", VECSIM_TYPE_FLOAT16)?; + m.add("VecSimType_INT8", VECSIM_TYPE_INT8)?; + m.add("VecSimType_UINT8", VECSIM_TYPE_UINT8)?; + + m.add("BY_SCORE", BY_SCORE)?; + m.add("BY_ID", BY_ID)?; + + // SVS quantization constants + m.add("VecSimSvsQuant_NONE", VECSIM_SVS_QUANT_NONE)?; + m.add("VecSimSvsQuant_8", VECSIM_SVS_QUANT_8)?; + m.add("VecSimSvsQuant_4", VECSIM_SVS_QUANT_4)?; + + // VecSim option constants + m.add("VecSimOption_OFF", VECSIM_OPTION_OFF)?; + m.add("VecSimOption_ON", VECSIM_OPTION_ON)?; + m.add("VecSimOption_AUTO", VECSIM_OPTION_AUTO)?; + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/rust/vecsim/Cargo.toml b/rust/vecsim/Cargo.toml new file mode 100644 index 000000000..487e2a9d7 --- /dev/null +++ b/rust/vecsim/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "vecsim" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +description = "High-performance vector similarity search library with BruteForce and HNSW indices" + +[dependencies] +rayon = { workspace = true } +parking_lot = { workspace = true } +half = { workspace = true } +num-traits = { workspace = true } +thiserror = { workspace = true } +rand = { workspace = true } +memmap2 = { workspace = true } +dashmap = "5.5" + +[features] +default = [] +nightly = [] # Enable nightly-only SIMD intrinsics +profile = [] # Enable profiling instrumentation + +[dev-dependencies] +criterion = "0.5" +rand = { workspace = true } + +[[bench]] +name = "brute_force_bench" +harness = false + +[[bench]] +name = "hnsw_bench" +harness = false + +[[bench]] +name = "tiered_bench" +harness = false + +[[bench]] +name = "comparison_bench" +harness = false + +[[bench]] +name = "svs_bench" +harness = false + +[[bench]] +name = "dbpedia_bench" +harness = false + +[[bench]] +name = "hnsw_bottleneck_bench" +harness = false diff --git a/rust/vecsim/benches/brute_force_bench.rs b/rust/vecsim/benches/brute_force_bench.rs new file mode 100644 index 000000000..9b6c9b176 --- /dev/null +++ b/rust/vecsim/benches/brute_force_bench.rs @@ -0,0 +1,269 @@ +//! Benchmarks for BruteForce index operations. +//! +//! Run with: cargo bench --bench brute_force_bench + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use vecsim::distance::Metric; +use vecsim::index::brute_force::{BruteForceMulti, BruteForceParams, BruteForceSingle}; +use vecsim::index::VecSimIndex; + +const DIM: usize = 128; + +/// Generate random vectors for benchmarking. +fn generate_vectors(count: usize, dim: usize) -> Vec> { + use rand::Rng; + let mut rng = rand::thread_rng(); + (0..count) + .map(|_| (0..dim).map(|_| rng.gen::()).collect()) + .collect() +} + +/// Benchmark adding vectors to BruteForceSingle. +fn bench_bf_single_add(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_single_add"); + + for size in [100, 1000, 10000] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let vectors = generate_vectors(size, DIM); + b.iter(|| { + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark adding vectors to BruteForceMulti. +fn bench_bf_multi_add(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_multi_add"); + + for size in [100, 1000, 10000] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let vectors = generate_vectors(size, DIM); + b.iter(|| { + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceMulti::::new(params); + for (i, v) in vectors.iter().enumerate() { + // Use fewer labels to have multiple vectors per label + index.add_vector(black_box(v), (i % 100) as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries on BruteForceSingle. +fn bench_bf_single_topk(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_single_topk"); + + for size in [100, 1000, 10000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries with varying k values. +fn bench_bf_single_topk_varying_k(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_single_topk_k"); + + let size = 10000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + for k in [1, 10, 50, 100] { + group.bench_with_input(BenchmarkId::from_parameter(k), &k, |b, &k| { + b.iter(|| index.top_k_query(black_box(&query), black_box(k), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark range queries on BruteForceSingle. +fn bench_bf_single_range(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_single_range"); + + for size in [100, 1000, 10000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + // Use a radius that returns ~10% of vectors + let radius = 10.0; + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.range_query(black_box(&query), black_box(radius), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark delete operations on BruteForceSingle. +fn bench_bf_single_delete(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_single_delete"); + + for size in [100, 1000, 5000] { + let vectors = generate_vectors(size, DIM); + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + b.iter_batched( + || { + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index + }, + |mut index| { + // Delete half the vectors + for i in (0..size).step_by(2) { + index.delete_vector(i as u64).unwrap(); + } + index + }, + criterion::BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +/// Benchmark different distance metrics. +fn bench_bf_metrics(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_metrics_5000"); + + let size = 5000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + for metric in [Metric::L2, Metric::InnerProduct, Metric::Cosine] { + let params = BruteForceParams::new(DIM, metric); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let metric_name = match metric { + Metric::L2 => "L2", + Metric::InnerProduct => "IP", + Metric::Cosine => "Cosine", + }; + + group.bench_function(metric_name, |b| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark different vector dimensions. +fn bench_bf_dimensions(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_dimensions_1000"); + + let size = 1000; + + for dim in [32, 128, 512, 1024] { + let vectors = generate_vectors(size, dim); + let query = generate_vectors(1, dim).pop().unwrap(); + + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(dim), &dim, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark serialization round-trip. +fn bench_bf_serialization(c: &mut Criterion) { + let mut group = c.benchmark_group("bf_serialization"); + + for size in [1000, 5000, 10000] { + let vectors = generate_vectors(size, DIM); + + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::new("save", size), &size, |b, _| { + b.iter(|| { + let mut buffer = Vec::new(); + index.save(black_box(&mut buffer)).unwrap(); + buffer + }); + }); + + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + group.bench_with_input(BenchmarkId::new("load", size), &size, |b, _| { + b.iter(|| { + let mut cursor = std::io::Cursor::new(&buffer); + BruteForceSingle::::load(black_box(&mut cursor)).unwrap() + }); + }); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_bf_single_add, + bench_bf_multi_add, + bench_bf_single_topk, + bench_bf_single_topk_varying_k, + bench_bf_single_range, + bench_bf_single_delete, + bench_bf_metrics, + bench_bf_dimensions, + bench_bf_serialization, +); + +criterion_main!(benches); diff --git a/rust/vecsim/benches/comparison_bench.rs b/rust/vecsim/benches/comparison_bench.rs new file mode 100644 index 000000000..1d98f29e7 --- /dev/null +++ b/rust/vecsim/benches/comparison_bench.rs @@ -0,0 +1,160 @@ +//! Benchmarks designed to compare with C++ implementation. +//! +//! Uses similar parameters to C++ benchmarks for fair comparison: +//! - 10,000 vectors (smaller than C++ 1M for quick testing) +//! - 128 dimensions +//! - HNSW M=16, ef_construction=200, ef_runtime=10/100/200 +//! +//! Run with: cargo bench --bench comparison_bench + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use vecsim::distance::Metric; +use vecsim::index::brute_force::{BruteForceParams, BruteForceSingle}; +use vecsim::index::hnsw::{HnswParams, HnswSingle}; +use vecsim::index::VecSimIndex; +use vecsim::query::QueryParams; + +const DIM: usize = 128; +const N_VECTORS: usize = 10_000; +const N_QUERIES: usize = 100; + +/// Generate normalized random vectors (for cosine similarity). +fn generate_normalized_vectors(count: usize, dim: usize) -> Vec> { + use rand::Rng; + let mut rng = rand::thread_rng(); + (0..count) + .map(|_| { + let mut v: Vec = (0..dim).map(|_| rng.gen::() - 0.5).collect(); + let norm: f32 = v.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut v { + *x /= norm; + } + v + }) + .collect() +} + +/// Benchmark: BruteForce Top-K query (baseline) +fn bench_bf_topk(c: &mut Criterion) { + let mut group = c.benchmark_group("comparison_bf_topk"); + + let vectors = generate_normalized_vectors(N_VECTORS, DIM); + let queries = generate_normalized_vectors(N_QUERIES, DIM); + + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + for k in [10, 100] { + group.bench_with_input(BenchmarkId::new("k", k), &k, |b, &k| { + b.iter(|| { + for query in &queries { + black_box(index.top_k_query(query, k, None).unwrap()); + } + }); + }); + } + + group.finish(); +} + +/// Benchmark: HNSW Top-K query with varying ef_runtime +fn bench_hnsw_topk(c: &mut Criterion) { + let mut group = c.benchmark_group("comparison_hnsw_topk"); + + let vectors = generate_normalized_vectors(N_VECTORS, DIM); + let queries = generate_normalized_vectors(N_QUERIES, DIM); + + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(200) + .with_ef_runtime(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + // Test different ef_runtime values + for ef in [10, 100, 200] { + let query_params = QueryParams::new().with_ef_runtime(ef); + group.bench_with_input(BenchmarkId::new("ef", ef), &ef, |b, _| { + b.iter(|| { + for query in &queries { + black_box(index.top_k_query(query, 10, Some(&query_params)).unwrap()); + } + }); + }); + } + + group.finish(); +} + +/// Benchmark: HNSW index construction +fn bench_hnsw_construction(c: &mut Criterion) { + let mut group = c.benchmark_group("comparison_hnsw_construction"); + group.sample_size(10); + + let vectors = generate_normalized_vectors(N_VECTORS, DIM); + + group.bench_function("10k_vectors", |b| { + b.iter(|| { + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(200); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + + group.finish(); +} + +/// Benchmark: Different metrics +fn bench_metrics(c: &mut Criterion) { + let mut group = c.benchmark_group("comparison_metrics"); + + let vectors = generate_normalized_vectors(N_VECTORS, DIM); + let queries = generate_normalized_vectors(N_QUERIES, DIM); + + for metric in [Metric::L2, Metric::InnerProduct, Metric::Cosine] { + let params = HnswParams::new(DIM, metric) + .with_m(16) + .with_ef_construction(200) + .with_ef_runtime(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let name = match metric { + Metric::L2 => "L2", + Metric::InnerProduct => "IP", + Metric::Cosine => "Cosine", + }; + + group.bench_function(name, |b| { + b.iter(|| { + for query in &queries { + black_box(index.top_k_query(query, 10, None).unwrap()); + } + }); + }); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_bf_topk, + bench_hnsw_topk, + bench_hnsw_construction, + bench_metrics, +); + +criterion_main!(benches); diff --git a/rust/vecsim/benches/data_loader.rs b/rust/vecsim/benches/data_loader.rs new file mode 100644 index 000000000..888cb228c --- /dev/null +++ b/rust/vecsim/benches/data_loader.rs @@ -0,0 +1,165 @@ +//! Data loader for benchmark datasets. +//! +//! Loads binary vector files in the same format as C++ benchmarks. +//! The format is simple: n_vectors * dim * sizeof(T) contiguous bytes. + +use std::fs::File; +use std::io::{BufReader, Read}; +use std::path::Path; + +/// Load f32 vectors from a binary file. +/// +/// Format: contiguous f32 values, each vector is `dim` floats. +pub fn load_f32_vectors(path: &Path, n_vectors: usize, dim: usize) -> std::io::Result>> { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + + let mut vectors = Vec::with_capacity(n_vectors); + let mut buffer = vec![0u8; dim * std::mem::size_of::()]; + + for _ in 0..n_vectors { + reader.read_exact(&mut buffer)?; + let vector: Vec = buffer + .chunks_exact(4) + .map(|bytes| f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])) + .collect(); + vectors.push(vector); + } + + Ok(vectors) +} + +/// Load f64 vectors from a binary file. +pub fn load_f64_vectors(path: &Path, n_vectors: usize, dim: usize) -> std::io::Result>> { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + + let mut vectors = Vec::with_capacity(n_vectors); + let mut buffer = vec![0u8; dim * std::mem::size_of::()]; + + for _ in 0..n_vectors { + reader.read_exact(&mut buffer)?; + let vector: Vec = buffer + .chunks_exact(8) + .map(|bytes| { + f64::from_le_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], + bytes[4], bytes[5], bytes[6], bytes[7], + ]) + }) + .collect(); + vectors.push(vector); + } + + Ok(vectors) +} + +/// Dataset configuration matching C++ benchmarks. +pub struct DatasetConfig { + pub name: &'static str, + pub dim: usize, + pub n_vectors: usize, + pub n_queries: usize, + pub m: usize, + pub ef_construction: usize, + pub vectors_file: &'static str, + pub queries_file: &'static str, +} + +/// DBPedia single-value dataset (fp32, cosine, 768 dim, 1M vectors). +/// Download with: bash tests/benchmark/bm_files.sh benchmarks-all +pub const DBPEDIA_SINGLE_FP32: DatasetConfig = DatasetConfig { + name: "dbpedia_single_fp32", + dim: 768, + n_vectors: 1_000_000, + n_queries: 10_000, + m: 64, + ef_construction: 512, + vectors_file: "tests/benchmark/data/dbpedia-cosine-dim768-1M-vectors.raw", + queries_file: "tests/benchmark/data/dbpedia-cosine-dim768-test_vectors.raw", +}; + +/// Fashion images multi-value dataset (fp32, cosine, 512 dim). +/// Download with: bash tests/benchmark/bm_files.sh benchmarks-all +pub const FASHION_MULTI_FP32: DatasetConfig = DatasetConfig { + name: "fashion_multi_fp32", + dim: 512, + n_vectors: 1_111_025, + n_queries: 10_000, + m: 64, + ef_construction: 512, + vectors_file: "tests/benchmark/data/fashion_images_multi_value-cosine-dim512-1M-vectors.raw", + queries_file: "tests/benchmark/data/fashion_images_multi_value-cosine-dim512-test_vectors.raw", +}; + +/// Try to find the repository root by looking for benchmark data directory. +pub fn find_repo_root() -> Option { + let mut current = std::env::current_dir().ok()?; + + // Try current directory and parents + for _ in 0..5 { + let test_path = current.join("tests/benchmark/data"); + if test_path.exists() { + return Some(current); + } + current = current.parent()?.to_path_buf(); + } + + None +} + +/// Load dataset vectors, returning None if files don't exist. +pub fn try_load_dataset_vectors(config: &DatasetConfig, max_vectors: usize) -> Option>> { + let repo_root = find_repo_root()?; + let vectors_path = repo_root.join(config.vectors_file); + + if !vectors_path.exists() { + eprintln!("Dataset vectors not found: {:?}", vectors_path); + eprintln!("Run: bash tests/benchmark/bm_files.sh benchmarks-all"); + return None; + } + + let n = max_vectors.min(config.n_vectors); + load_f32_vectors(&vectors_path, n, config.dim).ok() +} + +/// Load dataset queries, returning None if files don't exist. +pub fn try_load_dataset_queries(config: &DatasetConfig, max_queries: usize) -> Option>> { + let repo_root = find_repo_root()?; + let queries_path = repo_root.join(config.queries_file); + + if !queries_path.exists() { + eprintln!("Dataset queries not found: {:?}", queries_path); + return None; + } + + let n = max_queries.min(config.n_queries); + load_f32_vectors(&queries_path, n, config.dim).ok() +} + +/// Generate random f32 vectors (fallback when real data is unavailable). +pub fn generate_random_vectors(count: usize, dim: usize) -> Vec> { + use rand::Rng; + let mut rng = rand::thread_rng(); + (0..count) + .map(|_| (0..dim).map(|_| rng.gen::()).collect()) + .collect() +} + +/// Generate normalized random f32 vectors (for cosine similarity). +pub fn generate_normalized_vectors(count: usize, dim: usize) -> Vec> { + use rand::Rng; + let mut rng = rand::thread_rng(); + (0..count) + .map(|_| { + let mut v: Vec = (0..dim).map(|_| rng.gen::() - 0.5).collect(); + let norm: f32 = v.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + for x in &mut v { + *x /= norm; + } + } + v + }) + .collect() +} diff --git a/rust/vecsim/benches/dbpedia_bench.rs b/rust/vecsim/benches/dbpedia_bench.rs new file mode 100644 index 000000000..f55af42af --- /dev/null +++ b/rust/vecsim/benches/dbpedia_bench.rs @@ -0,0 +1,1004 @@ +//! Benchmarks using DBPedia dataset (same as C++ benchmarks). +//! +//! This benchmark uses the same data files as the C++ benchmarks for +//! direct comparison. If data files are not available, it falls back +//! to random data with the same dimensions. +//! +//! Dataset: DBPedia embeddings +//! - 1M vectors, 768 dimensions, Cosine similarity +//! - 10K query vectors +//! - HNSW parameters: M=64, EF_C=512 +//! +//! Tested data types (matching C++ benchmarks): +//! - f32 (fp32) +//! - f64 (fp64) +//! - BFloat16 (bf16) +//! - Float16 (fp16) +//! - Int8 (int8) +//! - UInt8 (uint8) +//! +//! To download the benchmark data files, run from repository root: +//! bash tests/benchmark/bm_files.sh benchmarks-all +//! +//! Run with: cargo bench --bench dbpedia_bench + +mod data_loader; + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use data_loader::{ + generate_normalized_vectors, try_load_dataset_queries, try_load_dataset_vectors, + DBPEDIA_SINGLE_FP32, +}; +use std::collections::HashSet; +use vecsim::distance::Metric; +use vecsim::index::brute_force::{BruteForceParams, BruteForceSingle}; +use vecsim::index::hnsw::{HnswParams, HnswSingle}; +use vecsim::index::VecSimIndex; +use vecsim::query::QueryParams; +use vecsim::types::{BFloat16, DistanceType, Float16, Int8, UInt8, VectorElement}; + +// ============================================================================ +// Data Loading +// ============================================================================ + +/// Benchmark data holder - loaded once for all benchmarks. +struct BenchmarkData { + vectors_f32: Vec>, + queries_f32: Vec>, + dim: usize, + is_real_data: bool, +} + +impl BenchmarkData { + fn load(max_vectors: usize, max_queries: usize) -> Self { + let config = &DBPEDIA_SINGLE_FP32; + + // Try to load real data + if let (Some(vectors), Some(queries)) = ( + try_load_dataset_vectors(config, max_vectors), + try_load_dataset_queries(config, max_queries), + ) { + println!( + "Loaded real DBPedia dataset: {} vectors, {} queries, dim={}", + vectors.len(), + queries.len(), + config.dim + ); + return Self { + vectors_f32: vectors, + queries_f32: queries, + dim: config.dim, + is_real_data: true, + }; + } + + // Fall back to random data + println!("Using random data (real dataset not found)"); + println!("To use real data, run: bash tests/benchmark/bm_files.sh benchmarks-all"); + let dim = config.dim; + Self { + vectors_f32: generate_normalized_vectors(max_vectors, dim), + queries_f32: generate_normalized_vectors(max_queries, dim), + dim, + is_real_data: false, + } + } + + /// Convert vectors to a different element type. + fn vectors_as(&self) -> Vec> { + self.vectors_f32 + .iter() + .map(|v| v.iter().map(|&x| T::from_f32(x)).collect()) + .collect() + } + + /// Convert queries to a different element type. + fn queries_as(&self) -> Vec> { + self.queries_f32 + .iter() + .map(|v| v.iter().map(|&x| T::from_f32(x)).collect()) + .collect() + } + + fn data_label(&self) -> &'static str { + if self.is_real_data { + "real" + } else { + "random" + } + } +} + +// ============================================================================ +// Generic Index Builders +// ============================================================================ + +fn build_hnsw_index( + vectors: &[Vec], + dim: usize, + n_vectors: usize, +) -> HnswSingle { + let params = HnswParams::new(dim, Metric::Cosine) + .with_m(DBPEDIA_SINGLE_FP32.m) + .with_ef_construction(DBPEDIA_SINGLE_FP32.ef_construction) + .with_ef_runtime(10); + + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().take(n_vectors).enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index +} + +fn build_bf_index( + vectors: &[Vec], + dim: usize, + n_vectors: usize, +) -> BruteForceSingle { + let params = BruteForceParams::new(dim, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().take(n_vectors).enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index +} + +/// Build HNSW index with lighter parameters for faster recall testing. +fn build_hnsw_index_light( + vectors: &[Vec], + dim: usize, + n_vectors: usize, +) -> HnswSingle { + // Use lighter parameters for faster benchmark execution + let params = HnswParams::new(dim, Metric::Cosine) + .with_m(16) // Smaller M for faster construction + .with_ef_construction(64) // Smaller ef_c for faster construction + .with_ef_runtime(10); + + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().take(n_vectors).enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index +} + +// ============================================================================ +// Recall Measurement +// ============================================================================ + +/// Compute recall: fraction of ground truth results found in approximate results. +/// +/// Recall = |approximate ∩ ground_truth| / |ground_truth| +fn compute_recall(approximate: &[(u64, f64)], ground_truth: &[(u64, f64)]) -> f64 { + if ground_truth.is_empty() { + return 1.0; + } + + let gt_ids: HashSet = ground_truth.iter().map(|(id, _)| *id).collect(); + let found = approximate.iter().filter(|(id, _)| gt_ids.contains(id)).count(); + + found as f64 / ground_truth.len() as f64 +} + +/// Compute ground truth for a set of queries using brute force search. +fn compute_ground_truth( + bf_index: &BruteForceSingle, + queries: &[Vec], + k: usize, +) -> Vec> { + queries + .iter() + .map(|q| { + bf_index + .top_k_query(q, k, None) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect() + }) + .collect() +} + +/// Compute average recall for HNSW index against ground truth. +fn measure_recall( + hnsw_index: &HnswSingle, + queries: &[Vec], + ground_truth: &[Vec<(u64, f64)>], + k: usize, + ef_runtime: usize, +) -> f64 { + let query_params = QueryParams::new().with_ef_runtime(ef_runtime); + + let total_recall: f64 = queries + .iter() + .zip(ground_truth.iter()) + .map(|(q, gt)| { + let results: Vec<(u64, f64)> = hnsw_index + .top_k_query(q, k, Some(&query_params)) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect(); + compute_recall(&results, gt) + }) + .sum(); + + total_recall / queries.len() as f64 +} + +/// Print recall measurements for various ef_runtime values. +fn print_recall_report( + hnsw_index: &HnswSingle, + queries: &[Vec], + ground_truth: &[Vec<(u64, f64)>], + k: usize, + type_name: &str, +) { + println!("\n{} Recall Report (k={}):", type_name, k); + println!("{:>10} {:>10}", "ef_runtime", "recall"); + println!("{:-<22}", ""); + + for ef in [10, 20, 50, 100, 200, 500, 1000] { + let recall = measure_recall(hnsw_index, queries, ground_truth, k, ef); + println!("{:>10} {:>10.4}", ef, recall); + } +} + +// ============================================================================ +// F32 Benchmarks (Primary - most detailed) +// ============================================================================ + +/// Benchmark top-k queries on HNSW with varying ef_runtime (f32). +fn bench_f32_topk_ef_runtime(c: &mut Criterion) { + let data = BenchmarkData::load(100_000, 1_000); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 100_000); + + let mut group = c.benchmark_group("f32_hnsw_topk_ef"); + + for ef in [10, 50, 100, 200, 500] { + let query_params = QueryParams::new().with_ef_runtime(ef); + let label = format!("{}_{}", data.data_label(), ef); + + group.bench_function(BenchmarkId::from_parameter(label), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries with varying k (f32). +fn bench_f32_topk_k(c: &mut Criterion) { + let data = BenchmarkData::load(100_000, 1_000); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 100_000); + + let mut group = c.benchmark_group("f32_hnsw_topk_k"); + let query_params = QueryParams::new().with_ef_runtime(200); + + for k in [1, 10, 50, 100, 500] { + let label = format!("{}_{}", data.data_label(), k); + + group.bench_function(BenchmarkId::from_parameter(label), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(k), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +/// Benchmark HNSW vs BruteForce comparison (f32). +fn bench_f32_hnsw_vs_bf(c: &mut Criterion) { + let data = BenchmarkData::load(100_000, 1_000); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + + let hnsw_index = build_hnsw_index(&vectors, data.dim, 100_000); + let bf_index = build_bf_index(&vectors, data.dim, 100_000); + + let mut group = c.benchmark_group("f32_hnsw_vs_bf"); + + // BruteForce baseline + group.bench_function(format!("{}_bf", data.data_label()), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + bf_index + .top_k_query(black_box(query), black_box(10), None) + .unwrap() + }); + }); + + // HNSW with various ef values + for ef in [10, 100, 500] { + let query_params = QueryParams::new().with_ef_runtime(ef); + group.bench_function(format!("{}_hnsw_ef{}", data.data_label(), ef), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + hnsw_index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// F64 Benchmarks +// ============================================================================ + +fn bench_f64_topk(c: &mut Criterion) { + let data = BenchmarkData::load(50_000, 500); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 50_000); + + let mut group = c.benchmark_group("f64_hnsw_topk"); + + for ef in [10, 100, 500] { + let query_params = QueryParams::new().with_ef_runtime(ef); + let label = format!("{}_{}", data.data_label(), ef); + + group.bench_function(BenchmarkId::from_parameter(label), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// BFloat16 Benchmarks +// ============================================================================ + +fn bench_bf16_topk(c: &mut Criterion) { + let data = BenchmarkData::load(100_000, 1_000); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 100_000); + + let mut group = c.benchmark_group("bf16_hnsw_topk"); + + for ef in [10, 100, 500] { + let query_params = QueryParams::new().with_ef_runtime(ef); + let label = format!("{}_{}", data.data_label(), ef); + + group.bench_function(BenchmarkId::from_parameter(label), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// Float16 Benchmarks +// ============================================================================ + +fn bench_fp16_topk(c: &mut Criterion) { + let data = BenchmarkData::load(100_000, 1_000); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 100_000); + + let mut group = c.benchmark_group("fp16_hnsw_topk"); + + for ef in [10, 100, 500] { + let query_params = QueryParams::new().with_ef_runtime(ef); + let label = format!("{}_{}", data.data_label(), ef); + + group.bench_function(BenchmarkId::from_parameter(label), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// Int8 Benchmarks +// ============================================================================ + +fn bench_int8_topk(c: &mut Criterion) { + let data = BenchmarkData::load(100_000, 1_000); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 100_000); + + let mut group = c.benchmark_group("int8_hnsw_topk"); + + for ef in [10, 100, 500] { + let query_params = QueryParams::new().with_ef_runtime(ef); + let label = format!("{}_{}", data.data_label(), ef); + + group.bench_function(BenchmarkId::from_parameter(label), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// UInt8 Benchmarks +// ============================================================================ + +fn bench_uint8_topk(c: &mut Criterion) { + let data = BenchmarkData::load(100_000, 1_000); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 100_000); + + let mut group = c.benchmark_group("uint8_hnsw_topk"); + + for ef in [10, 100, 500] { + let query_params = QueryParams::new().with_ef_runtime(ef); + let label = format!("{}_{}", data.data_label(), ef); + + group.bench_function(BenchmarkId::from_parameter(label), |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// Cross-Type Comparison Benchmarks +// ============================================================================ + +/// Compare query performance across all data types. +fn bench_all_types_comparison(c: &mut Criterion) { + let data = BenchmarkData::load(50_000, 500); + let query_params = QueryParams::new().with_ef_runtime(100); + + let mut group = c.benchmark_group("all_types_topk10_ef100"); + + // f32 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 50_000); + + group.bench_function("f32", |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + // f64 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 50_000); + + group.bench_function("f64", |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + // bf16 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 50_000); + + group.bench_function("bf16", |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + // fp16 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 50_000); + + group.bench_function("fp16", |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + // int8 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 50_000); + + group.bench_function("int8", |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + // uint8 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let index = build_hnsw_index(&vectors, data.dim, 50_000); + + group.bench_function("uint8", |b| { + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + query_idx += 1; + index + .top_k_query(black_box(query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// Index Construction Benchmarks (all types) +// ============================================================================ + +fn bench_all_types_add(c: &mut Criterion) { + let data = BenchmarkData::load(5_000, 100); + + let mut group = c.benchmark_group("all_types_add_5000"); + group.sample_size(10); + + // f32 + { + let vectors = data.vectors_as::(); + group.bench_function("f32", |b| { + b.iter(|| { + let params = HnswParams::new(data.dim, Metric::Cosine) + .with_m(DBPEDIA_SINGLE_FP32.m) + .with_ef_construction(DBPEDIA_SINGLE_FP32.ef_construction); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + // f64 + { + let vectors = data.vectors_as::(); + group.bench_function("f64", |b| { + b.iter(|| { + let params = HnswParams::new(data.dim, Metric::Cosine) + .with_m(DBPEDIA_SINGLE_FP32.m) + .with_ef_construction(DBPEDIA_SINGLE_FP32.ef_construction); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + // bf16 + { + let vectors = data.vectors_as::(); + group.bench_function("bf16", |b| { + b.iter(|| { + let params = HnswParams::new(data.dim, Metric::Cosine) + .with_m(DBPEDIA_SINGLE_FP32.m) + .with_ef_construction(DBPEDIA_SINGLE_FP32.ef_construction); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + // fp16 + { + let vectors = data.vectors_as::(); + group.bench_function("fp16", |b| { + b.iter(|| { + let params = HnswParams::new(data.dim, Metric::Cosine) + .with_m(DBPEDIA_SINGLE_FP32.m) + .with_ef_construction(DBPEDIA_SINGLE_FP32.ef_construction); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + // int8 + { + let vectors = data.vectors_as::(); + group.bench_function("int8", |b| { + b.iter(|| { + let params = HnswParams::new(data.dim, Metric::Cosine) + .with_m(DBPEDIA_SINGLE_FP32.m) + .with_ef_construction(DBPEDIA_SINGLE_FP32.ef_construction); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + // uint8 + { + let vectors = data.vectors_as::(); + group.bench_function("uint8", |b| { + b.iter(|| { + let params = HnswParams::new(data.dim, Metric::Cosine) + .with_m(DBPEDIA_SINGLE_FP32.m) + .with_ef_construction(DBPEDIA_SINGLE_FP32.ef_construction); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// Recall Benchmarks +// ============================================================================ + +/// Benchmark that measures and reports recall for f32 HNSW at various ef_runtime values. +fn bench_f32_recall(c: &mut Criterion) { + // Use smaller dataset and lighter HNSW params for faster benchmark execution + let n_vectors = 5_000; + let n_queries = 50; + let data = BenchmarkData::load(n_vectors, n_queries); + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + + println!("\nBuilding indices for recall measurement ({} vectors, {} queries)...", n_vectors, n_queries); + let hnsw_index = build_hnsw_index_light(&vectors, data.dim, n_vectors); + let bf_index = build_bf_index(&vectors, data.dim, n_vectors); + + let k = 10; + println!("Computing ground truth (k={})...", k); + let ground_truth = compute_ground_truth(&bf_index, &queries, k); + + // Print recall report + print_recall_report(&hnsw_index, &queries, &ground_truth, k, "f32"); + + // Benchmark recall computation at different ef values + let mut group = c.benchmark_group("f32_recall_vs_ef"); + + for ef in [10, 50, 100, 200, 500] { + let label = format!("{}_{}", data.data_label(), ef); + + group.bench_function(BenchmarkId::from_parameter(label), |b| { + let query_params = QueryParams::new().with_ef_runtime(ef); + let mut query_idx = 0; + + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + let gt = &ground_truth[query_idx % ground_truth.len()]; + query_idx += 1; + + let results: Vec<(u64, f64)> = hnsw_index + .top_k_query(black_box(query), black_box(k), Some(&query_params)) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect(); + + compute_recall(&results, gt) + }); + }); + } + + group.finish(); +} + +/// Measure recall across all data types at a fixed ef_runtime. +fn bench_all_types_recall(c: &mut Criterion) { + // Use smaller dataset for faster benchmark execution + let n_vectors = 5_000; + let n_queries = 50; + let data = BenchmarkData::load(n_vectors, n_queries); + let k = 10; + let ef_runtime = 100; + + println!("\n============================================"); + println!("Cross-Type Recall Comparison (k={}, ef={})", k, ef_runtime); + println!("============================================"); + + let mut group = c.benchmark_group("all_types_recall"); + + // f32 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let hnsw_index = build_hnsw_index_light(&vectors, data.dim, n_vectors); + let bf_index = build_bf_index(&vectors, data.dim, n_vectors); + let ground_truth = compute_ground_truth(&bf_index, &queries, k); + let recall = measure_recall(&hnsw_index, &queries, &ground_truth, k, ef_runtime); + println!("f32 recall: {:.4}", recall); + + group.bench_function("f32", |b| { + let query_params = QueryParams::new().with_ef_runtime(ef_runtime); + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + let gt = &ground_truth[query_idx % ground_truth.len()]; + query_idx += 1; + let results: Vec<(u64, f64)> = hnsw_index + .top_k_query(black_box(query), black_box(k), Some(&query_params)) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect(); + compute_recall(&results, gt) + }); + }); + } + + // f64 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let hnsw_index = build_hnsw_index_light(&vectors, data.dim, n_vectors); + let bf_index = build_bf_index(&vectors, data.dim, n_vectors); + let ground_truth = compute_ground_truth(&bf_index, &queries, k); + let recall = measure_recall(&hnsw_index, &queries, &ground_truth, k, ef_runtime); + println!("f64 recall: {:.4}", recall); + + group.bench_function("f64", |b| { + let query_params = QueryParams::new().with_ef_runtime(ef_runtime); + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + let gt = &ground_truth[query_idx % ground_truth.len()]; + query_idx += 1; + let results: Vec<(u64, f64)> = hnsw_index + .top_k_query(black_box(query), black_box(k), Some(&query_params)) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect(); + compute_recall(&results, gt) + }); + }); + } + + // bf16 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let hnsw_index = build_hnsw_index_light(&vectors, data.dim, n_vectors); + let bf_index = build_bf_index(&vectors, data.dim, n_vectors); + let ground_truth = compute_ground_truth(&bf_index, &queries, k); + let recall = measure_recall(&hnsw_index, &queries, &ground_truth, k, ef_runtime); + println!("bf16 recall: {:.4}", recall); + + group.bench_function("bf16", |b| { + let query_params = QueryParams::new().with_ef_runtime(ef_runtime); + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + let gt = &ground_truth[query_idx % ground_truth.len()]; + query_idx += 1; + let results: Vec<(u64, f64)> = hnsw_index + .top_k_query(black_box(query), black_box(k), Some(&query_params)) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect(); + compute_recall(&results, gt) + }); + }); + } + + // fp16 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let hnsw_index = build_hnsw_index_light(&vectors, data.dim, n_vectors); + let bf_index = build_bf_index(&vectors, data.dim, n_vectors); + let ground_truth = compute_ground_truth(&bf_index, &queries, k); + let recall = measure_recall(&hnsw_index, &queries, &ground_truth, k, ef_runtime); + println!("fp16 recall: {:.4}", recall); + + group.bench_function("fp16", |b| { + let query_params = QueryParams::new().with_ef_runtime(ef_runtime); + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + let gt = &ground_truth[query_idx % ground_truth.len()]; + query_idx += 1; + let results: Vec<(u64, f64)> = hnsw_index + .top_k_query(black_box(query), black_box(k), Some(&query_params)) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect(); + compute_recall(&results, gt) + }); + }); + } + + // int8 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let hnsw_index = build_hnsw_index_light(&vectors, data.dim, n_vectors); + let bf_index = build_bf_index(&vectors, data.dim, n_vectors); + let ground_truth = compute_ground_truth(&bf_index, &queries, k); + let recall = measure_recall(&hnsw_index, &queries, &ground_truth, k, ef_runtime); + println!("int8 recall: {:.4}", recall); + + group.bench_function("int8", |b| { + let query_params = QueryParams::new().with_ef_runtime(ef_runtime); + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + let gt = &ground_truth[query_idx % ground_truth.len()]; + query_idx += 1; + let results: Vec<(u64, f64)> = hnsw_index + .top_k_query(black_box(query), black_box(k), Some(&query_params)) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect(); + compute_recall(&results, gt) + }); + }); + } + + // uint8 + { + let vectors = data.vectors_as::(); + let queries = data.queries_as::(); + let hnsw_index = build_hnsw_index_light(&vectors, data.dim, n_vectors); + let bf_index = build_bf_index(&vectors, data.dim, n_vectors); + let ground_truth = compute_ground_truth(&bf_index, &queries, k); + let recall = measure_recall(&hnsw_index, &queries, &ground_truth, k, ef_runtime); + println!("uint8 recall: {:.4}", recall); + + group.bench_function("uint8", |b| { + let query_params = QueryParams::new().with_ef_runtime(ef_runtime); + let mut query_idx = 0; + b.iter(|| { + let query = &queries[query_idx % queries.len()]; + let gt = &ground_truth[query_idx % ground_truth.len()]; + query_idx += 1; + let results: Vec<(u64, f64)> = hnsw_index + .top_k_query(black_box(query), black_box(k), Some(&query_params)) + .unwrap() + .into_iter() + .map(|r| (r.label, r.distance.to_f64())) + .collect(); + compute_recall(&results, gt) + }); + }); + } + + group.finish(); +} + +// ============================================================================ +// Main Benchmark Groups +// ============================================================================ + +criterion_group!( + benches, + // F32 detailed benchmarks + bench_f32_topk_ef_runtime, + bench_f32_topk_k, + bench_f32_hnsw_vs_bf, + // Per-type benchmarks + bench_f64_topk, + bench_bf16_topk, + bench_fp16_topk, + bench_int8_topk, + bench_uint8_topk, + // Cross-type comparisons + bench_all_types_comparison, + bench_all_types_add, + // Recall measurements + bench_f32_recall, + bench_all_types_recall, +); + +criterion_main!(benches); diff --git a/rust/vecsim/benches/hnsw_bench.rs b/rust/vecsim/benches/hnsw_bench.rs new file mode 100644 index 000000000..c483c7409 --- /dev/null +++ b/rust/vecsim/benches/hnsw_bench.rs @@ -0,0 +1,383 @@ +//! Benchmarks for HNSW index operations. +//! +//! Run with: cargo bench --bench hnsw_bench + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use vecsim::distance::Metric; +use vecsim::index::hnsw::{HnswMulti, HnswParams, HnswSingle}; +use vecsim::index::VecSimIndex; + +const DIM: usize = 128; + +/// Generate random vectors for benchmarking. +fn generate_vectors(count: usize, dim: usize) -> Vec> { + use rand::Rng; + let mut rng = rand::thread_rng(); + (0..count) + .map(|_| (0..dim).map(|_| rng.gen::()).collect()) + .collect() +} + +/// Benchmark adding vectors to HnswSingle. +fn bench_hnsw_single_add(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_single_add"); + group.sample_size(10); // HNSW add is slow, reduce samples + + for size in [100, 500, 1000] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let vectors = generate_vectors(size, DIM); + b.iter(|| { + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark adding vectors with varying M parameter. +fn bench_hnsw_add_varying_m(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_add_m"); + group.sample_size(10); + + let size = 500; + let vectors = generate_vectors(size, DIM); + + for m in [4, 8, 16, 32] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(m), &m, |b, &m| { + b.iter(|| { + let params = HnswParams::new(DIM, Metric::L2) + .with_m(m) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark adding vectors with varying ef_construction. +fn bench_hnsw_add_varying_ef(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_add_ef_construction"); + group.sample_size(10); + + let size = 500; + let vectors = generate_vectors(size, DIM); + + for ef in [50, 100, 200, 400] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(ef), &ef, |b, &ef| { + b.iter(|| { + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(ef); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries on HnswSingle. +fn bench_hnsw_single_topk(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_single_topk"); + + for size in [1000, 5000, 10000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries with varying ef_runtime. +fn bench_hnsw_topk_varying_ef_runtime(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_topk_ef_runtime"); + + let size = 10000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(10); // Will be overridden per query + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + for ef in [10, 50, 100, 200] { + let query_params = vecsim::query::QueryParams::new().with_ef_runtime(ef); + group.bench_with_input(BenchmarkId::from_parameter(ef), &ef, |b, _| { + b.iter(|| { + index + .top_k_query(black_box(&query), black_box(10), Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries with varying k values. +fn bench_hnsw_topk_varying_k(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_topk_k"); + + let size = 10000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + for k in [1, 10, 50, 100] { + group.bench_with_input(BenchmarkId::from_parameter(k), &k, |b, &k| { + b.iter(|| index.top_k_query(black_box(&query), black_box(k), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark range queries on HnswSingle. +fn bench_hnsw_single_range(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_single_range"); + + for size in [1000, 5000, 10000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let radius = 10.0; + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.range_query(black_box(&query), black_box(radius), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark delete operations on HnswSingle. +fn bench_hnsw_single_delete(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_single_delete"); + group.sample_size(10); + + for size in [500, 1000, 2000] { + let vectors = generate_vectors(size, DIM); + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + b.iter_batched( + || { + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index + }, + |mut index| { + // Delete half the vectors + for i in (0..size).step_by(2) { + index.delete_vector(i as u64).unwrap(); + } + index + }, + criterion::BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +/// Benchmark HnswMulti with multiple vectors per label. +fn bench_hnsw_multi_add(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_multi_add"); + group.sample_size(10); + + for size in [100, 500, 1000] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let vectors = generate_vectors(size, DIM); + b.iter(|| { + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswMulti::::new(params); + for (i, v) in vectors.iter().enumerate() { + // Use fewer labels to have multiple vectors per label + index.add_vector(black_box(v), (i % 50) as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark different distance metrics for HNSW. +fn bench_hnsw_metrics(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_metrics_5000"); + + let size = 5000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + for metric in [Metric::L2, Metric::InnerProduct, Metric::Cosine] { + let params = HnswParams::new(DIM, metric) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let metric_name = match metric { + Metric::L2 => "L2", + Metric::InnerProduct => "IP", + Metric::Cosine => "Cosine", + }; + + group.bench_function(metric_name, |b| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark different vector dimensions for HNSW. +fn bench_hnsw_dimensions(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_dimensions_1000"); + + let size = 1000; + + for dim in [32, 128, 512] { + let vectors = generate_vectors(size, dim); + let query = generate_vectors(1, dim).pop().unwrap(); + + let params = HnswParams::new(dim, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(dim), &dim, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark serialization round-trip for HNSW. +fn bench_hnsw_serialization(c: &mut Criterion) { + let mut group = c.benchmark_group("hnsw_serialization"); + group.sample_size(10); + + for size in [1000, 5000] { + let vectors = generate_vectors(size, DIM); + + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::new("save", size), &size, |b, _| { + b.iter(|| { + let mut buffer = Vec::new(); + index.save(black_box(&mut buffer)).unwrap(); + buffer + }); + }); + + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + group.bench_with_input(BenchmarkId::new("load", size), &size, |b, _| { + b.iter(|| { + let mut cursor = std::io::Cursor::new(&buffer); + HnswSingle::::load(black_box(&mut cursor)).unwrap() + }); + }); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_hnsw_single_add, + bench_hnsw_add_varying_m, + bench_hnsw_add_varying_ef, + bench_hnsw_single_topk, + bench_hnsw_topk_varying_ef_runtime, + bench_hnsw_topk_varying_k, + bench_hnsw_single_range, + bench_hnsw_single_delete, + bench_hnsw_multi_add, + bench_hnsw_metrics, + bench_hnsw_dimensions, + bench_hnsw_serialization, +); + +criterion_main!(benches); diff --git a/rust/vecsim/benches/hnsw_bottleneck_bench.rs b/rust/vecsim/benches/hnsw_bottleneck_bench.rs new file mode 100644 index 000000000..34a8eba4e --- /dev/null +++ b/rust/vecsim/benches/hnsw_bottleneck_bench.rs @@ -0,0 +1,469 @@ +//! Benchmarks for measuring HNSW performance bottlenecks. +//! +//! This module measures individual components to identify performance bottlenecks: +//! - Distance computation (L2, IP, Cosine) for different dimensions +//! - SIMD vs scalar implementations +//! - Visited nodes tracking operations +//! - Neighbor selection algorithms +//! - Search layer performance at different ef values +//! - Memory access patterns +//! +//! Run with: cargo bench --bench hnsw_bottleneck_bench + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use rand::Rng; +use vecsim::distance::cosine::CosineDistance; +use vecsim::distance::ip::InnerProductDistance; +use vecsim::distance::l2::L2Distance; +use vecsim::distance::{DistanceFunction, Metric}; +use vecsim::index::hnsw::search::{select_neighbors_heuristic, select_neighbors_simple}; +use vecsim::index::hnsw::{HnswParams, HnswSingle, VisitedNodesHandler, VisitedNodesHandlerPool}; +use vecsim::index::VecSimIndex; +use vecsim::query::QueryParams; + +/// Generate random vectors for benchmarking. +fn generate_vectors(count: usize, dim: usize) -> Vec> { + let mut rng = rand::thread_rng(); + (0..count) + .map(|_| (0..dim).map(|_| rng.gen::()).collect()) + .collect() +} + +/// Benchmark distance computation across different dimensions and metrics. +/// +/// This helps identify if distance computation is a bottleneck for specific dimensions. +fn bench_distance_computation(c: &mut Criterion) { + let mut group = c.benchmark_group("distance_computation"); + + for dim in [32, 128, 384, 768, 1536] { + let v1: Vec = (0..dim).map(|i| i as f32 / dim as f32).collect(); + let v2: Vec = (0..dim).map(|i| (dim - i) as f32 / dim as f32).collect(); + + group.throughput(Throughput::Elements(dim as u64)); + + // L2 distance with SIMD + let dist_fn = L2Distance::::with_simd(dim); + group.bench_with_input(BenchmarkId::new("L2_simd", dim), &dim, |b, &d| { + b.iter(|| dist_fn.compute(black_box(&v1), black_box(&v2), d)); + }); + + // L2 distance scalar + let dist_fn_scalar = L2Distance::::scalar(dim); + group.bench_with_input(BenchmarkId::new("L2_scalar", dim), &dim, |b, &d| { + b.iter(|| dist_fn_scalar.compute(black_box(&v1), black_box(&v2), d)); + }); + + // Inner product with SIMD + let ip_fn = InnerProductDistance::::with_simd(dim); + group.bench_with_input(BenchmarkId::new("IP_simd", dim), &dim, |b, &d| { + b.iter(|| ip_fn.compute(black_box(&v1), black_box(&v2), d)); + }); + + // Inner product scalar + let ip_fn_scalar = InnerProductDistance::::scalar(dim); + group.bench_with_input(BenchmarkId::new("IP_scalar", dim), &dim, |b, &d| { + b.iter(|| ip_fn_scalar.compute(black_box(&v1), black_box(&v2), d)); + }); + + // Cosine distance with SIMD + let cos_fn = CosineDistance::::with_simd(dim); + group.bench_with_input(BenchmarkId::new("Cosine_simd", dim), &dim, |b, &d| { + b.iter(|| cos_fn.compute(black_box(&v1), black_box(&v2), d)); + }); + + // Cosine distance scalar + let cos_fn_scalar = CosineDistance::::scalar(dim); + group.bench_with_input(BenchmarkId::new("Cosine_scalar", dim), &dim, |b, &d| { + b.iter(|| cos_fn_scalar.compute(black_box(&v1), black_box(&v2), d)); + }); + } + + group.finish(); +} + +/// Benchmark visited nodes tracking operations. +/// +/// Measures visit(), is_visited(), reset(), and pool checkout/return overhead. +fn bench_visited_nodes(c: &mut Criterion) { + let mut group = c.benchmark_group("visited_nodes"); + + for capacity in [1_000, 10_000, 100_000] { + // Visit operation (mark node as visited) + group.bench_with_input( + BenchmarkId::new("visit", capacity), + &capacity, + |b, &cap| { + let handler = VisitedNodesHandler::new(cap); + let mut rng = rand::thread_rng(); + b.iter(|| { + let id = rng.gen_range(0..cap as u32); + handler.visit(black_box(id)) + }); + }, + ); + + // is_visited check operation + group.bench_with_input( + BenchmarkId::new("is_visited", capacity), + &capacity, + |b, &cap| { + let handler = VisitedNodesHandler::new(cap); + // Pre-visit some nodes + for i in (0..cap as u32).step_by(2) { + handler.visit(i); + } + let mut rng = rand::thread_rng(); + b.iter(|| { + let id = rng.gen_range(0..cap as u32); + handler.is_visited(black_box(id)) + }); + }, + ); + + // Reset operation (O(1) with tag-based approach) + group.bench_with_input( + BenchmarkId::new("reset", capacity), + &capacity, + |b, &cap| { + let mut handler = VisitedNodesHandler::new(cap); + b.iter(|| { + handler.reset(); + }); + }, + ); + + // Pool checkout and return + group.bench_with_input( + BenchmarkId::new("pool_get_return", capacity), + &capacity, + |b, &cap| { + let pool = VisitedNodesHandlerPool::new(cap); + // Warm up the pool with one handler + { + let _h = pool.get(); + } + b.iter(|| { + let _handler = pool.get(); + // Handler returned automatically on drop + }); + }, + ); + } + + group.finish(); +} + +/// Benchmark neighbor selection algorithms. +/// +/// Compares simple selection (sort + take) vs heuristic selection (diversity-aware). +fn bench_neighbor_selection(c: &mut Criterion) { + let mut group = c.benchmark_group("neighbor_selection"); + let dim = 128; + + for num_candidates in [10, 50, 100, 200] { + // Generate candidates as (id, distance) pairs + let candidates: Vec<(u32, f32)> = (0..num_candidates) + .map(|i| (i as u32, i as f32 * 0.1)) + .collect(); + + // Simple selection (just keep M closest) + group.bench_with_input( + BenchmarkId::new("simple", num_candidates), + &candidates, + |b, cands| { + b.iter(|| select_neighbors_simple(black_box(cands), 16)); + }, + ); + + // Heuristic selection (diversity-aware, requires vector data) + let vectors = generate_vectors(num_candidates, dim); + let data_getter = |id: u32| -> Option<&[f32]> { vectors.get(id as usize).map(|v| v.as_slice()) }; + let dist_fn = L2Distance::::with_simd(dim); + + group.bench_with_input( + BenchmarkId::new("heuristic", num_candidates), + &candidates, + |b, cands| { + b.iter(|| { + select_neighbors_heuristic( + 0, // target id + black_box(cands), + 16, // M (max neighbors) + &data_getter, + &dist_fn as &dyn DistanceFunction, + dim, + false, // extend_candidates + true, // keep_pruned + ) + }); + }, + ); + } + + group.finish(); +} + +/// Benchmark HNSW search with varying ef values. +/// +/// Isolates search_layer performance to measure the impact of ef_runtime. +fn bench_search_ef_impact(c: &mut Criterion) { + let mut group = c.benchmark_group("search_ef_impact"); + group.sample_size(50); + + let dim = 128; + let size = 10_000; + let vectors = generate_vectors(size, dim); + + // Build index once + let params = HnswParams::new(dim, Metric::L2) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let query = generate_vectors(1, dim).pop().unwrap(); + + for ef in [10, 20, 50, 100, 200, 400] { + let query_params = QueryParams::new().with_ef_runtime(ef); + group.bench_with_input(BenchmarkId::from_parameter(ef), &ef, |b, _| { + b.iter(|| { + index + .top_k_query(black_box(&query), 10, Some(&query_params)) + .unwrap() + }); + }); + } + + group.finish(); +} + +/// Benchmark search with and without filters. +/// +/// Measures the overhead of filter evaluation during search. +fn bench_search_with_filters(c: &mut Criterion) { + let mut group = c.benchmark_group("search_filters"); + group.sample_size(50); + + let dim = 128; + let size = 10_000; + let vectors = generate_vectors(size, dim); + + let params = HnswParams::new(dim, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let query = generate_vectors(1, dim).pop().unwrap(); + + // Without filter + group.bench_function("no_filter", |b| { + b.iter(|| { + index + .top_k_query(black_box(&query), 10, None) + .unwrap() + }); + }); + + // With filter that accepts all (minimal overhead measurement) + group.bench_function("filter_accept_all", |b| { + b.iter(|| { + let params = QueryParams::new() + .with_ef_runtime(100) + .with_filter(|_| true); + index + .top_k_query(black_box(&query), 10, Some(¶ms)) + .unwrap() + }); + }); + + // With filter that accepts 50% + group.bench_function("filter_accept_50pct", |b| { + b.iter(|| { + let params = QueryParams::new() + .with_ef_runtime(100) + .with_filter(|label| label % 2 == 0); + index + .top_k_query(black_box(&query), 10, Some(¶ms)) + .unwrap() + }); + }); + + // With filter that accepts 10% + group.bench_function("filter_accept_10pct", |b| { + b.iter(|| { + let params = QueryParams::new() + .with_ef_runtime(100) + .with_filter(|label| label % 10 == 0); + index + .top_k_query(black_box(&query), 10, Some(¶ms)) + .unwrap() + }); + }); + + group.finish(); +} + +/// Benchmark memory access patterns (sequential vs random). +/// +/// This helps understand cache effects in vector access patterns. +fn bench_memory_access_patterns(c: &mut Criterion) { + let mut group = c.benchmark_group("memory_access"); + + let size = 100_000; + let data: Vec = (0..size).map(|i| i as f32).collect(); + + // Sequential access - cache friendly + group.bench_function("sequential_1000", |b| { + b.iter(|| { + let mut sum = 0.0f32; + for i in 0..1000 { + sum += black_box(data[i]); + } + sum + }); + }); + + // Sequential access with stride (simulates vector access) + let dim = 128; + group.bench_function("sequential_stride_128", |b| { + let num_vectors = size / dim; + b.iter(|| { + let mut sum = 0.0f32; + for v in 0..num_vectors.min(100) { + for d in 0..dim { + sum += black_box(data[v * dim + d]); + } + } + sum + }); + }); + + // Random access - cache unfriendly + let mut rng = rand::thread_rng(); + let random_indices: Vec = (0..1000).map(|_| rng.gen_range(0..size)).collect(); + group.bench_function("random_1000", |b| { + b.iter(|| { + let mut sum = 0.0f32; + for &i in &random_indices { + sum += black_box(data[i]); + } + sum + }); + }); + + // Random vector access (simulates HNSW neighbor traversal) + let random_vector_indices: Vec = (0..100) + .map(|_| rng.gen_range(0..(size / dim))) + .collect(); + group.bench_function("random_vectors_100", |b| { + b.iter(|| { + let mut sum = 0.0f32; + for &vi in &random_vector_indices { + let base = vi * dim; + for d in 0..dim { + sum += black_box(data[base + d]); + } + } + sum + }); + }); + + group.finish(); +} + +/// Benchmark batch distance computations. +/// +/// Measures throughput when computing distances against multiple candidates. +fn bench_batch_distance(c: &mut Criterion) { + let mut group = c.benchmark_group("batch_distance"); + + for dim in [128, 384, 768] { + let query: Vec = (0..dim).map(|i| i as f32 / dim as f32).collect(); + let candidates: Vec> = (0..100) + .map(|_| { + let mut rng = rand::thread_rng(); + (0..dim).map(|_| rng.gen::()).collect() + }) + .collect(); + + let dist_fn = L2Distance::::with_simd(dim); + + group.throughput(Throughput::Elements(100)); + group.bench_with_input( + BenchmarkId::new("100_candidates", dim), + &dim, + |b, &d| { + b.iter(|| { + let mut min_dist = f32::MAX; + for cand in &candidates { + let dist = dist_fn.compute(black_box(&query), black_box(cand), d); + if dist < min_dist { + min_dist = dist; + } + } + min_dist + }); + }, + ); + } + + group.finish(); +} + +/// Benchmark batch query processing. +/// +/// Measures batch query throughput at different batch sizes. +fn bench_batch_query(c: &mut Criterion) { + let mut group = c.benchmark_group("batch_query"); + + let dim = 128; + let num_vectors = 10_000; + + // Build index + let params = HnswParams::new(dim, vecsim::distance::Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + + let mut index = HnswSingle::::new(params); + + // Add vectors + let vectors = generate_vectors(num_vectors, dim); + for (label, vec) in vectors.iter().enumerate() { + let _ = index.add_vector(vec, label as u64); + } + + // Generate query batches and benchmark + for batch_size in [1, 10, 50, 100] { + let queries = generate_vectors(batch_size, dim); + + group.bench_with_input( + BenchmarkId::new("batch_search", batch_size), + &batch_size, + |b, _| { + b.iter(|| index.batch_search(black_box(&queries), 10, None)); + }, + ); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_distance_computation, + bench_visited_nodes, + bench_neighbor_selection, + bench_search_ef_impact, + bench_search_with_filters, + bench_memory_access_patterns, + bench_batch_distance, + bench_batch_query, +); +criterion_main!(benches); diff --git a/rust/vecsim/benches/svs_bench.rs b/rust/vecsim/benches/svs_bench.rs new file mode 100644 index 000000000..d4cc5a72e --- /dev/null +++ b/rust/vecsim/benches/svs_bench.rs @@ -0,0 +1,417 @@ +//! Benchmarks for SVS (Vamana) index operations. +//! +//! Run with: cargo bench --bench svs_bench + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use vecsim::distance::Metric; +use vecsim::index::svs::{SvsMulti, SvsParams, SvsSingle}; +use vecsim::index::VecSimIndex; + +const DIM: usize = 128; + +/// Generate random vectors for benchmarking. +fn generate_vectors(count: usize, dim: usize) -> Vec> { + use rand::Rng; + let mut rng = rand::thread_rng(); + (0..count) + .map(|_| (0..dim).map(|_| rng.gen::()).collect()) + .collect() +} + +/// Benchmark adding vectors to SvsSingle. +fn bench_svs_single_add(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_single_add"); + group.sample_size(10); // SVS add is slow due to graph construction + + for size in [100, 500, 1000] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let vectors = generate_vectors(size, DIM); + b.iter(|| { + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_two_pass(false); // Single pass for faster construction + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark adding vectors with varying graph degree (R). +fn bench_svs_add_varying_degree(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_add_degree"); + group.sample_size(10); + + let size = 500; + let vectors = generate_vectors(size, DIM); + + for degree in [16, 32, 48, 64] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(degree), °ree, |b, °ree| { + b.iter(|| { + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(degree) + .with_construction_l(100) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark adding vectors with varying alpha parameter. +fn bench_svs_add_varying_alpha(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_add_alpha"); + group.sample_size(10); + + let size = 500; + let vectors = generate_vectors(size, DIM); + + for alpha in [1.0f32, 1.1, 1.2, 1.4] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(alpha), &alpha, |b, &alpha| { + b.iter(|| { + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_alpha(alpha) + .with_construction_l(100) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark adding vectors with varying construction window size (L). +fn bench_svs_add_varying_construction_l(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_add_construction_l"); + group.sample_size(10); + + let size = 500; + let vectors = generate_vectors(size, DIM); + + for l in [50, 100, 200, 400] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(l), &l, |b, &l| { + b.iter(|| { + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(l) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark two-pass vs single-pass construction. +fn bench_svs_two_pass_construction(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_two_pass"); + group.sample_size(10); + + let size = 500; + let vectors = generate_vectors(size, DIM); + + for two_pass in [false, true] { + let label = if two_pass { "two_pass" } else { "single_pass" }; + group.throughput(Throughput::Elements(size as u64)); + group.bench_function(label, |b| { + b.iter(|| { + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_two_pass(two_pass); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries on SvsSingle. +fn bench_svs_single_topk(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_single_topk"); + + for size in [1000, 5000, 10000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_search_l(50) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries with varying default search window size. +/// Note: SVS search_l is set at index creation time, not at query time. +fn bench_svs_topk_varying_search_l(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_topk_search_l"); + group.sample_size(10); // Need to rebuild index for each search_l + + let size = 5000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + for search_l in [10, 50, 100, 200] { + // SVS search_l is an index parameter, so we need to rebuild for each value + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_search_l(search_l) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(search_l), &search_l, |b, _| { + b.iter(|| { + index + .top_k_query(black_box(&query), black_box(10), None) + .unwrap() + }); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries with varying k values. +fn bench_svs_topk_varying_k(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_topk_k"); + + let size = 10000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_search_l(100) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + for k in [1, 10, 50, 100] { + group.bench_with_input(BenchmarkId::from_parameter(k), &k, |b, &k| { + b.iter(|| index.top_k_query(black_box(&query), black_box(k), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark range queries on SvsSingle. +fn bench_svs_single_range(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_single_range"); + + for size in [1000, 5000, 10000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_search_l(100) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let radius = 10.0; + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.range_query(black_box(&query), black_box(radius), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark delete operations on SvsSingle. +fn bench_svs_single_delete(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_single_delete"); + group.sample_size(10); + + for size in [500, 1000, 2000] { + let vectors = generate_vectors(size, DIM); + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + b.iter_batched( + || { + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index + }, + |mut index| { + // Delete half the vectors + for i in (0..size).step_by(2) { + index.delete_vector(i as u64).unwrap(); + } + index + }, + criterion::BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +/// Benchmark SvsMulti with multiple vectors per label. +fn bench_svs_multi_add(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_multi_add"); + group.sample_size(10); + + for size in [100, 500, 1000] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let vectors = generate_vectors(size, DIM); + b.iter(|| { + let params = SvsParams::new(DIM, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_two_pass(false); + let mut index = SvsMulti::::new(params); + for (i, v) in vectors.iter().enumerate() { + // Use fewer labels to have multiple vectors per label + index.add_vector(black_box(v), (i % 50) as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark different distance metrics for SVS. +fn bench_svs_metrics(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_metrics_5000"); + + let size = 5000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + for metric in [Metric::L2, Metric::InnerProduct, Metric::Cosine] { + let params = SvsParams::new(DIM, metric) + .with_graph_degree(32) + .with_construction_l(100) + .with_search_l(50) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let metric_name = match metric { + Metric::L2 => "L2", + Metric::InnerProduct => "IP", + Metric::Cosine => "Cosine", + }; + + group.bench_function(metric_name, |b| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark different vector dimensions for SVS. +fn bench_svs_dimensions(c: &mut Criterion) { + let mut group = c.benchmark_group("svs_dimensions_1000"); + + let size = 1000; + + for dim in [32, 128, 512] { + let vectors = generate_vectors(size, dim); + let query = generate_vectors(1, dim).pop().unwrap(); + + let params = SvsParams::new(dim, Metric::L2) + .with_graph_degree(32) + .with_construction_l(100) + .with_search_l(50) + .with_two_pass(false); + let mut index = SvsSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(dim), &dim, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_svs_single_add, + bench_svs_add_varying_degree, + bench_svs_add_varying_alpha, + bench_svs_add_varying_construction_l, + bench_svs_two_pass_construction, + bench_svs_single_topk, + bench_svs_topk_varying_search_l, + bench_svs_topk_varying_k, + bench_svs_single_range, + bench_svs_single_delete, + bench_svs_multi_add, + bench_svs_metrics, + bench_svs_dimensions, +); + +criterion_main!(benches); diff --git a/rust/vecsim/benches/tiered_bench.rs b/rust/vecsim/benches/tiered_bench.rs new file mode 100644 index 000000000..d92e6c37d --- /dev/null +++ b/rust/vecsim/benches/tiered_bench.rs @@ -0,0 +1,337 @@ +//! Benchmarks for TieredIndex operations. +//! +//! Run with: cargo bench --bench tiered_bench + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use vecsim::distance::Metric; +use vecsim::index::brute_force::{BruteForceParams, BruteForceSingle}; +use vecsim::index::hnsw::{HnswParams, HnswSingle}; +use vecsim::index::tiered::{TieredParams, TieredSingle, WriteMode}; +use vecsim::index::VecSimIndex; + +const DIM: usize = 128; + +/// Generate random vectors for benchmarking. +fn generate_vectors(count: usize, dim: usize) -> Vec> { + use rand::Rng; + let mut rng = rand::thread_rng(); + (0..count) + .map(|_| (0..dim).map(|_| rng.gen::()).collect()) + .collect() +} + +/// Benchmark adding vectors to TieredSingle in async mode. +fn bench_tiered_add_async(c: &mut Criterion) { + let mut group = c.benchmark_group("tiered_add_async"); + + for size in [100, 1000, 5000] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let vectors = generate_vectors(size, DIM); + b.iter(|| { + let params = TieredParams::new(DIM, Metric::L2) + .with_flat_buffer_limit(size * 2) // Keep in async mode + .with_write_mode(WriteMode::Async); + let mut index = TieredSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark adding vectors to TieredSingle in in-place mode. +fn bench_tiered_add_inplace(c: &mut Criterion) { + let mut group = c.benchmark_group("tiered_add_inplace"); + + for size in [100, 1000, 5000] { + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let vectors = generate_vectors(size, DIM); + b.iter(|| { + let params = TieredParams::new(DIM, Metric::L2) + .with_write_mode(WriteMode::InPlace) + .with_m(16) + .with_ef_construction(100); + let mut index = TieredSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries on TieredSingle with vectors in flat buffer only. +fn bench_tiered_query_flat(c: &mut Criterion) { + let mut group = c.benchmark_group("tiered_query_flat"); + + for size in [100, 1000, 5000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = TieredParams::new(DIM, Metric::L2) + .with_flat_buffer_limit(size * 2); + let mut index = TieredSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries on TieredSingle with vectors in HNSW only. +fn bench_tiered_query_hnsw(c: &mut Criterion) { + let mut group = c.benchmark_group("tiered_query_hnsw"); + + for size in [100, 1000, 5000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = TieredParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut index = TieredSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index.flush().unwrap(); + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark top-k queries on TieredSingle with vectors in both tiers. +fn bench_tiered_query_both(c: &mut Criterion) { + let mut group = c.benchmark_group("tiered_query_both_tiers"); + + for size in [100, 1000, 5000] { + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + let params = TieredParams::new(DIM, Metric::L2) + .with_flat_buffer_limit(size / 2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut index = TieredSingle::::new(params); + + // Add half to flat, flush to HNSW + for (i, v) in vectors.iter().take(size / 2).enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index.flush().unwrap(); + + // Add other half to flat + for (i, v) in vectors.iter().skip(size / 2).enumerate() { + index.add_vector(v, (size / 2 + i) as u64).unwrap(); + } + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + } + + group.finish(); +} + +/// Benchmark flush operation (migrating from flat to HNSW). +fn bench_tiered_flush(c: &mut Criterion) { + let mut group = c.benchmark_group("tiered_flush"); + + for size in [100, 500, 1000] { + let vectors = generate_vectors(size, DIM); + + group.throughput(Throughput::Elements(size as u64)); + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + b.iter_batched( + || { + let params = TieredParams::new(DIM, Metric::L2) + .with_flat_buffer_limit(size * 2) + .with_m(16) + .with_ef_construction(100); + let mut index = TieredSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + index + }, + |mut index| { + index.flush().unwrap(); + index + }, + criterion::BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +/// Compare TieredSingle query performance against BruteForce and HNSW. +fn bench_comparison_query(c: &mut Criterion) { + let mut group = c.benchmark_group("comparison_query_5000"); + + let size = 5000; + let vectors = generate_vectors(size, DIM); + let query = generate_vectors(1, DIM).pop().unwrap(); + + // BruteForce + let bf_params = BruteForceParams::new(DIM, Metric::L2); + let mut bf_index = BruteForceSingle::::new(bf_params); + for (i, v) in vectors.iter().enumerate() { + bf_index.add_vector(v, i as u64).unwrap(); + } + + group.bench_function("brute_force", |b| { + b.iter(|| bf_index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + + // HNSW + let hnsw_params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut hnsw_index = HnswSingle::::new(hnsw_params); + for (i, v) in vectors.iter().enumerate() { + hnsw_index.add_vector(v, i as u64).unwrap(); + } + + group.bench_function("hnsw", |b| { + b.iter(|| hnsw_index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + + // Tiered (all in HNSW after flush) + let tiered_params = TieredParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut tiered_index = TieredSingle::::new(tiered_params); + for (i, v) in vectors.iter().enumerate() { + tiered_index.add_vector(v, i as u64).unwrap(); + } + tiered_index.flush().unwrap(); + + group.bench_function("tiered_flushed", |b| { + b.iter(|| tiered_index.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + + // Tiered (half in each tier) + let tiered_params2 = TieredParams::new(DIM, Metric::L2) + .with_flat_buffer_limit(size / 2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut tiered_index2 = TieredSingle::::new(tiered_params2); + for (i, v) in vectors.iter().take(size / 2).enumerate() { + tiered_index2.add_vector(v, i as u64).unwrap(); + } + tiered_index2.flush().unwrap(); + for (i, v) in vectors.iter().skip(size / 2).enumerate() { + tiered_index2.add_vector(v, (size / 2 + i) as u64).unwrap(); + } + + group.bench_function("tiered_both_tiers", |b| { + b.iter(|| tiered_index2.top_k_query(black_box(&query), black_box(10), None).unwrap()); + }); + + group.finish(); +} + +/// Compare add performance across index types. +fn bench_comparison_add(c: &mut Criterion) { + let mut group = c.benchmark_group("comparison_add_1000"); + + let size = 1000; + let vectors = generate_vectors(size, DIM); + + group.throughput(Throughput::Elements(size as u64)); + + // BruteForce + group.bench_function("brute_force", |b| { + b.iter(|| { + let params = BruteForceParams::new(DIM, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + + // HNSW + group.bench_function("hnsw", |b| { + b.iter(|| { + let params = HnswParams::new(DIM, Metric::L2) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + + // Tiered (async mode - writes to flat buffer) + group.bench_function("tiered_async", |b| { + b.iter(|| { + let params = TieredParams::new(DIM, Metric::L2) + .with_flat_buffer_limit(size * 2) + .with_write_mode(WriteMode::Async); + let mut index = TieredSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + + // Tiered (in-place mode - writes directly to HNSW) + group.bench_function("tiered_inplace", |b| { + b.iter(|| { + let params = TieredParams::new(DIM, Metric::L2) + .with_write_mode(WriteMode::InPlace) + .with_m(16) + .with_ef_construction(100); + let mut index = TieredSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(black_box(v), i as u64).unwrap(); + } + index + }); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_tiered_add_async, + bench_tiered_add_inplace, + bench_tiered_query_flat, + bench_tiered_query_hnsw, + bench_tiered_query_both, + bench_tiered_flush, + bench_comparison_query, + bench_comparison_add, +); + +criterion_main!(benches); diff --git a/rust/vecsim/src/containers/data_blocks.rs b/rust/vecsim/src/containers/data_blocks.rs new file mode 100644 index 000000000..d5f16d6e1 --- /dev/null +++ b/rust/vecsim/src/containers/data_blocks.rs @@ -0,0 +1,783 @@ +//! Block-based vector storage with SIMD-aligned memory. +//! +//! This module provides `DataBlocks`, a container optimized for storing +//! vectors in contiguous, cache-friendly blocks with proper SIMD alignment. + +use crate::distance::simd::optimal_alignment; +use crate::types::{IdType, VectorElement, INVALID_ID}; +use parking_lot::{Mutex, RwLock}; +use std::alloc::{self, Layout}; +use std::collections::HashSet; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Default block size (number of vectors per block). +const DEFAULT_BLOCK_SIZE: usize = 1024; + +/// A single block of vector data with aligned memory. +struct DataBlock { + /// Pointer to aligned memory. + data: NonNull, + /// Layout used for allocation. + layout: Layout, + /// Number of elements (not vectors) in this block. + capacity: usize, +} + +impl DataBlock { + /// Create a new data block with capacity for `num_vectors` vectors of `dim` elements. + fn new(num_vectors: usize, dim: usize) -> Self { + let num_elements = num_vectors * dim; + let alignment = optimal_alignment().max(std::mem::align_of::()); + let size = num_elements * std::mem::size_of::(); + + let layout = Layout::from_size_align(size.max(1), alignment) + .expect("Invalid layout for DataBlock"); + + let data = if size == 0 { + NonNull::dangling() + } else { + let ptr = unsafe { alloc::alloc(layout) }; + if ptr.is_null() { + alloc::handle_alloc_error(layout); + } + NonNull::new(ptr as *mut T).expect("Allocation returned null") + }; + + Self { + data, + layout, + capacity: num_elements, + } + } + + /// Check if an index is within bounds. + #[inline] + fn is_valid_index(&self, index: usize, dim: usize) -> bool { + index * dim + dim <= self.capacity + } + + /// Get a pointer to the vector at the given index. + /// + /// # Safety + /// Caller must ensure the index is valid (use `is_valid_index` first). + #[inline] + unsafe fn get_vector_ptr_unchecked(&self, index: usize, dim: usize) -> *const T { + self.data.as_ptr().add(index * dim) + } + + /// Get a mutable pointer to the vector at the given index. + /// + /// # Safety + /// Caller must ensure the index is valid (use `is_valid_index` first). + #[inline] + unsafe fn get_vector_ptr_mut_unchecked(&mut self, index: usize, dim: usize) -> *mut T { + self.data.as_ptr().add(index * dim) + } + + /// Get a slice to the vector at the given index. + /// + /// Returns `None` if the index is out of bounds. + #[inline] + fn get_vector(&self, index: usize, dim: usize) -> Option<&[T]> { + if !self.is_valid_index(index, dim) { + return None; + } + // SAFETY: We just verified the index is valid. + unsafe { + Some(std::slice::from_raw_parts( + self.get_vector_ptr_unchecked(index, dim), + dim, + )) + } + } + + /// Write a vector at the given index. + /// + /// Returns `false` if the index is out of bounds or data length doesn't match dim. + #[inline] + fn write_vector(&mut self, index: usize, dim: usize, data: &[T]) -> bool { + if data.len() != dim || !self.is_valid_index(index, dim) { + return false; + } + // SAFETY: We just verified the index is valid and data length matches. + unsafe { + std::ptr::copy_nonoverlapping( + data.as_ptr(), + self.get_vector_ptr_mut_unchecked(index, dim), + dim, + ); + } + true + } + + /// Write a vector at the given index (concurrent version). + /// + /// This is safe for concurrent writes to different indices within the same block. + /// Returns `false` if the index is out of bounds or data length doesn't match dim. + /// + /// # Safety + /// Caller must ensure no two threads write to the same index simultaneously. + #[inline] + fn write_vector_concurrent(&self, index: usize, dim: usize, data: &[T]) -> bool { + if data.len() != dim || !self.is_valid_index(index, dim) { + return false; + } + // SAFETY: We verified the index is valid and data length matches. + // The caller ensures no concurrent writes to the same index. + unsafe { + std::ptr::copy_nonoverlapping( + data.as_ptr(), + self.data.as_ptr().add(index * dim), + dim, + ); + } + true + } +} + +impl Drop for DataBlock { + fn drop(&mut self) { + if self.layout.size() > 0 { + unsafe { + alloc::dealloc(self.data.as_ptr() as *mut u8, self.layout); + } + } + } +} + +// Safety: DataBlock contains raw pointers but they are owned and not shared. +unsafe impl Send for DataBlock {} +unsafe impl Sync for DataBlock {} + +/// Block-based storage for vectors with SIMD-aligned memory. +/// +/// Vectors are stored in contiguous blocks for cache efficiency. +/// Each vector is accessed by its internal ID. +/// +/// This structure supports both mutable (single-threaded) and concurrent +/// (multi-threaded) access patterns via different methods. +pub struct DataBlocks { + /// The blocks storing vector data. RwLock allows concurrent reads + /// and exclusive writes for block growth. + blocks: RwLock>>, + /// Number of vectors per block. + vectors_per_block: usize, + /// Vector dimension. + dim: usize, + /// Total number of vectors stored (excluding deleted). + count: AtomicUsize, + /// Free slots from deleted vectors (for reuse). Uses Mutex for thread-safe access. + free_slots: Mutex>, + /// High water mark: the highest ID ever allocated + 1. + /// Used to determine which slots are valid vs never-allocated. + high_water_mark: AtomicUsize, +} + +impl DataBlocks { + /// Create a new DataBlocks container. + /// + /// # Arguments + /// * `dim` - Vector dimension + /// * `initial_capacity` - Initial number of vectors to allocate + pub fn new(dim: usize, initial_capacity: usize) -> Self { + let vectors_per_block = DEFAULT_BLOCK_SIZE; + let num_blocks = initial_capacity.div_ceil(vectors_per_block); + + let blocks: Vec<_> = (0..num_blocks.max(1)) + .map(|_| DataBlock::new(vectors_per_block, dim)) + .collect(); + + Self { + blocks: RwLock::new(blocks), + vectors_per_block, + dim, + count: AtomicUsize::new(0), + free_slots: Mutex::new(HashSet::new()), + high_water_mark: AtomicUsize::new(0), + } + } + + /// Create with a custom block size. + pub fn with_block_size(dim: usize, initial_capacity: usize, block_size: usize) -> Self { + let vectors_per_block = block_size; + let num_blocks = initial_capacity.div_ceil(vectors_per_block); + + let blocks: Vec<_> = (0..num_blocks.max(1)) + .map(|_| DataBlock::new(vectors_per_block, dim)) + .collect(); + + Self { + blocks: RwLock::new(blocks), + vectors_per_block, + dim, + count: AtomicUsize::new(0), + free_slots: Mutex::new(HashSet::new()), + high_water_mark: AtomicUsize::new(0), + } + } + + /// Get the vector dimension. + #[inline] + pub fn dimension(&self) -> usize { + self.dim + } + + /// Get the number of vectors stored. + #[inline] + pub fn len(&self) -> usize { + self.count.load(Ordering::Acquire) + } + + /// Check if empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.count.load(Ordering::Acquire) == 0 + } + + /// Get the total capacity (number of vector slots). + #[inline] + pub fn capacity(&self) -> usize { + self.blocks.read().len() * self.vectors_per_block + } + + /// Convert an internal ID to block and offset indices. + #[inline] + fn id_to_indices(&self, id: IdType) -> (usize, usize) { + let id = id as usize; + (id / self.vectors_per_block, id % self.vectors_per_block) + } + + /// Convert block and offset indices to an internal ID. + #[inline] + #[allow(dead_code)] + fn indices_to_id(&self, block: usize, offset: usize) -> IdType { + (block * self.vectors_per_block + offset) as IdType + } + + /// Add a vector and return its internal ID. + /// + /// Returns `None` if the vector dimension doesn't match the container's dimension. + /// + /// This method requires `&mut self` for API compatibility. For concurrent access, + /// use `add_concurrent` instead. + pub fn add(&mut self, vector: &[T]) -> Option { + // Delegate to concurrent implementation since the data structures + // are now thread-safe + self.add_concurrent(vector) + } + + /// Add a vector concurrently and return its internal ID. + /// + /// This method is thread-safe and can be called from multiple threads simultaneously. + /// Returns `None` if the vector dimension doesn't match the container's dimension. + pub fn add_concurrent(&self, vector: &[T]) -> Option { + if vector.len() != self.dim { + return None; + } + + // Try to reuse a free slot first + { + let mut free_slots = self.free_slots.lock(); + if let Some(&id) = free_slots.iter().next() { + free_slots.remove(&id); + drop(free_slots); // Release lock before writing + + let (block_idx, offset) = self.id_to_indices(id); + let blocks = self.blocks.read(); + if blocks[block_idx].write_vector_concurrent(offset, self.dim, vector) { + self.count.fetch_add(1, Ordering::AcqRel); + return Some(id); + } + // Write failed (shouldn't happen), put the slot back + self.free_slots.lock().insert(id); + return None; + } + } + + // Allocate a new slot using atomic increment + loop { + let next_slot = self.high_water_mark.fetch_add(1, Ordering::AcqRel); + + // Check if we need to grow the blocks + { + let blocks = self.blocks.read(); + let total_slots = blocks.len() * self.vectors_per_block; + if next_slot < total_slots { + // We have space, write the vector + let (block_idx, offset) = self.id_to_indices(next_slot as IdType); + if blocks[block_idx].write_vector_concurrent(offset, self.dim, vector) { + self.count.fetch_add(1, Ordering::AcqRel); + return Some(next_slot as IdType); + } + // Write failed (shouldn't happen) + return None; + } + } + + // Need to grow - acquire write lock + let mut blocks = self.blocks.write(); + let total_slots = blocks.len() * self.vectors_per_block; + + // Double-check after acquiring write lock + if next_slot >= total_slots { + // Still need more space, allocate a new block + blocks.push(DataBlock::new(self.vectors_per_block, self.dim)); + } + + // Now write the vector + let (block_idx, offset) = self.id_to_indices(next_slot as IdType); + if blocks[block_idx].write_vector_concurrent(offset, self.dim, vector) { + self.count.fetch_add(1, Ordering::AcqRel); + return Some(next_slot as IdType); + } + // Write failed (shouldn't happen) + return None; + } + } + + /// Check if an ID is valid and not deleted. + #[inline] + pub fn is_valid(&self, id: IdType) -> bool { + if id == INVALID_ID { + return false; + } + let id_usize = id as usize; + // Must be within allocated range and not deleted + id_usize < self.high_water_mark.load(Ordering::Acquire) && !self.free_slots.lock().contains(&id) + } + + /// Get a vector by its internal ID. + /// + /// Returns `None` if the ID is invalid, out of bounds, or the vector was deleted. + #[inline] + pub fn get(&self, id: IdType) -> Option<&[T]> { + if !self.is_valid(id) { + return None; + } + let (block_idx, offset) = self.id_to_indices(id); + let blocks = self.blocks.read(); + if block_idx >= blocks.len() { + return None; + } + // SAFETY: We hold the read lock and verified the index is valid. + // The returned reference is safe because: + // 1. The block data is never moved (blocks can only grow, not shrink during normal operation) + // 2. Individual vector slots are never reallocated while valid + // 3. The ID was validated as not deleted + unsafe { + let block = &blocks[block_idx]; + if !block.is_valid_index(offset, self.dim) { + return None; + } + let ptr = block.get_vector_ptr_unchecked(offset, self.dim); + Some(std::slice::from_raw_parts(ptr, self.dim)) + } + } + + /// Get a raw pointer to a vector (for SIMD operations). + /// + /// Returns `None` if the ID is invalid, out of bounds, or the vector was deleted. + #[inline] + pub fn get_ptr(&self, id: IdType) -> Option<*const T> { + if !self.is_valid(id) { + return None; + } + let (block_idx, offset) = self.id_to_indices(id); + let blocks = self.blocks.read(); + if block_idx >= blocks.len() { + return None; + } + let block = &blocks[block_idx]; + if !block.is_valid_index(offset, self.dim) { + return None; + } + // SAFETY: We verified the index is valid above. + unsafe { Some(block.get_vector_ptr_unchecked(offset, self.dim)) } + } + + /// Get a vector by its internal ID, skipping the deleted check. + /// + /// This is faster than `get()` because it doesn't acquire the Mutex lock + /// on `free_slots` to check if the ID is deleted. + /// + /// # Safety + /// This method is safe but may return data for deleted vectors. + /// Use this only during search operations where: + /// - The ID is known to be valid (from graph traversal) + /// - Deleted vectors are handled separately (e.g., via isMarkedDeleted check) + #[inline] + pub fn get_unchecked_deleted(&self, id: IdType) -> Option<&[T]> { + if id == INVALID_ID { + return None; + } + let id_usize = id as usize; + if id_usize >= self.high_water_mark.load(Ordering::Acquire) { + return None; + } + let (block_idx, offset) = self.id_to_indices(id); + let blocks = self.blocks.read(); + if block_idx >= blocks.len() { + return None; + } + // SAFETY: We hold the read lock and verified the index is within high_water_mark. + unsafe { + let block = &blocks[block_idx]; + if !block.is_valid_index(offset, self.dim) { + return None; + } + let ptr = block.get_vector_ptr_unchecked(offset, self.dim); + Some(std::slice::from_raw_parts(ptr, self.dim)) + } + } + + /// Mark a slot as free for reuse. + /// + /// Returns `true` if the slot was successfully marked as deleted, + /// `false` if the ID is invalid, already deleted, or out of bounds. + /// + /// Note: This doesn't actually clear the data, just marks the slot as available. + pub fn mark_deleted(&mut self, id: IdType) -> bool { + self.mark_deleted_concurrent(id) + } + + /// Mark a slot as free for reuse (concurrent version). + /// + /// Returns `true` if the slot was successfully marked as deleted, + /// `false` if the ID is invalid, already deleted, or out of bounds. + pub fn mark_deleted_concurrent(&self, id: IdType) -> bool { + if id == INVALID_ID { + return false; + } + let id_usize = id as usize; + let high_water_mark = self.high_water_mark.load(Ordering::Acquire); + let mut free_slots = self.free_slots.lock(); + // Check bounds and ensure not already deleted + if id_usize >= high_water_mark || free_slots.contains(&id) { + return false; + } + free_slots.insert(id); + self.count.fetch_sub(1, Ordering::AcqRel); + true + } + + /// Update a vector at the given ID. + /// + /// Returns `true` if the update was successful, `false` if the ID is invalid, + /// deleted, out of bounds, or the vector dimension doesn't match. + pub fn update(&mut self, id: IdType, vector: &[T]) -> bool { + if vector.len() != self.dim || !self.is_valid(id) { + return false; + } + let (block_idx, offset) = self.id_to_indices(id); + let blocks = self.blocks.read(); + if block_idx >= blocks.len() { + return false; + } + blocks[block_idx].write_vector_concurrent(offset, self.dim, vector) + } + + /// Clear all vectors, resetting to empty state. + /// + /// This keeps the allocated blocks but marks them as empty. + pub fn clear(&mut self) { + self.count.store(0, Ordering::Release); + self.free_slots.lock().clear(); + self.high_water_mark.store(0, Ordering::Release); + } + + /// Reserve space for additional vectors. + pub fn reserve(&mut self, additional: usize) { + let needed = self.count.load(Ordering::Acquire) + additional; + let current_capacity = self.capacity(); + + if needed > current_capacity { + let additional_blocks = + (needed - current_capacity).div_ceil(self.vectors_per_block); + + let mut blocks = self.blocks.write(); + for _ in 0..additional_blocks { + blocks.push(DataBlock::new(self.vectors_per_block, self.dim)); + } + } + } + + /// Iterate over all valid (non-deleted) vector IDs. + pub fn iter_ids(&self) -> impl Iterator + '_ { + let high_water_mark = self.high_water_mark.load(Ordering::Acquire); + (0..high_water_mark as IdType).filter(move |&id| !self.free_slots.lock().contains(&id)) + } + + /// Compact the storage by removing gaps from deleted vectors. + /// + /// Returns a mapping from old IDs to new IDs. Vectors that were deleted + /// will not appear in the mapping. + /// + /// After compaction: + /// - All vectors are contiguous starting from ID 0 + /// - `free_slots` is empty + /// - `high_water_mark` equals `count` + /// - Unused blocks may be deallocated if `shrink` is true + /// + /// # Arguments + /// * `shrink` - If true, deallocate unused blocks after compaction + pub fn compact(&mut self, shrink: bool) -> std::collections::HashMap { + use std::collections::HashMap; + + let high_water_mark = self.high_water_mark.load(Ordering::Acquire); + let free_slots = self.free_slots.lock(); + + if free_slots.is_empty() { + drop(free_slots); + // No gaps to fill, just return identity mapping + return (0..high_water_mark as IdType) + .map(|id| (id, id)) + .collect(); + } + + let count = self.count.load(Ordering::Acquire); + let mut id_mapping = HashMap::with_capacity(count); + let mut new_id: IdType = 0; + + // Collect valid vectors and their data + let valid_ids: Vec = (0..high_water_mark as IdType) + .filter(|id| !free_slots.contains(id)) + .collect(); + drop(free_slots); // Release lock before reading vectors + + // Copy vectors to temporary storage + let vectors: Vec> = valid_ids + .iter() + .filter_map(|&id| self.get(id).map(|v| v.to_vec())) + .collect(); + + // Clear and rebuild + self.free_slots.lock().clear(); + self.high_water_mark.store(0, Ordering::Release); + self.count.store(0, Ordering::Release); + + // Re-add vectors in order + for (old_id, vector) in valid_ids.into_iter().zip(vectors.into_iter()) { + if let Some(added_id) = self.add(&vector) { + id_mapping.insert(old_id, added_id); + new_id = added_id + 1; + } + } + + // Shrink blocks if requested + if shrink { + let needed_blocks = new_id as usize / self.vectors_per_block + 1; + let mut blocks = self.blocks.write(); + if blocks.len() > needed_blocks { + blocks.truncate(needed_blocks); + } + } + + id_mapping + } + + /// Get the number of deleted (free) slots. + #[inline] + pub fn deleted_count(&self) -> usize { + self.free_slots.lock().len() + } + + /// Get the fragmentation ratio (deleted / total allocated). + /// + /// Returns 0.0 if no vectors have been allocated. + #[inline] + pub fn fragmentation(&self) -> f64 { + let high_water_mark = self.high_water_mark.load(Ordering::Acquire); + if high_water_mark == 0 { + 0.0 + } else { + self.free_slots.lock().len() as f64 / high_water_mark as f64 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_data_blocks_basic() { + let mut blocks = DataBlocks::::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let v2 = vec![5.0, 6.0, 7.0, 8.0]; + + let id1 = blocks.add(&v1).unwrap(); + let id2 = blocks.add(&v2).unwrap(); + + assert_eq!(blocks.len(), 2); + assert_eq!(blocks.get(id1), Some(v1.as_slice())); + assert_eq!(blocks.get(id2), Some(v2.as_slice())); + } + + #[test] + fn test_data_blocks_reuse() { + let mut blocks = DataBlocks::::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let v2 = vec![5.0, 6.0, 7.0, 8.0]; + let v3 = vec![9.0, 10.0, 11.0, 12.0]; + + let id1 = blocks.add(&v1).unwrap(); + let _id2 = blocks.add(&v2).unwrap(); + assert_eq!(blocks.len(), 2); + + // Delete first vector + assert!(blocks.mark_deleted(id1)); + assert_eq!(blocks.len(), 1); + + // Verify deleted vector is not accessible + assert!(blocks.get(id1).is_none()); + + // Add new vector - should reuse slot + let id3 = blocks.add(&v3).unwrap(); + assert_eq!(id3, id1); // Reused the same slot + assert_eq!(blocks.len(), 2); + assert_eq!(blocks.get(id3), Some(v3.as_slice())); + } + + #[test] + fn test_data_blocks_grow() { + let mut blocks = DataBlocks::::with_block_size(4, 2, 2); + + // Fill initial capacity + for i in 0..2 { + blocks.add(&vec![i as f32; 4]).unwrap(); + } + assert_eq!(blocks.len(), 2); + + // Should trigger new block allocation + blocks.add(&vec![99.0; 4]).unwrap(); + assert_eq!(blocks.len(), 3); + assert!(blocks.capacity() >= 3); + } + + #[test] + fn test_data_blocks_double_delete() { + let mut blocks = DataBlocks::::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let id1 = blocks.add(&v1).unwrap(); + + // First delete should succeed + assert!(blocks.mark_deleted(id1)); + assert_eq!(blocks.len(), 0); + + // Second delete should fail (already deleted) + assert!(!blocks.mark_deleted(id1)); + assert_eq!(blocks.len(), 0); + } + + #[test] + fn test_data_blocks_invalid_dimension() { + let mut blocks = DataBlocks::::new(4, 10); + + // Wrong dimension should fail + let wrong_dim = vec![1.0, 2.0, 3.0]; // 3 instead of 4 + assert!(blocks.add(&wrong_dim).is_none()); + + // Correct dimension should succeed + let correct_dim = vec![1.0, 2.0, 3.0, 4.0]; + assert!(blocks.add(&correct_dim).is_some()); + } + + #[test] + fn test_data_blocks_bounds_checking() { + let blocks = DataBlocks::::new(4, 10); + + // Invalid ID should return None + assert!(blocks.get(INVALID_ID).is_none()); + assert!(blocks.get_ptr(INVALID_ID).is_none()); + + // Out of bounds ID should return None + assert!(blocks.get(999).is_none()); + assert!(blocks.get_ptr(999).is_none()); + } + + #[test] + fn test_data_blocks_update() { + let mut blocks = DataBlocks::::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let v2 = vec![5.0, 6.0, 7.0, 8.0]; + let id1 = blocks.add(&v1).unwrap(); + + // Update should succeed + assert!(blocks.update(id1, &v2)); + assert_eq!(blocks.get(id1), Some(v2.as_slice())); + + // Update with wrong dimension should fail + let wrong_dim = vec![1.0, 2.0, 3.0]; + assert!(!blocks.update(id1, &wrong_dim)); + + // Update deleted vector should fail + blocks.mark_deleted(id1); + assert!(!blocks.update(id1, &v1)); + } + + #[test] + fn test_data_blocks_compact() { + let mut blocks = DataBlocks::::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let v2 = vec![5.0, 6.0, 7.0, 8.0]; + let v3 = vec![9.0, 10.0, 11.0, 12.0]; + let v4 = vec![13.0, 14.0, 15.0, 16.0]; + + let id1 = blocks.add(&v1).unwrap(); + let id2 = blocks.add(&v2).unwrap(); + let id3 = blocks.add(&v3).unwrap(); + let id4 = blocks.add(&v4).unwrap(); + + // Delete vectors 1 and 3 (creating gaps) + blocks.mark_deleted(id2); + blocks.mark_deleted(id3); + + assert_eq!(blocks.len(), 2); + assert_eq!(blocks.deleted_count(), 2); + assert!((blocks.fragmentation() - 0.5).abs() < 0.01); + + // Compact + let mapping = blocks.compact(false); + + // Should have 2 vectors now, contiguous + assert_eq!(blocks.len(), 2); + assert_eq!(blocks.deleted_count(), 0); + assert!((blocks.fragmentation() - 0.0).abs() < 0.01); + + // Check mapping + assert!(mapping.contains_key(&id1)); + assert!(mapping.contains_key(&id4)); + assert!(!mapping.contains_key(&id2)); // Deleted + assert!(!mapping.contains_key(&id3)); // Deleted + + // Verify data integrity + let new_id1 = mapping[&id1]; + let new_id4 = mapping[&id4]; + assert_eq!(blocks.get(new_id1), Some(v1.as_slice())); + assert_eq!(blocks.get(new_id4), Some(v4.as_slice())); + } + + #[test] + fn test_data_blocks_compact_no_deletions() { + let mut blocks = DataBlocks::::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let v2 = vec![5.0, 6.0, 7.0, 8.0]; + + let id1 = blocks.add(&v1).unwrap(); + let id2 = blocks.add(&v2).unwrap(); + + // Compact with no deletions should return identity mapping + let mapping = blocks.compact(false); + + assert_eq!(mapping[&id1], id1); + assert_eq!(mapping[&id2], id2); + assert_eq!(blocks.len(), 2); + } +} diff --git a/rust/vecsim/src/containers/mmap_blocks.rs b/rust/vecsim/src/containers/mmap_blocks.rs new file mode 100644 index 000000000..0c50a59c0 --- /dev/null +++ b/rust/vecsim/src/containers/mmap_blocks.rs @@ -0,0 +1,501 @@ +//! Memory-mapped block storage for disk-based vector indices. +//! +//! This module provides persistent storage of vectors using memory-mapped files, +//! allowing indices to work with datasets larger than available RAM. + +use crate::types::{IdType, VectorElement}; +use memmap2::{MmapMut, MmapOptions}; +use std::collections::HashSet; +use std::fs::OpenOptions; +use std::io; +use std::marker::PhantomData; +use std::path::{Path, PathBuf}; + +/// Header size in bytes (magic + version + dim + count + high_water_mark). +const HEADER_SIZE: usize = 32; + +/// Magic number for file format identification. +const MAGIC: u64 = 0x56454353494D4D41; // "VECSIMMA" + +/// File format version. +const FORMAT_VERSION: u32 = 1; + +/// Memory-mapped block storage for vectors. +/// +/// Provides persistent storage using memory-mapped files. +/// The file format is: +/// ```text +/// [Header: 32 bytes] +/// - magic: u64 +/// - version: u32 +/// - dim: u32 +/// - count: u64 +/// - high_water_mark: u64 +/// [Vector data: dim * sizeof(T) * capacity] +/// [Deleted bitmap: ceil(capacity / 8) bytes] +/// ``` +pub struct MmapDataBlocks { + /// Memory-mapped file. + mmap: MmapMut, + /// Path to the data file. + path: PathBuf, + /// Vector dimension. + dim: usize, + /// Number of active (non-deleted) vectors. + count: usize, + /// Highest ID ever allocated (not counting reuse). + high_water_mark: usize, + /// Current capacity. + capacity: usize, + /// Set of deleted slot IDs. + free_slots: HashSet, + /// Phantom marker for element type. + _marker: PhantomData, +} + +impl MmapDataBlocks { + /// Create a new memory-mapped storage at the given path. + /// + /// If the file doesn't exist, it will be created. + /// If it exists, it will be opened and validated. + pub fn new>(path: P, dim: usize, initial_capacity: usize) -> io::Result { + let path = path.as_ref().to_path_buf(); + let capacity = initial_capacity.max(1); + let file_size = Self::calculate_file_size(dim, capacity); + + if path.exists() { + Self::open(path) + } else { + Self::create(path, dim, capacity, file_size) + } + } + + /// Create a new storage file. + fn create(path: PathBuf, dim: usize, capacity: usize, file_size: usize) -> io::Result { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(true) + .open(&path)?; + + file.set_len(file_size as u64)?; + + let mut mmap = unsafe { MmapOptions::new().map_mut(&file)? }; + + // Write header + Self::write_header(&mut mmap, dim, 0, 0); + + Ok(Self { + mmap, + path, + dim, + count: 0, + high_water_mark: 0, + capacity, + free_slots: HashSet::new(), + _marker: PhantomData, + }) + } + + /// Open an existing storage file. + fn open(path: PathBuf) -> io::Result { + let file = OpenOptions::new().read(true).write(true).open(&path)?; + + let mmap = unsafe { MmapOptions::new().map_mut(&file)? }; + + // Validate and read header + if mmap.len() < HEADER_SIZE { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "File too small for header", + )); + } + + let magic = u64::from_le_bytes(mmap[0..8].try_into().unwrap()); + if magic != MAGIC { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid magic number", + )); + } + + let version = u32::from_le_bytes(mmap[8..12].try_into().unwrap()); + if version != FORMAT_VERSION { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Unsupported version: {version}"), + )); + } + + let dim = u32::from_le_bytes(mmap[12..16].try_into().unwrap()) as usize; + let count = u64::from_le_bytes(mmap[16..24].try_into().unwrap()) as usize; + let high_water_mark = u64::from_le_bytes(mmap[24..32].try_into().unwrap()) as usize; + + // Calculate capacity from file size + let vector_size = dim * std::mem::size_of::(); + let data_area_size = mmap.len() - HEADER_SIZE; + // Approximate capacity (ignoring deleted bitmap for now) + let capacity = if vector_size > 0 { + data_area_size / vector_size + } else { + 0 + }; + + // Load deleted slots from bitmap + let free_slots = Self::load_deleted_bitmap(&mmap, dim, capacity, high_water_mark); + + Ok(Self { + mmap, + path, + dim, + count, + high_water_mark, + capacity, + free_slots, + _marker: PhantomData, + }) + } + + /// Calculate the file size needed for the given capacity. + fn calculate_file_size(dim: usize, capacity: usize) -> usize { + let vector_size = dim * std::mem::size_of::(); + let data_size = vector_size * capacity; + let bitmap_size = capacity.div_ceil(8); + HEADER_SIZE + data_size + bitmap_size + } + + /// Write the header to the mmap. + fn write_header(mmap: &mut MmapMut, dim: usize, count: usize, high_water_mark: usize) { + mmap[0..8].copy_from_slice(&MAGIC.to_le_bytes()); + mmap[8..12].copy_from_slice(&FORMAT_VERSION.to_le_bytes()); + mmap[12..16].copy_from_slice(&(dim as u32).to_le_bytes()); + mmap[16..24].copy_from_slice(&(count as u64).to_le_bytes()); + mmap[24..32].copy_from_slice(&(high_water_mark as u64).to_le_bytes()); + } + + /// Load the deleted bitmap from the file. + fn load_deleted_bitmap( + mmap: &MmapMut, + dim: usize, + capacity: usize, + high_water_mark: usize, + ) -> HashSet { + let mut free_slots = HashSet::new(); + let vector_size = dim * std::mem::size_of::(); + let bitmap_offset = HEADER_SIZE + vector_size * capacity; + + if bitmap_offset >= mmap.len() { + return free_slots; + } + + for id in 0..high_water_mark { + let byte_idx = bitmap_offset + id / 8; + let bit_idx = id % 8; + + if byte_idx < mmap.len() && (mmap[byte_idx] & (1 << bit_idx)) != 0 { + free_slots.insert(id as IdType); + } + } + + free_slots + } + + /// Save the deleted bitmap to the file. + fn save_deleted_bitmap(&mut self) { + let vector_size = self.dim * std::mem::size_of::(); + let bitmap_offset = HEADER_SIZE + vector_size * self.capacity; + + // Clear bitmap + let bitmap_size = self.capacity.div_ceil(8); + if bitmap_offset + bitmap_size <= self.mmap.len() { + for i in 0..bitmap_size { + self.mmap[bitmap_offset + i] = 0; + } + + // Set deleted bits + for &id in &self.free_slots { + let id = id as usize; + let byte_idx = bitmap_offset + id / 8; + let bit_idx = id % 8; + + if byte_idx < self.mmap.len() { + self.mmap[byte_idx] |= 1 << bit_idx; + } + } + } + } + + /// Get the vector dimension. + #[inline] + pub fn dim(&self) -> usize { + self.dim + } + + /// Get the number of active vectors. + #[inline] + pub fn len(&self) -> usize { + self.count + } + + /// Check if empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.count == 0 + } + + /// Get the current capacity. + #[inline] + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Get the data path. + pub fn path(&self) -> &Path { + &self.path + } + + /// Get the offset of a vector in the file. + #[inline] + fn vector_offset(&self, id: IdType) -> usize { + let vector_size = self.dim * std::mem::size_of::(); + HEADER_SIZE + (id as usize) * vector_size + } + + /// Check if an ID is valid. + #[inline] + pub fn is_valid(&self, id: IdType) -> bool { + let id_usize = id as usize; + id_usize < self.high_water_mark && !self.free_slots.contains(&id) + } + + /// Add a vector and return its ID. + pub fn add(&mut self, vector: &[T]) -> Option { + if vector.len() != self.dim { + return None; + } + + // Check if we need to grow + if self.high_water_mark >= self.capacity && !self.grow() { + return None; + } + + let id = self.high_water_mark as IdType; + let offset = self.vector_offset(id); + let vector_size = self.dim * std::mem::size_of::(); + + // Write vector data + let src = unsafe { + std::slice::from_raw_parts(vector.as_ptr() as *const u8, vector_size) + }; + self.mmap[offset..offset + vector_size].copy_from_slice(src); + + self.high_water_mark += 1; + self.count += 1; + + // Update header + Self::write_header(&mut self.mmap, self.dim, self.count, self.high_water_mark); + + Some(id) + } + + /// Get a vector by ID. + pub fn get(&self, id: IdType) -> Option<&[T]> { + if !self.is_valid(id) { + return None; + } + + let offset = self.vector_offset(id); + let vector_size = self.dim * std::mem::size_of::(); + + if offset + vector_size > self.mmap.len() { + return None; + } + + let ptr = self.mmap[offset..].as_ptr() as *const T; + Some(unsafe { std::slice::from_raw_parts(ptr, self.dim) }) + } + + /// Mark a vector as deleted. + pub fn mark_deleted(&mut self, id: IdType) -> bool { + let id_usize = id as usize; + if id_usize >= self.high_water_mark || self.free_slots.contains(&id) { + return false; + } + + self.free_slots.insert(id); + self.count = self.count.saturating_sub(1); + + // Update header and bitmap + Self::write_header(&mut self.mmap, self.dim, self.count, self.high_water_mark); + self.save_deleted_bitmap(); + + true + } + + /// Grow the storage capacity. + fn grow(&mut self) -> bool { + let new_capacity = self.capacity * 2; + let new_size = Self::calculate_file_size(self.dim, new_capacity); + + // Reopen file and resize + let file = match OpenOptions::new().read(true).write(true).open(&self.path) { + Ok(f) => f, + Err(_) => return false, + }; + + if file.set_len(new_size as u64).is_err() { + return false; + } + + // Remap + match unsafe { MmapOptions::new().map_mut(&file) } { + Ok(new_mmap) => { + self.mmap = new_mmap; + self.capacity = new_capacity; + true + } + Err(_) => false, + } + } + + /// Clear all vectors. + pub fn clear(&mut self) { + self.count = 0; + self.high_water_mark = 0; + self.free_slots.clear(); + + Self::write_header(&mut self.mmap, self.dim, 0, 0); + self.save_deleted_bitmap(); + } + + /// Flush changes to disk. + pub fn flush(&self) -> io::Result<()> { + self.mmap.flush() + } + + /// Get fragmentation ratio. + pub fn fragmentation(&self) -> f64 { + if self.high_water_mark == 0 { + 0.0 + } else { + self.free_slots.len() as f64 / self.high_water_mark as f64 + } + } + + /// Iterate over valid IDs. + pub fn iter_ids(&self) -> impl Iterator + '_ { + (0..self.high_water_mark as IdType).filter(move |&id| !self.free_slots.contains(&id)) + } +} + +impl Drop for MmapDataBlocks { + fn drop(&mut self) { + // Ensure data is flushed on drop + let _ = self.mmap.flush(); + } +} + +// Safety: MmapDataBlocks can be sent between threads +unsafe impl Send for MmapDataBlocks {} + +// Note: Sync is NOT implemented because mmap requires careful synchronization +// for concurrent access. Use external locking. + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + fn temp_path() -> PathBuf { + let mut path = std::env::temp_dir(); + path.push(format!("mmap_test_{}.dat", rand::random::())); + path + } + + #[test] + fn test_mmap_blocks_basic() { + let path = temp_path(); + { + let mut blocks = MmapDataBlocks::::new(&path, 4, 10).unwrap(); + + let v1 = vec![1.0f32, 2.0, 3.0, 4.0]; + let v2 = vec![5.0f32, 6.0, 7.0, 8.0]; + + let id1 = blocks.add(&v1).unwrap(); + let id2 = blocks.add(&v2).unwrap(); + + assert_eq!(id1, 0); + assert_eq!(id2, 1); + assert_eq!(blocks.len(), 2); + + let retrieved = blocks.get(id1).unwrap(); + assert_eq!(retrieved, &v1[..]); + } + + // Cleanup + fs::remove_file(&path).ok(); + } + + #[test] + fn test_mmap_blocks_persistence() { + let path = temp_path(); + let v1 = vec![1.0f32, 2.0, 3.0, 4.0]; + + // Create and add + { + let mut blocks = MmapDataBlocks::::new(&path, 4, 10).unwrap(); + blocks.add(&v1).unwrap(); + blocks.flush().unwrap(); + } + + // Reopen and verify + { + let blocks = MmapDataBlocks::::new(&path, 4, 10).unwrap(); + assert_eq!(blocks.len(), 1); + let retrieved = blocks.get(0).unwrap(); + assert_eq!(retrieved, &v1[..]); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_mmap_blocks_delete() { + let path = temp_path(); + { + let mut blocks = MmapDataBlocks::::new(&path, 4, 10).unwrap(); + + let id1 = blocks.add(&[1.0, 2.0, 3.0, 4.0]).unwrap(); + let id2 = blocks.add(&[5.0, 6.0, 7.0, 8.0]).unwrap(); + + assert_eq!(blocks.len(), 2); + + blocks.mark_deleted(id1); + assert_eq!(blocks.len(), 1); + assert!(blocks.get(id1).is_none()); + assert!(blocks.get(id2).is_some()); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_mmap_blocks_grow() { + let path = temp_path(); + { + let mut blocks = MmapDataBlocks::::new(&path, 4, 2).unwrap(); + + // Add more than initial capacity + for i in 0..10 { + let v = vec![i as f32; 4]; + blocks.add(&v).unwrap(); + } + + assert_eq!(blocks.len(), 10); + assert!(blocks.capacity() >= 10); + } + + fs::remove_file(&path).ok(); + } +} diff --git a/rust/vecsim/src/containers/mod.rs b/rust/vecsim/src/containers/mod.rs new file mode 100644 index 000000000..a0da24c0a --- /dev/null +++ b/rust/vecsim/src/containers/mod.rs @@ -0,0 +1,14 @@ +//! Container types for vector storage. +//! +//! This module provides efficient data structures for storing vectors: +//! - `DataBlocks`: Block-based storage with SIMD-aligned memory +//! - `Sq8DataBlocks`: Block-based storage for SQ8-quantized vectors +//! - `MmapDataBlocks`: Memory-mapped storage for disk-based indices + +pub mod data_blocks; +pub mod mmap_blocks; +pub mod sq8_blocks; + +pub use data_blocks::DataBlocks; +pub use mmap_blocks::MmapDataBlocks; +pub use sq8_blocks::Sq8DataBlocks; diff --git a/rust/vecsim/src/containers/sq8_blocks.rs b/rust/vecsim/src/containers/sq8_blocks.rs new file mode 100644 index 000000000..43105e9c8 --- /dev/null +++ b/rust/vecsim/src/containers/sq8_blocks.rs @@ -0,0 +1,470 @@ +//! Block-based storage for SQ8-quantized vectors. +//! +//! This module provides `Sq8DataBlocks`, a container optimized for storing +//! SQ8-quantized vectors with per-vector metadata for dequantization and +//! asymmetric distance computation. + +use crate::quantization::{Sq8Codec, Sq8VectorMeta}; +use crate::types::{IdType, INVALID_ID}; +use std::collections::HashSet; + +/// Default block size (number of vectors per block). +const DEFAULT_BLOCK_SIZE: usize = 1024; + +/// A single block of SQ8-quantized vector data. +struct Sq8DataBlock { + /// Quantized vector data (u8 values). + data: Vec, + /// Per-vector metadata. + metadata: Vec, + /// Number of vectors this block can hold. + capacity: usize, + /// Vector dimension. + dim: usize, +} + +impl Sq8DataBlock { + /// Create a new SQ8 data block. + fn new(num_vectors: usize, dim: usize) -> Self { + Self { + data: vec![0u8; num_vectors * dim], + metadata: vec![Sq8VectorMeta::default(); num_vectors], + capacity: num_vectors, + dim, + } + } + + /// Check if an index is within bounds. + #[inline] + fn is_valid_index(&self, index: usize) -> bool { + index < self.capacity + } + + /// Get quantized data and metadata for a vector. + #[inline] + fn get_vector(&self, index: usize) -> Option<(&[u8], &Sq8VectorMeta)> { + if !self.is_valid_index(index) { + return None; + } + let start = index * self.dim; + let end = start + self.dim; + Some((&self.data[start..end], &self.metadata[index])) + } + + /// Get raw pointer to quantized data (for SIMD operations). + #[inline] + fn get_data_ptr(&self, index: usize) -> Option<*const u8> { + if !self.is_valid_index(index) { + return None; + } + Some(unsafe { self.data.as_ptr().add(index * self.dim) }) + } + + /// Write quantized vector data and metadata. + #[inline] + fn write_vector(&mut self, index: usize, data: &[u8], meta: Sq8VectorMeta) -> bool { + if !self.is_valid_index(index) || data.len() != self.dim { + return false; + } + let start = index * self.dim; + self.data[start..start + self.dim].copy_from_slice(data); + self.metadata[index] = meta; + true + } +} + +/// Block-based storage for SQ8-quantized vectors. +/// +/// Stores vectors in quantized form with per-vector metadata for efficient +/// asymmetric distance computation. +pub struct Sq8DataBlocks { + /// The blocks storing quantized vector data. + blocks: Vec, + /// Number of vectors per block. + vectors_per_block: usize, + /// Vector dimension. + dim: usize, + /// Total number of vectors stored (excluding deleted). + count: usize, + /// Free slots from deleted vectors. + free_slots: HashSet, + /// High water mark: highest ID ever allocated + 1. + high_water_mark: usize, + /// SQ8 codec for encoding/decoding. + codec: Sq8Codec, +} + +impl Sq8DataBlocks { + /// Create a new Sq8DataBlocks container. + /// + /// # Arguments + /// * `dim` - Vector dimension + /// * `initial_capacity` - Initial number of vectors to allocate + pub fn new(dim: usize, initial_capacity: usize) -> Self { + let vectors_per_block = DEFAULT_BLOCK_SIZE; + let num_blocks = initial_capacity.div_ceil(vectors_per_block); + + let blocks: Vec<_> = (0..num_blocks.max(1)) + .map(|_| Sq8DataBlock::new(vectors_per_block, dim)) + .collect(); + + Self { + blocks, + vectors_per_block, + dim, + count: 0, + free_slots: HashSet::new(), + high_water_mark: 0, + codec: Sq8Codec::new(dim), + } + } + + /// Create with a custom block size. + pub fn with_block_size(dim: usize, initial_capacity: usize, block_size: usize) -> Self { + let vectors_per_block = block_size; + let num_blocks = initial_capacity.div_ceil(vectors_per_block); + + let blocks: Vec<_> = (0..num_blocks.max(1)) + .map(|_| Sq8DataBlock::new(vectors_per_block, dim)) + .collect(); + + Self { + blocks, + vectors_per_block, + dim, + count: 0, + free_slots: HashSet::new(), + high_water_mark: 0, + codec: Sq8Codec::new(dim), + } + } + + /// Get the vector dimension. + #[inline] + pub fn dimension(&self) -> usize { + self.dim + } + + /// Get the number of vectors stored. + #[inline] + pub fn len(&self) -> usize { + self.count + } + + /// Check if empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.count == 0 + } + + /// Get the total capacity. + #[inline] + pub fn capacity(&self) -> usize { + self.blocks.len() * self.vectors_per_block + } + + /// Get the SQ8 codec. + #[inline] + pub fn codec(&self) -> &Sq8Codec { + &self.codec + } + + /// Convert an internal ID to block and offset indices. + #[inline] + fn id_to_indices(&self, id: IdType) -> (usize, usize) { + let id = id as usize; + (id / self.vectors_per_block, id % self.vectors_per_block) + } + + /// Add an f32 vector (will be quantized) and return its internal ID. + pub fn add(&mut self, vector: &[f32]) -> Option { + if vector.len() != self.dim { + return None; + } + + // Quantize the vector + let (quantized, meta) = self.codec.encode_from_f32_slice(vector); + + // Find a slot + let slot = if let Some(&id) = self.free_slots.iter().next() { + self.free_slots.remove(&id); + id + } else { + let next_slot = self.high_water_mark; + if next_slot >= self.capacity() { + self.blocks + .push(Sq8DataBlock::new(self.vectors_per_block, self.dim)); + } + self.high_water_mark += 1; + next_slot as IdType + }; + + let (block_idx, offset) = self.id_to_indices(slot); + if self.blocks[block_idx].write_vector(offset, &quantized, meta) { + self.count += 1; + Some(slot) + } else { + // Write failed, restore state + if slot as usize == self.high_water_mark - 1 { + self.high_water_mark -= 1; + } else { + self.free_slots.insert(slot); + } + None + } + } + + /// Add a pre-quantized vector with metadata. + pub fn add_quantized(&mut self, quantized: &[u8], meta: Sq8VectorMeta) -> Option { + if quantized.len() != self.dim { + return None; + } + + // Find a slot + let slot = if let Some(&id) = self.free_slots.iter().next() { + self.free_slots.remove(&id); + id + } else { + let next_slot = self.high_water_mark; + if next_slot >= self.capacity() { + self.blocks + .push(Sq8DataBlock::new(self.vectors_per_block, self.dim)); + } + self.high_water_mark += 1; + next_slot as IdType + }; + + let (block_idx, offset) = self.id_to_indices(slot); + if self.blocks[block_idx].write_vector(offset, quantized, meta) { + self.count += 1; + Some(slot) + } else { + if slot as usize == self.high_water_mark - 1 { + self.high_water_mark -= 1; + } else { + self.free_slots.insert(slot); + } + None + } + } + + /// Check if an ID is valid and not deleted. + #[inline] + pub fn is_valid(&self, id: IdType) -> bool { + if id == INVALID_ID { + return false; + } + let id_usize = id as usize; + id_usize < self.high_water_mark && !self.free_slots.contains(&id) + } + + /// Get quantized data and metadata by ID. + #[inline] + pub fn get(&self, id: IdType) -> Option<(&[u8], &Sq8VectorMeta)> { + if !self.is_valid(id) { + return None; + } + let (block_idx, offset) = self.id_to_indices(id); + if block_idx >= self.blocks.len() { + return None; + } + self.blocks[block_idx].get_vector(offset) + } + + /// Get raw pointer to quantized data. + #[inline] + pub fn get_data_ptr(&self, id: IdType) -> Option<*const u8> { + if !self.is_valid(id) { + return None; + } + let (block_idx, offset) = self.id_to_indices(id); + if block_idx >= self.blocks.len() { + return None; + } + self.blocks[block_idx].get_data_ptr(offset) + } + + /// Get metadata only (for precomputed values). + #[inline] + pub fn get_metadata(&self, id: IdType) -> Option<&Sq8VectorMeta> { + self.get(id).map(|(_, meta)| meta) + } + + /// Decode a vector back to f32. + pub fn decode(&self, id: IdType) -> Option> { + let (quantized, meta) = self.get(id)?; + Some(self.codec.decode_from_u8_slice(quantized, meta)) + } + + /// Mark a slot as deleted. + pub fn mark_deleted(&mut self, id: IdType) -> bool { + if id == INVALID_ID { + return false; + } + let id_usize = id as usize; + if id_usize >= self.high_water_mark || self.free_slots.contains(&id) { + return false; + } + self.free_slots.insert(id); + self.count = self.count.saturating_sub(1); + true + } + + /// Update a vector at the given ID. + pub fn update(&mut self, id: IdType, vector: &[f32]) -> bool { + if vector.len() != self.dim || !self.is_valid(id) { + return false; + } + + let (quantized, meta) = self.codec.encode_from_f32_slice(vector); + let (block_idx, offset) = self.id_to_indices(id); + + if block_idx >= self.blocks.len() { + return false; + } + + self.blocks[block_idx].write_vector(offset, &quantized, meta) + } + + /// Clear all vectors. + pub fn clear(&mut self) { + self.count = 0; + self.free_slots.clear(); + self.high_water_mark = 0; + } + + /// Reserve space for additional vectors. + pub fn reserve(&mut self, additional: usize) { + let needed = self.count + additional; + let current_capacity = self.capacity(); + + if needed > current_capacity { + let additional_blocks = (needed - current_capacity).div_ceil(self.vectors_per_block); + for _ in 0..additional_blocks { + self.blocks + .push(Sq8DataBlock::new(self.vectors_per_block, self.dim)); + } + } + } + + /// Iterate over all valid vector IDs. + pub fn iter_ids(&self) -> impl Iterator + '_ { + (0..self.high_water_mark as IdType).filter(move |&id| !self.free_slots.contains(&id)) + } + + /// Get the number of deleted slots. + #[inline] + pub fn deleted_count(&self) -> usize { + self.free_slots.len() + } + + /// Get the fragmentation ratio. + #[inline] + pub fn fragmentation(&self) -> f64 { + if self.high_water_mark == 0 { + 0.0 + } else { + self.free_slots.len() as f64 / self.high_water_mark as f64 + } + } + + /// Memory usage in bytes (approximate). + pub fn memory_usage(&self) -> usize { + let data_size = self.blocks.len() * self.vectors_per_block * self.dim; + let meta_size = self.blocks.len() + * self.vectors_per_block + * std::mem::size_of::(); + let overhead = self.free_slots.len() * std::mem::size_of::(); + data_size + meta_size + overhead + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sq8_blocks_basic() { + let mut blocks = Sq8DataBlocks::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let v2 = vec![5.0, 6.0, 7.0, 8.0]; + + let id1 = blocks.add(&v1).unwrap(); + let id2 = blocks.add(&v2).unwrap(); + + assert_eq!(blocks.len(), 2); + assert!(blocks.is_valid(id1)); + assert!(blocks.is_valid(id2)); + + // Decode and verify + let decoded1 = blocks.decode(id1).unwrap(); + let decoded2 = blocks.decode(id2).unwrap(); + + for (orig, dec) in v1.iter().zip(decoded1.iter()) { + assert!((orig - dec).abs() < 0.1); + } + for (orig, dec) in v2.iter().zip(decoded2.iter()) { + assert!((orig - dec).abs() < 0.1); + } + } + + #[test] + fn test_sq8_blocks_delete_reuse() { + let mut blocks = Sq8DataBlocks::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let v2 = vec![5.0, 6.0, 7.0, 8.0]; + let v3 = vec![9.0, 10.0, 11.0, 12.0]; + + let id1 = blocks.add(&v1).unwrap(); + let _id2 = blocks.add(&v2).unwrap(); + + // Delete first + assert!(blocks.mark_deleted(id1)); + assert_eq!(blocks.len(), 1); + assert!(!blocks.is_valid(id1)); + + // Add new - should reuse slot + let id3 = blocks.add(&v3).unwrap(); + assert_eq!(id3, id1); + assert_eq!(blocks.len(), 2); + } + + #[test] + fn test_sq8_blocks_metadata() { + let mut blocks = Sq8DataBlocks::new(4, 10); + + let v = vec![0.0, 0.5, 1.0, 0.25]; + let id = blocks.add(&v).unwrap(); + + let meta = blocks.get_metadata(id).unwrap(); + assert_eq!(meta.min, 0.0); + assert!((meta.delta - (1.0 / 255.0)).abs() < 0.0001); + } + + #[test] + fn test_sq8_blocks_update() { + let mut blocks = Sq8DataBlocks::new(4, 10); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + let v2 = vec![10.0, 20.0, 30.0, 40.0]; + + let id = blocks.add(&v1).unwrap(); + assert!(blocks.update(id, &v2)); + + let decoded = blocks.decode(id).unwrap(); + for (orig, dec) in v2.iter().zip(decoded.iter()) { + assert!((orig - dec).abs() < 0.5); + } + } + + #[test] + fn test_sq8_blocks_memory() { + let blocks = Sq8DataBlocks::new(128, 1000); + let mem = blocks.memory_usage(); + + // Should be approximately: 1 block * 1024 vectors * (128 bytes + 16 bytes meta) + assert!(mem > 0); + } +} diff --git a/rust/vecsim/src/data_type_tests.rs b/rust/vecsim/src/data_type_tests.rs new file mode 100644 index 000000000..8b8fa0ec2 --- /dev/null +++ b/rust/vecsim/src/data_type_tests.rs @@ -0,0 +1,1468 @@ +//! Data type-specific integration tests. +//! +//! This module tests all supported vector element types (f32, f64, Float16, BFloat16, +//! Int8, UInt8, Int32, Int64) with various index operations to ensure type-specific +//! behavior is correct. + +use crate::distance::Metric; +use crate::index::{BruteForceMulti, BruteForceParams, BruteForceSingle, VecSimIndex}; +use crate::index::{HnswParams, HnswSingle}; +use crate::types::{BFloat16, Float16, Int32, Int64, Int8, UInt8, VectorElement}; + +// ============================================================================= +// Helper functions +// ============================================================================= + +fn create_orthogonal_vectors_f32(dim: usize) -> Vec> { + (0..dim) + .map(|i| { + let mut v = vec![T::zero(); dim]; + v[i] = T::from_f32(1.0); + v + }) + .collect() +} + +// ============================================================================= +// f32 Tests (baseline) +// ============================================================================= + +#[test] +fn test_f32_brute_force_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + let results = index.top_k_query(&[1.0, 0.1, 0.0, 0.0], 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_f32_serialization_roundtrip() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10 { + let v: Vec = (0..4).map(|j| (i * 4 + j) as f32).collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + // Use save/load with a buffer + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + let loaded = BruteForceSingle::::load(&mut buffer.as_slice()).unwrap(); + + assert_eq!(index.index_size(), loaded.index_size()); + + let query = [1.0f32, 2.0, 3.0, 4.0]; + let orig_results = index.top_k_query(&query, 3, None).unwrap(); + let loaded_results = loaded.top_k_query(&query, 3, None).unwrap(); + + assert_eq!(orig_results.results.len(), loaded_results.results.len()); + for (o, d) in orig_results.results.iter().zip(loaded_results.results.iter()) { + assert_eq!(o.label, d.label); + } +} + +// ============================================================================= +// f64 Tests +// ============================================================================= + +#[test] +fn test_f64_brute_force_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + let results = index.top_k_query(&[1.0, 0.1, 0.0, 0.0], 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_f64_brute_force_inner_product() { + let params = BruteForceParams::new(4, Metric::InnerProduct); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.5, 0.5, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 3).unwrap(); + + let results = index.top_k_query(&[1.0, 0.0, 0.0, 0.0], 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_f64_brute_force_cosine() { + let params = BruteForceParams::new(4, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[0.707, 0.707, 0.0, 0.0], 3).unwrap(); + + let results = index.top_k_query(&[1.0, 0.0, 0.0, 0.0], 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_f64_brute_force_multi() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.9, 0.1, 0.0, 0.0], 1).unwrap(); // Same label + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.label_count(1), 2); + + let results = index.top_k_query(&[1.0, 0.0, 0.0, 0.0], 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_f64_hnsw_basic() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(16); + let mut index = HnswSingle::::new(params); + + for v in create_orthogonal_vectors_f32::(4) { + index.add_vector(&v, index.index_size() as u64 + 1).unwrap(); + } + + let results = index.top_k_query(&[1.0, 0.0, 0.0, 0.0], 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_f64_high_precision() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Use values that would lose precision in f32 + let precise1: Vec = vec![1.0000000001, 0.0, 0.0, 0.0]; + let precise2: Vec = vec![1.0000000002, 0.0, 0.0, 0.0]; + + index.add_vector(&precise1, 1).unwrap(); + index.add_vector(&precise2, 2).unwrap(); + + let query: Vec = vec![1.0000000001, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_f64_data_integrity() { + // Note: Serialization is only implemented for f32 + // This test verifies data integrity without serialization + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10 { + let v: Vec = (0..4).map(|j| (i * 4 + j) as f64).collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Verify each vector can be retrieved + for i in 0..10 { + assert!(index.contains(i as u64)); + let v = index.get_vector(i as u64).unwrap(); + for (j, &val) in v.iter().enumerate() { + assert_eq!(val, (i * 4 + j) as f64); + } + } +} + +#[test] +fn test_f64_range_query() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&[0.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 3).unwrap(); + index.add_vector(&[3.0, 0.0, 0.0, 0.0], 4).unwrap(); + + let results = index + .range_query(&[0.0, 0.0, 0.0, 0.0], 1.5, None) + .unwrap(); + assert_eq!(results.results.len(), 2); // Distance 0 and 1 +} + +#[test] +fn test_f64_delete_and_query() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + index.delete_vector(1).unwrap(); + + let results = index.top_k_query(&[1.0, 0.0, 0.0, 0.0], 2, None).unwrap(); + assert_eq!(results.results.len(), 1); + assert_eq!(results.results[0].label, 2); +} + +// ============================================================================= +// Float16 Tests +// ============================================================================= + +#[test] +fn test_float16_brute_force_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let v2: Vec = [0.0f32, 1.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let v3: Vec = [0.0f32, 0.0, 1.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + let query: Vec = [1.0f32, 0.1, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_float16_brute_force_inner_product() { + let params = BruteForceParams::new(4, Metric::InnerProduct); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let v2: Vec = [0.5f32, 0.5, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_float16_brute_force_cosine() { + let params = BruteForceParams::new(4, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let v2: Vec = [0.0f32, 1.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_float16_hnsw_basic() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(16); + let mut index = HnswSingle::::new(params); + + for i in 0..4 { + let mut v = vec![Float16::from_f32(0.0); 4]; + v[i] = Float16::from_f32(1.0); + index.add_vector(&v, i as u64 + 1).unwrap(); + } + + let query: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_float16_precision_limits() { + // Float16 has ~3 decimal digits precision + let v1 = Float16::from_f32(1.0); + let v2 = Float16::from_f32(1.001); // Should be distinguishable + let v3 = Float16::from_f32(1.0001); // May be same as v1 due to precision + + // Values 1.0 and 1.001 should be distinguishable + let diff = (v1.to_f32() - v2.to_f32()).abs(); + assert!(diff > 0.0005); + + // Very small differences may be lost + let tiny_diff = (v1.to_f32() - v3.to_f32()).abs(); + assert!(tiny_diff < 0.001); // Precision loss expected +} + +#[test] +fn test_float16_data_integrity() { + // Note: Serialization is only implemented for f32 + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10 { + let v: Vec = (0..4).map(|j| Float16::from_f32((i * 4 + j) as f32)).collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Verify vectors can be retrieved with approximate equality (FP16 has limited precision) + for i in 0..10 { + assert!(index.contains(i as u64)); + let v = index.get_vector(i as u64).unwrap(); + for (j, &val) in v.iter().enumerate() { + let expected = (i * 4 + j) as f32; + assert!((val.to_f32() - expected).abs() < 0.1); + } + } +} + +#[test] +fn test_float16_range_query() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..5 { + let v: Vec = vec![ + Float16::from_f32(i as f32), + Float16::from_f32(0.0), + Float16::from_f32(0.0), + Float16::from_f32(0.0), + ]; + index.add_vector(&v, i as u64).unwrap(); + } + + let query: Vec = [0.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| Float16::from_f32(x)) + .collect(); + let results = index.range_query(&query, 1.5, None).unwrap(); + assert_eq!(results.results.len(), 2); // Distance 0 and 1 +} + +#[test] +fn test_float16_special_values() { + // Test that special values don't break the index + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [Float16::ZERO, Float16::ZERO, Float16::ZERO, Float16::ONE].to_vec(); + index.add_vector(&v1, 1).unwrap(); + + let query: Vec = [Float16::ZERO, Float16::ZERO, Float16::ZERO, Float16::ONE].to_vec(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); +} + +// ============================================================================= +// BFloat16 Tests +// ============================================================================= + +#[test] +fn test_bfloat16_brute_force_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let v2: Vec = [0.0f32, 1.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let v3: Vec = [0.0f32, 0.0, 1.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + let query: Vec = [1.0f32, 0.1, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_bfloat16_brute_force_inner_product() { + let params = BruteForceParams::new(4, Metric::InnerProduct); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let v2: Vec = [0.5f32, 0.5, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_bfloat16_brute_force_cosine() { + let params = BruteForceParams::new(4, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let v2: Vec = [0.0f32, 1.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_bfloat16_hnsw_basic() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(16); + let mut index = HnswSingle::::new(params); + + for i in 0..4 { + let mut v = vec![BFloat16::from_f32(0.0); 4]; + v[i] = BFloat16::from_f32(1.0); + index.add_vector(&v, i as u64 + 1).unwrap(); + } + + let query: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_bfloat16_large_range() { + // BFloat16 has the same exponent range as f32 + let large = BFloat16::from_f32(1e30); + let small = BFloat16::from_f32(1e-30); + + assert!(large.to_f32() > 1e29); + assert!(small.to_f32() < 1e-29); + assert!(small.to_f32() > 0.0); +} + +#[test] +fn test_bfloat16_data_integrity() { + // Note: Serialization is only implemented for f32 + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10 { + let v: Vec = (0..4) + .map(|j| BFloat16::from_f32((i * 4 + j) as f32)) + .collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Verify vectors can be retrieved with approximate equality (BF16 has limited precision) + for i in 0..10 { + assert!(index.contains(i as u64)); + let v = index.get_vector(i as u64).unwrap(); + for (j, &val) in v.iter().enumerate() { + let expected = (i * 4 + j) as f32; + assert!((val.to_f32() - expected).abs() < 0.5); + } + } +} + +#[test] +fn test_bfloat16_range_query() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..5 { + let v: Vec = vec![ + BFloat16::from_f32(i as f32), + BFloat16::from_f32(0.0), + BFloat16::from_f32(0.0), + BFloat16::from_f32(0.0), + ]; + index.add_vector(&v, i as u64).unwrap(); + } + + let query: Vec = [0.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| BFloat16::from_f32(x)) + .collect(); + let results = index.range_query(&query, 1.5, None).unwrap(); + assert_eq!(results.results.len(), 2); // Distance 0 and 1 +} + +#[test] +fn test_bfloat16_special_values() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = + [BFloat16::ZERO, BFloat16::ZERO, BFloat16::ZERO, BFloat16::ONE].to_vec(); + index.add_vector(&v1, 1).unwrap(); + + let query: Vec = + [BFloat16::ZERO, BFloat16::ZERO, BFloat16::ZERO, BFloat16::ONE].to_vec(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); +} + +// ============================================================================= +// Int8 Tests +// ============================================================================= + +#[test] +fn test_int8_brute_force_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [127i8, 0, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let v2: Vec = [0i8, 127, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let v3: Vec = [0i8, 0, 127, 0].iter().map(|&x| Int8::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + let query: Vec = [100i8, 10, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int8_brute_force_inner_product() { + let params = BruteForceParams::new(4, Metric::InnerProduct); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [100i8, 0, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let v2: Vec = [50i8, 50, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [100i8, 0, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int8_brute_force_cosine() { + let params = BruteForceParams::new(4, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [100i8, 0, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let v2: Vec = [0i8, 100, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [100i8, 0, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int8_hnsw_basic() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(16); + let mut index = HnswSingle::::new(params); + + for i in 0..4 { + let mut v = vec![Int8::new(0); 4]; + v[i] = Int8::new(100); + index.add_vector(&v, i as u64 + 1).unwrap(); + } + + let query: Vec = [100i8, 0, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int8_boundary_values() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Test with boundary values + let v1: Vec = [Int8::MAX, Int8::MIN, Int8::ZERO, Int8::new(1)].to_vec(); + let v2: Vec = [Int8::MIN, Int8::MAX, Int8::ZERO, Int8::new(-1)].to_vec(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + // Query close to v1 + let query: Vec = [Int8::MAX, Int8::MIN, Int8::ZERO, Int8::ZERO].to_vec(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int8_negative_values() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [-100i8, -50, 0, 50] + .iter() + .map(|&x| Int8::new(x)) + .collect(); + let v2: Vec = [100i8, 50, 0, -50] + .iter() + .map(|&x| Int8::new(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [-100i8, -50, 0, 50] + .iter() + .map(|&x| Int8::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); +} + +#[test] +fn test_int8_data_integrity() { + // Note: Serialization is only implemented for f32 + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10i8 { + let v: Vec = (0..4).map(|j| Int8::new((i * 4 + j) % 127)).collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Verify vectors can be retrieved + for i in 0..10i8 { + assert!(index.contains(i as u64)); + let v = index.get_vector(i as u64).unwrap(); + for (j, &val) in v.iter().enumerate() { + let expected = (i * 4 + j as i8) % 127; + assert_eq!(val.get(), expected); + } + } +} + +#[test] +fn test_int8_multi_value() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + let v1: Vec = [100i8, 0, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let v2: Vec = [90i8, 10, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let v3: Vec = [0i8, 100, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 1).unwrap(); // Same label + index.add_vector(&v3, 2).unwrap(); + + assert_eq!(index.label_count(1), 2); + + let query: Vec = [100i8, 0, 0, 0].iter().map(|&x| Int8::new(x)).collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +// ============================================================================= +// UInt8 Tests +// ============================================================================= + +#[test] +fn test_uint8_brute_force_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [255u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let v2: Vec = [0u8, 255, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let v3: Vec = [0u8, 0, 255, 0].iter().map(|&x| UInt8::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + let query: Vec = [200u8, 10, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_uint8_brute_force_inner_product() { + let params = BruteForceParams::new(4, Metric::InnerProduct); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [200u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let v2: Vec = [100u8, 100, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [200u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_uint8_brute_force_cosine() { + let params = BruteForceParams::new(4, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [200u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let v2: Vec = [0u8, 200, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [200u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_uint8_hnsw_basic() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(16); + let mut index = HnswSingle::::new(params); + + for i in 0..4 { + let mut v = vec![UInt8::new(0); 4]; + v[i] = UInt8::new(200); + index.add_vector(&v, i as u64 + 1).unwrap(); + } + + let query: Vec = [200u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_uint8_boundary_values() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [UInt8::MAX, UInt8::MIN, UInt8::ZERO, UInt8::new(128)].to_vec(); + let v2: Vec = [UInt8::MIN, UInt8::MAX, UInt8::new(128), UInt8::ZERO].to_vec(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [UInt8::MAX, UInt8::MIN, UInt8::ZERO, UInt8::new(128)].to_vec(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); +} + +#[test] +fn test_uint8_data_integrity() { + // Note: Serialization is only implemented for f32 + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10u8 { + let v: Vec = (0..4).map(|j| UInt8::new((i * 4 + j) % 255)).collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Verify vectors can be retrieved + for i in 0..10u8 { + assert!(index.contains(i as u64)); + let v = index.get_vector(i as u64).unwrap(); + for (j, &val) in v.iter().enumerate() { + let expected = (i * 4 + j as u8) % 255; + assert_eq!(val.get(), expected); + } + } +} + +#[test] +fn test_uint8_multi_value() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + let v1: Vec = [200u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let v2: Vec = [190u8, 10, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let v3: Vec = [0u8, 200, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 1).unwrap(); // Same label + index.add_vector(&v3, 2).unwrap(); + + assert_eq!(index.label_count(1), 2); + + let query: Vec = [200u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_uint8_range_query() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..5u8 { + let v: Vec = vec![ + UInt8::new(i * 50), + UInt8::new(0), + UInt8::new(0), + UInt8::new(0), + ]; + index.add_vector(&v, i as u64).unwrap(); + } + + let query: Vec = [0u8, 0, 0, 0].iter().map(|&x| UInt8::new(x)).collect(); + // Range query uses squared L2 distance: + // Vector 0 at [0,0,0,0]: squared distance = 0 + // Vector 1 at [50,0,0,0]: squared distance = 50^2 = 2500 + // So radius 2600 should include first two vectors (squared distances 0 and 2500) + let results = index.range_query(&query, 2600.0, None).unwrap(); + assert_eq!(results.results.len(), 2); +} + +// ============================================================================= +// Int32 Tests +// ============================================================================= + +#[test] +fn test_int32_brute_force_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1000i32, 0, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + let v2: Vec = [0i32, 1000, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + let v3: Vec = [0i32, 0, 1000, 0].iter().map(|&x| Int32::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + let query: Vec = [900i32, 100, 0, 0] + .iter() + .map(|&x| Int32::new(x)) + .collect(); + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int32_brute_force_inner_product() { + let params = BruteForceParams::new(4, Metric::InnerProduct); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1000i32, 0, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + let v2: Vec = [500i32, 500, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [1000i32, 0, 0, 0] + .iter() + .map(|&x| Int32::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int32_brute_force_cosine() { + let params = BruteForceParams::new(4, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1000i32, 0, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + let v2: Vec = [0i32, 1000, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [1000i32, 0, 0, 0] + .iter() + .map(|&x| Int32::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int32_hnsw_basic() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(16); + let mut index = HnswSingle::::new(params); + + for i in 0..4 { + let mut v = vec![Int32::new(0); 4]; + v[i] = Int32::new(1000); + index.add_vector(&v, i as u64 + 1).unwrap(); + } + + let query: Vec = [1000i32, 0, 0, 0] + .iter() + .map(|&x| Int32::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int32_large_values() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1_000_000i32, 0, 0, 0] + .iter() + .map(|&x| Int32::new(x)) + .collect(); + let v2: Vec = [-1_000_000i32, 0, 0, 0] + .iter() + .map(|&x| Int32::new(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [1_000_000i32, 0, 0, 0] + .iter() + .map(|&x| Int32::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); +} + +#[test] +fn test_int32_boundary_values() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Note: Using values that won't overflow when squared + let v1: Vec = [Int32::new(10000), Int32::new(-10000), Int32::ZERO, Int32::new(1)] + .to_vec(); + let v2: Vec = [Int32::new(-10000), Int32::new(10000), Int32::ZERO, Int32::new(-1)] + .to_vec(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [Int32::new(10000), Int32::new(-10000), Int32::ZERO, Int32::ZERO] + .to_vec(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int32_data_integrity() { + // Note: Serialization is only implemented for f32 + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10i32 { + let v: Vec = (0..4).map(|j| Int32::new(i * 1000 + j * 100)).collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Verify vectors can be retrieved + for i in 0..10i32 { + assert!(index.contains(i as u64)); + let v = index.get_vector(i as u64).unwrap(); + for (j, &val) in v.iter().enumerate() { + let expected = i * 1000 + j as i32 * 100; + assert_eq!(val.get(), expected); + } + } +} + +#[test] +fn test_int32_multi_value() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + let v1: Vec = [1000i32, 0, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + let v2: Vec = [900i32, 100, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + let v3: Vec = [0i32, 1000, 0, 0].iter().map(|&x| Int32::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 1).unwrap(); // Same label + index.add_vector(&v3, 2).unwrap(); + + assert_eq!(index.label_count(1), 2); + + let query: Vec = [1000i32, 0, 0, 0] + .iter() + .map(|&x| Int32::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +// ============================================================================= +// Int64 Tests +// ============================================================================= + +#[test] +fn test_int64_brute_force_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [10000i64, 0, 0, 0].iter().map(|&x| Int64::new(x)).collect(); + let v2: Vec = [0i64, 10000, 0, 0].iter().map(|&x| Int64::new(x)).collect(); + let v3: Vec = [0i64, 0, 10000, 0].iter().map(|&x| Int64::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + let query: Vec = [9000i64, 1000, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int64_brute_force_inner_product() { + let params = BruteForceParams::new(4, Metric::InnerProduct); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [10000i64, 0, 0, 0].iter().map(|&x| Int64::new(x)).collect(); + let v2: Vec = [5000i64, 5000, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [10000i64, 0, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int64_brute_force_cosine() { + let params = BruteForceParams::new(4, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [10000i64, 0, 0, 0].iter().map(|&x| Int64::new(x)).collect(); + let v2: Vec = [0i64, 10000, 0, 0].iter().map(|&x| Int64::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [10000i64, 0, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int64_hnsw_basic() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(16); + let mut index = HnswSingle::::new(params); + + for i in 0..4 { + let mut v = vec![Int64::new(0); 4]; + v[i] = Int64::new(10000); + index.add_vector(&v, i as u64 + 1).unwrap(); + } + + let query: Vec = [10000i64, 0, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +#[test] +fn test_int64_large_values_within_f32_precision() { + // Values within f32 exact integer range (2^24) + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [16_000_000i64, 0, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + let v2: Vec = [-16_000_000i64, 0, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + let query: Vec = [16_000_000i64, 0, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); +} + +#[test] +fn test_int64_data_integrity() { + // Note: Serialization is only implemented for f32 + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10i64 { + let v: Vec = (0..4).map(|j| Int64::new(i * 10000 + j * 1000)).collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Verify vectors can be retrieved + for i in 0..10i64 { + assert!(index.contains(i as u64)); + let v = index.get_vector(i as u64).unwrap(); + for (j, &val) in v.iter().enumerate() { + let expected = i * 10000 + j as i64 * 1000; + assert_eq!(val.get(), expected); + } + } +} + +#[test] +fn test_int64_multi_value() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + let v1: Vec = [10000i64, 0, 0, 0].iter().map(|&x| Int64::new(x)).collect(); + let v2: Vec = [9000i64, 1000, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + let v3: Vec = [0i64, 10000, 0, 0].iter().map(|&x| Int64::new(x)).collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 1).unwrap(); // Same label + index.add_vector(&v3, 2).unwrap(); + + assert_eq!(index.label_count(1), 2); + + let query: Vec = [10000i64, 0, 0, 0] + .iter() + .map(|&x| Int64::new(x)) + .collect(); + let results = index.top_k_query(&query, 1, None).unwrap(); + assert_eq!(results.results[0].label, 1); +} + +// ============================================================================= +// Cross-type Consistency Tests +// ============================================================================= + +#[test] +fn test_all_types_produce_consistent_ordering() { + // Test that all types produce the same nearest neighbor ordering + // when given equivalent data + + let dim = 4; + + // Create orthogonal unit vectors as f32 + let vectors_f32: Vec> = create_orthogonal_vectors_f32(dim); + + // BruteForce with f32 + let params = BruteForceParams::new(dim, Metric::L2); + let mut index_f32 = BruteForceSingle::::new(params.clone()); + for (i, v) in vectors_f32.iter().enumerate() { + index_f32.add_vector(v, i as u64 + 1).unwrap(); + } + + // BruteForce with f64 + let mut index_f64 = BruteForceSingle::::new(params.clone()); + for (i, v) in vectors_f32.iter().enumerate() { + let v64: Vec = v.iter().map(|&x| x as f64).collect(); + index_f64.add_vector(&v64, i as u64 + 1).unwrap(); + } + + // BruteForce with Float16 + let mut index_fp16 = BruteForceSingle::::new(params.clone()); + for (i, v) in vectors_f32.iter().enumerate() { + let vfp16: Vec = v.iter().map(|&x| Float16::from_f32(x)).collect(); + index_fp16.add_vector(&vfp16, i as u64 + 1).unwrap(); + } + + // BruteForce with BFloat16 + let mut index_bf16 = BruteForceSingle::::new(params.clone()); + for (i, v) in vectors_f32.iter().enumerate() { + let vbf16: Vec = v.iter().map(|&x| BFloat16::from_f32(x)).collect(); + index_bf16.add_vector(&vbf16, i as u64 + 1).unwrap(); + } + + // Query for nearest neighbor to first vector + let query_f32: Vec = vec![1.0, 0.0, 0.0, 0.0]; + let query_f64: Vec = vec![1.0, 0.0, 0.0, 0.0]; + let query_fp16: Vec = query_f32.iter().map(|&x| Float16::from_f32(x)).collect(); + let query_bf16: Vec = query_f32.iter().map(|&x| BFloat16::from_f32(x)).collect(); + + let result_f32 = index_f32.top_k_query(&query_f32, 1, None).unwrap(); + let result_f64 = index_f64.top_k_query(&query_f64, 1, None).unwrap(); + let result_fp16 = index_fp16.top_k_query(&query_fp16, 1, None).unwrap(); + let result_bf16 = index_bf16.top_k_query(&query_bf16, 1, None).unwrap(); + + // All should return label 1 as nearest + assert_eq!(result_f32.results[0].label, 1); + assert_eq!(result_f64.results[0].label, 1); + assert_eq!(result_fp16.results[0].label, 1); + assert_eq!(result_bf16.results[0].label, 1); +} + +#[test] +fn test_integer_types_consistent_ordering() { + let dim = 4; + + // Create orthogonal vectors with integer values + let vectors_i8: Vec> = (0..dim) + .map(|i| { + let mut v = vec![Int8::new(0); dim]; + v[i] = Int8::new(100); + v + }) + .collect(); + + let vectors_u8: Vec> = (0..dim) + .map(|i| { + let mut v = vec![UInt8::new(0); dim]; + v[i] = UInt8::new(200); + v + }) + .collect(); + + let vectors_i32: Vec> = (0..dim) + .map(|i| { + let mut v = vec![Int32::new(0); dim]; + v[i] = Int32::new(1000); + v + }) + .collect(); + + let vectors_i64: Vec> = (0..dim) + .map(|i| { + let mut v = vec![Int64::new(0); dim]; + v[i] = Int64::new(10000); + v + }) + .collect(); + + let params = BruteForceParams::new(dim, Metric::L2); + + let mut index_i8 = BruteForceSingle::::new(params.clone()); + let mut index_u8 = BruteForceSingle::::new(params.clone()); + let mut index_i32 = BruteForceSingle::::new(params.clone()); + let mut index_i64 = BruteForceSingle::::new(params.clone()); + + for (i, (((vi8, vu8), vi32), vi64)) in vectors_i8 + .iter() + .zip(vectors_u8.iter()) + .zip(vectors_i32.iter()) + .zip(vectors_i64.iter()) + .enumerate() + { + index_i8.add_vector(vi8, i as u64 + 1).unwrap(); + index_u8.add_vector(vu8, i as u64 + 1).unwrap(); + index_i32.add_vector(vi32, i as u64 + 1).unwrap(); + index_i64.add_vector(vi64, i as u64 + 1).unwrap(); + } + + // Query for first axis + let query_i8: Vec = vec![Int8::new(100), Int8::new(0), Int8::new(0), Int8::new(0)]; + let query_u8: Vec = vec![ + UInt8::new(200), + UInt8::new(0), + UInt8::new(0), + UInt8::new(0), + ]; + let query_i32: Vec = vec![ + Int32::new(1000), + Int32::new(0), + Int32::new(0), + Int32::new(0), + ]; + let query_i64: Vec = vec![ + Int64::new(10000), + Int64::new(0), + Int64::new(0), + Int64::new(0), + ]; + + let result_i8 = index_i8.top_k_query(&query_i8, 1, None).unwrap(); + let result_u8 = index_u8.top_k_query(&query_u8, 1, None).unwrap(); + let result_i32 = index_i32.top_k_query(&query_i32, 1, None).unwrap(); + let result_i64 = index_i64.top_k_query(&query_i64, 1, None).unwrap(); + + // All should return label 1 as nearest + assert_eq!(result_i8.results[0].label, 1); + assert_eq!(result_u8.results[0].label, 1); + assert_eq!(result_i32.results[0].label, 1); + assert_eq!(result_i64.results[0].label, 1); +} + +// ============================================================================= +// Memory Size Tests +// ============================================================================= + +#[test] +fn test_data_type_memory_sizes() { + // Verify that smaller types use less memory + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(std::mem::size_of::(), 2); + assert_eq!(std::mem::size_of::(), 4); + assert_eq!(std::mem::size_of::(), 4); + assert_eq!(std::mem::size_of::(), 8); + assert_eq!(std::mem::size_of::(), 8); +} + +// ============================================================================= +// Delete Operation Tests for All Types +// ============================================================================= + +#[test] +fn test_delete_operations_all_types() { + fn test_delete_for_type() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| T::from_f32(x)) + .collect(); + let v2: Vec = [0.0f32, 1.0, 0.0, 0.0] + .iter() + .map(|&x| T::from_f32(x)) + .collect(); + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + index.delete_vector(1).unwrap(); + + let query: Vec = [1.0f32, 0.0, 0.0, 0.0] + .iter() + .map(|&x| T::from_f32(x)) + .collect(); + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.results.len(), 1); + assert_eq!(results.results[0].label, 2); + } + + test_delete_for_type::(); + test_delete_for_type::(); + test_delete_for_type::(); + test_delete_for_type::(); + test_delete_for_type::(); + test_delete_for_type::(); + test_delete_for_type::(); + test_delete_for_type::(); +} + +// ============================================================================= +// Scaling Tests +// ============================================================================= + +#[test] +fn test_scaling_all_types() { + fn test_scaling_for_type(count: usize) { + let dim = 16; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..count { + let v: Vec = (0..dim) + .map(|j| T::from_f32(((i * dim + j) % 1000) as f32 * 0.01)) + .collect(); + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), count); + + let query: Vec = (0..dim).map(|j| T::from_f32(j as f32 * 0.01)).collect(); + let results = index.top_k_query(&query, 10, None).unwrap(); + assert_eq!(results.results.len(), 10); + } + + test_scaling_for_type::(500); + test_scaling_for_type::(500); + test_scaling_for_type::(500); + test_scaling_for_type::(500); + test_scaling_for_type::(500); + test_scaling_for_type::(500); + test_scaling_for_type::(500); + test_scaling_for_type::(500); +} diff --git a/rust/vecsim/src/distance/batch.rs b/rust/vecsim/src/distance/batch.rs new file mode 100644 index 000000000..4d4df33d2 --- /dev/null +++ b/rust/vecsim/src/distance/batch.rs @@ -0,0 +1,312 @@ +//! Batch distance computation for improved SIMD efficiency. +//! +//! This module provides batch distance computation functions that compute +//! distances from a single query to multiple candidate vectors in one call. +//! This improves cache utilization and allows for better SIMD pipelining. + +use crate::distance::DistanceFunction; +use crate::types::{DistanceType, IdType, VectorElement}; + +/// Compute distances from a query to multiple vectors. +/// +/// This function computes distances in batch, which can be more efficient +/// than computing them one at a time due to better cache behavior and +/// SIMD pipelining. +/// +/// Returns a vector of (id, distance) pairs in the same order as input. +#[inline] +pub fn compute_distances_batch<'a, T, D, F>( + query: &[T], + candidates: &[IdType], + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, +) -> Vec<(IdType, D)> +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, +{ + let mut results = Vec::with_capacity(candidates.len()); + + // Process candidates - prefetching is handled separately + for &id in candidates { + if let Some(data) = data_getter(id) { + let dist = dist_fn.compute(data, query, dim); + results.push((id, dist)); + } + } + + results +} + +/// Compute distances from a query to multiple vectors with prefetching. +/// +/// This version prefetches the next vector while computing the current distance, +/// which helps hide memory latency. +#[inline] +pub fn compute_distances_batch_prefetch<'a, T, D, F>( + query: &[T], + candidates: &[IdType], + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, +) -> Vec<(IdType, D)> +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, +{ + use crate::utils::prefetch::prefetch_slice; + + let mut results = Vec::with_capacity(candidates.len()); + let n = candidates.len(); + + if n == 0 { + return results; + } + + // Prefetch first candidate + if let Some(first_data) = data_getter(candidates[0]) { + prefetch_slice(first_data); + } + + for i in 0..n { + let id = candidates[i]; + + // Prefetch next candidate while computing current + if i + 1 < n { + if let Some(next_data) = data_getter(candidates[i + 1]) { + prefetch_slice(next_data); + } + } + + if let Some(data) = data_getter(id) { + let dist = dist_fn.compute(data, query, dim); + results.push((id, dist)); + } + } + + results +} + +/// Compute distances between pairs of vectors (for neighbor selection heuristic). +/// +/// Given a candidate and a list of selected neighbors, compute the distance +/// from the candidate to each selected neighbor. This is used in the +/// diversity-aware neighbor selection heuristic. +#[inline] +pub fn compute_pairwise_distances<'a, T, D, F>( + candidate_id: IdType, + selected_ids: &[IdType], + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, +) -> Vec +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, +{ + let candidate_data = match data_getter(candidate_id) { + Some(d) => d, + None => return vec![D::infinity(); selected_ids.len()], + }; + + selected_ids + .iter() + .map(|&sel_id| { + data_getter(sel_id) + .map(|sel_data| dist_fn.compute(candidate_data, sel_data, dim)) + .unwrap_or_else(D::infinity) + }) + .collect() +} + +/// Check if a candidate is dominated by any selected neighbor. +/// +/// A candidate is "dominated" (and thus should be pruned) if it's closer +/// to any selected neighbor than to the target. This is used in the +/// HNSW heuristic neighbor selection. +/// +/// Returns (is_good, distances) where distances can be reused if needed. +#[inline] +pub fn check_candidate_diversity<'a, T, D, F>( + candidate_id: IdType, + candidate_dist_to_target: D, + selected_ids: &[IdType], + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, +) -> bool +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, +{ + let candidate_data = match data_getter(candidate_id) { + Some(d) => d, + None => return true, // Keep if we can't check + }; + + // Check against each selected neighbor + for &sel_id in selected_ids { + if let Some(sel_data) = data_getter(sel_id) { + let dist_to_selected = dist_fn.compute(candidate_data, sel_data, dim); + if dist_to_selected < candidate_dist_to_target { + return false; // Candidate is dominated + } + } + } + + true // Candidate provides good diversity +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::l2::L2Distance; + use crate::distance::DistanceFunction; + + fn make_test_vectors() -> Vec> { + vec![ + vec![0.0, 0.0, 0.0, 0.0], // id 0 + vec![1.0, 0.0, 0.0, 0.0], // id 1 + vec![0.0, 1.0, 0.0, 0.0], // id 2 + vec![0.0, 0.0, 1.0, 0.0], // id 3 + vec![2.0, 0.0, 0.0, 0.0], // id 4 + ] + } + + #[test] + fn test_compute_distances_batch() { + let vectors = make_test_vectors(); + let dist_fn = L2Distance::::new(4); + let query = vec![0.0f32, 0.0, 0.0, 0.0]; + let candidates = vec![1, 2, 3, 4]; + + let data_getter = |id: IdType| -> Option<&[f32]> { + vectors.get(id as usize).map(|v| v.as_slice()) + }; + + let results = compute_distances_batch( + &query, + &candidates, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + ); + + assert_eq!(results.len(), 4); + assert_eq!(results[0].0, 1); + assert!((results[0].1 - 1.0).abs() < 0.001); // L2^2 to (1,0,0,0) + assert_eq!(results[3].0, 4); + assert!((results[3].1 - 4.0).abs() < 0.001); // L2^2 to (2,0,0,0) + } + + #[test] + fn test_compute_distances_batch_prefetch() { + let vectors = make_test_vectors(); + let dist_fn = L2Distance::::new(4); + let query = vec![0.0f32, 0.0, 0.0, 0.0]; + let candidates = vec![1, 2, 3]; + + let data_getter = |id: IdType| -> Option<&[f32]> { + vectors.get(id as usize).map(|v| v.as_slice()) + }; + + let results = compute_distances_batch_prefetch( + &query, + &candidates, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + ); + + assert_eq!(results.len(), 3); + // All should have distance 1.0 (L2^2 of unit vectors) + for (_, dist) in &results { + assert!((dist - 1.0).abs() < 0.001); + } + } + + #[test] + fn test_check_candidate_diversity_good() { + let vectors = vec![ + vec![0.0f32, 0.0], // target (id 0) + vec![1.0, 0.0], // selected (id 1) + vec![0.0, 1.0], // candidate (id 2) - diverse direction + ]; + let dist_fn = L2Distance::::new(2); + + let data_getter = |id: IdType| -> Option<&[f32]> { + vectors.get(id as usize).map(|v| v.as_slice()) + }; + + // Candidate 2 has distance 1.0 to target, and distance 2.0 to selected[1] + // Since 2.0 > 1.0, candidate provides good diversity + let is_good = check_candidate_diversity( + 2, // candidate + 1.0, // distance to target + &[1], + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + ); + + assert!(is_good); + } + + #[test] + fn test_check_candidate_diversity_bad() { + let vectors = vec![ + vec![0.0f32, 0.0], // target (id 0) + vec![1.0, 0.0], // selected (id 1) + vec![1.1, 0.0], // candidate (id 2) - very close to selected[1] + ]; + let dist_fn = L2Distance::::new(2); + + let data_getter = |id: IdType| -> Option<&[f32]> { + vectors.get(id as usize).map(|v| v.as_slice()) + }; + + // Candidate 2 has distance ~1.21 to target, but only ~0.01 to selected[1] + // Since 0.01 < 1.21, candidate is dominated (not diverse) + let is_good = check_candidate_diversity( + 2, + 1.21, // distance to target + &[1], + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + ); + + assert!(!is_good); + } + + #[test] + fn test_compute_pairwise_distances() { + let vectors = make_test_vectors(); + let dist_fn = L2Distance::::new(4); + + let data_getter = |id: IdType| -> Option<&[f32]> { + vectors.get(id as usize).map(|v| v.as_slice()) + }; + + // Compute distances from candidate 0 to selected [1, 2, 3] + let dists = compute_pairwise_distances( + 0, + &[1, 2, 3], + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + ); + + assert_eq!(dists.len(), 3); + // All distances should be 1.0 (L2^2 from origin to unit vectors) + for d in &dists { + assert!((d - 1.0).abs() < 0.001); + } + } +} + diff --git a/rust/vecsim/src/distance/cosine.rs b/rust/vecsim/src/distance/cosine.rs new file mode 100644 index 000000000..255721200 --- /dev/null +++ b/rust/vecsim/src/distance/cosine.rs @@ -0,0 +1,292 @@ +//! Cosine distance implementation. +//! +//! Cosine similarity is defined as: dot(a, b) / (||a|| * ||b||) +//! Cosine distance is: 1 - cosine_similarity +//! +//! For efficiency, vectors can be pre-normalized during insertion, +//! reducing cosine distance computation to inner product distance. + +use super::simd::{self, SimdCapability}; +use super::{DistanceFunction, Metric}; +use crate::types::{DistanceType, VectorElement}; +use std::marker::PhantomData; + +/// Cosine distance calculator. +/// +/// Returns 1 - cosine_similarity, so values range from 0 (identical direction) +/// to 2 (opposite direction). +pub struct CosineDistance { + dim: usize, + simd_capability: SimdCapability, + _phantom: PhantomData, +} + +impl CosineDistance { + /// Create a new cosine distance calculator with automatic SIMD detection. + pub fn new(dim: usize) -> Self { + Self { + dim, + simd_capability: simd::detect_simd_capability(), + _phantom: PhantomData, + } + } + + /// Create with SIMD enabled (if available). + pub fn with_simd(dim: usize) -> Self { + Self::new(dim) + } + + /// Create with scalar-only implementation. + pub fn scalar(dim: usize) -> Self { + Self { + dim, + simd_capability: SimdCapability::None, + _phantom: PhantomData, + } + } +} + +impl DistanceFunction for CosineDistance { + type Output = T::DistanceType; + + #[inline] + fn compute(&self, a: &[T], b: &[T], dim: usize) -> Self::Output { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + debug_assert_eq!(dim, self.dim); + + // For pre-normalized vectors, this reduces to inner product + // For raw vectors, we need to compute the full cosine distance + match self.simd_capability { + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512Vnni => { + simd::avx512bw::cosine_distance_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512Bw => { + simd::avx512bw::cosine_distance_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512 => { + simd::avx512::cosine_distance_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx2 => { + simd::avx2::cosine_distance_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx => { + simd::avx::cosine_distance_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Sse4_1 => { + simd::sse4::cosine_distance_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Sse => { + simd::sse::cosine_distance_f32(a, b, dim) + } + #[cfg(target_arch = "aarch64")] + SimdCapability::Neon => { + simd::neon::cosine_distance_f32(a, b, dim) + } + #[allow(unreachable_patterns)] + _ => { + cosine_distance_scalar(a, b, dim) + } + } + } + + fn metric(&self) -> Metric { + Metric::Cosine + } + + fn preprocess(&self, vector: &[T], dim: usize) -> Vec { + // For integer types, normalization doesn't work (values round to 0) + // so we store vectors unchanged and use full cosine computation + if !T::can_normalize() { + return vector.to_vec(); + } + // Normalize the vector during preprocessing + normalize_vector(vector, dim) + } + + fn compute_from_preprocessed(&self, stored: &[T], query: &[T], dim: usize) -> Self::Output { + // For integer types that weren't normalized, use full cosine computation + if !T::can_normalize() { + return self.compute(stored, query, dim); + } + // When stored vectors are pre-normalized, we need to normalize query too + // and then it's just 1 - inner_product + let query_normalized = normalize_vector(query, dim); + let ip = inner_product_raw(stored, &query_normalized, dim); + T::DistanceType::from_f64(1.0 - ip) + } +} + +/// Normalize a vector to unit length. +pub fn normalize_vector(vector: &[T], dim: usize) -> Vec { + let norm = compute_norm(vector, dim); + if norm < 1e-30 { + // Avoid division by zero for zero vectors + return vector.to_vec(); + } + let inv_norm = 1.0 / norm; + vector + .iter() + .map(|&x| T::from_f32(x.to_f32() * inv_norm as f32)) + .collect() +} + +/// Compute the L2 norm of a vector. +#[inline] +pub fn compute_norm(vector: &[T], dim: usize) -> f64 { + let mut sum = 0.0f64; + for v in vector.iter().take(dim) { + let v = v.to_f32() as f64; + sum += v * v; + } + sum.sqrt() +} + +/// Raw inner product without conversion to distance. +#[inline] +fn inner_product_raw(a: &[T], b: &[T], dim: usize) -> f64 { + let mut sum = 0.0f64; + for i in 0..dim { + sum += a[i].to_f32() as f64 * b[i].to_f32() as f64; + } + sum +} + +/// Scalar implementation of cosine distance. +#[inline] +pub fn cosine_distance_scalar(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + let mut dot = 0.0f64; + let mut norm_a = 0.0f64; + let mut norm_b = 0.0f64; + + for i in 0..dim { + let va = a[i].to_f32() as f64; + let vb = b[i].to_f32() as f64; + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + // Handle zero vectors + return T::DistanceType::from_f64(1.0); + } + + let cosine_sim = dot / denom; + // Clamp to [-1, 1] to handle floating point errors + let cosine_sim = cosine_sim.clamp(-1.0, 1.0); + T::DistanceType::from_f64(1.0 - cosine_sim) +} + +/// Optimized scalar with loop unrolling for f32. +#[inline] +pub fn cosine_distance_scalar_f32(a: &[f32], b: &[f32], dim: usize) -> f32 { + let mut dot0 = 0.0f32; + let mut dot1 = 0.0f32; + let mut norm_a0 = 0.0f32; + let mut norm_a1 = 0.0f32; + let mut norm_b0 = 0.0f32; + let mut norm_b1 = 0.0f32; + + let unroll = dim / 2 * 2; + let mut i = 0; + + while i < unroll { + let va0 = a[i]; + let va1 = a[i + 1]; + let vb0 = b[i]; + let vb1 = b[i + 1]; + + dot0 += va0 * vb0; + dot1 += va1 * vb1; + norm_a0 += va0 * va0; + norm_a1 += va1 * va1; + norm_b0 += vb0 * vb0; + norm_b1 += vb1 * vb1; + + i += 2; + } + + // Handle remaining elements + while i < dim { + let va = a[i]; + let vb = b[i]; + dot0 += va * vb; + norm_a0 += va * va; + norm_b0 += vb * vb; + i += 1; + } + + let dot = dot0 + dot1; + let norm_a = norm_a0 + norm_a1; + let norm_b = norm_b0 + norm_b1; + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cosine_identical() { + let a = vec![1.0f32, 2.0, 3.0, 4.0]; + let dist = cosine_distance_scalar(&a, &a, 4); + assert!(dist.abs() < 0.001); + } + + #[test] + fn test_cosine_orthogonal() { + let a = vec![1.0f32, 0.0, 0.0]; + let b = vec![0.0f32, 1.0, 0.0]; + let dist = cosine_distance_scalar(&a, &b, 3); + assert!((dist - 1.0).abs() < 0.001); + } + + #[test] + fn test_cosine_opposite() { + let a = vec![1.0f32, 0.0, 0.0]; + let b = vec![-1.0f32, 0.0, 0.0]; + let dist = cosine_distance_scalar(&a, &b, 3); + assert!((dist - 2.0).abs() < 0.001); + } + + #[test] + fn test_normalize_vector() { + let a = vec![3.0f32, 4.0, 0.0]; + let normalized = normalize_vector(&a, 3); + // Norm should be 5, so normalized is [0.6, 0.8, 0.0] + assert!((normalized[0] - 0.6).abs() < 0.001); + assert!((normalized[1] - 0.8).abs() < 0.001); + assert!(normalized[2].abs() < 0.001); + + // Verify unit length + let norm = compute_norm(&normalized, 3); + assert!((norm - 1.0).abs() < 0.001); + } + + #[test] + fn test_cosine_distance_function() { + let dist_fn = CosineDistance::::scalar(3); + + // Same direction should have distance 0 + let a = vec![1.0f32, 1.0, 1.0]; + let b = vec![2.0f32, 2.0, 2.0]; + let dist = dist_fn.compute(&a, &b, 3); + assert!(dist.abs() < 0.001); + } +} diff --git a/rust/vecsim/src/distance/ip.rs b/rust/vecsim/src/distance/ip.rs new file mode 100644 index 000000000..4e7e28617 --- /dev/null +++ b/rust/vecsim/src/distance/ip.rs @@ -0,0 +1,224 @@ +//! Inner product (dot product) distance implementation. +//! +//! Inner product is defined as: sum(a[i] * b[i]) +//! For use as a distance metric (lower is better), we return the negative +//! inner product: -sum(a[i] * b[i]) +//! +//! This metric is particularly useful for normalized vectors where it +//! corresponds to cosine similarity. + +use super::simd::{self, SimdCapability}; +use super::{DistanceFunction, Metric}; +use crate::types::{DistanceType, VectorElement}; +use std::marker::PhantomData; + +/// Inner product distance calculator. +/// +/// Returns the negative inner product so that lower values indicate +/// more similar vectors (consistent with other distance metrics). +pub struct InnerProductDistance { + dim: usize, + simd_capability: SimdCapability, + _phantom: PhantomData, +} + +impl InnerProductDistance { + /// Create a new inner product distance calculator with automatic SIMD detection. + pub fn new(dim: usize) -> Self { + Self { + dim, + simd_capability: simd::detect_simd_capability(), + _phantom: PhantomData, + } + } + + /// Create with SIMD enabled (if available). + pub fn with_simd(dim: usize) -> Self { + Self::new(dim) + } + + /// Create with scalar-only implementation. + pub fn scalar(dim: usize) -> Self { + Self { + dim, + simd_capability: SimdCapability::None, + _phantom: PhantomData, + } + } +} + +impl DistanceFunction for InnerProductDistance { + type Output = T::DistanceType; + + #[inline] + fn compute(&self, a: &[T], b: &[T], dim: usize) -> Self::Output { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + debug_assert_eq!(dim, self.dim); + + // Compute inner product and negate for distance + let ip = match self.simd_capability { + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512Vnni => { + simd::avx512bw::inner_product_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512Bw => { + simd::avx512bw::inner_product_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512 => { + simd::avx512::inner_product_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx2 => { + simd::avx2::inner_product_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx => { + simd::avx::inner_product_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Sse4_1 => { + simd::sse4::inner_product_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Sse => { + simd::sse::inner_product_f32(a, b, dim) + } + #[cfg(target_arch = "aarch64")] + SimdCapability::Neon => { + simd::neon::inner_product_f32(a, b, dim) + } + #[allow(unreachable_patterns)] + _ => { + inner_product_scalar(a, b, dim) + } + }; + + // Return 1 - ip to convert similarity to distance + // For normalized vectors: 1 - cos(θ) ranges from 0 (identical) to 2 (opposite) + T::DistanceType::from_f64(1.0 - ip.to_f64()) + } + + fn metric(&self) -> Metric { + Metric::InnerProduct + } +} + +/// Scalar implementation of inner product. +#[inline] +pub fn inner_product_scalar(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + let mut sum = 0.0f64; + for i in 0..dim { + sum += a[i].to_f32() as f64 * b[i].to_f32() as f64; + } + T::DistanceType::from_f64(sum) +} + +/// Optimized scalar with loop unrolling for f32. +#[inline] +pub fn inner_product_scalar_f32(a: &[f32], b: &[f32], dim: usize) -> f32 { + let mut sum0 = 0.0f32; + let mut sum1 = 0.0f32; + let mut sum2 = 0.0f32; + let mut sum3 = 0.0f32; + + let unroll = dim / 4 * 4; + let mut i = 0; + + while i < unroll { + sum0 += a[i] * b[i]; + sum1 += a[i + 1] * b[i + 1]; + sum2 += a[i + 2] * b[i + 2]; + sum3 += a[i + 3] * b[i + 3]; + i += 4; + } + + // Handle remaining elements + while i < dim { + sum0 += a[i] * b[i]; + i += 1; + } + + sum0 + sum1 + sum2 + sum3 +} + +/// Optimized scalar with loop unrolling for f64. +#[inline] +pub fn inner_product_scalar_f64(a: &[f64], b: &[f64], dim: usize) -> f64 { + let mut sum0 = 0.0f64; + let mut sum1 = 0.0f64; + let mut sum2 = 0.0f64; + let mut sum3 = 0.0f64; + + let unroll = dim / 4 * 4; + let mut i = 0; + + while i < unroll { + sum0 += a[i] * b[i]; + sum1 += a[i + 1] * b[i + 1]; + sum2 += a[i + 2] * b[i + 2]; + sum3 += a[i + 3] * b[i + 3]; + i += 4; + } + + // Handle remaining elements + while i < dim { + sum0 += a[i] * b[i]; + i += 1; + } + + sum0 + sum1 + sum2 + sum3 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_inner_product_scalar() { + let a = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = vec![1.0f32, 1.0, 1.0, 1.0]; + + let ip = inner_product_scalar(&a, &b, 4); + // 1*1 + 2*1 + 3*1 + 4*1 = 10 + assert!((ip - 10.0).abs() < 0.001); + } + + #[test] + fn test_inner_product_normalized() { + // Two normalized vectors + let a = vec![1.0f32, 0.0, 0.0]; + let b = vec![1.0f32, 0.0, 0.0]; + + let ip = inner_product_scalar(&a, &b, 3); + // Should be 1.0 for identical normalized vectors + assert!((ip - 1.0).abs() < 0.001); + } + + #[test] + fn test_inner_product_orthogonal() { + let a = vec![1.0f32, 0.0, 0.0]; + let b = vec![0.0f32, 1.0, 0.0]; + + let ip = inner_product_scalar(&a, &b, 3); + // Should be 0.0 for orthogonal vectors + assert!(ip.abs() < 0.001); + } + + #[test] + fn test_inner_product_distance() { + let dist_fn = InnerProductDistance::::scalar(3); + + // Identical normalized vectors should have distance close to 0 + let a = vec![1.0f32, 0.0, 0.0]; + let dist = dist_fn.compute(&a, &a, 3); + assert!(dist.abs() < 0.001); + + // Orthogonal vectors should have distance 1 + let b = vec![0.0f32, 1.0, 0.0]; + let dist = dist_fn.compute(&a, &b, 3); + assert!((dist - 1.0).abs() < 0.001); + } +} diff --git a/rust/vecsim/src/distance/l2.rs b/rust/vecsim/src/distance/l2.rs new file mode 100644 index 000000000..b5640dde9 --- /dev/null +++ b/rust/vecsim/src/distance/l2.rs @@ -0,0 +1,220 @@ +//! L2 (Euclidean) squared distance implementation. +//! +//! L2 distance is defined as: sqrt(sum((a[i] - b[i])^2)) +//! For efficiency, we compute the squared L2 distance (without sqrt) +//! since the ordering is preserved and sqrt is expensive. + +use super::simd::{self, SimdCapability}; +use super::{DistanceFunction, Metric}; +use crate::types::{DistanceType, VectorElement}; +use std::marker::PhantomData; + +/// L2 (Euclidean) squared distance calculator. +pub struct L2Distance { + dim: usize, + simd_capability: SimdCapability, + _phantom: PhantomData, +} + +impl L2Distance { + /// Create a new L2 distance calculator with automatic SIMD detection. + pub fn new(dim: usize) -> Self { + Self { + dim, + simd_capability: simd::detect_simd_capability(), + _phantom: PhantomData, + } + } + + /// Create with SIMD enabled (if available). + pub fn with_simd(dim: usize) -> Self { + Self::new(dim) + } + + /// Create with scalar-only implementation. + pub fn scalar(dim: usize) -> Self { + Self { + dim, + simd_capability: SimdCapability::None, + _phantom: PhantomData, + } + } +} + +impl DistanceFunction for L2Distance { + type Output = T::DistanceType; + + #[inline] + fn compute(&self, a: &[T], b: &[T], dim: usize) -> Self::Output { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + debug_assert_eq!(dim, self.dim); + + // Dispatch to appropriate implementation + match self.simd_capability { + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512Vnni => { + simd::avx512bw::l2_squared_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512Bw => { + simd::avx512bw::l2_squared_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512 => { + simd::avx512::l2_squared_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx2 => { + simd::avx2::l2_squared_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx => { + simd::avx::l2_squared_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Sse4_1 => { + simd::sse4::l2_squared_f32(a, b, dim) + } + #[cfg(target_arch = "x86_64")] + SimdCapability::Sse => { + simd::sse::l2_squared_f32(a, b, dim) + } + #[cfg(target_arch = "aarch64")] + SimdCapability::Neon => { + simd::neon::l2_squared_f32(a, b, dim) + } + #[allow(unreachable_patterns)] + _ => { + l2_squared_scalar(a, b, dim) + } + } + } + + fn metric(&self) -> Metric { + Metric::L2 + } +} + +/// Scalar implementation of L2 squared distance. +#[inline] +pub fn l2_squared_scalar(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + let mut sum = 0.0f64; + for i in 0..dim { + let diff = a[i].to_f32() as f64 - b[i].to_f32() as f64; + sum += diff * diff; + } + T::DistanceType::from_f64(sum) +} + +/// Optimized scalar with loop unrolling for f32. +#[inline] +pub fn l2_squared_scalar_f32(a: &[f32], b: &[f32], dim: usize) -> f32 { + let mut sum0 = 0.0f32; + let mut sum1 = 0.0f32; + let mut sum2 = 0.0f32; + let mut sum3 = 0.0f32; + + let unroll = dim / 4 * 4; + let mut i = 0; + + while i < unroll { + let d0 = a[i] - b[i]; + let d1 = a[i + 1] - b[i + 1]; + let d2 = a[i + 2] - b[i + 2]; + let d3 = a[i + 3] - b[i + 3]; + + sum0 += d0 * d0; + sum1 += d1 * d1; + sum2 += d2 * d2; + sum3 += d3 * d3; + + i += 4; + } + + // Handle remaining elements + while i < dim { + let d = a[i] - b[i]; + sum0 += d * d; + i += 1; + } + + sum0 + sum1 + sum2 + sum3 +} + +/// Optimized scalar with loop unrolling for f64. +#[inline] +pub fn l2_squared_scalar_f64(a: &[f64], b: &[f64], dim: usize) -> f64 { + let mut sum0 = 0.0f64; + let mut sum1 = 0.0f64; + let mut sum2 = 0.0f64; + let mut sum3 = 0.0f64; + + let unroll = dim / 4 * 4; + let mut i = 0; + + while i < unroll { + let d0 = a[i] - b[i]; + let d1 = a[i + 1] - b[i + 1]; + let d2 = a[i + 2] - b[i + 2]; + let d3 = a[i + 3] - b[i + 3]; + + sum0 += d0 * d0; + sum1 += d1 * d1; + sum2 += d2 * d2; + sum3 += d3 * d3; + + i += 4; + } + + // Handle remaining elements + while i < dim { + let d = a[i] - b[i]; + sum0 += d * d; + i += 1; + } + + sum0 + sum1 + sum2 + sum3 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_l2_scalar() { + let a = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = vec![5.0f32, 6.0, 7.0, 8.0]; + + let dist = l2_squared_scalar(&a, &b, 4); + // (5-1)^2 + (6-2)^2 + (7-3)^2 + (8-4)^2 = 16 + 16 + 16 + 16 = 64 + assert!((dist - 64.0).abs() < 0.001); + } + + #[test] + fn test_l2_scalar_f32() { + let a = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = vec![5.0f32, 6.0, 7.0, 8.0]; + + let dist = l2_squared_scalar_f32(&a, &b, 4); + assert!((dist - 64.0).abs() < 0.001); + } + + #[test] + fn test_l2_identical_vectors() { + let a = vec![1.0f32, 2.0, 3.0, 4.0]; + let dist = l2_squared_scalar(&a, &a, 4); + assert!(dist.abs() < 0.0001); + } + + #[test] + fn test_l2_distance_function() { + let dist_fn = L2Distance::::scalar(4); + let a = vec![0.0f32, 0.0, 0.0, 0.0]; + let b = vec![3.0f32, 4.0, 0.0, 0.0]; + + let dist = dist_fn.compute(&a, &b, 4); + // 3^2 + 4^2 = 9 + 16 = 25 + assert!((dist - 25.0).abs() < 0.001); + } +} diff --git a/rust/vecsim/src/distance/mod.rs b/rust/vecsim/src/distance/mod.rs new file mode 100644 index 000000000..a89ded091 --- /dev/null +++ b/rust/vecsim/src/distance/mod.rs @@ -0,0 +1,275 @@ +//! Distance metric implementations for vector similarity. +//! +//! This module provides various distance/similarity metrics: +//! - L2 (Euclidean) distance +//! - Inner product (dot product) similarity +//! - Cosine similarity/distance +//! +//! Each metric has scalar and SIMD-optimized implementations. +//! +//! The `batch` submodule provides batch distance computation for improved +//! cache efficiency and SIMD pipelining. + +pub mod batch; +pub mod cosine; +pub mod ip; +pub mod l2; +pub mod simd; + +use crate::types::{DistanceType, VectorElement}; + +/// Distance/similarity metric types. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Metric { + /// L2 (Euclidean) squared distance. + /// Lower values indicate more similar vectors. + L2, + /// Inner product (dot product). + /// For normalized vectors, higher values indicate more similar vectors. + /// Note: This returns the negative inner product for use as a distance. + InnerProduct, + /// Cosine similarity converted to distance. + /// Returns 1 - cosine_similarity, so lower values indicate more similar vectors. + Cosine, +} + +impl Metric { + /// Check if this metric uses lower-is-better ordering. + /// + /// L2 and Cosine use lower-is-better (distance). + /// Inner product uses higher-is-better (similarity), but we negate it internally. + pub fn lower_is_better(&self) -> bool { + true // All metrics are converted to distances internally + } + + /// Get a human-readable name for the metric. + pub fn name(&self) -> &'static str { + match self { + Metric::L2 => "L2", + Metric::InnerProduct => "IP", + Metric::Cosine => "Cosine", + } + } +} + +impl std::fmt::Display for Metric { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// Trait for distance computation functions. +pub trait DistanceFunction: Send + Sync { + /// The output distance type. + type Output: DistanceType; + + /// Compute the distance between two vectors. + /// + /// # Safety + /// Both vectors must have the same length as specified in `dim`. + fn compute(&self, a: &[T], b: &[T], dim: usize) -> Self::Output; + + /// Get the metric type. + fn metric(&self) -> Metric; + + /// Pre-process a vector before storage (e.g., normalize for cosine). + /// Returns the processed vector. + fn preprocess(&self, vector: &[T], _dim: usize) -> Vec { + // Default implementation: no preprocessing + vector.to_vec() + } + + /// Compute distance from a pre-processed vector to a raw query. + /// Default uses regular compute. + fn compute_from_preprocessed(&self, stored: &[T], query: &[T], dim: usize) -> Self::Output { + self.compute(stored, query, dim) + } +} + +/// Normalize a vector to unit length (L2 norm = 1). +/// +/// This is useful for preparing vectors for cosine similarity search, +/// or when you need to ensure vectors have unit norm. +/// +/// Returns `None` if the vector has zero norm (cannot be normalized). +/// +/// # Example +/// ``` +/// use vecsim::distance::normalize; +/// +/// let v = vec![3.0f32, 4.0]; +/// let normalized = normalize(&v).unwrap(); +/// assert!((normalized[0] - 0.6).abs() < 0.001); +/// assert!((normalized[1] - 0.8).abs() < 0.001); +/// ``` +pub fn normalize(vector: &[T]) -> Option> { + let norm_sq: f64 = vector.iter().map(|&x| { + let f = x.to_f32() as f64; + f * f + }).sum(); + if norm_sq == 0.0 { + return None; + } + let norm = norm_sq.sqrt(); + Some(vector.iter().map(|&x| { + T::from_f32((x.to_f32() as f64 / norm) as f32) + }).collect()) +} + +/// Normalize a vector in place to unit length. +/// +/// Returns `false` if the vector has zero norm (not modified). +pub fn normalize_in_place(vector: &mut [T]) -> bool { + let norm_sq: f64 = vector.iter().map(|&x| { + let f = x.to_f32() as f64; + f * f + }).sum(); + if norm_sq == 0.0 { + return false; + } + let norm = norm_sq.sqrt(); + for x in vector.iter_mut() { + *x = T::from_f32((x.to_f32() as f64 / norm) as f32); + } + true +} + +/// Compute the L2 norm (magnitude) of a vector. +pub fn l2_norm(vector: &[T]) -> f64 { + let norm_sq: f64 = vector.iter().map(|&x| { + let f = x.to_f32() as f64; + f * f + }).sum(); + norm_sq.sqrt() +} + +/// Compute the dot product (inner product) of two vectors. +/// +/// Returns the sum of element-wise products. +pub fn dot_product(a: &[T], b: &[T]) -> f64 { + assert_eq!(a.len(), b.len(), "Vectors must have the same length"); + a.iter() + .zip(b.iter()) + .map(|(&x, &y)| (x.to_f32() as f64) * (y.to_f32() as f64)) + .sum() +} + +/// Compute the cosine similarity between two vectors. +/// +/// Returns a value in [-1, 1] where 1 means identical direction, +/// 0 means orthogonal, and -1 means opposite direction. +/// +/// Returns `None` if either vector has zero norm. +pub fn cosine_similarity(a: &[T], b: &[T]) -> Option { + let dot = dot_product(a, b); + let norm_a = l2_norm(a); + let norm_b = l2_norm(b); + + if norm_a == 0.0 || norm_b == 0.0 { + return None; + } + + Some(dot / (norm_a * norm_b)) +} + +/// Compute the Euclidean distance between two vectors. +/// +/// This is the square root of the L2 squared distance. +pub fn euclidean_distance(a: &[T], b: &[T]) -> f64 { + assert_eq!(a.len(), b.len(), "Vectors must have the same length"); + let sum_sq: f64 = a.iter() + .zip(b.iter()) + .map(|(&x, &y)| { + let diff = (x.to_f32() as f64) - (y.to_f32() as f64); + diff * diff + }) + .sum(); + sum_sq.sqrt() +} + +/// Compute the L2 squared distance between two vectors. +/// +/// This is more efficient than Euclidean distance when you don't need +/// the actual distance value (e.g., for comparisons). +pub fn l2_squared(a: &[T], b: &[T]) -> f64 { + assert_eq!(a.len(), b.len(), "Vectors must have the same length"); + a.iter() + .zip(b.iter()) + .map(|(&x, &y)| { + let diff = (x.to_f32() as f64) - (y.to_f32() as f64); + diff * diff + }) + .sum() +} + +/// Normalize multiple vectors in batch. +/// +/// Returns a vector of normalized vectors, skipping any that have zero norm. +pub fn batch_normalize(vectors: &[Vec]) -> Vec> { + vectors.iter() + .filter_map(|v| normalize(v)) + .collect() +} + +/// Create a distance function for the given metric and element type. +pub fn create_distance_function( + metric: Metric, + dim: usize, +) -> Box> +where + T::DistanceType: DistanceType, +{ + // Select the best implementation based on available SIMD features + let use_simd = simd::is_simd_available(); + + match metric { + Metric::L2 => { + if use_simd { + Box::new(l2::L2Distance::::with_simd(dim)) + } else { + Box::new(l2::L2Distance::::scalar(dim)) + } + } + Metric::InnerProduct => { + if use_simd { + Box::new(ip::InnerProductDistance::::with_simd(dim)) + } else { + Box::new(ip::InnerProductDistance::::scalar(dim)) + } + } + Metric::Cosine => { + if use_simd { + Box::new(cosine::CosineDistance::::with_simd(dim)) + } else { + Box::new(cosine::CosineDistance::::scalar(dim)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metric_properties() { + assert!(Metric::L2.lower_is_better()); + assert!(Metric::InnerProduct.lower_is_better()); + assert!(Metric::Cosine.lower_is_better()); + } + + #[test] + fn test_create_distance_functions() { + let l2: Box> = + create_distance_function(Metric::L2, 128); + assert_eq!(l2.metric(), Metric::L2); + + let ip: Box> = + create_distance_function(Metric::InnerProduct, 128); + assert_eq!(ip.metric(), Metric::InnerProduct); + + let cos: Box> = + create_distance_function(Metric::Cosine, 128); + assert_eq!(cos.metric(), Metric::Cosine); + } +} diff --git a/rust/vecsim/src/distance/simd/avx.rs b/rust/vecsim/src/distance/simd/avx.rs new file mode 100644 index 000000000..ed17a9305 --- /dev/null +++ b/rust/vecsim/src/distance/simd/avx.rs @@ -0,0 +1,301 @@ +//! AVX (AVX1) SIMD implementations for distance functions. +//! +//! AVX provides 256-bit vectors, processing 8 f32 values at a time. +//! Unlike AVX2, AVX1 does not include FMA instructions, so we use +//! separate multiply and add operations. +//! +//! Available on Intel Sandy Bridge+ and AMD Bulldozer+ processors. + +#![cfg(target_arch = "x86_64")] + +use crate::types::{DistanceType, VectorElement}; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +/// AVX1 L2 squared distance for f32 vectors. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - AVX must be available (checked at runtime by caller) +#[target_feature(enable = "avx")] +#[inline] +pub unsafe fn l2_squared_f32_avx(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm256_setzero_ps(); + let chunks = dim / 8; + let remainder = dim % 8; + + // Process 8 elements at a time + for i in 0..chunks { + let offset = i * 8; + let va = _mm256_loadu_ps(a.add(offset)); + let vb = _mm256_loadu_ps(b.add(offset)); + let diff = _mm256_sub_ps(va, vb); + // AVX1 doesn't have FMA, so use mul + add + let sq = _mm256_mul_ps(diff, diff); + sum = _mm256_add_ps(sum, sq); + } + + // Horizontal sum + let mut result = hsum256_ps_avx(sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + let diff = *a.add(base + i) - *b.add(base + i); + result += diff * diff; + } + + result +} + +/// AVX1 inner product for f32 vectors. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - AVX must be available (checked at runtime by caller) +#[target_feature(enable = "avx")] +#[inline] +pub unsafe fn inner_product_f32_avx(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm256_setzero_ps(); + let chunks = dim / 8; + let remainder = dim % 8; + + // Process 8 elements at a time + for i in 0..chunks { + let offset = i * 8; + let va = _mm256_loadu_ps(a.add(offset)); + let vb = _mm256_loadu_ps(b.add(offset)); + // AVX1 doesn't have FMA, so use mul + add + let prod = _mm256_mul_ps(va, vb); + sum = _mm256_add_ps(sum, prod); + } + + // Horizontal sum + let mut result = hsum256_ps_avx(sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + result += *a.add(base + i) * *b.add(base + i); + } + + result +} + +/// AVX1 cosine distance for f32 vectors. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - AVX must be available (checked at runtime by caller) +#[target_feature(enable = "avx")] +#[inline] +pub unsafe fn cosine_distance_f32_avx(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut dot_sum = _mm256_setzero_ps(); + let mut norm_a_sum = _mm256_setzero_ps(); + let mut norm_b_sum = _mm256_setzero_ps(); + + let chunks = dim / 8; + let remainder = dim % 8; + + // Process 8 elements at a time + for i in 0..chunks { + let offset = i * 8; + let va = _mm256_loadu_ps(a.add(offset)); + let vb = _mm256_loadu_ps(b.add(offset)); + + // AVX1 doesn't have FMA, so use mul + add + let dot_prod = _mm256_mul_ps(va, vb); + dot_sum = _mm256_add_ps(dot_sum, dot_prod); + + let norm_a_prod = _mm256_mul_ps(va, va); + norm_a_sum = _mm256_add_ps(norm_a_sum, norm_a_prod); + + let norm_b_prod = _mm256_mul_ps(vb, vb); + norm_b_sum = _mm256_add_ps(norm_b_sum, norm_b_prod); + } + + // Horizontal sums + let mut dot = hsum256_ps_avx(dot_sum); + let mut norm_a = hsum256_ps_avx(norm_a_sum); + let mut norm_b = hsum256_ps_avx(norm_b_sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).max(-1.0).min(1.0); + 1.0 - cosine_sim +} + +/// Horizontal sum of 8 f32 values in a 256-bit register (AVX1 version). +#[target_feature(enable = "avx")] +#[inline] +unsafe fn hsum256_ps_avx(v: __m256) -> f32 { + // Extract high 128 bits and add to low 128 bits + let high = _mm256_extractf128_ps(v, 1); + let low = _mm256_castps256_ps128(v); + let sum128 = _mm_add_ps(high, low); + + // Horizontal add within 128 bits + let shuf = _mm_movehdup_ps(sum128); // [1,1,3,3] + let sums = _mm_add_ps(sum128, shuf); // [0+1,1+1,2+3,3+3] + let shuf = _mm_movehl_ps(sums, sums); // [2+3,3+3,2+3,3+3] + let sums = _mm_add_ss(sums, shuf); // [0+1+2+3,...] + + _mm_cvtss_f32(sums) +} + +/// Safe wrapper for L2 squared distance. +#[inline] +pub fn l2_squared_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { l2_squared_f32_avx(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::l2::l2_squared_scalar(a, b, dim) + } +} + +/// Safe wrapper for inner product. +#[inline] +pub fn inner_product_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { inner_product_f32_avx(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::ip::inner_product_scalar(a, b, dim) + } +} + +/// Safe wrapper for cosine distance. +#[inline] +pub fn cosine_distance_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { cosine_distance_f32_avx(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::cosine::cosine_distance_scalar(a, b, dim) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn is_avx_available() -> bool { + is_x86_feature_detected!("avx") + } + + #[test] + fn test_l2_squared_avx() { + if !is_avx_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + + let result = unsafe { l2_squared_f32_avx(a.as_ptr(), b.as_ptr(), a.len()) }; + // Each difference is 1.0, so squared sum = 8.0 + assert!((result - 8.0).abs() < 1e-6); + } + + #[test] + fn test_l2_squared_avx_remainder() { + if !is_avx_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]; + + let result = unsafe { l2_squared_f32_avx(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!((result - 10.0).abs() < 1e-6); + } + + #[test] + fn test_inner_product_avx() { + if !is_avx_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let b = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; + + let result = unsafe { inner_product_f32_avx(a.as_ptr(), b.as_ptr(), a.len()) }; + // 1+2+3+4+5+6+7+8 = 36 + assert!((result - 36.0).abs() < 1e-6); + } + + #[test] + fn test_cosine_avx() { + if !is_avx_available() { + return; + } + + // Test with identical normalized vectors (cosine distance = 0) + let a = vec![0.5f32, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0]; + let b = vec![0.5f32, 0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0]; + + let result = unsafe { cosine_distance_f32_avx(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!(result.abs() < 1e-6); + + // Test with orthogonal vectors (cosine distance = 1) + let a = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let b = vec![0.0f32, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + + let result = unsafe { cosine_distance_f32_avx(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!((result - 1.0).abs() < 1e-6); + } + + #[test] + fn test_l2_safe_wrapper() { + if !is_avx_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f32).collect(); + let b: Vec = (0..128).map(|i| (i + 1) as f32).collect(); + + let avx_result = l2_squared_f32::(&a, &b, 128); + let scalar_result = crate::distance::l2::l2_squared_scalar(&a, &b, 128); + + assert!((avx_result - scalar_result).abs() < 0.1); + } + + #[test] + fn test_inner_product_safe_wrapper() { + if !is_avx_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f32 / 100.0).collect(); + let b: Vec = (0..128).map(|i| (128 - i) as f32 / 100.0).collect(); + + let avx_result = inner_product_f32::(&a, &b, 128); + let scalar_result = crate::distance::ip::inner_product_scalar(&a, &b, 128); + + assert!((avx_result - scalar_result).abs() < 0.01); + } +} diff --git a/rust/vecsim/src/distance/simd/avx2.rs b/rust/vecsim/src/distance/simd/avx2.rs new file mode 100644 index 000000000..0e4fc0c91 --- /dev/null +++ b/rust/vecsim/src/distance/simd/avx2.rs @@ -0,0 +1,454 @@ +//! AVX2 SIMD implementations for distance functions. +//! +//! These functions use 256-bit AVX2 instructions for fast distance computation. +//! Only available on x86_64 with AVX2 and FMA support. + +#![cfg(target_arch = "x86_64")] + +use crate::types::{DistanceType, VectorElement}; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +/// AVX2 L2 squared distance for f32 vectors. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn l2_squared_f32_avx2(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm256_setzero_ps(); + let chunks = dim / 8; + let remainder = dim % 8; + + // Process 8 elements at a time + for i in 0..chunks { + let offset = i * 8; + let va = _mm256_loadu_ps(a.add(offset)); + let vb = _mm256_loadu_ps(b.add(offset)); + let diff = _mm256_sub_ps(va, vb); + sum = _mm256_fmadd_ps(diff, diff, sum); + } + + // Horizontal sum + let mut result = hsum256_ps(sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + let diff = *a.add(base + i) - *b.add(base + i); + result += diff * diff; + } + + result +} + +/// AVX2 inner product for f32 vectors. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn inner_product_f32_avx2(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm256_setzero_ps(); + let chunks = dim / 8; + let remainder = dim % 8; + + // Process 8 elements at a time + for i in 0..chunks { + let offset = i * 8; + let va = _mm256_loadu_ps(a.add(offset)); + let vb = _mm256_loadu_ps(b.add(offset)); + sum = _mm256_fmadd_ps(va, vb, sum); + } + + // Horizontal sum + let mut result = hsum256_ps(sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + result += *a.add(base + i) * *b.add(base + i); + } + + result +} + +/// AVX2 cosine distance for f32 vectors. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn cosine_distance_f32_avx2(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut dot_sum = _mm256_setzero_ps(); + let mut norm_a_sum = _mm256_setzero_ps(); + let mut norm_b_sum = _mm256_setzero_ps(); + + let chunks = dim / 8; + let remainder = dim % 8; + + // Process 8 elements at a time + for i in 0..chunks { + let offset = i * 8; + let va = _mm256_loadu_ps(a.add(offset)); + let vb = _mm256_loadu_ps(b.add(offset)); + + dot_sum = _mm256_fmadd_ps(va, vb, dot_sum); + norm_a_sum = _mm256_fmadd_ps(va, va, norm_a_sum); + norm_b_sum = _mm256_fmadd_ps(vb, vb, norm_b_sum); + } + + // Horizontal sums + let mut dot = hsum256_ps(dot_sum); + let mut norm_a = hsum256_ps(norm_a_sum); + let mut norm_b = hsum256_ps(norm_b_sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).max(-1.0).min(1.0); + 1.0 - cosine_sim +} + +/// Horizontal sum of 8 f32 values in a 256-bit register. +#[target_feature(enable = "avx2")] +#[inline] +unsafe fn hsum256_ps(v: __m256) -> f32 { + // Extract high 128 bits and add to low 128 bits + let high = _mm256_extractf128_ps(v, 1); + let low = _mm256_castps256_ps128(v); + let sum128 = _mm_add_ps(high, low); + + // Horizontal add within 128 bits + let shuf = _mm_movehdup_ps(sum128); // [1,1,3,3] + let sums = _mm_add_ps(sum128, shuf); // [0+1,1+1,2+3,3+3] + let shuf = _mm_movehl_ps(sums, sums); // [2+3,3+3,2+3,3+3] + let sums = _mm_add_ss(sums, shuf); // [0+1+2+3,...] + + _mm_cvtss_f32(sums) +} + +/// Horizontal sum of 4 f64 values in a 256-bit register. +#[target_feature(enable = "avx2")] +#[inline] +unsafe fn hsum256_pd(v: __m256d) -> f64 { + // Extract high 128 bits and add to low 128 bits + let high = _mm256_extractf128_pd(v, 1); + let low = _mm256_castpd256_pd128(v); + let sum128 = _mm_add_pd(high, low); + + // Horizontal add within 128 bits + let high64 = _mm_unpackhi_pd(sum128, sum128); + let result = _mm_add_sd(sum128, high64); + + _mm_cvtsd_f64(result) +} + +// ============================================================================= +// AVX2 f64 SIMD operations +// ============================================================================= + +/// AVX2 L2 squared distance for f64 vectors. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn l2_squared_f64_avx2(a: *const f64, b: *const f64, dim: usize) -> f64 { + let mut sum0 = _mm256_setzero_pd(); + let mut sum1 = _mm256_setzero_pd(); + let chunks = dim / 8; + let remainder = dim % 8; + + // Process 8 elements at a time (two sets of 4) + for i in 0..chunks { + let offset = i * 8; + + let va0 = _mm256_loadu_pd(a.add(offset)); + let vb0 = _mm256_loadu_pd(b.add(offset)); + let diff0 = _mm256_sub_pd(va0, vb0); + sum0 = _mm256_fmadd_pd(diff0, diff0, sum0); + + let va1 = _mm256_loadu_pd(a.add(offset + 4)); + let vb1 = _mm256_loadu_pd(b.add(offset + 4)); + let diff1 = _mm256_sub_pd(va1, vb1); + sum1 = _mm256_fmadd_pd(diff1, diff1, sum1); + } + + // Combine and reduce + let sum = _mm256_add_pd(sum0, sum1); + let mut result = hsum256_pd(sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + let diff = *a.add(base + i) - *b.add(base + i); + result += diff * diff; + } + + result +} + +/// AVX2 inner product for f64 vectors. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn inner_product_f64_avx2(a: *const f64, b: *const f64, dim: usize) -> f64 { + let mut sum0 = _mm256_setzero_pd(); + let mut sum1 = _mm256_setzero_pd(); + let chunks = dim / 8; + let remainder = dim % 8; + + for i in 0..chunks { + let offset = i * 8; + + let va0 = _mm256_loadu_pd(a.add(offset)); + let vb0 = _mm256_loadu_pd(b.add(offset)); + sum0 = _mm256_fmadd_pd(va0, vb0, sum0); + + let va1 = _mm256_loadu_pd(a.add(offset + 4)); + let vb1 = _mm256_loadu_pd(b.add(offset + 4)); + sum1 = _mm256_fmadd_pd(va1, vb1, sum1); + } + + let sum = _mm256_add_pd(sum0, sum1); + let mut result = hsum256_pd(sum); + + let base = chunks * 8; + for i in 0..remainder { + result += *a.add(base + i) * *b.add(base + i); + } + + result +} + +/// AVX2 cosine distance for f64 vectors. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn cosine_distance_f64_avx2(a: *const f64, b: *const f64, dim: usize) -> f64 { + let mut dot_sum = _mm256_setzero_pd(); + let mut norm_a_sum = _mm256_setzero_pd(); + let mut norm_b_sum = _mm256_setzero_pd(); + + let chunks = dim / 4; + let remainder = dim % 4; + + for i in 0..chunks { + let offset = i * 4; + let va = _mm256_loadu_pd(a.add(offset)); + let vb = _mm256_loadu_pd(b.add(offset)); + + dot_sum = _mm256_fmadd_pd(va, vb, dot_sum); + norm_a_sum = _mm256_fmadd_pd(va, va, norm_a_sum); + norm_b_sum = _mm256_fmadd_pd(vb, vb, norm_b_sum); + } + + let mut dot = hsum256_pd(dot_sum); + let mut norm_a = hsum256_pd(norm_a_sum); + let mut norm_b = hsum256_pd(norm_b_sum); + + let base = chunks * 4; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +/// Safe wrapper for f64 L2 squared distance. +#[inline] +pub fn l2_squared_f64(a: &[f64], b: &[f64], dim: usize) -> f64 { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + unsafe { l2_squared_f64_avx2(a.as_ptr(), b.as_ptr(), dim) } + } else { + crate::distance::l2::l2_squared_scalar_f64(a, b, dim) + } +} + +/// Safe wrapper for f64 inner product. +#[inline] +pub fn inner_product_f64(a: &[f64], b: &[f64], dim: usize) -> f64 { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + unsafe { inner_product_f64_avx2(a.as_ptr(), b.as_ptr(), dim) } + } else { + crate::distance::ip::inner_product_scalar_f64(a, b, dim) + } +} + +/// Safe wrapper for f64 cosine distance. +#[inline] +pub fn cosine_distance_f64(a: &[f64], b: &[f64], dim: usize) -> f64 { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + unsafe { cosine_distance_f64_avx2(a.as_ptr(), b.as_ptr(), dim) } + } else { + // Scalar fallback + let mut dot = 0.0f64; + let mut norm_a = 0.0f64; + let mut norm_b = 0.0f64; + for i in 0..dim { + dot += a[i] * b[i]; + norm_a += a[i] * a[i]; + norm_b += b[i] * b[i]; + } + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + 1.0 - (dot / denom).clamp(-1.0, 1.0) + } +} + +/// Safe wrapper for L2 squared distance. +#[inline] +pub fn l2_squared_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + // Convert to f32 slices if needed + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { l2_squared_f32_avx2(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + // Fallback to scalar + crate::distance::l2::l2_squared_scalar(a, b, dim) + } +} + +/// Safe wrapper for inner product. +#[inline] +pub fn inner_product_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { inner_product_f32_avx2(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::ip::inner_product_scalar(a, b, dim) + } +} + +/// Safe wrapper for cosine distance. +#[inline] +pub fn cosine_distance_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { cosine_distance_f32_avx2(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::cosine::cosine_distance_scalar(a, b, dim) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn is_avx2_available() -> bool { + is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") + } + + #[test] + fn test_avx2_l2_squared() { + if !is_avx2_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f32).collect(); + let b: Vec = (0..128).map(|i| (i + 1) as f32).collect(); + + let avx2_result = l2_squared_f32::(&a, &b, 128); + let scalar_result = crate::distance::l2::l2_squared_scalar(&a, &b, 128); + + assert!((avx2_result - scalar_result).abs() < 0.1); + } + + #[test] + fn test_avx2_inner_product() { + if !is_avx2_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f32 / 100.0).collect(); + let b: Vec = (0..128).map(|i| (128 - i) as f32 / 100.0).collect(); + + let avx2_result = inner_product_f32::(&a, &b, 128); + let scalar_result = crate::distance::ip::inner_product_scalar(&a, &b, 128); + + assert!((avx2_result - scalar_result).abs() < 0.01); + } + + // f64 tests + #[test] + fn test_avx2_l2_squared_f64() { + if !is_avx2_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f64).collect(); + let b: Vec = (0..128).map(|i| (i + 1) as f64).collect(); + + let result = l2_squared_f64(&a, &b, 128); + // Each diff is 1, squared is 1, sum of 128 ones = 128 + assert!((result - 128.0).abs() < 1e-10); + } + + #[test] + fn test_avx2_inner_product_f64() { + if !is_avx2_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f64 / 100.0).collect(); + let b: Vec = (0..128).map(|i| (128 - i) as f64 / 100.0).collect(); + + let result = inner_product_f64(&a, &b, 128); + let expected: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum(); + + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_avx2_cosine_f64() { + if !is_avx2_available() { + return; + } + + // Identical vectors should have distance 0 + let a: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let result = cosine_distance_f64(&a, &a, 4); + assert!(result.abs() < 1e-10); + + // Orthogonal vectors should have distance 1 + let a: Vec = vec![1.0, 0.0, 0.0, 0.0]; + let b: Vec = vec![0.0, 1.0, 0.0, 0.0]; + let result = cosine_distance_f64(&a, &b, 4); + assert!((result - 1.0).abs() < 1e-10); + } + + #[test] + fn test_avx2_f64_remainder() { + if !is_avx2_available() { + return; + } + + // Test with non-aligned dimension + let a: Vec = (0..131).map(|i| i as f64).collect(); + let b: Vec = (0..131).map(|i| (i + 1) as f64).collect(); + + let result = l2_squared_f64(&a, &b, 131); + assert!((result - 131.0).abs() < 1e-10); + } +} diff --git a/rust/vecsim/src/distance/simd/avx512.rs b/rust/vecsim/src/distance/simd/avx512.rs new file mode 100644 index 000000000..d0d0dd8f4 --- /dev/null +++ b/rust/vecsim/src/distance/simd/avx512.rs @@ -0,0 +1,179 @@ +//! AVX-512 SIMD implementations for distance functions. +//! +//! These functions use 512-bit AVX-512 instructions for maximum throughput. +//! Only available on x86_64 with AVX-512F and AVX-512BW support. + +#![cfg(target_arch = "x86_64")] + +use crate::types::{DistanceType, VectorElement}; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +/// AVX-512 L2 squared distance for f32 vectors. +#[target_feature(enable = "avx512f")] +#[inline] +pub unsafe fn l2_squared_f32_avx512(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm512_setzero_ps(); + let chunks = dim / 16; + let remainder = dim % 16; + + // Process 16 elements at a time + for i in 0..chunks { + let offset = i * 16; + let va = _mm512_loadu_ps(a.add(offset)); + let vb = _mm512_loadu_ps(b.add(offset)); + let diff = _mm512_sub_ps(va, vb); + sum = _mm512_fmadd_ps(diff, diff, sum); + } + + // Reduce to scalar + let mut result = _mm512_reduce_add_ps(sum); + + // Handle remainder + let base = chunks * 16; + for i in 0..remainder { + let diff = *a.add(base + i) - *b.add(base + i); + result += diff * diff; + } + + result +} + +/// AVX-512 inner product for f32 vectors. +#[target_feature(enable = "avx512f")] +#[inline] +pub unsafe fn inner_product_f32_avx512(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm512_setzero_ps(); + let chunks = dim / 16; + let remainder = dim % 16; + + // Process 16 elements at a time + for i in 0..chunks { + let offset = i * 16; + let va = _mm512_loadu_ps(a.add(offset)); + let vb = _mm512_loadu_ps(b.add(offset)); + sum = _mm512_fmadd_ps(va, vb, sum); + } + + // Reduce to scalar + let mut result = _mm512_reduce_add_ps(sum); + + // Handle remainder + let base = chunks * 16; + for i in 0..remainder { + result += *a.add(base + i) * *b.add(base + i); + } + + result +} + +/// AVX-512 cosine distance for f32 vectors. +#[target_feature(enable = "avx512f")] +#[inline] +pub unsafe fn cosine_distance_f32_avx512(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut dot_sum = _mm512_setzero_ps(); + let mut norm_a_sum = _mm512_setzero_ps(); + let mut norm_b_sum = _mm512_setzero_ps(); + + let chunks = dim / 16; + let remainder = dim % 16; + + // Process 16 elements at a time + for i in 0..chunks { + let offset = i * 16; + let va = _mm512_loadu_ps(a.add(offset)); + let vb = _mm512_loadu_ps(b.add(offset)); + + dot_sum = _mm512_fmadd_ps(va, vb, dot_sum); + norm_a_sum = _mm512_fmadd_ps(va, va, norm_a_sum); + norm_b_sum = _mm512_fmadd_ps(vb, vb, norm_b_sum); + } + + // Reduce to scalars + let mut dot = _mm512_reduce_add_ps(dot_sum); + let mut norm_a = _mm512_reduce_add_ps(norm_a_sum); + let mut norm_b = _mm512_reduce_add_ps(norm_b_sum); + + // Handle remainder + let base = chunks * 16; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).max(-1.0).min(1.0); + 1.0 - cosine_sim +} + +/// Safe wrapper for L2 squared distance. +#[inline] +pub fn l2_squared_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx512f") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { l2_squared_f32_avx512(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + // Fallback to AVX2 or scalar + super::avx2::l2_squared_f32(a, b, dim) + } +} + +/// Safe wrapper for inner product. +#[inline] +pub fn inner_product_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx512f") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { inner_product_f32_avx512(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + super::avx2::inner_product_f32(a, b, dim) + } +} + +/// Safe wrapper for cosine distance. +#[inline] +pub fn cosine_distance_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx512f") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { cosine_distance_f32_avx512(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + super::avx2::cosine_distance_f32(a, b, dim) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_avx512_l2_squared() { + if !is_x86_feature_detected!("avx512f") { + println!("AVX-512 not available, skipping test"); + return; + } + + let a: Vec = (0..256).map(|i| i as f32).collect(); + let b: Vec = (0..256).map(|i| (i + 1) as f32).collect(); + + let avx512_result = l2_squared_f32::(&a, &b, 256); + let scalar_result = crate::distance::l2::l2_squared_scalar(&a, &b, 256); + + assert!((avx512_result - scalar_result).abs() < 0.1); + } +} diff --git a/rust/vecsim/src/distance/simd/avx512bf16.rs b/rust/vecsim/src/distance/simd/avx512bf16.rs new file mode 100644 index 000000000..2f712ce25 --- /dev/null +++ b/rust/vecsim/src/distance/simd/avx512bf16.rs @@ -0,0 +1,446 @@ +//! AVX-512 BF16 optimized distance functions for BFloat16 vectors. +//! +//! This module provides SIMD-optimized distance computations for bfloat16 +//! floating point vectors using AVX-512. +//! +//! BFloat16 has a simple relationship with f32: the bf16 bits are the upper +//! 16 bits of the f32 representation. This allows very efficient conversion: +//! - bf16 -> f32: shift left by 16 bits +//! - f32 -> bf16: shift right by 16 bits (with rounding) +//! +//! This module uses AVX-512 to: +//! - Load 32 bf16 values at a time (512 bits) +//! - Efficiently convert to f32 using the shift trick +//! - Compute distances in f32 using FMA + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use crate::types::BFloat16; + +/// Check if AVX-512 is available at runtime. +#[cfg(target_arch = "x86_64")] +#[inline] +pub fn has_avx512() -> bool { + is_x86_feature_detected!("avx512f") +} + +/// Compute L2 squared distance between two BFloat16 vectors using AVX-512. +/// +/// # Safety +/// - Requires AVX-512F support +/// - Pointers must be valid for `dim` elements +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +pub unsafe fn l2_squared_bf16_avx512( + a: *const BFloat16, + b: *const BFloat16, + dim: usize, +) -> f32 { + let a = a as *const u16; + let b = b as *const u16; + + let mut sum0 = _mm512_setzero_ps(); + let mut sum1 = _mm512_setzero_ps(); + + let unroll = dim / 32 * 32; + let mut offset = 0; + + // Process 32 bf16 elements per iteration (2x16) + while offset < unroll { + // Load 16 bf16 values at a time (256 bits) + let a_raw0 = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let a_raw1 = _mm256_loadu_si256((a.add(offset + 16)) as *const __m256i); + let b_raw0 = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + let b_raw1 = _mm256_loadu_si256((b.add(offset + 16)) as *const __m256i); + + // Convert bf16 to f32 by zero-extending to 32-bit and shifting left by 16 + // bf16 bits are the upper 16 bits of f32 + let a_32_0 = _mm512_cvtepu16_epi32(a_raw0); + let a_32_1 = _mm512_cvtepu16_epi32(a_raw1); + let b_32_0 = _mm512_cvtepu16_epi32(b_raw0); + let b_32_1 = _mm512_cvtepu16_epi32(b_raw1); + + // Shift left by 16 to get f32 representation + let a_shifted0 = _mm512_slli_epi32(a_32_0, 16); + let a_shifted1 = _mm512_slli_epi32(a_32_1, 16); + let b_shifted0 = _mm512_slli_epi32(b_32_0, 16); + let b_shifted1 = _mm512_slli_epi32(b_32_1, 16); + + // Reinterpret as f32 + let a_f32_0 = _mm512_castsi512_ps(a_shifted0); + let a_f32_1 = _mm512_castsi512_ps(a_shifted1); + let b_f32_0 = _mm512_castsi512_ps(b_shifted0); + let b_f32_1 = _mm512_castsi512_ps(b_shifted1); + + // Compute differences + let diff0 = _mm512_sub_ps(a_f32_0, b_f32_0); + let diff1 = _mm512_sub_ps(a_f32_1, b_f32_1); + + // Accumulate squared differences using FMA + sum0 = _mm512_fmadd_ps(diff0, diff0, sum0); + sum1 = _mm512_fmadd_ps(diff1, diff1, sum1); + + offset += 32; + } + + // Process remaining 16-element chunks + while offset + 16 <= dim { + let a_raw = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let b_raw = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + + let a_32 = _mm512_cvtepu16_epi32(a_raw); + let b_32 = _mm512_cvtepu16_epi32(b_raw); + + let a_shifted = _mm512_slli_epi32(a_32, 16); + let b_shifted = _mm512_slli_epi32(b_32, 16); + + let a_f32 = _mm512_castsi512_ps(a_shifted); + let b_f32 = _mm512_castsi512_ps(b_shifted); + + let diff = _mm512_sub_ps(a_f32, b_f32); + sum0 = _mm512_fmadd_ps(diff, diff, sum0); + + offset += 16; + } + + // Reduce sums + let total = _mm512_add_ps(sum0, sum1); + let mut result = _mm512_reduce_add_ps(total); + + // Handle remaining elements with scalar code + while offset < dim { + let av = BFloat16::from_bits(*a.add(offset)).to_f32(); + let bv = BFloat16::from_bits(*b.add(offset)).to_f32(); + let diff = av - bv; + result += diff * diff; + offset += 1; + } + + result +} + +/// Compute inner product between two BFloat16 vectors using AVX-512. +/// +/// # Safety +/// - Requires AVX-512F support +/// - Pointers must be valid for `dim` elements +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +pub unsafe fn inner_product_bf16_avx512( + a: *const BFloat16, + b: *const BFloat16, + dim: usize, +) -> f32 { + let a = a as *const u16; + let b = b as *const u16; + + let mut sum0 = _mm512_setzero_ps(); + let mut sum1 = _mm512_setzero_ps(); + + let unroll = dim / 32 * 32; + let mut offset = 0; + + // Process 32 elements per iteration + while offset < unroll { + let a_raw0 = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let a_raw1 = _mm256_loadu_si256((a.add(offset + 16)) as *const __m256i); + let b_raw0 = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + let b_raw1 = _mm256_loadu_si256((b.add(offset + 16)) as *const __m256i); + + let a_32_0 = _mm512_cvtepu16_epi32(a_raw0); + let a_32_1 = _mm512_cvtepu16_epi32(a_raw1); + let b_32_0 = _mm512_cvtepu16_epi32(b_raw0); + let b_32_1 = _mm512_cvtepu16_epi32(b_raw1); + + let a_shifted0 = _mm512_slli_epi32(a_32_0, 16); + let a_shifted1 = _mm512_slli_epi32(a_32_1, 16); + let b_shifted0 = _mm512_slli_epi32(b_32_0, 16); + let b_shifted1 = _mm512_slli_epi32(b_32_1, 16); + + let a_f32_0 = _mm512_castsi512_ps(a_shifted0); + let a_f32_1 = _mm512_castsi512_ps(a_shifted1); + let b_f32_0 = _mm512_castsi512_ps(b_shifted0); + let b_f32_1 = _mm512_castsi512_ps(b_shifted1); + + // Accumulate products using FMA + sum0 = _mm512_fmadd_ps(a_f32_0, b_f32_0, sum0); + sum1 = _mm512_fmadd_ps(a_f32_1, b_f32_1, sum1); + + offset += 32; + } + + // Process remaining 16-element chunks + while offset + 16 <= dim { + let a_raw = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let b_raw = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + + let a_32 = _mm512_cvtepu16_epi32(a_raw); + let b_32 = _mm512_cvtepu16_epi32(b_raw); + + let a_shifted = _mm512_slli_epi32(a_32, 16); + let b_shifted = _mm512_slli_epi32(b_32, 16); + + let a_f32 = _mm512_castsi512_ps(a_shifted); + let b_f32 = _mm512_castsi512_ps(b_shifted); + + sum0 = _mm512_fmadd_ps(a_f32, b_f32, sum0); + + offset += 16; + } + + let total = _mm512_add_ps(sum0, sum1); + let mut result = _mm512_reduce_add_ps(total); + + // Handle remaining elements + while offset < dim { + let av = BFloat16::from_bits(*a.add(offset)).to_f32(); + let bv = BFloat16::from_bits(*b.add(offset)).to_f32(); + result += av * bv; + offset += 1; + } + + result +} + +/// Compute cosine distance between two BFloat16 vectors using AVX-512. +/// +/// # Safety +/// - Requires AVX-512F support +/// - Pointers must be valid for `dim` elements +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +pub unsafe fn cosine_distance_bf16_avx512( + a: *const BFloat16, + b: *const BFloat16, + dim: usize, +) -> f32 { + let a = a as *const u16; + let b = b as *const u16; + + let mut dot_sum = _mm512_setzero_ps(); + let mut norm_a_sum = _mm512_setzero_ps(); + let mut norm_b_sum = _mm512_setzero_ps(); + + let unroll = dim / 16 * 16; + let mut offset = 0; + + // Process 16 elements per iteration + while offset < unroll { + let a_raw = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let b_raw = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + + let a_32 = _mm512_cvtepu16_epi32(a_raw); + let b_32 = _mm512_cvtepu16_epi32(b_raw); + + let a_shifted = _mm512_slli_epi32(a_32, 16); + let b_shifted = _mm512_slli_epi32(b_32, 16); + + let a_f32 = _mm512_castsi512_ps(a_shifted); + let b_f32 = _mm512_castsi512_ps(b_shifted); + + dot_sum = _mm512_fmadd_ps(a_f32, b_f32, dot_sum); + norm_a_sum = _mm512_fmadd_ps(a_f32, a_f32, norm_a_sum); + norm_b_sum = _mm512_fmadd_ps(b_f32, b_f32, norm_b_sum); + + offset += 16; + } + + let mut dot = _mm512_reduce_add_ps(dot_sum); + let mut norm_a = _mm512_reduce_add_ps(norm_a_sum); + let mut norm_b = _mm512_reduce_add_ps(norm_b_sum); + + // Handle remaining elements + while offset < dim { + let av = BFloat16::from_bits(*a.add(offset)).to_f32(); + let bv = BFloat16::from_bits(*b.add(offset)).to_f32(); + dot += av * bv; + norm_a += av * av; + norm_b += bv * bv; + offset += 1; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +// Safe wrappers with runtime dispatch + +/// Safe L2 squared distance for BFloat16 with automatic dispatch. +pub fn l2_squared_bf16(a: &[BFloat16], b: &[BFloat16], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + + #[cfg(target_arch = "x86_64")] + { + if has_avx512() { + return unsafe { l2_squared_bf16_avx512(a.as_ptr(), b.as_ptr(), dim) }; + } + } + + // Scalar fallback + l2_squared_bf16_scalar(a, b, dim) +} + +/// Safe inner product for BFloat16 with automatic dispatch. +pub fn inner_product_bf16(a: &[BFloat16], b: &[BFloat16], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + + #[cfg(target_arch = "x86_64")] + { + if has_avx512() { + return unsafe { inner_product_bf16_avx512(a.as_ptr(), b.as_ptr(), dim) }; + } + } + + // Scalar fallback + inner_product_bf16_scalar(a, b, dim) +} + +/// Safe cosine distance for BFloat16 with automatic dispatch. +pub fn cosine_distance_bf16(a: &[BFloat16], b: &[BFloat16], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + + #[cfg(target_arch = "x86_64")] + { + if has_avx512() { + return unsafe { cosine_distance_bf16_avx512(a.as_ptr(), b.as_ptr(), dim) }; + } + } + + // Scalar fallback + cosine_distance_bf16_scalar(a, b, dim) +} + +// Scalar implementations + +fn l2_squared_bf16_scalar(a: &[BFloat16], b: &[BFloat16], dim: usize) -> f32 { + let mut sum = 0.0f32; + for i in 0..dim { + let diff = a[i].to_f32() - b[i].to_f32(); + sum += diff * diff; + } + sum +} + +fn inner_product_bf16_scalar(a: &[BFloat16], b: &[BFloat16], dim: usize) -> f32 { + let mut sum = 0.0f32; + for i in 0..dim { + sum += a[i].to_f32() * b[i].to_f32(); + } + sum +} + +fn cosine_distance_bf16_scalar(a: &[BFloat16], b: &[BFloat16], dim: usize) -> f32 { + let mut dot = 0.0f32; + let mut norm_a = 0.0f32; + let mut norm_b = 0.0f32; + + for i in 0..dim { + let av = a[i].to_f32(); + let bv = b[i].to_f32(); + dot += av * bv; + norm_a += av * av; + norm_b += bv * bv; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_vectors(dim: usize) -> (Vec, Vec) { + let a: Vec = (0..dim) + .map(|i| BFloat16::from_f32((i as f32) / (dim as f32))) + .collect(); + let b: Vec = (0..dim) + .map(|i| BFloat16::from_f32(((dim - i) as f32) / (dim as f32))) + .collect(); + (a, b) + } + + #[test] + fn test_bf16_l2_squared() { + for &dim in &[16, 32, 64, 100, 128, 256] { + let (a, b) = create_test_vectors(dim); + let scalar_result = l2_squared_bf16_scalar(&a, &b, dim); + let simd_result = l2_squared_bf16(&a, &b, dim); + + // BFloat16 has lower precision, so allow more tolerance + assert!( + (scalar_result - simd_result).abs() < 0.1, + "dim={}: scalar={}, simd={}", + dim, + scalar_result, + simd_result + ); + } + } + + #[test] + fn test_bf16_inner_product() { + for &dim in &[16, 32, 64, 100, 128, 256] { + let (a, b) = create_test_vectors(dim); + let scalar_result = inner_product_bf16_scalar(&a, &b, dim); + let simd_result = inner_product_bf16(&a, &b, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.1, + "dim={}: scalar={}, simd={}", + dim, + scalar_result, + simd_result + ); + } + } + + #[test] + fn test_bf16_cosine_distance() { + for &dim in &[16, 32, 64, 100, 128, 256] { + let (a, b) = create_test_vectors(dim); + let scalar_result = cosine_distance_bf16_scalar(&a, &b, dim); + let simd_result = cosine_distance_bf16(&a, &b, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.1, + "dim={}: scalar={}, simd={}", + dim, + scalar_result, + simd_result + ); + } + } + + #[test] + fn test_bf16_self_distance() { + let dim = 128; + let (a, _) = create_test_vectors(dim); + + // Self L2 distance should be 0 + let l2 = l2_squared_bf16(&a, &a, dim); + assert!(l2.abs() < 0.01, "Self L2 distance should be ~0, got {}", l2); + + // Self cosine distance should be 0 + let cosine = cosine_distance_bf16(&a, &a, dim); + assert!( + cosine.abs() < 0.01, + "Self cosine distance should be ~0, got {}", + cosine + ); + } +} diff --git a/rust/vecsim/src/distance/simd/avx512bw.rs b/rust/vecsim/src/distance/simd/avx512bw.rs new file mode 100644 index 000000000..e48097aea --- /dev/null +++ b/rust/vecsim/src/distance/simd/avx512bw.rs @@ -0,0 +1,603 @@ +//! AVX-512BW and AVX-512VNNI SIMD implementations for distance functions. +//! +//! This module provides optimized distance computations using: +//! - AVX-512BW: Byte and word operations on 512-bit vectors +//! - AVX-512VNNI: Vector Neural Network Instructions for int8 dot products +//! +//! VNNI provides `VPDPBUSD` which computes: +//! result[i] += sum(a[4*i+j] * b[4*i+j]) for j in 0..4 +//! where a is unsigned bytes and b is signed bytes, accumulating to int32. + +#![cfg(target_arch = "x86_64")] + +use crate::types::{Int8, UInt8, DistanceType, VectorElement}; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +// ============================================================================ +// AVX-512BW f32 functions (same as avx512.rs but with BW feature) +// ============================================================================ + +/// AVX-512BW L2 squared distance for f32 vectors. +#[target_feature(enable = "avx512f", enable = "avx512bw")] +#[inline] +pub unsafe fn l2_squared_f32_avx512bw(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm512_setzero_ps(); + let chunks = dim / 16; + let remainder = dim % 16; + + for i in 0..chunks { + let offset = i * 16; + let va = _mm512_loadu_ps(a.add(offset)); + let vb = _mm512_loadu_ps(b.add(offset)); + let diff = _mm512_sub_ps(va, vb); + sum = _mm512_fmadd_ps(diff, diff, sum); + } + + let mut result = _mm512_reduce_add_ps(sum); + + let base = chunks * 16; + for i in 0..remainder { + let diff = *a.add(base + i) - *b.add(base + i); + result += diff * diff; + } + + result +} + +/// AVX-512BW inner product for f32 vectors. +#[target_feature(enable = "avx512f", enable = "avx512bw")] +#[inline] +pub unsafe fn inner_product_f32_avx512bw(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm512_setzero_ps(); + let chunks = dim / 16; + let remainder = dim % 16; + + for i in 0..chunks { + let offset = i * 16; + let va = _mm512_loadu_ps(a.add(offset)); + let vb = _mm512_loadu_ps(b.add(offset)); + sum = _mm512_fmadd_ps(va, vb, sum); + } + + let mut result = _mm512_reduce_add_ps(sum); + + let base = chunks * 16; + for i in 0..remainder { + result += *a.add(base + i) * *b.add(base + i); + } + + result +} + +/// AVX-512BW cosine distance for f32 vectors. +#[target_feature(enable = "avx512f", enable = "avx512bw")] +#[inline] +pub unsafe fn cosine_distance_f32_avx512bw(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut dot_sum = _mm512_setzero_ps(); + let mut norm_a_sum = _mm512_setzero_ps(); + let mut norm_b_sum = _mm512_setzero_ps(); + + let chunks = dim / 16; + let remainder = dim % 16; + + for i in 0..chunks { + let offset = i * 16; + let va = _mm512_loadu_ps(a.add(offset)); + let vb = _mm512_loadu_ps(b.add(offset)); + + dot_sum = _mm512_fmadd_ps(va, vb, dot_sum); + norm_a_sum = _mm512_fmadd_ps(va, va, norm_a_sum); + norm_b_sum = _mm512_fmadd_ps(vb, vb, norm_b_sum); + } + + let mut dot = _mm512_reduce_add_ps(dot_sum); + let mut norm_a = _mm512_reduce_add_ps(norm_a_sum); + let mut norm_b = _mm512_reduce_add_ps(norm_b_sum); + + let base = chunks * 16; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).max(-1.0).min(1.0); + 1.0 - cosine_sim +} + +// ============================================================================ +// AVX-512VNNI Int8/UInt8 functions +// ============================================================================ + +/// AVX-512VNNI inner product for Int8 vectors. +/// +/// Uses VPDPBUSD instruction for efficient int8×int8 dot product. +/// Returns the dot product as i32, which should be converted to f32 for distance. +#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vnni")] +#[inline] +pub unsafe fn inner_product_i8_avx512vnni(a: *const i8, b: *const i8, dim: usize) -> i32 { + // VPDPBUSD expects unsigned × signed, so we need to handle signed × signed + // by adjusting: (a+128) * b - 128 * sum(b) gives same result as signed a * b + // For simplicity, we use a different approach: + // Split into positive/negative parts or use i16 intermediate + + // Alternative approach: Use _mm512_dpbusd_epi32 with bias adjustment + // For signed×signed: result = dpbusd(a+128, b) - 128 * sum(b) + // This is complex, so we use a simpler i16 widening approach with AVX-512BW + + let mut sum = _mm512_setzero_si512(); + let chunks = dim / 64; + let remainder = dim % 64; + + for i in 0..chunks { + let offset = i * 64; + // Load 64 bytes (512 bits) + let va = _mm512_loadu_si512(a.add(offset) as *const __m512i); + let vb = _mm512_loadu_si512(b.add(offset) as *const __m512i); + + // Unpack to i16 and multiply, then accumulate to i32 + // Low 32 bytes + let va_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(va)); + let vb_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(vb)); + let prod_lo = _mm512_madd_epi16(va_lo, vb_lo); + sum = _mm512_add_epi32(sum, prod_lo); + + // High 32 bytes + let va_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(va, 1)); + let vb_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(vb, 1)); + let prod_hi = _mm512_madd_epi16(va_hi, vb_hi); + sum = _mm512_add_epi32(sum, prod_hi); + } + + // Reduce to scalar + let mut result = _mm512_reduce_add_epi32(sum); + + // Handle remainder + let base = chunks * 64; + for i in 0..remainder { + result += (*a.add(base + i) as i32) * (*b.add(base + i) as i32); + } + + result +} + +/// AVX-512VNNI inner product for UInt8 vectors (unsigned × unsigned). +/// +/// Uses VPDPBUSD with careful handling for u8×u8. +#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vnni")] +#[inline] +pub unsafe fn inner_product_u8_avx512vnni(a: *const u8, b: *const u8, dim: usize) -> u32 { + // For u8×u8, we can use VPDPBUSD (u8×i8) by treating one operand as i8 + // and adjusting: u8×u8 = u8×(i8+128) + u8×(-128) = dpbusd(a,b-128) + 128*sum(a) + // Simpler: widen to u16/i16 and multiply + + let mut sum = _mm512_setzero_si512(); + let chunks = dim / 64; + let remainder = dim % 64; + + for i in 0..chunks { + let offset = i * 64; + let va = _mm512_loadu_si512(a.add(offset) as *const __m512i); + let vb = _mm512_loadu_si512(b.add(offset) as *const __m512i); + + // Zero-extend to u16 (we use signed i16 since _mm512_cvtepu8_epi16 gives us the same bits) + let va_lo = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(va)); + let vb_lo = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(vb)); + let prod_lo = _mm512_madd_epi16(va_lo, vb_lo); + sum = _mm512_add_epi32(sum, prod_lo); + + let va_hi = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(va, 1)); + let vb_hi = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(vb, 1)); + let prod_hi = _mm512_madd_epi16(va_hi, vb_hi); + sum = _mm512_add_epi32(sum, prod_hi); + } + + let mut result = _mm512_reduce_add_epi32(sum) as u32; + + let base = chunks * 64; + for i in 0..remainder { + result += (*a.add(base + i) as u32) * (*b.add(base + i) as u32); + } + + result +} + +/// AVX-512VNNI L2 squared distance for Int8 vectors. +/// +/// Computes sum((a[i] - b[i])^2) using AVX-512BW for subtraction +/// and widening to i16 for the squared difference. +#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vnni")] +#[inline] +pub unsafe fn l2_squared_i8_avx512vnni(a: *const i8, b: *const i8, dim: usize) -> i32 { + let mut sum = _mm512_setzero_si512(); + let chunks = dim / 64; + let remainder = dim % 64; + + for i in 0..chunks { + let offset = i * 64; + let va = _mm512_loadu_si512(a.add(offset) as *const __m512i); + let vb = _mm512_loadu_si512(b.add(offset) as *const __m512i); + + // Compute difference in i8 (may saturate, but we widen immediately) + let diff = _mm512_sub_epi8(va, vb); + + // Widen to i16 and compute diff^2 + let diff_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(diff)); + let sq_lo = _mm512_madd_epi16(diff_lo, diff_lo); + sum = _mm512_add_epi32(sum, sq_lo); + + let diff_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(diff, 1)); + let sq_hi = _mm512_madd_epi16(diff_hi, diff_hi); + sum = _mm512_add_epi32(sum, sq_hi); + } + + let mut result = _mm512_reduce_add_epi32(sum); + + let base = chunks * 64; + for i in 0..remainder { + let diff = (*a.add(base + i) as i32) - (*b.add(base + i) as i32); + result += diff * diff; + } + + result +} + +/// AVX-512VNNI L2 squared distance for UInt8 vectors. +#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vnni")] +#[inline] +pub unsafe fn l2_squared_u8_avx512vnni(a: *const u8, b: *const u8, dim: usize) -> u32 { + let mut sum = _mm512_setzero_si512(); + let chunks = dim / 64; + let remainder = dim % 64; + + for i in 0..chunks { + let offset = i * 64; + let va = _mm512_loadu_si512(a.add(offset) as *const __m512i); + let vb = _mm512_loadu_si512(b.add(offset) as *const __m512i); + + // For u8 difference, we need to handle potential underflow + // Use max(a,b) - min(a,b) = abs(a-b) for unsigned + let max_ab = _mm512_max_epu8(va, vb); + let min_ab = _mm512_min_epu8(va, vb); + let diff = _mm512_sub_epi8(max_ab, min_ab); // unsigned difference + + // Zero-extend to i16 and compute diff^2 + let diff_lo = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(diff)); + let sq_lo = _mm512_madd_epi16(diff_lo, diff_lo); + sum = _mm512_add_epi32(sum, sq_lo); + + let diff_hi = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(diff, 1)); + let sq_hi = _mm512_madd_epi16(diff_hi, diff_hi); + sum = _mm512_add_epi32(sum, sq_hi); + } + + let mut result = _mm512_reduce_add_epi32(sum) as u32; + + let base = chunks * 64; + for i in 0..remainder { + let diff = (*a.add(base + i) as i32) - (*b.add(base + i) as i32); + result += (diff * diff) as u32; + } + + result +} + +/// AVX-512VNNI dot product using VPDPBUSD instruction. +/// +/// This is the most efficient path for u8×i8 multiplication. +/// Computes: result += sum(a[4i+j] * b[4i+j]) for j in 0..4 +/// where a is u8 (unsigned) and b is i8 (signed). +#[target_feature(enable = "avx512f", enable = "avx512bw", enable = "avx512vnni")] +#[inline] +pub unsafe fn dot_product_u8_i8_avx512vnni(a: *const u8, b: *const i8, dim: usize) -> i32 { + let mut sum = _mm512_setzero_si512(); + let chunks = dim / 64; + let remainder = dim % 64; + + for i in 0..chunks { + let offset = i * 64; + let va = _mm512_loadu_si512(a.add(offset) as *const __m512i); + let vb = _mm512_loadu_si512(b.add(offset) as *const __m512i); + + // VPDPBUSD: Multiply unsigned bytes by signed bytes, + // sum adjacent 4 products, accumulate to i32 + sum = _mm512_dpbusd_epi32(sum, va, vb); + } + + let mut result = _mm512_reduce_add_epi32(sum); + + // Handle remainder + let base = chunks * 64; + for i in 0..remainder { + result += (*a.add(base + i) as i32) * (*b.add(base + i) as i32); + } + + result +} + +// ============================================================================ +// Safe wrappers +// ============================================================================ + +/// Safe wrapper for L2 squared distance (f32). +#[inline] +pub fn l2_squared_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { l2_squared_f32_avx512bw(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + super::avx512::l2_squared_f32(a, b, dim) + } +} + +/// Safe wrapper for inner product (f32). +#[inline] +pub fn inner_product_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { inner_product_f32_avx512bw(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + super::avx512::inner_product_f32(a, b, dim) + } +} + +/// Safe wrapper for cosine distance (f32). +#[inline] +pub fn cosine_distance_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { cosine_distance_f32_avx512bw(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + super::avx512::cosine_distance_f32(a, b, dim) + } +} + +/// Safe wrapper for L2 squared distance for Int8 using VNNI. +#[inline] +pub fn l2_squared_i8(a: &[Int8], b: &[Int8], dim: usize) -> f32 { + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && is_x86_feature_detected!("avx512vnni") + { + let result = + unsafe { l2_squared_i8_avx512vnni(a.as_ptr() as *const i8, b.as_ptr() as *const i8, dim) }; + result as f32 + } else { + // Fallback to scalar + l2_squared_i8_scalar(a, b, dim) + } +} + +/// Safe wrapper for L2 squared distance for UInt8 using VNNI. +#[inline] +pub fn l2_squared_u8(a: &[UInt8], b: &[UInt8], dim: usize) -> f32 { + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && is_x86_feature_detected!("avx512vnni") + { + let result = + unsafe { l2_squared_u8_avx512vnni(a.as_ptr() as *const u8, b.as_ptr() as *const u8, dim) }; + result as f32 + } else { + l2_squared_u8_scalar(a, b, dim) + } +} + +/// Safe wrapper for inner product for Int8 using VNNI. +#[inline] +pub fn inner_product_i8(a: &[Int8], b: &[Int8], dim: usize) -> f32 { + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && is_x86_feature_detected!("avx512vnni") + { + let result = unsafe { + inner_product_i8_avx512vnni(a.as_ptr() as *const i8, b.as_ptr() as *const i8, dim) + }; + result as f32 + } else { + inner_product_i8_scalar(a, b, dim) + } +} + +/// Safe wrapper for inner product for UInt8 using VNNI. +#[inline] +pub fn inner_product_u8(a: &[UInt8], b: &[UInt8], dim: usize) -> f32 { + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && is_x86_feature_detected!("avx512vnni") + { + let result = unsafe { + inner_product_u8_avx512vnni(a.as_ptr() as *const u8, b.as_ptr() as *const u8, dim) + }; + result as f32 + } else { + inner_product_u8_scalar(a, b, dim) + } +} + +/// Safe wrapper for VNNI dot product (u8 × i8). +#[inline] +pub fn dot_product_u8_i8(a: &[UInt8], b: &[Int8], dim: usize) -> i32 { + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && is_x86_feature_detected!("avx512vnni") + { + unsafe { dot_product_u8_i8_avx512vnni(a.as_ptr() as *const u8, b.as_ptr() as *const i8, dim) } + } else { + dot_product_u8_i8_scalar(a, b, dim) + } +} + +// ============================================================================ +// Scalar fallbacks +// ============================================================================ + +/// Scalar L2 squared for Int8. +#[inline] +fn l2_squared_i8_scalar(a: &[Int8], b: &[Int8], dim: usize) -> f32 { + let mut sum: i32 = 0; + for i in 0..dim { + let diff = a[i].0 as i32 - b[i].0 as i32; + sum += diff * diff; + } + sum as f32 +} + +/// Scalar L2 squared for UInt8. +#[inline] +fn l2_squared_u8_scalar(a: &[UInt8], b: &[UInt8], dim: usize) -> f32 { + let mut sum: u32 = 0; + for i in 0..dim { + let diff = a[i].0 as i32 - b[i].0 as i32; + sum += (diff * diff) as u32; + } + sum as f32 +} + +/// Scalar inner product for Int8. +#[inline] +fn inner_product_i8_scalar(a: &[Int8], b: &[Int8], dim: usize) -> f32 { + let mut sum: i32 = 0; + for i in 0..dim { + sum += (a[i].0 as i32) * (b[i].0 as i32); + } + sum as f32 +} + +/// Scalar inner product for UInt8. +#[inline] +fn inner_product_u8_scalar(a: &[UInt8], b: &[UInt8], dim: usize) -> f32 { + let mut sum: u32 = 0; + for i in 0..dim { + sum += (a[i].0 as u32) * (b[i].0 as u32); + } + sum as f32 +} + +/// Scalar dot product for u8 × i8. +#[inline] +fn dot_product_u8_i8_scalar(a: &[UInt8], b: &[Int8], dim: usize) -> i32 { + let mut sum: i32 = 0; + for i in 0..dim { + sum += (a[i].0 as i32) * (b[i].0 as i32); + } + sum +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_avx512bw_l2_squared_f32() { + if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512bw") { + println!("AVX-512BW not available, skipping test"); + return; + } + + let a: Vec = (0..256).map(|i| i as f32).collect(); + let b: Vec = (0..256).map(|i| (i + 1) as f32).collect(); + + let result = l2_squared_f32::(&a, &b, 256); + let expected = crate::distance::l2::l2_squared_scalar(&a, &b, 256); + + assert!((result - expected).abs() < 0.1); + } + + #[test] + fn test_avx512vnni_l2_squared_i8() { + let a: Vec = (0..256).map(|i| Int8((i % 128) as i8)).collect(); + let b: Vec = (0..256).map(|i| Int8(((i + 1) % 128) as i8)).collect(); + + let vnni_result = l2_squared_i8(&a, &b, 256); + let scalar_result = l2_squared_i8_scalar(&a, &b, 256); + + assert!( + (vnni_result - scalar_result).abs() < 1.0, + "VNNI: {}, Scalar: {}", + vnni_result, + scalar_result + ); + } + + #[test] + fn test_avx512vnni_l2_squared_u8() { + let a: Vec = (0..256).map(|i| UInt8((i % 256) as u8)).collect(); + let b: Vec = (0..256).map(|i| UInt8(((i + 1) % 256) as u8)).collect(); + + let vnni_result = l2_squared_u8(&a, &b, 256); + let scalar_result = l2_squared_u8_scalar(&a, &b, 256); + + assert!( + (vnni_result - scalar_result).abs() < 1.0, + "VNNI: {}, Scalar: {}", + vnni_result, + scalar_result + ); + } + + #[test] + fn test_avx512vnni_inner_product_i8() { + let a: Vec = (0..256).map(|i| Int8((i % 64) as i8)).collect(); + let b: Vec = (0..256).map(|i| Int8((i % 64) as i8)).collect(); + + let vnni_result = inner_product_i8(&a, &b, 256); + let scalar_result = inner_product_i8_scalar(&a, &b, 256); + + assert!( + (vnni_result - scalar_result).abs() < 1.0, + "VNNI: {}, Scalar: {}", + vnni_result, + scalar_result + ); + } + + #[test] + fn test_avx512vnni_inner_product_u8() { + let a: Vec = (0..256).map(|i| UInt8((i % 256) as u8)).collect(); + let b: Vec = (0..256).map(|i| UInt8((i % 256) as u8)).collect(); + + let vnni_result = inner_product_u8(&a, &b, 256); + let scalar_result = inner_product_u8_scalar(&a, &b, 256); + + assert!( + (vnni_result - scalar_result).abs() < 1.0, + "VNNI: {}, Scalar: {}", + vnni_result, + scalar_result + ); + } + + #[test] + fn test_avx512vnni_dot_product_u8_i8() { + let a: Vec = (0..256).map(|i| UInt8((i % 256) as u8)).collect(); + let b: Vec = (0..256).map(|i| Int8((i % 128) as i8)).collect(); + + let vnni_result = dot_product_u8_i8(&a, &b, 256); + let scalar_result = dot_product_u8_i8_scalar(&a, &b, 256); + + assert_eq!( + vnni_result, scalar_result, + "VNNI: {}, Scalar: {}", + vnni_result, scalar_result + ); + } +} diff --git a/rust/vecsim/src/distance/simd/avx512fp16.rs b/rust/vecsim/src/distance/simd/avx512fp16.rs new file mode 100644 index 000000000..cc6563b26 --- /dev/null +++ b/rust/vecsim/src/distance/simd/avx512fp16.rs @@ -0,0 +1,402 @@ +//! AVX-512 FP16 optimized distance functions for Float16 vectors. +//! +//! This module provides SIMD-optimized distance computations for half-precision +//! floating point vectors using AVX-512 with F16C conversions. +//! +//! Since native AVX-512 FP16 arithmetic instructions are not yet stable in Rust, +//! this implementation uses AVX-512 with F16C conversion instructions: +//! - Load 16 fp16 values at a time +//! - Convert to f32 using _mm512_cvtph_ps +//! - Compute distances in f32 +//! +//! This provides significant speedups over scalar conversion-compute loops. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use crate::types::Float16; + +/// Check if AVX-512 with F16C is available at runtime. +#[cfg(target_arch = "x86_64")] +#[inline] +pub fn has_avx512_f16c() -> bool { + is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("f16c") +} + +/// Compute L2 squared distance between two Float16 vectors using AVX-512. +/// +/// # Safety +/// - Requires AVX-512F and F16C support +/// - Pointers must be valid for `dim` elements +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f", enable = "f16c")] +pub unsafe fn l2_squared_fp16_avx512( + a: *const Float16, + b: *const Float16, + dim: usize, +) -> f32 { + let a = a as *const u16; + let b = b as *const u16; + + let mut sum0 = _mm512_setzero_ps(); + let mut sum1 = _mm512_setzero_ps(); + + let unroll = dim / 32 * 32; + let mut offset = 0; + + // Process 32 elements per iteration (2x unroll) + while offset < unroll { + // Load 16 fp16 values (256 bits) for each vector + let a_half0 = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let a_half1 = _mm256_loadu_si256((a.add(offset + 16)) as *const __m256i); + let b_half0 = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + let b_half1 = _mm256_loadu_si256((b.add(offset + 16)) as *const __m256i); + + // Convert fp16 to f32 (expands 256-bit to 512-bit) + let a_f32_0 = _mm512_cvtph_ps(a_half0); + let a_f32_1 = _mm512_cvtph_ps(a_half1); + let b_f32_0 = _mm512_cvtph_ps(b_half0); + let b_f32_1 = _mm512_cvtph_ps(b_half1); + + // Compute differences + let diff0 = _mm512_sub_ps(a_f32_0, b_f32_0); + let diff1 = _mm512_sub_ps(a_f32_1, b_f32_1); + + // Accumulate squared differences using FMA + sum0 = _mm512_fmadd_ps(diff0, diff0, sum0); + sum1 = _mm512_fmadd_ps(diff1, diff1, sum1); + + offset += 32; + } + + // Process remaining 16-element chunks + while offset + 16 <= dim { + let a_half = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let b_half = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + + let a_f32 = _mm512_cvtph_ps(a_half); + let b_f32 = _mm512_cvtph_ps(b_half); + + let diff = _mm512_sub_ps(a_f32, b_f32); + sum0 = _mm512_fmadd_ps(diff, diff, sum0); + + offset += 16; + } + + // Reduce sums + let total = _mm512_add_ps(sum0, sum1); + let mut result = _mm512_reduce_add_ps(total); + + // Handle remaining elements with scalar code + while offset < dim { + let av = Float16::from_bits(*a.add(offset)).to_f32(); + let bv = Float16::from_bits(*b.add(offset)).to_f32(); + let diff = av - bv; + result += diff * diff; + offset += 1; + } + + result +} + +/// Compute inner product between two Float16 vectors using AVX-512. +/// +/// # Safety +/// - Requires AVX-512F and F16C support +/// - Pointers must be valid for `dim` elements +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f", enable = "f16c")] +pub unsafe fn inner_product_fp16_avx512( + a: *const Float16, + b: *const Float16, + dim: usize, +) -> f32 { + let a = a as *const u16; + let b = b as *const u16; + + let mut sum0 = _mm512_setzero_ps(); + let mut sum1 = _mm512_setzero_ps(); + + let unroll = dim / 32 * 32; + let mut offset = 0; + + // Process 32 elements per iteration + while offset < unroll { + let a_half0 = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let a_half1 = _mm256_loadu_si256((a.add(offset + 16)) as *const __m256i); + let b_half0 = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + let b_half1 = _mm256_loadu_si256((b.add(offset + 16)) as *const __m256i); + + let a_f32_0 = _mm512_cvtph_ps(a_half0); + let a_f32_1 = _mm512_cvtph_ps(a_half1); + let b_f32_0 = _mm512_cvtph_ps(b_half0); + let b_f32_1 = _mm512_cvtph_ps(b_half1); + + // Accumulate products using FMA + sum0 = _mm512_fmadd_ps(a_f32_0, b_f32_0, sum0); + sum1 = _mm512_fmadd_ps(a_f32_1, b_f32_1, sum1); + + offset += 32; + } + + // Process remaining 16-element chunks + while offset + 16 <= dim { + let a_half = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let b_half = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + + let a_f32 = _mm512_cvtph_ps(a_half); + let b_f32 = _mm512_cvtph_ps(b_half); + + sum0 = _mm512_fmadd_ps(a_f32, b_f32, sum0); + + offset += 16; + } + + let total = _mm512_add_ps(sum0, sum1); + let mut result = _mm512_reduce_add_ps(total); + + // Handle remaining elements + while offset < dim { + let av = Float16::from_bits(*a.add(offset)).to_f32(); + let bv = Float16::from_bits(*b.add(offset)).to_f32(); + result += av * bv; + offset += 1; + } + + result +} + +/// Compute cosine distance between two Float16 vectors using AVX-512. +/// +/// # Safety +/// - Requires AVX-512F and F16C support +/// - Pointers must be valid for `dim` elements +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f", enable = "f16c")] +pub unsafe fn cosine_distance_fp16_avx512( + a: *const Float16, + b: *const Float16, + dim: usize, +) -> f32 { + let a = a as *const u16; + let b = b as *const u16; + + let mut dot_sum = _mm512_setzero_ps(); + let mut norm_a_sum = _mm512_setzero_ps(); + let mut norm_b_sum = _mm512_setzero_ps(); + + let unroll = dim / 16 * 16; + let mut offset = 0; + + // Process 16 elements per iteration + while offset < unroll { + let a_half = _mm256_loadu_si256((a.add(offset)) as *const __m256i); + let b_half = _mm256_loadu_si256((b.add(offset)) as *const __m256i); + + let a_f32 = _mm512_cvtph_ps(a_half); + let b_f32 = _mm512_cvtph_ps(b_half); + + dot_sum = _mm512_fmadd_ps(a_f32, b_f32, dot_sum); + norm_a_sum = _mm512_fmadd_ps(a_f32, a_f32, norm_a_sum); + norm_b_sum = _mm512_fmadd_ps(b_f32, b_f32, norm_b_sum); + + offset += 16; + } + + let mut dot = _mm512_reduce_add_ps(dot_sum); + let mut norm_a = _mm512_reduce_add_ps(norm_a_sum); + let mut norm_b = _mm512_reduce_add_ps(norm_b_sum); + + // Handle remaining elements + while offset < dim { + let av = Float16::from_bits(*a.add(offset)).to_f32(); + let bv = Float16::from_bits(*b.add(offset)).to_f32(); + dot += av * bv; + norm_a += av * av; + norm_b += bv * bv; + offset += 1; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +// Safe wrappers with runtime dispatch + +/// Safe L2 squared distance for Float16 with automatic dispatch. +pub fn l2_squared_fp16(a: &[Float16], b: &[Float16], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + + #[cfg(target_arch = "x86_64")] + { + if has_avx512_f16c() { + return unsafe { l2_squared_fp16_avx512(a.as_ptr(), b.as_ptr(), dim) }; + } + } + + // Scalar fallback + l2_squared_fp16_scalar(a, b, dim) +} + +/// Safe inner product for Float16 with automatic dispatch. +pub fn inner_product_fp16(a: &[Float16], b: &[Float16], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + + #[cfg(target_arch = "x86_64")] + { + if has_avx512_f16c() { + return unsafe { inner_product_fp16_avx512(a.as_ptr(), b.as_ptr(), dim) }; + } + } + + // Scalar fallback + inner_product_fp16_scalar(a, b, dim) +} + +/// Safe cosine distance for Float16 with automatic dispatch. +pub fn cosine_distance_fp16(a: &[Float16], b: &[Float16], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + + #[cfg(target_arch = "x86_64")] + { + if has_avx512_f16c() { + return unsafe { cosine_distance_fp16_avx512(a.as_ptr(), b.as_ptr(), dim) }; + } + } + + // Scalar fallback + cosine_distance_fp16_scalar(a, b, dim) +} + +// Scalar implementations + +fn l2_squared_fp16_scalar(a: &[Float16], b: &[Float16], dim: usize) -> f32 { + let mut sum = 0.0f32; + for i in 0..dim { + let diff = a[i].to_f32() - b[i].to_f32(); + sum += diff * diff; + } + sum +} + +fn inner_product_fp16_scalar(a: &[Float16], b: &[Float16], dim: usize) -> f32 { + let mut sum = 0.0f32; + for i in 0..dim { + sum += a[i].to_f32() * b[i].to_f32(); + } + sum +} + +fn cosine_distance_fp16_scalar(a: &[Float16], b: &[Float16], dim: usize) -> f32 { + let mut dot = 0.0f32; + let mut norm_a = 0.0f32; + let mut norm_b = 0.0f32; + + for i in 0..dim { + let av = a[i].to_f32(); + let bv = b[i].to_f32(); + dot += av * bv; + norm_a += av * av; + norm_b += bv * bv; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_vectors(dim: usize) -> (Vec, Vec) { + let a: Vec = (0..dim) + .map(|i| Float16::from_f32((i as f32) / (dim as f32))) + .collect(); + let b: Vec = (0..dim) + .map(|i| Float16::from_f32(((dim - i) as f32) / (dim as f32))) + .collect(); + (a, b) + } + + #[test] + fn test_fp16_l2_squared() { + for &dim in &[16, 32, 64, 100, 128, 256] { + let (a, b) = create_test_vectors(dim); + let scalar_result = l2_squared_fp16_scalar(&a, &b, dim); + let simd_result = l2_squared_fp16(&a, &b, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.01, + "dim={}: scalar={}, simd={}", + dim, + scalar_result, + simd_result + ); + } + } + + #[test] + fn test_fp16_inner_product() { + for &dim in &[16, 32, 64, 100, 128, 256] { + let (a, b) = create_test_vectors(dim); + let scalar_result = inner_product_fp16_scalar(&a, &b, dim); + let simd_result = inner_product_fp16(&a, &b, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.01, + "dim={}: scalar={}, simd={}", + dim, + scalar_result, + simd_result + ); + } + } + + #[test] + fn test_fp16_cosine_distance() { + for &dim in &[16, 32, 64, 100, 128, 256] { + let (a, b) = create_test_vectors(dim); + let scalar_result = cosine_distance_fp16_scalar(&a, &b, dim); + let simd_result = cosine_distance_fp16(&a, &b, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.01, + "dim={}: scalar={}, simd={}", + dim, + scalar_result, + simd_result + ); + } + } + + #[test] + fn test_fp16_self_distance() { + let dim = 128; + let (a, _) = create_test_vectors(dim); + + // Self L2 distance should be 0 + let l2 = l2_squared_fp16(&a, &a, dim); + assert!(l2.abs() < 0.001, "Self L2 distance should be ~0, got {}", l2); + + // Self cosine distance should be 0 + let cosine = cosine_distance_fp16(&a, &a, dim); + assert!( + cosine.abs() < 0.001, + "Self cosine distance should be ~0, got {}", + cosine + ); + } +} diff --git a/rust/vecsim/src/distance/simd/cross_consistency_tests.rs b/rust/vecsim/src/distance/simd/cross_consistency_tests.rs new file mode 100644 index 000000000..10b3359bd --- /dev/null +++ b/rust/vecsim/src/distance/simd/cross_consistency_tests.rs @@ -0,0 +1,955 @@ +//! SIMD Cross-Consistency Tests +//! +//! These tests verify that all SIMD implementations produce results consistent +//! with the scalar implementation. This ensures correctness across different +//! hardware and SIMD instruction sets. +//! +//! Tests cover: +//! - All distance metrics (L2, inner product, cosine) +//! - Various vector dimensions (aligned and non-aligned) +//! - Edge cases (identical vectors, orthogonal vectors, zero vectors) +//! - Random data for statistical validation + +#![cfg(test)] +#![allow(unused_imports)] + +use crate::distance::l2::{l2_squared_scalar_f32, l2_squared_scalar_f64}; +use crate::distance::ip::inner_product_scalar_f32; +use crate::distance::cosine::cosine_distance_scalar_f32; + +/// Tolerance for f32 comparisons (SIMD operations may have different rounding) +const F32_TOLERANCE: f32 = 1e-4; +/// Tolerance for f64 comparisons +#[cfg(target_arch = "x86_64")] +const F64_TOLERANCE: f64 = 1e-10; +/// Relative tolerance for larger values +const RELATIVE_TOLERANCE: f64 = 1e-5; + +/// Check if two f32 values are approximately equal +fn approx_eq_f32(a: f32, b: f32, abs_tol: f32, rel_tol: f64) -> bool { + let diff = (a - b).abs(); + let max_val = a.abs().max(b.abs()); + diff <= abs_tol || diff <= (max_val as f64 * rel_tol) as f32 +} + +/// Check if two f64 values are approximately equal +#[cfg(target_arch = "x86_64")] +fn approx_eq_f64(a: f64, b: f64, abs_tol: f64, rel_tol: f64) -> bool { + let diff = (a - b).abs(); + let max_val = a.abs().max(b.abs()); + diff <= abs_tol || diff <= max_val * rel_tol +} + +/// Generate test vectors of given dimension with a pattern +fn generate_test_vectors_f32(dim: usize, seed: u32) -> (Vec, Vec) { + let mut a = Vec::with_capacity(dim); + let mut b = Vec::with_capacity(dim); + + for i in 0..dim { + // Use a simple deterministic pattern based on index and seed + let val_a = ((i as f32 + seed as f32) * 0.1).sin() * 10.0; + let val_b = ((i as f32 + seed as f32 + 1.0) * 0.15).cos() * 10.0; + a.push(val_a); + b.push(val_b); + } + + (a, b) +} + +/// Generate test vectors of given dimension with a pattern (f64) +#[cfg(target_arch = "x86_64")] +fn generate_test_vectors_f64(dim: usize, seed: u32) -> (Vec, Vec) { + let mut a = Vec::with_capacity(dim); + let mut b = Vec::with_capacity(dim); + + for i in 0..dim { + let val_a = ((i as f64 + seed as f64) * 0.1).sin() * 10.0; + let val_b = ((i as f64 + seed as f64 + 1.0) * 0.15).cos() * 10.0; + a.push(val_a); + b.push(val_b); + } + + (a, b) +} + +/// Dimensions to test - includes aligned and non-aligned sizes +const TEST_DIMENSIONS: &[usize] = &[ + 1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, + 63, 64, 65, 127, 128, 129, 255, 256, 257, 512, 768, 1024, +]; + +// ============================================================================= +// x86_64 SIMD Cross-Consistency Tests +// ============================================================================= + +#[cfg(target_arch = "x86_64")] +mod x86_64_tests { + use super::*; + + // ------------------------------------------------------------------------- + // SSE Tests + // ------------------------------------------------------------------------- + + #[test] + fn test_sse_l2_cross_consistency() { + if !is_x86_feature_detected!("sse") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = l2_squared_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::sse::l2_squared_f32_sse( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "SSE L2 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_sse_inner_product_cross_consistency() { + if !is_x86_feature_detected!("sse") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = inner_product_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::sse::inner_product_f32_sse( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "SSE inner product mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_sse_cosine_cross_consistency() { + if !is_x86_feature_detected!("sse") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = cosine_distance_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::sse::cosine_distance_f32_sse( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "SSE cosine mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + // ------------------------------------------------------------------------- + // SSE4.1 Tests + // ------------------------------------------------------------------------- + + #[test] + fn test_sse4_l2_cross_consistency() { + if !is_x86_feature_detected!("sse4.1") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = l2_squared_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::sse4::l2_squared_f32_sse4( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "SSE4 L2 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_sse4_inner_product_cross_consistency() { + if !is_x86_feature_detected!("sse4.1") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = inner_product_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::sse4::inner_product_f32_sse4( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "SSE4 inner product mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_sse4_cosine_cross_consistency() { + if !is_x86_feature_detected!("sse4.1") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = cosine_distance_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::sse4::cosine_distance_f32_sse4( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "SSE4 cosine mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + // ------------------------------------------------------------------------- + // AVX Tests + // ------------------------------------------------------------------------- + + #[test] + fn test_avx_l2_cross_consistency() { + if !is_x86_feature_detected!("avx") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = l2_squared_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx::l2_squared_f32_avx( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX L2 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_avx_inner_product_cross_consistency() { + if !is_x86_feature_detected!("avx") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = inner_product_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx::inner_product_f32_avx( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX inner product mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_avx_cosine_cross_consistency() { + if !is_x86_feature_detected!("avx") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = cosine_distance_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx::cosine_distance_f32_avx( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX cosine mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + // ------------------------------------------------------------------------- + // AVX2 Tests + // ------------------------------------------------------------------------- + + #[test] + fn test_avx2_l2_f32_cross_consistency() { + if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = l2_squared_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx2::l2_squared_f32_avx2( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX2 L2 f32 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_avx2_l2_f64_cross_consistency() { + if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f64(dim, seed); + + let scalar_result = l2_squared_scalar_f64(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx2::l2_squared_f64_avx2( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f64(scalar_result, simd_result, F64_TOLERANCE, RELATIVE_TOLERANCE), + "AVX2 L2 f64 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_avx2_inner_product_f32_cross_consistency() { + if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = inner_product_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx2::inner_product_f32_avx2( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX2 inner product f32 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_avx2_cosine_f32_cross_consistency() { + if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = cosine_distance_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx2::cosine_distance_f32_avx2( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX2 cosine f32 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + // ------------------------------------------------------------------------- + // AVX-512 Tests + // ------------------------------------------------------------------------- + + #[test] + fn test_avx512_l2_cross_consistency() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = l2_squared_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx512::l2_squared_f32_avx512( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX-512 L2 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_avx512_inner_product_cross_consistency() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = inner_product_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx512::inner_product_f32_avx512( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX-512 inner product mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_avx512_cosine_cross_consistency() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = cosine_distance_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx512::cosine_distance_f32_avx512( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX-512 cosine mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + // ------------------------------------------------------------------------- + // AVX-512 BW Tests + // ------------------------------------------------------------------------- + + #[test] + fn test_avx512bw_l2_cross_consistency() { + if !is_x86_feature_detected!("avx512f") || !is_x86_feature_detected!("avx512bw") { + return; + } + + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = l2_squared_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::avx512bw::l2_squared_f32_avx512bw( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX-512 BW L2 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + // ------------------------------------------------------------------------- + // Edge Case Tests + // ------------------------------------------------------------------------- + + #[test] + fn test_simd_identical_vectors() { + // All SIMD implementations should return 0 for identical vectors (L2) + let dims = [4, 8, 16, 32, 64, 128, 256]; + + for dim in dims { + let a: Vec = (0..dim).map(|i| (i as f32) * 0.1).collect(); + + // Scalar + let scalar_l2 = l2_squared_scalar_f32(&a, &a, dim); + assert!(scalar_l2.abs() < 1e-10, "Scalar L2 of identical vectors should be 0"); + + // SSE + if is_x86_feature_detected!("sse") { + let sse_l2 = unsafe { + crate::distance::simd::sse::l2_squared_f32_sse(a.as_ptr(), a.as_ptr(), dim) + }; + assert!(sse_l2.abs() < 1e-6, "SSE L2 of identical vectors should be ~0, got {}", sse_l2); + } + + // AVX2 + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let avx2_l2 = unsafe { + crate::distance::simd::avx2::l2_squared_f32_avx2(a.as_ptr(), a.as_ptr(), dim) + }; + assert!(avx2_l2.abs() < 1e-6, "AVX2 L2 of identical vectors should be ~0, got {}", avx2_l2); + } + + // AVX-512 + if is_x86_feature_detected!("avx512f") { + let avx512_l2 = unsafe { + crate::distance::simd::avx512::l2_squared_f32_avx512(a.as_ptr(), a.as_ptr(), dim) + }; + assert!(avx512_l2.abs() < 1e-6, "AVX-512 L2 of identical vectors should be ~0, got {}", avx512_l2); + } + } + } + + #[test] + fn test_simd_orthogonal_vectors_cosine() { + // Orthogonal vectors should have cosine distance of 1.0 + let a = vec![1.0f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]; + let b = vec![0.0f32, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; + let dim = 8; + + let scalar_cos = cosine_distance_scalar_f32(&a, &b, dim); + assert!((scalar_cos - 1.0).abs() < 0.01, "Scalar cosine of orthogonal vectors should be ~1.0"); + + if is_x86_feature_detected!("sse") { + let sse_cos = unsafe { + crate::distance::simd::sse::cosine_distance_f32_sse(a.as_ptr(), b.as_ptr(), dim) + }; + assert!((sse_cos - 1.0).abs() < 0.01, "SSE cosine of orthogonal vectors should be ~1.0, got {}", sse_cos); + } + + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let avx2_cos = unsafe { + crate::distance::simd::avx2::cosine_distance_f32_avx2(a.as_ptr(), b.as_ptr(), dim) + }; + assert!((avx2_cos - 1.0).abs() < 0.01, "AVX2 cosine of orthogonal vectors should be ~1.0, got {}", avx2_cos); + } + } + + #[test] + fn test_simd_opposite_vectors_cosine() { + // Opposite vectors should have cosine distance of 2.0 + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let b: Vec = a.iter().map(|x| -x).collect(); + let dim = 8; + + let scalar_cos = cosine_distance_scalar_f32(&a, &b, dim); + assert!((scalar_cos - 2.0).abs() < 0.01, "Scalar cosine of opposite vectors should be ~2.0"); + + if is_x86_feature_detected!("sse") { + let sse_cos = unsafe { + crate::distance::simd::sse::cosine_distance_f32_sse(a.as_ptr(), b.as_ptr(), dim) + }; + assert!((sse_cos - 2.0).abs() < 0.01, "SSE cosine of opposite vectors should be ~2.0, got {}", sse_cos); + } + + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let avx2_cos = unsafe { + crate::distance::simd::avx2::cosine_distance_f32_avx2(a.as_ptr(), b.as_ptr(), dim) + }; + assert!((avx2_cos - 2.0).abs() < 0.01, "AVX2 cosine of opposite vectors should be ~2.0, got {}", avx2_cos); + } + } + + #[test] + fn test_simd_all_implementations_agree() { + // Test that all available SIMD implementations produce consistent results + let dim = 128; + let (a, b) = generate_test_vectors_f32(dim, 42); + + let scalar_l2 = l2_squared_scalar_f32(&a, &b, dim); + let scalar_ip = inner_product_scalar_f32(&a, &b, dim); + let scalar_cos = cosine_distance_scalar_f32(&a, &b, dim); + + let mut results_l2: Vec<(&str, f32)> = vec![("scalar", scalar_l2)]; + let mut results_ip: Vec<(&str, f32)> = vec![("scalar", scalar_ip)]; + let mut results_cos: Vec<(&str, f32)> = vec![("scalar", scalar_cos)]; + + if is_x86_feature_detected!("sse") { + unsafe { + results_l2.push(("sse", crate::distance::simd::sse::l2_squared_f32_sse(a.as_ptr(), b.as_ptr(), dim))); + results_ip.push(("sse", crate::distance::simd::sse::inner_product_f32_sse(a.as_ptr(), b.as_ptr(), dim))); + results_cos.push(("sse", crate::distance::simd::sse::cosine_distance_f32_sse(a.as_ptr(), b.as_ptr(), dim))); + } + } + + if is_x86_feature_detected!("sse4.1") { + unsafe { + results_l2.push(("sse4", crate::distance::simd::sse4::l2_squared_f32_sse4(a.as_ptr(), b.as_ptr(), dim))); + results_ip.push(("sse4", crate::distance::simd::sse4::inner_product_f32_sse4(a.as_ptr(), b.as_ptr(), dim))); + results_cos.push(("sse4", crate::distance::simd::sse4::cosine_distance_f32_sse4(a.as_ptr(), b.as_ptr(), dim))); + } + } + + if is_x86_feature_detected!("avx") { + unsafe { + results_l2.push(("avx", crate::distance::simd::avx::l2_squared_f32_avx(a.as_ptr(), b.as_ptr(), dim))); + results_ip.push(("avx", crate::distance::simd::avx::inner_product_f32_avx(a.as_ptr(), b.as_ptr(), dim))); + results_cos.push(("avx", crate::distance::simd::avx::cosine_distance_f32_avx(a.as_ptr(), b.as_ptr(), dim))); + } + } + + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + unsafe { + results_l2.push(("avx2", crate::distance::simd::avx2::l2_squared_f32_avx2(a.as_ptr(), b.as_ptr(), dim))); + results_ip.push(("avx2", crate::distance::simd::avx2::inner_product_f32_avx2(a.as_ptr(), b.as_ptr(), dim))); + results_cos.push(("avx2", crate::distance::simd::avx2::cosine_distance_f32_avx2(a.as_ptr(), b.as_ptr(), dim))); + } + } + + if is_x86_feature_detected!("avx512f") { + unsafe { + results_l2.push(("avx512", crate::distance::simd::avx512::l2_squared_f32_avx512(a.as_ptr(), b.as_ptr(), dim))); + results_ip.push(("avx512", crate::distance::simd::avx512::inner_product_f32_avx512(a.as_ptr(), b.as_ptr(), dim))); + results_cos.push(("avx512", crate::distance::simd::avx512::cosine_distance_f32_avx512(a.as_ptr(), b.as_ptr(), dim))); + } + } + + // Verify all L2 results are consistent + for (name, result) in &results_l2[1..] { + assert!( + approx_eq_f32(scalar_l2, *result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "L2 mismatch between scalar ({}) and {} ({})", + scalar_l2, name, result + ); + } + + // Verify all inner product results are consistent + for (name, result) in &results_ip[1..] { + assert!( + approx_eq_f32(scalar_ip, *result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "Inner product mismatch between scalar ({}) and {} ({})", + scalar_ip, name, result + ); + } + + // Verify all cosine results are consistent + for (name, result) in &results_cos[1..] { + assert!( + approx_eq_f32(scalar_cos, *result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "Cosine mismatch between scalar ({}) and {} ({})", + scalar_cos, name, result + ); + } + } + + // ------------------------------------------------------------------------- + // Large Vector Tests + // ------------------------------------------------------------------------- + + #[test] + fn test_simd_large_vectors() { + // Test with large vectors typical of embeddings + let large_dims = [384, 512, 768, 1024, 1536]; + + for dim in large_dims { + let (a, b) = generate_test_vectors_f32(dim, 123); + + let scalar_l2 = l2_squared_scalar_f32(&a, &b, dim); + + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let avx2_l2 = unsafe { + crate::distance::simd::avx2::l2_squared_f32_avx2(a.as_ptr(), b.as_ptr(), dim) + }; + assert!( + approx_eq_f32(scalar_l2, avx2_l2, F32_TOLERANCE * 10.0, RELATIVE_TOLERANCE), + "AVX2 L2 mismatch for large dim={}: scalar={}, avx2={}", + dim, scalar_l2, avx2_l2 + ); + } + + if is_x86_feature_detected!("avx512f") { + let avx512_l2 = unsafe { + crate::distance::simd::avx512::l2_squared_f32_avx512(a.as_ptr(), b.as_ptr(), dim) + }; + assert!( + approx_eq_f32(scalar_l2, avx512_l2, F32_TOLERANCE * 10.0, RELATIVE_TOLERANCE), + "AVX-512 L2 mismatch for large dim={}: scalar={}, avx512={}", + dim, scalar_l2, avx512_l2 + ); + } + } + } + + #[test] + fn test_simd_small_values() { + // Test with very small values that might cause precision issues + let dim = 64; + let a: Vec = (0..dim).map(|i| (i as f32) * 1e-6).collect(); + let b: Vec = (0..dim).map(|i| (i as f32 + 0.5) * 1e-6).collect(); + + let scalar_l2 = l2_squared_scalar_f32(&a, &b, dim); + + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let avx2_l2 = unsafe { + crate::distance::simd::avx2::l2_squared_f32_avx2(a.as_ptr(), b.as_ptr(), dim) + }; + // Use relative tolerance for small values + let rel_diff = if scalar_l2.abs() > 1e-20 { + (scalar_l2 - avx2_l2).abs() / scalar_l2.abs() + } else { + (scalar_l2 - avx2_l2).abs() + }; + assert!( + rel_diff < 0.01, + "AVX2 L2 mismatch for small values: scalar={}, avx2={}, rel_diff={}", + scalar_l2, avx2_l2, rel_diff + ); + } + } + + #[test] + fn test_simd_large_values() { + // Test with large values + let dim = 64; + let a: Vec = (0..dim).map(|i| (i as f32) * 1e4).collect(); + let b: Vec = (0..dim).map(|i| (i as f32 + 0.5) * 1e4).collect(); + + let scalar_l2 = l2_squared_scalar_f32(&a, &b, dim); + + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let avx2_l2 = unsafe { + crate::distance::simd::avx2::l2_squared_f32_avx2(a.as_ptr(), b.as_ptr(), dim) + }; + assert!( + approx_eq_f32(scalar_l2, avx2_l2, 1.0, RELATIVE_TOLERANCE), + "AVX2 L2 mismatch for large values: scalar={}, avx2={}", + scalar_l2, avx2_l2 + ); + } + } + + #[test] + fn test_simd_negative_values() { + // Test with negative values + let dim = 64; + let a: Vec = (0..dim).map(|i| -((i as f32) * 0.1)).collect(); + let b: Vec = (0..dim).map(|i| ((i as f32) * 0.15)).collect(); + + let scalar_l2 = l2_squared_scalar_f32(&a, &b, dim); + let scalar_ip = inner_product_scalar_f32(&a, &b, dim); + + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + let avx2_l2 = unsafe { + crate::distance::simd::avx2::l2_squared_f32_avx2(a.as_ptr(), b.as_ptr(), dim) + }; + let avx2_ip = unsafe { + crate::distance::simd::avx2::inner_product_f32_avx2(a.as_ptr(), b.as_ptr(), dim) + }; + + assert!( + approx_eq_f32(scalar_l2, avx2_l2, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX2 L2 mismatch for negative values: scalar={}, avx2={}", + scalar_l2, avx2_l2 + ); + + assert!( + approx_eq_f32(scalar_ip, avx2_ip, F32_TOLERANCE, RELATIVE_TOLERANCE), + "AVX2 IP mismatch for negative values: scalar={}, avx2={}", + scalar_ip, avx2_ip + ); + } + } +} + +// ============================================================================= +// aarch64 NEON Cross-Consistency Tests +// ============================================================================= + +#[cfg(target_arch = "aarch64")] +mod aarch64_tests { + use super::*; + + #[test] + fn test_neon_l2_cross_consistency() { + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = l2_squared_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::neon::l2_squared_f32_neon( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "NEON L2 mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_neon_inner_product_cross_consistency() { + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = inner_product_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::neon::inner_product_f32_neon( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "NEON inner product mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_neon_cosine_cross_consistency() { + for &dim in TEST_DIMENSIONS { + for seed in 0..3 { + let (a, b) = generate_test_vectors_f32(dim, seed); + + let scalar_result = cosine_distance_scalar_f32(&a, &b, dim); + let simd_result = unsafe { + crate::distance::simd::neon::cosine_distance_f32_neon( + a.as_ptr(), b.as_ptr(), dim + ) + }; + + assert!( + approx_eq_f32(scalar_result, simd_result, F32_TOLERANCE, RELATIVE_TOLERANCE), + "NEON cosine mismatch at dim={}, seed={}: scalar={}, simd={}", + dim, seed, scalar_result, simd_result + ); + } + } + } + + #[test] + fn test_neon_identical_vectors() { + let dims = [4, 8, 16, 32, 64, 128]; + + for dim in dims { + let a: Vec = (0..dim).map(|i| (i as f32) * 0.1).collect(); + + let neon_l2 = unsafe { + crate::distance::simd::neon::l2_squared_f32_neon(a.as_ptr(), a.as_ptr(), dim) + }; + assert!(neon_l2.abs() < 1e-6, "NEON L2 of identical vectors should be ~0, got {}", neon_l2); + } + } + + #[test] + fn test_neon_large_vectors() { + let large_dims = [384, 512, 768, 1024]; + + for dim in large_dims { + let (a, b) = generate_test_vectors_f32(dim, 123); + + let scalar_l2 = l2_squared_scalar_f32(&a, &b, dim); + let neon_l2 = unsafe { + crate::distance::simd::neon::l2_squared_f32_neon(a.as_ptr(), b.as_ptr(), dim) + }; + + assert!( + approx_eq_f32(scalar_l2, neon_l2, F32_TOLERANCE * 10.0, RELATIVE_TOLERANCE), + "NEON L2 mismatch for large dim={}: scalar={}, neon={}", + dim, scalar_l2, neon_l2 + ); + } + } +} diff --git a/rust/vecsim/src/distance/simd/mod.rs b/rust/vecsim/src/distance/simd/mod.rs new file mode 100644 index 000000000..df0368645 --- /dev/null +++ b/rust/vecsim/src/distance/simd/mod.rs @@ -0,0 +1,181 @@ +//! SIMD-optimized distance function implementations. +//! +//! This module provides hardware-accelerated distance computations: +//! - AVX-512 VNNI (x86_64) - 512-bit vectors with VNNI for int8 operations +//! - AVX-512 BW (x86_64) - 512-bit vectors with byte/word operations +//! - AVX-512 (x86_64) - 512-bit vectors, 16 f32 at a time +//! - AVX-512 FP16 (x86_64) - 512-bit vectors for half-precision (Float16) +//! - AVX-512 BF16 (x86_64) - 512-bit vectors for bfloat16 (BFloat16) +//! - AVX2 (x86_64) - 256-bit vectors, 8 f32 at a time, with FMA +//! - AVX (x86_64) - 256-bit vectors, 8 f32 at a time, no FMA +//! - SSE (x86_64) - 128-bit vectors, 4 f32 at a time +//! - NEON (aarch64) - 128-bit vectors, 4 f32 at a time +//! +//! Runtime feature detection is used to select the best implementation. + +#[cfg(target_arch = "x86_64")] +pub mod avx; +#[cfg(target_arch = "x86_64")] +pub mod avx2; +#[cfg(target_arch = "x86_64")] +pub mod avx512; +#[cfg(target_arch = "x86_64")] +pub mod avx512bf16; +#[cfg(target_arch = "x86_64")] +pub mod avx512bw; +#[cfg(target_arch = "x86_64")] +pub mod avx512fp16; +#[cfg(target_arch = "x86_64")] +pub mod sse; +#[cfg(target_arch = "x86_64")] +pub mod sse4; +#[cfg(target_arch = "aarch64")] +pub mod neon; +#[cfg(target_arch = "aarch64")] +pub mod sve; + +#[cfg(test)] +mod cross_consistency_tests; + +/// SIMD capability levels. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SimdCapability { + /// No SIMD support. + None, + /// SSE (128-bit vectors). + #[cfg(target_arch = "x86_64")] + Sse, + /// SSE4.1 (128-bit vectors with dot product instruction). + #[cfg(target_arch = "x86_64")] + Sse4_1, + /// AVX (256-bit vectors, no FMA). + #[cfg(target_arch = "x86_64")] + Avx, + /// AVX2 (256-bit vectors, with FMA). + #[cfg(target_arch = "x86_64")] + Avx2, + /// AVX-512 (512-bit vectors). + #[cfg(target_arch = "x86_64")] + Avx512, + /// AVX-512 with BW (byte/word operations). + #[cfg(target_arch = "x86_64")] + Avx512Bw, + /// AVX-512 with VNNI (int8 neural network instructions). + #[cfg(target_arch = "x86_64")] + Avx512Vnni, + /// ARM NEON (128-bit vectors). + #[cfg(target_arch = "aarch64")] + Neon, +} + +/// Check if any SIMD capability is available. +pub fn is_simd_available() -> bool { + detect_simd_capability() != SimdCapability::None +} + +/// Detect the best available SIMD capability at runtime. +pub fn detect_simd_capability() -> SimdCapability { + #[cfg(target_arch = "x86_64")] + { + // Check for AVX-512 VNNI first (best for int8 operations) + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && is_x86_feature_detected!("avx512vnni") + { + return SimdCapability::Avx512Vnni; + } + // Check for AVX-512 BW (byte/word operations) + if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") { + return SimdCapability::Avx512Bw; + } + // Check for basic AVX-512 + if is_x86_feature_detected!("avx512f") { + return SimdCapability::Avx512; + } + // Fall back to AVX2 (with FMA) + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + return SimdCapability::Avx2; + } + // Fall back to AVX (without FMA) + if is_x86_feature_detected!("avx") { + return SimdCapability::Avx; + } + // Fall back to SSE4.1 (with dot product instruction) + if is_x86_feature_detected!("sse4.1") { + return SimdCapability::Sse4_1; + } + // Fall back to SSE (available on virtually all x86_64) + if is_x86_feature_detected!("sse") { + return SimdCapability::Sse; + } + } + + #[cfg(target_arch = "aarch64")] + { + // NEON is always available on aarch64 + return SimdCapability::Neon; + } + + #[allow(unreachable_code)] + SimdCapability::None +} + +/// Check if AVX-512 VNNI is available at runtime. +#[cfg(target_arch = "x86_64")] +#[inline] +pub fn has_avx512_vnni() -> bool { + is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && is_x86_feature_detected!("avx512vnni") +} + +/// Check if AVX-512 VNNI is available at runtime (stub for non-x86_64). +#[cfg(not(target_arch = "x86_64"))] +#[inline] +pub fn has_avx512_vnni() -> bool { + false +} + +/// Get the optimal vector alignment for the detected SIMD capability. +pub fn optimal_alignment() -> usize { + match detect_simd_capability() { + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512Vnni => 64, + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512Bw => 64, + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx512 => 64, + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx2 => 32, + #[cfg(target_arch = "x86_64")] + SimdCapability::Avx => 32, + #[cfg(target_arch = "x86_64")] + SimdCapability::Sse4_1 => 16, + #[cfg(target_arch = "x86_64")] + SimdCapability::Sse => 16, + #[cfg(target_arch = "aarch64")] + SimdCapability::Neon => 16, + SimdCapability::None => 8, + #[allow(unreachable_patterns)] + _ => 8, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_simd() { + let cap = detect_simd_capability(); + println!("Detected SIMD capability: {:?}", cap); + // Just ensure it doesn't crash + } + + #[test] + fn test_optimal_alignment() { + let align = optimal_alignment(); + assert!(align >= 8); + assert!(align.is_power_of_two()); + } +} diff --git a/rust/vecsim/src/distance/simd/neon.rs b/rust/vecsim/src/distance/simd/neon.rs new file mode 100644 index 000000000..1c96eccd4 --- /dev/null +++ b/rust/vecsim/src/distance/simd/neon.rs @@ -0,0 +1,856 @@ +//! ARM NEON SIMD implementations for distance functions. +//! +//! These functions use 128-bit NEON instructions for ARM processors. +//! Available on all aarch64 (ARM64) platforms. +//! +//! Additional optimizations: +//! - NEON DOTPROD (ARMv8.2-A+): Int8/UInt8 dot product instructions +//! Available on Apple M1+, AWS Graviton2+, and other modern ARM chips. + +use crate::types::{DistanceType, VectorElement}; + +use std::arch::aarch64::*; + +// ============================================================================= +// NEON Int8/UInt8 optimized operations +// ============================================================================= +// +// Note: NEON DOTPROD intrinsics (vdotq_s32, vdotq_u32) are currently unstable +// in Rust (see rust-lang/rust#117224). We use widening multiply-add instead, +// which is available on all NEON-capable ARM processors. + +/// Check if NEON DOTPROD is available at runtime. +/// Currently used for future optimization when the intrinsics stabilize. +#[inline] +pub fn has_dotprod() -> bool { + std::arch::is_aarch64_feature_detected!("dotprod") +} + +/// NEON inner product for i8 vectors using widening multiply-add. +/// +/// Processes 16 int8 elements per iteration using NEON SIMD. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn inner_product_i8_neon(a: *const i8, b: *const i8, dim: usize) -> i32 { + let mut sum0 = vdupq_n_s32(0); + let mut sum1 = vdupq_n_s32(0); + let chunks = dim / 16; + let remainder = dim % 16; + + for i in 0..chunks { + let offset = i * 16; + + // Load 16 int8 values + let va = vld1q_s8(a.add(offset)); + let vb = vld1q_s8(b.add(offset)); + + // Split into low and high halves + let va_lo = vget_low_s8(va); + let va_hi = vget_high_s8(va); + let vb_lo = vget_low_s8(vb); + let vb_hi = vget_high_s8(vb); + + // Widen to i16 and multiply + let prod_lo = vmull_s8(va_lo, vb_lo); // 8 x i16 + let prod_hi = vmull_s8(va_hi, vb_hi); // 8 x i16 + + // Pairwise add to i32 and accumulate + sum0 = vpadalq_s16(sum0, prod_lo); + sum1 = vpadalq_s16(sum1, prod_hi); + } + + // Combine accumulators and reduce + let sum = vaddq_s32(sum0, sum1); + let mut result = vaddvq_s32(sum); + + // Handle remainder + let base = chunks * 16; + for i in 0..remainder { + result += (*a.add(base + i) as i32) * (*b.add(base + i) as i32); + } + + result +} + +/// NEON inner product for u8 vectors using widening multiply-add. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn inner_product_u8_neon(a: *const u8, b: *const u8, dim: usize) -> u32 { + let mut sum0 = vdupq_n_u32(0); + let mut sum1 = vdupq_n_u32(0); + let chunks = dim / 16; + let remainder = dim % 16; + + for i in 0..chunks { + let offset = i * 16; + + let va = vld1q_u8(a.add(offset)); + let vb = vld1q_u8(b.add(offset)); + + let va_lo = vget_low_u8(va); + let va_hi = vget_high_u8(va); + let vb_lo = vget_low_u8(vb); + let vb_hi = vget_high_u8(vb); + + // Widen to u16 and multiply + let prod_lo = vmull_u8(va_lo, vb_lo); + let prod_hi = vmull_u8(va_hi, vb_hi); + + // Pairwise add to u32 and accumulate + sum0 = vpadalq_u16(sum0, prod_lo); + sum1 = vpadalq_u16(sum1, prod_hi); + } + + let sum = vaddq_u32(sum0, sum1); + let mut result = vaddvq_u32(sum); + + let base = chunks * 16; + for i in 0..remainder { + result += (*a.add(base + i) as u32) * (*b.add(base + i) as u32); + } + + result +} + +/// NEON L2 squared distance for i8 vectors. +/// +/// Computes ||a - b||² using widening operations. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn l2_squared_i8_neon(a: *const i8, b: *const i8, dim: usize) -> i32 { + let mut sum0 = vdupq_n_s32(0); + let mut sum1 = vdupq_n_s32(0); + let chunks = dim / 16; + let remainder = dim % 16; + + for i in 0..chunks { + let offset = i * 16; + + let va = vld1q_s8(a.add(offset)); + let vb = vld1q_s8(b.add(offset)); + + // Compute difference (saturating subtract, then widen) + let va_lo = vget_low_s8(va); + let va_hi = vget_high_s8(va); + let vb_lo = vget_low_s8(vb); + let vb_hi = vget_high_s8(vb); + + // Widen to i16 for subtraction to avoid overflow + let va_lo_16 = vmovl_s8(va_lo); + let va_hi_16 = vmovl_s8(va_hi); + let vb_lo_16 = vmovl_s8(vb_lo); + let vb_hi_16 = vmovl_s8(vb_hi); + + let diff_lo = vsubq_s16(va_lo_16, vb_lo_16); + let diff_hi = vsubq_s16(va_hi_16, vb_hi_16); + + // Square the differences (i16 * i16 -> i32) + let sq_lo_lo = vmull_s16(vget_low_s16(diff_lo), vget_low_s16(diff_lo)); + let sq_lo_hi = vmull_s16(vget_high_s16(diff_lo), vget_high_s16(diff_lo)); + let sq_hi_lo = vmull_s16(vget_low_s16(diff_hi), vget_low_s16(diff_hi)); + let sq_hi_hi = vmull_s16(vget_high_s16(diff_hi), vget_high_s16(diff_hi)); + + // Accumulate + sum0 = vaddq_s32(sum0, sq_lo_lo); + sum0 = vaddq_s32(sum0, sq_lo_hi); + sum1 = vaddq_s32(sum1, sq_hi_lo); + sum1 = vaddq_s32(sum1, sq_hi_hi); + } + + let sum = vaddq_s32(sum0, sum1); + let mut result = vaddvq_s32(sum); + + let base = chunks * 16; + for i in 0..remainder { + let diff = (*a.add(base + i) as i32) - (*b.add(base + i) as i32); + result += diff * diff; + } + + result +} + +/// NEON L2 squared distance for u8 vectors. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn l2_squared_u8_neon(a: *const u8, b: *const u8, dim: usize) -> i32 { + let mut sum0 = vdupq_n_s32(0); + let mut sum1 = vdupq_n_s32(0); + + let chunks = dim / 16; + let remainder = dim % 16; + + for i in 0..chunks { + let offset = i * 16; + let va = vld1q_u8(a.add(offset)); + let vb = vld1q_u8(b.add(offset)); + + // Compute absolute difference (won't overflow for u8) + let diff = vabdq_u8(va, vb); + + // Split into low and high halves, widen to u16, then square + let diff_lo = vget_low_u8(diff); + let diff_hi = vget_high_u8(diff); + + // Widen to u16 + let diff_lo_16 = vmovl_u8(diff_lo); + let diff_hi_16 = vmovl_u8(diff_hi); + + // Square (u16 * u16 -> u32) + let sq_lo = vmull_u16(vget_low_u16(diff_lo_16), vget_low_u16(diff_lo_16)); + let sq_hi = vmull_u16(vget_high_u16(diff_lo_16), vget_high_u16(diff_lo_16)); + let sq_lo2 = vmull_u16(vget_low_u16(diff_hi_16), vget_low_u16(diff_hi_16)); + let sq_hi2 = vmull_u16(vget_high_u16(diff_hi_16), vget_high_u16(diff_hi_16)); + + // Accumulate (reinterpret u32 as s32 for final sum) + sum0 = vaddq_s32(sum0, vreinterpretq_s32_u32(sq_lo)); + sum0 = vaddq_s32(sum0, vreinterpretq_s32_u32(sq_hi)); + sum1 = vaddq_s32(sum1, vreinterpretq_s32_u32(sq_lo2)); + sum1 = vaddq_s32(sum1, vreinterpretq_s32_u32(sq_hi2)); + } + + let sum = vaddq_s32(sum0, sum1); + let mut result = vaddvq_s32(sum); + + let base = chunks * 16; + for i in 0..remainder { + let diff = (*a.add(base + i) as i32) - (*b.add(base + i) as i32); + result += diff * diff; + } + + result +} + +/// Safe wrapper for i8 inner product. +#[inline] +pub fn inner_product_i8(a: &[i8], b: &[i8], dim: usize) -> i32 { + unsafe { inner_product_i8_neon(a.as_ptr(), b.as_ptr(), dim) } +} + +/// Safe wrapper for u8 inner product. +#[inline] +pub fn inner_product_u8(a: &[u8], b: &[u8], dim: usize) -> u32 { + unsafe { inner_product_u8_neon(a.as_ptr(), b.as_ptr(), dim) } +} + +/// Safe wrapper for i8 L2 squared. +#[inline] +pub fn l2_squared_i8(a: &[i8], b: &[i8], dim: usize) -> i32 { + unsafe { l2_squared_i8_neon(a.as_ptr(), b.as_ptr(), dim) } +} + +/// Safe wrapper for u8 L2 squared. +#[inline] +pub fn l2_squared_u8(a: &[u8], b: &[u8], dim: usize) -> i32 { + unsafe { l2_squared_u8_neon(a.as_ptr(), b.as_ptr(), dim) } +} + +// ============================================================================= +// NEON f64 SIMD operations +// ============================================================================= + +/// NEON L2 squared distance for f64 vectors. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn l2_squared_f64_neon(a: *const f64, b: *const f64, dim: usize) -> f64 { + let mut sum0 = vdupq_n_f64(0.0); + let mut sum1 = vdupq_n_f64(0.0); + let chunks = dim / 4; + let remainder = dim % 4; + + // Process 4 elements at a time (two 2-element vectors) + for i in 0..chunks { + let offset = i * 4; + + let va0 = vld1q_f64(a.add(offset)); + let vb0 = vld1q_f64(b.add(offset)); + let diff0 = vsubq_f64(va0, vb0); + sum0 = vfmaq_f64(sum0, diff0, diff0); + + let va1 = vld1q_f64(a.add(offset + 2)); + let vb1 = vld1q_f64(b.add(offset + 2)); + let diff1 = vsubq_f64(va1, vb1); + sum1 = vfmaq_f64(sum1, diff1, diff1); + } + + // Combine and reduce + let sum = vaddq_f64(sum0, sum1); + let mut result = vaddvq_f64(sum); + + // Handle remainder + let base = chunks * 4; + for i in 0..remainder { + let diff = *a.add(base + i) - *b.add(base + i); + result += diff * diff; + } + + result +} + +/// NEON inner product for f64 vectors. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn inner_product_f64_neon(a: *const f64, b: *const f64, dim: usize) -> f64 { + let mut sum0 = vdupq_n_f64(0.0); + let mut sum1 = vdupq_n_f64(0.0); + let chunks = dim / 4; + let remainder = dim % 4; + + for i in 0..chunks { + let offset = i * 4; + + let va0 = vld1q_f64(a.add(offset)); + let vb0 = vld1q_f64(b.add(offset)); + sum0 = vfmaq_f64(sum0, va0, vb0); + + let va1 = vld1q_f64(a.add(offset + 2)); + let vb1 = vld1q_f64(b.add(offset + 2)); + sum1 = vfmaq_f64(sum1, va1, vb1); + } + + let sum = vaddq_f64(sum0, sum1); + let mut result = vaddvq_f64(sum); + + let base = chunks * 4; + for i in 0..remainder { + result += *a.add(base + i) * *b.add(base + i); + } + + result +} + +/// NEON cosine distance for f64 vectors. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn cosine_distance_f64_neon(a: *const f64, b: *const f64, dim: usize) -> f64 { + let mut dot_sum = vdupq_n_f64(0.0); + let mut norm_a_sum = vdupq_n_f64(0.0); + let mut norm_b_sum = vdupq_n_f64(0.0); + + let chunks = dim / 2; + let remainder = dim % 2; + + for i in 0..chunks { + let offset = i * 2; + let va = vld1q_f64(a.add(offset)); + let vb = vld1q_f64(b.add(offset)); + + dot_sum = vfmaq_f64(dot_sum, va, vb); + norm_a_sum = vfmaq_f64(norm_a_sum, va, va); + norm_b_sum = vfmaq_f64(norm_b_sum, vb, vb); + } + + let mut dot = vaddvq_f64(dot_sum); + let mut norm_a = vaddvq_f64(norm_a_sum); + let mut norm_b = vaddvq_f64(norm_b_sum); + + let base = chunks * 2; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +/// Safe wrapper for f64 L2 squared distance. +#[inline] +pub fn l2_squared_f64(a: &[f64], b: &[f64], dim: usize) -> f64 { + unsafe { l2_squared_f64_neon(a.as_ptr(), b.as_ptr(), dim) } +} + +/// Safe wrapper for f64 inner product. +#[inline] +pub fn inner_product_f64(a: &[f64], b: &[f64], dim: usize) -> f64 { + unsafe { inner_product_f64_neon(a.as_ptr(), b.as_ptr(), dim) } +} + +/// Safe wrapper for f64 cosine distance. +#[inline] +pub fn cosine_distance_f64(a: &[f64], b: &[f64], dim: usize) -> f64 { + unsafe { cosine_distance_f64_neon(a.as_ptr(), b.as_ptr(), dim) } +} + +// ============================================================================= +// Original NEON f32 implementations +// ============================================================================= + +/// NEON L2 squared distance for f32 vectors. +/// +/// Uses 4 accumulators for better instruction-level parallelism (ILP), +/// processing 16 elements per iteration to match C++ performance. +/// +/// # Safety +/// - Pointers `a` and `b` must be valid for reads of `dim` f32 elements. +/// - Must only be called on aarch64 platforms with NEON support. +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn l2_squared_f32_neon(a: *const f32, b: *const f32, dim: usize) -> f32 { + // Use 4 accumulators for better ILP (matches C++ implementation) + let mut sum0 = vdupq_n_f32(0.0); + let mut sum1 = vdupq_n_f32(0.0); + let mut sum2 = vdupq_n_f32(0.0); + let mut sum3 = vdupq_n_f32(0.0); + + let chunks = dim / 16; + let remainder = dim % 16; + + // Process 16 elements at a time (four 4-element vectors) + for i in 0..chunks { + let offset = i * 16; + + let va0 = vld1q_f32(a.add(offset)); + let vb0 = vld1q_f32(b.add(offset)); + let diff0 = vsubq_f32(va0, vb0); + sum0 = vfmaq_f32(sum0, diff0, diff0); + + let va1 = vld1q_f32(a.add(offset + 4)); + let vb1 = vld1q_f32(b.add(offset + 4)); + let diff1 = vsubq_f32(va1, vb1); + sum1 = vfmaq_f32(sum1, diff1, diff1); + + let va2 = vld1q_f32(a.add(offset + 8)); + let vb2 = vld1q_f32(b.add(offset + 8)); + let diff2 = vsubq_f32(va2, vb2); + sum2 = vfmaq_f32(sum2, diff2, diff2); + + let va3 = vld1q_f32(a.add(offset + 12)); + let vb3 = vld1q_f32(b.add(offset + 12)); + let diff3 = vsubq_f32(va3, vb3); + sum3 = vfmaq_f32(sum3, diff3, diff3); + } + + // Handle remaining complete 4-element blocks (0-3 blocks) + let base = chunks * 16; + let remaining_chunks = remainder / 4; + + if remaining_chunks >= 1 { + let va = vld1q_f32(a.add(base)); + let vb = vld1q_f32(b.add(base)); + let diff = vsubq_f32(va, vb); + sum0 = vfmaq_f32(sum0, diff, diff); + } + if remaining_chunks >= 2 { + let va = vld1q_f32(a.add(base + 4)); + let vb = vld1q_f32(b.add(base + 4)); + let diff = vsubq_f32(va, vb); + sum1 = vfmaq_f32(sum1, diff, diff); + } + if remaining_chunks >= 3 { + let va = vld1q_f32(a.add(base + 8)); + let vb = vld1q_f32(b.add(base + 8)); + let diff = vsubq_f32(va, vb); + sum2 = vfmaq_f32(sum2, diff, diff); + } + + // Combine all four accumulators + let sum01 = vaddq_f32(sum0, sum1); + let sum23 = vaddq_f32(sum2, sum3); + let sum = vaddq_f32(sum01, sum23); + + // Horizontal reduction using pairwise adds (matches C++ pattern) + let sum_halves = vadd_f32(vget_low_f32(sum), vget_high_f32(sum)); + let summed = vpadd_f32(sum_halves, sum_halves); + let mut result = vget_lane_f32::<0>(summed); + + // Handle final remainder (0-3 elements) + let final_base = base + remaining_chunks * 4; + let final_remainder = remainder % 4; + for i in 0..final_remainder { + let diff = *a.add(final_base + i) - *b.add(final_base + i); + result += diff * diff; + } + + result +} + +/// NEON inner product for f32 vectors. +/// +/// Uses 4 accumulators for better instruction-level parallelism (ILP), +/// processing 16 elements per iteration to match C++ performance. +/// +/// # Safety +/// - Pointers `a` and `b` must be valid for reads of `dim` f32 elements. +/// - Must only be called on aarch64 platforms with NEON support. +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn inner_product_f32_neon(a: *const f32, b: *const f32, dim: usize) -> f32 { + // Use 4 accumulators for better ILP + let mut sum0 = vdupq_n_f32(0.0); + let mut sum1 = vdupq_n_f32(0.0); + let mut sum2 = vdupq_n_f32(0.0); + let mut sum3 = vdupq_n_f32(0.0); + + let chunks = dim / 16; + let remainder = dim % 16; + + // Process 16 elements at a time (four 4-element vectors) + for i in 0..chunks { + let offset = i * 16; + + let va0 = vld1q_f32(a.add(offset)); + let vb0 = vld1q_f32(b.add(offset)); + sum0 = vfmaq_f32(sum0, va0, vb0); + + let va1 = vld1q_f32(a.add(offset + 4)); + let vb1 = vld1q_f32(b.add(offset + 4)); + sum1 = vfmaq_f32(sum1, va1, vb1); + + let va2 = vld1q_f32(a.add(offset + 8)); + let vb2 = vld1q_f32(b.add(offset + 8)); + sum2 = vfmaq_f32(sum2, va2, vb2); + + let va3 = vld1q_f32(a.add(offset + 12)); + let vb3 = vld1q_f32(b.add(offset + 12)); + sum3 = vfmaq_f32(sum3, va3, vb3); + } + + // Handle remaining complete 4-element blocks + let base = chunks * 16; + let remaining_chunks = remainder / 4; + + if remaining_chunks >= 1 { + let va = vld1q_f32(a.add(base)); + let vb = vld1q_f32(b.add(base)); + sum0 = vfmaq_f32(sum0, va, vb); + } + if remaining_chunks >= 2 { + let va = vld1q_f32(a.add(base + 4)); + let vb = vld1q_f32(b.add(base + 4)); + sum1 = vfmaq_f32(sum1, va, vb); + } + if remaining_chunks >= 3 { + let va = vld1q_f32(a.add(base + 8)); + let vb = vld1q_f32(b.add(base + 8)); + sum2 = vfmaq_f32(sum2, va, vb); + } + + // Combine all four accumulators + let sum01 = vaddq_f32(sum0, sum1); + let sum23 = vaddq_f32(sum2, sum3); + let sum = vaddq_f32(sum01, sum23); + + // Horizontal reduction using pairwise adds + let sum_halves = vadd_f32(vget_low_f32(sum), vget_high_f32(sum)); + let summed = vpadd_f32(sum_halves, sum_halves); + let mut result = vget_lane_f32::<0>(summed); + + // Handle final remainder (0-3 elements) + let final_base = base + remaining_chunks * 4; + let final_remainder = remainder % 4; + for i in 0..final_remainder { + result += *a.add(final_base + i) * *b.add(final_base + i); + } + + result +} + +/// NEON cosine distance for f32 vectors. +/// +/// # Safety +/// - Pointers `a` and `b` must be valid for reads of `dim` f32 elements. +/// - Must only be called on aarch64 platforms with NEON support. +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn cosine_distance_f32_neon(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut dot_sum = vdupq_n_f32(0.0); + let mut norm_a_sum = vdupq_n_f32(0.0); + let mut norm_b_sum = vdupq_n_f32(0.0); + + let chunks = dim / 4; + let remainder = dim % 4; + + // Process 4 elements at a time + for i in 0..chunks { + let offset = i * 4; + let va = vld1q_f32(a.add(offset)); + let vb = vld1q_f32(b.add(offset)); + + dot_sum = vfmaq_f32(dot_sum, va, vb); + norm_a_sum = vfmaq_f32(norm_a_sum, va, va); + norm_b_sum = vfmaq_f32(norm_b_sum, vb, vb); + } + + // Reduce to scalars + let mut dot = vaddvq_f32(dot_sum); + let mut norm_a = vaddvq_f32(norm_a_sum); + let mut norm_b = vaddvq_f32(norm_b_sum); + + // Handle remainder + let base = chunks * 4; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +/// Safe wrapper for L2 squared distance. +/// Specialized for f32 to avoid allocation when input is already f32. +#[inline] +pub fn l2_squared_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + // Fast path: if T is f32, avoid allocation by reinterpreting the slices + if std::any::TypeId::of::() == std::any::TypeId::of::() { + // SAFETY: We verified T is f32, so this reinterpret is safe + let a_f32 = unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f32, dim) }; + let b_f32 = unsafe { std::slice::from_raw_parts(b.as_ptr() as *const f32, dim) }; + let result = unsafe { l2_squared_f32_neon(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + return T::DistanceType::from_f64(result as f64); + } + + // Slow path: convert to f32 + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { l2_squared_f32_neon(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) +} + +/// Safe wrapper for inner product. +/// Specialized for f32 to avoid allocation when input is already f32. +#[inline] +pub fn inner_product_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + // Fast path: if T is f32, avoid allocation + if std::any::TypeId::of::() == std::any::TypeId::of::() { + let a_f32 = unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f32, dim) }; + let b_f32 = unsafe { std::slice::from_raw_parts(b.as_ptr() as *const f32, dim) }; + let result = unsafe { inner_product_f32_neon(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + return T::DistanceType::from_f64(result as f64); + } + + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { inner_product_f32_neon(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) +} + +/// Safe wrapper for cosine distance. +/// Specialized for f32 to avoid allocation when input is already f32. +#[inline] +pub fn cosine_distance_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + // Fast path: if T is f32, avoid allocation + if std::any::TypeId::of::() == std::any::TypeId::of::() { + let a_f32 = unsafe { std::slice::from_raw_parts(a.as_ptr() as *const f32, dim) }; + let b_f32 = unsafe { std::slice::from_raw_parts(b.as_ptr() as *const f32, dim) }; + let result = unsafe { cosine_distance_f32_neon(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + return T::DistanceType::from_f64(result as f64); + } + + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { cosine_distance_f32_neon(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_neon_l2_squared() { + let a: Vec = (0..128).map(|i| i as f32).collect(); + let b: Vec = (0..128).map(|i| (i + 1) as f32).collect(); + + let neon_result = l2_squared_f32::(&a, &b, 128); + let scalar_result = crate::distance::l2::l2_squared_scalar(&a, &b, 128); + + assert!((neon_result - scalar_result).abs() < 0.1); + } + + #[test] + fn test_neon_inner_product() { + let a: Vec = (0..128).map(|i| i as f32 / 100.0).collect(); + let b: Vec = (0..128).map(|i| (128 - i) as f32 / 100.0).collect(); + + let neon_result = inner_product_f32::(&a, &b, 128); + let scalar_result = crate::distance::ip::inner_product_scalar(&a, &b, 128); + + assert!((neon_result - scalar_result).abs() < 0.01); + } + + // NEON DOTPROD tests (int8/uint8) + #[test] + fn test_dotprod_inner_product_i8() { + let a: Vec = (0..128).map(|i| (i % 127) as i8).collect(); + let b: Vec = (0..128).map(|i| ((128 - i) % 127) as i8).collect(); + + let result = inner_product_i8(&a, &b, 128); + + // Verify against scalar + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| (x as i32) * (y as i32)) + .sum(); + + assert_eq!(result, expected); + } + + #[test] + fn test_dotprod_inner_product_u8() { + let a: Vec = (0..128).map(|i| i as u8).collect(); + let b: Vec = (0..128).map(|i| (128 - i) as u8).collect(); + + let result = inner_product_u8(&a, &b, 128); + + let expected: u32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| (x as u32) * (y as u32)) + .sum(); + + assert_eq!(result, expected); + } + + #[test] + fn test_dotprod_l2_squared_i8() { + let a: Vec = (0..128).map(|i| (i % 64) as i8).collect(); + let b: Vec = (0..128).map(|i| ((i + 1) % 64) as i8).collect(); + + let result = l2_squared_i8(&a, &b, 128); + + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| { + let diff = (x as i32) - (y as i32); + diff * diff + }) + .sum(); + + assert_eq!(result, expected); + } + + #[test] + fn test_dotprod_l2_squared_u8() { + let a: Vec = (0..128).map(|i| i as u8).collect(); + let b: Vec = (0..128).map(|i| (i + 1) as u8).collect(); + + let result = l2_squared_u8(&a, &b, 128); + + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| { + let diff = (x as i32) - (y as i32); + diff * diff + }) + .sum(); + + assert_eq!(result, expected); + } + + #[test] + fn test_dotprod_remainder_handling() { + // Test with non-aligned dimensions + for dim in [17, 33, 65, 100, 127] { + let a: Vec = (0..dim).map(|i| (i % 100) as i8).collect(); + let b: Vec = (0..dim).map(|i| ((i * 2) % 100) as i8).collect(); + + let result = inner_product_i8(&a, &b, dim); + + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| (x as i32) * (y as i32)) + .sum(); + + assert_eq!(result, expected, "Failed for dim={}", dim); + } + } + + // NEON f64 tests + #[test] + fn test_neon_l2_squared_f64() { + let a: Vec = (0..128).map(|i| i as f64).collect(); + let b: Vec = (0..128).map(|i| (i + 1) as f64).collect(); + + let neon_result = l2_squared_f64(&a, &b, 128); + + // Should be 128.0 (each diff is 1, squared is 1, sum of 128 ones) + assert!((neon_result - 128.0).abs() < 1e-10); + } + + #[test] + fn test_neon_inner_product_f64() { + let a: Vec = (0..128).map(|i| i as f64 / 100.0).collect(); + let b: Vec = (0..128).map(|i| (128 - i) as f64 / 100.0).collect(); + + let neon_result = inner_product_f64(&a, &b, 128); + + let expected: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum(); + + assert!((neon_result - expected).abs() < 1e-10); + } + + #[test] + fn test_neon_cosine_f64() { + // Identical vectors should have distance 0 + let a: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let result = cosine_distance_f64(&a, &a, 4); + assert!(result.abs() < 1e-10); + + // Orthogonal vectors should have distance 1 + let a: Vec = vec![1.0, 0.0, 0.0, 0.0]; + let b: Vec = vec![0.0, 1.0, 0.0, 0.0]; + let result = cosine_distance_f64(&a, &b, 4); + assert!((result - 1.0).abs() < 1e-10); + } + + #[test] + fn test_neon_f64_remainder() { + // Test with non-aligned dimension + let a: Vec = (0..131).map(|i| i as f64).collect(); + let b: Vec = (0..131).map(|i| (i + 1) as f64).collect(); + + let neon_result = l2_squared_f64(&a, &b, 131); + assert!((neon_result - 131.0).abs() < 1e-10); + } + + #[test] + fn test_has_dotprod() { + // Just verify the detection doesn't crash + let _ = has_dotprod(); + } +} diff --git a/rust/vecsim/src/distance/simd/sse.rs b/rust/vecsim/src/distance/simd/sse.rs new file mode 100644 index 000000000..0c6d60ad6 --- /dev/null +++ b/rust/vecsim/src/distance/simd/sse.rs @@ -0,0 +1,292 @@ +//! SSE (Streaming SIMD Extensions) optimized distance functions. +//! +//! SSE provides 128-bit vectors, processing 4 f32 values at a time. +//! This is available on virtually all x86-64 processors. + +#![cfg(target_arch = "x86_64")] + +use crate::types::{DistanceType, VectorElement}; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +/// Compute squared L2 distance using SSE. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - SSE must be available (checked at runtime by caller) +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "sse")] +#[inline] +pub unsafe fn l2_squared_f32_sse(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm_setzero_ps(); + let chunks = dim / 4; + let remainder = dim % 4; + + // Process 4 elements at a time + for i in 0..chunks { + let offset = i * 4; + let va = _mm_loadu_ps(a.add(offset)); + let vb = _mm_loadu_ps(b.add(offset)); + let diff = _mm_sub_ps(va, vb); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + + // Horizontal sum of the 4 floats in sum + // sum = [a, b, c, d] + // After first hadd: [a+b, c+d, a+b, c+d] + // After second hadd: [a+b+c+d, ...] + let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01); // [b, a, d, c] + let sums = _mm_add_ps(sum, shuf); // [a+b, a+b, c+d, c+d] + let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, c+d, ?, ?] + let result = _mm_add_ss(sums, shuf2); // [a+b+c+d, ?, ?, ?] + let mut total = _mm_cvtss_f32(result); + + // Handle remainder + let base = chunks * 4; + for i in 0..remainder { + let diff = *a.add(base + i) - *b.add(base + i); + total += diff * diff; + } + + total +} + +/// Compute inner product using SSE. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - SSE must be available (checked at runtime by caller) +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "sse")] +#[inline] +pub unsafe fn inner_product_f32_sse(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum = _mm_setzero_ps(); + let chunks = dim / 4; + let remainder = dim % 4; + + // Process 4 elements at a time + for i in 0..chunks { + let offset = i * 4; + let va = _mm_loadu_ps(a.add(offset)); + let vb = _mm_loadu_ps(b.add(offset)); + sum = _mm_add_ps(sum, _mm_mul_ps(va, vb)); + } + + // Horizontal sum + let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01); + let sums = _mm_add_ps(sum, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let result = _mm_add_ss(sums, shuf2); + let mut total = _mm_cvtss_f32(result); + + // Handle remainder + let base = chunks * 4; + for i in 0..remainder { + total += *a.add(base + i) * *b.add(base + i); + } + + total +} + +/// Compute cosine distance using SSE. +/// Returns 1 - cosine_similarity. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - SSE must be available (checked at runtime by caller) +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "sse")] +#[inline] +pub unsafe fn cosine_f32_sse(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut dot_sum = _mm_setzero_ps(); + let mut norm_a_sum = _mm_setzero_ps(); + let mut norm_b_sum = _mm_setzero_ps(); + + let chunks = dim / 4; + let remainder = dim % 4; + + // Process 4 elements at a time + for i in 0..chunks { + let offset = i * 4; + let va = _mm_loadu_ps(a.add(offset)); + let vb = _mm_loadu_ps(b.add(offset)); + + dot_sum = _mm_add_ps(dot_sum, _mm_mul_ps(va, vb)); + norm_a_sum = _mm_add_ps(norm_a_sum, _mm_mul_ps(va, va)); + norm_b_sum = _mm_add_ps(norm_b_sum, _mm_mul_ps(vb, vb)); + } + + // Horizontal sums + let hsum = |v: __m128| -> f32 { + let shuf = _mm_shuffle_ps(v, v, 0b10_11_00_01); + let sums = _mm_add_ps(v, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let result = _mm_add_ss(sums, shuf2); + _mm_cvtss_f32(result) + }; + + let mut dot = hsum(dot_sum); + let mut norm_a = hsum(norm_a_sum); + let mut norm_b = hsum(norm_b_sum); + + // Handle remainder + let base = chunks * 4; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom > 0.0 { + 1.0 - (dot / denom) + } else { + 0.0 + } +} + +/// Safe wrapper for L2 squared distance. +#[inline] +pub fn l2_squared_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("sse") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { l2_squared_f32_sse(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::l2::l2_squared_scalar(a, b, dim) + } +} + +/// Safe wrapper for inner product. +#[inline] +pub fn inner_product_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("sse") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { inner_product_f32_sse(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::ip::inner_product_scalar(a, b, dim) + } +} + +/// Safe wrapper for cosine distance. +#[inline] +pub fn cosine_distance_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("sse") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { cosine_f32_sse(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::cosine::cosine_distance_scalar(a, b, dim) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn is_sse_available() -> bool { + is_x86_feature_detected!("sse") + } + + #[test] + fn test_l2_squared_sse() { + if !is_sse_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + + let result = unsafe { l2_squared_f32_sse(a.as_ptr(), b.as_ptr(), a.len()) }; + // Each difference is 1.0, so squared sum = 8.0 + assert!((result - 8.0).abs() < 1e-6); + } + + #[test] + fn test_l2_squared_sse_remainder() { + if !is_sse_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; + let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0]; + + let result = unsafe { l2_squared_f32_sse(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!((result - 5.0).abs() < 1e-6); + } + + #[test] + fn test_inner_product_sse() { + if !is_sse_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = vec![2.0f32, 3.0, 4.0, 5.0]; + + let result = unsafe { inner_product_f32_sse(a.as_ptr(), b.as_ptr(), a.len()) }; + // 1*2 + 2*3 + 3*4 + 4*5 = 2 + 6 + 12 + 20 = 40 + assert!((result - 40.0).abs() < 1e-6); + } + + #[test] + fn test_cosine_sse() { + if !is_sse_available() { + return; + } + + // Test with identical normalized vectors (cosine distance = 0) + let a = vec![0.6f32, 0.8, 0.0, 0.0]; + let b = vec![0.6f32, 0.8, 0.0, 0.0]; + + let result = unsafe { cosine_f32_sse(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!(result.abs() < 1e-6); + + // Test with orthogonal vectors (cosine distance = 1) + let a = vec![1.0f32, 0.0, 0.0, 0.0]; + let b = vec![0.0f32, 1.0, 0.0, 0.0]; + + let result = unsafe { cosine_f32_sse(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!((result - 1.0).abs() < 1e-6); + } + + #[test] + fn test_l2_safe_wrapper() { + if !is_sse_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f32).collect(); + let b: Vec = (0..128).map(|i| (i + 1) as f32).collect(); + + let sse_result = l2_squared_f32::(&a, &b, 128); + let scalar_result = crate::distance::l2::l2_squared_scalar(&a, &b, 128); + + assert!((sse_result - scalar_result).abs() < 0.1); + } + + #[test] + fn test_inner_product_safe_wrapper() { + if !is_sse_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f32 / 100.0).collect(); + let b: Vec = (0..128).map(|i| (128 - i) as f32 / 100.0).collect(); + + let sse_result = inner_product_f32::(&a, &b, 128); + let scalar_result = crate::distance::ip::inner_product_scalar(&a, &b, 128); + + assert!((sse_result - scalar_result).abs() < 0.01); + } +} diff --git a/rust/vecsim/src/distance/simd/sse4.rs b/rust/vecsim/src/distance/simd/sse4.rs new file mode 100644 index 000000000..d8c2a9c09 --- /dev/null +++ b/rust/vecsim/src/distance/simd/sse4.rs @@ -0,0 +1,324 @@ +//! SSE4.1 optimized distance functions. +//! +//! SSE4.1 provides the `_mm_dp_ps` instruction for single-instruction dot products, +//! which is significantly faster than manual multiply-accumulate for inner products. +//! +//! Available on Intel Penryn+ (2008) and AMD Bulldozer+ (2011). + +#![cfg(target_arch = "x86_64")] + +use crate::types::{DistanceType, VectorElement}; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +/// SSE4.1 L2 squared distance for f32 vectors. +/// +/// Uses SSE4.1 `_mm_dp_ps` for efficient dot product computation. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - SSE4.1 must be available (checked at runtime by caller) +#[target_feature(enable = "sse4.1")] +#[inline] +pub unsafe fn l2_squared_f32_sse4(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut total = 0.0f32; + let chunks = dim / 4; + let remainder = dim % 4; + + // Process 4 elements at a time + for i in 0..chunks { + let offset = i * 4; + let va = _mm_loadu_ps(a.add(offset)); + let vb = _mm_loadu_ps(b.add(offset)); + let diff = _mm_sub_ps(va, vb); + // _mm_dp_ps with mask 0xF1: multiply all 4 pairs, sum to lowest element + let sq_sum = _mm_dp_ps(diff, diff, 0xF1); + total += _mm_cvtss_f32(sq_sum); + } + + // Handle remainder + let base = chunks * 4; + for i in 0..remainder { + let diff = *a.add(base + i) - *b.add(base + i); + total += diff * diff; + } + + total +} + +/// SSE4.1 inner product for f32 vectors. +/// +/// Uses SSE4.1 `_mm_dp_ps` for efficient dot product computation. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - SSE4.1 must be available (checked at runtime by caller) +#[target_feature(enable = "sse4.1")] +#[inline] +pub unsafe fn inner_product_f32_sse4(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut total = 0.0f32; + let chunks = dim / 4; + let remainder = dim % 4; + + // Process 4 elements at a time using _mm_dp_ps + for i in 0..chunks { + let offset = i * 4; + let va = _mm_loadu_ps(a.add(offset)); + let vb = _mm_loadu_ps(b.add(offset)); + // _mm_dp_ps with mask 0xF1: multiply all 4 pairs, sum to lowest element + let dp = _mm_dp_ps(va, vb, 0xF1); + total += _mm_cvtss_f32(dp); + } + + // Handle remainder + let base = chunks * 4; + for i in 0..remainder { + total += *a.add(base + i) * *b.add(base + i); + } + + total +} + +/// SSE4.1 cosine distance for f32 vectors. +/// +/// Uses SSE4.1 `_mm_dp_ps` for efficient dot product computation. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - SSE4.1 must be available (checked at runtime by caller) +#[target_feature(enable = "sse4.1")] +#[inline] +pub unsafe fn cosine_distance_f32_sse4(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut dot = 0.0f32; + let mut norm_a = 0.0f32; + let mut norm_b = 0.0f32; + + let chunks = dim / 4; + let remainder = dim % 4; + + // Process 4 elements at a time using _mm_dp_ps + for i in 0..chunks { + let offset = i * 4; + let va = _mm_loadu_ps(a.add(offset)); + let vb = _mm_loadu_ps(b.add(offset)); + + // Dot product a·b + let dp = _mm_dp_ps(va, vb, 0xF1); + dot += _mm_cvtss_f32(dp); + + // Norm squared ||a||² + let na = _mm_dp_ps(va, va, 0xF1); + norm_a += _mm_cvtss_f32(na); + + // Norm squared ||b||² + let nb = _mm_dp_ps(vb, vb, 0xF1); + norm_b += _mm_cvtss_f32(nb); + } + + // Handle remainder + let base = chunks * 4; + for i in 0..remainder { + let va = *a.add(base + i); + let vb = *b.add(base + i); + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).max(-1.0).min(1.0); + 1.0 - cosine_sim +} + +/// Safe wrapper for L2 squared distance. +#[inline] +pub fn l2_squared_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("sse4.1") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { l2_squared_f32_sse4(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::l2::l2_squared_scalar(a, b, dim) + } +} + +/// Safe wrapper for inner product. +#[inline] +pub fn inner_product_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("sse4.1") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { inner_product_f32_sse4(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::ip::inner_product_scalar(a, b, dim) + } +} + +/// Safe wrapper for cosine distance. +#[inline] +pub fn cosine_distance_f32(a: &[T], b: &[T], dim: usize) -> T::DistanceType { + if is_x86_feature_detected!("sse4.1") { + let a_f32: Vec = a.iter().map(|x| x.to_f32()).collect(); + let b_f32: Vec = b.iter().map(|x| x.to_f32()).collect(); + + let result = unsafe { cosine_distance_f32_sse4(a_f32.as_ptr(), b_f32.as_ptr(), dim) }; + T::DistanceType::from_f64(result as f64) + } else { + crate::distance::cosine::cosine_distance_scalar(a, b, dim) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn is_sse4_available() -> bool { + is_x86_feature_detected!("sse4.1") + } + + #[test] + fn test_l2_squared_sse4() { + if !is_sse4_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + + let result = unsafe { l2_squared_f32_sse4(a.as_ptr(), b.as_ptr(), a.len()) }; + // Each difference is 1.0, so squared sum = 8.0 + assert!((result - 8.0).abs() < 1e-5); + } + + #[test] + fn test_l2_squared_sse4_remainder() { + if !is_sse4_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0]; + + let result = unsafe { l2_squared_f32_sse4(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!((result - 6.0).abs() < 1e-5); + } + + #[test] + fn test_inner_product_sse4() { + if !is_sse4_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let b = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; + + let result = unsafe { inner_product_f32_sse4(a.as_ptr(), b.as_ptr(), a.len()) }; + // 1+2+3+4+5+6+7+8 = 36 + assert!((result - 36.0).abs() < 1e-5); + } + + #[test] + fn test_inner_product_sse4_remainder() { + if !is_sse4_available() { + return; + } + + let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; + let b = vec![2.0f32, 2.0, 2.0, 2.0, 2.0]; + + let result = unsafe { inner_product_f32_sse4(a.as_ptr(), b.as_ptr(), a.len()) }; + // 2+4+6+8+10 = 30 + assert!((result - 30.0).abs() < 1e-5); + } + + #[test] + fn test_cosine_sse4_identical() { + if !is_sse4_available() { + return; + } + + let a = vec![0.6f32, 0.8, 0.0, 0.0]; + let result = unsafe { cosine_distance_f32_sse4(a.as_ptr(), a.as_ptr(), a.len()) }; + assert!(result.abs() < 1e-5); + } + + #[test] + fn test_cosine_sse4_orthogonal() { + if !is_sse4_available() { + return; + } + + let a = vec![1.0f32, 0.0, 0.0, 0.0]; + let b = vec![0.0f32, 1.0, 0.0, 0.0]; + + let result = unsafe { cosine_distance_f32_sse4(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!((result - 1.0).abs() < 1e-5); + } + + #[test] + fn test_cosine_sse4_opposite() { + if !is_sse4_available() { + return; + } + + let a = vec![1.0f32, 0.0, 0.0, 0.0]; + let b = vec![-1.0f32, 0.0, 0.0, 0.0]; + + let result = unsafe { cosine_distance_f32_sse4(a.as_ptr(), b.as_ptr(), a.len()) }; + assert!((result - 2.0).abs() < 1e-5); + } + + #[test] + fn test_l2_safe_wrapper() { + if !is_sse4_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f32).collect(); + let b: Vec = (0..128).map(|i| (i + 1) as f32).collect(); + + let sse4_result = l2_squared_f32::(&a, &b, 128); + let scalar_result = crate::distance::l2::l2_squared_scalar(&a, &b, 128); + + assert!((sse4_result - scalar_result).abs() < 0.1); + } + + #[test] + fn test_inner_product_safe_wrapper() { + if !is_sse4_available() { + return; + } + + let a: Vec = (0..128).map(|i| i as f32 / 100.0).collect(); + let b: Vec = (0..128).map(|i| (128 - i) as f32 / 100.0).collect(); + + let sse4_result = inner_product_f32::(&a, &b, 128); + let scalar_result = crate::distance::ip::inner_product_scalar(&a, &b, 128); + + assert!((sse4_result - scalar_result).abs() < 0.01); + } + + #[test] + fn test_cosine_safe_wrapper() { + if !is_sse4_available() { + return; + } + + let a: Vec = (1..129).map(|i| i as f32 / 100.0).collect(); + let b: Vec = (1..129).map(|i| (129 - i) as f32 / 100.0).collect(); + + let sse4_result = cosine_distance_f32::(&a, &b, 128); + let scalar_result = crate::distance::cosine::cosine_distance_scalar(&a, &b, 128); + + assert!((sse4_result - scalar_result).abs() < 0.001); + } +} diff --git a/rust/vecsim/src/distance/simd/sve.rs b/rust/vecsim/src/distance/simd/sve.rs new file mode 100644 index 000000000..382d3c911 --- /dev/null +++ b/rust/vecsim/src/distance/simd/sve.rs @@ -0,0 +1,445 @@ +//! ARM SVE (Scalable Vector Extension) optimized distance functions. +//! +//! ARM SVE provides variable-length vector operations (128-2048 bits) that +//! automatically scale to the hardware's vector length. This makes SVE code +//! portable across different ARM implementations. +//! +//! Note: Full SVE intrinsics are currently unstable in Rust (nightly-only). +//! This module provides: +//! - Runtime detection of SVE support +//! - Optimized implementations using available NEON with larger unrolling +//! - Preparedness for native SVE when it stabilizes +//! +//! For maximum performance on SVE-capable hardware, compile with: +//! `RUSTFLAGS="-C target-feature=+sve" cargo build --release` +//! +//! SVE-capable processors include: +//! - AWS Graviton3 (256-bit vectors) +//! - Fujitsu A64FX (512-bit vectors) +//! - ARM Neoverse V1/V2 + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +/// Check if SVE is available at runtime. +/// +/// Note: This requires the `std_detect_dlsym_getauxval` feature to be stable +/// for proper runtime detection. Currently returns false on stable Rust. +#[cfg(target_arch = "aarch64")] +#[inline] +pub fn has_sve() -> bool { + // SVE detection is not yet stable in Rust + // When stable, use: std::arch::is_aarch64_feature_detected!("sve") + // For now, we can't reliably detect SVE at runtime on stable Rust + #[cfg(target_feature = "sve")] + { + true + } + #[cfg(not(target_feature = "sve"))] + { + false + } +} + +#[cfg(not(target_arch = "aarch64"))] +#[inline] +pub fn has_sve() -> bool { + false +} + +/// Get the SVE vector length in bytes. +/// Returns 0 if SVE is not available. +#[cfg(target_arch = "aarch64")] +#[inline] +pub fn get_sve_vector_length() -> usize { + if has_sve() { + // When SVE is enabled at compile time, we can assume at least 128 bits + // The actual length depends on the hardware (128, 256, 512, 1024, 2048 bits) + #[cfg(target_feature = "sve")] + { + // In real SVE code, you would use svcntb() to get the vector length + // For now, assume minimum SVE length of 128 bits = 16 bytes + 16 + } + #[cfg(not(target_feature = "sve"))] + { + 0 + } + } else { + 0 + } +} + +#[cfg(not(target_arch = "aarch64"))] +#[inline] +pub fn get_sve_vector_length() -> usize { + 0 +} + +// When SVE intrinsics become stable, we would have implementations like: +// +// #[cfg(target_arch = "aarch64")] +// #[target_feature(enable = "sve")] +// pub unsafe fn l2_squared_f32_sve(a: *const f32, b: *const f32, dim: usize) -> f32 { +// use std::arch::aarch64::*; +// +// let vl = svcntw(); // Get vector length in words (f32s) +// let mut sum = svdup_n_f32(0.0); +// let mut offset = 0; +// +// // SVE has predicated operations for automatic remainder handling +// while offset < dim { +// let pred = svwhilelt_b32(offset as u64, dim as u64); +// let va = svld1_f32(pred, a.add(offset)); +// let vb = svld1_f32(pred, b.add(offset)); +// let diff = svsub_f32_x(pred, va, vb); +// sum = svmla_f32_x(pred, sum, diff, diff); +// offset += vl; +// } +// +// svaddv_f32(svptrue_b32(), sum) +// } + +// For now, provide optimized NEON with larger unrolling as a substitute +// These can be used on SVE-capable hardware when targeting NEON compatibility + +/// L2 squared distance with aggressive unrolling for SVE-class hardware. +/// Uses NEON but with 8x unrolling to better utilize wide pipelines. +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn l2_squared_f32_wide(a: *const f32, b: *const f32, dim: usize) -> f32 { + // 8x unrolling for wide execution units (32 f32s per iteration) + let mut sum0 = vdupq_n_f32(0.0); + let mut sum1 = vdupq_n_f32(0.0); + let mut sum2 = vdupq_n_f32(0.0); + let mut sum3 = vdupq_n_f32(0.0); + let mut sum4 = vdupq_n_f32(0.0); + let mut sum5 = vdupq_n_f32(0.0); + let mut sum6 = vdupq_n_f32(0.0); + let mut sum7 = vdupq_n_f32(0.0); + + let unroll = dim / 32 * 32; + let mut offset = 0; + + while offset < unroll { + let a0 = vld1q_f32(a.add(offset)); + let a1 = vld1q_f32(a.add(offset + 4)); + let a2 = vld1q_f32(a.add(offset + 8)); + let a3 = vld1q_f32(a.add(offset + 12)); + let a4 = vld1q_f32(a.add(offset + 16)); + let a5 = vld1q_f32(a.add(offset + 20)); + let a6 = vld1q_f32(a.add(offset + 24)); + let a7 = vld1q_f32(a.add(offset + 28)); + + let b0 = vld1q_f32(b.add(offset)); + let b1 = vld1q_f32(b.add(offset + 4)); + let b2 = vld1q_f32(b.add(offset + 8)); + let b3 = vld1q_f32(b.add(offset + 12)); + let b4 = vld1q_f32(b.add(offset + 16)); + let b5 = vld1q_f32(b.add(offset + 20)); + let b6 = vld1q_f32(b.add(offset + 24)); + let b7 = vld1q_f32(b.add(offset + 28)); + + let d0 = vsubq_f32(a0, b0); + let d1 = vsubq_f32(a1, b1); + let d2 = vsubq_f32(a2, b2); + let d3 = vsubq_f32(a3, b3); + let d4 = vsubq_f32(a4, b4); + let d5 = vsubq_f32(a5, b5); + let d6 = vsubq_f32(a6, b6); + let d7 = vsubq_f32(a7, b7); + + sum0 = vfmaq_f32(sum0, d0, d0); + sum1 = vfmaq_f32(sum1, d1, d1); + sum2 = vfmaq_f32(sum2, d2, d2); + sum3 = vfmaq_f32(sum3, d3, d3); + sum4 = vfmaq_f32(sum4, d4, d4); + sum5 = vfmaq_f32(sum5, d5, d5); + sum6 = vfmaq_f32(sum6, d6, d6); + sum7 = vfmaq_f32(sum7, d7, d7); + + offset += 32; + } + + // Reduce partial sums + sum0 = vaddq_f32(sum0, sum1); + sum2 = vaddq_f32(sum2, sum3); + sum4 = vaddq_f32(sum4, sum5); + sum6 = vaddq_f32(sum6, sum7); + sum0 = vaddq_f32(sum0, sum2); + sum4 = vaddq_f32(sum4, sum6); + sum0 = vaddq_f32(sum0, sum4); + + // Handle remaining 4-element chunks + while offset + 4 <= dim { + let va = vld1q_f32(a.add(offset)); + let vb = vld1q_f32(b.add(offset)); + let diff = vsubq_f32(va, vb); + sum0 = vfmaq_f32(sum0, diff, diff); + offset += 4; + } + + // Horizontal sum + let mut result = vaddvq_f32(sum0); + + // Handle remaining elements + while offset < dim { + let diff = *a.add(offset) - *b.add(offset); + result += diff * diff; + offset += 1; + } + + result +} + +/// Inner product with aggressive unrolling for SVE-class hardware. +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn inner_product_f32_wide(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut sum0 = vdupq_n_f32(0.0); + let mut sum1 = vdupq_n_f32(0.0); + let mut sum2 = vdupq_n_f32(0.0); + let mut sum3 = vdupq_n_f32(0.0); + let mut sum4 = vdupq_n_f32(0.0); + let mut sum5 = vdupq_n_f32(0.0); + let mut sum6 = vdupq_n_f32(0.0); + let mut sum7 = vdupq_n_f32(0.0); + + let unroll = dim / 32 * 32; + let mut offset = 0; + + while offset < unroll { + let a0 = vld1q_f32(a.add(offset)); + let a1 = vld1q_f32(a.add(offset + 4)); + let a2 = vld1q_f32(a.add(offset + 8)); + let a3 = vld1q_f32(a.add(offset + 12)); + let a4 = vld1q_f32(a.add(offset + 16)); + let a5 = vld1q_f32(a.add(offset + 20)); + let a6 = vld1q_f32(a.add(offset + 24)); + let a7 = vld1q_f32(a.add(offset + 28)); + + let b0 = vld1q_f32(b.add(offset)); + let b1 = vld1q_f32(b.add(offset + 4)); + let b2 = vld1q_f32(b.add(offset + 8)); + let b3 = vld1q_f32(b.add(offset + 12)); + let b4 = vld1q_f32(b.add(offset + 16)); + let b5 = vld1q_f32(b.add(offset + 20)); + let b6 = vld1q_f32(b.add(offset + 24)); + let b7 = vld1q_f32(b.add(offset + 28)); + + sum0 = vfmaq_f32(sum0, a0, b0); + sum1 = vfmaq_f32(sum1, a1, b1); + sum2 = vfmaq_f32(sum2, a2, b2); + sum3 = vfmaq_f32(sum3, a3, b3); + sum4 = vfmaq_f32(sum4, a4, b4); + sum5 = vfmaq_f32(sum5, a5, b5); + sum6 = vfmaq_f32(sum6, a6, b6); + sum7 = vfmaq_f32(sum7, a7, b7); + + offset += 32; + } + + // Reduce + sum0 = vaddq_f32(sum0, sum1); + sum2 = vaddq_f32(sum2, sum3); + sum4 = vaddq_f32(sum4, sum5); + sum6 = vaddq_f32(sum6, sum7); + sum0 = vaddq_f32(sum0, sum2); + sum4 = vaddq_f32(sum4, sum6); + sum0 = vaddq_f32(sum0, sum4); + + while offset + 4 <= dim { + let va = vld1q_f32(a.add(offset)); + let vb = vld1q_f32(b.add(offset)); + sum0 = vfmaq_f32(sum0, va, vb); + offset += 4; + } + + let mut result = vaddvq_f32(sum0); + + while offset < dim { + result += *a.add(offset) * *b.add(offset); + offset += 1; + } + + result +} + +/// Cosine distance with aggressive unrolling for SVE-class hardware. +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn cosine_distance_f32_wide(a: *const f32, b: *const f32, dim: usize) -> f32 { + let mut dot0 = vdupq_n_f32(0.0); + let mut dot1 = vdupq_n_f32(0.0); + let mut norm_a0 = vdupq_n_f32(0.0); + let mut norm_a1 = vdupq_n_f32(0.0); + let mut norm_b0 = vdupq_n_f32(0.0); + let mut norm_b1 = vdupq_n_f32(0.0); + + let unroll = dim / 8 * 8; + let mut offset = 0; + + while offset < unroll { + let a0 = vld1q_f32(a.add(offset)); + let a1 = vld1q_f32(a.add(offset + 4)); + let b0 = vld1q_f32(b.add(offset)); + let b1 = vld1q_f32(b.add(offset + 4)); + + dot0 = vfmaq_f32(dot0, a0, b0); + dot1 = vfmaq_f32(dot1, a1, b1); + norm_a0 = vfmaq_f32(norm_a0, a0, a0); + norm_a1 = vfmaq_f32(norm_a1, a1, a1); + norm_b0 = vfmaq_f32(norm_b0, b0, b0); + norm_b1 = vfmaq_f32(norm_b1, b1, b1); + + offset += 8; + } + + let dot_sum = vaddq_f32(dot0, dot1); + let norm_a_sum = vaddq_f32(norm_a0, norm_a1); + let norm_b_sum = vaddq_f32(norm_b0, norm_b1); + + let mut dot = vaddvq_f32(dot_sum); + let mut norm_a = vaddvq_f32(norm_a_sum); + let mut norm_b = vaddvq_f32(norm_b_sum); + + while offset < dim { + let av = *a.add(offset); + let bv = *b.add(offset); + dot += av * bv; + norm_a += av * av; + norm_b += bv * bv; + offset += 1; + } + + let denom = (norm_a * norm_b).sqrt(); + if denom < 1e-30 { + return 1.0; + } + + let cosine_sim = (dot / denom).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +// Safe wrappers + +/// Safe L2 squared distance with wide unrolling. +#[cfg(target_arch = "aarch64")] +pub fn l2_squared_f32(a: &[f32], b: &[f32], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + unsafe { l2_squared_f32_wide(a.as_ptr(), b.as_ptr(), dim) } +} + +/// Safe inner product with wide unrolling. +#[cfg(target_arch = "aarch64")] +pub fn inner_product_f32(a: &[f32], b: &[f32], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + unsafe { inner_product_f32_wide(a.as_ptr(), b.as_ptr(), dim) } +} + +/// Safe cosine distance with wide unrolling. +#[cfg(target_arch = "aarch64")] +pub fn cosine_distance_f32(a: &[f32], b: &[f32], dim: usize) -> f32 { + debug_assert_eq!(a.len(), dim); + debug_assert_eq!(b.len(), dim); + unsafe { cosine_distance_f32_wide(a.as_ptr(), b.as_ptr(), dim) } +} + +// Stubs for non-aarch64 +#[cfg(not(target_arch = "aarch64"))] +pub fn l2_squared_f32(_a: &[f32], _b: &[f32], _dim: usize) -> f32 { + unimplemented!("SVE/wide NEON only available on aarch64") +} + +#[cfg(not(target_arch = "aarch64"))] +pub fn inner_product_f32(_a: &[f32], _b: &[f32], _dim: usize) -> f32 { + unimplemented!("SVE/wide NEON only available on aarch64") +} + +#[cfg(not(target_arch = "aarch64"))] +pub fn cosine_distance_f32(_a: &[f32], _b: &[f32], _dim: usize) -> f32 { + unimplemented!("SVE/wide NEON only available on aarch64") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_has_sve() { + let has = has_sve(); + println!("SVE available: {}", has); + // Just check it doesn't crash + } + + #[test] + fn test_sve_vector_length() { + let len = get_sve_vector_length(); + println!("SVE vector length: {} bytes", len); + // Either 0 (not available) or a power of 2 >= 16 + assert!(len == 0 || (len >= 16 && len.is_power_of_two())); + } + + #[cfg(target_arch = "aarch64")] + #[test] + fn test_wide_l2_squared() { + for &dim in &[32, 64, 100, 128, 256, 512] { + let a: Vec = (0..dim).map(|i| i as f32 / dim as f32).collect(); + let b: Vec = (0..dim).map(|i| (dim - i) as f32 / dim as f32).collect(); + + let wide_result = l2_squared_f32(&a, &b, dim); + + // Compute scalar reference + let scalar_result: f32 = a + .iter() + .zip(b.iter()) + .map(|(x, y)| (x - y) * (x - y)) + .sum(); + + assert!( + (wide_result - scalar_result).abs() < 0.001, + "dim={}: wide={}, scalar={}", + dim, + wide_result, + scalar_result + ); + } + } + + #[cfg(target_arch = "aarch64")] + #[test] + fn test_wide_inner_product() { + for &dim in &[32, 64, 100, 128, 256, 512] { + let a: Vec = (0..dim).map(|i| i as f32 / dim as f32).collect(); + let b: Vec = (0..dim).map(|i| (dim - i) as f32 / dim as f32).collect(); + + let wide_result = inner_product_f32(&a, &b, dim); + let scalar_result: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + + assert!( + (wide_result - scalar_result).abs() < 0.001, + "dim={}: wide={}, scalar={}", + dim, + wide_result, + scalar_result + ); + } + } + + #[cfg(target_arch = "aarch64")] + #[test] + fn test_wide_cosine_distance() { + let dim = 128; + let a: Vec = (0..dim).map(|i| i as f32 / dim as f32).collect(); + + // Self-distance should be ~0 + let self_dist = cosine_distance_f32(&a, &a, dim); + assert!( + self_dist.abs() < 0.001, + "Self cosine distance should be ~0, got {}", + self_dist + ); + } +} diff --git a/rust/vecsim/src/e2e_tests.rs b/rust/vecsim/src/e2e_tests.rs new file mode 100644 index 000000000..59b17377c --- /dev/null +++ b/rust/vecsim/src/e2e_tests.rs @@ -0,0 +1,964 @@ +//! End-to-end integration tests for the VecSim library. +//! +//! These tests verify complete workflows from start to finish, simulating +//! real-world usage scenarios including index lifecycle management, +//! serialization/persistence, realistic workloads, and multi-index comparisons. + +use crate::distance::Metric; +use crate::index::{ + BruteForceMulti, BruteForceParams, BruteForceSingle, HnswParams, HnswSingle, SvsParams, + SvsSingle, TieredParams, TieredSingle, VecSimIndex, WriteMode, +}; +use crate::query::QueryParams; +use rand::prelude::*; +use std::collections::HashSet; + +// ============================================================================= +// Test Data Generators +// ============================================================================= + +/// Generate random vectors with optional clustering around centroids. +fn generate_random_vectors(count: usize, dim: usize, seed: u64) -> Vec> { + let mut rng = StdRng::seed_from_u64(seed); + (0..count) + .map(|_| (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect()) + .collect() +} + +/// Normalize a vector to unit length. +fn normalize_vector(v: &[f32]) -> Vec { + let norm: f32 = v.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + v.iter().map(|x| x / norm).collect() + } else { + v.to_vec() + } +} + +/// Generate normalized vectors for cosine similarity testing. +fn generate_normalized_vectors(count: usize, dim: usize, seed: u64) -> Vec> { + generate_random_vectors(count, dim, seed) + .into_iter() + .map(|v| normalize_vector(&v)) + .collect() +} + +// ============================================================================= +// Index Lifecycle E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_brute_force_complete_lifecycle() { + // E2E test: Create → Add → Query → Update → Delete → Query → Clear + let dim = 32; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Phase 1: Initial population + let vectors = generate_random_vectors(100, dim, 12345); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + assert_eq!(index.index_size(), 100); + + // Phase 2: Query and verify results + let query = &vectors[0]; + let results = index.top_k_query(query, 5, None).unwrap(); + assert_eq!(results.results.len(), 5); + assert_eq!(results.results[0].label, 0); // Should find itself + + // Phase 3: Update a vector (delete + add with same label) + let new_vector: Vec = (0..dim).map(|i| i as f32 * 0.1).collect(); + index.delete_vector(50).unwrap(); + index.add_vector(&new_vector, 50).unwrap(); + assert_eq!(index.index_size(), 100); + + // Verify update worked + let retrieved = index.get_vector(50).unwrap(); + for (a, b) in retrieved.iter().zip(new_vector.iter()) { + assert!((a - b).abs() < 1e-6); + } + + // Phase 4: Delete multiple vectors + for label in 90..100 { + index.delete_vector(label).unwrap(); + } + assert_eq!(index.index_size(), 90); + + // Query should not return deleted vectors + let results = index.top_k_query(&vectors[95], 100, None).unwrap(); + for r in &results.results { + assert!(r.label < 90 || r.label == 50); + } + + // Phase 5: Clear index + index.clear(); + assert_eq!(index.index_size(), 0); + + // Verify cleared index works for new additions + index.add_vector(&vectors[0], 1000).unwrap(); + assert_eq!(index.index_size(), 1); +} + +#[test] +fn test_e2e_hnsw_complete_lifecycle() { + // E2E test for HNSW index lifecycle + // Use random vectors (not clustered) for reliable HNSW graph connectivity + let dim = 64; + let params = HnswParams::new(dim, Metric::L2) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + + // Phase 1: Build index with 500 random vectors + let vectors = generate_random_vectors(500, dim, 42); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + assert_eq!(index.index_size(), 500); + + // Phase 2: Query with different ef_runtime values + let query = &vectors[0]; + + // Low ef_runtime (fast but lower recall) + let params_low = QueryParams::new().with_ef_runtime(10); + let results_low = index.top_k_query(query, 10, Some(¶ms_low)).unwrap(); + assert!(!results_low.results.is_empty()); + + // High ef_runtime (slower but higher recall) + let params_high = QueryParams::new().with_ef_runtime(200); + let results_high = index.top_k_query(query, 10, Some(¶ms_high)).unwrap(); + assert_eq!(results_high.results[0].label, 0); // Should find itself with high ef + + // Phase 3: Delete and verify + index.delete_vector(0).unwrap(); + let results = index.top_k_query(query, 10, Some(¶ms_high)).unwrap(); + for r in &results.results { + assert_ne!(r.label, 0); + } + + // Phase 4: Range query - use radius appropriate for 64-dim random vectors in [-1,1] + let range_results = index.range_query(&vectors[100], 5.0, None).unwrap(); + assert!(!range_results.results.is_empty()); +} + +#[test] +fn test_e2e_multi_value_index_lifecycle() { + // E2E test for multi-value indices (multiple vectors per label) + let dim = 16; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Phase 1: Add multiple vectors per label (simulating multiple embeddings per document) + let vectors = generate_random_vectors(300, dim, 9999); + + // Labels 1-10 each get 10 vectors, labels 11-100 get 2 vectors each + for label in 1..=10u64 { + for j in 0..10 { + let idx = ((label - 1) * 10 + j) as usize; + index.add_vector(&vectors[idx], label).unwrap(); + } + } + for label in 11..=100u64 { + for j in 0..2 { + let idx = 100 + ((label - 11) * 2 + j) as usize; + index.add_vector(&vectors[idx], label).unwrap(); + } + } + + assert_eq!(index.index_size(), 280); // 10*10 + 90*2 + assert_eq!(index.label_count(1), 10); + assert_eq!(index.label_count(50), 2); + + // Phase 2: Query returns results (multi-value returns unique labels per query) + let query = &vectors[0]; + let results = index.top_k_query(query, 20, None).unwrap(); + assert!(!results.results.is_empty()); + // Multi-value index should return unique labels (best match per label) + let unique_labels: HashSet<_> = results.results.iter().map(|r| r.label).collect(); + // The number of unique labels should equal or be close to the result count + assert!(unique_labels.len() >= results.results.len() / 2); + + // Phase 3: Delete one vector from multi-vector label + index.delete_vector(1).unwrap(); // Deletes all vectors for label 1 + assert_eq!(index.label_count(1), 0); + assert!(!index.contains(1)); + + // Phase 4: Verify queries don't return deleted label + let results = index.top_k_query(query, 100, None).unwrap(); + for r in &results.results { + assert_ne!(r.label, 1); + } +} + +// ============================================================================= +// Serialization/Persistence E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_brute_force_persistence_roundtrip() { + // E2E test: Create → Populate → Save → Load → Verify → Continue using + let dim = 48; + let params = BruteForceParams::new(dim, Metric::L2); + let mut original = BruteForceSingle::::new(params.clone()); + + // Populate with vectors + let vectors = generate_random_vectors(200, dim, 11111); + for (i, v) in vectors.iter().enumerate() { + original.add_vector(v, i as u64).unwrap(); + } + + // Delete some to test fragmentation handling + for label in [10, 50, 100, 150] { + original.delete_vector(label).unwrap(); + } + + // Save to buffer + let mut buffer = Vec::new(); + original.save(&mut buffer).unwrap(); + + // Load from buffer + let mut loaded = BruteForceSingle::::load(&mut buffer.as_slice()).unwrap(); + + // Verify metadata + assert_eq!(original.index_size(), loaded.index_size()); + assert_eq!(original.dimension(), loaded.dimension()); + + // Verify queries produce same results + let query = &vectors[0]; + let orig_results = original.top_k_query(query, 10, None).unwrap(); + let loaded_results = loaded.top_k_query(query, 10, None).unwrap(); + + // Both should return results (exact count may vary based on available vectors) + assert!(!orig_results.results.is_empty()); + assert!(!loaded_results.results.is_empty()); + + // Top result should be the same (query vector itself, label 0) + assert_eq!(orig_results.results[0].label, loaded_results.results[0].label); + assert!((orig_results.results[0].distance - loaded_results.results[0].distance).abs() < 1e-6); + + // Verify deleted vectors are still deleted + assert!(!loaded.contains(10)); + assert!(!loaded.contains(50)); + + // Verify loaded index can continue operations + let new_vec: Vec = (0..dim).map(|_| 0.5).collect(); + loaded.add_vector(&new_vec, 9999).unwrap(); + assert!(loaded.contains(9999)); + assert_eq!(loaded.index_size(), original.index_size() + 1); +} + +#[test] +fn test_e2e_hnsw_persistence_roundtrip() { + // E2E test for HNSW serialization + let dim = 32; + let params = HnswParams::new(dim, Metric::InnerProduct) + .with_m(12) + .with_ef_construction(50); + let mut original = HnswSingle::::new(params.clone()); + + // Populate with normalized vectors (for inner product) + let vectors = generate_normalized_vectors(150, dim, 22222); + for (i, v) in vectors.iter().enumerate() { + original.add_vector(v, i as u64).unwrap(); + } + + // Save and load + let mut buffer = Vec::new(); + original.save(&mut buffer).unwrap(); + let loaded = HnswSingle::::load(&mut buffer.as_slice()).unwrap(); + + // Verify structure + assert_eq!(original.index_size(), loaded.index_size()); + + // Verify queries with high ef for exact comparison + let query_params = QueryParams::new().with_ef_runtime(150); + for query_label in [0, 50, 100] { + let query = &vectors[query_label]; + let orig_results = original.top_k_query(query, 5, Some(&query_params)).unwrap(); + let loaded_results = loaded.top_k_query(query, 5, Some(&query_params)).unwrap(); + + // With high ef, results should match exactly + for (o, l) in orig_results.results.iter().zip(loaded_results.results.iter()) { + assert_eq!(o.label, l.label); + } + } +} + +// ============================================================================= +// Real-world Workload E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_realistic_document_embedding_workflow() { + // Simulates a document embedding use case: + // 1. Bulk load initial corpus + // 2. Query for similar documents + // 3. Add new documents over time + // 4. Delete outdated documents + // 5. Re-query to verify updates + + let dim = 128; // Typical embedding dimension + let params = HnswParams::new(dim, Metric::Cosine) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + + // Phase 1: Bulk load initial corpus (1000 documents) + let initial_docs = generate_normalized_vectors(1000, dim, 33333); + for (i, v) in initial_docs.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + assert_eq!(index.index_size(), 1000); + + // Phase 2: Query for similar documents + let query_doc = &initial_docs[500]; + let params = QueryParams::new().with_ef_runtime(50); + let results = index.top_k_query(query_doc, 10, Some(¶ms)).unwrap(); + assert_eq!(results.results[0].label, 500); + + // Phase 3: Add new documents (simulating updates) + let new_docs = generate_normalized_vectors(100, dim, 44444); + for (i, v) in new_docs.iter().enumerate() { + index.add_vector(v, 1000 + i as u64).unwrap(); + } + assert_eq!(index.index_size(), 1100); + + // Phase 4: Delete outdated documents + for label in 0..50u64 { + index.delete_vector(label).unwrap(); + } + assert_eq!(index.index_size(), 1050); + + // Phase 5: Re-query and verify deleted docs aren't returned + let results = index.top_k_query(&initial_docs[25], 100, Some(¶ms)).unwrap(); + for r in &results.results { + assert!(r.label >= 50); + } + + // Verify fragmentation is reasonable + let frag = index.fragmentation(); + assert!(frag < 0.1); // Less than 10% fragmentation +} + +#[test] +fn test_e2e_image_similarity_workflow() { + // Simulates an image similarity search use case with multi-value index: + // Each image has multiple feature vectors (e.g., from different regions) + + let dim = 256; // Image feature dimension + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Each image (label) gets 4 feature vectors (e.g., 4 crops) + let num_images = 100; + let features_per_image = 4; + + let mut rng = StdRng::seed_from_u64(55555); + for img_id in 0..num_images { + for _ in 0..features_per_image { + let feature: Vec = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect(); + index.add_vector(&feature, img_id as u64).unwrap(); + } + } + + assert_eq!(index.index_size(), num_images * features_per_image); + + // Query returns results - multi-value indices return unique labels + let query: Vec = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect(); + let results = index.top_k_query(&query, 20, None).unwrap(); + + // Results should be non-empty and have unique labels + assert!(!results.results.is_empty()); + let labels: Vec<_> = results.results.iter().map(|r| r.label).collect(); + let unique: HashSet<_> = labels.iter().cloned().collect(); + // Multi-value index returns one result per label (best match) + assert!(unique.len() >= labels.len() / 2); +} + +#[test] +fn test_e2e_batch_operations_workflow() { + // Test batch iterator for paginated results + let dim = 64; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Add vectors + let vectors = generate_random_vectors(500, dim, 66666); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + // Use batch iterator to process all vectors + let query = &vectors[0]; + let batch_size = 50; + let mut iterator = index.batch_iterator(query, None).unwrap(); + + let mut all_labels = Vec::new(); + let mut batch_count = 0; + while iterator.has_next() { + if let Some(batch) = iterator.next_batch(batch_size) { + assert!(batch.len() <= batch_size); + all_labels.extend(batch.iter().map(|(_, label, _)| *label)); + batch_count += 1; + } + } + + // Should have processed all vectors + assert_eq!(all_labels.len(), 500); + assert_eq!(batch_count, 10); // 500 / 50 = 10 batches +} + +// ============================================================================= +// Multi-Index Comparison E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_brute_force_vs_hnsw_recall() { + // Verify HNSW recall against BruteForce ground truth + let dim = 64; + let num_vectors = 1000; + let num_queries = 50; + let k = 10; + + // Create both indices with same data + let bf_params = BruteForceParams::new(dim, Metric::L2); + let hnsw_params = HnswParams::new(dim, Metric::L2) + .with_m(16) + .with_ef_construction(200); + + let mut bf_index = BruteForceSingle::::new(bf_params); + let mut hnsw_index = HnswSingle::::new(hnsw_params); + + // Same vectors in both + let vectors = generate_random_vectors(num_vectors, dim, 77777); + for (i, v) in vectors.iter().enumerate() { + bf_index.add_vector(v, i as u64).unwrap(); + hnsw_index.add_vector(v, i as u64).unwrap(); + } + + // Generate random queries + let queries = generate_random_vectors(num_queries, dim, 88888); + + // Compare results with high ef_runtime + let hnsw_params = QueryParams::new().with_ef_runtime(100); + let mut total_recall = 0.0; + + for query in &queries { + let bf_results = bf_index.top_k_query(query, k, None).unwrap(); + let hnsw_results = hnsw_index.top_k_query(query, k, Some(&hnsw_params)).unwrap(); + + // Calculate recall + let bf_labels: HashSet<_> = bf_results.results.iter().map(|r| r.label).collect(); + let hnsw_labels: HashSet<_> = hnsw_results.results.iter().map(|r| r.label).collect(); + let intersection = bf_labels.intersection(&hnsw_labels).count(); + let recall = intersection as f64 / k as f64; + total_recall += recall; + } + + let avg_recall = total_recall / num_queries as f64; + assert!( + avg_recall >= 0.9, + "HNSW recall {} should be >= 0.9", + avg_recall + ); +} + +#[test] +fn test_e2e_index_type_consistency() { + // Verify all index types return consistent results for exact match + let dim = 16; + let vectors = generate_random_vectors(100, dim, 99999); + + // Create all single-value index types + let bf_params = BruteForceParams::new(dim, Metric::L2); + let hnsw_params = HnswParams::new(dim, Metric::L2).with_m(8).with_ef_construction(50); + let svs_params = SvsParams::new(dim, Metric::L2); + + let mut bf = BruteForceSingle::::new(bf_params); + let mut hnsw = HnswSingle::::new(hnsw_params); + let mut svs = SvsSingle::::new(svs_params); + + // Add same vectors to all + for (i, v) in vectors.iter().enumerate() { + bf.add_vector(v, i as u64).unwrap(); + hnsw.add_vector(v, i as u64).unwrap(); + svs.add_vector(v, i as u64).unwrap(); + } + + // Query for an existing vector - all should return it as top result + let query = &vectors[42]; + let hnsw_params = QueryParams::new().with_ef_runtime(100); + + let bf_result = bf.top_k_query(query, 1, None).unwrap(); + let hnsw_result = hnsw.top_k_query(query, 1, Some(&hnsw_params)).unwrap(); + let svs_result = svs.top_k_query(query, 1, None).unwrap(); + + assert_eq!(bf_result.results[0].label, 42); + assert_eq!(hnsw_result.results[0].label, 42); + assert_eq!(svs_result.results[0].label, 42); + + // All should have near-zero distance + assert!(bf_result.results[0].distance < 1e-6); + assert!(hnsw_result.results[0].distance < 1e-6); + assert!(svs_result.results[0].distance < 1e-6); +} + +// ============================================================================= +// Tiered Index E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_tiered_index_flush_workflow() { + // E2E test for tiered index: write buffering and flush to HNSW + let dim = 32; + let tiered_params = TieredParams::new(dim, Metric::L2) + .with_m(8) + .with_ef_construction(50) + .with_flat_buffer_limit(50) + .with_write_mode(WriteMode::Async); + + let mut index = TieredSingle::::new(tiered_params); + + let vectors = generate_random_vectors(150, dim, 11111); + + // Phase 1: Add vectors below buffer limit - should go to flat + for (i, v) in vectors.iter().take(40).enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + assert_eq!(index.flat_size(), 40); + assert_eq!(index.hnsw_size(), 0); + assert_eq!(index.write_mode(), WriteMode::Async); + + // Phase 2: Add more to exceed buffer limit - should trigger InPlace mode + for (i, v) in vectors.iter().skip(40).take(30).enumerate() { + index.add_vector(v, (40 + i) as u64).unwrap(); + } + + // After exceeding limit, new vectors go directly to HNSW + assert!(index.flat_size() <= 50); + + // Phase 3: Flush to move all to HNSW + index.flush().unwrap(); + + assert_eq!(index.flat_size(), 0); + assert_eq!(index.hnsw_size(), 70); + + // Phase 4: Continue adding - should go to flat again + for (i, v) in vectors.iter().skip(70).enumerate() { + index.add_vector(v, (70 + i) as u64).unwrap(); + } + + // Queries should work across both tiers + let query = &vectors[0]; + let results = index.top_k_query(query, 10, None).unwrap(); + assert!(!results.results.is_empty()); +} + +#[test] +fn test_e2e_tiered_index_query_merging() { + // Verify tiered index correctly merges results from both tiers + let dim = 16; + let tiered_params = TieredParams::new(dim, Metric::L2) + .with_m(8) + .with_ef_construction(50) + .with_flat_buffer_limit(100); + + let mut tiered = TieredSingle::::new(tiered_params); + let bf_params = BruteForceParams::new(dim, Metric::L2); + let mut bf = BruteForceSingle::::new(bf_params); + + // Add vectors to both + let vectors = generate_random_vectors(80, dim, 22222); + for (i, v) in vectors.iter().enumerate() { + tiered.add_vector(v, i as u64).unwrap(); + bf.add_vector(v, i as u64).unwrap(); + } + + // Flush half to HNSW + tiered.flush().unwrap(); + + // Add more to flat buffer + let more_vectors = generate_random_vectors(30, dim, 33333); + for (i, v) in more_vectors.iter().enumerate() { + tiered.add_vector(v, (80 + i) as u64).unwrap(); + bf.add_vector(v, (80 + i) as u64).unwrap(); + } + + // Now tiered has data in both tiers + assert!(tiered.hnsw_size() > 0); + assert!(tiered.flat_size() > 0); + + // Query should return merged results + let query = &vectors[0]; + let tiered_results = tiered.top_k_query(query, 10, None).unwrap(); + let bf_results = bf.top_k_query(query, 10, None).unwrap(); + + // Top result should be the same (exact match) + assert_eq!(tiered_results.results[0].label, bf_results.results[0].label); +} + +// ============================================================================= +// Filtered Query E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_filtered_queries_workflow() { + // E2E test for filtered queries (e.g., search within category) + let dim = 32; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Add vectors with labels encoding category (label % 10 = category) + let vectors = generate_random_vectors(500, dim, 44444); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + // Query with filter: only return even labels + let query = &vectors[100]; + let filter = |label: u64| label % 2 == 0; + let params = QueryParams::new().with_filter(filter); + + let results = index.top_k_query(query, 20, Some(¶ms)).unwrap(); + + // All returned labels should be even + for r in &results.results { + assert_eq!(r.label % 2, 0, "Label {} should be even", r.label); + } + + // Should still get requested number of results (if enough pass filter) + assert_eq!(results.results.len(), 20); +} + +#[test] +fn test_e2e_filtered_range_query() { + // E2E test for filtered range queries + let dim = 16; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Create vectors at known distances from origin + for i in 0..100u64 { + let mut v = vec![0.0f32; dim]; + v[0] = i as f32; // Distance from origin = i + index.add_vector(&v, i).unwrap(); + } + + // Range query with filter + let query = vec![0.0f32; dim]; + let filter = |label: u64| label >= 10 && label < 20; + let params = QueryParams::new().with_filter(filter); + + // Radius of 400 (squared L2) should include labels 0-20 + let results = index.range_query(&query, 400.0, Some(¶ms)).unwrap(); + + // Should only get labels 10-19 (in range AND passing filter) + assert_eq!(results.results.len(), 10); + for r in &results.results { + assert!(r.label >= 10 && r.label < 20); + } +} + +// ============================================================================= +// Error Handling E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_error_recovery_workflow() { + // Test that index remains usable after errors + let dim = 8; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Add some valid vectors + let v1: Vec = vec![1.0; dim]; + let v2: Vec = vec![2.0; dim]; + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + + // Try to add wrong dimension (should error) + let wrong_dim: Vec = vec![1.0; dim + 1]; + let result = index.add_vector(&wrong_dim, 3); + assert!(result.is_err()); + + // Index should still be usable + assert_eq!(index.index_size(), 2); + + // Try to delete non-existent label (should error) + let result = index.delete_vector(999); + assert!(result.is_err()); + + // Index should still be usable + let v3: Vec = vec![3.0; dim]; + index.add_vector(&v3, 3).unwrap(); + assert_eq!(index.index_size(), 3); + + // Queries should work + let results = index.top_k_query(&v1, 3, None).unwrap(); + assert_eq!(results.results.len(), 3); +} + +#[test] +fn test_e2e_duplicate_label_handling() { + // Test handling of duplicate labels in single-value index + // Single-value indices UPDATE the vector when using same label (not error) + let dim = 8; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1: Vec = vec![1.0; dim]; + let v2: Vec = vec![2.0; dim]; + + // Add first vector + index.add_vector(&v1, 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Add with same label - should UPDATE (not error), size stays 1 + let result = index.add_vector(&v2, 1); + assert!(result.is_ok()); + assert_eq!(index.index_size(), 1); // Still only 1 vector + + // Vector should be updated to v2 + let retrieved = index.get_vector(1).unwrap(); + assert!((retrieved[0] - 2.0).abs() < 1e-6); +} + +// ============================================================================= +// Metrics Comparison E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_all_metrics_workflow() { + // Test complete workflow with all metric types + let dim = 16; + let vectors = generate_random_vectors(100, dim, 55555); + let normalized = generate_normalized_vectors(100, dim, 55555); + + // Test L2 + { + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + let results = index.top_k_query(&vectors[0], 5, None).unwrap(); + assert_eq!(results.results[0].label, 0); + assert!(results.results[0].distance < 1e-6); // L2 squared should be ~0 + } + + // Test Inner Product + { + let params = BruteForceParams::new(dim, Metric::InnerProduct); + let mut index = BruteForceSingle::::new(params); + for (i, v) in normalized.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + let results = index.top_k_query(&normalized[0], 5, None).unwrap(); + assert_eq!(results.results[0].label, 0); + // For normalized vectors, IP with self = 1, distance = 1 - 1 = 0 + } + + // Test Cosine + { + let params = BruteForceParams::new(dim, Metric::Cosine); + let mut index = BruteForceSingle::::new(params); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + let results = index.top_k_query(&vectors[0], 5, None).unwrap(); + assert_eq!(results.results[0].label, 0); + assert!(results.results[0].distance < 1e-6); // Cosine distance with self = 0 + } +} + +// ============================================================================= +// Capacity and Memory E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_capacity_enforcement() { + // Test that capacity limits are enforced + let dim = 8; + let max_capacity = 50; + let params = BruteForceParams::new(dim, Metric::L2); + // Use BruteForceSingle::with_capacity() for actual max capacity enforcement + let mut index = BruteForceSingle::::with_capacity(params, max_capacity); + + let vectors = generate_random_vectors(100, dim, 66666); + + // Add up to capacity + for (i, v) in vectors.iter().take(max_capacity).enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + assert_eq!(index.index_size(), max_capacity); + + // Adding more should fail + let result = index.add_vector(&vectors[max_capacity], max_capacity as u64); + assert!(result.is_err()); + + // Index should still be at capacity + assert_eq!(index.index_size(), max_capacity); + + // Delete one and add should work + index.delete_vector(0).unwrap(); + index.add_vector(&vectors[max_capacity], max_capacity as u64).unwrap(); + assert_eq!(index.index_size(), max_capacity); +} + +#[test] +fn test_e2e_memory_usage_tracking() { + // Verify memory usage reporting + let dim = 64; + let params = BruteForceParams::new(dim, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let initial_memory = index.memory_usage(); + + // Add vectors and track memory growth + let vectors = generate_random_vectors(100, dim, 77777); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let after_add_memory = index.memory_usage(); + assert!(after_add_memory > initial_memory); + + // Memory per vector should be roughly dim * sizeof(f32) + overhead + let memory_per_vector = (after_add_memory - initial_memory) / 100; + let expected_min = dim * std::mem::size_of::(); // At least the vector data + assert!( + memory_per_vector >= expected_min, + "Memory per vector {} should be >= {}", + memory_per_vector, + expected_min + ); +} + +// ============================================================================= +// Large Scale E2E Tests +// ============================================================================= + +/// Profiling test - run with `cargo test --release --features profile test_profile_insertion -- --nocapture` +#[test] +fn test_profile_insertion() { + use crate::index::hnsw::HnswCore; + + let dim = 128; + let num_vectors = 10_000; + let params = HnswParams::new(dim, Metric::L2) + .with_m(16) + .with_ef_construction(200) + .with_seed(12345); + let mut index = HnswSingle::::new(params); + + // Insert with profiling + let vectors = generate_random_vectors(num_vectors, dim, 88888); + let start = std::time::Instant::now(); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + let elapsed = start.elapsed(); + println!("\nInserted {} vectors in {:?} ({:.0} ops/s)", + num_vectors, elapsed, num_vectors as f64 / elapsed.as_secs_f64()); + + // Print profile stats if profiling is enabled + #[cfg(feature = "profile")] + { + crate::index::hnsw::PROFILE_STATS.with(|s| s.borrow_mut().print_and_reset()); + } + #[cfg(not(feature = "profile"))] + { + println!("Profiling not enabled. Run with: cargo test --release --features profile"); + } +} + +#[test] +fn test_e2e_scaling_to_10k_vectors() { + // Test with larger dataset using random vectors (not clustered) + // Random vectors work better with HNSW as they don't create disconnected graph regions + let dim = 128; + let num_vectors = 10_000; + let params = HnswParams::new(dim, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_seed(12345); // Fixed seed for reproducible graph structure + let mut index = HnswSingle::::new(params); + + // Bulk insert with random vectors + let vectors = generate_random_vectors(num_vectors, dim, 88888); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), num_vectors); + + // Query performance - should find similar vectors quickly + let query_params = QueryParams::new().with_ef_runtime(200); + let query = &vectors[5000]; + let results = index.top_k_query(query, 100, Some(&query_params)).unwrap(); + + // Should return requested number of results + assert_eq!(results.results.len(), 100); + // First result should be the query vector itself (distance ~0) + assert_eq!(results.results[0].label, 5000); + assert!(results.results[0].distance < 0.001, "Self-query distance {} too large", results.results[0].distance); + + // Test range query - use a reasonable radius for 128-dim L2 space + // Random vectors in [-1,1] have squared distances around 128 * 2/3 ≈ 85 + // The nearest neighbor (other than self) is typically at distance ~57 + // Use radius=60 (larger than typical nearest neighbor) so we find some results + // Use epsilon=0.2 (20%) to explore beyond the radius boundary + // With radius=60 and epsilon=0.2, boundary = 72 which allows exploration + let range_params = QueryParams::new().with_ef_runtime(200).with_epsilon(0.2); + let range_results = index.range_query(query, 60.0, Some(&range_params)).unwrap(); + + // The self-query should be within radius 60.0 (distance is 0) + assert!(!range_results.results.is_empty(), "Range query should find at least the query vector itself (distance=0)"); + // Verify the query vector is in the results + assert!(range_results.results.iter().any(|r| r.label == 5000 && r.distance < 0.001), + "Range query should find the query vector itself"); +} + +// ============================================================================= +// Index Info and Statistics E2E Tests +// ============================================================================= + +#[test] +fn test_e2e_index_info_consistency() { + // Verify IndexInfo is consistent across operations + let dim = 32; + let params = HnswParams::new(dim, Metric::Cosine) + .with_m(16) + .with_ef_construction(100); + let mut index = HnswSingle::::new(params); + + // Check initial info + let info = index.info(); + assert_eq!(info.dimension, dim); + assert_eq!(info.size, 0); + assert!(info.memory_bytes > 0); // Some base overhead + + // Add vectors + let vectors = generate_normalized_vectors(200, dim, 99999); + for (i, v) in vectors.iter().enumerate() { + index.add_vector(v, i as u64).unwrap(); + } + + let info = index.info(); + assert_eq!(info.size, 200); + assert!(info.memory_bytes > 0); + + // Delete some + for label in 0..50u64 { + index.delete_vector(label).unwrap(); + } + + let info = index.info(); + assert_eq!(info.size, 150); + + // Fragmentation should be non-zero after deletes + let frag = index.fragmentation(); + assert!(frag > 0.0); +} diff --git a/rust/vecsim/src/index/brute_force/batch_iterator.rs b/rust/vecsim/src/index/brute_force/batch_iterator.rs new file mode 100644 index 000000000..01208078f --- /dev/null +++ b/rust/vecsim/src/index/brute_force/batch_iterator.rs @@ -0,0 +1,355 @@ +//! Batch iterator implementations for BruteForce indices. +//! +//! These iterators hold pre-computed results, allowing streaming +//! in batches for processing large result sets incrementally. + +use crate::index::traits::BatchIterator; +use crate::types::{IdType, LabelType, VectorElement}; + +/// Batch iterator for single-value BruteForce index. +/// +/// Holds pre-computed, sorted results from the index. +pub struct BruteForceBatchIterator { + /// All results sorted by distance. + results: Vec<(IdType, LabelType, T::DistanceType)>, + /// Current position in results. + position: usize, +} + +impl BruteForceBatchIterator { + /// Create a new batch iterator with pre-computed results. + pub fn new(results: Vec<(IdType, LabelType, T::DistanceType)>) -> Self { + Self { + results, + position: 0, + } + } +} + +impl BatchIterator for BruteForceBatchIterator { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + self.position < self.results.len() + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + if self.position >= self.results.len() { + return None; + } + + let end = (self.position + batch_size).min(self.results.len()); + let batch = self.results[self.position..end].to_vec(); + self.position = end; + + Some(batch) + } + + fn reset(&mut self) { + self.position = 0; + } +} + +/// Batch iterator for multi-value BruteForce index. +/// +/// Holds pre-computed, sorted results from the index. +pub struct BruteForceMultiBatchIterator { + /// All results sorted by distance. + results: Vec<(IdType, LabelType, T::DistanceType)>, + /// Current position in results. + position: usize, +} + +impl BruteForceMultiBatchIterator { + /// Create a new batch iterator with pre-computed results. + pub fn new(results: Vec<(IdType, LabelType, T::DistanceType)>) -> Self { + Self { + results, + position: 0, + } + } +} + +impl BatchIterator for BruteForceMultiBatchIterator { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + self.position < self.results.len() + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + if self.position >= self.results.len() { + return None; + } + + let end = (self.position + batch_size).min(self.results.len()); + let batch = self.results[self.position..end].to_vec(); + self.position = end; + + Some(batch) + } + + fn reset(&mut self) { + self.position = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + use crate::index::brute_force::{BruteForceMulti, BruteForceParams, BruteForceSingle}; + use crate::index::traits::BatchIterator; + use crate::index::VecSimIndex; + use crate::types::DistanceType; + + #[test] + fn test_batch_iterator_single() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..10 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + assert!(iter.has_next()); + + // Get first batch + let batch1 = iter.next_batch(3).unwrap(); + assert_eq!(batch1.len(), 3); + + // Verify ordering + assert!(batch1[0].2.to_f64() <= batch1[1].2.to_f64()); + assert!(batch1[1].2.to_f64() <= batch1[2].2.to_f64()); + + // Get remaining batches + let mut total = batch1.len(); + while let Some(batch) = iter.next_batch(3) { + total += batch.len(); + } + assert_eq!(total, 10); + + // Reset and verify + iter.reset(); + assert!(iter.has_next()); + } + + #[test] + fn test_brute_force_batch_iterator_empty() { + let results: Vec<(IdType, LabelType, f32)> = vec![]; + let iter = BruteForceBatchIterator::::new(results); + + assert!(!iter.has_next()); + } + + #[test] + fn test_brute_force_batch_iterator_next_batch_empty() { + let results: Vec<(IdType, LabelType, f32)> = vec![]; + let mut iter = BruteForceBatchIterator::::new(results); + + assert!(iter.next_batch(10).is_none()); + } + + #[test] + fn test_brute_force_batch_iterator_single_result() { + let results = vec![(0, 100, 0.5f32)]; + let mut iter = BruteForceBatchIterator::::new(results); + + assert!(iter.has_next()); + + let batch = iter.next_batch(10).unwrap(); + assert_eq!(batch.len(), 1); + assert_eq!(batch[0], (0, 100, 0.5f32)); + + assert!(!iter.has_next()); + assert!(iter.next_batch(10).is_none()); + } + + #[test] + fn test_brute_force_batch_iterator_batch_size_larger_than_results() { + let results = vec![(0, 100, 0.5f32), (1, 101, 1.0f32), (2, 102, 1.5f32)]; + let mut iter = BruteForceBatchIterator::::new(results); + + let batch = iter.next_batch(100).unwrap(); + assert_eq!(batch.len(), 3); + + assert!(!iter.has_next()); + } + + #[test] + fn test_brute_force_batch_iterator_exact_batch_size() { + let results = vec![ + (0, 100, 0.5f32), + (1, 101, 1.0f32), + (2, 102, 1.5f32), + (3, 103, 2.0f32), + ]; + let mut iter = BruteForceBatchIterator::::new(results); + + let batch1 = iter.next_batch(2).unwrap(); + assert_eq!(batch1.len(), 2); + + let batch2 = iter.next_batch(2).unwrap(); + assert_eq!(batch2.len(), 2); + + assert!(!iter.has_next()); + } + + #[test] + fn test_brute_force_batch_iterator_reset() { + let results = vec![(0, 100, 0.5f32), (1, 101, 1.0f32)]; + let mut iter = BruteForceBatchIterator::::new(results); + + // Consume all + iter.next_batch(10); + assert!(!iter.has_next()); + + // Reset + iter.reset(); + assert!(iter.has_next()); + + // Can iterate again + let batch = iter.next_batch(10).unwrap(); + assert_eq!(batch.len(), 2); + } + + #[test] + fn test_brute_force_batch_iterator_multiple_batches() { + let results: Vec<(IdType, LabelType, f32)> = (0..10) + .map(|i| (i as IdType, i as LabelType + 100, i as f32 * 0.1)) + .collect(); + let mut iter = BruteForceBatchIterator::::new(results); + + let mut batches = Vec::new(); + while let Some(batch) = iter.next_batch(3) { + batches.push(batch); + } + + assert_eq!(batches.len(), 4); // 3 + 3 + 3 + 1 + assert_eq!(batches[0].len(), 3); + assert_eq!(batches[1].len(), 3); + assert_eq!(batches[2].len(), 3); + assert_eq!(batches[3].len(), 1); + } + + #[test] + fn test_brute_force_multi_batch_iterator_empty() { + let results: Vec<(IdType, LabelType, f32)> = vec![]; + let iter = BruteForceMultiBatchIterator::::new(results); + + assert!(!iter.has_next()); + } + + #[test] + fn test_brute_force_multi_batch_iterator_basic() { + let results = vec![(0, 100, 0.5f32), (1, 101, 1.0f32), (2, 102, 1.5f32)]; + let mut iter = BruteForceMultiBatchIterator::::new(results); + + assert!(iter.has_next()); + + let batch = iter.next_batch(2).unwrap(); + assert_eq!(batch.len(), 2); + + let batch = iter.next_batch(2).unwrap(); + assert_eq!(batch.len(), 1); + + assert!(!iter.has_next()); + } + + #[test] + fn test_brute_force_multi_batch_iterator_reset() { + let results = vec![(0, 100, 0.5f32), (1, 101, 1.0f32)]; + let mut iter = BruteForceMultiBatchIterator::::new(results); + + // Consume + iter.next_batch(10); + assert!(!iter.has_next()); + + // Reset + iter.reset(); + assert!(iter.has_next()); + } + + #[test] + fn test_batch_iterator_multi_index() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add vectors with same label (multi-value) + for i in 0..5 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], 100) + .unwrap(); + } + for i in 0..5 { + index + .add_vector(&vec![i as f32 + 10.0, 0.0, 0.0, 0.0], 200) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + assert!(iter.has_next()); + + let mut total = 0; + while let Some(batch) = iter.next_batch(3) { + total += batch.len(); + } + assert_eq!(total, 10); + } + + #[test] + fn test_batch_iterator_preserves_order() { + let results = vec![ + (0, 100, 0.1f32), + (1, 101, 0.2f32), + (2, 102, 0.3f32), + (3, 103, 0.4f32), + (4, 104, 0.5f32), + ]; + let mut iter = BruteForceBatchIterator::::new(results); + + let batch1 = iter.next_batch(2).unwrap(); + assert_eq!(batch1[0].0, 0); + assert_eq!(batch1[1].0, 1); + + let batch2 = iter.next_batch(2).unwrap(); + assert_eq!(batch2[0].0, 2); + assert_eq!(batch2[1].0, 3); + + let batch3 = iter.next_batch(2).unwrap(); + assert_eq!(batch3[0].0, 4); + } + + #[test] + fn test_batch_iterator_batch_size_one() { + let results = vec![(0, 100, 0.5f32), (1, 101, 1.0f32), (2, 102, 1.5f32)]; + let mut iter = BruteForceBatchIterator::::new(results); + + let batch1 = iter.next_batch(1).unwrap(); + assert_eq!(batch1.len(), 1); + assert_eq!(batch1[0].0, 0); + + let batch2 = iter.next_batch(1).unwrap(); + assert_eq!(batch2.len(), 1); + assert_eq!(batch2[0].0, 1); + + let batch3 = iter.next_batch(1).unwrap(); + assert_eq!(batch3.len(), 1); + assert_eq!(batch3[0].0, 2); + + assert!(!iter.has_next()); + } +} diff --git a/rust/vecsim/src/index/brute_force/mod.rs b/rust/vecsim/src/index/brute_force/mod.rs new file mode 100644 index 000000000..212e04c99 --- /dev/null +++ b/rust/vecsim/src/index/brute_force/mod.rs @@ -0,0 +1,127 @@ +//! BruteForce index implementation. +//! +//! The BruteForce index performs linear scans over all vectors to find nearest neighbors. +//! While this has O(n) query complexity, it provides exact results and is efficient +//! for small datasets or when high recall is critical. +//! +//! Two variants are provided: +//! - `BruteForceSingle`: One vector per label (new vector replaces existing) +//! - `BruteForceMulti`: Multiple vectors allowed per label + +pub mod batch_iterator; +pub mod multi; +pub mod single; + +pub use batch_iterator::BruteForceBatchIterator; +pub use multi::BruteForceMulti; +pub use single::{BruteForceSingle, BruteForceStats}; + +use crate::containers::DataBlocks; +use crate::distance::{create_distance_function, DistanceFunction, Metric}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement}; + +/// Parameters for creating a BruteForce index. +#[derive(Debug, Clone)] +pub struct BruteForceParams { + /// Vector dimension. + pub dim: usize, + /// Distance metric. + pub metric: Metric, + /// Initial capacity (number of vectors). + pub initial_capacity: usize, + /// Block size for vector storage. + pub block_size: Option, +} + +impl BruteForceParams { + /// Create new parameters with required fields. + pub fn new(dim: usize, metric: Metric) -> Self { + Self { + dim, + metric, + initial_capacity: 1024, + block_size: None, + } + } + + /// Set initial capacity. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.initial_capacity = capacity; + self + } + + /// Set block size. + pub fn with_block_size(mut self, size: usize) -> Self { + self.block_size = Some(size); + self + } +} + +/// Shared state for BruteForce indices. +pub(crate) struct BruteForceCore { + /// Vector storage. + pub data: DataBlocks, + /// Distance function. + pub dist_fn: Box>, + /// Vector dimension. + pub dim: usize, + /// Distance metric. + pub metric: Metric, +} + +impl BruteForceCore { + /// Create a new BruteForce core. + pub fn new(params: &BruteForceParams) -> Self { + let data = if let Some(block_size) = params.block_size { + DataBlocks::with_block_size(params.dim, params.initial_capacity, block_size) + } else { + DataBlocks::new(params.dim, params.initial_capacity) + }; + + let dist_fn = create_distance_function(params.metric, params.dim); + + Self { + data, + dist_fn, + dim: params.dim, + metric: params.metric, + } + } + + /// Add a vector and return its internal ID. + /// + /// Returns `None` if the vector dimension doesn't match. + #[inline] + pub fn add_vector(&mut self, vector: &[T]) -> Option { + // Preprocess if needed (e.g., normalize for cosine) + let processed = self.dist_fn.preprocess(vector, self.dim); + self.data.add(&processed) + } + + /// Get a vector by ID. + #[inline] + #[allow(dead_code)] + pub fn get_vector(&self, id: IdType) -> Option<&[T]> { + self.data.get(id) + } + + /// Compute distance between stored vector and query. + #[inline] + pub fn compute_distance(&self, id: IdType, query: &[T]) -> T::DistanceType { + if let Some(stored) = self.data.get(id) { + self.dist_fn + .compute_from_preprocessed(stored, query, self.dim) + } else { + T::DistanceType::infinity() + } + } +} + +/// Entry in the id-to-label mapping. +#[derive(Clone, Copy)] +#[derive(Default)] +pub(crate) struct IdLabelEntry { + pub label: LabelType, + pub is_valid: bool, +} + diff --git a/rust/vecsim/src/index/brute_force/multi.rs b/rust/vecsim/src/index/brute_force/multi.rs new file mode 100644 index 000000000..3fc69c1cb --- /dev/null +++ b/rust/vecsim/src/index/brute_force/multi.rs @@ -0,0 +1,1312 @@ +//! Multi-value BruteForce index implementation. +//! +//! This index allows multiple vectors per label. Each label can have +//! any number of associated vectors. + +use super::{BruteForceCore, BruteForceParams, IdLabelEntry}; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply, QueryResult}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement}; +use crate::utils::MaxHeap; +use parking_lot::RwLock; +use rayon::prelude::*; +use std::collections::{HashMap, HashSet}; + +/// Multi-value BruteForce index. +/// +/// Each label can have multiple associated vectors. This is useful for +/// scenarios where a single entity has multiple representations. +pub struct BruteForceMulti { + /// Core storage and distance computation. + pub(crate) core: RwLock>, + /// Label to set of internal IDs mapping. + label_to_ids: RwLock>>, + /// Internal ID to label mapping. + pub(crate) id_to_label: RwLock>, + /// Number of vectors. + count: std::sync::atomic::AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, +} + +impl BruteForceMulti { + /// Create a new multi-value BruteForce index. + pub fn new(params: BruteForceParams) -> Self { + let core = BruteForceCore::new(¶ms); + let initial_capacity = params.initial_capacity; + + Self { + core: RwLock::new(core), + label_to_ids: RwLock::new(HashMap::with_capacity(initial_capacity / 2)), + id_to_label: RwLock::new(Vec::with_capacity(initial_capacity)), + count: std::sync::atomic::AtomicUsize::new(0), + capacity: None, + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: BruteForceParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the distance metric. + pub fn metric(&self) -> crate::distance::Metric { + self.core.read().metric + } + + /// Get copies of all vectors stored for a given label. + /// + /// Returns `None` if the label doesn't exist in the index. + pub fn get_vectors(&self, label: LabelType) -> Option>> { + let label_to_ids = self.label_to_ids.read(); + let ids = label_to_ids.get(&label)?; + let core = self.core.read(); + let vectors: Vec> = ids + .iter() + .filter_map(|&id| core.data.get(id).map(|v| v.to_vec())) + .collect(); + if vectors.is_empty() { + None + } else { + Some(vectors) + } + } + + /// Get all labels currently in the index. + pub fn get_labels(&self) -> Vec { + self.label_to_ids.read().keys().copied().collect() + } + + /// Compute the minimum distance between any stored vector for a label and a query vector. + /// + /// Returns `None` if the label doesn't exist. + pub fn compute_distance(&self, label: LabelType, query: &[T]) -> Option { + let label_to_ids = self.label_to_ids.read(); + let ids = label_to_ids.get(&label)?; + let core = self.core.read(); + ids.iter() + .filter_map(|&id| { + if core.data.get(id).is_some() { + Some(core.compute_distance(id, query)) + } else { + None + } + }) + .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + } + + /// Get the memory usage in bytes. + pub fn memory_usage(&self) -> usize { + let core = self.core.read(); + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + // Vector data storage + let vector_storage = count * core.dim * std::mem::size_of::(); + + // Label mappings + let label_maps = self.label_to_ids.read().capacity() + * std::mem::size_of::<(LabelType, HashSet)>() + + self.id_to_label.read().capacity() * std::mem::size_of::(); + + vector_storage + label_maps + } + + /// Clear all vectors from the index, resetting it to empty state. + pub fn clear(&mut self) { + let mut core = self.core.write(); + let mut label_to_ids = self.label_to_ids.write(); + let mut id_to_label = self.id_to_label.write(); + + core.data.clear(); + label_to_ids.clear(); + id_to_label.clear(); + self.count.store(0, std::sync::atomic::Ordering::Relaxed); + } + + /// Compact the index by removing gaps from deleted vectors. + /// + /// This reorganizes the internal storage to reclaim space from deleted vectors. + /// After compaction, all vectors are stored contiguously. + /// + /// # Arguments + /// * `shrink` - If true, also release unused memory blocks + /// + /// # Returns + /// The number of bytes reclaimed (approximate). + pub fn compact(&mut self, shrink: bool) -> usize { + let mut core = self.core.write(); + let mut label_to_ids = self.label_to_ids.write(); + let mut id_to_label = self.id_to_label.write(); + + let old_capacity = core.data.capacity(); + let id_mapping = core.data.compact(shrink); + + // Update label_to_ids mapping + for (_label, ids) in label_to_ids.iter_mut() { + let new_ids: HashSet = ids + .iter() + .filter_map(|id| id_mapping.get(id).copied()) + .collect(); + *ids = new_ids; + } + + // Remove labels with no remaining vectors + label_to_ids.retain(|_, ids| !ids.is_empty()); + + // Rebuild id_to_label mapping + let mut new_id_to_label = Vec::with_capacity(id_mapping.len()); + for (&old_id, &new_id) in &id_mapping { + let new_id_usize = new_id as usize; + if new_id_usize >= new_id_to_label.len() { + new_id_to_label.resize(new_id_usize + 1, IdLabelEntry::default()); + } + if let Some(entry) = id_to_label.get(old_id as usize) { + new_id_to_label[new_id_usize] = *entry; + } + } + *id_to_label = new_id_to_label; + + let new_capacity = core.data.capacity(); + let dim = core.dim; + let bytes_per_vector = dim * std::mem::size_of::(); + + (old_capacity.saturating_sub(new_capacity)) * bytes_per_vector + } + + /// Get the fragmentation ratio of the index. + /// + /// Returns a value between 0.0 (no fragmentation) and 1.0 (all slots are deleted). + pub fn fragmentation(&self) -> f64 { + self.core.read().data.fragmentation() + } + + /// Add multiple vectors at once. + /// + /// Returns the number of vectors successfully added. + /// Stops on first error and returns what was added so far. + pub fn add_vectors( + &mut self, + vectors: &[(&[T], LabelType)], + ) -> Result { + let mut added = 0; + for &(vector, label) in vectors { + self.add_vector(vector, label)?; + added += 1; + } + Ok(added) + } + + /// Internal implementation of top-k query. + fn top_k_impl( + &self, + query: &[T], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + let core = self.core.read(); + let id_to_label = self.id_to_label.read(); + + if query.len() != core.dim { + return Err(QueryError::DimensionMismatch { + expected: core.dim, + got: query.len(), + }); + } + + let use_parallel = params.is_some_and(|p| p.parallel); + let filter = params.and_then(|p| p.filter.as_ref()).map(|f| f.as_ref()); + + let mut results = if use_parallel && id_to_label.len() > 1000 { + self.parallel_top_k(&core, &id_to_label, query, k, filter) + } else { + self.sequential_top_k(&core, &id_to_label, query, k, filter) + }; + + results.sort_by_distance(); + Ok(results) + } + + /// Sequential top-k scan. + fn sequential_top_k( + &self, + core: &BruteForceCore, + id_to_label: &[IdLabelEntry], + query: &[T], + k: usize, + filter: Option<&(dyn Fn(LabelType) -> bool + Send + Sync)>, + ) -> QueryReply { + let mut heap = MaxHeap::new(k); + + for (id, entry) in id_to_label.iter().enumerate() { + if !entry.is_valid { + continue; + } + + if let Some(f) = filter { + if !f(entry.label) { + continue; + } + } + + let dist = core.compute_distance(id as IdType, query); + heap.try_insert(id as IdType, dist); + } + + let mut reply = QueryReply::with_capacity(heap.len()); + for entry in heap.into_sorted_vec() { + if let Some(label_entry) = id_to_label.get(entry.id as usize) { + reply.push(QueryResult::new(label_entry.label, entry.distance)); + } + } + reply + } + + /// Parallel top-k scan using rayon. + fn parallel_top_k( + &self, + core: &BruteForceCore, + id_to_label: &[IdLabelEntry], + query: &[T], + k: usize, + filter: Option<&(dyn Fn(LabelType) -> bool + Send + Sync)>, + ) -> QueryReply { + let candidates: Vec<_> = id_to_label + .par_iter() + .enumerate() + .filter_map(|(id, entry)| { + if !entry.is_valid { + return None; + } + if let Some(f) = filter { + if !f(entry.label) { + return None; + } + } + let dist = core.compute_distance(id as IdType, query); + Some((id as IdType, entry.label, dist)) + }) + .collect(); + + let mut heap = MaxHeap::new(k); + for (id, _label, dist) in candidates { + heap.try_insert(id, dist); + } + + let mut reply = QueryReply::with_capacity(heap.len()); + for entry in heap.into_sorted_vec() { + if let Some(label_entry) = id_to_label.get(entry.id as usize) { + reply.push(QueryResult::new(label_entry.label, entry.distance)); + } + } + reply + } + + /// Internal implementation of range query. + fn range_impl( + &self, + query: &[T], + radius: T::DistanceType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + let core = self.core.read(); + let id_to_label = self.id_to_label.read(); + + if query.len() != core.dim { + return Err(QueryError::DimensionMismatch { + expected: core.dim, + got: query.len(), + }); + } + + let filter = params.and_then(|p| p.filter.as_ref()); + let mut reply = QueryReply::new(); + + for (id, entry) in id_to_label.iter().enumerate() { + if !entry.is_valid { + continue; + } + + if let Some(f) = filter { + if !f(entry.label) { + continue; + } + } + + let dist = core.compute_distance(id as IdType, query); + if dist.to_f64() <= radius.to_f64() { + reply.push(QueryResult::new(entry.label, dist)); + } + } + + reply.sort_by_distance(); + Ok(reply) + } +} + +impl VecSimIndex for BruteForceMulti { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + let mut core = self.core.write(); + + if vector.len() != core.dim { + return Err(IndexError::DimensionMismatch { + expected: core.dim, + got: vector.len(), + }); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(std::sync::atomic::Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add the vector + let id = core + .add_vector(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + + let mut label_to_ids = self.label_to_ids.write(); + let mut id_to_label = self.id_to_label.write(); + + // Update mappings + label_to_ids.entry(label).or_default().insert(id); + + // Ensure id_to_label is large enough + let id_usize = id as usize; + if id_usize >= id_to_label.len() { + id_to_label.resize(id_usize + 1, IdLabelEntry::default()); + } + id_to_label[id_usize] = IdLabelEntry { + label, + is_valid: true, + }; + + self.count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let mut label_to_ids = self.label_to_ids.write(); + let mut id_to_label = self.id_to_label.write(); + let mut core = self.core.write(); + + if let Some(ids) = label_to_ids.remove(&label) { + let count = ids.len(); + + for id in ids { + // Mark as invalid + if let Some(entry) = id_to_label.get_mut(id as usize) { + entry.is_valid = false; + } + // Mark slot as free + core.data.mark_deleted(id); + } + + self.count + .fetch_sub(count, std::sync::atomic::Ordering::Relaxed); + Ok(count) + } else { + Err(IndexError::LabelNotFound(label)) + } + } + + fn top_k_query( + &self, + query: &[T], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + self.top_k_impl(query, k, params) + } + + fn range_query( + &self, + query: &[T], + radius: T::DistanceType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + self.range_impl(query, radius, params) + } + + fn index_size(&self) -> usize { + self.count.load(std::sync::atomic::Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.core.read().dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[T], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + let core = self.core.read(); + if query.len() != core.dim { + return Err(QueryError::DimensionMismatch { + expected: core.dim, + got: query.len(), + }); + } + + // Compute results immediately to preserve filter from params + let id_to_label = self.id_to_label.read(); + let filter = params.and_then(|p| p.filter.as_ref()); + + let mut results = Vec::new(); + for (id, entry) in id_to_label.iter().enumerate() { + if !entry.is_valid { + continue; + } + + if let Some(f) = filter { + if !f(entry.label) { + continue; + } + } + + let dist = core.compute_distance(id as IdType, query); + results.push((id as IdType, entry.label, dist)); + } + + // Sort by distance + results.sort_by(|a, b| { + a.2.to_f64() + .partial_cmp(&b.2.to_f64()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + Ok(Box::new( + super::batch_iterator::BruteForceMultiBatchIterator::::new(results), + )) + } + + fn info(&self) -> IndexInfo { + let core = self.core.read(); + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + IndexInfo { + size: count, + capacity: self.capacity, + dimension: core.dim, + index_type: "BruteForceMulti", + memory_bytes: count * core.dim * std::mem::size_of::() + + self.label_to_ids.read().capacity() + * std::mem::size_of::<(LabelType, HashSet)>() + + self.id_to_label.read().capacity() * std::mem::size_of::(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.label_to_ids.read().contains_key(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + self.label_to_ids + .read() + .get(&label) + .map_or(0, |ids| ids.len()) + } +} + +unsafe impl Send for BruteForceMulti {} +unsafe impl Sync for BruteForceMulti {} + +// Serialization support +impl BruteForceMulti { + /// Save the index to a writer. + pub fn save( + &self, + writer: &mut W, + ) -> crate::serialization::SerializationResult<()> { + use crate::serialization::*; + + let core = self.core.read(); + let label_to_ids = self.label_to_ids.read(); + let id_to_label = self.id_to_label.read(); + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + // Write header + let header = IndexHeader::new( + IndexTypeId::BruteForceMulti, + DataTypeId::F32, + core.metric, + core.dim, + count, + ); + header.write(writer)?; + + // Write capacity + write_u8(writer, if self.capacity.is_some() { 1 } else { 0 })?; + if let Some(cap) = self.capacity { + write_usize(writer, cap)?; + } + + // Write label_to_ids mapping (label -> set of IDs) + write_usize(writer, label_to_ids.len())?; + for (&label, ids) in label_to_ids.iter() { + write_u64(writer, label)?; + write_usize(writer, ids.len())?; + for &id in ids { + write_u32(writer, id)?; + } + } + + // Write id_to_label entries + write_usize(writer, id_to_label.len())?; + for entry in id_to_label.iter() { + write_u64(writer, entry.label)?; + write_u8(writer, if entry.is_valid { 1 } else { 0 })?; + } + + // Write vectors - only write valid entries + let mut valid_ids: Vec = Vec::with_capacity(count); + for (id, entry) in id_to_label.iter().enumerate() { + if entry.is_valid { + valid_ids.push(id as IdType); + } + } + + write_usize(writer, valid_ids.len())?; + for id in valid_ids { + write_u32(writer, id)?; + if let Some(vector) = core.data.get(id) { + for &v in vector { + write_f32(writer, v)?; + } + } + } + + Ok(()) + } + + /// Load the index from a reader. + pub fn load(reader: &mut R) -> crate::serialization::SerializationResult { + use crate::serialization::*; + + // Read and validate header + let header = IndexHeader::read(reader)?; + + if header.index_type != IndexTypeId::BruteForceMulti { + return Err(SerializationError::IndexTypeMismatch { + expected: "BruteForceMulti".to_string(), + got: header.index_type.as_str().to_string(), + }); + } + + if header.data_type != DataTypeId::F32 { + return Err(SerializationError::InvalidData( + "Expected f32 data type".to_string(), + )); + } + + // Create the index with proper parameters + let params = BruteForceParams::new(header.dimension, header.metric); + let mut index = Self::new(params); + + // Read capacity + let has_capacity = read_u8(reader)? != 0; + if has_capacity { + index.capacity = Some(read_usize(reader)?); + } + + // Read label_to_ids mapping + let label_to_ids_len = read_usize(reader)?; + let mut label_to_ids: HashMap> = + HashMap::with_capacity(label_to_ids_len); + for _ in 0..label_to_ids_len { + let label = read_u64(reader)?; + let num_ids = read_usize(reader)?; + let mut ids = HashSet::with_capacity(num_ids); + for _ in 0..num_ids { + ids.insert(read_u32(reader)?); + } + label_to_ids.insert(label, ids); + } + + // Read id_to_label entries + let id_to_label_len = read_usize(reader)?; + let mut id_to_label = Vec::with_capacity(id_to_label_len); + for _ in 0..id_to_label_len { + let label = read_u64(reader)?; + let is_valid = read_u8(reader)? != 0; + id_to_label.push(IdLabelEntry { label, is_valid }); + } + + // Read vectors + let num_vectors = read_usize(reader)?; + let dim = header.dimension; + { + let mut core = index.core.write(); + + // Pre-allocate space + core.data.reserve(num_vectors); + + for _ in 0..num_vectors { + let _id = read_u32(reader)?; + let mut vector = vec![0.0f32; dim]; + for v in &mut vector { + *v = read_f32(reader)?; + } + + // Add vector to data storage + core.data.add(&vector).ok_or_else(|| { + SerializationError::DataCorruption("Failed to add vector during deserialization".to_string()) + })?; + } + } + + // Set the internal state + *index.label_to_ids.write() = label_to_ids; + *index.id_to_label.write() = id_to_label; + index + .count + .store(header.count, std::sync::atomic::Ordering::Relaxed); + + Ok(index) + } + + /// Save the index to a file. + pub fn save_to_file>( + &self, + path: P, + ) -> crate::serialization::SerializationResult<()> { + let file = std::fs::File::create(path)?; + let mut writer = std::io::BufWriter::new(file); + self.save(&mut writer) + } + + /// Load the index from a file. + pub fn load_from_file>( + path: P, + ) -> crate::serialization::SerializationResult { + let file = std::fs::File::open(path)?; + let mut reader = std::io::BufReader::new(file); + Self::load(&mut reader) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_brute_force_multi_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add multiple vectors with same label + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 3); + assert_eq!(index.label_count(1), 2); + assert_eq!(index.label_count(2), 1); + + // Query should find both vectors for label 1 + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + assert_eq!(results.len(), 3); + // First result should be exact match + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); + } + + #[test] + fn test_brute_force_multi_delete() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 3); + + // Delete all vectors for label 1 + let deleted = index.delete_vector(1).unwrap(); + assert_eq!(deleted, 2); + assert_eq!(index.index_size(), 1); + + // Only label 2 should remain + let results = index + .top_k_query(&vec![1.0, 0.0, 0.0, 0.0], 10, None) + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results.results[0].label, 2); + } + + #[test] + fn test_brute_force_multi_serialization() { + use std::io::Cursor; + + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add multiple vectors with same label + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + assert_eq!(index.index_size(), 4); + assert_eq!(index.label_count(1), 2); + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = BruteForceMulti::::load(&mut cursor).unwrap(); + + // Verify + assert_eq!(loaded.index_size(), 4); + assert_eq!(loaded.dimension(), 4); + assert_eq!(loaded.label_count(1), 2); + assert_eq!(loaded.label_count(2), 1); + assert_eq!(loaded.label_count(3), 1); + + // Query should work the same + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = loaded.top_k_query(&query, 3, None).unwrap(); + assert_eq!(results.results[0].label, 1); // Exact match + } + + #[test] + fn test_brute_force_multi_compact() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add multiple vectors, some with same label + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 3).unwrap(); + index.add_vector(&vec![0.0, 0.0, 0.0, 1.0], 4).unwrap(); + + assert_eq!(index.index_size(), 5); + assert_eq!(index.label_count(1), 2); + + // Delete label 2 and 3 + index.delete_vector(2).unwrap(); + index.delete_vector(3).unwrap(); + + assert_eq!(index.index_size(), 3); + assert!(index.fragmentation() > 0.0); + + // Compact the index + index.compact(true); + + // Verify fragmentation is gone + assert!((index.fragmentation() - 0.0).abs() < 0.01); + assert_eq!(index.index_size(), 3); + assert_eq!(index.label_count(1), 2); // Both vectors for label 1 remain + + // Verify queries still work correctly + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + assert_eq!(results.len(), 3); + // Top results should be label 1 vectors + assert_eq!(results.results[0].label, 1); + + // Verify we can still find v4 + let query2 = vec![0.0, 0.0, 0.0, 1.0]; + let results2 = index.top_k_query(&query2, 1, None).unwrap(); + assert_eq!(results2.results[0].label, 4); + + // Deleted labels should not appear + let all_results = index.top_k_query(&query, 10, None).unwrap(); + for result in &all_results.results { + assert!(result.label == 1 || result.label == 4); + } + } + + #[test] + fn test_brute_force_multi_multiple_labels() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add multiple vectors per label for multiple labels + for label in 1..=5u64 { + for i in 0..3 { + let v = vec![label as f32, i as f32, 0.0, 0.0]; + index.add_vector(&v, label).unwrap(); + } + } + + assert_eq!(index.index_size(), 15); + for label in 1..=5u64 { + assert_eq!(index.label_count(label), 3); + } + } + + #[test] + fn test_brute_force_multi_get_vectors() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![2.0, 0.0, 0.0, 0.0]; + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 1).unwrap(); + + let vectors = index.get_vectors(1).unwrap(); + assert_eq!(vectors.len(), 2); + // Check that both vectors are returned + let mut found_v1 = false; + let mut found_v2 = false; + for v in &vectors { + if v[0] == 1.0 { + found_v1 = true; + } + if v[0] == 2.0 { + found_v2 = true; + } + } + assert!(found_v1 && found_v2); + + // Non-existent label + assert!(index.get_vectors(999).is_none()); + } + + #[test] + fn test_brute_force_multi_get_labels() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 10).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 20).unwrap(); + index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 30).unwrap(); + + let labels = index.get_labels(); + assert_eq!(labels.len(), 3); + assert!(labels.contains(&10)); + assert!(labels.contains(&20)); + assert!(labels.contains(&30)); + } + + #[test] + fn test_brute_force_multi_compute_distance() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add two vectors for label 1, at different distances from query + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Query should return minimum distance among all vectors for the label + let query = vec![0.0, 0.0, 0.0, 0.0]; + let dist = index.compute_distance(1, &query).unwrap(); + // Distance to [1.0, 0.0, 0.0, 0.0] is 1.0 (L2 squared) + assert!((dist - 1.0).abs() < 0.001); + + // Non-existent label + assert!(index.compute_distance(999, &query).is_none()); + } + + #[test] + fn test_brute_force_multi_contains() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + assert!(index.contains(1)); + assert!(!index.contains(2)); + } + + #[test] + fn test_brute_force_multi_range_query() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add vectors at different distances + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 3).unwrap(); + index.add_vector(&vec![10.0, 0.0, 0.0, 0.0], 4).unwrap(); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + // L2 squared: dist to [1,0,0,0]=1, [2,0,0,0]=4, [3,0,0,0]=9, [10,0,0,0]=100 + let results = index.range_query(&query, 10.0, None).unwrap(); + + assert_eq!(results.len(), 3); // labels 1, 2, 3 are within radius 10 + for r in &results.results { + assert!(r.label != 4); // label 4 should not be included + } + } + + #[test] + fn test_brute_force_multi_filtered_query() { + use crate::query::QueryParams; + + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + for label in 1..=5u64 { + index.add_vector(&vec![label as f32, 0.0, 0.0, 0.0], label).unwrap(); + } + + // Filter to only allow even labels + let query_params = QueryParams::new().with_filter(|label| label % 2 == 0); + let query = vec![0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, Some(&query_params)).unwrap(); + + assert_eq!(results.len(), 2); // Only labels 2 and 4 + for r in &results.results { + assert!(r.label % 2 == 0); + } + } + + #[test] + fn test_brute_force_multi_batch_iterator() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + for i in 0..10u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![5.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(3) { + all_results.extend(batch); + } + + assert_eq!(all_results.len(), 10); + // Results should be sorted by distance + for i in 0..all_results.len() - 1 { + assert!(all_results[i].2 <= all_results[i + 1].2); + } + } + + #[test] + fn test_brute_force_multi_batch_iterator_reset() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + for i in 0..5u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + // First iteration + let batch1 = iter.next_batch(100).unwrap(); + assert_eq!(batch1.len(), 5); + assert!(!iter.has_next()); + + // Reset and iterate again + iter.reset(); + assert!(iter.has_next()); + let batch2 = iter.next_batch(100).unwrap(); + assert_eq!(batch2.len(), 5); + } + + #[test] + fn test_brute_force_multi_dimension_mismatch() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Wrong dimension on add + let result = index.add_vector(&vec![1.0, 2.0], 1); + assert!(matches!(result, Err(IndexError::DimensionMismatch { expected: 4, got: 2 }))); + + // Add a valid vector first + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Wrong dimension on query + let result = index.top_k_query(&vec![1.0, 2.0], 1, None); + assert!(matches!(result, Err(QueryError::DimensionMismatch { expected: 4, got: 2 }))); + } + + #[test] + fn test_brute_force_multi_capacity_exceeded() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::with_capacity(params, 2); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + // Third should fail + let result = index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 3); + assert!(matches!(result, Err(IndexError::CapacityExceeded { capacity: 2 }))); + } + + #[test] + fn test_brute_force_multi_delete_not_found() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let result = index.delete_vector(999); + assert!(matches!(result, Err(IndexError::LabelNotFound(999)))); + } + + #[test] + fn test_brute_force_multi_memory_usage() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + let initial_memory = index.memory_usage(); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + let after_memory = index.memory_usage(); + assert!(after_memory > initial_memory); + } + + #[test] + fn test_brute_force_multi_info() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let info = index.info(); + assert_eq!(info.size, 2); + assert_eq!(info.dimension, 4); + assert_eq!(info.index_type, "BruteForceMulti"); + assert!(info.memory_bytes > 0); + } + + #[test] + fn test_brute_force_multi_clear() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + index.clear(); + + assert_eq!(index.index_size(), 0); + assert!(index.get_labels().is_empty()); + assert!(!index.contains(1)); + assert!(!index.contains(2)); + } + + #[test] + fn test_brute_force_multi_add_vectors_batch() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + let vectors: Vec<(&[f32], LabelType)> = vec![ + (&[1.0, 0.0, 0.0, 0.0], 1), + (&[2.0, 0.0, 0.0, 0.0], 1), + (&[3.0, 0.0, 0.0, 0.0], 2), + ]; + + let added = index.add_vectors(&vectors).unwrap(); + assert_eq!(added, 3); + assert_eq!(index.index_size(), 3); + assert_eq!(index.label_count(1), 2); + assert_eq!(index.label_count(2), 1); + } + + #[test] + fn test_brute_force_multi_with_capacity() { + let params = BruteForceParams::new(4, Metric::L2); + let index = BruteForceMulti::::with_capacity(params, 100); + + assert_eq!(index.index_capacity(), Some(100)); + } + + #[test] + fn test_brute_force_multi_inner_product() { + let params = BruteForceParams::new(4, Metric::InnerProduct); + let mut index = BruteForceMulti::::new(params); + + // For InnerProduct, higher dot product = lower "distance" + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + // Label 1 has perfect alignment with query + assert_eq!(results.results[0].label, 1); + } + + #[test] + fn test_brute_force_multi_cosine() { + let params = BruteForceParams::new(4, Metric::Cosine); + let mut index = BruteForceMulti::::new(params); + + // Cosine similarity is direction-based + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 3).unwrap(); // Same direction as label 1 + + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + // Labels 1 and 3 have the same direction (cosine=1), should be top results + assert!(results.results[0].label == 1 || results.results[0].label == 3); + assert!(results.results[1].label == 1 || results.results[1].label == 3); + } + + #[test] + fn test_brute_force_multi_metric_getter() { + let params = BruteForceParams::new(4, Metric::Cosine); + let index = BruteForceMulti::::new(params); + + assert_eq!(index.metric(), Metric::Cosine); + } + + #[test] + fn test_brute_force_multi_filtered_range_query() { + use crate::query::QueryParams; + + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + for label in 1..=10u64 { + index.add_vector(&vec![label as f32, 0.0, 0.0, 0.0], label).unwrap(); + } + + // Filter to only allow labels 1-5, with range 50 (covers labels 1-7 by distance) + let query_params = QueryParams::new().with_filter(|label| label <= 5); + let query = vec![0.0, 0.0, 0.0, 0.0]; + let results = index.range_query(&query, 50.0, Some(&query_params)).unwrap(); + + // Should have labels 1-5 (filtered) that are within range 50 + assert_eq!(results.len(), 5); + for r in &results.results { + assert!(r.label <= 5); + } + } + + #[test] + fn test_brute_force_multi_empty_query() { + let params = BruteForceParams::new(4, Metric::L2); + let index = BruteForceMulti::::new(params); + + // Query on empty index + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, None).unwrap(); + assert!(results.is_empty()); + } + + #[test] + fn test_brute_force_multi_fragmentation() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + for i in 1..=10u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + // Initially no fragmentation + assert!((index.fragmentation() - 0.0).abs() < 0.01); + + // Delete half the vectors + for i in 1..=5u64 { + index.delete_vector(i).unwrap(); + } + + // Now there should be fragmentation + assert!(index.fragmentation() > 0.3); + } + + #[test] + fn test_brute_force_multi_parallel_query() { + use crate::query::QueryParams; + + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add many vectors to trigger parallel processing + for i in 0..2000u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![500.0, 0.0, 0.0, 0.0]; + let query_params = QueryParams::new().with_parallel(true); + let results = index.top_k_query(&query, 10, Some(&query_params)).unwrap(); + + assert_eq!(results.len(), 10); + // First result should be label 500 (exact match) + assert_eq!(results.results[0].label, 500); + } + + #[test] + fn test_brute_force_multi_serialization_with_capacity() { + use std::io::Cursor; + + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::with_capacity(params, 1000); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = BruteForceMulti::::load(&mut cursor).unwrap(); + + // Capacity should be preserved + assert_eq!(loaded.index_capacity(), Some(1000)); + assert_eq!(loaded.index_size(), 2); + assert_eq!(loaded.label_count(1), 2); + } + + #[test] + fn test_brute_force_multi_query_after_compact() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add many vectors + for i in 1..=20u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + // Delete odd labels + for i in (1..=20u64).step_by(2) { + index.delete_vector(i).unwrap(); + } + + // Compact + index.compact(true); + + // Query should still work + let query = vec![4.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 5, None).unwrap(); + + assert_eq!(results.len(), 5); + // First result should be label 4 (exact match, and it's even so not deleted) + assert_eq!(results.results[0].label, 4); + + // All results should be even labels + for r in &results.results { + assert!(r.label % 2 == 0); + } + } +} diff --git a/rust/vecsim/src/index/brute_force/single.rs b/rust/vecsim/src/index/brute_force/single.rs new file mode 100644 index 000000000..b97ee3b43 --- /dev/null +++ b/rust/vecsim/src/index/brute_force/single.rs @@ -0,0 +1,889 @@ +//! Single-value BruteForce index implementation. +//! +//! This index stores one vector per label. When adding a vector with +//! an existing label, the old vector is replaced. + +use super::{BruteForceCore, BruteForceParams, IdLabelEntry}; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply, QueryResult}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement}; +use crate::utils::MaxHeap; +use parking_lot::RwLock; +use rayon::prelude::*; +use std::collections::HashMap; + +/// Statistics about a BruteForce index. +#[derive(Debug, Clone)] +pub struct BruteForceStats { + /// Number of vectors in the index. + pub size: usize, + /// Number of deleted (but not yet compacted) elements. + pub deleted_count: usize, + /// Approximate memory usage in bytes. + pub memory_bytes: usize, +} + +/// Single-value BruteForce index. +/// +/// Each label has exactly one associated vector. Adding a vector with +/// an existing label replaces the previous vector. +pub struct BruteForceSingle { + /// Core storage and distance computation. + pub(crate) core: RwLock>, + /// Label to internal ID mapping. + label_to_id: RwLock>, + /// Internal ID to label mapping (for reverse lookup). + pub(crate) id_to_label: RwLock>, + /// Number of vectors. + count: std::sync::atomic::AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, +} + +impl BruteForceSingle { + /// Create a new single-value BruteForce index. + pub fn new(params: BruteForceParams) -> Self { + let core = BruteForceCore::new(¶ms); + let initial_capacity = params.initial_capacity; + + Self { + core: RwLock::new(core), + label_to_id: RwLock::new(HashMap::with_capacity(initial_capacity)), + id_to_label: RwLock::new(Vec::with_capacity(initial_capacity)), + count: std::sync::atomic::AtomicUsize::new(0), + capacity: None, + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: BruteForceParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the distance metric. + pub fn metric(&self) -> crate::distance::Metric { + self.core.read().metric + } + + /// Get the number of deleted (but not yet compacted) elements. + pub fn deleted_count(&self) -> usize { + self.id_to_label + .read() + .iter() + .filter(|e| !e.is_valid && e.label != 0) + .count() + } + + /// Get detailed statistics about the index. + pub fn stats(&self) -> BruteForceStats { + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + let id_to_label = self.id_to_label.read(); + + let deleted_count = id_to_label + .iter() + .filter(|e| !e.is_valid && e.label != 0) + .count(); + + BruteForceStats { + size: count, + deleted_count, + memory_bytes: self.memory_usage(), + } + } + + /// Get the memory usage in bytes. + pub fn memory_usage(&self) -> usize { + let core = self.core.read(); + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + // Vector data storage + let vector_storage = count * core.dim * std::mem::size_of::(); + + // Label mappings + let label_maps = self.label_to_id.read().capacity() + * std::mem::size_of::<(LabelType, IdType)>() + + self.id_to_label.read().capacity() * std::mem::size_of::(); + + vector_storage + label_maps + } + + /// Get a copy of the vector stored for a given label. + /// + /// Returns `None` if the label doesn't exist in the index. + pub fn get_vector(&self, label: LabelType) -> Option> { + let label_to_id = self.label_to_id.read(); + let id = *label_to_id.get(&label)?; + let core = self.core.read(); + core.data.get(id).map(|v| v.to_vec()) + } + + /// Get all labels currently in the index. + pub fn get_labels(&self) -> Vec { + self.label_to_id.read().keys().copied().collect() + } + + /// Compute the distance between a stored vector and a query vector. + /// + /// Returns `None` if the label doesn't exist. + pub fn compute_distance(&self, label: LabelType, query: &[T]) -> Option { + let label_to_id = self.label_to_id.read(); + let id = *label_to_id.get(&label)?; + let core = self.core.read(); + Some(core.compute_distance(id, query)) + } + + /// Clear all vectors from the index, resetting it to empty state. + pub fn clear(&mut self) { + let mut core = self.core.write(); + let mut label_to_id = self.label_to_id.write(); + let mut id_to_label = self.id_to_label.write(); + + core.data.clear(); + label_to_id.clear(); + id_to_label.clear(); + self.count.store(0, std::sync::atomic::Ordering::Relaxed); + } + + /// Compact the index by removing gaps from deleted vectors. + /// + /// This reorganizes the internal storage to reclaim space from deleted vectors. + /// After compaction, all vectors are stored contiguously. + /// + /// # Arguments + /// * `shrink` - If true, also release unused memory blocks + /// + /// # Returns + /// The number of bytes reclaimed (approximate). + pub fn compact(&mut self, shrink: bool) -> usize { + let mut core = self.core.write(); + let mut label_to_id = self.label_to_id.write(); + let mut id_to_label = self.id_to_label.write(); + + let old_capacity = core.data.capacity(); + let id_mapping = core.data.compact(shrink); + + // Update label_to_id mapping + for (_label, id) in label_to_id.iter_mut() { + if let Some(&new_id) = id_mapping.get(id) { + *id = new_id; + } + } + + // Rebuild id_to_label mapping + let mut new_id_to_label = Vec::with_capacity(id_mapping.len()); + for (&old_id, &new_id) in &id_mapping { + let new_id_usize = new_id as usize; + if new_id_usize >= new_id_to_label.len() { + new_id_to_label.resize(new_id_usize + 1, IdLabelEntry::default()); + } + if let Some(entry) = id_to_label.get(old_id as usize) { + new_id_to_label[new_id_usize] = *entry; + } + } + *id_to_label = new_id_to_label; + + let new_capacity = core.data.capacity(); + let dim = core.dim; + let bytes_per_vector = dim * std::mem::size_of::(); + + (old_capacity.saturating_sub(new_capacity)) * bytes_per_vector + } + + /// Get the fragmentation ratio of the index. + /// + /// Returns a value between 0.0 (no fragmentation) and 1.0 (all slots are deleted). + pub fn fragmentation(&self) -> f64 { + self.core.read().data.fragmentation() + } + + /// Add multiple vectors at once. + /// + /// Returns the number of vectors successfully added. + /// Stops on first error and returns what was added so far. + pub fn add_vectors( + &mut self, + vectors: &[(&[T], LabelType)], + ) -> Result { + let mut added = 0; + for &(vector, label) in vectors { + self.add_vector(vector, label)?; + added += 1; + } + Ok(added) + } + + /// Internal implementation of top-k query. + fn top_k_impl( + &self, + query: &[T], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + let core = self.core.read(); + let id_to_label = self.id_to_label.read(); + + if query.len() != core.dim { + return Err(QueryError::DimensionMismatch { + expected: core.dim, + got: query.len(), + }); + } + + let use_parallel = params.is_some_and(|p| p.parallel); + let filter = params.and_then(|p| p.filter.as_ref()).map(|f| f.as_ref()); + + let mut results = if use_parallel && id_to_label.len() > 1000 { + // Parallel scan for large datasets + self.parallel_top_k(&core, &id_to_label, query, k, filter) + } else { + // Sequential scan + self.sequential_top_k(&core, &id_to_label, query, k, filter) + }; + + results.sort_by_distance(); + Ok(results) + } + + /// Sequential top-k scan. + fn sequential_top_k( + &self, + core: &BruteForceCore, + id_to_label: &[IdLabelEntry], + query: &[T], + k: usize, + filter: Option<&(dyn Fn(LabelType) -> bool + Send + Sync)>, + ) -> QueryReply { + let mut heap = MaxHeap::new(k); + + for (id, entry) in id_to_label.iter().enumerate() { + if !entry.is_valid { + continue; + } + + if let Some(f) = filter { + if !f(entry.label) { + continue; + } + } + + let dist = core.compute_distance(id as IdType, query); + heap.try_insert(id as IdType, dist); + } + + let mut reply = QueryReply::with_capacity(heap.len()); + for entry in heap.into_sorted_vec() { + if let Some(label_entry) = id_to_label.get(entry.id as usize) { + reply.push(QueryResult::new(label_entry.label, entry.distance)); + } + } + reply + } + + /// Parallel top-k scan using rayon. + fn parallel_top_k( + &self, + core: &BruteForceCore, + id_to_label: &[IdLabelEntry], + query: &[T], + k: usize, + filter: Option<&(dyn Fn(LabelType) -> bool + Send + Sync)>, + ) -> QueryReply { + // Parallel map to compute distances + let candidates: Vec<_> = id_to_label + .par_iter() + .enumerate() + .filter_map(|(id, entry)| { + if !entry.is_valid { + return None; + } + if let Some(f) = filter { + if !f(entry.label) { + return None; + } + } + let dist = core.compute_distance(id as IdType, query); + Some((id as IdType, entry.label, dist)) + }) + .collect(); + + // Serial reduction to find top-k + let mut heap = MaxHeap::new(k); + for (id, _label, dist) in candidates { + heap.try_insert(id, dist); + } + + let mut reply = QueryReply::with_capacity(heap.len()); + for entry in heap.into_sorted_vec() { + if let Some(label_entry) = id_to_label.get(entry.id as usize) { + reply.push(QueryResult::new(label_entry.label, entry.distance)); + } + } + reply + } + + /// Internal implementation of range query. + fn range_impl( + &self, + query: &[T], + radius: T::DistanceType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + let core = self.core.read(); + let id_to_label = self.id_to_label.read(); + + if query.len() != core.dim { + return Err(QueryError::DimensionMismatch { + expected: core.dim, + got: query.len(), + }); + } + + let filter = params.and_then(|p| p.filter.as_ref()); + let mut reply = QueryReply::new(); + + for (id, entry) in id_to_label.iter().enumerate() { + if !entry.is_valid { + continue; + } + + if let Some(f) = filter { + if !f(entry.label) { + continue; + } + } + + let dist = core.compute_distance(id as IdType, query); + if dist.to_f64() <= radius.to_f64() { + reply.push(QueryResult::new(entry.label, dist)); + } + } + + reply.sort_by_distance(); + Ok(reply) + } +} + +impl VecSimIndex for BruteForceSingle { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + let mut core = self.core.write(); + + if vector.len() != core.dim { + return Err(IndexError::DimensionMismatch { + expected: core.dim, + got: vector.len(), + }); + } + + let mut label_to_id = self.label_to_id.write(); + let mut id_to_label = self.id_to_label.write(); + + // Check if label already exists + if let Some(&existing_id) = label_to_id.get(&label) { + // Update existing vector + let processed = core.dist_fn.preprocess(vector, core.dim); + core.data.update(existing_id, &processed); + return Ok(0); // No new vector added + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(std::sync::atomic::Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add new vector + let id = core + .add_vector(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + + // Update mappings + label_to_id.insert(label, id); + + // Ensure id_to_label is large enough + let id_usize = id as usize; + if id_usize >= id_to_label.len() { + id_to_label.resize(id_usize + 1, IdLabelEntry::default()); + } + id_to_label[id_usize] = IdLabelEntry { + label, + is_valid: true, + }; + + self.count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let mut label_to_id = self.label_to_id.write(); + let mut id_to_label = self.id_to_label.write(); + + if let Some(id) = label_to_id.remove(&label) { + // Mark as invalid + if let Some(entry) = id_to_label.get_mut(id as usize) { + entry.is_valid = false; + } + + // Mark slot as free in data storage + self.core.write().data.mark_deleted(id); + + self.count + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + Ok(1) + } else { + Err(IndexError::LabelNotFound(label)) + } + } + + fn top_k_query( + &self, + query: &[T], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + self.top_k_impl(query, k, params) + } + + fn range_query( + &self, + query: &[T], + radius: T::DistanceType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + self.range_impl(query, radius, params) + } + + fn index_size(&self) -> usize { + self.count.load(std::sync::atomic::Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.core.read().dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[T], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + let core = self.core.read(); + if query.len() != core.dim { + return Err(QueryError::DimensionMismatch { + expected: core.dim, + got: query.len(), + }); + } + + // Compute results immediately to preserve filter from params + let id_to_label = self.id_to_label.read(); + let filter = params.and_then(|p| p.filter.as_ref()); + + let mut results = Vec::new(); + for (id, entry) in id_to_label.iter().enumerate() { + if !entry.is_valid { + continue; + } + + if let Some(f) = filter { + if !f(entry.label) { + continue; + } + } + + let dist = core.compute_distance(id as IdType, query); + results.push((id as IdType, entry.label, dist)); + } + + // Sort by distance + results.sort_by(|a, b| { + a.2.to_f64() + .partial_cmp(&b.2.to_f64()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + Ok(Box::new(super::batch_iterator::BruteForceBatchIterator::::new(results))) + } + + fn info(&self) -> IndexInfo { + let core = self.core.read(); + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + IndexInfo { + size: count, + capacity: self.capacity, + dimension: core.dim, + index_type: "BruteForceSingle", + memory_bytes: count * core.dim * std::mem::size_of::() + + self.label_to_id.read().capacity() * std::mem::size_of::<(LabelType, IdType)>() + + self.id_to_label.read().capacity() * std::mem::size_of::(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.label_to_id.read().contains_key(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + if self.contains(label) { + 1 + } else { + 0 + } + } +} + +// Allow read-only concurrent access for queries +unsafe impl Send for BruteForceSingle {} +unsafe impl Sync for BruteForceSingle {} + +// Serialization support +impl BruteForceSingle { + /// Save the index to a writer. + pub fn save( + &self, + writer: &mut W, + ) -> crate::serialization::SerializationResult<()> { + use crate::serialization::*; + + let core = self.core.read(); + let label_to_id = self.label_to_id.read(); + let id_to_label = self.id_to_label.read(); + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + // Write header + let header = IndexHeader::new( + IndexTypeId::BruteForceSingle, + DataTypeId::F32, + core.metric, + core.dim, + count, + ); + header.write(writer)?; + + // Write capacity + write_u8(writer, if self.capacity.is_some() { 1 } else { 0 })?; + if let Some(cap) = self.capacity { + write_usize(writer, cap)?; + } + + // Write label_to_id mapping + write_usize(writer, label_to_id.len())?; + for (&label, &id) in label_to_id.iter() { + write_u64(writer, label)?; + write_u32(writer, id)?; + } + + // Write id_to_label entries + write_usize(writer, id_to_label.len())?; + for entry in id_to_label.iter() { + write_u64(writer, entry.label)?; + write_u8(writer, if entry.is_valid { 1 } else { 0 })?; + } + + // Write vectors - only write valid entries + let mut valid_ids: Vec = Vec::with_capacity(count); + for (id, entry) in id_to_label.iter().enumerate() { + if entry.is_valid { + valid_ids.push(id as IdType); + } + } + + write_usize(writer, valid_ids.len())?; + for id in valid_ids { + write_u32(writer, id)?; + if let Some(vector) = core.data.get(id) { + for &v in vector { + write_f32(writer, v)?; + } + } + } + + Ok(()) + } + + /// Load the index from a reader. + pub fn load(reader: &mut R) -> crate::serialization::SerializationResult { + use crate::serialization::*; + + // Read and validate header + let header = IndexHeader::read(reader)?; + + if header.index_type != IndexTypeId::BruteForceSingle { + return Err(SerializationError::IndexTypeMismatch { + expected: "BruteForceSingle".to_string(), + got: header.index_type.as_str().to_string(), + }); + } + + if header.data_type != DataTypeId::F32 { + return Err(SerializationError::InvalidData( + "Expected f32 data type".to_string(), + )); + } + + // Create the index with proper parameters + let params = BruteForceParams::new(header.dimension, header.metric); + let mut index = Self::new(params); + + // Read capacity + let has_capacity = read_u8(reader)? != 0; + if has_capacity { + index.capacity = Some(read_usize(reader)?); + } + + // Read label_to_id mapping + let label_to_id_len = read_usize(reader)?; + let mut label_to_id = HashMap::with_capacity(label_to_id_len); + for _ in 0..label_to_id_len { + let label = read_u64(reader)?; + let id = read_u32(reader)?; + label_to_id.insert(label, id); + } + + // Read id_to_label entries + let id_to_label_len = read_usize(reader)?; + let mut id_to_label = Vec::with_capacity(id_to_label_len); + for _ in 0..id_to_label_len { + let label = read_u64(reader)?; + let is_valid = read_u8(reader)? != 0; + id_to_label.push(IdLabelEntry { label, is_valid }); + } + + // Read vectors + let num_vectors = read_usize(reader)?; + let dim = header.dimension; + { + let mut core = index.core.write(); + + // Pre-allocate space + core.data.reserve(num_vectors); + + for _ in 0..num_vectors { + let id = read_u32(reader)?; + let mut vector = vec![0.0f32; dim]; + for v in &mut vector { + *v = read_f32(reader)?; + } + + // Add vector at specific ID + let added_id = core.data.add(&vector).ok_or_else(|| { + SerializationError::DataCorruption("Failed to add vector during deserialization".to_string()) + })?; + + // Ensure ID matches (vectors should be added in order) + if added_id != id { + // If IDs don't match, we need to handle this case + // For now, this shouldn't happen with proper serialization + } + } + } + + // Set the internal state + *index.label_to_id.write() = label_to_id; + *index.id_to_label.write() = id_to_label; + index + .count + .store(header.count, std::sync::atomic::Ordering::Relaxed); + + Ok(index) + } + + /// Save the index to a file. + pub fn save_to_file>( + &self, + path: P, + ) -> crate::serialization::SerializationResult<()> { + let file = std::fs::File::create(path)?; + let mut writer = std::io::BufWriter::new(file); + self.save(&mut writer) + } + + /// Load the index from a file. + pub fn load_from_file>( + path: P, + ) -> crate::serialization::SerializationResult { + let file = std::fs::File::open(path)?; + let mut reader = std::io::BufReader::new(file); + Self::load(&mut reader) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_brute_force_single_basic() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 1.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + assert_eq!(index.index_size(), 3); + + // Query for nearest neighbor + let query = vec![1.0, 0.1, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + assert_eq!(results.len(), 2); + assert_eq!(results.results[0].label, 1); // Closest to v1 + } + + #[test] + fn test_brute_force_single_replace() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Replace with new vector + index.add_vector(&v2, 1).unwrap(); + assert_eq!(index.index_size(), 1); // Size unchanged + + // Query should return updated vector + let query = vec![0.0, 1.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 1, None).unwrap(); + + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); // Should be very close + } + + #[test] + fn test_brute_force_single_delete() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + index.delete_vector(1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Deleted vector should not appear in results + let results = index + .top_k_query(&vec![1.0, 0.0, 0.0, 0.0], 10, None) + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results.results[0].label, 2); + } + + #[test] + fn test_brute_force_single_range_query() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&vec![0.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 3).unwrap(); + + // Range query with radius 1.5 (squared = 2.25) + let query = vec![0.0, 0.0, 0.0, 0.0]; + let results = index.range_query(&query, 2.25, None).unwrap(); + + // Should find vectors 1 and 2 (distances 0 and 1) + assert_eq!(results.len(), 2); + } + + #[test] + fn test_brute_force_single_serialization() { + use std::io::Cursor; + + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Add some vectors + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = BruteForceSingle::::load(&mut cursor).unwrap(); + + // Verify + assert_eq!(loaded.index_size(), 3); + assert_eq!(loaded.dimension(), 4); + + // Query should work the same + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = loaded.top_k_query(&query, 3, None).unwrap(); + assert_eq!(results.results[0].label, 1); + } + + #[test] + fn test_brute_force_single_compact() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 1.0, 0.0]; + let v4 = vec![0.0, 0.0, 0.0, 1.0]; + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + index.add_vector(&v4, 4).unwrap(); + + assert_eq!(index.index_size(), 4); + + // Delete vectors 2 and 3 + index.delete_vector(2).unwrap(); + index.delete_vector(3).unwrap(); + + assert_eq!(index.index_size(), 2); + assert!(index.fragmentation() > 0.0); + + // Compact the index + index.compact(true); + + // Verify fragmentation is gone + assert!((index.fragmentation() - 0.0).abs() < 0.01); + assert_eq!(index.index_size(), 2); + + // Verify queries still work correctly + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + assert_eq!(results.len(), 2); + assert_eq!(results.results[0].label, 1); // Closest to v1 + + // Verify we can still find v4 + let query2 = vec![0.0, 0.0, 0.0, 1.0]; + let results2 = index.top_k_query(&query2, 1, None).unwrap(); + assert_eq!(results2.results[0].label, 4); + + // Deleted labels should not appear + let all_results = index.top_k_query(&v1, 10, None).unwrap(); + for result in &all_results.results { + assert!(result.label == 1 || result.label == 4); + } + } +} diff --git a/rust/vecsim/src/index/debug.rs b/rust/vecsim/src/index/debug.rs new file mode 100644 index 000000000..87b9e5fe8 --- /dev/null +++ b/rust/vecsim/src/index/debug.rs @@ -0,0 +1,450 @@ +//! Debug and introspection API for indices. +//! +//! This module provides debugging and introspection capabilities: +//! - Graph structure inspection (HNSW, SVS) +//! - Search mode tracking +//! - Index statistics and metadata +//! +//! ## Example +//! +//! ```rust,ignore +//! use vecsim::index::debug::{DebugInfo, GraphInspector}; +//! +//! // Get debug info from an HNSW index +//! let info = index.debug_info(); +//! println!("Index size: {}", info.basic.index_size); +//! +//! // Inspect graph neighbors +//! if let Some(inspector) = index.graph_inspector() { +//! let neighbors = inspector.get_neighbors(0, 0); +//! println!("Level 0 neighbors of node 0: {:?}", neighbors); +//! } +//! ``` + +use crate::distance::Metric; +use crate::query::SearchMode; +use crate::types::{IdType, LabelType}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Basic immutable index information. +#[derive(Debug, Clone)] +pub struct BasicIndexInfo { + /// Type of the index (BruteForce, HNSW, SVS, etc.). + pub index_type: IndexType, + /// Vector dimensionality. + pub dim: usize, + /// Distance metric. + pub metric: Metric, + /// Whether this is a multi-value index. + pub is_multi: bool, + /// Block size used for storage. + pub block_size: usize, +} + +/// Mutable index statistics. +#[derive(Debug, Clone, Default)] +pub struct IndexStats { + /// Total number of vectors in the index. + pub index_size: usize, + /// Number of unique labels. + pub label_count: usize, + /// Memory usage in bytes (if tracked). + pub memory_usage: Option, + /// Number of deleted vectors pending cleanup. + pub deleted_count: usize, + /// Resize count (number of times the index grew). + pub resize_count: usize, +} + +/// HNSW-specific debug information. +#[derive(Debug, Clone)] +pub struct HnswDebugInfo { + /// M parameter (max connections per node). + pub m: usize, + /// Max M for level 0. + pub m_max: usize, + /// Max M for higher levels. + pub m_max0: usize, + /// ef_construction parameter. + pub ef_construction: usize, + /// Current ef_runtime parameter. + pub ef_runtime: usize, + /// Entry point node ID. + pub entry_point: Option, + /// Maximum level in the graph. + pub max_level: usize, + /// Distribution of nodes per level. + pub level_distribution: Vec, +} + +/// SVS/Vamana-specific debug information. +#[derive(Debug, Clone)] +pub struct SvsDebugInfo { + /// Alpha parameter for graph construction. + pub alpha: f32, + /// Maximum neighbors per node (R). + pub max_neighbors: usize, + /// Search window size (L). + pub search_window_size: usize, + /// Medoid (entry point) node ID. + pub medoid: Option, + /// Average degree in the graph. + pub avg_degree: f32, +} + +/// Combined debug information for an index. +#[derive(Debug, Clone)] +pub struct DebugInfo { + /// Basic immutable info. + pub basic: BasicIndexInfo, + /// Mutable statistics. + pub stats: IndexStats, + /// HNSW-specific info (if applicable). + pub hnsw: Option, + /// SVS-specific info (if applicable). + pub svs: Option, + /// Last search mode used. + pub last_search_mode: SearchMode, +} + +/// Index type enumeration for debug info. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IndexType { + BruteForceSingle, + BruteForceMulti, + HnswSingle, + HnswMulti, + SvsSingle, + SvsMulti, + TieredSingle, + TieredMulti, + TieredSvsSingle, + TieredSvsMulti, + DiskIndex, +} + +impl std::fmt::Display for IndexType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IndexType::BruteForceSingle => write!(f, "BruteForceSingle"), + IndexType::BruteForceMulti => write!(f, "BruteForceMulti"), + IndexType::HnswSingle => write!(f, "HnswSingle"), + IndexType::HnswMulti => write!(f, "HnswMulti"), + IndexType::SvsSingle => write!(f, "SvsSingle"), + IndexType::SvsMulti => write!(f, "SvsMulti"), + IndexType::TieredSingle => write!(f, "TieredSingle"), + IndexType::TieredMulti => write!(f, "TieredMulti"), + IndexType::TieredSvsSingle => write!(f, "TieredSvsSingle"), + IndexType::TieredSvsMulti => write!(f, "TieredSvsMulti"), + IndexType::DiskIndex => write!(f, "DiskIndex"), + } + } +} + +/// Trait for indices that support debug inspection. +pub trait Debuggable { + /// Get comprehensive debug information. + fn debug_info(&self) -> DebugInfo; + + /// Get basic immutable information. + fn basic_info(&self) -> BasicIndexInfo; + + /// Get current statistics. + fn stats(&self) -> IndexStats; + + /// Get the last search mode used. + fn last_search_mode(&self) -> SearchMode; +} + +/// Trait for inspecting graph structure (HNSW, SVS). +pub trait GraphInspector { + /// Get neighbors of a node at a specific level. + /// + /// For HNSW, level 0 is the base level. + /// For SVS, there's only one level (use level 0). + fn get_neighbors(&self, node_id: IdType, level: usize) -> Vec; + + /// Get the number of levels for a node (HNSW only). + fn get_node_level(&self, node_id: IdType) -> usize; + + /// Get the entry point of the graph. + fn entry_point(&self) -> Option; + + /// Get the maximum level in the graph. + fn max_level(&self) -> usize; + + /// Get the label for a node ID. + fn get_label(&self, node_id: IdType) -> Option; + + /// Check if a node is deleted. + fn is_deleted(&self, node_id: IdType) -> bool; +} + +/// Search mode tracker for hybrid queries. +/// +/// This tracks the search strategy used across queries. +#[derive(Debug, Default)] +pub struct SearchModeTracker { + last_mode: std::sync::atomic::AtomicU8, + adhoc_count: AtomicUsize, + batch_count: AtomicUsize, + switch_count: AtomicUsize, +} + +impl SearchModeTracker { + /// Create a new tracker. + pub fn new() -> Self { + Self::default() + } + + /// Record a search mode. + pub fn record(&self, mode: SearchMode) { + let mode_u8 = match mode { + SearchMode::Empty => 0, + SearchMode::StandardKnn => 1, + SearchMode::HybridAdhocBf => 2, + SearchMode::HybridBatches => 3, + SearchMode::HybridBatchesToAdhocBf => 4, + SearchMode::RangeQuery => 5, + }; + self.last_mode.store(mode_u8, Ordering::Relaxed); + + match mode { + SearchMode::HybridAdhocBf => { + self.adhoc_count.fetch_add(1, Ordering::Relaxed); + } + SearchMode::HybridBatches => { + self.batch_count.fetch_add(1, Ordering::Relaxed); + } + SearchMode::HybridBatchesToAdhocBf => { + self.switch_count.fetch_add(1, Ordering::Relaxed); + } + _ => {} + } + } + + /// Get the last search mode. + pub fn last_mode(&self) -> SearchMode { + match self.last_mode.load(Ordering::Relaxed) { + 0 => SearchMode::Empty, + 1 => SearchMode::StandardKnn, + 2 => SearchMode::HybridAdhocBf, + 3 => SearchMode::HybridBatches, + 4 => SearchMode::HybridBatchesToAdhocBf, + 5 => SearchMode::RangeQuery, + _ => SearchMode::Empty, + } + } + + /// Get count of ad-hoc searches. + pub fn adhoc_count(&self) -> usize { + self.adhoc_count.load(Ordering::Relaxed) + } + + /// Get count of batch searches. + pub fn batch_count(&self) -> usize { + self.batch_count.load(Ordering::Relaxed) + } + + /// Get count of switches from batch to ad-hoc. + pub fn switch_count(&self) -> usize { + self.switch_count.load(Ordering::Relaxed) + } + + /// Get total hybrid search count. + pub fn total_hybrid_count(&self) -> usize { + self.adhoc_count() + self.batch_count() + self.switch_count() + } +} + +/// Iterator over debug information entries. +pub struct DebugInfoIterator<'a> { + entries: std::vec::IntoIter<(&'a str, String)>, +} + +impl<'a> DebugInfoIterator<'a> { + /// Create from debug info. + pub fn from_debug_info(info: &'a DebugInfo) -> Self { + let mut entries = Vec::new(); + + // Basic info + entries.push(("index_type", info.basic.index_type.to_string())); + entries.push(("dim", info.basic.dim.to_string())); + entries.push(("metric", format!("{:?}", info.basic.metric))); + entries.push(("is_multi", info.basic.is_multi.to_string())); + entries.push(("block_size", info.basic.block_size.to_string())); + + // Stats + entries.push(("index_size", info.stats.index_size.to_string())); + entries.push(("label_count", info.stats.label_count.to_string())); + if let Some(mem) = info.stats.memory_usage { + entries.push(("memory_usage", mem.to_string())); + } + entries.push(("deleted_count", info.stats.deleted_count.to_string())); + + // HNSW-specific + if let Some(ref hnsw) = info.hnsw { + entries.push(("hnsw_m", hnsw.m.to_string())); + entries.push(("hnsw_ef_construction", hnsw.ef_construction.to_string())); + entries.push(("hnsw_ef_runtime", hnsw.ef_runtime.to_string())); + entries.push(("hnsw_max_level", hnsw.max_level.to_string())); + if let Some(ep) = hnsw.entry_point { + entries.push(("hnsw_entry_point", ep.to_string())); + } + } + + // SVS-specific + if let Some(ref svs) = info.svs { + entries.push(("svs_alpha", svs.alpha.to_string())); + entries.push(("svs_max_neighbors", svs.max_neighbors.to_string())); + entries.push(("svs_search_window_size", svs.search_window_size.to_string())); + entries.push(("svs_avg_degree", svs.avg_degree.to_string())); + if let Some(medoid) = svs.medoid { + entries.push(("svs_medoid", medoid.to_string())); + } + } + + // Search mode + entries.push(("last_search_mode", format!("{:?}", info.last_search_mode))); + + Self { + entries: entries.into_iter(), + } + } +} + +impl<'a> Iterator for DebugInfoIterator<'a> { + type Item = (&'a str, String); + + fn next(&mut self) -> Option { + self.entries.next() + } +} + +/// Element neighbors structure for graph inspection. +#[derive(Debug, Clone)] +pub struct ElementNeighbors { + /// Node ID. + pub node_id: IdType, + /// Label of the node. + pub label: LabelType, + /// Neighbors at each level (index 0 = level 0). + pub neighbors_per_level: Vec>, +} + +/// Mapping from internal ID to label. +#[derive(Debug, Clone, Default)] +pub struct IdLabelMapping { + id_to_label: HashMap, + label_to_ids: HashMap>, +} + +impl IdLabelMapping { + /// Create a new empty mapping. + pub fn new() -> Self { + Self::default() + } + + /// Insert a mapping. + pub fn insert(&mut self, id: IdType, label: LabelType) { + self.id_to_label.insert(id, label); + self.label_to_ids.entry(label).or_default().push(id); + } + + /// Get label for an ID. + pub fn get_label(&self, id: IdType) -> Option { + self.id_to_label.get(&id).copied() + } + + /// Get IDs for a label. + pub fn get_ids(&self, label: LabelType) -> Option<&[IdType]> { + self.label_to_ids.get(&label).map(|v| v.as_slice()) + } + + /// Get total number of mappings. + pub fn len(&self) -> usize { + self.id_to_label.len() + } + + /// Check if empty. + pub fn is_empty(&self) -> bool { + self.id_to_label.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_search_mode_tracker() { + let tracker = SearchModeTracker::new(); + + tracker.record(SearchMode::HybridAdhocBf); + tracker.record(SearchMode::HybridBatches); + tracker.record(SearchMode::HybridAdhocBf); + + assert_eq!(tracker.adhoc_count(), 2); + assert_eq!(tracker.batch_count(), 1); + assert_eq!(tracker.total_hybrid_count(), 3); + } + + #[test] + fn test_id_label_mapping() { + let mut mapping = IdLabelMapping::new(); + mapping.insert(0, 100); + mapping.insert(1, 100); + mapping.insert(2, 200); + + assert_eq!(mapping.get_label(0), Some(100)); + assert_eq!(mapping.get_label(2), Some(200)); + assert_eq!(mapping.get_ids(100), Some(&[0, 1][..])); + assert_eq!(mapping.len(), 3); + } + + #[test] + fn test_index_type_display() { + assert_eq!(IndexType::HnswSingle.to_string(), "HnswSingle"); + assert_eq!(IndexType::BruteForceMulti.to_string(), "BruteForceMulti"); + } + + #[test] + fn test_debug_info_iterator() { + let info = DebugInfo { + basic: BasicIndexInfo { + index_type: IndexType::HnswSingle, + dim: 128, + metric: Metric::L2, + is_multi: false, + block_size: 1024, + }, + stats: IndexStats { + index_size: 1000, + label_count: 1000, + memory_usage: Some(1024 * 1024), + deleted_count: 0, + resize_count: 2, + }, + hnsw: Some(HnswDebugInfo { + m: 16, + m_max: 16, + m_max0: 32, + ef_construction: 200, + ef_runtime: 10, + entry_point: Some(0), + max_level: 3, + level_distribution: vec![1000, 50, 5, 1], + }), + svs: None, + last_search_mode: SearchMode::StandardKnn, + }; + + let iter = DebugInfoIterator::from_debug_info(&info); + let entries: Vec<_> = iter.collect(); + + assert!(entries.iter().any(|(k, _)| *k == "index_type")); + assert!(entries.iter().any(|(k, _)| *k == "hnsw_m")); + assert!(entries.iter().any(|(k, v)| *k == "dim" && v == "128")); + } +} diff --git a/rust/vecsim/src/index/disk/mod.rs b/rust/vecsim/src/index/disk/mod.rs new file mode 100644 index 000000000..f4088f51c --- /dev/null +++ b/rust/vecsim/src/index/disk/mod.rs @@ -0,0 +1,102 @@ +//! Disk-based vector index using memory-mapped storage. +//! +//! This module provides persistent vector indices that use memory-mapped files +//! for vector storage, allowing datasets larger than RAM. +//! +//! # Backend Options +//! +//! - **BruteForce**: Linear scan over all vectors (exact results, O(n)) +//! - **Vamana**: SVS graph structure for approximate search (fast, O(log n)) + +pub mod single; + +pub use single::DiskIndexSingle; + +use crate::distance::Metric; +use std::path::PathBuf; + +/// Backend type for the disk index. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DiskBackend { + /// Linear scan (exact results). + #[default] + BruteForce, + /// Vamana graph (approximate, fast). + Vamana, +} + +/// Parameters for creating a disk-based index. +#[derive(Debug, Clone)] +pub struct DiskIndexParams { + /// Vector dimension. + pub dim: usize, + /// Distance metric. + pub metric: Metric, + /// Path to the data file. + pub data_path: PathBuf, + /// Backend algorithm. + pub backend: DiskBackend, + /// Initial capacity. + pub initial_capacity: usize, + /// Graph max degree (for Vamana backend). + pub graph_max_degree: usize, + /// Alpha parameter (for Vamana backend). + pub alpha: f32, + /// Construction window size (for Vamana backend). + pub construction_l: usize, + /// Search window size (for Vamana backend). + pub search_l: usize, +} + +impl DiskIndexParams { + /// Create new disk index parameters. + pub fn new>(dim: usize, metric: Metric, data_path: P) -> Self { + Self { + dim, + metric, + data_path: data_path.into(), + backend: DiskBackend::BruteForce, + initial_capacity: 10_000, + graph_max_degree: 32, + alpha: 1.2, + construction_l: 200, + search_l: 100, + } + } + + /// Set the backend algorithm. + pub fn with_backend(mut self, backend: DiskBackend) -> Self { + self.backend = backend; + self + } + + /// Set initial capacity. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.initial_capacity = capacity; + self + } + + /// Set graph max degree (for Vamana backend). + pub fn with_graph_degree(mut self, degree: usize) -> Self { + self.graph_max_degree = degree; + self + } + + /// Set alpha parameter (for Vamana backend). + pub fn with_alpha(mut self, alpha: f32) -> Self { + self.alpha = alpha; + self + } + + /// Set construction window size (for Vamana backend). + pub fn with_construction_l(mut self, l: usize) -> Self { + self.construction_l = l; + self + } + + /// Set search window size (for Vamana backend). + pub fn with_search_l(mut self, l: usize) -> Self { + self.search_l = l; + self + } +} diff --git a/rust/vecsim/src/index/disk/single.rs b/rust/vecsim/src/index/disk/single.rs new file mode 100644 index 000000000..6f20470a0 --- /dev/null +++ b/rust/vecsim/src/index/disk/single.rs @@ -0,0 +1,2131 @@ +//! Single-value disk-based index implementation. +//! +//! This index stores one vector per label using memory-mapped storage. + +use super::{DiskBackend, DiskIndexParams}; +use crate::containers::MmapDataBlocks; +use crate::distance::{create_distance_function, DistanceFunction}; +use crate::index::hnsw::VisitedNodesHandlerPool; +use crate::index::svs::graph::VamanaGraph; +use crate::index::svs::search::{greedy_beam_search_filtered, robust_prune}; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply, QueryResult}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement, INVALID_ID}; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::io; +use std::path::Path; +use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering}; + +/// Statistics for disk index. +#[derive(Debug, Clone)] +pub struct DiskIndexStats { + /// Number of vectors. + pub vector_count: usize, + /// Vector dimension. + pub dimension: usize, + /// Backend type. + pub backend: DiskBackend, + /// Data file path. + pub data_path: String, + /// Fragmentation ratio. + pub fragmentation: f64, + /// Approximate memory usage (for in-memory structures). + pub memory_bytes: usize, +} + +/// State for Vamana graph-based search. +struct VamanaState { + /// Graph structure for neighbor relationships. + graph: VamanaGraph, + /// Medoid (entry point for search). + medoid: AtomicU32, + /// Pool of visited handlers for concurrent searches. + visited_pool: VisitedNodesHandlerPool, +} + +impl VamanaState { + /// Create a new Vamana state. + fn new(initial_capacity: usize, max_degree: usize) -> Self { + Self { + graph: VamanaGraph::new(initial_capacity, max_degree), + medoid: AtomicU32::new(INVALID_ID), + visited_pool: VisitedNodesHandlerPool::new(initial_capacity), + } + } +} + +/// Single-value disk-based index. +/// +/// Stores vectors in a memory-mapped file for persistence. +/// Supports BruteForce (exact) or Vamana (approximate) search. +pub struct DiskIndexSingle { + /// Memory-mapped vector storage. + data: RwLock>, + /// Distance function. + dist_fn: Box>, + /// Label to internal ID mapping. + label_to_id: RwLock>, + /// Internal ID to label mapping. + id_to_label: RwLock>, + /// Parameters. + params: DiskIndexParams, + /// Vector count. + count: AtomicUsize, + /// Capacity (if bounded). + capacity: Option, + /// Vamana graph state (for Vamana backend). + vamana_state: Option>, +} + +impl DiskIndexSingle { + /// Create a new disk-based index. + pub fn new(params: DiskIndexParams) -> io::Result { + let data = MmapDataBlocks::new(¶ms.data_path, params.dim, params.initial_capacity)?; + let dist_fn = create_distance_function(params.metric, params.dim); + + // Build label mappings from existing data + let label_to_id = HashMap::new(); + let id_to_label = HashMap::new(); + let count = data.len(); + + // Initialize Vamana state if using Vamana backend + let vamana_state = match params.backend { + DiskBackend::Vamana => Some(RwLock::new(VamanaState::new( + params.initial_capacity, + params.graph_max_degree, + ))), + DiskBackend::BruteForce => None, + }; + + // For now, we don't persist label mappings - this would need separate storage + // New indices start fresh, existing data without labels won't work well + // A full implementation would store labels in the mmap file + + Ok(Self { + data: RwLock::new(data), + dist_fn, + label_to_id: RwLock::new(label_to_id), + id_to_label: RwLock::new(id_to_label), + params, + count: AtomicUsize::new(count), + capacity: None, + vamana_state, + }) + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: DiskIndexParams, max_capacity: usize) -> io::Result { + let mut index = Self::new(params)?; + index.capacity = Some(max_capacity); + Ok(index) + } + + /// Get the data file path. + pub fn data_path(&self) -> &Path { + &self.params.data_path + } + + /// Get the backend type. + pub fn backend(&self) -> DiskBackend { + self.params.backend + } + + /// Get index statistics. + pub fn stats(&self) -> DiskIndexStats { + let data = self.data.read(); + + DiskIndexStats { + vector_count: self.count.load(Ordering::Relaxed), + dimension: self.params.dim, + backend: self.params.backend, + data_path: self.params.data_path.to_string_lossy().to_string(), + fragmentation: data.fragmentation(), + memory_bytes: self.memory_usage(), + } + } + + /// Get the fragmentation ratio. + pub fn fragmentation(&self) -> f64 { + self.data.read().fragmentation() + } + + /// Estimate memory usage (in-memory structures only). + pub fn memory_usage(&self) -> usize { + let label_size = self.label_to_id.read().capacity() + * (std::mem::size_of::() + std::mem::size_of::()) + + self.id_to_label.read().capacity() + * (std::mem::size_of::() + std::mem::size_of::()); + + label_size + std::mem::size_of::() + } + + /// Flush changes to disk. + pub fn flush(&self) -> io::Result<()> { + self.data.read().flush() + } + + /// Get a copy of a vector by label. + pub fn get_vector(&self, label: LabelType) -> Option> { + let id = *self.label_to_id.read().get(&label)?; + self.data.read().get(id).map(|v| v.to_vec()) + } + + /// Clear all vectors. + pub fn clear(&mut self) { + self.data.write().clear(); + self.label_to_id.write().clear(); + self.id_to_label.write().clear(); + self.count.store(0, Ordering::Relaxed); + + // Clear Vamana state if present + if let Some(ref vamana) = self.vamana_state { + let mut state = vamana.write(); + state.graph = VamanaGraph::new( + self.params.initial_capacity, + self.params.graph_max_degree, + ); + state.medoid.store(INVALID_ID, Ordering::Release); + } + } + + // ========== Vamana-specific methods ========== + + /// Insert a vector into the Vamana graph. + /// + /// This method is called after the vector has been added to the mmap storage. + fn insert_into_graph(&self, id: IdType, label: LabelType) { + let vamana = match &self.vamana_state { + Some(v) => v, + None => return, + }; + + let mut state = vamana.write(); + let data = self.data.read(); + + // Ensure graph has space + state.graph.ensure_capacity(id as usize + 1); + state.graph.set_label(id, label); + + // Update visited pool if needed + if (id as usize) >= state.visited_pool.current_capacity() { + state.visited_pool.resize(id as usize + 1024); + } + + let medoid = state.medoid.load(Ordering::Acquire); + + if medoid == INVALID_ID { + // First element becomes the medoid + state.medoid.store(id, Ordering::Release); + return; + } + + // Get query vector + let query = match data.get(id) { + Some(v) => v, + None => return, + }; + + // Search for neighbors using greedy beam search + let selected = { + let mut visited = state.visited_pool.get(); + visited.reset(); + + let neighbors = greedy_beam_search_filtered:: bool>( + medoid, + query, + self.params.construction_l, + &state.graph, + |nid| data.get(nid), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + None, + ); + + // Select neighbors using robust pruning + robust_prune( + id, + &neighbors, + self.params.graph_max_degree, + self.params.alpha, + |nid| data.get(nid), + self.dist_fn.as_ref(), + self.params.dim, + ) + }; + + // Set outgoing edges + state.graph.set_neighbors(id, &selected); + + // Add bidirectional edges + for &neighbor_id in &selected { + Self::add_bidirectional_link_inner( + &mut state.graph, + &data, + self.dist_fn.as_ref(), + self.params.dim, + self.params.graph_max_degree, + self.params.alpha, + neighbor_id, + id, + ); + } + } + + /// Add a bidirectional link from one node to another in the graph. + fn add_bidirectional_link_inner( + graph: &mut VamanaGraph, + data: &MmapDataBlocks, + dist_fn: &dyn DistanceFunction, + dim: usize, + max_degree: usize, + alpha: f32, + from: IdType, + to: IdType, + ) { + let mut current_neighbors = graph.get_neighbors(from); + if current_neighbors.contains(&to) { + return; + } + + current_neighbors.push(to); + + // Check if we need to prune + if current_neighbors.len() > max_degree { + if let Some(from_data) = data.get(from) { + let candidates: Vec<_> = current_neighbors + .iter() + .filter_map(|&n| { + data.get(n).map(|d| { + let dist = dist_fn.compute(d, from_data, dim); + (n, dist) + }) + }) + .collect(); + + let selected = robust_prune( + from, + &candidates, + max_degree, + alpha, + |id| data.get(id), + dist_fn, + dim, + ); + + graph.set_neighbors(from, &selected); + } + } else { + graph.set_neighbors(from, ¤t_neighbors); + } + } + + /// Mark a vector as deleted in the Vamana graph. + fn delete_from_graph(&self, id: IdType) { + let vamana = match &self.vamana_state { + Some(v) => v, + None => return, + }; + + let mut state = vamana.write(); + state.graph.mark_deleted(id); + + // Update medoid if needed + if state.medoid.load(Ordering::Acquire) == id { + if let Some(new_medoid) = self.find_new_medoid(&state) { + state.medoid.store(new_medoid, Ordering::Release); + } else { + state.medoid.store(INVALID_ID, Ordering::Release); + } + } + } + + /// Find a new medoid for the graph (excluding deleted nodes). + fn find_new_medoid(&self, state: &VamanaState) -> Option { + let data = self.data.read(); + let id_to_label = self.id_to_label.read(); + + // Get all active IDs + let ids: Vec = data + .iter_ids() + .filter(|&id| !state.graph.is_deleted(id) && id_to_label.contains_key(&id)) + .collect(); + + if ids.is_empty() { + return None; + } + if ids.len() == 1 { + return Some(ids[0]); + } + + // Sample for medoid computation + let sample_size = ids.len().min(100); + let sample: Vec = if ids.len() <= sample_size { + ids.clone() + } else { + use rand::seq::SliceRandom; + let mut rng = rand::thread_rng(); + let mut shuffled = ids.clone(); + shuffled.shuffle(&mut rng); + shuffled.into_iter().take(sample_size).collect() + }; + + // Find vector with minimum total distance + let mut best_id = sample[0]; + let mut best_total_dist = f64::MAX; + + for &candidate in &sample { + if let Some(candidate_data) = data.get(candidate) { + let total_dist: f64 = sample + .iter() + .filter(|&&id| id != candidate) + .filter_map(|&id| data.get(id)) + .map(|other_data| { + self.dist_fn + .compute(candidate_data, other_data, self.params.dim) + .to_f64() + }) + .sum(); + + if total_dist < best_total_dist { + best_total_dist = total_dist; + best_id = candidate; + } + } + } + + Some(best_id) + } + + /// Perform Vamana graph-based search. + fn vamana_search( + &self, + query: &[T], + k: usize, + filter: Option<&(dyn Fn(LabelType) -> bool + Send + Sync)>, + ) -> Vec<(IdType, LabelType, T::DistanceType)> { + let vamana = match &self.vamana_state { + Some(v) => v, + None => return Vec::new(), + }; + + let state = vamana.read(); + let medoid = state.medoid.load(Ordering::Acquire); + + if medoid == INVALID_ID { + return Vec::new(); + } + + let data = self.data.read(); + let id_to_label = self.id_to_label.read(); + + let mut visited = state.visited_pool.get(); + visited.reset(); + + let search_l = self.params.search_l.max(k); + + // Perform search with internal ID filter if needed + let results = if let Some(f) = filter { + // Create filter that works with internal IDs + let id_filter = |id: IdType| -> bool { + if let Some(&label) = id_to_label.get(&id) { + f(label) + } else { + false + } + }; + + greedy_beam_search_filtered( + medoid, + query, + search_l, + &state.graph, + |id| data.get(id), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + Some(&id_filter), + ) + } else { + greedy_beam_search_filtered:: bool>( + medoid, + query, + search_l, + &state.graph, + |id| data.get(id), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + None, + ) + }; + + // Convert to final format with labels + results + .into_iter() + .take(k) + .filter_map(|(id, dist)| { + let label = *id_to_label.get(&id)?; + Some((id, label, dist)) + }) + .collect() + } + + /// Find the medoid (approximate centroid) of all vectors. + /// + /// This is used as the entry point for Vamana search. + pub fn find_medoid(&self) -> Option { + let data = self.data.read(); + let id_to_label = self.id_to_label.read(); + + let ids: Vec = data + .iter_ids() + .filter(|id| id_to_label.contains_key(id)) + .collect(); + + if ids.is_empty() { + return None; + } + if ids.len() == 1 { + return Some(ids[0]); + } + + // Sample at most 1000 vectors for medoid computation + let sample_size = ids.len().min(1000); + let sample: Vec = if ids.len() <= sample_size { + ids.clone() + } else { + use rand::seq::SliceRandom; + let mut rng = rand::thread_rng(); + let mut shuffled = ids.clone(); + shuffled.shuffle(&mut rng); + shuffled.into_iter().take(sample_size).collect() + }; + + // Find vector with minimum total distance + let mut best_id = sample[0]; + let mut best_total_dist = f64::MAX; + + for &candidate in &sample { + if let Some(candidate_data) = data.get(candidate) { + let total_dist: f64 = sample + .iter() + .filter(|&&id| id != candidate) + .filter_map(|&id| data.get(id)) + .map(|other_data| { + self.dist_fn + .compute(candidate_data, other_data, self.params.dim) + .to_f64() + }) + .sum(); + + if total_dist < best_total_dist { + best_total_dist = total_dist; + best_id = candidate; + } + } + } + + Some(best_id) + } + + /// Build or rebuild the Vamana graph. + /// + /// This performs a two-pass construction for better recall: + /// 1. First pass with alpha = 1.0 + /// 2. Second pass with configured alpha + pub fn build(&mut self) { + if self.params.backend != DiskBackend::Vamana { + return; + } + + let vamana = match &self.vamana_state { + Some(v) => v, + None => return, + }; + + // Get all IDs and labels + let ids_and_labels: Vec<(IdType, LabelType)> = { + let id_to_label = self.id_to_label.read(); + self.data + .read() + .iter_ids() + .filter_map(|id| id_to_label.get(&id).map(|&label| (id, label))) + .collect() + }; + + if ids_and_labels.is_empty() { + return; + } + + // Clear and reinitialize graph + { + let mut state = vamana.write(); + state.graph = VamanaGraph::new( + self.params.initial_capacity, + self.params.graph_max_degree, + ); + state.medoid.store(INVALID_ID, Ordering::Release); + } + + // Find and set medoid + if let Some(medoid) = self.find_medoid() { + vamana.write().medoid.store(medoid, Ordering::Release); + } + + // First pass: insert all nodes with alpha = 1.0 + let original_alpha = self.params.alpha; + + // Pass 1 with alpha = 1.0 (no diversity pruning) + for &(id, label) in &ids_and_labels { + self.insert_into_graph_with_alpha(id, label, 1.0); + } + + // Pass 2: rebuild with configured alpha for better diversity + if original_alpha > 1.0 { + self.rebuild_graph_with_alpha(original_alpha, &ids_and_labels); + } + } + + /// Insert a vector into the graph with a specific alpha value. + fn insert_into_graph_with_alpha(&self, id: IdType, label: LabelType, alpha: f32) { + let vamana = match &self.vamana_state { + Some(v) => v, + None => return, + }; + + let mut state = vamana.write(); + let data = self.data.read(); + + // Ensure graph has space + state.graph.ensure_capacity(id as usize + 1); + state.graph.set_label(id, label); + + // Update visited pool if needed + if (id as usize) >= state.visited_pool.current_capacity() { + state.visited_pool.resize(id as usize + 1024); + } + + let medoid = state.medoid.load(Ordering::Acquire); + + if medoid == INVALID_ID { + state.medoid.store(id, Ordering::Release); + return; + } + + let query = match data.get(id) { + Some(v) => v, + None => return, + }; + + let selected = { + let mut visited = state.visited_pool.get(); + visited.reset(); + + let neighbors = greedy_beam_search_filtered:: bool>( + medoid, + query, + self.params.construction_l, + &state.graph, + |nid| data.get(nid), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + None, + ); + + robust_prune( + id, + &neighbors, + self.params.graph_max_degree, + alpha, + |nid| data.get(nid), + self.dist_fn.as_ref(), + self.params.dim, + ) + }; + + state.graph.set_neighbors(id, &selected); + + for &neighbor_id in &selected { + Self::add_bidirectional_link_inner( + &mut state.graph, + &data, + self.dist_fn.as_ref(), + self.params.dim, + self.params.graph_max_degree, + alpha, + neighbor_id, + id, + ); + } + } + + /// Rebuild the graph with a specific alpha value. + fn rebuild_graph_with_alpha(&self, alpha: f32, ids_and_labels: &[(IdType, LabelType)]) { + let vamana = match &self.vamana_state { + Some(v) => v, + None => return, + }; + + // Clear existing neighbors + { + let mut state = vamana.write(); + for &(id, _) in ids_and_labels { + state.graph.clear_neighbors(id); + } + } + + // Update medoid + if let Some(new_medoid) = self.find_medoid() { + vamana.write().medoid.store(new_medoid, Ordering::Release); + } + + // Reinsert all nodes with configured alpha + for &(id, label) in ids_and_labels { + self.insert_into_graph_with_alpha(id, label, alpha); + } + } + + /// Perform brute-force search. + fn brute_force_search( + &self, + query: &[T], + k: usize, + filter: Option<&(dyn Fn(LabelType) -> bool + Send + Sync)>, + ) -> Vec<(IdType, LabelType, T::DistanceType)> { + let data = self.data.read(); + let id_to_label = self.id_to_label.read(); + + let mut results: Vec<(IdType, LabelType, T::DistanceType)> = data + .iter_ids() + .filter_map(|id| { + let label = *id_to_label.get(&id)?; + + // Apply filter if present + if let Some(f) = filter { + if !f(label) { + return None; + } + } + + let vec_data = data.get(id)?; + let dist = self.dist_fn.compute(query, vec_data, self.params.dim); + Some((id, label, dist)) + }) + .collect(); + + // Sort by distance + results.sort_by(|a, b| { + a.2.to_f64() + .partial_cmp(&b.2.to_f64()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + results.truncate(k); + results + } + + /// Perform range search. + fn brute_force_range( + &self, + query: &[T], + radius: T::DistanceType, + filter: Option<&(dyn Fn(LabelType) -> bool + Send + Sync)>, + ) -> Vec<(IdType, LabelType, T::DistanceType)> { + let data = self.data.read(); + let id_to_label = self.id_to_label.read(); + let radius_f64 = radius.to_f64(); + + let mut results: Vec<(IdType, LabelType, T::DistanceType)> = data + .iter_ids() + .filter_map(|id| { + let label = *id_to_label.get(&id)?; + + if let Some(f) = filter { + if !f(label) { + return None; + } + } + + let vec_data = data.get(id)?; + let dist = self.dist_fn.compute(query, vec_data, self.params.dim); + + if dist.to_f64() <= radius_f64 { + Some((id, label, dist)) + } else { + None + } + }) + .collect(); + + results.sort_by(|a, b| { + a.2.to_f64() + .partial_cmp(&b.2.to_f64()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + results + } +} + +impl VecSimIndex for DiskIndexSingle { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + if vector.len() != self.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.params.dim, + got: vector.len(), + }); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + let mut label_to_id = self.label_to_id.write(); + let mut id_to_label = self.id_to_label.write(); + let mut data = self.data.write(); + + // Check if label exists (replace) + if let Some(&existing_id) = label_to_id.get(&label) { + // Mark old entry as deleted in both storage and graph + data.mark_deleted(existing_id); + id_to_label.remove(&existing_id); + + // Mark deleted in graph for Vamana backend and handle medoid update + if self.params.backend == DiskBackend::Vamana { + if let Some(ref vamana) = self.vamana_state { + let mut state = vamana.write(); + state.graph.mark_deleted(existing_id); + // If the medoid was deleted, set to INVALID_ID (will be set to new node) + if state.medoid.load(Ordering::Acquire) == existing_id { + state.medoid.store(INVALID_ID, Ordering::Release); + } + } + } + + let new_id = data.add(vector).ok_or_else(|| { + IndexError::Internal("Failed to add vector to storage".to_string()) + })?; + + label_to_id.insert(label, new_id); + id_to_label.insert(new_id, label); + + // Drop locks before calling insert_into_graph + drop(data); + drop(label_to_id); + drop(id_to_label); + + // Insert into Vamana graph + if self.params.backend == DiskBackend::Vamana { + self.insert_into_graph(new_id, label); + } + + return Ok(0); // Replacement + } + + // Add new vector + let id = data.add(vector).ok_or_else(|| { + IndexError::Internal("Failed to add vector to storage".to_string()) + })?; + + label_to_id.insert(label, id); + id_to_label.insert(id, label); + self.count.fetch_add(1, Ordering::Relaxed); + + // Drop locks before calling insert_into_graph + drop(data); + drop(label_to_id); + drop(id_to_label); + + // Insert into Vamana graph + if self.params.backend == DiskBackend::Vamana { + self.insert_into_graph(id, label); + } + + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let mut label_to_id = self.label_to_id.write(); + let mut id_to_label = self.id_to_label.write(); + let mut data = self.data.write(); + + if let Some(id) = label_to_id.remove(&label) { + data.mark_deleted(id); + id_to_label.remove(&id); + self.count.fetch_sub(1, Ordering::Relaxed); + + // Drop locks before calling delete_from_graph + drop(data); + drop(label_to_id); + drop(id_to_label); + + // Delete from Vamana graph + if self.params.backend == DiskBackend::Vamana { + self.delete_from_graph(id); + } + + Ok(1) + } else { + Err(IndexError::LabelNotFound(label)) + } + } + + fn top_k_query( + &self, + query: &[Self::DataType], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + if k == 0 { + return Ok(QueryReply::new()); + } + + // Build filter if provided + let filter = params.and_then(|p| p.filter.as_ref().map(|f| f.as_ref())); + + let results = match self.params.backend { + DiskBackend::BruteForce => self.brute_force_search(query, k, filter), + DiskBackend::Vamana => self.vamana_search(query, k, filter), + }; + + let mut reply = QueryReply::with_capacity(results.len()); + for (_, label, dist) in results { + reply.push(QueryResult::new(label, dist)); + } + + Ok(reply) + } + + fn range_query( + &self, + query: &[Self::DataType], + radius: Self::DistType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + let filter = params.and_then(|p| p.filter.as_ref().map(|f| f.as_ref())); + + let results = match self.params.backend { + DiskBackend::BruteForce => self.brute_force_range(query, radius, filter), + DiskBackend::Vamana => { + // For Vamana, search with a large k and filter by radius + // Use search_l as the search window, results are approximate + let search_results = self.vamana_search(query, self.params.search_l, filter); + let radius_f64 = radius.to_f64(); + + search_results + .into_iter() + .filter(|(_, _, dist)| dist.to_f64() <= radius_f64) + .collect() + } + }; + + let mut reply = QueryReply::with_capacity(results.len()); + for (_, label, dist) in results { + reply.push(QueryResult::new(label, dist)); + } + + Ok(reply) + } + + fn index_size(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[Self::DataType], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + Ok(Box::new(DiskIndexBatchIterator::new( + self, + query.to_vec(), + params.cloned(), + ))) + } + + fn info(&self) -> IndexInfo { + IndexInfo { + size: self.index_size(), + capacity: self.capacity, + dimension: self.params.dim, + index_type: "DiskIndexSingle", + memory_bytes: self.memory_usage(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.label_to_id.read().contains_key(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + if self.contains(label) { 1 } else { 0 } + } +} + +// Safety: DiskIndexSingle is thread-safe with RwLock protection +unsafe impl Send for DiskIndexSingle {} +unsafe impl Sync for DiskIndexSingle {} + +/// Batch iterator for DiskIndexSingle. +pub struct DiskIndexBatchIterator<'a, T: VectorElement> { + index: &'a DiskIndexSingle, + query: Vec, + params: Option, + results: Option>, + position: usize, +} + +impl<'a, T: VectorElement> DiskIndexBatchIterator<'a, T> { + /// Create a new batch iterator. + pub fn new( + index: &'a DiskIndexSingle, + query: Vec, + params: Option, + ) -> Self { + Self { + index, + query, + params, + results: None, + position: 0, + } + } + + fn ensure_results(&mut self) { + if self.results.is_some() { + return; + } + + let filter = self + .params + .as_ref() + .and_then(|p| p.filter.as_ref().map(|f| f.as_ref())); + + let count = self.index.count.load(Ordering::Relaxed); + let results = self.index.brute_force_search(&self.query, count, filter); + + self.results = Some(results); + } +} + +impl<'a, T: VectorElement> BatchIterator for DiskIndexBatchIterator<'a, T> { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + match &self.results { + Some(results) => self.position < results.len(), + None => true, + } + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + self.ensure_results(); + + let results = self.results.as_ref()?; + if self.position >= results.len() { + return None; + } + + let end = (self.position + batch_size).min(results.len()); + let batch = results[self.position..end].to_vec(); + self.position = end; + + if batch.is_empty() { + None + } else { + Some(batch) + } + } + + fn reset(&mut self) { + self.position = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + use std::fs; + + fn temp_path() -> std::path::PathBuf { + let mut path = std::env::temp_dir(); + path.push(format!("disk_index_test_{}.dat", rand::random::())); + path + } + + #[test] + fn test_disk_index_basic() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 1.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + assert_eq!(index.index_size(), 3); + + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + assert_eq!(results.len(), 2); + assert_eq!(results.results[0].label, 1); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_delete() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + index.delete_vector(1).unwrap(); + assert_eq!(index.index_size(), 1); + assert!(!index.contains(1)); + assert!(index.contains(2)); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_replace() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Replace + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Query should return new vector + let query = vec![0.0, 1.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 1, None).unwrap(); + assert!(results.results[0].distance < 0.001); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_range_query() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[0.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 3).unwrap(); + index.add_vector(&[10.0, 0.0, 0.0, 0.0], 4).unwrap(); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let results = index.range_query(&query, 5.0, None).unwrap(); + + // Should find first 3 vectors (distances: 0, 1, 4) + assert_eq!(results.len(), 3); + } + + fs::remove_file(&path).ok(); + } + + // ========== Vamana Backend Tests ========== + + #[test] + fn test_disk_index_vamana_basic() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana) + .with_graph_degree(8) + .with_construction_l(50) + .with_search_l(20); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 1.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + assert_eq!(index.index_size(), 3); + + // Query for exact match + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + assert_eq!(results.len(), 2); + // Exact match should be first + assert_eq!(results.results[0].label, 1); + assert!(results.results[0].distance < 0.001); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_scaling() { + let path = temp_path(); + { + let dim = 16; + let n = 1000; + + let params = DiskIndexParams::new(dim, Metric::L2, &path) + .with_backend(DiskBackend::Vamana) + .with_capacity(n + 100) + .with_graph_degree(16) + .with_construction_l(100) + .with_search_l(50); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Add vectors in a grid pattern + for i in 0..n { + let mut vec = vec![0.0f32; dim]; + vec[i % dim] = (i / dim + 1) as f32; + index.add_vector(&vec, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), n); + + // Query for a vector that should exist + let mut query = vec![0.0f32; dim]; + query[0] = 1.0; + let results = index.top_k_query(&query, 10, None).unwrap(); + + assert_eq!(results.len(), 10); + // First result should be label 0 (exact match) + assert_eq!(results.results[0].label, 0); + assert!(results.results[0].distance < 0.001); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_delete() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana) + .with_graph_degree(8) + .with_search_l(20); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + assert_eq!(index.index_size(), 3); + + // Delete label 1 + index.delete_vector(1).unwrap(); + assert_eq!(index.index_size(), 2); + assert!(!index.contains(1)); + + // Query should not return deleted vector + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + // Should only get 2 results (labels 2 and 3) + assert_eq!(results.len(), 2); + for result in &results.results { + assert_ne!(result.label, 1); + } + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_update() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana) + .with_graph_degree(8) + .with_search_l(20); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Add initial vector + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Query for original + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 1, None).unwrap(); + assert!(results.results[0].distance < 0.001); + + // Replace with new vector + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 1).unwrap(); + assert_eq!(index.index_size(), 1); // Still 1 (replacement) + + // Query for new vector + let query2 = vec![0.0, 1.0, 0.0, 0.0]; + let results2 = index.top_k_query(&query2, 1, None).unwrap(); + assert_eq!(results2.results[0].label, 1); + assert!(results2.results[0].distance < 0.001); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_range_query() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana) + .with_graph_degree(8) + .with_search_l(50); // Higher search_l for range query + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[0.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 3).unwrap(); + index.add_vector(&[10.0, 0.0, 0.0, 0.0], 4).unwrap(); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let results = index.range_query(&query, 5.0, None).unwrap(); + + // Should find first 3 vectors (distances: 0, 1, 4) + assert!(results.len() >= 2); // At least 2 with Vamana (approximate) + assert!(results.len() <= 4); // At most all 4 + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_filtered() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana) + .with_graph_degree(8) + .with_search_l(30); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Add vectors with labels 1-10 + for i in 1..=10 { + let vec = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&vec, i).unwrap(); + } + + // Query with filter to only accept even labels + let query = vec![5.0, 0.0, 0.0, 0.0]; + let filter_fn = |label: LabelType| label % 2 == 0; + let params = QueryParams::default().with_filter(filter_fn); + + let results = index.top_k_query(&query, 5, Some(¶ms)).unwrap(); + + // All results should have even labels + for result in &results.results { + assert_eq!(result.label % 2, 0); + } + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_build() { + let path = temp_path(); + { + let dim = 8; + let n = 100; + + let params = DiskIndexParams::new(dim, Metric::L2, &path) + .with_backend(DiskBackend::Vamana) + .with_graph_degree(8) + .with_alpha(1.2) + .with_construction_l(50) + .with_search_l(30); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Add vectors + for i in 0..n { + let mut vec = vec![0.0f32; dim]; + vec[i % dim] = (i / dim + 1) as f32; + index.add_vector(&vec, i as u64).unwrap(); + } + + // Rebuild the graph (two-pass construction) + index.build(); + + // Query should still work correctly + let mut query = vec![0.0f32; dim]; + query[0] = 1.0; + let results = index.top_k_query(&query, 5, None).unwrap(); + + assert!(!results.is_empty()); + // First result should be label 0 (exact match) + assert_eq!(results.results[0].label, 0); + } + + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_recall() { + let path = temp_path(); + { + let dim = 8; + let n = 200; + + let params = DiskIndexParams::new(dim, Metric::L2, &path) + .with_backend(DiskBackend::Vamana) + .with_graph_degree(16) + .with_construction_l(100) + .with_search_l(50); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Create brute force index for ground truth + let bf_path = temp_path(); + let bf_params = DiskIndexParams::new(dim, Metric::L2, &bf_path); + let mut bf_index = DiskIndexSingle::::new(bf_params).unwrap(); + + // Add same vectors to both + for i in 0..n { + let vec: Vec = (0..dim).map(|j| ((i * dim + j) % 100) as f32 / 10.0).collect(); + index.add_vector(&vec, i as u64).unwrap(); + bf_index.add_vector(&vec, i as u64).unwrap(); + } + + // Test recall with several queries + let k = 10; + let mut total_recall = 0.0; + let num_queries = 10; + + for q in 0..num_queries { + let query: Vec = (0..dim).map(|j| ((q * 7 + j * 3) % 100) as f32 / 10.0).collect(); + + let vamana_results = index.top_k_query(&query, k, None).unwrap(); + let bf_results = bf_index.top_k_query(&query, k, None).unwrap(); + + // Count how many Vamana results are in ground truth + let vamana_labels: std::collections::HashSet<_> = + vamana_results.results.iter().map(|r| r.label).collect(); + let bf_labels: std::collections::HashSet<_> = + bf_results.results.iter().map(|r| r.label).collect(); + + let intersection = vamana_labels.intersection(&bf_labels).count(); + total_recall += intersection as f64 / k as f64; + } + + let avg_recall = total_recall / num_queries as f64; + // With reasonable parameters, we should get at least 60% recall + assert!( + avg_recall >= 0.6, + "Average recall {} is too low", + avg_recall + ); + + fs::remove_file(&bf_path).ok(); + } + + fs::remove_file(&path).ok(); + } + + // ========== DiskIndexParams Tests ========== + + #[test] + fn test_disk_index_params_new() { + let path = temp_path(); + let params = DiskIndexParams::new(128, Metric::L2, &path); + + assert_eq!(params.dim, 128); + assert_eq!(params.metric, Metric::L2); + assert_eq!(params.backend, DiskBackend::BruteForce); + assert_eq!(params.initial_capacity, 10_000); + assert_eq!(params.graph_max_degree, 32); + assert!((params.alpha - 1.2).abs() < 0.01); + assert_eq!(params.construction_l, 200); + assert_eq!(params.search_l, 100); + } + + #[test] + fn test_disk_index_params_builder() { + let path = temp_path(); + let params = DiskIndexParams::new(64, Metric::InnerProduct, &path) + .with_backend(DiskBackend::Vamana) + .with_capacity(5000) + .with_graph_degree(64) + .with_alpha(1.5) + .with_construction_l(150) + .with_search_l(75); + + assert_eq!(params.dim, 64); + assert_eq!(params.metric, Metric::InnerProduct); + assert_eq!(params.backend, DiskBackend::Vamana); + assert_eq!(params.initial_capacity, 5000); + assert_eq!(params.graph_max_degree, 64); + assert!((params.alpha - 1.5).abs() < 0.01); + assert_eq!(params.construction_l, 150); + assert_eq!(params.search_l, 75); + } + + #[test] + fn test_disk_backend_default() { + let backend = DiskBackend::default(); + assert_eq!(backend, DiskBackend::BruteForce); + } + + #[test] + fn test_disk_backend_debug() { + assert_eq!(format!("{:?}", DiskBackend::BruteForce), "BruteForce"); + assert_eq!(format!("{:?}", DiskBackend::Vamana), "Vamana"); + } + + // ========== DiskIndexSingle Utility Method Tests ========== + + #[test] + fn test_disk_index_get_vector() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + let v1 = vec![1.0, 2.0, 3.0, 4.0]; + index.add_vector(&v1, 1).unwrap(); + + // Get existing vector + let retrieved = index.get_vector(1).unwrap(); + assert_eq!(retrieved.len(), 4); + for (a, b) in v1.iter().zip(retrieved.iter()) { + assert!((a - b).abs() < 0.001); + } + + // Get non-existing vector + assert!(index.get_vector(999).is_none()); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_stats() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + let stats = index.stats(); + assert_eq!(stats.vector_count, 2); + assert_eq!(stats.dimension, 4); + assert_eq!(stats.backend, DiskBackend::BruteForce); + assert!(stats.data_path.contains("disk_index_test_")); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_fragmentation() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Initially no fragmentation + let frag = index.fragmentation(); + assert!(frag >= 0.0); + + // Add and delete to create fragmentation + for i in 0..10 { + index.add_vector(&[i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + for i in 0..5 { + index.delete_vector(i).unwrap(); + } + + // Now there should be some fragmentation + let frag_after = index.fragmentation(); + assert!(frag_after >= 0.0); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_clear() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + assert_eq!(index.index_size(), 2); + + index.clear(); + + assert_eq!(index.index_size(), 0); + assert!(!index.contains(1)); + assert!(!index.contains(2)); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_clear() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + index.clear(); + + assert_eq!(index.index_size(), 0); + + // Should be able to add new vectors after clear + index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3).unwrap(); + assert_eq!(index.index_size(), 1); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_data_path() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let index = DiskIndexSingle::::new(params).unwrap(); + + assert_eq!(index.data_path(), &path); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_backend() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let index = DiskIndexSingle::::new(params).unwrap(); + assert_eq!(index.backend(), DiskBackend::BruteForce); + } + fs::remove_file(&path).ok(); + + let path2 = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path2) + .with_backend(DiskBackend::Vamana); + let index = DiskIndexSingle::::new(params).unwrap(); + assert_eq!(index.backend(), DiskBackend::Vamana); + } + fs::remove_file(&path2).ok(); + } + + #[test] + fn test_disk_index_memory_usage() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + let initial_mem = index.memory_usage(); + assert!(initial_mem > 0); + + // Add vectors - memory usage should increase + for i in 0..100 { + index.add_vector(&[i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let final_mem = index.memory_usage(); + assert!(final_mem >= initial_mem); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_info() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let info = index.info(); + assert_eq!(info.size, 1); + assert_eq!(info.dimension, 4); + assert_eq!(info.index_type, "DiskIndexSingle"); + assert!(info.memory_bytes > 0); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_contains_and_label_count() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert!(index.contains(1)); + assert!(index.contains(2)); + assert!(!index.contains(3)); + + assert_eq!(index.label_count(1), 1); + assert_eq!(index.label_count(2), 1); + assert_eq!(index.label_count(3), 0); + } + fs::remove_file(&path).ok(); + } + + // ========== Batch Iterator Tests ========== + + #[test] + fn test_disk_index_batch_iterator() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + for i in 0..10 { + index.add_vector(&[i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iterator = index.batch_iterator(&query, None).unwrap(); + + assert!(iterator.has_next()); + + let batch1 = iterator.next_batch(5).unwrap(); + assert_eq!(batch1.len(), 5); + + let batch2 = iterator.next_batch(5).unwrap(); + assert_eq!(batch2.len(), 5); + + // No more results + assert!(!iterator.has_next()); + assert!(iterator.next_batch(5).is_none()); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_batch_iterator_reset() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + for i in 0..5 { + index.add_vector(&[i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iterator = index.batch_iterator(&query, None).unwrap(); + + // Consume all + let _ = iterator.next_batch(10); + assert!(!iterator.has_next()); + + // Reset + iterator.reset(); + assert!(iterator.has_next()); + + // Can iterate again + let batch = iterator.next_batch(5).unwrap(); + assert_eq!(batch.len(), 5); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_batch_iterator_ordering() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Add vectors at increasing distances from origin + index.add_vector(&[0.0, 0.0, 0.0, 0.0], 0).unwrap(); + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[3.0, 0.0, 0.0, 0.0], 3).unwrap(); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iterator = index.batch_iterator(&query, None).unwrap(); + + let batch = iterator.next_batch(4).unwrap(); + + // Should be ordered by distance (label 0 first) + assert_eq!(batch[0].1, 0); + assert_eq!(batch[1].1, 1); + assert_eq!(batch[2].1, 2); + assert_eq!(batch[3].1, 3); + } + fs::remove_file(&path).ok(); + } + + // ========== Error Handling Tests ========== + + #[test] + fn test_disk_index_dimension_mismatch_add() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Wrong dimension + let result = index.add_vector(&[1.0, 2.0, 3.0], 1); + assert!(result.is_err()); + + match result { + Err(IndexError::DimensionMismatch { expected, got }) => { + assert_eq!(expected, 4); + assert_eq!(got, 3); + } + _ => panic!("Expected DimensionMismatch error"), + } + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_dimension_mismatch_query() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Wrong dimension query + let result = index.top_k_query(&[1.0, 0.0, 0.0], 1, None); + assert!(result.is_err()); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_delete_not_found() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Delete non-existing label + let result = index.delete_vector(999); + assert!(result.is_err()); + + match result { + Err(IndexError::LabelNotFound(label)) => { + assert_eq!(label, 999); + } + _ => panic!("Expected LabelNotFound error"), + } + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_capacity_exceeded() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::with_capacity(params, 2).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + // Third should fail + let result = index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3); + assert!(result.is_err()); + + match result { + Err(IndexError::CapacityExceeded { capacity }) => { + assert_eq!(capacity, 2); + } + _ => panic!("Expected CapacityExceeded error"), + } + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_empty_query() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let index = DiskIndexSingle::::new(params).unwrap(); + + // Query on empty index + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, None).unwrap(); + + assert!(results.is_empty()); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_zero_k_query() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Query with k=0 + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 0, None).unwrap(); + + assert!(results.is_empty()); + } + fs::remove_file(&path).ok(); + } + + // ========== Different Metrics Tests ========== + + #[test] + fn test_disk_index_inner_product() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::InnerProduct, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // For IP, higher dot product = more similar (lower distance when negated) + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.5, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 3).unwrap(); + + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + assert_eq!(results.len(), 3); + // Label 1 should be first (highest IP with query) + assert_eq!(results.results[0].label, 1); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_cosine() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::Cosine, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Same direction should have distance ~0 + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 2).unwrap(); // Same direction, different magnitude + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 3).unwrap(); // Orthogonal + + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + assert_eq!(results.len(), 3); + // Labels 1 and 2 should have similar (small) distances + assert!(results.results[0].distance < 0.1); + assert!(results.results[1].distance < 0.1); + // Label 3 (orthogonal) should have distance ~1 + assert!(results.results[2].distance > 0.9); + } + fs::remove_file(&path).ok(); + } + + // ========== Vamana-specific Additional Tests ========== + + #[test] + fn test_disk_index_vamana_find_medoid() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + // Add vectors + index.add_vector(&[0.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[0.5, 0.0, 0.0, 0.0], 3).unwrap(); + + let medoid = index.find_medoid(); + assert!(medoid.is_some()); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_empty_find_medoid() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana); + let index = DiskIndexSingle::::new(params).unwrap(); + + // Empty index should return None + let medoid = index.find_medoid(); + assert!(medoid.is_none()); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_single_vector() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 1, None).unwrap(); + + assert_eq!(results.len(), 1); + assert_eq!(results.results[0].label, 1); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_vamana_delete_all() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path) + .with_backend(DiskBackend::Vamana); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + index.delete_vector(1).unwrap(); + index.delete_vector(2).unwrap(); + + assert_eq!(index.index_size(), 0); + + // Query on empty index + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, None).unwrap(); + assert!(results.is_empty()); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_flush() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let mut index = DiskIndexSingle::::new(params).unwrap(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Flush should succeed + let result = index.flush(); + assert!(result.is_ok()); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_with_capacity() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let index = DiskIndexSingle::::with_capacity(params, 100).unwrap(); + + assert_eq!(index.index_capacity(), Some(100)); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_no_capacity() { + let path = temp_path(); + { + let params = DiskIndexParams::new(4, Metric::L2, &path); + let index = DiskIndexSingle::::new(params).unwrap(); + + assert_eq!(index.index_capacity(), None); + } + fs::remove_file(&path).ok(); + } + + #[test] + fn test_disk_index_dimension() { + let path = temp_path(); + { + let params = DiskIndexParams::new(128, Metric::L2, &path); + let index = DiskIndexSingle::::new(params).unwrap(); + + assert_eq!(index.dimension(), 128); + } + fs::remove_file(&path).ok(); + } +} diff --git a/rust/vecsim/src/index/hnsw/OPTIMIZATIONS.md b/rust/vecsim/src/index/hnsw/OPTIMIZATIONS.md new file mode 100644 index 000000000..7085b53f5 --- /dev/null +++ b/rust/vecsim/src/index/hnsw/OPTIMIZATIONS.md @@ -0,0 +1,821 @@ +# HNSW Performance Optimization Opportunities + +This document analyzes specific optimization opportunities for the Rust HNSW implementation, +with code examples and implementation suggestions based on analysis of the current codebase. + +## Table of Contents + +1. [Batch Distance Computation](#1-batch-distance-computation) +2. [Memory Layout Improvements](#2-memory-layout-improvements) +3. [Product Quantization Integration](#3-product-quantization-integration) +4. [Parallel Search Improvements](#4-parallel-search-improvements) +5. [Adaptive Parameters](#5-adaptive-parameters) +6. [Graph Construction Optimizations](#6-graph-construction-optimizations) + +--- + +## 1. Batch Distance Computation + +### Current Implementation Analysis + +**Location**: `rust/vecsim/src/index/hnsw/search.rs:172-193` + +The current `search_layer` function computes distances one at a time in the inner loop: + +```rust +for neighbor in element.iter_neighbors(level) { + if visited.visit(neighbor) { + continue; + } + // ... + if let Some(data) = data_getter(neighbor) { + let dist = dist_fn.compute(data, query, dim); // One distance at a time + // ... + } +} +``` + +### Optimization Opportunity + +**Expected Impact**: 15-30% speedup in search operations +**Complexity**: Medium + +Batch multiple neighbors for SIMD-efficient distance computation. Instead of computing +distances one at a time, collect neighbor vectors and compute distances in batches. + +### Suggested Implementation + +```rust +// New batch distance function in rust/vecsim/src/distance/mod.rs +pub trait BatchDistanceFunction { + /// Compute distances from query to multiple vectors. + /// Returns distances in the same order as input vectors. + fn compute_batch( + &self, + query: &[T], + vectors: &[&[T]], // Multiple vectors + dim: usize, + ) -> Vec; +} + +// Example AVX2 implementation for L2 (rust/vecsim/src/distance/simd/avx2.rs) +#[target_feature(enable = "avx2", enable = "fma")] +unsafe fn l2_squared_batch_f32_avx2( + query: *const f32, + vectors: &[*const f32], // Pointers to multiple vectors + dim: usize, +) -> Vec { + let mut results = Vec::with_capacity(vectors.len()); + + // Process in groups of 4 vectors (optimal for AVX2 256-bit registers) + for chunk in vectors.chunks(4) { + // Interleave loading from 4 vectors for better memory bandwidth + let mut sums = [_mm256_setzero_ps(); 4]; + + for d in (0..dim).step_by(8) { + let q = _mm256_loadu_ps(query.add(d)); + + for (i, &vec_ptr) in chunk.iter().enumerate() { + let v = _mm256_loadu_ps(vec_ptr.add(d)); + let diff = _mm256_sub_ps(q, v); + sums[i] = _mm256_fmadd_ps(diff, diff, sums[i]); + } + } + + // Horizontal sum and store results + for (i, sum) in sums.iter().take(chunk.len()).enumerate() { + results.push(hsum256_ps(*sum)); + } + } + results +} +``` + +### Modified search_layer + +```rust +// In rust/vecsim/src/index/hnsw/search.rs +pub fn search_layer_batched<'a, T, D, F, P, G>(/* ... */) -> SearchResult { + // ... setup ... + + while let Some(candidate) = candidates.pop() { + if let Some(element) = graph.get(candidate.id) { + // Collect non-visited neighbors + let mut batch_ids: Vec = Vec::with_capacity(32); + let mut batch_data: Vec<&[T]> = Vec::with_capacity(32); + + for neighbor in element.iter_neighbors(level) { + if visited.visit(neighbor) { + continue; + } + if let Some(data) = data_getter(neighbor) { + batch_ids.push(neighbor); + batch_data.push(data); + } + } + + // Batch distance computation + if !batch_data.is_empty() { + let distances = dist_fn.compute_batch(query, &batch_data, dim); + + for (i, dist) in distances.into_iter().enumerate() { + let neighbor = batch_ids[i]; + // ... rest of distance processing ... + } + } + } + } +} +``` + +--- + +## 2. Memory Layout Improvements + +### Current Implementation Analysis + +**Location**: `rust/vecsim/src/index/hnsw/graph.rs:42-52, 194-222` + +Current `LevelLinks` structure: + +```rust +pub struct LevelLinks { + neighbors: Vec, // Dynamic allocation, pointer indirection + count: AtomicU32, // 4 bytes + capacity: usize, // 8 bytes +} + +pub struct ElementGraphData { + pub meta: ElementMetaData, // 10+ bytes (label u64 + level u8 + deleted bool) + pub levels: Vec, // Dynamic Vec, pointer indirection + pub lock: Mutex<()>, // ~40+ bytes (parking_lot Mutex) +} +``` + +### Issues + +1. **Cache Misses**: Multiple pointer indirections (ElementGraphData → Vec → LevelLinks → Vec) +2. **Memory Fragmentation**: Small allocations scattered across heap +3. **Lock Overhead**: Per-element mutex is ~40 bytes overhead per element + +### Optimization Opportunity + +**Expected Impact**: 10-25% speedup in graph traversal +**Complexity**: Hard + +### Suggested Implementation: Cache-Line Aligned Compact Structure + +```rust +// rust/vecsim/src/index/hnsw/graph_compact.rs + +/// Compact neighbor list that fits in cache lines. +/// For M=16, this is exactly 64 bytes (one cache line). +#[repr(C, align(64))] +pub struct CompactLevelLinks { + /// Neighbor IDs stored inline (no pointer indirection) + neighbors: [u32; 15], // 60 bytes - supports up to M=15 + count: u8, // 1 byte + capacity: u8, // 1 byte + _padding: [u8; 2], // 2 bytes padding for alignment +} + +/// For level 0 with M_max_0=32, use two cache lines +#[repr(C, align(64))] +pub struct CompactLevelLinks0 { + neighbors: [u32; 31], // 124 bytes + count: u8, // 1 byte + capacity: u8, // 1 byte + _padding: [u8; 2], // 2 bytes +} // Total: 128 bytes = 2 cache lines + +impl CompactLevelLinks { + #[inline] + pub fn iter_neighbors(&self) -> impl Iterator + '_ { + self.neighbors[..self.count as usize].iter().copied() + } +} +``` + +### Structure-of-Arrays Layout for Vector Data + +**Location**: `rust/vecsim/src/containers/data_blocks.rs` + +Consider a Structure-of-Arrays (SoA) layout for better SIMD efficiency: + +```rust +// Current: Array of Structures (AoS) +// vectors[0] = [x0, y0, z0, w0] +// vectors[1] = [x1, y1, z1, w1] +// ... + +// Proposed: Structure of Arrays (SoA) for specific dimensions +pub struct SoAVectorBlock { + x: Vec, // All x components contiguous + y: Vec, // All y components contiguous + z: Vec, // All z components contiguous + w: Vec, // All w components contiguous +} + +// This enables computing distances to 8 vectors simultaneously with AVX2: +// Load 8 x-components, compute 8 differences, etc. +``` + +### Compressed Neighbor IDs + +For indices with < 65536 vectors, use u16 IDs: + +```rust +/// Adaptive ID type based on index size +pub enum CompressedLinks { + /// For small indices (< 65536 elements) + Small(SmallLinks), + /// For large indices + Large(LargeLinks), +} + +#[repr(C, align(64))] +pub struct SmallLinks { + neighbors: [u16; 30], // 60 bytes - double the capacity! + count: u8, + capacity: u8, + _padding: [u8; 2], +} // Still fits in one cache line, but 2x capacity +``` + +--- + +## 3. Product Quantization (PQ) Integration + +### Current Implementation Analysis + +**Location**: `rust/vecsim/src/quantization/mod.rs` + +The codebase already has quantization support: +- `SQ8`: Scalar quantization to 8-bit +- `LVQ`: Learned Vector Quantization (4-bit/8-bit) +- `LeanVec`: Dimension reduction with two-level quantization + +### Optimization Opportunity + +**Expected Impact**: 2-4x memory reduction, 20-50% faster search +**Complexity**: Hard + +### Suggested PQ Implementation + +```rust +// rust/vecsim/src/quantization/pq.rs + +/// Product Quantization codec +pub struct PQCodec { + /// Number of subspaces (typically dim/4 to dim/8) + m: usize, + /// Bits per subspace (typically 8 for 256 centroids) + nbits: usize, + /// Dimension of original vectors + dim: usize, + /// Dimension of each subspace + dsub: usize, + /// Centroids for each subspace: [m][2^nbits][dsub] + centroids: Vec>>, +} + +impl PQCodec { + /// Encode a vector to PQ codes + pub fn encode(&self, vector: &[f32]) -> Vec { + let mut codes = Vec::with_capacity(self.m); + + for (i, chunk) in vector.chunks(self.dsub).enumerate() { + // Find nearest centroid for this subspace + let mut best_idx = 0; + let mut best_dist = f32::MAX; + + for (j, centroid) in self.centroids[i].iter().enumerate() { + let dist = l2_distance(chunk, centroid); + if dist < best_dist { + best_dist = dist; + best_idx = j; + } + } + codes.push(best_idx as u8); + } + codes + } + + /// Asymmetric distance computation using precomputed lookup tables + pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 { + // Precompute distance tables: for each subspace, distance to all centroids + let tables: Vec> = (0..self.m) + .map(|i| { + let query_sub = &query[i * self.dsub..(i + 1) * self.dsub]; + self.centroids[i] + .iter() + .map(|c| l2_squared(query_sub, c)) + .collect() + }) + .collect(); + + // Sum distances from lookup tables (very fast!) + codes.iter() + .enumerate() + .map(|(i, &code)| tables[i][code as usize]) + .sum() + } +} +``` + +### HNSW Integration Points + +**Location**: `rust/vecsim/src/index/hnsw/mod.rs:716-724` + +```rust +// In HnswCore::compute_distance +#[inline] +fn compute_distance(&self, id: IdType, query: &[T]) -> T::DistanceType { + // Option 1: Two-stage search with PQ + if let Some(pq) = &self.pq_codec { + // Fast PQ distance for initial filtering + let pq_codes = self.pq_data.get(id); + let approx_dist = pq.asymmetric_distance(query, pq_codes); + + // Only compute exact distance if promising + if approx_dist < self.rerank_threshold { + if let Some(data) = self.data.get(id) { + return self.dist_fn.compute(data, query, self.params.dim); + } + } + return T::DistanceType::from_f64(approx_dist as f64); + } + + // Original exact distance + if let Some(data) = self.data.get(id) { + self.dist_fn.compute(data, query, self.params.dim) + } else { + T::DistanceType::infinity() + } +} +``` + +--- + +## 4. Parallel Search Improvements + +### Current Implementation Analysis + +**Location**: `rust/vecsim/src/index/hnsw/concurrent_graph.rs:85-93` + +Current locking structure: +```rust +pub struct ConcurrentGraph { + segments: RwLock>, // Global lock for growth + segment_size: usize, + len: AtomicUsize, +} +``` + +**Location**: `rust/vecsim/src/index/hnsw/graph.rs:194-202` + +Per-element lock: +```rust +pub struct ElementGraphData { + pub lock: Mutex<()>, // Per-element lock for neighbor modification +} +``` + +### Optimization Opportunities + +**Expected Impact**: 2-4x throughput for concurrent queries +**Complexity**: Medium to Hard + +### 4.1 Lock-Free Read Path + +The current read path acquires a read lock on segments. For search-only workloads, +we can use epoch-based reclamation: + +```rust +// rust/vecsim/src/index/hnsw/concurrent_graph_lockfree.rs +use crossbeam_epoch::{self as epoch, Atomic, Owned, Shared}; + +pub struct LockFreeGraph { + /// Segments using atomic pointer + segments: Atomic>, + segment_size: usize, + len: AtomicUsize, +} + +impl LockFreeGraph { + /// Lock-free read + #[inline] + pub fn get(&self, id: IdType) -> Option<&ElementGraphData> { + let guard = epoch::pin(); + let segments = unsafe { self.segments.load(Ordering::Acquire, &guard).deref() }; + + let (seg_idx, offset) = self.id_to_indices(id); + if seg_idx >= segments.len() { + return None; + } + + unsafe { segments[seg_idx].get(offset) } + } +} +``` + +### 4.2 Batch Query Parallelization + +```rust +// rust/vecsim/src/index/hnsw/single.rs + +impl HnswSingle { + /// Process multiple queries in parallel + pub fn batch_search( + &self, + queries: &[Vec], + k: usize, + ef: usize, + ) -> Vec> { + use rayon::prelude::*; + + queries + .par_iter() + .map(|query| { + self.core.search(query, k, ef, None) + }) + .collect() + } + + /// Optimized batch search with shared state + pub fn batch_search_optimized( + &self, + queries: &[Vec], + k: usize, + ef: usize, + ) -> Vec> { + use rayon::prelude::*; + + // Pre-allocate visited handlers for all queries + let num_queries = queries.len(); + let handlers: Vec<_> = (0..rayon::current_num_threads()) + .map(|_| self.core.visited_pool.get()) + .collect(); + + queries + .par_iter() + .enumerate() + .map(|(i, query)| { + let thread_id = rayon::current_thread_index().unwrap_or(0); + // Reuse handler for this thread + self.core.search_with_handler(query, k, ef, None, &handlers[thread_id]) + }) + .collect() + } +} +``` + +### 4.3 SIMD-Parallel Candidate Evaluation + +```rust +/// Evaluate multiple candidates in parallel using SIMD +fn evaluate_candidates_simd( + candidates: &[(IdType, &[T])], + query: &[T], + dist_fn: &dyn DistanceFunction, + dim: usize, +) -> Vec<(IdType, T::DistanceType)> { + // Process 4 candidates at a time with AVX2 + let mut results = Vec::with_capacity(candidates.len()); + + for chunk in candidates.chunks(4) { + let distances = compute_distances_4way( + query, + chunk.iter().map(|(_, v)| *v).collect::>().as_slice(), + dim, + ); + + for (i, (id, _)) in chunk.iter().enumerate() { + results.push((*id, distances[i])); + } + } + + results +} +``` + +--- + +## 5. Adaptive Parameters + +### Current Implementation Analysis + +**Location**: `rust/vecsim/src/index/hnsw/mod.rs:100-120` + +Current static parameters: +```rust +pub struct HnswParams { + pub dim: usize, + pub m: usize, + pub m_max_0: usize, + pub ef_construction: usize, + pub ef_runtime: usize, + // ... +} +``` + +### 5.1 Dynamic ef Adjustment + +**Expected Impact**: Better quality-latency tradeoff +**Complexity**: Easy + +```rust +// rust/vecsim/src/index/hnsw/adaptive.rs + +/// Adaptive ef controller based on result quality +pub struct AdaptiveEfController { + /// Minimum ef value + min_ef: usize, + /// Maximum ef value + max_ef: usize, + /// Current ef value + current_ef: AtomicUsize, + /// Quality threshold (e.g., recall@k target) + quality_threshold: f32, + /// Moving average of observed quality + quality_ema: AtomicU32, // Stored as fixed-point +} + +impl AdaptiveEfController { + /// Update ef based on measured quality + pub fn update(&self, measured_recall: f32) { + let current = self.current_ef.load(Ordering::Relaxed); + + if measured_recall < self.quality_threshold { + // Increase ef to improve quality + let new_ef = (current * 12 / 10).min(self.max_ef); + self.current_ef.store(new_ef, Ordering::Relaxed); + } else if measured_recall > self.quality_threshold + 0.05 { + // Decrease ef to reduce latency + let new_ef = (current * 9 / 10).max(self.min_ef); + self.current_ef.store(new_ef, Ordering::Relaxed); + } + + // Update EMA + self.update_quality_ema(measured_recall); + } + + /// Get current adaptive ef value + pub fn get_ef(&self) -> usize { + self.current_ef.load(Ordering::Relaxed) + } +} +``` + +### 5.2 Adaptive Prefetch Based on Vector Size (SVS Approach) + +**Location**: `rust/vecsim/src/index/hnsw/search.rs:82-98` + +Reference: The SVS (Scalable Vector Search) paper suggests adapting prefetch +distance based on vector size to hide memory latency effectively. + +```rust +// rust/vecsim/src/utils/prefetch.rs + +/// Adaptive prefetch parameters based on vector characteristics +pub struct PrefetchConfig { + /// Number of vectors to prefetch ahead + pub prefetch_depth: usize, + /// Whether to prefetch graph structure + pub prefetch_graph: bool, +} + +impl PrefetchConfig { + /// Create config based on vector size and cache characteristics + pub fn for_vector_size(dim: usize, element_size: usize) -> Self { + let vector_bytes = dim * element_size; + let l1_cache_size = 32 * 1024; // 32KB typical L1D + let cache_lines = (vector_bytes + 63) / 64; // 64-byte cache lines + + // Prefetch more aggressively for smaller vectors + let prefetch_depth = if vector_bytes <= 256 { + 4 // Small vectors: prefetch 4 ahead + } else if vector_bytes <= 1024 { + 2 // Medium vectors: prefetch 2 ahead + } else { + 1 // Large vectors: prefetch 1 ahead + }; + + Self { + prefetch_depth, + prefetch_graph: vector_bytes <= 512, + } + } +} + +/// Enhanced prefetch for multiple cache lines +#[inline] +pub fn prefetch_vector(data: &[T], prefetch_config: &PrefetchConfig) { + let bytes = std::mem::size_of_val(data); + let cache_lines = (bytes + 63) / 64; + let ptr = data.as_ptr() as *const i8; + + for i in 0..cache_lines.min(8) { // Max 8 cache lines + prefetch_read(unsafe { ptr.add(i * 64) } as *const T); + } +} +``` + +### 5.3 Early Termination Heuristics + +```rust +// In search_layer function + +/// Check if we can terminate early based on distance distribution +fn should_terminate_early( + results: &MaxHeap, + candidates: &MinHeap, + early_termination_factor: f32, +) -> bool { + if !results.is_full() { + return false; + } + + let worst_result = results.top_distance().unwrap(); + if let Some(best_candidate) = candidates.peek() { + // If best remaining candidate is significantly worse than + // our worst result, we can stop + let threshold = worst_result.to_f64() * (1.0 + early_termination_factor as f64); + best_candidate.distance.to_f64() > threshold + } else { + true + } +} +``` + +--- + +## 6. Graph Construction Optimizations + +### Current Implementation Analysis + +**Location**: `rust/vecsim/src/index/hnsw/mod.rs:286-466` + +Current insertion flow: +1. Generate random level +2. Find entry point via greedy search +3. For each level: search_layer → select_neighbors → mutually_connect + +### 6.1 Parallel Index Building + +**Expected Impact**: Near-linear speedup with cores +**Complexity**: Medium + +```rust +// rust/vecsim/src/index/hnsw/single.rs + +impl HnswSingle { + /// Build index from vectors in parallel + pub fn build_parallel( + params: HnswParams, + vectors: &[T], + labels: &[LabelType], + ) -> Result { + let dim = params.dim; + let num_vectors = vectors.len() / dim; + + // Phase 1: Sequential insertion of first few elements (establish structure) + let mut index = Self::new(params.clone()); + let bootstrap_count = (num_vectors / 100).max(100).min(num_vectors); + + for i in 0..bootstrap_count { + let vec_start = i * dim; + let vector = &vectors[vec_start..vec_start + dim]; + index.add_vector(vector, labels[i])?; + } + + // Phase 2: Parallel insertion of remaining elements + let remaining: Vec<(usize, LabelType)> = (bootstrap_count..num_vectors) + .map(|i| (i, labels[i])) + .collect(); + + remaining.par_iter().try_for_each(|&(i, label)| { + let vec_start = i * dim; + let vector = &vectors[vec_start..vec_start + dim]; + index.add_vector_concurrent(vector, label) + })?; + + Ok(index) + } +} +``` + +### 6.2 Incremental Neighbor Selection + +Instead of recomputing all neighbors, incrementally update: + +```rust +/// Incrementally update neighbor selection when a new vector is added +fn update_neighbors_incremental( + &self, + existing_neighbors: &[(IdType, T::DistanceType)], + new_candidate: (IdType, T::DistanceType), + max_neighbors: usize, +) -> Vec { + if existing_neighbors.len() < max_neighbors { + // Simply add if below capacity + let mut result: Vec<_> = existing_neighbors.iter().map(|(id, _)| *id).collect(); + result.push(new_candidate.0); + return result; + } + + // Check if new candidate should replace worst existing neighbor + let worst_existing = existing_neighbors + .iter() + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .unwrap(); + + if new_candidate.1 < worst_existing.1 { + // Replace worst with new candidate + existing_neighbors + .iter() + .filter(|(id, _)| *id != worst_existing.0) + .map(|(id, _)| *id) + .chain(std::iter::once(new_candidate.0)) + .collect() + } else { + // Keep existing neighbors + existing_neighbors.iter().map(|(id, _)| *id).collect() + } +} +``` + +### 6.3 Lazy Re-linking + +Defer expensive neighbor updates to background thread: + +```rust +// rust/vecsim/src/index/hnsw/lazy_relink.rs + +pub struct LazyRelinkQueue { + /// Queue of pending relink operations + pending: Mutex>, + /// Background worker handle + worker: Option>, +} + +struct RelinkTask { + node_id: IdType, + level: usize, + new_neighbors: Vec, +} + +impl LazyRelinkQueue { + /// Schedule a relink operation (non-blocking) + pub fn schedule_relink(&self, node_id: IdType, level: usize, neighbors: Vec) { + self.pending.lock().push_back(RelinkTask { + node_id, + level, + new_neighbors: neighbors, + }); + } + + /// Process pending relinks in background + fn process_pending(&self, graph: &ConcurrentGraph) { + while let Some(task) = self.pending.lock().pop_front() { + if let Some(element) = graph.get(task.node_id) { + let _lock = element.lock.lock(); + element.set_neighbors(task.level, &task.new_neighbors); + } + } + } +} +``` + +--- + +## Summary: Prioritized Implementation Roadmap + +| Optimization | Impact | Complexity | Priority | +|--------------|--------|------------|----------| +| Adaptive Prefetch | 5-15% | Easy | High | +| Early Termination | 5-10% | Easy | High | +| Batch Distance Computation | 15-30% | Medium | High | +| Compact Memory Layout | 10-25% | Hard | Medium | +| Parallel Index Building | 3-8x build | Medium | Medium | +| Dynamic ef Adjustment | Quality++ | Easy | Medium | +| Lock-Free Read Path | 2-4x QPS | Hard | Low | +| PQ Integration | 2-4x memory | Hard | Low | +| Lazy Re-linking | 10-20% build | Medium | Low | + +### Recommended Implementation Order + +1. **Quick wins** (1-2 days each): + - Adaptive prefetch parameters + - Early termination heuristics + - Dynamic ef adjustment + +2. **Medium effort** (1-2 weeks each): + - Batch distance computation + - Parallel index building improvements + +3. **Major refactoring** (2-4 weeks each): + - Compact memory layout + - PQ integration + - Lock-free concurrent graph + diff --git a/rust/vecsim/src/index/hnsw/batch_iterator.rs b/rust/vecsim/src/index/hnsw/batch_iterator.rs new file mode 100644 index 000000000..f99950e2a --- /dev/null +++ b/rust/vecsim/src/index/hnsw/batch_iterator.rs @@ -0,0 +1,544 @@ +//! Batch iterator implementations for HNSW indices. + +use super::multi::HnswMulti; +use super::single::HnswSingle; +use crate::index::traits::{BatchIterator, VecSimIndex}; +use crate::query::QueryParams; +use crate::types::{IdType, LabelType, VectorElement}; + +/// Batch iterator for single-value HNSW index. +pub struct HnswSingleBatchIterator<'a, T: VectorElement> { + index: &'a HnswSingle, + query: Vec, + params: Option, + results: Vec<(IdType, LabelType, T::DistanceType)>, + position: usize, + computed: bool, +} + +impl<'a, T: VectorElement> HnswSingleBatchIterator<'a, T> { + pub fn new( + index: &'a HnswSingle, + query: Vec, + params: Option, + ) -> Self { + Self { + index, + query, + params, + results: Vec::new(), + position: 0, + computed: false, + } + } + + fn compute_results(&mut self) { + if self.computed { + return; + } + + let core = &self.index.core; + + let ef = self.params + .as_ref() + .and_then(|p| p.ef_runtime) + .unwrap_or(core.params.ef_runtime) + .max(1000); // Use large ef for batch iteration + + let count = self.index.index_size(); + + // Clone id_to_label map for filter closure to avoid borrow-after-move + let filter_fn: Option bool + '_>> = + if let Some(ref p) = self.params { + if let Some(ref f) = p.filter { + // Collect DashMap entries into a regular HashMap for the closure + let id_to_label_for_filter: std::collections::HashMap = + self.index.id_to_label.iter().map(|r| (*r.key(), *r.value())).collect(); + Some(Box::new(move |id: IdType| { + id_to_label_for_filter.get(&id).is_some_and(|&label| f(label)) + })) + } else { + None + } + } else { + None + }; + + let search_results = core.search( + &self.query, + count, + ef, + filter_fn.as_ref().map(|f| f.as_ref()), + ); + + // Process results using DashMap + self.results = search_results + .into_iter() + .filter_map(|(id, dist)| { + self.index.id_to_label.get(&id).map(|label_ref| (id, *label_ref, dist)) + }) + .collect(); + + self.computed = true; + } +} + +impl<'a, T: VectorElement> BatchIterator for HnswSingleBatchIterator<'a, T> { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + if !self.computed { + return true; // Haven't computed yet + } + self.position < self.results.len() + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + self.compute_results(); + + if self.position >= self.results.len() { + return None; + } + + let end = (self.position + batch_size).min(self.results.len()); + let batch = self.results[self.position..end].to_vec(); + self.position = end; + + Some(batch) + } + + fn reset(&mut self) { + self.position = 0; + } +} + +/// Batch iterator for multi-value HNSW index. +pub struct HnswMultiBatchIterator<'a, T: VectorElement> { + index: &'a HnswMulti, + query: Vec, + params: Option, + results: Vec<(IdType, LabelType, T::DistanceType)>, + position: usize, + computed: bool, +} + +impl<'a, T: VectorElement> HnswMultiBatchIterator<'a, T> { + pub fn new( + index: &'a HnswMulti, + query: Vec, + params: Option, + ) -> Self { + Self { + index, + query, + params, + results: Vec::new(), + position: 0, + computed: false, + } + } + + fn compute_results(&mut self) { + if self.computed { + return; + } + + let core = &self.index.core; + + let ef = self.params + .as_ref() + .and_then(|p| p.ef_runtime) + .unwrap_or(core.params.ef_runtime) + .max(1000); + + let count = self.index.index_size(); + + // Clone id_to_label map for filter closure to avoid borrow-after-move + let filter_fn: Option bool + '_>> = + if let Some(ref p) = self.params { + if let Some(ref f) = p.filter { + // Collect DashMap entries into a regular HashMap for the closure + let id_to_label_for_filter: std::collections::HashMap = + self.index.id_to_label.iter().map(|r| (*r.key(), *r.value())).collect(); + Some(Box::new(move |id: IdType| { + id_to_label_for_filter.get(&id).is_some_and(|&label| f(label)) + })) + } else { + None + } + } else { + None + }; + + let search_results = core.search( + &self.query, + count, + ef, + filter_fn.as_ref().map(|f| f.as_ref()), + ); + + // Process results using DashMap + self.results = search_results + .into_iter() + .filter_map(|(id, dist)| { + self.index.id_to_label.get(&id).map(|label_ref| (id, *label_ref, dist)) + }) + .collect(); + + self.computed = true; + } +} + +impl<'a, T: VectorElement> BatchIterator for HnswMultiBatchIterator<'a, T> { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + if !self.computed { + return true; + } + self.position < self.results.len() + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + self.compute_results(); + + if self.position >= self.results.len() { + return None; + } + + let end = (self.position + batch_size).min(self.results.len()); + let batch = self.results[self.position..end].to_vec(); + self.position = end; + + Some(batch) + } + + fn reset(&mut self) { + self.position = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + use crate::index::hnsw::HnswParams; + use crate::index::VecSimIndex; + use crate::types::DistanceType; + + #[test] + fn test_hnsw_batch_iterator() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + for i in 0..10 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + assert!(iter.has_next()); + + let batch1 = iter.next_batch(3).unwrap(); + assert!(!batch1.is_empty()); + + let mut total = batch1.len(); + while let Some(batch) = iter.next_batch(3) { + total += batch.len(); + } + + // Should have gotten all vectors + assert!(total <= 10); + } + + #[test] + fn test_hnsw_batch_iterator_empty_index() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let index = HnswSingle::::new(params); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + // Has_next returns true because results haven't been computed yet + assert!(iter.has_next()); + + // But next_batch returns None after computing empty results + let batch = iter.next_batch(10); + assert!(batch.is_none()); + } + + #[test] + fn test_hnsw_batch_iterator_reset() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + for i in 0..5 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + // Consume all + while iter.next_batch(10).is_some() {} + assert!(!iter.has_next()); + + // Reset + iter.reset(); + + // Position is reset but computed flag stays true + let batch = iter.next_batch(10); + assert!(batch.is_some()); + } + + #[test] + fn test_hnsw_batch_iterator_with_query_params() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + for i in 0..10 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let query_params = QueryParams::new().with_ef_runtime(100); + let mut iter = index.batch_iterator(&query, Some(&query_params)).unwrap(); + + assert!(iter.has_next()); + + let batch = iter.next_batch(5).unwrap(); + assert!(!batch.is_empty()); + } + + #[test] + fn test_hnsw_batch_iterator_with_params_ef_runtime() { + // Note: Filter functionality cannot be tested with batch_iterator because + // QueryParams::clone() sets filter to None (closures can't be cloned). + // This test verifies ef_runtime parameter passing instead. + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + for i in 0..10 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + // Test with high ef_runtime for better recall + let query_params = QueryParams::new().with_ef_runtime(200); + let mut iter = index.batch_iterator(&query, Some(&query_params)).unwrap(); + + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(10) { + all_results.extend(batch); + } + + // Should get all results with high ef_runtime + assert!(!all_results.is_empty()); + } + + #[test] + fn test_hnsw_batch_iterator_sorted_by_distance() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + for i in 0..20 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![5.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(5) { + all_results.extend(batch); + } + + // Verify sorted by distance + for i in 1..all_results.len() { + assert!(all_results[i - 1].2.to_f64() <= all_results[i].2.to_f64()); + } + } + + #[test] + fn test_hnsw_multi_batch_iterator() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add vectors with same label (multi-value) + for i in 0..5 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], 100) + .unwrap(); + } + for i in 0..5 { + index + .add_vector(&vec![i as f32 + 10.0, 0.0, 0.0, 0.0], 200) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + assert!(iter.has_next()); + + let mut total = 0; + while let Some(batch) = iter.next_batch(3) { + total += batch.len(); + } + + assert!(total <= 10); + } + + #[test] + fn test_hnsw_multi_batch_iterator_empty() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let index = HnswMulti::::new(params); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + // Has_next returns true initially + assert!(iter.has_next()); + + // But next_batch returns None for empty index + assert!(iter.next_batch(10).is_none()); + } + + #[test] + fn test_hnsw_multi_batch_iterator_reset() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + for i in 0..5 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + // Consume + while iter.next_batch(10).is_some() {} + assert!(!iter.has_next()); + + // Reset + iter.reset(); + + // Can iterate again + assert!(iter.next_batch(10).is_some()); + } + + #[test] + fn test_hnsw_multi_batch_iterator_with_params() { + // Note: Filter functionality cannot be tested with batch_iterator because + // QueryParams::clone() sets filter to None (closures can't be cloned). + // This test verifies ef_runtime parameter passing instead. + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + for i in 0..10 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let query_params = QueryParams::new().with_ef_runtime(200); + let mut iter = index.batch_iterator(&query, Some(&query_params)).unwrap(); + + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(10) { + all_results.extend(batch); + } + + // Should get results with high ef_runtime + assert!(!all_results.is_empty()); + } + + #[test] + fn test_hnsw_batch_iterator_large_batch_size() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + for i in 0..5 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + // Request much larger batch than available + let batch = iter.next_batch(1000).unwrap(); + assert!(batch.len() <= 5); + + // No more results + assert!(!iter.has_next()); + } + + #[test] + fn test_hnsw_batch_iterator_batch_size_one() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + for i in 0..3 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + let mut count = 0; + while let Some(batch) = iter.next_batch(1) { + assert_eq!(batch.len(), 1); + count += 1; + } + + assert!(count <= 3); + } +} diff --git a/rust/vecsim/src/index/hnsw/concurrent_graph.rs b/rust/vecsim/src/index/hnsw/concurrent_graph.rs new file mode 100644 index 000000000..ef28ddb5f --- /dev/null +++ b/rust/vecsim/src/index/hnsw/concurrent_graph.rs @@ -0,0 +1,373 @@ +//! Lock-free concurrent graph storage for HNSW. +//! +//! This module provides a concurrent graph structure that allows multiple threads +//! to read and write graph data without blocking each other. It uses a segmented +//! approach where each segment is a fixed-size array of graph elements. +//! +//! The key insight from the C++ HNSW implementation is that the graph array itself +//! doesn't need locks - only individual node modifications need synchronization, +//! which is handled by the per-node locks in `ElementGraphData`. + +use super::ElementGraphData; +use crate::types::IdType; +use parking_lot::RwLock; +use std::cell::UnsafeCell; +use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering}; + +/// Flags for element status (matches C++ Flags enum). +#[allow(dead_code)] +mod flags { + /// Element is logically deleted but still exists in the graph. + pub const DELETE_MARK: u8 = 0x1; + /// Element is being inserted into the graph. + pub const IN_PROCESS: u8 = 0x2; +} + +/// Default segment size (number of elements per segment). +const SEGMENT_SIZE: usize = 4096; + +/// A segment of graph elements. +/// +/// Each slot uses `UnsafeCell` to allow interior mutability. Thread safety +/// is ensured by: +/// 1. Each element has its own lock (`ElementGraphData.lock`) +/// 2. Element initialization is atomic (write once, then read-only structure) +/// 3. Neighbor modifications use the per-element lock +struct GraphSegment { + data: Box<[UnsafeCell>]>, +} + +impl GraphSegment { + /// Create a new segment with the given size. + fn new(size: usize) -> Self { + let mut vec = Vec::with_capacity(size); + for _ in 0..size { + vec.push(UnsafeCell::new(None)); + } + Self { + data: vec.into_boxed_slice(), + } + } + + /// Get a reference to an element. + /// + /// # Safety + /// Caller must ensure no mutable reference exists to this element. + #[inline] + unsafe fn get(&self, index: usize) -> Option<&ElementGraphData> { + if index >= self.data.len() { + return None; + } + (*self.data[index].get()).as_ref() + } + + /// Set an element. + /// + /// # Safety + /// Caller must ensure no other thread is writing to this exact index. + #[inline] + unsafe fn set(&self, index: usize, value: ElementGraphData) { + if index < self.data.len() { + *self.data[index].get() = Some(value); + } + } + + /// Get the number of slots in this segment. + #[inline] + fn len(&self) -> usize { + self.data.len() + } +} + +// Safety: GraphSegment is safe to share across threads because: +// 1. The Box<[UnsafeCell<...>]> is never reallocated after creation +// 2. Individual elements use UnsafeCell for interior mutability +// 3. Thread safety is ensured by callers using per-element locks +unsafe impl Send for GraphSegment {} +unsafe impl Sync for GraphSegment {} + +/// A concurrent graph structure for HNSW. +/// +/// This structure allows lock-free reads and writes to graph elements. +/// Growth (adding new segments) uses a read-write lock but this is rare +/// since segments are large (4096 elements by default). +pub struct ConcurrentGraph { + /// Segments of graph elements. + /// RwLock is only acquired for growth (very rare). + segments: RwLock>, + /// Number of elements per segment. + segment_size: usize, + /// Total number of initialized elements (approximate, may be slightly stale). + len: AtomicUsize, + /// Atomic flags for each element (DELETE_MARK, IN_PROCESS). + /// Stored separately for O(1) access without acquiring segment lock. + /// This matches the C++ idToMetaData approach. + element_flags: RwLock>, +} + +impl ConcurrentGraph { + /// Create a new concurrent graph with initial capacity. + pub fn new(initial_capacity: usize) -> Self { + let segment_size = SEGMENT_SIZE; + let num_segments = initial_capacity.div_ceil(segment_size).max(1); + + let segments: Vec<_> = (0..num_segments) + .map(|_| GraphSegment::new(segment_size)) + .collect(); + + // Pre-allocate flags array for initial capacity + let flags: Vec = (0..initial_capacity) + .map(|_| AtomicU8::new(0)) + .collect(); + + Self { + segments: RwLock::new(segments), + segment_size, + len: AtomicUsize::new(0), + element_flags: RwLock::new(flags), + } + } + + /// Get the segment and offset for an ID. + #[inline] + fn id_to_indices(&self, id: IdType) -> (usize, usize) { + let id = id as usize; + (id / self.segment_size, id % self.segment_size) + } + + /// Get a reference to an element by ID. + /// + /// Returns `None` if the ID is out of bounds or the element is not initialized. + #[inline] + pub fn get(&self, id: IdType) -> Option<&ElementGraphData> { + let (seg_idx, offset) = self.id_to_indices(id); + let segments = self.segments.read(); + if seg_idx >= segments.len() { + return None; + } + // Safety: We hold the read lock on segments, ensuring the segment exists. + // The returned reference is safe because: + // 1. Segments are never removed or reallocated + // 2. Individual elements use their own lock for modifications + // 3. We're returning an immutable reference to the ElementGraphData + unsafe { + // Transmute lifetime to 'static then back to our lifetime. + // This is safe because segments are never removed and elements + // are never moved once created. + let segment = &segments[seg_idx]; + let result = segment.get(offset); + std::mem::transmute::, Option<&ElementGraphData>>(result) + } + } + + /// Set an element at the given ID. + /// + /// This method ensures the graph has capacity for the ID before setting. + /// It's safe to call from multiple threads for different IDs. + pub fn set(&self, id: IdType, value: ElementGraphData) { + let (seg_idx, offset) = self.id_to_indices(id); + + // Ensure we have enough segments + self.ensure_capacity(id); + + // Set the element + let segments = self.segments.read(); + // Safety: We've ensured capacity above, and each ID is only written once + // during insertion. The ElementGraphData's own lock handles concurrent + // modifications to neighbors. + unsafe { + segments[seg_idx].set(offset, value); + } + + // Update length (best-effort, may be slightly stale) + let id_plus_one = (id as usize) + 1; + loop { + let current = self.len.load(Ordering::Relaxed); + if id_plus_one <= current { + break; + } + if self + .len + .compare_exchange_weak(current, id_plus_one, Ordering::Release, Ordering::Relaxed) + .is_ok() + { + break; + } + } + } + + /// Ensure the graph has capacity for the given ID. + fn ensure_capacity(&self, id: IdType) { + let (seg_idx, _) = self.id_to_indices(id); + + // Fast path - check with read lock + { + let segments = self.segments.read(); + if seg_idx < segments.len() { + return; + } + } + + // Slow path - need to grow + let mut segments = self.segments.write(); + // Double-check after acquiring write lock + while seg_idx >= segments.len() { + segments.push(GraphSegment::new(self.segment_size)); + } + } + + /// Ensure the flags array has capacity for the given ID. + fn ensure_flags_capacity(&self, id: IdType) { + let id_usize = id as usize; + + // Fast path - check with read lock + { + let flags = self.element_flags.read(); + if id_usize < flags.len() { + return; + } + } + + // Slow path - need to grow + let mut flags = self.element_flags.write(); + // Double-check after acquiring write lock + let needed = id_usize + 1; + let current_len = flags.len(); + if needed > current_len { + // Grow by at least one segment worth + let new_capacity = needed.max(current_len + self.segment_size); + let additional = new_capacity - current_len; + flags.reserve(additional); + for _ in 0..additional { + flags.push(AtomicU8::new(0)); + } + } + } + + /// Check if an element is marked as deleted (O(1) lock-free after initial read lock). + #[inline] + pub fn is_marked_deleted(&self, id: IdType) -> bool { + let id_usize = id as usize; + let flags = self.element_flags.read(); + if id_usize < flags.len() { + flags[id_usize].load(Ordering::Acquire) & flags::DELETE_MARK != 0 + } else { + false + } + } + + /// Mark an element as deleted atomically. + pub fn mark_deleted(&self, id: IdType) { + self.ensure_flags_capacity(id); + let flags = self.element_flags.read(); + let id_usize = id as usize; + if id_usize < flags.len() { + flags[id_usize].fetch_or(flags::DELETE_MARK, Ordering::Release); + } + } + + /// Get the approximate number of elements. + #[inline] + pub fn len(&self) -> usize { + self.len.load(Ordering::Acquire) + } + + /// Check if the graph is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get the total capacity. + #[inline] + pub fn capacity(&self) -> usize { + self.segments.read().len() * self.segment_size + } + + /// Iterate over all initialized elements. + /// + /// Note: This provides a snapshot view. Elements may be added during iteration. + pub fn iter(&self) -> impl Iterator { + let len = self.len(); + (0..len as IdType).filter_map(move |id| self.get(id).map(|e| (id, e))) + } + + /// Clear all elements from the graph. + /// + /// This resets the graph to empty state while keeping allocated memory. + pub fn clear(&self) { + // Reset length + self.len.store(0, Ordering::Release); + + // Clear all segments by resetting each slot to None + let segments = self.segments.read(); + for segment in segments.iter() { + for i in 0..segment.len() { + // Safety: We're the only writer during clear() + unsafe { + *segment.data[i].get() = None; + } + } + } + } + + /// Replace the graph contents with a new vector of elements. + /// + /// This is used during compaction to rebuild the graph. + pub fn replace(&self, new_elements: Vec>) { + // Reset length + self.len.store(0, Ordering::Release); + + // Clear existing segments + let segments = self.segments.read(); + for segment in segments.iter() { + for i in 0..segment.len() { + unsafe { + *segment.data[i].get() = None; + } + } + } + drop(segments); + + // Reset flags array (clear all flags) + { + let mut flags = self.element_flags.write(); + // Clear existing flags + for flag in flags.iter() { + flag.store(0, Ordering::Release); + } + // Resize if needed + let needed = new_elements.len(); + let current_len = flags.len(); + if needed > current_len { + let additional = needed - current_len; + flags.reserve(additional); + for _ in 0..additional { + flags.push(AtomicU8::new(0)); + } + } + } + + // Set new elements and update flags for deleted elements + for (id, element) in new_elements.into_iter().enumerate() { + if let Some(data) = element { + // If the element is marked as deleted in metadata, update flags + if data.meta.deleted { + let flags = self.element_flags.read(); + if id < flags.len() { + flags[id].fetch_or(flags::DELETE_MARK, Ordering::Release); + } + } + self.set(id as IdType, data); + } + } + } + + /// Get an iterator over indices with their elements (for compaction). + /// + /// Returns (index, &ElementGraphData) for each initialized slot. + pub fn indexed_iter(&self) -> impl Iterator { + let len = self.len(); + (0..len).filter_map(move |idx| self.get(idx as IdType).map(|e| (idx, e))) + } +} diff --git a/rust/vecsim/src/index/hnsw/graph.rs b/rust/vecsim/src/index/hnsw/graph.rs new file mode 100644 index 000000000..5ef24d8f3 --- /dev/null +++ b/rust/vecsim/src/index/hnsw/graph.rs @@ -0,0 +1,575 @@ +//! Graph data structures for HNSW index. +//! +//! This module provides the core data structures for storing the HNSW graph: +//! - `ElementMetaData`: Metadata for each element (label, level) +//! - `LevelLinks`: Neighbor connections for a single level +//! - `ElementGraphData`: Complete graph data for an element across all levels + +use crate::types::{IdType, LabelType, INVALID_ID}; +use parking_lot::Mutex; +use std::sync::atomic::{AtomicU32, Ordering}; + +/// Maximum number of neighbors at level 0. +pub const DEFAULT_M: usize = 16; + +/// Maximum number of neighbors at levels > 0. +pub const DEFAULT_M_MAX: usize = 16; + +/// Maximum number of neighbors at level 0 (typically 2*M). +pub const DEFAULT_M_MAX_0: usize = 32; + +/// Metadata for a single element in the index. +#[derive(Debug)] +pub struct ElementMetaData { + /// External label. + pub label: LabelType, + /// Maximum level this element appears in. + pub level: u8, + /// Whether this element has been deleted (tombstone). + pub deleted: bool, +} + +impl ElementMetaData { + pub fn new(label: LabelType, level: u8) -> Self { + Self { + label, + level, + deleted: false, + } + } +} + +/// Neighbor connections for a single level. +/// +/// Uses a fixed-size array for cache efficiency. +pub struct LevelLinks { + /// Neighbor IDs. INVALID_ID marks empty slots. + neighbors: Vec, + /// Current number of neighbors. + count: AtomicU32, + /// Maximum capacity. + capacity: usize, +} + +impl LevelLinks { + /// Create new level links with given capacity. + pub fn new(capacity: usize) -> Self { + let neighbors: Vec<_> = (0..capacity).map(|_| AtomicU32::new(INVALID_ID)).collect(); + Self { + neighbors, + count: AtomicU32::new(0), + capacity, + } + } + + /// Get the number of neighbors. + #[inline] + pub fn len(&self) -> usize { + self.count.load(Ordering::Acquire) as usize + } + + /// Check if empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get the capacity. + #[inline] + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Get all neighbor IDs. + pub fn get_neighbors(&self) -> Vec { + let count = self.len(); + let mut result = Vec::with_capacity(count); + for i in 0..count { + let id = self.neighbors[i].load(Ordering::Acquire); + if id != INVALID_ID { + result.push(id); + } + } + result + } + + /// Iterate over neighbor IDs without allocation. + #[inline] + pub fn iter_neighbors(&self) -> impl Iterator + '_ { + let count = self.len(); // Cache the count once + let neighbors = &self.neighbors; + std::iter::from_fn({ + let mut idx = 0usize; + move || { + while idx < count { + let id = neighbors[idx].load(Ordering::Acquire); + idx += 1; + if id != INVALID_ID { + return Some(id); + } + } + None + } + }) + } + + /// Get neighbor at a specific index (0-based). + /// Returns None if index is out of bounds or if the slot contains INVALID_ID. + #[inline] + pub fn get_neighbor_at(&self, index: usize) -> Option { + if index >= self.len() { + return None; + } + let id = self.neighbors[index].load(Ordering::Acquire); + if id != INVALID_ID { + Some(id) + } else { + None + } + } + + /// Check if a neighbor exists in the list. + #[inline] + pub fn has_neighbor(&self, neighbor: IdType) -> bool { + let count = self.len(); + for i in 0..count { + if self.neighbors[i].load(Ordering::Acquire) == neighbor { + return true; + } + } + false + } + + /// Add a neighbor if there's space. + /// Returns true if added, false if full. + pub fn try_add(&self, neighbor: IdType) -> bool { + let mut current = self.count.load(Ordering::Acquire); + loop { + if current as usize >= self.capacity { + return false; + } + match self.count.compare_exchange_weak( + current, + current + 1, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + self.neighbors[current as usize].store(neighbor, Ordering::Release); + return true; + } + Err(c) => current = c, + } + } + } + + /// Set neighbors from a slice. Clears existing neighbors. + pub fn set_neighbors(&self, neighbors: &[IdType]) { + // Clear existing + for i in 0..self.capacity { + self.neighbors[i].store(INVALID_ID, Ordering::Release); + } + + let count = neighbors.len().min(self.capacity); + for (i, &n) in neighbors.iter().take(count).enumerate() { + self.neighbors[i].store(n, Ordering::Release); + } + self.count.store(count as u32, Ordering::Release); + } + + /// Remove a neighbor by ID. + pub fn remove(&self, neighbor: IdType) -> bool { + let count = self.len(); + for i in 0..count { + if self.neighbors[i].load(Ordering::Acquire) == neighbor { + // Swap with last and decrement count + let last_idx = count - 1; + if i < last_idx { + let last = self.neighbors[last_idx].load(Ordering::Acquire); + self.neighbors[i].store(last, Ordering::Release); + } + self.neighbors[last_idx].store(INVALID_ID, Ordering::Release); + self.count.fetch_sub(1, Ordering::AcqRel); + return true; + } + } + false + } + + /// Check if a neighbor exists. + pub fn contains(&self, neighbor: IdType) -> bool { + let count = self.len(); + for i in 0..count { + if self.neighbors[i].load(Ordering::Acquire) == neighbor { + return true; + } + } + false + } +} + +impl std::fmt::Debug for LevelLinks { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LevelLinks") + .field("count", &self.len()) + .field("capacity", &self.capacity) + .field("neighbors", &self.get_neighbors()) + .finish() + } +} + +/// Complete graph data for an element across all levels. +pub struct ElementGraphData { + /// Metadata (label, level, deleted flag). + pub meta: ElementMetaData, + /// Neighbor links for each level (level 0 to max_level). + pub levels: Vec, + /// Lock for modifying this element's neighbors. + pub lock: Mutex<()>, +} + +impl ElementGraphData { + /// Create new graph data for an element. + pub fn new(label: LabelType, level: u8, m_max_0: usize, m_max: usize) -> Self { + let mut levels = Vec::with_capacity(level as usize + 1); + + // Level 0 has m_max_0 capacity + levels.push(LevelLinks::new(m_max_0)); + + // Higher levels have m_max capacity + for _ in 1..=level { + levels.push(LevelLinks::new(m_max)); + } + + Self { + meta: ElementMetaData::new(label, level), + levels, + lock: Mutex::new(()), + } + } + + /// Get neighbors at a specific level. + pub fn get_neighbors(&self, level: usize) -> Vec { + if level < self.levels.len() { + self.levels[level].get_neighbors() + } else { + Vec::new() + } + } + + /// Iterate over neighbors at a specific level without allocation. + #[inline] + pub fn iter_neighbors(&self, level: usize) -> impl Iterator + '_ { + let (level_links, count) = if level < self.levels.len() { + let links = &self.levels[level]; + (Some(links), links.len()) + } else { + (None, 0) + }; + std::iter::from_fn({ + let mut idx = 0usize; + move || { + let links = level_links?; + while idx < count { + let id = links.neighbors[idx].load(Ordering::Acquire); + idx += 1; + if id != INVALID_ID { + return Some(id); + } + } + None + } + }) + } + + /// Get the number of neighbors at a specific level. + #[inline] + pub fn neighbor_count(&self, level: usize) -> usize { + if level < self.levels.len() { + self.levels[level].len() + } else { + 0 + } + } + + /// Get neighbor at a specific index within a level. + #[inline] + pub fn get_neighbor_at(&self, level: usize, index: usize) -> Option { + if level < self.levels.len() { + self.levels[level].get_neighbor_at(index) + } else { + None + } + } + + /// Check if a neighbor exists at a specific level. + #[inline] + pub fn has_neighbor(&self, level: usize, neighbor: IdType) -> bool { + if level < self.levels.len() { + self.levels[level].has_neighbor(neighbor) + } else { + false + } + } + + /// Try to add a neighbor at a specific level. + /// Returns true if added, false if the level is full. + #[inline] + pub fn try_add_neighbor(&self, level: usize, neighbor: IdType) -> bool { + if level < self.levels.len() { + self.levels[level].try_add(neighbor) + } else { + false + } + } + + /// Set neighbors at a specific level. + pub fn set_neighbors(&self, level: usize, neighbors: &[IdType]) { + if level < self.levels.len() { + self.levels[level].set_neighbors(neighbors); + } + } + + /// Get the maximum level for this element. + #[inline] + pub fn max_level(&self) -> u8 { + self.meta.level + } +} + +impl std::fmt::Debug for ElementGraphData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ElementGraphData") + .field("meta", &self.meta) + .field("levels", &self.levels) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_level_links() { + let links = LevelLinks::new(4); + assert!(links.is_empty()); + + assert!(links.try_add(1)); + assert!(links.try_add(2)); + assert!(links.try_add(3)); + assert!(links.try_add(4)); + assert!(!links.try_add(5)); // Full + + assert_eq!(links.len(), 4); + assert!(links.contains(2)); + + assert!(links.remove(2)); + assert_eq!(links.len(), 3); + assert!(!links.contains(2)); + } + + #[test] + fn test_element_graph_data() { + let data = ElementGraphData::new(42, 2, 32, 16); + + assert_eq!(data.meta.label, 42); + assert_eq!(data.max_level(), 2); + assert_eq!(data.levels.len(), 3); + assert_eq!(data.levels[0].capacity(), 32); + assert_eq!(data.levels[1].capacity(), 16); + assert_eq!(data.levels[2].capacity(), 16); + } + + #[test] + fn test_level_links_empty() { + let links = LevelLinks::new(4); + assert!(links.is_empty()); + assert_eq!(links.len(), 0); + assert_eq!(links.get_neighbors(), Vec::::new()); + assert!(!links.contains(1)); + assert!(!links.remove(1)); + } + + #[test] + fn test_level_links_set_neighbors() { + let links = LevelLinks::new(4); + + links.set_neighbors(&[1, 2, 3]); + assert_eq!(links.len(), 3); + assert_eq!(links.get_neighbors(), vec![1, 2, 3]); + + // Overwrite with different neighbors + links.set_neighbors(&[10, 20]); + assert_eq!(links.len(), 2); + assert_eq!(links.get_neighbors(), vec![10, 20]); + + // Clear all + links.set_neighbors(&[]); + assert!(links.is_empty()); + } + + #[test] + fn test_level_links_set_neighbors_truncates() { + let links = LevelLinks::new(3); + + // Try to set more than capacity + links.set_neighbors(&[1, 2, 3, 4, 5, 6]); + assert_eq!(links.len(), 3); + assert_eq!(links.get_neighbors(), vec![1, 2, 3]); + } + + #[test] + fn test_level_links_remove_first() { + let links = LevelLinks::new(4); + links.set_neighbors(&[1, 2, 3, 4]); + + assert!(links.remove(1)); + assert_eq!(links.len(), 3); + assert!(!links.contains(1)); + // First element is now swapped with last + let neighbors = links.get_neighbors(); + assert!(neighbors.contains(&4)); + assert!(neighbors.contains(&2)); + assert!(neighbors.contains(&3)); + } + + #[test] + fn test_level_links_remove_last() { + let links = LevelLinks::new(4); + links.set_neighbors(&[1, 2, 3, 4]); + + assert!(links.remove(4)); + assert_eq!(links.len(), 3); + assert_eq!(links.get_neighbors(), vec![1, 2, 3]); + } + + #[test] + fn test_level_links_remove_middle() { + let links = LevelLinks::new(4); + links.set_neighbors(&[1, 2, 3, 4]); + + assert!(links.remove(2)); + assert_eq!(links.len(), 3); + // 2 is replaced by 4 (swap with last) + let neighbors = links.get_neighbors(); + assert!(neighbors.contains(&1)); + assert!(neighbors.contains(&4)); + assert!(neighbors.contains(&3)); + } + + #[test] + fn test_level_links_remove_nonexistent() { + let links = LevelLinks::new(4); + links.set_neighbors(&[1, 2, 3]); + + assert!(!links.remove(99)); + assert_eq!(links.len(), 3); + } + + #[test] + fn test_level_links_concurrent_add() { + use std::sync::Arc; + use std::thread; + + let links = Arc::new(LevelLinks::new(100)); + + let mut handles = vec![]; + for t in 0..4 { + let links_clone = Arc::clone(&links); + handles.push(thread::spawn(move || { + for i in 0..25 { + links_clone.try_add((t * 25 + i) as u32); + } + })); + } + + for handle in handles { + handle.join().unwrap(); + } + + assert_eq!(links.len(), 100); + } + + #[test] + fn test_element_metadata() { + let meta = ElementMetaData::new(123, 5); + assert_eq!(meta.label, 123); + assert_eq!(meta.level, 5); + assert!(!meta.deleted); + } + + #[test] + fn test_element_graph_data_level_zero_only() { + let data = ElementGraphData::new(1, 0, 32, 16); + + assert_eq!(data.meta.label, 1); + assert_eq!(data.max_level(), 0); + assert_eq!(data.levels.len(), 1); + assert_eq!(data.levels[0].capacity(), 32); + } + + #[test] + fn test_element_graph_data_get_set_neighbors() { + let data = ElementGraphData::new(1, 2, 32, 16); + + // Level 0 + data.set_neighbors(0, &[10, 20, 30]); + assert_eq!(data.get_neighbors(0), vec![10, 20, 30]); + + // Level 1 + data.set_neighbors(1, &[100, 200]); + assert_eq!(data.get_neighbors(1), vec![100, 200]); + + // Level 2 + data.set_neighbors(2, &[1000]); + assert_eq!(data.get_neighbors(2), vec![1000]); + + // Out of bounds level returns empty + assert_eq!(data.get_neighbors(5), Vec::::new()); + + // Setting out of bounds level does nothing + data.set_neighbors(5, &[1, 2, 3]); + assert_eq!(data.get_neighbors(5), Vec::::new()); + } + + #[test] + fn test_element_graph_data_debug() { + let data = ElementGraphData::new(42, 1, 4, 2); + data.set_neighbors(0, &[1, 2]); + data.set_neighbors(1, &[3]); + + let debug_str = format!("{:?}", data); + assert!(debug_str.contains("ElementGraphData")); + assert!(debug_str.contains("meta")); + assert!(debug_str.contains("levels")); + } + + #[test] + fn test_level_links_debug() { + let links = LevelLinks::new(4); + links.set_neighbors(&[1, 2, 3]); + + let debug_str = format!("{:?}", links); + assert!(debug_str.contains("LevelLinks")); + assert!(debug_str.contains("count")); + assert!(debug_str.contains("capacity")); + assert!(debug_str.contains("neighbors")); + } + + #[test] + fn test_element_graph_data_lock() { + let data = ElementGraphData::new(1, 0, 32, 16); + + // Test lock can be acquired and released + { + let _guard = data.lock.lock(); + // Lock held + } + // Lock released + { + let _guard2 = data.lock.lock(); + // Lock can be acquired again + } + } +} diff --git a/rust/vecsim/src/index/hnsw/mod.rs b/rust/vecsim/src/index/hnsw/mod.rs new file mode 100644 index 000000000..9cf47ffd5 --- /dev/null +++ b/rust/vecsim/src/index/hnsw/mod.rs @@ -0,0 +1,1236 @@ +//! HNSW (Hierarchical Navigable Small World) index implementation. +//! +//! HNSW is an approximate nearest neighbor algorithm that provides +//! logarithmic query complexity with high recall. It constructs a +//! multi-layer graph where each layer is a proximity graph. +//! +//! Key parameters: +//! - `M`: Maximum number of connections per element per layer +//! - `ef_construction`: Size of dynamic candidate list during construction +//! - `ef_runtime`: Size of dynamic candidate list during search (runtime) + +pub mod batch_iterator; +pub mod concurrent_graph; +pub mod graph; +pub mod multi; +pub mod search; +pub mod single; +pub mod visited; + +pub use batch_iterator::{HnswSingleBatchIterator, HnswMultiBatchIterator}; +/// Type alias for HNSW batch iterator. +pub type HnswBatchIterator<'a, T> = HnswSingleBatchIterator<'a, T>; +pub use graph::{ElementGraphData, DEFAULT_M, DEFAULT_M_MAX, DEFAULT_M_MAX_0}; +pub use multi::HnswMulti; +pub use single::{HnswSingle, HnswStats}; +pub use visited::{VisitedNodesHandler, VisitedNodesHandlerPool}; + +use crate::containers::DataBlocks; +use crate::distance::{create_distance_function, DistanceFunction, Metric}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement, INVALID_ID}; +use concurrent_graph::ConcurrentGraph; +use rand::Rng; +use std::collections::HashMap; +use std::sync::atomic::{AtomicU32, Ordering}; + +// Profiling support +#[cfg(feature = "profile")] +use std::cell::RefCell; +#[cfg(feature = "profile")] +use std::time::Instant; + +#[cfg(feature = "profile")] +thread_local! { + pub static PROFILE_STATS: RefCell = RefCell::new(ProfileStats::default()); +} + +#[cfg(feature = "profile")] +#[derive(Default)] +pub struct ProfileStats { + pub search_layer_ns: u64, + pub select_neighbors_ns: u64, + pub add_links_ns: u64, + pub visited_pool_ns: u64, + pub greedy_search_ns: u64, + pub calls: u64, + // Detailed add_links breakdown + pub add_links_lock_ns: u64, + pub add_links_get_neighbors_ns: u64, + pub add_links_contains_ns: u64, + pub add_links_prune_ns: u64, + pub add_links_set_ns: u64, + pub add_links_prune_count: u64, +} + +#[cfg(feature = "profile")] +impl ProfileStats { + pub fn print_and_reset(&mut self) { + if self.calls > 0 { + let total = self.search_layer_ns + self.select_neighbors_ns + self.add_links_ns + + self.visited_pool_ns + self.greedy_search_ns; + println!("Profile stats ({} insertions):", self.calls); + println!(" search_layer: {:>8.2}ms ({:>5.1}%)", + self.search_layer_ns as f64 / 1_000_000.0, + 100.0 * self.search_layer_ns as f64 / total as f64); + println!(" select_neighbors: {:>8.2}ms ({:>5.1}%)", + self.select_neighbors_ns as f64 / 1_000_000.0, + 100.0 * self.select_neighbors_ns as f64 / total as f64); + println!(" add_links: {:>8.2}ms ({:>5.1}%)", + self.add_links_ns as f64 / 1_000_000.0, + 100.0 * self.add_links_ns as f64 / total as f64); + println!(" visited_pool: {:>8.2}ms ({:>5.1}%)", + self.visited_pool_ns as f64 / 1_000_000.0, + 100.0 * self.visited_pool_ns as f64 / total as f64); + println!(" greedy_search: {:>8.2}ms ({:>5.1}%)", + self.greedy_search_ns as f64 / 1_000_000.0, + 100.0 * self.greedy_search_ns as f64 / total as f64); + + // Detailed add_links breakdown + if self.add_links_ns > 0 { + println!("\n add_links breakdown:"); + println!(" lock: {:>8.2}ms ({:>5.1}%)", + self.add_links_lock_ns as f64 / 1_000_000.0, + 100.0 * self.add_links_lock_ns as f64 / self.add_links_ns as f64); + println!(" get_neighbors: {:>8.2}ms ({:>5.1}%)", + self.add_links_get_neighbors_ns as f64 / 1_000_000.0, + 100.0 * self.add_links_get_neighbors_ns as f64 / self.add_links_ns as f64); + println!(" contains: {:>8.2}ms ({:>5.1}%)", + self.add_links_contains_ns as f64 / 1_000_000.0, + 100.0 * self.add_links_contains_ns as f64 / self.add_links_ns as f64); + println!(" prune: {:>8.2}ms ({:>5.1}%) [{} calls]", + self.add_links_prune_ns as f64 / 1_000_000.0, + 100.0 * self.add_links_prune_ns as f64 / self.add_links_ns as f64, + self.add_links_prune_count); + println!(" set_neighbors: {:>8.2}ms ({:>5.1}%)", + self.add_links_set_ns as f64 / 1_000_000.0, + 100.0 * self.add_links_set_ns as f64 / self.add_links_ns as f64); + } + } + *self = ProfileStats::default(); + } +} + +/// Parameters for creating an HNSW index. +#[derive(Debug, Clone)] +pub struct HnswParams { + /// Vector dimension. + pub dim: usize, + /// Distance metric. + pub metric: Metric, + /// Maximum number of connections per element (default: 16). + pub m: usize, + /// Maximum connections at level 0 (default: 2*M). + pub m_max_0: usize, + /// Size of dynamic candidate list during construction. + pub ef_construction: usize, + /// Size of dynamic candidate list during search (default value). + pub ef_runtime: usize, + /// Initial capacity (number of vectors). + pub initial_capacity: usize, + /// Enable diverse neighbor selection heuristic. + pub enable_heuristic: bool, + /// Random seed for reproducible level generation (None = random). + pub seed: Option, +} + +impl HnswParams { + /// Create new parameters with required fields. + pub fn new(dim: usize, metric: Metric) -> Self { + Self { + dim, + metric, + m: DEFAULT_M, + m_max_0: DEFAULT_M_MAX_0, + ef_construction: 200, + ef_runtime: 10, + initial_capacity: 1024, + enable_heuristic: true, + seed: None, + } + } + + /// Set M parameter. + pub fn with_m(mut self, m: usize) -> Self { + self.m = m; + self.m_max_0 = m * 2; + self + } + + /// Set M_max_0 parameter independently (overrides 2*M default). + pub fn with_m_max_0(mut self, m_max_0: usize) -> Self { + self.m_max_0 = m_max_0; + self + } + + /// Set ef_construction. + pub fn with_ef_construction(mut self, ef: usize) -> Self { + self.ef_construction = ef; + self + } + + /// Set ef_runtime. + pub fn with_ef_runtime(mut self, ef: usize) -> Self { + self.ef_runtime = ef; + self + } + + /// Set initial capacity. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.initial_capacity = capacity; + self + } + + /// Enable/disable heuristic neighbor selection. + pub fn with_heuristic(mut self, enable: bool) -> Self { + self.enable_heuristic = enable; + self + } + + /// Set random seed for reproducible level generation. + /// + /// When set, the same sequence of insertions will produce + /// the same graph structure, useful for testing and debugging. + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = Some(seed); + self + } +} + +/// Core HNSW implementation shared between single and multi variants. +/// +/// This structure supports concurrent access for parallel insertion using +/// a lock-free graph structure. Only per-node locks are used for neighbor +/// modifications, matching the C++ implementation approach. +pub(crate) struct HnswCore { + /// Vector storage (thread-safe with interior mutability). + pub data: DataBlocks, + /// Graph structure for each element (lock-free concurrent access). + pub graph: ConcurrentGraph, + /// Distance function. + pub dist_fn: Box>, + /// Entry point to the graph (top level). + pub entry_point: AtomicU32, + /// Current maximum level in the graph. + pub max_level: AtomicU32, + /// Pool of visited handlers for concurrent searches. + pub visited_pool: VisitedNodesHandlerPool, + /// Parameters. + pub params: HnswParams, + /// Multiplier for random level generation (1/ln(M)). + pub level_mult: f64, + /// Random number generator for level selection. + rng: parking_lot::Mutex, +} + +impl HnswCore { + /// Create a new HNSW core. + pub fn new(params: HnswParams) -> Self { + use rand::SeedableRng; + + let data = DataBlocks::new(params.dim, params.initial_capacity); + let dist_fn = create_distance_function(params.metric, params.dim); + let visited_pool = VisitedNodesHandlerPool::new(params.initial_capacity); + + // Level multiplier: 1/ln(M) + let level_mult = 1.0 / (params.m as f64).ln(); + + // Initialize RNG with seed if provided, otherwise use entropy + let rng = match params.seed { + Some(seed) => rand::rngs::StdRng::seed_from_u64(seed), + None => rand::rngs::StdRng::from_entropy(), + }; + + Self { + data, + graph: ConcurrentGraph::new(params.initial_capacity), + dist_fn, + entry_point: AtomicU32::new(INVALID_ID), + max_level: AtomicU32::new(0), + visited_pool, + level_mult, + rng: parking_lot::Mutex::new(rng), + params, + } + } + + /// Generate a random level for a new element. + pub fn generate_random_level(&self) -> u8 { + let mut rng = self.rng.lock(); + let r: f64 = rng.gen(); + let level = (-r.ln() * self.level_mult).floor() as u8; + level.min(32) // Cap at reasonable level + } + + /// Add a vector and return its internal ID. + /// + /// Returns `None` if the vector dimension doesn't match. + pub fn add_vector(&mut self, vector: &[T]) -> Option { + self.add_vector_concurrent(vector) + } + + /// Add a vector concurrently and return its internal ID. + /// + /// This method is thread-safe and can be called from multiple threads. + /// Returns `None` if the vector dimension doesn't match. + pub fn add_vector_concurrent(&self, vector: &[T]) -> Option { + let processed = self.dist_fn.preprocess(vector, self.params.dim); + self.data.add_concurrent(&processed) + } + + /// Get vector data by ID. + #[inline] + pub fn get_vector(&self, id: IdType) -> Option<&[T]> { + self.data.get(id) + } + + /// Insert a new element into the graph. + #[cfg(not(feature = "profile"))] + pub fn insert(&mut self, id: IdType, label: LabelType) { + self.insert_concurrent(id, label); + } + + /// Insert a new element into the graph (profiled version). + #[cfg(feature = "profile")] + pub fn insert(&mut self, id: IdType, label: LabelType) { + self.insert_concurrent(id, label); + PROFILE_STATS.with(|s| s.borrow_mut().calls += 1); + } + + /// Insert a new element into the graph concurrently. + /// + /// This method is thread-safe and can be called from multiple threads. + /// It uses fine-grained locking to allow concurrent insertions while + /// maintaining graph consistency. + #[cfg(not(feature = "profile"))] + pub fn insert_concurrent(&self, id: IdType, label: LabelType) { + self.insert_concurrent_impl(id, label); + } + + /// Insert a new element into the graph concurrently (profiled version). + #[cfg(feature = "profile")] + pub fn insert_concurrent(&self, id: IdType, label: LabelType) { + self.insert_concurrent_impl(id, label); + PROFILE_STATS.with(|s| s.borrow_mut().calls += 1); + } + + /// Concurrent insert implementation. + fn insert_concurrent_impl(&self, id: IdType, label: LabelType) { + let level = self.generate_random_level(); + + // Create graph data for this element + let graph_data = ElementGraphData::new( + label, + level, + self.params.m_max_0, + self.params.m, + ); + + // Set the graph data (ConcurrentGraph handles capacity automatically) + let id_usize = id as usize; + self.graph.set(id, graph_data); + + // Update visited pool if needed + if id_usize >= self.visited_pool.current_capacity() { + self.visited_pool.resize(id_usize + 1024); + } + + let entry_point = self.entry_point.load(Ordering::Acquire); + + if entry_point == INVALID_ID { + // First element - use CAS to avoid race + if self.entry_point.compare_exchange( + INVALID_ID, id, + Ordering::AcqRel, Ordering::Relaxed + ).is_ok() { + self.max_level.store(level as u32, Ordering::Release); + return; + } + // Another thread beat us, continue with normal insertion + } + + // Get query vector + let query = match self.get_vector(id) { + Some(v) => v, + None => return, + }; + + // Search from entry point to find insertion point + let current_max = self.max_level.load(Ordering::Acquire) as usize; + let mut current_entry = self.entry_point.load(Ordering::Acquire); + + // Traverse upper layers with greedy search + #[cfg(feature = "profile")] + let greedy_start = Instant::now(); + + for l in (level as usize + 1..=current_max).rev() { + let (new_entry, _) = search::greedy_search( + current_entry, + query, + l, + &self.graph, + |id| self.data.get(id), + self.dist_fn.as_ref(), + self.params.dim, + ); + current_entry = new_entry; + } + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().greedy_search_ns += greedy_start.elapsed().as_nanos() as u64); + + // Insert at each level from min(level, max_level) down to 0 + let start_level = level.min(current_max as u8); + let mut entry_points = vec![(current_entry, self.compute_distance(current_entry, query))]; + + for l in (0..=start_level as usize).rev() { + #[cfg(feature = "profile")] + let pool_start = Instant::now(); + + let mut visited = self.visited_pool.get(); + visited.reset(); + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().visited_pool_ns += pool_start.elapsed().as_nanos() as u64); + + // Search this layer + #[cfg(feature = "profile")] + let search_start = Instant::now(); + + let neighbors = search::search_layer:: bool, _>( + &entry_points, + query, + l, + self.params.ef_construction, + &self.graph, + |id| self.data.get(id), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + None, + ); + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().search_layer_ns += search_start.elapsed().as_nanos() as u64); + + // Select neighbors + #[cfg(feature = "profile")] + let select_start = Instant::now(); + + let m = if l == 0 { self.params.m_max_0 } else { self.params.m }; + let selected = if self.params.enable_heuristic { + search::select_neighbors_heuristic( + id, + &neighbors, + m, + |id| self.data.get(id), + self.dist_fn.as_ref(), + self.params.dim, + false, + true, + ) + } else { + search::select_neighbors_simple(&neighbors, m) + }; + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().select_neighbors_ns += select_start.elapsed().as_nanos() as u64); + + // Build selected neighbors with distances for mutually_connect_new_element + let selected_with_distances: Vec<(IdType, T::DistanceType)> = selected + .iter() + .filter_map(|&neighbor_id| { + // Find the distance from the search results + neighbors.iter() + .find(|&&(nid, _)| nid == neighbor_id) + .map(|&(nid, dist)| (nid, dist)) + }) + .collect(); + + // Mutually connect new element with its neighbors using lock ordering + #[cfg(feature = "profile")] + let links_start = Instant::now(); + + self.mutually_connect_new_element(id, &selected_with_distances, l); + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().add_links_ns += links_start.elapsed().as_nanos() as u64); + + // Use neighbors as entry points for next level + if !neighbors.is_empty() { + entry_points = neighbors; + } + } + + // Update entry point and max level if needed using CAS + self.maybe_update_entry_point(id, level); + } + + /// Maybe update entry point and max level using CAS for thread safety. + fn maybe_update_entry_point(&self, new_id: IdType, new_level: u8) { + loop { + let current_max = self.max_level.load(Ordering::Acquire); + if (new_level as u32) <= current_max { + return; + } + if self.max_level.compare_exchange( + current_max, new_level as u32, + Ordering::AcqRel, Ordering::Relaxed + ).is_ok() { + self.entry_point.store(new_id, Ordering::Release); + return; + } + } + } + + /// Mutually connect a new element with its selected neighbors at a given level. + /// + /// This method implements the C++ lock ordering strategy: + /// 1. Lock nodes in sorted ID order to prevent deadlocks + /// 2. Fast path: if neighbor has space, append link directly + /// 3. Slow path: if neighbor is full, call revisit_neighbor_connections + /// + /// This is the key function for achieving true parallel insertion with high recall. + fn mutually_connect_new_element( + &self, + new_node_id: IdType, + selected_neighbors: &[(IdType, T::DistanceType)], + level: usize, + ) { + let max_m = if level == 0 { self.params.m_max_0 } else { self.params.m }; + + for &(neighbor_id, dist_to_neighbor) in selected_neighbors { + // Get both elements + let (new_element, neighbor_element) = match ( + self.graph.get(new_node_id), + self.graph.get(neighbor_id), + ) { + (Some(n), Some(e)) => (n, e), + _ => continue, + }; + + // Check level bounds + if level >= new_element.levels.len() || level >= neighbor_element.levels.len() { + continue; + } + + // Lock in sorted ID order to prevent deadlocks (C++ pattern) + let (_lock1, _lock2) = if new_node_id < neighbor_id { + (new_element.lock.lock(), neighbor_element.lock.lock()) + } else { + (neighbor_element.lock.lock(), new_element.lock.lock()) + }; + + // Check if new node can still add neighbors (may have changed between iterations) + // Use neighbor_count to avoid Vec allocation + if new_element.neighbor_count(level) >= max_m { + // New node is full, skip remaining neighbors + break; + } + + // Check if connection already exists (no allocation) + if new_element.has_neighbor(level, neighbor_id) { + continue; + } + + // Check if neighbor has space for the new node (no allocation) + if neighbor_element.neighbor_count(level) < max_m { + // Fast path: neighbor has space, make bidirectional connection + // Use try_add_neighbor for O(1) append instead of get→push→set + new_element.try_add_neighbor(level, neighbor_id); + neighbor_element.try_add_neighbor(level, new_node_id); + } else { + // Slow path: neighbor is full, need to revisit its connections + // First add new_node -> neighbor (new node has space, we checked above) + new_element.try_add_neighbor(level, neighbor_id); + + // Now revisit neighbor's connections to possibly include new_node + self.revisit_neighbor_connections_locked( + new_node_id, + neighbor_id, + dist_to_neighbor, + neighbor_element, + level, + max_m, + ); + } + } + } + + /// Revisit a neighbor's connections when it's at capacity. + /// + /// This is called while holding both the new_node and neighbor locks. + /// It re-evaluates which neighbors the neighbor should keep, possibly + /// including the new node if it's closer than existing neighbors. + fn revisit_neighbor_connections_locked( + &self, + new_node_id: IdType, + neighbor_id: IdType, + dist_to_new_node: T::DistanceType, + neighbor_element: &ElementGraphData, + level: usize, + max_m: usize, + ) { + // Collect all candidates: existing neighbors + new node + let neighbor_data = match self.data.get(neighbor_id) { + Some(d) => d, + None => return, + }; + + let current_neighbors = neighbor_element.get_neighbors(level); + let mut candidates: Vec<(IdType, T::DistanceType)> = Vec::with_capacity(current_neighbors.len() + 1); + + // Add new node as candidate with its pre-computed distance + candidates.push((new_node_id, dist_to_new_node)); + + // Add existing neighbors with their distances to the neighbor + for &existing_neighbor_id in ¤t_neighbors { + if let Some(existing_data) = self.data.get(existing_neighbor_id) { + let dist = self.dist_fn.compute(existing_data, neighbor_data, self.params.dim); + candidates.push((existing_neighbor_id, dist)); + } + } + + // Select best M neighbors using simple selection (M closest) + if candidates.len() > max_m { + candidates.select_nth_unstable_by(max_m - 1, |a, b| { + a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) + }); + candidates.truncate(max_m); + } + + let selected: Vec = candidates.iter().map(|&(id, _)| id).collect(); + + // Update neighbor's connections + neighbor_element.set_neighbors(level, &selected); + } + + /// Add a bidirectional link between two elements at a given level (concurrent version). + /// + /// This method uses the per-node lock in ElementGraphData for thread safety. + /// NOTE: This is the legacy single-lock version, kept for compatibility. + /// For parallel insertion, use mutually_connect_new_element instead. + #[allow(dead_code)] + fn add_bidirectional_link_concurrent(&self, from: IdType, to: IdType, level: usize) { + if let Some(from_element) = self.graph.get(from) { + if level < from_element.levels.len() { + #[cfg(feature = "profile")] + let lock_start = Instant::now(); + + let _lock = from_element.lock.lock(); + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().add_links_lock_ns += lock_start.elapsed().as_nanos() as u64); + + #[cfg(feature = "profile")] + let get_start = Instant::now(); + + let mut current_neighbors = from_element.get_neighbors(level); + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().add_links_get_neighbors_ns += get_start.elapsed().as_nanos() as u64); + + #[cfg(feature = "profile")] + let contains_start = Instant::now(); + + if current_neighbors.contains(&to) { + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().add_links_contains_ns += contains_start.elapsed().as_nanos() as u64); + return; + } + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().add_links_contains_ns += contains_start.elapsed().as_nanos() as u64); + + current_neighbors.push(to); + + // Check if we need to prune + let m = if level == 0 { self.params.m_max_0 } else { self.params.m }; + + if current_neighbors.len() > m { + #[cfg(feature = "profile")] + let prune_start = Instant::now(); + + // Need to select best neighbors - use simple selection (M closest) + // This is faster than the heuristic and still maintains good graph quality + let query = match self.data.get(from) { + Some(v) => v, + None => return, + }; + + let mut candidates: Vec<_> = current_neighbors + .iter() + .filter_map(|&n| { + self.data.get(n).map(|data| { + let dist = self.dist_fn.compute(data, query, self.params.dim); + (n, dist) + }) + }) + .collect(); + + // Use partial sort to find M closest - O(n) instead of O(n log n) + // select_nth_unstable partitions so elements [0..m] are the m smallest + if candidates.len() > m { + candidates.select_nth_unstable_by(m - 1, |a, b| { + a.1.partial_cmp(&b.1).unwrap() + }); + candidates.truncate(m); + } + let selected: Vec<_> = candidates.iter().map(|&(id, _)| id).collect(); + + #[cfg(feature = "profile")] + { + PROFILE_STATS.with(|s| { + let mut stats = s.borrow_mut(); + stats.add_links_prune_ns += prune_start.elapsed().as_nanos() as u64; + stats.add_links_prune_count += 1; + }); + } + + #[cfg(feature = "profile")] + let set_start = Instant::now(); + + from_element.set_neighbors(level, &selected); + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().add_links_set_ns += set_start.elapsed().as_nanos() as u64); + } else { + #[cfg(feature = "profile")] + let set_start = Instant::now(); + + from_element.set_neighbors(level, ¤t_neighbors); + + #[cfg(feature = "profile")] + PROFILE_STATS.with(|s| s.borrow_mut().add_links_set_ns += set_start.elapsed().as_nanos() as u64); + } + } + } + } + + /// Compute distance between two elements. + /// Uses get_unchecked_deleted for faster access during search. + #[inline] + fn compute_distance(&self, id: IdType, query: &[T]) -> T::DistanceType { + if let Some(data) = self.data.get_unchecked_deleted(id) { + self.dist_fn.compute(data, query, self.params.dim) + } else { + T::DistanceType::infinity() + } + } + + /// Mark an element as deleted. + pub fn mark_deleted(&mut self, id: IdType) { + self.mark_deleted_concurrent(id); + } + + /// Mark an element as deleted (concurrent version). + pub fn mark_deleted_concurrent(&self, id: IdType) { + // Update the flags array first (O(1) atomic operation for fast deleted checks) + self.graph.mark_deleted(id); + + // Also update the element's metadata for consistency + if let Some(element) = self.graph.get(id) { + // ElementMetaData.deleted is not atomic, but this is a best-effort + // tombstone - reads may see stale state briefly, which is acceptable + // for deletion semantics. Use a separate lock if stronger guarantees needed. + let _lock = element.lock.lock(); + // SAFETY: We hold the element's lock, so this is the only writer + unsafe { + let meta = &element.meta as *const _ as *mut graph::ElementMetaData; + (*meta).deleted = true; + } + } + self.data.mark_deleted_concurrent(id); + } + + /// Search for nearest neighbors. + pub fn search( + &self, + query: &[T], + k: usize, + ef: usize, + filter: Option<&dyn Fn(IdType) -> bool>, + ) -> Vec<(IdType, T::DistanceType)> { + let mut visited = self.visited_pool.get(); + visited.reset(); + self.search_with_visited(query, k, ef, filter, &mut visited) + } + + /// Search for nearest neighbors using a provided visited handler. + /// + /// This variant allows callers to provide their own visited handler, + /// which is useful for batch queries to avoid pool contention. + pub fn search_with_visited( + &self, + query: &[T], + k: usize, + ef: usize, + filter: Option<&dyn Fn(IdType) -> bool>, + visited: &mut VisitedNodesHandler, + ) -> Vec<(IdType, T::DistanceType)> { + let entry_point = self.entry_point.load(Ordering::Acquire); + if entry_point == INVALID_ID { + return Vec::new(); + } + + let current_max = self.max_level.load(Ordering::Acquire) as usize; + let mut current_entry = entry_point; + + // Greedy search through upper layers + // Use get_unchecked_deleted for faster access during search + for l in (1..=current_max).rev() { + let (new_entry, _) = search::greedy_search( + current_entry, + query, + l, + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + ); + current_entry = new_entry; + } + + let entry_dist = self.compute_distance(current_entry, query); + #[cfg(debug_assertions)] + eprintln!("KNN search: entry_point={}, entry_dist={}, k={}, ef={}", + current_entry, entry_dist.to_f64(), k, ef); + let entry_points = vec![(current_entry, entry_dist)]; + + let results = if let Some(f) = filter { + search::search_layer( + &entry_points, + query, + 0, + ef.max(k), + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + visited, + Some(f), + ) + } else { + search::search_layer:: bool, _>( + &entry_points, + query, + 0, + ef.max(k), + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + visited, + None, + ) + }; + + // Return top k + results.into_iter().take(k).collect() + } + + /// Range search for all vectors within a radius. + /// + /// Uses epsilon-neighborhood search algorithm for efficient graph traversal. + /// The search terminates early when no candidates are within the dynamic range. + /// + /// # Arguments + /// * `query` - Query vector + /// * `radius` - Maximum distance for results + /// * `epsilon` - Search boundary expansion factor (e.g., 0.01 = 1% expansion) + /// * `filter` - Optional filter function + pub fn search_range( + &self, + query: &[T], + radius: T::DistanceType, + epsilon: f64, + filter: Option<&dyn Fn(IdType) -> bool>, + ) -> Vec<(IdType, T::DistanceType)> { + let entry_point = self.entry_point.load(Ordering::Acquire); + if entry_point == INVALID_ID { + return Vec::new(); + } + + let current_max = self.max_level.load(Ordering::Acquire) as usize; + let mut current_entry = entry_point; + + // Greedy search through upper layers to find entry point for layer 0 + // Use get_unchecked_deleted for faster access during search + for l in (1..=current_max).rev() { + let (new_entry, _) = search::greedy_search( + current_entry, + query, + l, + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + ); + current_entry = new_entry; + } + + // Range search on layer 0 + let mut visited = self.visited_pool.get(); + visited.reset(); + + let entry_dist = self.compute_distance(current_entry, query); + let entry_points = vec![(current_entry, entry_dist)]; + + // Use get_unchecked_deleted for faster access during search + // (skips the Mutex lock on free_slots) + if let Some(f) = filter { + search::search_layer_range( + &entry_points, + query, + 0, + radius, + epsilon, + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + Some(f), + ) + } else { + search::search_layer_range:: bool, _>( + &entry_points, + query, + 0, + radius, + epsilon, + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + None, + ) + } + } + + /// Search for k nearest unique labels (for multi-value indices). + /// + /// This method does label-aware search during graph traversal, + /// ensuring we find the k closest unique labels rather than the + /// k closest vectors (which might share labels). + pub fn search_multi( + &self, + query: &[T], + k: usize, + ef: usize, + id_to_label: &HashMap, + filter: Option<&dyn Fn(LabelType) -> bool>, + ) -> Vec<(LabelType, T::DistanceType)> { + let entry_point = self.entry_point.load(Ordering::Acquire); + if entry_point == INVALID_ID { + return Vec::new(); + } + + let current_max = self.max_level.load(Ordering::Acquire) as usize; + let mut current_entry = entry_point; + + // Greedy search through upper layers + // Use get_unchecked_deleted for faster access during search + for l in (1..=current_max).rev() { + let (new_entry, _) = search::greedy_search( + current_entry, + query, + l, + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + ); + current_entry = new_entry; + } + + // Search layer 0 with label-aware search + let mut visited = self.visited_pool.get(); + visited.reset(); + + let entry_dist = self.compute_distance(current_entry, query); + let entry_points = vec![(current_entry, entry_dist)]; + + if let Some(f) = filter { + search::search_layer_multi( + &entry_points, + query, + 0, + k, + ef.max(k), + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + id_to_label, + Some(f), + ) + } else { + search::search_layer_multi:: bool, _>( + &entry_points, + query, + 0, + k, + ef.max(k), + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + id_to_label, + None, + ) + } + } + + // ========================================================================= + // Batch Construction Methods + // ========================================================================= + + /// Generate N random levels in batch (reduces lock contention on RNG). + pub fn generate_random_levels_batch(&self, n: usize) -> Vec { + let mut rng = self.rng.lock(); + (0..n) + .map(|_| { + let r: f64 = rng.gen(); + let level = (-r.ln() * self.level_mult).floor() as u8; + level.min(32) + }) + .collect() + } + + /// Batch insert multiple elements. + /// + /// NOTE: True parallel HNSW construction with high recall is fundamentally limited + /// because each new vector must search the existing graph to find neighbors. + /// This method provides correct recall (~99%) but limited speedup (~1.0x). + /// + /// For significant speedup (3-4x), graph partitioning with merge would be needed, + /// which requires k-means clustering and is a much more complex implementation. + /// + /// # Arguments + /// * `assignments` - Vec of (id, level, label) tuples for elements already added to data storage + pub fn insert_batch(&self, assignments: &[(IdType, u8, LabelType)]) { + // Simply insert each element using the concurrent method + // This maintains correct recall and uses lock-ordering for thread safety + for &(id, level, label) in assignments { + // Create graph node + let graph_data = ElementGraphData::new( + label, + level, + self.params.m_max_0, + self.params.m, + ); + self.graph.set(id, graph_data); + + // Update visited pool if needed + let id_usize = id as usize; + if id_usize >= self.visited_pool.current_capacity() { + self.visited_pool.resize(id_usize + 1024); + } + + // Insert into graph using concurrent method + self.insert_concurrent_impl(id, label); + } + } + + /// Search for neighbors at a specific layer for batch construction. + /// + /// This is a read-only operation that can run in parallel across vectors. + fn search_neighbors_for_layer( + &self, + id: IdType, + layer: usize, + ) -> Vec<(IdType, T::DistanceType)> { + let query = match self.get_vector(id) { + Some(v) => v, + None => return Vec::new(), + }; + + let entry_point = self.entry_point.load(Ordering::Acquire); + if entry_point == INVALID_ID || entry_point == id { + return Vec::new(); + } + + let current_max = self.max_level.load(Ordering::Acquire) as usize; + + // Greedy search from entry point down to target layer + // Use get_unchecked_deleted for faster access during search + let mut current_entry = entry_point; + for l in (layer + 1..=current_max).rev() { + let (new_entry, _) = search::greedy_search( + current_entry, + query, + l, + &self.graph, + |id| self.data.get_unchecked_deleted(id), + self.dist_fn.as_ref(), + self.params.dim, + ); + current_entry = new_entry; + } + + // Search at target layer + let mut visited = self.visited_pool.get(); + visited.reset(); + + let entry_dist = self.compute_distance(current_entry, query); + let entry_points = vec![(current_entry, entry_dist)]; + + let neighbors = search::search_layer:: bool, _>( + &entry_points, + query, + layer, + self.params.ef_construction, + &self.graph, + |nid| self.data.get_unchecked_deleted(nid), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + None, + ); + + // Select best neighbors using heuristic + let m = if layer == 0 { self.params.m_max_0 } else { self.params.m }; + if self.params.enable_heuristic { + search::select_neighbors_heuristic( + id, + &neighbors, + m, + |nid| self.data.get_unchecked_deleted(nid), + self.dist_fn.as_ref(), + self.params.dim, + false, + true, + ) + .into_iter() + .filter_map(|nid| { + neighbors.iter().find(|&&(n, _)| n == nid).map(|&(n, d)| (n, d)) + }) + .collect() + } else { + search::select_neighbors_simple(&neighbors, m) + .into_iter() + .filter_map(|nid| { + neighbors.iter().find(|&&(n, _)| n == nid).map(|&(n, d)| (n, d)) + }) + .collect() + } + } + + /// Connect multiple nodes with their neighbors in batch. + /// + /// This method groups connections to minimize lock acquisitions and + /// uses lock ordering to prevent deadlocks. + fn batch_connect_neighbors( + &self, + candidates: &[(IdType, Vec<(IdType, T::DistanceType)>)], + layer: usize, + ) { + let max_m = if layer == 0 { self.params.m_max_0 } else { self.params.m }; + + for &(id, ref neighbors) in candidates { + if neighbors.is_empty() { + continue; + } + + // Set outgoing edges for this node + if let Some(element) = self.graph.get(id) { + if layer < element.levels.len() { + let _lock = element.lock.lock(); + let neighbor_ids: Vec = neighbors.iter().map(|&(n, _)| n).collect(); + element.set_neighbors(layer, &neighbor_ids); + } + } + + // Add reverse edges (with lock ordering) + for &(neighbor_id, dist) in neighbors { + self.add_reverse_edge_batch(neighbor_id, id, dist, layer, max_m); + } + } + } + + /// Add a reverse edge during batch construction. + /// + /// This adds `from_id` to `to_id`'s neighbor list, with pruning if needed. + fn add_reverse_edge_batch( + &self, + to_id: IdType, + from_id: IdType, + _dist: T::DistanceType, + layer: usize, + max_m: usize, + ) { + if let Some(to_element) = self.graph.get(to_id) { + if layer >= to_element.levels.len() { + return; + } + + let _lock = to_element.lock.lock(); + let mut current_neighbors = to_element.get_neighbors(layer); + + // Skip if already connected + if current_neighbors.contains(&from_id) { + return; + } + + current_neighbors.push(from_id); + + // Prune if over capacity + if current_neighbors.len() > max_m { + let to_data = match self.data.get(to_id) { + Some(d) => d, + None => return, + }; + + let mut candidates: Vec<(IdType, T::DistanceType)> = current_neighbors + .iter() + .filter_map(|&n| { + self.data.get(n).map(|data| { + let d = self.dist_fn.compute(data, to_data, self.params.dim); + (n, d) + }) + }) + .collect(); + + if candidates.len() > max_m { + candidates.select_nth_unstable_by(max_m - 1, |a, b| { + a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) + }); + candidates.truncate(max_m); + } + + let selected: Vec = candidates.iter().map(|&(id, _)| id).collect(); + to_element.set_neighbors(layer, &selected); + } else { + to_element.set_neighbors(layer, ¤t_neighbors); + } + } + } + + /// Get the neighbors of an element at all levels by internal ID. + /// + /// Returns None if the ID doesn't exist in the index. + /// Returns Some(Vec>) where each inner Vec contains the neighbor labels at that level. + pub fn get_element_neighbors_by_id(&self, id: IdType) -> Option>> { + // Get the graph element + let element = self.graph.get(id)?; + + // Collect neighbors at each level, converting internal IDs to labels + let mut result = Vec::with_capacity(element.levels.len()); + + for level in 0..element.levels.len() { + let neighbor_ids = element.get_neighbors(level); + let neighbor_labels: Vec = neighbor_ids + .iter() + .filter_map(|&neighbor_id| { + self.graph.get(neighbor_id).map(|e| e.meta.label) + }) + .collect(); + result.push(neighbor_labels); + } + + Some(result) + } +} diff --git a/rust/vecsim/src/index/hnsw/multi.rs b/rust/vecsim/src/index/hnsw/multi.rs new file mode 100644 index 000000000..69956c2a1 --- /dev/null +++ b/rust/vecsim/src/index/hnsw/multi.rs @@ -0,0 +1,1527 @@ +//! Multi-value HNSW index implementation. +//! +//! This index allows multiple vectors per label. + +use super::{ElementGraphData, HnswCore, HnswParams}; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply, QueryResult}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement}; +use dashmap::{DashMap, DashSet}; +use std::collections::HashMap; + +/// Multi-value HNSW index. +/// +/// Each label can have multiple associated vectors. +pub struct HnswMulti { + /// Core HNSW implementation. + pub(crate) core: HnswCore, + /// Label to set of internal IDs mapping. + label_to_ids: DashMap>, + /// Internal ID to label mapping. + pub(crate) id_to_label: DashMap, + /// Number of vectors. + count: std::sync::atomic::AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, +} + +impl HnswMulti { + /// Create a new multi-value HNSW index. + pub fn new(params: HnswParams) -> Self { + let initial_capacity = params.initial_capacity; + let core = HnswCore::new(params); + + Self { + core, + label_to_ids: DashMap::with_capacity(initial_capacity / 2), + id_to_label: DashMap::with_capacity(initial_capacity), + count: std::sync::atomic::AtomicUsize::new(0), + capacity: None, + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: HnswParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the distance metric. + pub fn metric(&self) -> crate::distance::Metric { + self.core.params.metric + } + + /// Get the ef_runtime parameter. + pub fn ef_runtime(&self) -> usize { + self.core.params.ef_runtime + } + + /// Set the ef_runtime parameter. + pub fn set_ef_runtime(&mut self, ef: usize) { + self.core.params.ef_runtime = ef; + } + + /// Get copies of all vectors stored for a given label. + /// + /// Returns `None` if the label doesn't exist in the index. + pub fn get_vectors(&self, label: LabelType) -> Option>> { + let ids_ref = self.label_to_ids.get(&label)?; + let vectors: Vec> = ids_ref + .iter() + .filter_map(|id| self.core.data.get(*id).map(|v| v.to_vec())) + .collect(); + if vectors.is_empty() { + None + } else { + Some(vectors) + } + } + + /// Get all labels currently in the index. + pub fn get_labels(&self) -> Vec { + self.label_to_ids.iter().map(|r| *r.key()).collect() + } + + /// Compute the minimum distance between any stored vector for a label and a query vector. + /// + /// Returns `None` if the label doesn't exist. + pub fn compute_distance(&self, label: LabelType, query: &[T]) -> Option { + let ids_ref = self.label_to_ids.get(&label)?; + ids_ref + .iter() + .filter_map(|id| { + self.core.data.get(*id).map(|stored| { + self.core.dist_fn.compute(stored, query, self.core.params.dim) + }) + }) + .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + } + + /// Get the memory usage in bytes. + pub fn memory_usage(&self) -> usize { + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + // Vector data storage + let vector_storage = count * self.core.params.dim * std::mem::size_of::(); + + // Graph structure (rough estimate) + let graph_overhead = self.core.graph.len() + * std::mem::size_of::>(); + + // Label mappings (rough estimate with DashMap) + let label_maps = self.label_to_ids.len() + * std::mem::size_of::<(LabelType, DashSet)>() + + self.id_to_label.len() * std::mem::size_of::<(IdType, LabelType)>(); + + vector_storage + graph_overhead + label_maps + } + + /// Clear all vectors from the index, resetting it to empty state. + pub fn clear(&mut self) { + use std::sync::atomic::Ordering; + + self.core.data.clear(); + self.core.graph.clear(); + self.core.entry_point.store(crate::types::INVALID_ID, Ordering::Relaxed); + self.core.max_level.store(0, Ordering::Relaxed); + self.label_to_ids.clear(); + self.id_to_label.clear(); + self.count.store(0, Ordering::Relaxed); + } + + /// Get the neighbors of an element in the HNSW graph by label. + /// + /// For multi-value indices, this returns the neighbors of the first internal ID + /// associated with the label. Returns a vector of vectors, where each inner vector + /// contains the neighbor labels for that level (level 0 first). + /// Returns None if the label doesn't exist. + pub fn get_element_neighbors(&self, label: LabelType) -> Option>> { + // Get the first internal ID for this label + let ids = self.label_to_ids.get(&label)?; + let first_id = *ids.iter().next()?; + self.core.get_element_neighbors_by_id(first_id) + } + + /// Compact the index by removing gaps from deleted vectors. + /// + /// This reorganizes the internal storage and graph structure to reclaim space + /// from deleted vectors. After compaction, all vectors are stored contiguously. + /// + /// **Note**: For HNSW indices, compaction also rebuilds the graph neighbor links + /// with updated IDs, which can be computationally expensive for large indices. + /// + /// # Arguments + /// * `shrink` - If true, also release unused memory blocks + /// + /// # Returns + /// The number of bytes reclaimed (approximate). + pub fn compact(&mut self, shrink: bool) -> usize { + use std::sync::atomic::Ordering; + + let old_capacity = self.core.data.capacity(); + let id_mapping = self.core.data.compact(shrink); + + // Rebuild graph with new IDs + let mut new_graph: Vec> = (0..id_mapping.len()).map(|_| None).collect(); + + for (&old_id, &new_id) in &id_mapping { + if let Some(old_graph_data) = self.core.graph.get(old_id) { + // Clone the graph data and update neighbor IDs + let mut new_graph_data = ElementGraphData::new( + old_graph_data.meta.label, + old_graph_data.meta.level, + self.core.params.m_max_0, + self.core.params.m, + ); + new_graph_data.meta.deleted = old_graph_data.meta.deleted; + + // Update neighbor IDs in each level + for (level_idx, level_link) in old_graph_data.levels.iter().enumerate() { + let old_neighbors = level_link.get_neighbors(); + let new_neighbors: Vec = old_neighbors + .iter() + .filter_map(|&neighbor_id| id_mapping.get(&neighbor_id).copied()) + .collect(); + + if level_idx < new_graph_data.levels.len() { + new_graph_data.levels[level_idx].set_neighbors(&new_neighbors); + } + } + + new_graph[new_id as usize] = Some(new_graph_data); + } + } + + self.core.graph.replace(new_graph); + + // Update entry point + let old_entry = self.core.entry_point.load(Ordering::Relaxed); + if old_entry != crate::types::INVALID_ID { + if let Some(&new_entry) = id_mapping.get(&old_entry) { + self.core.entry_point.store(new_entry, Ordering::Relaxed); + } else { + // Entry point was deleted, find a new one + let new_entry = self.core.graph.iter() + .filter(|(_, g)| !g.meta.deleted) + .map(|(id, _)| id) + .next() + .unwrap_or(crate::types::INVALID_ID); + self.core.entry_point.store(new_entry, Ordering::Relaxed); + } + } + + // Update label_to_ids mapping - collect keys first to avoid holding the iter + let labels: Vec = self.label_to_ids.iter().map(|r| *r.key()).collect(); + for label in labels { + if let Some(mut ids_ref) = self.label_to_ids.get_mut(&label) { + let new_ids: DashSet = ids_ref + .iter() + .filter_map(|id| id_mapping.get(&*id).copied()) + .collect(); + *ids_ref = new_ids; + } + } + + // Remove labels with no remaining vectors + self.label_to_ids.retain(|_, ids| !ids.is_empty()); + + // Rebuild id_to_label mapping + let old_id_to_label: HashMap = self.id_to_label.iter() + .map(|r| (*r.key(), *r.value())) + .collect(); + self.id_to_label.clear(); + for (&old_id, &new_id) in &id_mapping { + if let Some(&label) = old_id_to_label.get(&old_id) { + self.id_to_label.insert(new_id, label); + } + } + + // Resize visited pool + if !id_mapping.is_empty() { + let max_id = id_mapping.values().max().copied().unwrap_or(0) as usize; + self.core.visited_pool.resize(max_id + 1); + } + + let new_capacity = self.core.data.capacity(); + let dim = self.core.params.dim; + let bytes_per_vector = dim * std::mem::size_of::(); + + (old_capacity.saturating_sub(new_capacity)) * bytes_per_vector + } + + /// Get the fragmentation ratio of the index. + /// + /// Returns a value between 0.0 (no fragmentation) and 1.0 (all slots are deleted). + pub fn fragmentation(&self) -> f64 { + self.core.data.fragmentation() + } + + /// Add multiple vectors at once. + /// + /// Returns the number of vectors successfully added. + /// Stops on first error and returns what was added so far. + pub fn add_vectors( + &mut self, + vectors: &[(&[T], LabelType)], + ) -> Result { + let mut added = 0; + for &(vector, label) in vectors { + self.add_vector(vector, label)?; + added += 1; + } + Ok(added) + } + + /// Add a vector with parallel-safe semantics. + /// + /// Can be called from multiple threads simultaneously. + pub fn add_vector_concurrent(&self, vector: &[T], label: LabelType) -> Result { + if vector.len() != self.core.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.core.params.dim, + got: vector.len(), + }); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(std::sync::atomic::Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add the vector concurrently + let id = self.core + .add_vector_concurrent(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + self.core.insert_concurrent(id, label); + + // Update mappings using DashMap + self.label_to_ids + .entry(label) + .or_insert_with(DashSet::new) + .insert(id); + self.id_to_label.insert(id, label); + + self.count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Ok(1) + } +} + +impl VecSimIndex for HnswMulti { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + if vector.len() != self.core.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.core.params.dim, + got: vector.len(), + }); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(std::sync::atomic::Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add the vector + let id = self.core + .add_vector(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + self.core.insert(id, label); + + // Update mappings using DashMap + self.label_to_ids + .entry(label) + .or_insert_with(DashSet::new) + .insert(id); + self.id_to_label.insert(id, label); + + self.count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + if let Some((_, ids)) = self.label_to_ids.remove(&label) { + let count = ids.len(); + + for id in ids.iter() { + self.core.mark_deleted(*id); + self.id_to_label.remove(&*id); + } + + self.count + .fetch_sub(count, std::sync::atomic::Ordering::Relaxed); + Ok(count) + } else { + Err(IndexError::LabelNotFound(label)) + } + } + + fn top_k_query( + &self, + query: &[T], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.core.params.dim, + got: query.len(), + }); + } + + let base_ef = params + .and_then(|p| p.ef_runtime) + .unwrap_or(self.core.params.ef_runtime); + + // Get the id_to_label mapping for label-aware search + let id_to_label: HashMap = self.id_to_label + .iter() + .map(|r| (*r.key(), *r.value())) + .collect(); + + // For multi-value indices, we need a higher ef to explore enough unique labels. + // The label heap will track unique labels, so ef controls how many labels we track. + // Use ef * avg_per_label to compensate for label clustering. + let total_vectors = self.count.load(std::sync::atomic::Ordering::Relaxed); + let num_labels = self.label_to_ids.len(); + let avg_per_label = if num_labels > 0 { + (total_vectors / num_labels).max(1) + } else { + 1 + }; + // Scale ef by avg_per_label * 2 to ensure we find enough unique labels + // The extra 2x factor helps compensate for graph clustering effects + let ef = (base_ef * avg_per_label * 2).min(num_labels).max(base_ef); + + // Build filter function by wrapping the reference + let filter_ref: Option<&dyn Fn(LabelType) -> bool> = if let Some(p) = params { + p.filter.as_ref().map(|f| f.as_ref() as &dyn Fn(LabelType) -> bool) + } else { + None + }; + + // Use label-aware search that tracks unique labels during graph traversal + let results = self.core.search_multi( + query, + k, + ef, + &id_to_label, + filter_ref, + ); + + // Convert to QueryReply + let mut reply = QueryReply::with_capacity(results.len()); + for (label, dist) in results { + reply.push(QueryResult::new(label, dist)); + } + + Ok(reply) + } + + fn range_query( + &self, + query: &[T], + radius: T::DistanceType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.core.params.dim, + got: query.len(), + }); + } + + // Get epsilon from params or use default (0.01 = 1% expansion) + // This matches the C++ HNSW_DEFAULT_EPSILON value for consistent behavior. + let epsilon = params + .and_then(|p| p.epsilon) + .unwrap_or(0.01); + + // Build filter if needed + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to DashMap directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.get(&id).is_some_and(|label_ref| f(*label_ref)) + })) + } else { + None + } + } else { + None + }; + + // Use epsilon-neighborhood search for efficient range query + let results = self.core.search_range( + query, + radius, + epsilon, + filter_fn.as_ref().map(|f| f.as_ref()), + ); + + // For multi-value index, deduplicate by label and keep best distance per label + let mut label_best: HashMap = HashMap::new(); + + for (id, dist) in results { + if let Some(label_ref) = self.id_to_label.get(&id) { + let label = *label_ref; + label_best + .entry(label) + .and_modify(|best| { + if dist.to_f64() < best.to_f64() { + *best = dist; + } + }) + .or_insert(dist); + } + } + + // Convert to QueryReply + let mut reply = QueryReply::with_capacity(label_best.len()); + for (label, dist) in label_best { + reply.push(QueryResult::new(label, dist)); + } + + reply.sort_by_distance(); + Ok(reply) + } + + fn index_size(&self) -> usize { + self.count.load(std::sync::atomic::Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.core.params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[T], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + if query.len() != self.core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.core.params.dim, + got: query.len(), + }); + } + + Ok(Box::new( + super::batch_iterator::HnswMultiBatchIterator::new(self, query.to_vec(), params.cloned()), + )) + } + + fn info(&self) -> IndexInfo { + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + // Base overhead for the struct itself and internal data structures + let base_overhead = std::mem::size_of::() + std::mem::size_of::>(); + + IndexInfo { + size: count, + capacity: self.capacity, + dimension: self.core.params.dim, + index_type: "HnswMulti", + memory_bytes: base_overhead + + count * self.core.params.dim * std::mem::size_of::() + + self.core.graph.len() * std::mem::size_of::>() + + self.label_to_ids.len() + * std::mem::size_of::<(LabelType, DashSet)>(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.label_to_ids.contains_key(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + self.label_to_ids + .get(&label) + .map_or(0, |ids| ids.len()) + } +} + +unsafe impl Send for HnswMulti {} +unsafe impl Sync for HnswMulti {} + +// Serialization support +impl HnswMulti { + /// Save the index to a writer. + pub fn save( + &self, + writer: &mut W, + ) -> crate::serialization::SerializationResult<()> { + use crate::serialization::*; + use std::sync::atomic::Ordering; + + let graph_len = self.core.graph.len(); + let count = self.count.load(Ordering::Relaxed); + + // Write header + let header = IndexHeader::new( + IndexTypeId::HnswMulti, + T::data_type_id(), + self.core.params.metric, + self.core.params.dim, + count, + ); + header.write(writer)?; + + // Write HNSW-specific params + write_usize(writer, self.core.params.m)?; + write_usize(writer, self.core.params.m_max_0)?; + write_usize(writer, self.core.params.ef_construction)?; + write_usize(writer, self.core.params.ef_runtime)?; + write_u8(writer, if self.core.params.enable_heuristic { 1 } else { 0 })?; + + // Write graph metadata + let entry_point = self.core.entry_point.load(Ordering::Relaxed); + let max_level = self.core.max_level.load(Ordering::Relaxed); + write_u32(writer, entry_point)?; + write_u32(writer, max_level)?; + + // Write capacity + write_u8(writer, if self.capacity.is_some() { 1 } else { 0 })?; + if let Some(cap) = self.capacity { + write_usize(writer, cap)?; + } + + // Write label_to_ids mapping (label -> set of IDs) + write_usize(writer, self.label_to_ids.len())?; + for entry in self.label_to_ids.iter() { + let label = *entry.key(); + let ids = entry.value(); + write_u64(writer, label)?; + write_usize(writer, ids.len())?; + for id in ids.iter() { + write_u32(writer, *id)?; + } + } + + // Write graph structure + write_usize(writer, graph_len)?; + for id in 0..graph_len as u32 { + if let Some(graph_data) = self.core.graph.get(id) { + write_u8(writer, 1)?; // Present flag + + // Write metadata + write_u64(writer, graph_data.meta.label)?; + write_u8(writer, graph_data.meta.level)?; + write_u8(writer, if graph_data.meta.deleted { 1 } else { 0 })?; + + // Write levels + write_usize(writer, graph_data.levels.len())?; + for level_links in &graph_data.levels { + let neighbors = level_links.get_neighbors(); + write_usize(writer, neighbors.len())?; + for neighbor in neighbors { + write_u32(writer, neighbor)?; + } + } + + // Write vector data + if let Some(vector) = self.core.data.get(id) { + for v in vector { + v.write_to(writer)?; + } + } + } else { + write_u8(writer, 0)?; // Not present + } + } + + Ok(()) + } + + /// Load the index from a reader. + pub fn load(reader: &mut R) -> crate::serialization::SerializationResult { + use crate::serialization::*; + use super::graph::ElementGraphData; + use std::sync::atomic::Ordering; + + // Read and validate header + let header = IndexHeader::read(reader)?; + + if header.index_type != IndexTypeId::HnswMulti { + return Err(SerializationError::IndexTypeMismatch { + expected: "HnswMulti".to_string(), + got: header.index_type.as_str().to_string(), + }); + } + + if header.data_type != T::data_type_id() { + return Err(SerializationError::InvalidData( + format!("Expected {:?} data type, got {:?}", T::data_type_id(), header.data_type), + )); + } + + // Read HNSW-specific params + let m = read_usize(reader)?; + let m_max_0 = read_usize(reader)?; + let ef_construction = read_usize(reader)?; + let ef_runtime = read_usize(reader)?; + let enable_heuristic = read_u8(reader)? != 0; + + // Read graph metadata + let entry_point = read_u32(reader)?; + let max_level = read_u32(reader)?; + + // Create params and index + let params = HnswParams { + dim: header.dimension, + metric: header.metric, + m, + m_max_0, + ef_construction, + ef_runtime, + initial_capacity: header.count.max(1024), + enable_heuristic, + seed: None, // Seed not preserved in serialization + }; + let mut index = Self::new(params); + + // Read capacity + let has_capacity = read_u8(reader)? != 0; + if has_capacity { + index.capacity = Some(read_usize(reader)?); + } + + // Read label_to_ids mapping into DashMap + let label_to_ids_len = read_usize(reader)?; + for _ in 0..label_to_ids_len { + let label = read_u64(reader)?; + let num_ids = read_usize(reader)?; + let ids: DashSet = DashSet::with_capacity(num_ids); + for _ in 0..num_ids { + ids.insert(read_u32(reader)?); + } + index.label_to_ids.insert(label, ids); + } + + // Build id_to_label from label_to_ids + for entry in index.label_to_ids.iter() { + let label = *entry.key(); + for id in entry.value().iter() { + index.id_to_label.insert(*id, label); + } + } + + // Read graph structure + let graph_len = read_usize(reader)?; + let dim = header.dimension; + + // Set entry point and max level + index.core.entry_point.store(entry_point, Ordering::Relaxed); + index.core.max_level.store(max_level, Ordering::Relaxed); + + for id in 0..graph_len { + let present = read_u8(reader)? != 0; + if !present { + continue; + } + + // Read metadata + let label = read_u64(reader)?; + let level = read_u8(reader)?; + let deleted = read_u8(reader)? != 0; + + // Read levels + let num_levels = read_usize(reader)?; + let mut graph_data = ElementGraphData::new(label, level, m_max_0, m); + graph_data.meta.deleted = deleted; + + for level_idx in 0..num_levels { + let num_neighbors = read_usize(reader)?; + let mut neighbors = Vec::with_capacity(num_neighbors); + for _ in 0..num_neighbors { + neighbors.push(read_u32(reader)?); + } + if level_idx < graph_data.levels.len() { + graph_data.levels[level_idx].set_neighbors(&neighbors); + } + } + + // Read vector data + let mut vector = vec![T::zero(); dim]; + for v in &mut vector { + *v = T::read_from(reader)?; + } + + // Add vector to data storage + index.core.data.add(&vector).ok_or_else(|| { + SerializationError::DataCorruption("Failed to add vector during deserialization".to_string()) + })?; + + // Store graph data + index.core.graph.set(id as IdType, graph_data); + } + + // Resize visited pool + if graph_len > 0 { + index.core.visited_pool.resize(graph_len); + } + + index.count.store(header.count, Ordering::Relaxed); + + Ok(index) + } + + /// Save the index to a file. + pub fn save_to_file>( + &self, + path: P, + ) -> crate::serialization::SerializationResult<()> { + let file = std::fs::File::create(path)?; + let mut writer = std::io::BufWriter::new(file); + self.save(&mut writer) + } + + /// Load the index from a file. + pub fn load_from_file>( + path: P, + ) -> crate::serialization::SerializationResult { + let file = std::fs::File::open(path)?; + let mut reader = std::io::BufReader::new(file); + Self::load(&mut reader) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_hnsw_multi_basic() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add multiple vectors with same label + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 3); + assert_eq!(index.label_count(1), 2); + assert_eq!(index.label_count(2), 1); + } + + #[test] + fn test_hnsw_multi_delete() { + let params = HnswParams::new(4, Metric::L2); + let mut index = HnswMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 3); + + let deleted = index.delete_vector(1).unwrap(); + assert_eq!(deleted, 2); + assert_eq!(index.index_size(), 1); + } + + #[test] + fn test_hnsw_multi_serialization() { + use std::io::Cursor; + + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add multiple vectors with same label + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + assert_eq!(index.index_size(), 4); + assert_eq!(index.label_count(1), 2); + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = HnswMulti::::load(&mut cursor).unwrap(); + + // Verify + assert_eq!(loaded.index_size(), 4); + assert_eq!(loaded.dimension(), 4); + assert_eq!(loaded.label_count(1), 2); + assert_eq!(loaded.label_count(2), 1); + assert_eq!(loaded.label_count(3), 1); + + // Query should work the same + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = loaded.top_k_query(&query, 3, None).unwrap(); + assert!(!results.is_empty()); + // First result should be one of the label 1 vectors + assert_eq!(results.results[0].label, 1); + } + + #[test] + fn test_hnsw_multi_compact() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add multiple vectors, some with same label + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 3).unwrap(); + index.add_vector(&vec![0.0, 0.0, 0.0, 1.0], 4).unwrap(); + + assert_eq!(index.index_size(), 5); + assert_eq!(index.label_count(1), 2); + + // Delete labels 2 and 3 + index.delete_vector(2).unwrap(); + index.delete_vector(3).unwrap(); + + assert_eq!(index.index_size(), 3); + assert!(index.fragmentation() > 0.0); + + // Compact the index + index.compact(true); + + // Verify fragmentation is gone + assert!((index.fragmentation() - 0.0).abs() < 0.01); + assert_eq!(index.index_size(), 3); + assert_eq!(index.label_count(1), 2); // Both vectors for label 1 remain + + // Verify queries still work correctly + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + // After deduplication by label, we can only have 2 unique labels (1 and 4) + assert_eq!(results.len(), 2); + // Top results should include label 1 vectors + assert_eq!(results.results[0].label, 1); + + // Verify we can still find v4 + let query2 = vec![0.0, 0.0, 0.0, 1.0]; + let results2 = index.top_k_query(&query2, 1, None).unwrap(); + assert_eq!(results2.results[0].label, 4); + + // Deleted labels should not appear + let all_results = index.top_k_query(&query, 10, None).unwrap(); + for result in &all_results.results { + assert!(result.label == 1 || result.label == 4); + } + } + + #[test] + fn test_hnsw_multi_multiple_labels() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add multiple vectors per label for multiple labels + for label in 1..=5u64 { + for i in 0..3 { + let v = vec![label as f32, i as f32, 0.0, 0.0]; + index.add_vector(&v, label).unwrap(); + } + } + + assert_eq!(index.index_size(), 15); + for label in 1..=5u64 { + assert_eq!(index.label_count(label), 3); + } + } + + #[test] + fn test_hnsw_multi_get_vectors() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![2.0, 0.0, 0.0, 0.0]; + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 1).unwrap(); + + let vectors = index.get_vectors(1).unwrap(); + assert_eq!(vectors.len(), 2); + // Check that both vectors are returned + let mut found_v1 = false; + let mut found_v2 = false; + for v in &vectors { + if v[0] == 1.0 { + found_v1 = true; + } + if v[0] == 2.0 { + found_v2 = true; + } + } + assert!(found_v1 && found_v2); + + // Non-existent label + assert!(index.get_vectors(999).is_none()); + } + + #[test] + fn test_hnsw_multi_get_labels() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 10).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 20).unwrap(); + index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 30).unwrap(); + + let labels = index.get_labels(); + assert_eq!(labels.len(), 3); + assert!(labels.contains(&10)); + assert!(labels.contains(&20)); + assert!(labels.contains(&30)); + } + + #[test] + fn test_hnsw_multi_compute_distance() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add two vectors for label 1, at different distances from query + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Query should return minimum distance among all vectors for the label + let query = vec![0.0, 0.0, 0.0, 0.0]; + let dist = index.compute_distance(1, &query).unwrap(); + // Distance to [1.0, 0.0, 0.0, 0.0] is 1.0 (L2 squared) + assert!((dist - 1.0).abs() < 0.001); + + // Non-existent label + assert!(index.compute_distance(999, &query).is_none()); + } + + #[test] + fn test_hnsw_multi_contains() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + assert!(index.contains(1)); + assert!(!index.contains(2)); + } + + #[test] + fn test_hnsw_multi_range_query() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add vectors at different distances + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 3).unwrap(); + index.add_vector(&vec![10.0, 0.0, 0.0, 0.0], 4).unwrap(); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + // L2 squared: dist to [1,0,0,0]=1, [2,0,0,0]=4, [3,0,0,0]=9, [10,0,0,0]=100 + let results = index.range_query(&query, 10.0, None).unwrap(); + + assert_eq!(results.len(), 3); // labels 1, 2, 3 are within radius 10 + for r in &results.results { + assert!(r.label != 4); // label 4 should not be included + } + } + + #[test] + fn test_hnsw_multi_filtered_query() { + use crate::query::QueryParams; + + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + for label in 1..=5u64 { + index.add_vector(&vec![label as f32, 0.0, 0.0, 0.0], label).unwrap(); + } + + // Filter to only allow even labels + let query_params = QueryParams::new().with_filter(|label| label % 2 == 0); + let query = vec![0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, Some(&query_params)).unwrap(); + + assert_eq!(results.len(), 2); // Only labels 2 and 4 + for r in &results.results { + assert!(r.label % 2 == 0); + } + } + + #[test] + fn test_hnsw_multi_batch_iterator() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + for i in 0..10u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![5.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(3) { + all_results.extend(batch); + } + + assert!(!all_results.is_empty()); + } + + #[test] + fn test_hnsw_multi_dimension_mismatch() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Wrong dimension on add + let result = index.add_vector(&vec![1.0, 2.0], 1); + assert!(matches!(result, Err(IndexError::DimensionMismatch { expected: 4, got: 2 }))); + + // Add a valid vector first + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Wrong dimension on query + let result = index.top_k_query(&vec![1.0, 2.0], 1, None); + assert!(matches!(result, Err(QueryError::DimensionMismatch { expected: 4, got: 2 }))); + } + + #[test] + fn test_hnsw_multi_capacity_exceeded() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::with_capacity(params, 2); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + // Third should fail + let result = index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 3); + assert!(matches!(result, Err(IndexError::CapacityExceeded { capacity: 2 }))); + } + + #[test] + fn test_hnsw_multi_delete_not_found() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let result = index.delete_vector(999); + assert!(matches!(result, Err(IndexError::LabelNotFound(999)))); + } + + #[test] + fn test_hnsw_multi_memory_usage() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + let initial_memory = index.memory_usage(); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + let after_memory = index.memory_usage(); + assert!(after_memory > initial_memory); + } + + #[test] + fn test_hnsw_multi_info() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let info = index.info(); + assert_eq!(info.size, 2); + assert_eq!(info.dimension, 4); + assert_eq!(info.index_type, "HnswMulti"); + assert!(info.memory_bytes > 0); + } + + #[test] + fn test_hnsw_multi_clear() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + index.clear(); + + assert_eq!(index.index_size(), 0); + assert!(index.get_labels().is_empty()); + assert!(!index.contains(1)); + assert!(!index.contains(2)); + } + + #[test] + fn test_hnsw_multi_add_vectors_batch() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + let vectors: Vec<(&[f32], LabelType)> = vec![ + (&[1.0, 0.0, 0.0, 0.0], 1), + (&[2.0, 0.0, 0.0, 0.0], 1), + (&[3.0, 0.0, 0.0, 0.0], 2), + ]; + + let added = index.add_vectors(&vectors).unwrap(); + assert_eq!(added, 3); + assert_eq!(index.index_size(), 3); + assert_eq!(index.label_count(1), 2); + assert_eq!(index.label_count(2), 1); + } + + #[test] + fn test_hnsw_multi_with_capacity() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let index = HnswMulti::::with_capacity(params, 100); + + assert_eq!(index.index_capacity(), Some(100)); + } + + #[test] + fn test_hnsw_multi_inner_product() { + let params = HnswParams::new(4, Metric::InnerProduct).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // For InnerProduct, higher dot product = lower "distance" + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + // Label 1 has perfect alignment with query + assert_eq!(results.results[0].label, 1); + } + + #[test] + fn test_hnsw_multi_cosine() { + let params = HnswParams::new(4, Metric::Cosine).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Cosine similarity is direction-based + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 3).unwrap(); // Same direction as label 1 + + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + // Labels 1 and 3 have the same direction (cosine=1), should be top results + assert!(results.results[0].label == 1 || results.results[0].label == 3); + assert!(results.results[1].label == 1 || results.results[1].label == 3); + } + + #[test] + fn test_hnsw_multi_metric_getter() { + let params = HnswParams::new(4, Metric::Cosine).with_m(4).with_ef_construction(20); + let index = HnswMulti::::new(params); + + assert_eq!(index.metric(), Metric::Cosine); + } + + #[test] + fn test_hnsw_multi_ef_runtime() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20) + .with_ef_runtime(50); + let mut index = HnswMulti::::new(params); + + assert_eq!(index.ef_runtime(), 50); + + // Modify ef_runtime + index.set_ef_runtime(100); + assert_eq!(index.ef_runtime(), 100); + } + + #[test] + fn test_hnsw_multi_query_with_ef_runtime() { + use crate::query::QueryParams; + + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20) + .with_ef_runtime(10); + let mut index = HnswMulti::::new(params); + + for i in 0..50u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![25.0, 0.0, 0.0, 0.0]; + + // Query with default ef_runtime + let results1 = index.top_k_query(&query, 5, None).unwrap(); + assert!(!results1.is_empty()); + + // Query with higher ef_runtime + let query_params = QueryParams::new().with_ef_runtime(100); + let results2 = index.top_k_query(&query, 5, Some(&query_params)).unwrap(); + assert!(!results2.is_empty()); + } + + #[test] + fn test_hnsw_multi_filtered_range_query() { + use crate::query::QueryParams; + + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + for label in 1..=10u64 { + index.add_vector(&vec![label as f32, 0.0, 0.0, 0.0], label).unwrap(); + } + + // Filter to only allow labels 1-5, with range 50 (covers labels 1-7 by distance) + let query_params = QueryParams::new().with_filter(|label| label <= 5); + let query = vec![0.0, 0.0, 0.0, 0.0]; + let results = index.range_query(&query, 50.0, Some(&query_params)).unwrap(); + + // Should have labels 1-5 (filtered) that are within range 50 + assert_eq!(results.len(), 5); + for r in &results.results { + assert!(r.label <= 5); + } + } + + #[test] + fn test_hnsw_multi_empty_query() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let index = HnswMulti::::new(params); + + // Query on empty index + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, None).unwrap(); + assert!(results.is_empty()); + } + + #[test] + fn test_hnsw_multi_fragmentation() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + for i in 1..=10u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + // Initially no fragmentation + assert!((index.fragmentation() - 0.0).abs() < 0.01); + + // Delete half the vectors + for i in 1..=5u64 { + index.delete_vector(i).unwrap(); + } + + // Now there should be fragmentation + assert!(index.fragmentation() > 0.3); + } + + #[test] + fn test_hnsw_multi_serialization_with_capacity() { + use std::io::Cursor; + + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::with_capacity(params, 1000); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = HnswMulti::::load(&mut cursor).unwrap(); + + // Capacity should be preserved + assert_eq!(loaded.index_capacity(), Some(1000)); + assert_eq!(loaded.index_size(), 2); + assert_eq!(loaded.label_count(1), 2); + } + + #[test] + fn test_hnsw_multi_query_after_compact() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add many vectors + for i in 1..=20u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + // Delete odd labels + for i in (1..=20u64).step_by(2) { + index.delete_vector(i).unwrap(); + } + + // Compact + index.compact(true); + + // Query should still work + let query = vec![4.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 5, None).unwrap(); + + assert_eq!(results.len(), 5); + // First result should be label 4 (exact match, and it's even so not deleted) + assert_eq!(results.results[0].label, 4); + + // All results should be even labels + for r in &results.results { + assert!(r.label % 2 == 0); + } + } + + #[test] + fn test_hnsw_multi_larger_scale() { + let params = HnswParams::new(8, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut index = HnswMulti::::new(params); + + // Add 1000 vectors, 10 per label + for label in 0..100u64 { + for i in 0..10 { + let v = vec![ + label as f32 + i as f32 * 0.01, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ]; + index.add_vector(&v, label).unwrap(); + } + } + + assert_eq!(index.index_size(), 1000); + assert_eq!(index.label_count(50), 10); + + // Query should work + let query = vec![50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, None).unwrap(); + + assert!(!results.is_empty()); + // First result should be label 50 (closest) + assert_eq!(results.results[0].label, 50); + } + + #[test] + fn test_hnsw_multi_heuristic_mode() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20) + .with_heuristic(true); + let mut index = HnswMulti::::new(params); + + for i in 0..20u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![10.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 5, None).unwrap(); + + assert_eq!(results.len(), 5); + assert_eq!(results.results[0].label, 10); + } + + #[test] + fn test_hnsw_multi_batch_iterator_with_params() { + use crate::query::QueryParams; + + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + for i in 1..=10u64 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = vec![5.0, 0.0, 0.0, 0.0]; + // Note: filters cannot be cloned, so QueryParams with filters + // won't preserve the filter when cloned for batch_iterator + let query_params = QueryParams::new().with_ef_runtime(50); + + // Batch iterator should work with params + let mut iter = index.batch_iterator(&query, Some(&query_params)).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should have all 10 results + assert_eq!(all_results.len(), 10); + } + + #[test] + fn test_hnsw_multi_range_query_with_custom_epsilon() { + use crate::query::QueryParams; + + let params = HnswParams::new(4, Metric::L2).with_m(8).with_ef_construction(50); + let mut index = HnswMulti::::new(params); + + // Add vectors at increasing distances + for i in 0..50 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64).unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let radius = 25.0; // Should include labels 0-5 (distances 0, 1, 4, 9, 16, 25) + + // Test with default epsilon (0.01) + let results_default = index.range_query(&query, radius, None).unwrap(); + let expected_labels: Vec = (0..=5).collect(); + assert_eq!(results_default.len(), expected_labels.len()); + + // Test with custom epsilon (0.01 - same as default) + let query_params = QueryParams::new().with_epsilon(0.01); + let results_eps_001 = index.range_query(&query, radius, Some(&query_params)).unwrap(); + assert_eq!(results_eps_001.len(), expected_labels.len()); + + // Test with larger epsilon (1.0 = 100% expansion) + let query_params_large = QueryParams::new().with_epsilon(1.0); + let results_eps_100 = index.range_query(&query, radius, Some(&query_params_large)).unwrap(); + assert_eq!(results_eps_100.len(), expected_labels.len()); + } + + #[test] + fn test_hnsw_multi_range_query_empty_result() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswMulti::::new(params); + + // Add vectors far from origin + index.add_vector(&vec![100.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![200.0, 0.0, 0.0, 0.0], 2).unwrap(); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + // Very small radius should return no results + let results = index.range_query(&query, 1.0, None).unwrap(); + assert_eq!(results.len(), 0); + } +} diff --git a/rust/vecsim/src/index/hnsw/search.rs b/rust/vecsim/src/index/hnsw/search.rs new file mode 100644 index 000000000..0b905db78 --- /dev/null +++ b/rust/vecsim/src/index/hnsw/search.rs @@ -0,0 +1,1186 @@ +//! Search algorithms for HNSW graph traversal. +//! +//! This module provides the core search algorithms: +//! - `greedy_search`: Single entry point search (for upper layers) +//! - `search_layer`: Full layer search with candidate exploration + +use super::concurrent_graph::ConcurrentGraph; +use super::graph::ElementGraphData; +use super::visited::VisitedNodesHandler; +use crate::distance::DistanceFunction; +use crate::types::{DistanceType, IdType, VectorElement}; +use crate::utils::prefetch::prefetch_slice; +use crate::utils::{MaxHeap, MinHeap}; + +/// Trait for graph access abstraction. +/// +/// This allows search functions to work with both slice-based graphs +/// (for tests) and the concurrent graph structure (for production). +pub trait GraphAccess { + /// Get an element by ID. + fn get(&self, id: IdType) -> Option<&ElementGraphData>; + + /// Check if an element is marked as deleted. + /// Default implementation uses get(), but can be overridden for efficiency. + #[inline] + fn is_deleted(&self, id: IdType) -> bool { + self.get(id).is_some_and(|e| e.meta.deleted) + } +} + +/// Implementation for slice-based graphs (used in tests). +impl GraphAccess for [Option] { + #[inline] + fn get(&self, id: IdType) -> Option<&ElementGraphData> { + <[Option]>::get(self, id as usize).and_then(|x| x.as_ref()) + } +} + +/// Implementation for Vec-based graphs (used in tests). +impl GraphAccess for Vec> { + #[inline] + fn get(&self, id: IdType) -> Option<&ElementGraphData> { + self.as_slice().get(id as usize).and_then(|x| x.as_ref()) + } +} + +/// Implementation for ConcurrentGraph. +impl GraphAccess for ConcurrentGraph { + #[inline] + fn get(&self, id: IdType) -> Option<&ElementGraphData> { + ConcurrentGraph::get(self, id) + } + + /// O(1) deleted check using separate flags array. + /// This avoids acquiring a segment read lock for every neighbor check. + #[inline] + fn is_deleted(&self, id: IdType) -> bool { + self.is_marked_deleted(id) + } +} + +/// Result of a layer search: (id, distance) pairs. +pub type SearchResult = Vec<(IdType, D)>; + + + +/// Greedy search to find the single closest element at a given layer. +/// +/// This is used to traverse upper layers where we just need to find +/// the best entry point for the next layer. +pub fn greedy_search<'a, T, D, F, G>( + entry_point: IdType, + query: &[T], + level: usize, + graph: &G, + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, +) -> (IdType, D) +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, + G: GraphAccess + ?Sized, +{ + let mut current = entry_point; + let mut current_dist = if let Some(data) = data_getter(entry_point) { + dist_fn.compute(data, query, dim) + } else { + D::infinity() + }; + + loop { + let mut changed = false; + + if let Some(element) = graph.get(current) { + // Get neighbor count without allocation + let neighbor_count = element.neighbor_count(level); + + // Prefetch first neighbor's data + if neighbor_count > 0 { + if let Some(first_neighbor) = element.get_neighbor_at(level, 0) { + if let Some(first_data) = data_getter(first_neighbor) { + prefetch_slice(first_data); + } + } + } + + // Iterate using indexed access to avoid Vec allocation + for i in 0..neighbor_count { + let neighbor = match element.get_neighbor_at(level, i) { + Some(n) => n, + None => continue, + }; + + // Prefetch next neighbor's data while processing current + if i + 1 < neighbor_count { + if let Some(next_neighbor) = element.get_neighbor_at(level, i + 1) { + if let Some(next_data) = data_getter(next_neighbor) { + prefetch_slice(next_data); + } + } + } + + if let Some(data) = data_getter(neighbor) { + let dist = dist_fn.compute(data, query, dim); + if dist < current_dist { + current = neighbor; + current_dist = dist; + changed = true; + } + } + } + } + + if !changed { + break; + } + } + + (current, current_dist) +} + +/// Search a layer to find the ef closest elements. +/// +/// This is the main search algorithm for finding nearest neighbors +/// at a given layer. +#[allow(clippy::too_many_arguments)] +pub fn search_layer<'a, T, D, F, P, G>( + entry_points: &[(IdType, D)], + query: &[T], + level: usize, + ef: usize, + graph: &G, + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, + visited: &VisitedNodesHandler, + filter: Option<&P>, +) -> SearchResult +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, + P: Fn(IdType) -> bool + ?Sized, + G: GraphAccess + ?Sized, +{ + // Candidates to explore (min-heap: closest first) + let mut candidates = MinHeap::::with_capacity(ef * 2); + + // Results (max-heap: keeps k closest, largest at top) + let mut results = MaxHeap::::new(ef); + + // Initialize with entry points + for &(id, dist) in entry_points { + if !visited.visit(id) { + candidates.push(id, dist); + + // Check filter for results + let passes = filter.is_none_or(|f| f(id)); + if passes { + results.insert(id, dist); + } + } + } + + // Explore candidates + while let Some(candidate) = candidates.pop() { + // Check if we can stop (candidate is further than worst result) + if results.is_full() { + if let Some(worst_dist) = results.top_distance() { + if candidate.distance > worst_dist { + break; + } + } + } + + // Get neighbors of this candidate + if let Some(element) = graph.get(candidate.id) { + if element.meta.deleted { + continue; + } + + // Get neighbor count without allocation + let neighbor_count = element.neighbor_count(level); + + // Prefetch first neighbor's data and visited tag + if neighbor_count > 0 { + if let Some(first_neighbor) = element.get_neighbor_at(level, 0) { + visited.prefetch(first_neighbor); + if let Some(first_data) = data_getter(first_neighbor) { + prefetch_slice(first_data); + } + } + } + + // Iterate using indexed access to avoid Vec allocation + for i in 0..neighbor_count { + // Get current neighbor + let neighbor = match element.get_neighbor_at(level, i) { + Some(n) => n, + None => continue, + }; + + // Prefetch next neighbor's data and visited tag while processing current + if i + 1 < neighbor_count { + if let Some(next_neighbor) = element.get_neighbor_at(level, i + 1) { + visited.prefetch(next_neighbor); + if let Some(next_data) = data_getter(next_neighbor) { + prefetch_slice(next_data); + } + } + } + + if visited.visit(neighbor) { + continue; // Already visited + } + + // Compute distance to neighbor (don't check deleted here - we check at result insertion) + // This matches C++ behavior: deleted nodes are still explored for graph connectivity + if let Some(data) = data_getter(neighbor) { + let dist = dist_fn.compute(data, query, dim); + + // Check if close enough to consider + let dominated = results.is_full() + && dist >= results.top_distance().unwrap(); + + if !dominated { + // Add to candidates for exploration (even for deleted nodes) + candidates.push(neighbor, dist); + + // Only add to results if not deleted and passes filter + if !graph.is_deleted(neighbor) { + let passes = filter.is_none_or(|f| f(neighbor)); + if passes { + results.try_insert(neighbor, dist); + } + } + } + } + } + } + } + + // Convert results to vector (optimized to minimize allocations) + results.into_sorted_pairs() +} + +use crate::types::LabelType; +use std::collections::HashMap; + +/// Search result for multi-value indices: (label, distance) pairs. +pub type MultiSearchResult = Vec<(LabelType, D)>; + +/// Search a layer to find the k closest unique labels. +/// +/// This is a label-aware search for multi-value indices. It explores the graph +/// like standard HNSW but tracks labels separately. Vectors from already-seen +/// labels are still explored (their neighbors might lead to new labels) but +/// only count once in the label results. This prevents early termination from +/// cutting off exploration when vectors cluster by label. +#[allow(clippy::too_many_arguments)] +pub fn search_layer_multi<'a, T, D, F, P, G>( + entry_points: &[(IdType, D)], + query: &[T], + level: usize, + k: usize, + ef: usize, + graph: &G, + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, + visited: &VisitedNodesHandler, + id_to_label: &HashMap, + filter: Option<&P>, +) -> MultiSearchResult +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, + P: Fn(LabelType) -> bool + ?Sized, + G: GraphAccess + ?Sized, +{ + // Track labels we've found and their best distances + let mut label_best: HashMap = HashMap::with_capacity(k * 2); + + // Candidates to explore (min-heap: closest first) + let mut candidates = MinHeap::::with_capacity(ef * 2); + + // Helper to get the worst (largest) distance among the top-ef labels + let get_ef_worst_dist = |label_best: &HashMap, ef: usize| -> Option { + if label_best.len() < ef { + return None; + } + let mut dists: Vec = label_best.values().copied().collect(); + dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + dists.get(ef - 1).copied() + }; + + // Initialize with entry points + for &(id, dist) in entry_points { + if !visited.visit(id) { + candidates.push(id, dist); + + // Check filter and update label tracking + if let Some(&label) = id_to_label.get(&id) { + let passes = filter.is_none_or(|f| f(label)); + if passes { + label_best + .entry(label) + .and_modify(|best| { + if dist < *best { + *best = dist; + } + }) + .or_insert(dist); + } + } + } + } + + // Explore candidates with label-aware early termination + while let Some(candidate) = candidates.pop() { + // Early termination: stop when we have ef labels AND + // candidate is further than the ef-th best label distance + if label_best.len() >= ef { + if let Some(ef_worst) = get_ef_worst_dist(&label_best, ef) { + if candidate.distance > ef_worst { + break; + } + } + } + + // Get neighbors of this candidate + if let Some(element) = graph.get(candidate.id) { + if element.meta.deleted { + continue; + } + + // Get neighbor count without allocation + let neighbor_count = element.neighbor_count(level); + + // Prefetch first neighbor's data and visited tag + if neighbor_count > 0 { + if let Some(first_neighbor) = element.get_neighbor_at(level, 0) { + visited.prefetch(first_neighbor); + if let Some(first_data) = data_getter(first_neighbor) { + prefetch_slice(first_data); + } + } + } + + // Iterate using indexed access to avoid Vec allocation + for i in 0..neighbor_count { + let neighbor = match element.get_neighbor_at(level, i) { + Some(n) => n, + None => continue, + }; + + // Prefetch next neighbor's data and visited tag while processing current + if i + 1 < neighbor_count { + if let Some(next_neighbor) = element.get_neighbor_at(level, i + 1) { + visited.prefetch(next_neighbor); + if let Some(next_data) = data_getter(next_neighbor) { + prefetch_slice(next_data); + } + } + } + + if visited.visit(neighbor) { + continue; // Already visited + } + + // Compute distance to neighbor (don't check deleted here - we check at result insertion) + if let Some(data) = data_getter(neighbor) { + let dist = dist_fn.compute(data, query, dim); + + // Less aggressive pruning: only prune if significantly worse + // Always explore if we haven't found enough labels yet + let should_explore = if label_best.len() >= ef { + get_ef_worst_dist(&label_best, ef) + .map_or(true, |worst| dist < worst) + } else { + true + }; + + if should_explore { + candidates.push(neighbor, dist); + } + + // Only update label tracking if not deleted + if !graph.is_deleted(neighbor) { + if let Some(&label) = id_to_label.get(&neighbor) { + let passes = filter.is_none_or(|f| f(label)); + if passes { + label_best + .entry(label) + .and_modify(|best| { + if dist < *best { + *best = dist; + } + }) + .or_insert(dist); + } + } + } + } + } + } + } + + // Convert to sorted vector of (label, distance) + let mut results_vec: Vec<_> = label_best.into_iter().collect(); + results_vec.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + results_vec.truncate(k); + results_vec +} + +/// Range search on a layer to find all elements within a radius. +/// +/// This implements the epsilon-neighborhood search algorithm from the C++ implementation: +/// - Uses a dynamic range that shrinks as closer candidates are found +/// - Terminates early when no candidates are within range * (1 + epsilon) +/// - Returns all vectors within the radius +/// +/// # Arguments +/// * `entry_points` - Initial entry points with their distances +/// * `query` - Query vector +/// * `level` - Graph level to search (typically 0 for range queries) +/// * `radius` - Maximum distance for results +/// * `epsilon` - Expansion factor for search boundaries (e.g., 0.01 = 1% expansion) +/// * `graph` - Graph structure +/// * `data_getter` - Function to get vector data by ID +/// * `dist_fn` - Distance function +/// * `dim` - Vector dimension +/// * `visited` - Visited nodes handler for deduplication +/// * `filter` - Optional filter function for results +#[allow(clippy::too_many_arguments)] +pub fn search_layer_range<'a, T, D, F, P, G>( + entry_points: &[(IdType, D)], + query: &[T], + level: usize, + radius: D, + epsilon: f64, + graph: &G, + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, + visited: &VisitedNodesHandler, + filter: Option<&P>, +) -> SearchResult +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, + P: Fn(IdType) -> bool + ?Sized, + G: GraphAccess + ?Sized, +{ + // Results container + let mut results: Vec<(IdType, D)> = Vec::new(); + + // Candidates to explore (min-heap: closest first, stored as negative for max-heap behavior) + // We use MinHeap but negate distances to get max-heap behavior (pop largest = closest) + let mut candidates = MinHeap::::with_capacity(256); + + // Initialize dynamic range based on entry point + let mut dynamic_range = D::infinity(); + + // Initialize with entry points + for &(id, dist) in entry_points { + if !visited.visit(id) { + candidates.push(id, dist); + + // Check if entry point is within radius + let passes_filter = filter.is_none_or(|f| f(id)); + if passes_filter && dist.to_f64() <= radius.to_f64() { + results.push((id, dist)); + } + + // Update dynamic range + if dist.to_f64() < dynamic_range.to_f64() { + dynamic_range = dist; + } + } + } + + // Ensure dynamic_range >= radius (we need to explore at least to the radius) + if dynamic_range.to_f64() < radius.to_f64() { + dynamic_range = radius; + } + + // Search boundary includes epsilon expansion + let compute_boundary = |range: D| -> f64 { range.to_f64() * (1.0 + epsilon) }; + + // Explore candidates + // Compute initial boundary (matching C++ behavior: boundary is computed BEFORE shrinking) + let mut current_boundary = compute_boundary(dynamic_range); + + while let Some(candidate) = candidates.pop() { + // Early termination: stop if best candidate is outside the dynamic search boundary + if candidate.distance.to_f64() > current_boundary { + break; + } + + // Shrink dynamic range if this candidate is closer but still >= radius + // Update boundary AFTER shrinking (matching C++ behavior) + let cand_dist = candidate.distance.to_f64(); + if cand_dist < dynamic_range.to_f64() && cand_dist >= radius.to_f64() { + dynamic_range = candidate.distance; + current_boundary = compute_boundary(dynamic_range); + } + + // Get neighbors of this candidate + if let Some(element) = graph.get(candidate.id) { + if element.meta.deleted { + continue; + } + + // Get neighbor count without allocation + let neighbor_count = element.neighbor_count(level); + + // Prefetch first neighbor's data and visited tag + if neighbor_count > 0 { + if let Some(first_neighbor) = element.get_neighbor_at(level, 0) { + visited.prefetch(first_neighbor); + if let Some(first_data) = data_getter(first_neighbor) { + prefetch_slice(first_data); + } + } + } + + // Iterate using indexed access to avoid Vec allocation + for i in 0..neighbor_count { + let neighbor = match element.get_neighbor_at(level, i) { + Some(n) => n, + None => continue, + }; + + // Prefetch next neighbor's data and visited tag while processing current + if i + 1 < neighbor_count { + if let Some(next_neighbor) = element.get_neighbor_at(level, i + 1) { + visited.prefetch(next_neighbor); + if let Some(next_data) = data_getter(next_neighbor) { + prefetch_slice(next_data); + } + } + } + + if visited.visit(neighbor) { + continue; // Already visited + } + + // Compute distance to neighbor (don't check deleted here - we check at result insertion) + if let Some(data) = data_getter(neighbor) { + let dist = dist_fn.compute(data, query, dim); + let dist_f64 = dist.to_f64(); + + // Add to candidates if within dynamic search boundary (even for deleted nodes) + if dist_f64 < current_boundary { + candidates.push(neighbor, dist); + + // Only add to results if not deleted, within radius, and passes filter + if dist_f64 <= radius.to_f64() && !graph.is_deleted(neighbor) { + let passes = filter.is_none_or(|f| f(neighbor)); + if passes { + results.push((neighbor, dist)); + } + } + } + } + } + } + } + + // Sort results by distance + results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + results +} + +/// Select neighbors using the simple heuristic (just keep closest). +pub fn select_neighbors_simple(candidates: &[(IdType, D)], m: usize) -> Vec { + let mut sorted: Vec<_> = candidates.to_vec(); + sorted.sort_by(|a, b| { + a.1.partial_cmp(&b.1) + .unwrap_or(std::cmp::Ordering::Equal) + }); + sorted.into_iter().take(m).map(|(id, _)| id).collect() +} + +/// Select neighbors using the heuristic from the HNSW paper. +/// +/// This heuristic ensures diversity in the selected neighbors by checking +/// that each candidate is closer to the target than to any already-selected +/// neighbor. Uses batch distance computation for better cache efficiency. +#[allow(clippy::too_many_arguments)] +pub fn select_neighbors_heuristic<'a, T, D, F>( + target: IdType, + candidates: &[(IdType, D)], + m: usize, + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, + _extend_candidates: bool, + keep_pruned: bool, +) -> Vec +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, +{ + use crate::distance::batch::check_candidate_diversity; + + if candidates.is_empty() { + return Vec::new(); + } + + let _target_data = match data_getter(target) { + Some(d) => d, + None => return select_neighbors_simple(candidates, m), + }; + + // Sort candidates by distance + let mut working: Vec<_> = candidates.to_vec(); + working.sort_by(|a, b| { + a.1.partial_cmp(&b.1) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let mut selected: Vec = Vec::with_capacity(m); + let mut pruned: Vec<(IdType, D)> = Vec::new(); + + for (candidate_id, candidate_dist) in working { + if selected.len() >= m { + if keep_pruned { + pruned.push((candidate_id, candidate_dist)); + } + continue; + } + + // Use batch diversity check - this computes distances from candidate + // to all selected neighbors and checks if any is closer than candidate_dist + let is_good = check_candidate_diversity( + candidate_id, + candidate_dist, + &selected, + &data_getter, + dist_fn, + dim, + ); + + if is_good { + selected.push(candidate_id); + } else if keep_pruned { + pruned.push((candidate_id, candidate_dist)); + } + } + + // Fill remaining slots with pruned candidates if needed + if keep_pruned && selected.len() < m { + for (id, _) in pruned { + if selected.len() >= m { + break; + } + selected.push(id); + } + } + + selected +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::{l2::L2Distance, DistanceFunction}; + + #[test] + fn test_select_neighbors_simple() { + let candidates = vec![(1, 1.0f32), (2, 0.5), (3, 2.0), (4, 0.3)]; + + let selected = select_neighbors_simple(&candidates, 2); + assert_eq!(selected.len(), 2); + assert_eq!(selected[0], 4); // Closest + assert_eq!(selected[1], 2); + } + + #[test] + fn test_select_neighbors_simple_empty() { + let candidates: Vec<(IdType, f32)> = vec![]; + let selected = select_neighbors_simple(&candidates, 5); + assert!(selected.is_empty()); + } + + #[test] + fn test_select_neighbors_simple_fewer_than_m() { + let candidates = vec![(1, 1.0f32), (2, 0.5)]; + let selected = select_neighbors_simple(&candidates, 10); + assert_eq!(selected.len(), 2); + } + + #[test] + fn test_select_neighbors_simple_equal_distances() { + let candidates = vec![(1, 1.0f32), (2, 1.0), (3, 1.0)]; + let selected = select_neighbors_simple(&candidates, 2); + assert_eq!(selected.len(), 2); + // All have same distance, so first two in sorted order + } + + #[test] + fn test_select_neighbors_heuristic_empty() { + let candidates: Vec<(IdType, f32)> = vec![]; + let dist_fn = L2Distance::::new(4); + let data_getter = |_id: IdType| -> Option<&[f32]> { None }; + + let selected = select_neighbors_heuristic( + 0, + &candidates, + 5, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + false, + false, + ); + assert!(selected.is_empty()); + } + + #[test] + fn test_select_neighbors_heuristic_basic() { + // Create simple test vectors + let vectors: Vec> = vec![ + vec![0.0, 0.0, 0.0, 0.0], // target (id 0) + vec![1.0, 0.0, 0.0, 0.0], // id 1 + vec![0.0, 1.0, 0.0, 0.0], // id 2 + vec![0.0, 0.0, 1.0, 0.0], // id 3 + vec![0.5, 0.5, 0.0, 0.0], // id 4 (between 1 and 2) + ]; + + let dist_fn = L2Distance::::new(4); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + // Distances from target (0,0,0,0): + // id 1: 1.0 (1^2) + // id 2: 1.0 + // id 3: 1.0 + // id 4: 0.5 (0.5^2 + 0.5^2 = 0.5) + let candidates = vec![(1, 1.0f32), (2, 1.0f32), (3, 1.0f32), (4, 0.5f32)]; + + let selected = select_neighbors_heuristic( + 0, + &candidates, + 3, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + false, + false, + ); + + // Should select 4 (closest) and some diverse set + assert!(!selected.is_empty()); + assert!(selected.len() <= 3); + assert_eq!(selected[0], 4); // Closest first + } + + #[test] + fn test_select_neighbors_heuristic_with_keep_pruned() { + let vectors: Vec> = vec![ + vec![0.0, 0.0], // target + vec![1.0, 0.0], // id 1 + vec![1.1, 0.0], // id 2 (very close to 1) + vec![0.0, 1.0], // id 3 (diverse direction) + ]; + + let dist_fn = L2Distance::::new(2); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + // id 2 is closer to id 1 than to target, so might be pruned + let candidates = vec![(1, 1.0f32), (2, 1.21f32), (3, 1.0f32)]; + + let selected = select_neighbors_heuristic( + 0, + &candidates, + 3, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + false, + true, // keep_pruned = true + ); + + assert_eq!(selected.len(), 3); + } + + #[test] + fn test_greedy_search_single_node() { + let dist_fn = L2Distance::::new(4); + let vectors: Vec> = vec![vec![1.0, 0.0, 0.0, 0.0]]; + + // Create single-node graph + let mut graph: Vec> = Vec::new(); + graph.push(Some(ElementGraphData::new(1, 0, 4, 2))); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [1.0, 0.0, 0.0, 0.0]; + + let (result_id, result_dist) = greedy_search( + 0, // entry point + &query, + 0, // level + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + ); + + assert_eq!(result_id, 0); + assert!(result_dist < 0.001); // Should be very close to 0 + } + + #[test] + fn test_greedy_search_finds_closest() { + let dist_fn = L2Distance::::new(4); + let vectors: Vec> = vec![ + vec![0.0, 0.0, 0.0, 0.0], // id 0 + vec![1.0, 0.0, 0.0, 0.0], // id 1 + vec![2.0, 0.0, 0.0, 0.0], // id 2 + vec![3.0, 0.0, 0.0, 0.0], // id 3 + ]; + + // Create linear graph: 0 -> 1 -> 2 -> 3 + let mut graph: Vec> = Vec::new(); + for i in 0..4 { + graph.push(Some(ElementGraphData::new(i as u64, 0, 4, 2))); + } + graph[0].as_ref().unwrap().set_neighbors(0, &[1]); + graph[1].as_ref().unwrap().set_neighbors(0, &[0, 2]); + graph[2].as_ref().unwrap().set_neighbors(0, &[1, 3]); + graph[3].as_ref().unwrap().set_neighbors(0, &[2]); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + // Query close to id 3 + let query = [3.0, 0.0, 0.0, 0.0]; + + let (result_id, result_dist) = greedy_search( + 0, // start from entry point 0 + &query, + 0, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + ); + + assert_eq!(result_id, 3); + assert!(result_dist < 0.001); + } + + #[test] + fn test_greedy_search_invalid_entry_point() { + let dist_fn = L2Distance::::new(4); + let vectors: Vec> = vec![vec![1.0, 0.0, 0.0, 0.0]]; + + let graph: Vec> = vec![Some(ElementGraphData::new(1, 0, 4, 2))]; + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [1.0, 0.0, 0.0, 0.0]; + + // Entry point 99 doesn't exist + let (result_id, result_dist) = greedy_search( + 99, + &query, + 0, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + ); + + // Should stay at entry point with infinity distance + assert_eq!(result_id, 99); + assert!(result_dist.is_infinite()); + } + + #[test] + fn test_search_layer_basic() { + use crate::index::hnsw::VisitedNodesHandlerPool; + + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = vec![ + vec![0.0, 0.0], // id 0 + vec![1.0, 0.0], // id 1 - closest to query + vec![2.0, 0.0], // id 2 + vec![3.0, 0.0], // id 3 + ]; + + // Create connected graph + let mut graph: Vec> = Vec::new(); + for i in 0..4 { + graph.push(Some(ElementGraphData::new(i as u64, 0, 4, 2))); + } + // Full connectivity at level 0 + graph[0].as_ref().unwrap().set_neighbors(0, &[1, 2, 3]); + graph[1].as_ref().unwrap().set_neighbors(0, &[0, 2, 3]); + graph[2].as_ref().unwrap().set_neighbors(0, &[0, 1, 3]); + graph[3].as_ref().unwrap().set_neighbors(0, &[0, 1, 2]); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let query = [1.0, 0.0]; // Closest to id 1 + let entry_points = vec![(0u32, 1.0f32)]; // Start from id 0 + + let results = search_layer:: bool, _>( + &entry_points, + &query, + 0, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + None, + ); + + assert!(!results.is_empty()); + // First result should be id 1 (exact match) + assert_eq!(results[0].0, 1); + assert!(results[0].1 < 0.001); + } + + #[test] + fn test_search_layer_with_filter() { + use crate::index::hnsw::VisitedNodesHandlerPool; + + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = vec![ + vec![0.0, 0.0], // id 0 + vec![1.0, 0.0], // id 1 + vec![2.0, 0.0], // id 2 + vec![3.0, 0.0], // id 3 + ]; + + let mut graph: Vec> = Vec::new(); + for i in 0..4 { + graph.push(Some(ElementGraphData::new(i as u64, 0, 4, 2))); + } + graph[0].as_ref().unwrap().set_neighbors(0, &[1, 2, 3]); + graph[1].as_ref().unwrap().set_neighbors(0, &[0, 2, 3]); + graph[2].as_ref().unwrap().set_neighbors(0, &[0, 1, 3]); + graph[3].as_ref().unwrap().set_neighbors(0, &[0, 1, 2]); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let query = [1.0, 0.0]; + let entry_points = vec![(0u32, 1.0f32)]; + + // Filter: only accept even IDs + let filter = |id: IdType| -> bool { id % 2 == 0 }; + + let results = search_layer( + &entry_points, + &query, + 0, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + Some(&filter), + ); + + // All results should have even IDs + for (id, _) in &results { + assert_eq!(id % 2, 0, "Filter should only allow even IDs"); + } + } + + #[test] + fn test_search_layer_skips_deleted() { + use crate::index::hnsw::VisitedNodesHandlerPool; + + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = vec![ + vec![0.0, 0.0], // id 0 + vec![1.0, 0.0], // id 1 - closest but deleted + vec![2.0, 0.0], // id 2 + ]; + + let mut graph: Vec> = Vec::new(); + for i in 0..3 { + graph.push(Some(ElementGraphData::new(i as u64, 0, 4, 2))); + } + graph[0].as_ref().unwrap().set_neighbors(0, &[1, 2]); + graph[1].as_ref().unwrap().set_neighbors(0, &[0, 2]); + graph[2].as_ref().unwrap().set_neighbors(0, &[0, 1]); + + // Mark id 1 as deleted + graph[1].as_mut().unwrap().meta.deleted = true; + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let query = [1.0, 0.0]; // Closest to id 1 (which is deleted) + let entry_points = vec![(0u32, 1.0f32)]; + + let results = search_layer:: bool, _>( + &entry_points, + &query, + 0, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + None, + ); + + // Should not include deleted id 1 + for (id, _) in &results { + assert_ne!(*id, 1, "Deleted node should not appear in results"); + } + } + + #[test] + fn test_search_layer_empty_graph() { + use crate::index::hnsw::VisitedNodesHandlerPool; + + let dist_fn = L2Distance::::new(2); + let graph: Vec> = Vec::new(); + + let data_getter = |_id: IdType| -> Option<&[f32]> { None }; + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let query = [1.0, 0.0]; + let entry_points: Vec<(IdType, f32)> = vec![]; + + let results = search_layer:: bool, _>( + &entry_points, + &query, + 0, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + None, + ); + + assert!(results.is_empty()); + } + + #[test] + fn test_search_layer_respects_ef() { + use crate::index::hnsw::VisitedNodesHandlerPool; + + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = (0..10).map(|i| vec![i as f32, 0.0]).collect(); + + let mut graph: Vec> = Vec::new(); + for i in 0..10 { + graph.push(Some(ElementGraphData::new(i as u64, 0, 10, 5))); + } + + // Fully connected + for i in 0..10 { + let neighbors: Vec = (0..10).filter(|&j| j != i).map(|j| j as u32).collect(); + graph[i].as_ref().unwrap().set_neighbors(0, &neighbors); + } + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let query = [0.0, 0.0]; + let entry_points = vec![(5u32, 25.0f32)]; + + // Set ef = 3, should return at most 3 results + let results = search_layer:: bool, _>( + &entry_points, + &query, + 0, + 3, // ef = 3 + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + None, + ); + + assert!(results.len() <= 3); + } +} diff --git a/rust/vecsim/src/index/hnsw/single.rs b/rust/vecsim/src/index/hnsw/single.rs new file mode 100644 index 000000000..d16bcd785 --- /dev/null +++ b/rust/vecsim/src/index/hnsw/single.rs @@ -0,0 +1,1361 @@ +//! Single-value HNSW index implementation. +//! +//! This index stores one vector per label. When adding a vector with +//! an existing label, the old vector is replaced. + +use super::{ElementGraphData, HnswCore, HnswParams}; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply, QueryResult}; +use crate::types::{IdType, LabelType, VectorElement}; +use dashmap::DashMap; + +/// Statistics about an HNSW index. +#[derive(Debug, Clone)] +pub struct HnswStats { + /// Number of vectors in the index. + pub size: usize, + /// Number of deleted (but not yet removed) elements. + pub deleted_count: usize, + /// Maximum level in the graph. + pub max_level: usize, + /// Number of elements at each level. + pub level_counts: Vec, + /// Total number of connections in the graph. + pub total_connections: usize, + /// Average connections per element. + pub avg_connections_per_element: f64, + /// Approximate memory usage in bytes. + pub memory_bytes: usize, +} + +/// Single-value HNSW index. +/// +/// Each label has exactly one associated vector. +/// +/// This index now supports concurrent access for parallel insertion +/// via the `add_vector_concurrent` method. +pub struct HnswSingle { + /// Core HNSW implementation (internally thread-safe). + pub(crate) core: HnswCore, + /// Label to internal ID mapping (concurrent hash map). + label_to_id: DashMap, + /// Internal ID to label mapping (concurrent hash map). + pub(crate) id_to_label: DashMap, + /// Number of vectors. + count: std::sync::atomic::AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, +} + +impl HnswSingle { + /// Create a new single-value HNSW index. + pub fn new(params: HnswParams) -> Self { + let initial_capacity = params.initial_capacity; + let core = HnswCore::new(params); + + Self { + core, + label_to_id: DashMap::with_capacity(initial_capacity), + id_to_label: DashMap::with_capacity(initial_capacity), + count: std::sync::atomic::AtomicUsize::new(0), + capacity: None, + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: HnswParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Process multiple queries sequentially. + /// + /// This method processes multiple queries and returns results in the same order. + /// For parallel processing, consider using rayon at the application level with + /// multiple index instances or accepting the RwLock contention overhead. + /// + /// # Arguments + /// * `queries` - Slice of query vectors (each must have length == dimension) + /// * `k` - Number of nearest neighbors to return per query + /// * `params` - Optional query parameters (applied to all queries) + /// + /// # Returns + /// A vector of QueryReply, one per input query, in the same order. + pub fn batch_search( + &self, + queries: &[Vec], + k: usize, + params: Option<&QueryParams>, + ) -> Vec, QueryError>> { + queries + .iter() + .map(|query| self.top_k_query(query, k, params)) + .collect() + } + + /// Process multiple queries with a filter. + /// + /// Similar to `batch_search` but applies a filter function to all queries. + /// + /// # Arguments + /// * `queries` - Slice of query vectors + /// * `k` - Number of nearest neighbors to return per query + /// * `ef` - Search expansion factor (higher = better recall, slower) + /// * `filter` - Filter function applied to candidate labels + pub fn batch_search_filtered( + &self, + queries: &[Vec], + k: usize, + ef: usize, + filter: F, + ) -> Vec, QueryError>> + where + F: Fn(LabelType) -> bool, + { + queries + .iter() + .map(|query| { + if query.len() != self.core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.core.params.dim, + got: query.len(), + }); + } + + // Build filter closure for this search + let id_to_label_ref = &self.id_to_label; + let filter_fn = |id: IdType| -> bool { + id_to_label_ref + .get(&id) + .is_some_and(|label_ref| filter(*label_ref)) + }; + + let results = self.core.search(query, k, ef, Some(&filter_fn)); + + // Look up labels for results + let mut reply = QueryReply::with_capacity(results.len()); + for (id, dist) in results { + if let Some(label_ref) = self.id_to_label.get(&id) { + reply.push(QueryResult::new(*label_ref, dist)); + } + } + + Ok(reply) + }) + .collect() + } + + /// Process queries from a contiguous slice of vectors. + /// + /// This is optimized for the case where queries are stored contiguously + /// in memory (e.g., from a matrix where each row is a query). + /// + /// # Arguments + /// * `query_data` - Contiguous slice containing all query vectors + /// * `num_queries` - Number of queries (query_data.len() / dim) + /// * `k` - Number of nearest neighbors per query + /// * `params` - Optional query parameters + pub fn batch_search_contiguous( + &self, + query_data: &[T], + num_queries: usize, + k: usize, + params: Option<&QueryParams>, + ) -> Vec, QueryError>> { + let dim = self.core.params.dim; + + if query_data.len() != num_queries * dim { + return vec![Err(QueryError::DimensionMismatch { + expected: num_queries * dim, + got: query_data.len(), + })]; + } + + (0..num_queries) + .map(|i| { + let query = &query_data[i * dim..(i + 1) * dim]; + self.top_k_query(query, k, params) + }) + .collect() + } + + /// Get the distance metric. + pub fn metric(&self) -> crate::distance::Metric { + self.core.params.metric + } + + /// Get the ef_runtime parameter. + pub fn ef_runtime(&self) -> usize { + self.core.params.ef_runtime + } + + /// Set the ef_runtime parameter. + pub fn set_ef_runtime(&self, ef: usize) { + // This is a race condition but acceptable for runtime parameter updates + // A more robust solution would use an AtomicUsize for ef_runtime + unsafe { + let params = &self.core.params as *const _ as *mut super::HnswParams; + (*params).ef_runtime = ef; + } + } + + /// Get the M parameter (max connections per element per layer). + pub fn m(&self) -> usize { + self.core.params.m + } + + /// Get the M_max_0 parameter (max connections at layer 0). + pub fn m_max_0(&self) -> usize { + self.core.params.m_max_0 + } + + /// Get the ef_construction parameter. + pub fn ef_construction(&self) -> usize { + self.core.params.ef_construction + } + + /// Check if heuristic neighbor selection is enabled. + pub fn is_heuristic_enabled(&self) -> bool { + self.core.params.enable_heuristic + } + + /// Get the current entry point ID (top-level node). + pub fn entry_point(&self) -> Option { + let ep = self.core.entry_point.load(std::sync::atomic::Ordering::Relaxed); + if ep == crate::types::INVALID_ID { + None + } else { + Some(ep) + } + } + + /// Get the current maximum level in the graph. + pub fn max_level(&self) -> usize { + self.core.max_level.load(std::sync::atomic::Ordering::Relaxed) as usize + } + + /// Get the number of deleted (but not yet removed) elements. + pub fn deleted_count(&self) -> usize { + self.core.graph.iter() + .filter(|(_, g)| g.meta.deleted) + .count() + } + + /// Get detailed statistics about the index. + pub fn stats(&self) -> HnswStats { + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + let mut level_counts = vec![0usize; self.core.max_level.load(std::sync::atomic::Ordering::Relaxed) as usize + 1]; + let mut total_connections = 0usize; + let mut deleted_count = 0usize; + + for (_, element) in self.core.graph.iter() { + if element.meta.deleted { + deleted_count += 1; + continue; + } + let level = element.meta.level as usize; + for l in 0..=level { + if l < level_counts.len() { + level_counts[l] += 1; + } + } + for level_link in &element.levels { + total_connections += level_link.get_neighbors().len(); + } + } + + let active_count = count.saturating_sub(deleted_count); + let avg_connections = if active_count > 0 { + total_connections as f64 / active_count as f64 + } else { + 0.0 + }; + + HnswStats { + size: count, + deleted_count, + max_level: self.core.max_level.load(std::sync::atomic::Ordering::Relaxed) as usize, + level_counts, + total_connections, + avg_connections_per_element: avg_connections, + memory_bytes: self.memory_usage(), + } + } + + /// Get the memory usage in bytes. + pub fn memory_usage(&self) -> usize { + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + // Vector data storage + let vector_storage = count * self.core.params.dim * std::mem::size_of::(); + + // Graph structure (rough estimate) + let graph_overhead = self.core.graph.len() + * std::mem::size_of::>(); + + // Label mappings + let label_maps = self.label_to_id.len() + * std::mem::size_of::<(LabelType, IdType)>() + * 2; + + vector_storage + graph_overhead + label_maps + } + + /// Get a copy of the vector stored for a given label. + /// + /// Returns `None` if the label doesn't exist in the index. + pub fn get_vector(&self, label: LabelType) -> Option> { + let id = *self.label_to_id.get(&label)?; + self.core.data.get(id).map(|v| v.to_vec()) + } + + /// Get all labels currently in the index. + pub fn get_labels(&self) -> Vec { + self.label_to_id.iter().map(|r| *r.key()).collect() + } + + /// Compute the distance between a stored vector and a query vector. + /// + /// Returns `None` if the label doesn't exist. + pub fn compute_distance(&self, label: LabelType, query: &[T]) -> Option { + let id = *self.label_to_id.get(&label)?; + if let Some(stored) = self.core.data.get(id) { + Some(self.core.dist_fn.compute(stored, query, self.core.params.dim)) + } else { + None + } + } + + /// Get the neighbors of an element at all levels. + /// + /// Returns None if the label doesn't exist in the index. + /// Returns Some(Vec>) where each inner Vec contains the neighbor labels at that level. + pub fn get_element_neighbors(&self, label: LabelType) -> Option>> { + let id = *self.label_to_id.get(&label)?; + self.core.get_element_neighbors_by_id(id) + } + + /// Clear all vectors from the index, resetting it to empty state. + pub fn clear(&mut self) { + use std::sync::atomic::Ordering; + + self.core.data.clear(); + self.core.graph.clear(); + self.core.entry_point.store(crate::types::INVALID_ID, Ordering::Relaxed); + self.core.max_level.store(0, Ordering::Relaxed); + self.label_to_id.clear(); + self.id_to_label.clear(); + self.count.store(0, Ordering::Relaxed); + } + + /// Compact the index by removing gaps from deleted vectors. + /// + /// This reorganizes the internal storage and graph structure to reclaim space + /// from deleted vectors. After compaction, all vectors are stored contiguously. + /// + /// **Note**: For HNSW indices, compaction also rebuilds the graph neighbor links + /// with updated IDs, which can be computationally expensive for large indices. + /// + /// # Arguments + /// * `shrink` - If true, also release unused memory blocks + /// + /// # Returns + /// The number of bytes reclaimed (approximate). + pub fn compact(&mut self, shrink: bool) -> usize { + use std::sync::atomic::Ordering; + + let old_capacity = self.core.data.capacity(); + let id_mapping = self.core.data.compact(shrink); + + // Rebuild graph with new IDs + let mut new_graph: Vec> = (0..id_mapping.len()).map(|_| None).collect(); + + for (&old_id, &new_id) in &id_mapping { + if let Some(old_graph_data) = self.core.graph.get(old_id) { + // Clone the graph data and update neighbor IDs + let mut new_graph_data = ElementGraphData::new( + old_graph_data.meta.label, + old_graph_data.meta.level, + self.core.params.m_max_0, + self.core.params.m, + ); + new_graph_data.meta.deleted = old_graph_data.meta.deleted; + + // Update neighbor IDs in each level + for (level_idx, level_link) in old_graph_data.levels.iter().enumerate() { + let old_neighbors = level_link.get_neighbors(); + let new_neighbors: Vec = old_neighbors + .iter() + .filter_map(|&neighbor_id| id_mapping.get(&neighbor_id).copied()) + .collect(); + + if level_idx < new_graph_data.levels.len() { + new_graph_data.levels[level_idx].set_neighbors(&new_neighbors); + } + } + + new_graph[new_id as usize] = Some(new_graph_data); + } + } + + self.core.graph.replace(new_graph); + + // Update entry point + let old_entry = self.core.entry_point.load(Ordering::Relaxed); + if old_entry != crate::types::INVALID_ID { + if let Some(&new_entry) = id_mapping.get(&old_entry) { + self.core.entry_point.store(new_entry, Ordering::Relaxed); + } else { + // Entry point was deleted, find a new one + let new_entry = self.core.graph.iter() + .filter(|(_, g)| !g.meta.deleted) + .map(|(id, _)| id) + .next() + .unwrap_or(crate::types::INVALID_ID); + self.core.entry_point.store(new_entry, Ordering::Relaxed); + } + } + + // Update label_to_id mapping + for mut entry in self.label_to_id.iter_mut() { + if let Some(&new_id) = id_mapping.get(entry.value()) { + *entry.value_mut() = new_id; + } + } + + // Rebuild id_to_label mapping + self.id_to_label.clear(); + for entry in self.label_to_id.iter() { + self.id_to_label.insert(*entry.value(), *entry.key()); + } + + // Resize visited pool + if !id_mapping.is_empty() { + let max_id = id_mapping.values().max().copied().unwrap_or(0) as usize; + self.core.visited_pool.resize(max_id + 1); + } + + let new_capacity = self.core.data.capacity(); + let dim = self.core.params.dim; + let bytes_per_vector = dim * std::mem::size_of::(); + + (old_capacity.saturating_sub(new_capacity)) * bytes_per_vector + } + + /// Get the fragmentation ratio of the index. + /// + /// Returns a value between 0.0 (no fragmentation) and 1.0 (all slots are deleted). + pub fn fragmentation(&self) -> f64 { + self.core.data.fragmentation() + } + + /// Add multiple vectors at once. + /// + /// Returns the number of vectors successfully added. + /// Stops on first error and returns what was added so far. + pub fn add_vectors( + &mut self, + vectors: &[(&[T], LabelType)], + ) -> Result { + let mut added = 0; + for &(vector, label) in vectors { + self.add_vector(vector, label)?; + added += 1; + } + Ok(added) + } + + /// Add a vector with parallel-safe semantics. + /// + /// This method can be called from multiple threads simultaneously. + /// For single-value indices, if the label already exists, this returns + /// an error (unlike `add_vector` which replaces the vector). + /// + /// # Arguments + /// * `vector` - The vector to add + /// * `label` - The label for this vector + /// + /// # Returns + /// * `Ok(1)` if the vector was added successfully + /// * `Err(IndexError::DuplicateLabel)` if the label already exists + /// * `Err(IndexError::DimensionMismatch)` if vector dimension is wrong + /// * `Err(IndexError::CapacityExceeded)` if index is full + pub fn add_vector_concurrent(&self, vector: &[T], label: LabelType) -> Result { + if vector.len() != self.core.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.core.params.dim, + got: vector.len(), + }); + } + + // Check if label already exists (atomic check) + if self.label_to_id.contains_key(&label) { + return Err(IndexError::DuplicateLabel(label)); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(std::sync::atomic::Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add the vector to storage + let id = self.core + .add_vector_concurrent(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + + // Try to insert the label mapping atomically + // DashMap's entry API provides atomic insert-if-not-exists + use dashmap::mapref::entry::Entry; + match self.label_to_id.entry(label) { + Entry::Occupied(_) => { + // Another thread inserted this label, rollback + self.core.data.mark_deleted_concurrent(id); + return Err(IndexError::DuplicateLabel(label)); + } + Entry::Vacant(entry) => { + entry.insert(id); + } + } + + // Insert into id_to_label + self.id_to_label.insert(id, label); + + // Insert into graph + self.core.insert_concurrent(id, label); + + self.count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Ok(1) + } + + /// Batch insert multiple vectors using layer-at-a-time construction. + /// + /// This method provides significant speedup for bulk insertions by: + /// 1. Adding all vectors to storage in parallel + /// 2. Pre-assigning random levels for all vectors + /// 3. Building the graph layer by layer with parallel neighbor search + /// + /// # Arguments + /// * `vectors` - Slice of vector data (flattened, each vector has `dim` elements) + /// * `labels` - Slice of labels corresponding to each vector + /// * `dim` - Dimension of each vector + /// + /// # Returns + /// * `Ok(count)` - Number of vectors successfully inserted + /// * `Err` - If dimension mismatch or other errors occur + pub fn add_vectors_batch( + &self, + vectors: &[T], + labels: &[LabelType], + dim: usize, + ) -> Result { + use rayon::prelude::*; + + if dim != self.core.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.core.params.dim, + got: dim, + }); + } + + let num_vectors = labels.len(); + if vectors.len() != num_vectors * dim { + return Err(IndexError::Internal(format!( + "Vector data length {} does not match num_vectors {} * dim {}", + vectors.len(), num_vectors, dim + ))); + } + + if num_vectors == 0 { + return Ok(0); + } + + // Check capacity + if let Some(cap) = self.capacity { + let current = self.count.load(std::sync::atomic::Ordering::Relaxed); + if current + num_vectors > cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Phase 1: Pre-generate random levels (single lock acquisition) + let levels = self.core.generate_random_levels_batch(num_vectors); + + // Phase 2: Add all vectors to storage and create assignments (parallel) + let assignments: Vec> = (0..num_vectors) + .into_par_iter() + .map(|i| { + let label = labels[i]; + + // Check for duplicate label + if self.label_to_id.contains_key(&label) { + return Err(IndexError::DuplicateLabel(label)); + } + + // Add vector to storage + let vec_start = i * dim; + let vec_end = vec_start + dim; + let vector = &vectors[vec_start..vec_end]; + + let id = self.core + .add_vector_concurrent(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + + // Insert label mappings + use dashmap::mapref::entry::Entry; + match self.label_to_id.entry(label) { + Entry::Occupied(_) => { + self.core.data.mark_deleted_concurrent(id); + return Err(IndexError::DuplicateLabel(label)); + } + Entry::Vacant(entry) => { + entry.insert(id); + } + } + self.id_to_label.insert(id, label); + + Ok((id, levels[i], label)) + }) + .collect(); + + // Filter successful assignments and count errors + let mut successful: Vec<(IdType, u8, LabelType)> = Vec::with_capacity(num_vectors); + let mut errors = 0; + for result in assignments { + match result { + Ok(assignment) => successful.push(assignment), + Err(_) => errors += 1, + } + } + + if successful.is_empty() { + return Ok(0); + } + + // Phase 3: Build graph using batch construction + self.core.insert_batch(&successful); + + // Update count + let inserted = successful.len(); + self.count.fetch_add(inserted, std::sync::atomic::Ordering::Relaxed); + + if errors > 0 { + // Some vectors failed but others succeeded + Ok(inserted) + } else { + Ok(inserted) + } + } +} + +impl VecSimIndex for HnswSingle { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + if vector.len() != self.core.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.core.params.dim, + got: vector.len(), + }); + } + + // Check if label already exists + if let Some(existing_id) = self.label_to_id.get(&label).map(|r| *r) { + // Mark old vector as deleted + self.core.mark_deleted_concurrent(existing_id); + self.id_to_label.remove(&existing_id); + + // Add new vector + let new_id = self.core + .add_vector_concurrent(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + self.core.insert_concurrent(new_id, label); + + // Update mappings + self.label_to_id.insert(label, new_id); + self.id_to_label.insert(new_id, label); + + return Ok(0); // Replacement, not a new vector + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(std::sync::atomic::Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add new vector + let id = self.core + .add_vector_concurrent(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + self.core.insert_concurrent(id, label); + + // Update mappings + self.label_to_id.insert(label, id); + self.id_to_label.insert(id, label); + + self.count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + if let Some((_, id)) = self.label_to_id.remove(&label) { + self.core.mark_deleted_concurrent(id); + self.id_to_label.remove(&id); + self.count + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + Ok(1) + } else { + Err(IndexError::LabelNotFound(label)) + } + } + + fn top_k_query( + &self, + query: &[T], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.core.params.dim, + got: query.len(), + }); + } + + let ef = params + .and_then(|p| p.ef_runtime) + .unwrap_or(self.core.params.ef_runtime); + + // Build filter if needed - use DashMap directly instead of copying + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to DashMap directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.get(&id).is_some_and(|label_ref| f(*label_ref)) + })) + } else { + None + } + } else { + None + }; + + let results = self.core.search(query, k, ef, filter_fn.as_ref().map(|f| f.as_ref())); + + // Look up labels for results + let mut reply = QueryReply::with_capacity(results.len()); + for (id, dist) in results { + if let Some(label_ref) = self.id_to_label.get(&id) { + reply.push(QueryResult::new(*label_ref, dist)); + } + } + + Ok(reply) + } + + fn range_query( + &self, + query: &[T], + radius: T::DistanceType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.core.params.dim, + got: query.len(), + }); + } + + // Get epsilon from params or use default (0.01 = 1% expansion) + // This matches the C++ HNSW_DEFAULT_EPSILON value for consistent behavior. + let epsilon = params + .and_then(|p| p.epsilon) + .unwrap_or(0.01); + + // Build filter if needed - use DashMap directly instead of copying + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to DashMap directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.get(&id).is_some_and(|label_ref| f(*label_ref)) + })) + } else { + None + } + } else { + None + }; + + // Use epsilon-neighborhood search for efficient range query + let results = self.core.search_range( + query, + radius, + epsilon, + filter_fn.as_ref().map(|f| f.as_ref()), + ); + + // Look up labels for results + let mut reply = QueryReply::with_capacity(results.len()); + for (id, dist) in results { + if let Some(label_ref) = self.id_to_label.get(&id) { + reply.push(QueryResult::new(*label_ref, dist)); + } + } + + // Results are already sorted by search_range, but ensure order + reply.sort_by_distance(); + Ok(reply) + } + + fn index_size(&self) -> usize { + self.count.load(std::sync::atomic::Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.core.params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[T], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + if query.len() != self.core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.core.params.dim, + got: query.len(), + }); + } + + Ok(Box::new( + super::batch_iterator::HnswSingleBatchIterator::new(self, query.to_vec(), params.cloned()), + )) + } + + fn info(&self) -> IndexInfo { + let count = self.count.load(std::sync::atomic::Ordering::Relaxed); + + // Base overhead for the struct itself and internal data structures + let base_overhead = std::mem::size_of::() + std::mem::size_of::>(); + + IndexInfo { + size: count, + capacity: self.capacity, + dimension: self.core.params.dim, + index_type: "HnswSingle", + memory_bytes: base_overhead + + count * self.core.params.dim * std::mem::size_of::() + + self.core.graph.len() * std::mem::size_of::>() + + self.label_to_id.len() * std::mem::size_of::<(LabelType, IdType)>(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.label_to_id.contains_key(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + if self.contains(label) { + 1 + } else { + 0 + } + } +} + +unsafe impl Send for HnswSingle {} +unsafe impl Sync for HnswSingle {} + +// Serialization support +impl HnswSingle { + /// Save the index to a writer. + pub fn save( + &self, + writer: &mut W, + ) -> crate::serialization::SerializationResult<()> { + use crate::serialization::*; + use std::sync::atomic::Ordering; + + let count = self.count.load(Ordering::Relaxed); + + // Write header + let header = IndexHeader::new( + IndexTypeId::HnswSingle, + T::data_type_id(), + self.core.params.metric, + self.core.params.dim, + count, + ); + header.write(writer)?; + + // Write HNSW-specific params + write_usize(writer, self.core.params.m)?; + write_usize(writer, self.core.params.m_max_0)?; + write_usize(writer, self.core.params.ef_construction)?; + write_usize(writer, self.core.params.ef_runtime)?; + write_u8(writer, if self.core.params.enable_heuristic { 1 } else { 0 })?; + + // Write graph metadata + let entry_point = self.core.entry_point.load(Ordering::Relaxed); + let max_level = self.core.max_level.load(Ordering::Relaxed); + write_u32(writer, entry_point)?; + write_u32(writer, max_level)?; + + // Write capacity + write_u8(writer, if self.capacity.is_some() { 1 } else { 0 })?; + if let Some(cap) = self.capacity { + write_usize(writer, cap)?; + } + + // Write label_to_id mapping + write_usize(writer, self.label_to_id.len())?; + for entry in self.label_to_id.iter() { + write_u64(writer, *entry.key())?; + write_u32(writer, *entry.value())?; + } + + // Write graph structure + let graph_len = self.core.graph.len(); + write_usize(writer, graph_len)?; + for id in 0..graph_len as u32 { + if let Some(graph_data) = self.core.graph.get(id) { + write_u8(writer, 1)?; // Present flag + + // Write metadata + write_u64(writer, graph_data.meta.label)?; + write_u8(writer, graph_data.meta.level)?; + write_u8(writer, if graph_data.meta.deleted { 1 } else { 0 })?; + + // Write levels + write_usize(writer, graph_data.levels.len())?; + for level_links in &graph_data.levels { + let neighbors = level_links.get_neighbors(); + write_usize(writer, neighbors.len())?; + for neighbor in neighbors { + write_u32(writer, neighbor)?; + } + } + + // Write vector data + if let Some(vector) = self.core.data.get(id) { + for v in vector { + v.write_to(writer)?; + } + } + } else { + write_u8(writer, 0)?; // Not present + } + } + + Ok(()) + } + + /// Load the index from a reader. + pub fn load(reader: &mut R) -> crate::serialization::SerializationResult { + use crate::serialization::*; + use super::graph::ElementGraphData; + use std::sync::atomic::Ordering; + + // Read and validate header + let header = IndexHeader::read(reader)?; + + if header.index_type != IndexTypeId::HnswSingle { + return Err(SerializationError::IndexTypeMismatch { + expected: "HnswSingle".to_string(), + got: header.index_type.as_str().to_string(), + }); + } + + if header.data_type != T::data_type_id() { + return Err(SerializationError::InvalidData( + format!("Expected {:?} data type, got {:?}", T::data_type_id(), header.data_type), + )); + } + + // Read HNSW-specific params + let m = read_usize(reader)?; + let m_max_0 = read_usize(reader)?; + let ef_construction = read_usize(reader)?; + let ef_runtime = read_usize(reader)?; + let enable_heuristic = read_u8(reader)? != 0; + + // Read graph metadata + let entry_point = read_u32(reader)?; + let max_level = read_u32(reader)?; + + // Create params and index + let params = HnswParams { + dim: header.dimension, + metric: header.metric, + m, + m_max_0, + ef_construction, + ef_runtime, + initial_capacity: header.count.max(1024), + enable_heuristic, + seed: None, // Seed not preserved in serialization + }; + let mut index = Self::new(params); + + // Read capacity + let has_capacity = read_u8(reader)? != 0; + if has_capacity { + index.capacity = Some(read_usize(reader)?); + } + + // Read label_to_id mapping + let label_to_id_len = read_usize(reader)?; + for _ in 0..label_to_id_len { + let label = read_u64(reader)?; + let id = read_u32(reader)?; + index.label_to_id.insert(label, id); + index.id_to_label.insert(id, label); + } + + // Read graph structure + let graph_len = read_usize(reader)?; + let dim = header.dimension; + + // Set entry point and max level + index.core.entry_point.store(entry_point, Ordering::Relaxed); + index.core.max_level.store(max_level, Ordering::Relaxed); + + for id in 0..graph_len { + let present = read_u8(reader)? != 0; + if !present { + continue; + } + + // Read metadata + let label = read_u64(reader)?; + let level = read_u8(reader)?; + let deleted = read_u8(reader)? != 0; + + // Read levels + let num_levels = read_usize(reader)?; + let mut graph_data = ElementGraphData::new(label, level, m_max_0, m); + graph_data.meta.deleted = deleted; + + for level_idx in 0..num_levels { + let num_neighbors = read_usize(reader)?; + let mut neighbors = Vec::with_capacity(num_neighbors); + for _ in 0..num_neighbors { + neighbors.push(read_u32(reader)?); + } + if level_idx < graph_data.levels.len() { + graph_data.levels[level_idx].set_neighbors(&neighbors); + } + } + + // Read vector data + let mut vector = vec![T::zero(); dim]; + for v in &mut vector { + *v = T::read_from(reader)?; + } + + // Add vector to data storage + index.core.data.add_concurrent(&vector).ok_or_else(|| { + SerializationError::DataCorruption("Failed to add vector during deserialization".to_string()) + })?; + + // Store graph data + index.core.graph.set(id as IdType, graph_data); + } + + // Resize visited pool + if graph_len > 0 { + index.core.visited_pool.resize(graph_len); + } + + index.count.store(header.count, Ordering::Relaxed); + + Ok(index) + } + + /// Save the index to a file. + pub fn save_to_file>( + &self, + path: P, + ) -> crate::serialization::SerializationResult<()> { + let file = std::fs::File::create(path)?; + let mut writer = std::io::BufWriter::new(file); + self.save(&mut writer) + } + + /// Load the index from a file. + pub fn load_from_file>( + path: P, + ) -> crate::serialization::SerializationResult { + let file = std::fs::File::open(path)?; + let mut reader = std::io::BufReader::new(file); + Self::load(&mut reader) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_hnsw_single_basic() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 1.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + assert_eq!(index.index_size(), 3); + + // Query for nearest neighbor + let query = vec![1.0, 0.1, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + assert!(!results.is_empty()); + assert_eq!(results.results[0].label, 1); // Closest to v1 + } + + #[test] + fn test_hnsw_single_large() { + let params = HnswParams::new(4, Metric::L2) + .with_m(8) + .with_ef_construction(50) + .with_ef_runtime(20); + let mut index = HnswSingle::::new(params); + + // Add many vectors + for i in 0..100 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 100); + + // Query should find closest + let query = vec![50.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 5, None).unwrap(); + + assert_eq!(results.len(), 5); + // Exact match should be first + assert_eq!(results.results[0].label, 50); + } + + #[test] + fn test_hnsw_single_delete() { + let params = HnswParams::new(4, Metric::L2); + let mut index = HnswSingle::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + index.delete_vector(1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Deleted vector should not appear in results + let results = index + .top_k_query(&vec![1.0, 0.0, 0.0, 0.0], 10, None) + .unwrap(); + + for result in &results.results { + assert_ne!(result.label, 1); + } + } + + #[test] + fn test_hnsw_single_serialization() { + use std::io::Cursor; + + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + // Add some vectors + for i in 0..20 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 20); + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = HnswSingle::::load(&mut cursor).unwrap(); + + // Verify + assert_eq!(loaded.index_size(), 20); + assert_eq!(loaded.dimension(), 4); + + // Query should work the same + let query = vec![5.0, 0.0, 0.0, 0.0]; + let results = loaded.top_k_query(&query, 3, None).unwrap(); + assert!(!results.is_empty()); + // Label 5 should be closest + assert_eq!(results.results[0].label, 5); + } + + #[test] + fn test_hnsw_single_compact() { + let params = HnswParams::new(4, Metric::L2) + .with_m(4) + .with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + // Add vectors + for i in 0..10 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Delete some vectors (labels 2, 4, 6, 8) + index.delete_vector(2).unwrap(); + index.delete_vector(4).unwrap(); + index.delete_vector(6).unwrap(); + index.delete_vector(8).unwrap(); + + assert_eq!(index.index_size(), 6); + assert!(index.fragmentation() > 0.0); + + // Compact the index + index.compact(true); + + // Verify fragmentation is gone + assert!((index.fragmentation() - 0.0).abs() < 0.01); + assert_eq!(index.index_size(), 6); + + // Verify queries still work correctly + let query = vec![5.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + assert!(!results.is_empty()); + // Label 5 should be closest + assert_eq!(results.results[0].label, 5); + + // Verify we can find other remaining vectors + let query2 = vec![0.0, 0.0, 0.0, 0.0]; + let results2 = index.top_k_query(&query2, 1, None).unwrap(); + assert_eq!(results2.results[0].label, 0); + + // Deleted labels should not appear in any query + let all_results = index.top_k_query(&query, 10, None).unwrap(); + for result in &all_results.results { + assert!(result.label != 2 && result.label != 4 && result.label != 6 && result.label != 8); + } + } + + #[test] + fn test_hnsw_single_range_query_basic() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + // Add vectors at different distances from origin + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![2.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&vec![3.0, 0.0, 0.0, 0.0], 3).unwrap(); + index.add_vector(&vec![10.0, 0.0, 0.0, 0.0], 4).unwrap(); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + // L2 squared: dist to [1,0,0,0]=1, [2,0,0,0]=4, [3,0,0,0]=9, [10,0,0,0]=100 + // Radius 10 should include labels 1, 2, 3 (distances 1, 4, 9) + let results = index.range_query(&query, 10.0, None).unwrap(); + + assert_eq!(results.len(), 3); + for r in &results.results { + assert!(r.label != 4); // label 4 should not be included (distance=100) + } + } + + #[test] + fn test_hnsw_single_range_query_with_custom_epsilon() { + use crate::query::QueryParams; + + let params = HnswParams::new(4, Metric::L2).with_m(8).with_ef_construction(50); + let mut index = HnswSingle::::new(params); + + // Add vectors at increasing distances + for i in 0..50 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64).unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let radius = 25.0; // Should include labels 0-5 (distances 0, 1, 4, 9, 16, 25) + + // Test with default epsilon (0.01) + let results_default = index.range_query(&query, radius, None).unwrap(); + let expected_labels: Vec = (0..=5).collect(); + assert_eq!(results_default.len(), expected_labels.len()); + + // Test with custom epsilon (0.01 - same as default, should get same results) + let query_params = QueryParams::new().with_epsilon(0.01); + let results_eps_001 = index.range_query(&query, radius, Some(&query_params)).unwrap(); + assert_eq!(results_eps_001.len(), expected_labels.len()); + + // Test with larger epsilon (1.0 = 100% expansion) + // This should still find the same results, just potentially explore more candidates + let query_params_large = QueryParams::new().with_epsilon(1.0); + let results_eps_100 = index.range_query(&query, radius, Some(&query_params_large)).unwrap(); + // Should still find all vectors within radius + assert_eq!(results_eps_100.len(), expected_labels.len()); + } + + #[test] + fn test_hnsw_single_range_query_empty_result() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + // Add vectors far from origin + index.add_vector(&vec![100.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![200.0, 0.0, 0.0, 0.0], 2).unwrap(); + + let query = vec![0.0, 0.0, 0.0, 0.0]; + // Very small radius should return no results + let results = index.range_query(&query, 1.0, None).unwrap(); + assert_eq!(results.len(), 0); + } + + #[test] + fn test_hnsw_single_range_query_all_within_radius() { + let params = HnswParams::new(4, Metric::L2).with_m(4).with_ef_construction(20); + let mut index = HnswSingle::::new(params); + + // Add vectors close to origin + for i in 0..10 { + let val = (i as f32) * 0.1; + index.add_vector(&vec![val, 0.0, 0.0, 0.0], i as u64).unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + // Large radius should return all vectors + // Max distance is 0.9^2 = 0.81 + let results = index.range_query(&query, 10.0, None).unwrap(); + assert_eq!(results.len(), 10); + } +} diff --git a/rust/vecsim/src/index/hnsw/visited.rs b/rust/vecsim/src/index/hnsw/visited.rs new file mode 100644 index 000000000..dd569189b --- /dev/null +++ b/rust/vecsim/src/index/hnsw/visited.rs @@ -0,0 +1,243 @@ +//! Visited nodes tracking for HNSW graph traversal. +//! +//! This module provides efficient tracking of visited nodes during graph search: +//! - `VisitedNodesHandler`: Tag-based visited tracking (no clearing needed) +//! - `VisitedNodesHandlerPool`: Pool of handlers for concurrent searches + +use crate::types::IdType; +use parking_lot::Mutex; +use std::sync::atomic::{AtomicU32, Ordering}; + +/// A handler for tracking visited nodes during graph traversal. +/// +/// Uses a tag-based approach: instead of clearing the visited array between +/// searches, we increment a tag. A node is considered visited if its stored +/// tag matches the current tag. +pub struct VisitedNodesHandler { + /// Tag for each node (node is visited if tag matches current_tag). + tags: Vec, + /// Current search tag. + current_tag: u32, + /// Capacity (maximum number of nodes). + capacity: usize, +} + +impl VisitedNodesHandler { + /// Create a new handler with given capacity. + pub fn new(capacity: usize) -> Self { + let tags = (0..capacity).map(|_| AtomicU32::new(0)).collect(); + Self { + tags, + current_tag: 1, + capacity, + } + } + + /// Reset for a new search. O(1) operation. + #[inline] + pub fn reset(&mut self) { + self.current_tag = self.current_tag.wrapping_add(1); + // Handle wrap-around by clearing all tags + if self.current_tag == 0 { + for tag in &self.tags { + tag.store(0, Ordering::Relaxed); + } + self.current_tag = 1; + } + } + + /// Mark a node as visited. Returns true if it was already visited. + /// + /// Uses Relaxed ordering since the tags array is only used within a single + /// search operation that starts with reset(). + #[inline] + pub fn visit(&self, id: IdType) -> bool { + let idx = id as usize; + if idx >= self.capacity { + return false; + } + + // Use Relaxed ordering - this is safe because: + // 1. The tags array is reset at the start of each search + // 2. Only the current search thread modifies tags during the search + // 3. We don't need synchronization with other threads + let old = self.tags[idx].swap(self.current_tag, Ordering::Relaxed); + old == self.current_tag + } + + /// Check if a node has been visited without marking it. + #[inline] + pub fn is_visited(&self, id: IdType) -> bool { + let idx = id as usize; + if idx >= self.capacity { + return false; + } + self.tags[idx].load(Ordering::Relaxed) == self.current_tag + } + + /// Prefetch the visited tag for a node. + /// + /// Call this before visit() to hide memory latency. + #[inline] + pub fn prefetch(&self, id: IdType) { + let idx = id as usize; + if idx < self.capacity { + crate::utils::prefetch::prefetch_read(self.tags[idx].as_ptr()); + } + } + + /// Get the capacity. + #[inline] + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Resize the handler to accommodate more nodes. + pub fn resize(&mut self, new_capacity: usize) { + if new_capacity > self.capacity { + self.tags + .resize_with(new_capacity, || AtomicU32::new(0)); + self.capacity = new_capacity; + } + } +} + +/// Pool of visited nodes handlers for concurrent searches. +/// +/// Allows multiple threads to perform searches simultaneously by +/// checking out handlers from the pool. +pub struct VisitedNodesHandlerPool { + /// Available handlers. + handlers: Mutex>, + /// Default capacity for new handlers. + default_capacity: std::sync::atomic::AtomicUsize, +} + +impl VisitedNodesHandlerPool { + /// Create a new pool. + pub fn new(default_capacity: usize) -> Self { + Self { + handlers: Mutex::new(Vec::new()), + default_capacity: std::sync::atomic::AtomicUsize::new(default_capacity), + } + } + + /// Get a handler from the pool, creating one if necessary. + pub fn get(&self) -> PooledHandler<'_> { + let cap = self.default_capacity.load(std::sync::atomic::Ordering::Acquire); + let handler = self.handlers.lock().pop().unwrap_or_else(|| { + VisitedNodesHandler::new(cap) + }); + PooledHandler { + handler: Some(handler), + pool: self, + } + } + + /// Return a handler to the pool. + fn return_handler(&self, mut handler: VisitedNodesHandler) { + handler.reset(); + self.handlers.lock().push(handler); + } + + /// Resize all handlers in the pool and update default capacity. + pub fn resize(&self, new_capacity: usize) { + self.default_capacity.store(new_capacity, std::sync::atomic::Ordering::Release); + let mut handlers = self.handlers.lock(); + for handler in handlers.iter_mut() { + handler.resize(new_capacity); + } + } + + /// Get the current default capacity. + pub fn current_capacity(&self) -> usize { + self.default_capacity.load(std::sync::atomic::Ordering::Acquire) + } +} + +/// A handler checked out from the pool. +/// +/// Automatically returns the handler when dropped. +pub struct PooledHandler<'a> { + handler: Option, + pool: &'a VisitedNodesHandlerPool, +} + +impl<'a> PooledHandler<'a> { + /// Get a mutable reference to the handler. + #[inline] + pub fn get_mut(&mut self) -> &mut VisitedNodesHandler { + self.handler.as_mut().expect("Handler already returned") + } + + /// Get an immutable reference to the handler. + #[inline] + pub fn get(&self) -> &VisitedNodesHandler { + self.handler.as_ref().expect("Handler already returned") + } +} + +impl<'a> Drop for PooledHandler<'a> { + fn drop(&mut self) { + if let Some(handler) = self.handler.take() { + self.pool.return_handler(handler); + } + } +} + +impl<'a> std::ops::Deref for PooledHandler<'a> { + type Target = VisitedNodesHandler; + + fn deref(&self) -> &Self::Target { + self.get() + } +} + +impl<'a> std::ops::DerefMut for PooledHandler<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.get_mut() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_visited_nodes_handler() { + let mut handler = VisitedNodesHandler::new(100); + + // First visit should return false + assert!(!handler.visit(5)); + assert!(handler.is_visited(5)); + + // Second visit should return true + assert!(handler.visit(5)); + + // Reset + handler.reset(); + assert!(!handler.is_visited(5)); + + // Can visit again + assert!(!handler.visit(5)); + } + + #[test] + fn test_pool() { + let pool = VisitedNodesHandlerPool::new(100); + + { + let h1 = pool.get(); + h1.visit(10); + assert!(h1.is_visited(10)); + } + + // Handler should be returned to pool and reset + { + let h2 = pool.get(); + // After reset, node should not be visited + // (unless tag wrapped around, which is unlikely) + assert!(!h2.is_visited(10)); + } + } +} diff --git a/rust/vecsim/src/index/mod.rs b/rust/vecsim/src/index/mod.rs new file mode 100644 index 000000000..b7105b792 --- /dev/null +++ b/rust/vecsim/src/index/mod.rs @@ -0,0 +1,123 @@ +//! Vector similarity index implementations. +//! +//! This module provides different index types for vector similarity search: +//! - `brute_force`: Linear scan over all vectors (exact results) +//! - `hnsw`: Hierarchical Navigable Small World graphs (approximate, fast) +//! - `tiered`: Two-tier index combining BruteForce frontend with HNSW backend +//! - `svs`: Single-layer Vamana graph (alternative to HNSW) +//! - `tiered_svs`: Two-tier index combining BruteForce frontend with SVS backend +//! - `disk`: Disk-based index with memory-mapped storage +//! - `debug`: Debug and introspection API + +pub mod brute_force; +pub mod debug; +pub mod disk; +pub mod hnsw; +pub mod svs; +pub mod tiered; +pub mod tiered_svs; +pub mod traits; + +// Re-export traits +pub use traits::{ + AsyncIndex, BatchIterator, BlockSizeConfigurable, GarbageCollectable, IndexError, IndexInfo, + IndexType, MemoryFittable, MultiValue, QueryError, VecSimIndex, DEFAULT_BLOCK_SIZE, +}; + +// Re-export BruteForce types +pub use brute_force::{ + BruteForceParams, BruteForceSingle, BruteForceMulti, BruteForceBatchIterator, BruteForceStats, +}; + +// Re-export HNSW types +pub use hnsw::{ + HnswParams, HnswSingle, HnswMulti, HnswBatchIterator, HnswStats, +}; + +// Re-export Tiered types +pub use tiered::{ + TieredParams, TieredSingle, TieredMulti, TieredBatchIterator, WriteMode, +}; + +// Re-export SVS types +pub use svs::{ + SvsParams, SvsSingle, SvsMulti, SvsStats, +}; + +// Re-export Tiered SVS types +pub use tiered_svs::{ + TieredSvsParams, TieredSvsSingle, TieredSvsMulti, TieredSvsBatchIterator, SvsWriteMode, +}; + +// Re-export Disk index types +pub use disk::{ + DiskIndexParams, DiskIndexSingle, DiskBackend, +}; + +/// Estimate the initial memory size for a BruteForce index. +/// +/// This estimates the memory needed before any vectors are added. +/// Uses saturating arithmetic to avoid overflow when initial_capacity is SIZE_MAX. +pub fn estimate_brute_force_initial_size(dim: usize, initial_capacity: usize) -> usize { + // Base struct overhead + let base = std::mem::size_of::>(); + // Data storage (use saturating arithmetic to avoid overflow) + let data = dim + .saturating_mul(std::mem::size_of::()) + .saturating_mul(initial_capacity); + // Label maps + let maps = initial_capacity + .saturating_mul(std::mem::size_of::<(u64, u32)>()) + .saturating_mul(2); + base.saturating_add(data).saturating_add(maps) +} + +/// Estimate the memory size per element for a BruteForce index. +pub fn estimate_brute_force_element_size(dim: usize) -> usize { + // Vector data + let vector = dim * std::mem::size_of::(); + // Label entry overhead + let label_overhead = std::mem::size_of::<(u64, u32)>() + std::mem::size_of::<(u64, bool)>(); + vector + label_overhead +} + +/// Estimate the initial memory size for an HNSW index. +/// +/// This estimates the memory needed before any vectors are added. +/// Uses saturating arithmetic to avoid overflow when initial_capacity is SIZE_MAX. +pub fn estimate_hnsw_initial_size(dim: usize, initial_capacity: usize, m: usize) -> usize { + // Base struct overhead + let base = std::mem::size_of::>(); + // Data storage (use saturating arithmetic to avoid overflow) + let data = dim + .saturating_mul(std::mem::size_of::()) + .saturating_mul(initial_capacity); + // Graph overhead per node (rough estimate: neighbors at level 0 + higher levels) + let graph = initial_capacity + .saturating_mul(m.saturating_mul(2).saturating_add(m)) + .saturating_mul(std::mem::size_of::()); + // Label maps + let maps = initial_capacity + .saturating_mul(std::mem::size_of::<(u64, u32)>()) + .saturating_mul(2); + // Visited pool + let visited = initial_capacity.saturating_mul(std::mem::size_of::()); + base.saturating_add(data) + .saturating_add(graph) + .saturating_add(maps) + .saturating_add(visited) +} + +/// Estimate the memory size per element for an HNSW index. +pub fn estimate_hnsw_element_size(dim: usize, m: usize) -> usize { + // Vector data + let vector = dim * std::mem::size_of::(); + // Graph connections (average across levels) + // Level 0 has 2*M neighbors, higher levels have M each + // Average number of levels is ~1.3 for typical M values + let avg_levels = 1.3f64; + let graph = ((m * 2) as f64 + (avg_levels * m as f64)) as usize * std::mem::size_of::(); + // Label entry overhead + let label_overhead = std::mem::size_of::<(u64, u32)>() * 2; + vector + graph + label_overhead +} diff --git a/rust/vecsim/src/index/svs/graph.rs b/rust/vecsim/src/index/svs/graph.rs new file mode 100644 index 000000000..18f631e87 --- /dev/null +++ b/rust/vecsim/src/index/svs/graph.rs @@ -0,0 +1,490 @@ +//! Flat graph data structure for Vamana index. +//! +//! Unlike HNSW's multi-layer graph, Vamana uses a single flat layer. + +use crate::types::{IdType, LabelType}; +use parking_lot::RwLock; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; + +/// Data for a single node in the Vamana graph. +pub struct VamanaGraphData { + /// External label. + pub label: AtomicU64, + /// Whether this node has been deleted. + pub deleted: AtomicBool, + /// Neighbor IDs (protected by RwLock for concurrent access). + neighbors: RwLock>, + /// Maximum number of neighbors. + max_neighbors: usize, +} + +impl VamanaGraphData { + /// Create a new graph node. + pub fn new(max_neighbors: usize) -> Self { + Self { + label: AtomicU64::new(0), + deleted: AtomicBool::new(false), + neighbors: RwLock::new(Vec::with_capacity(max_neighbors)), + max_neighbors, + } + } + + /// Get the label. + #[inline] + pub fn get_label(&self) -> LabelType { + self.label.load(Ordering::Acquire) + } + + /// Set the label. + #[inline] + pub fn set_label(&self, label: LabelType) { + self.label.store(label, Ordering::Release); + } + + /// Check if deleted. + #[inline] + pub fn is_deleted(&self) -> bool { + self.deleted.load(Ordering::Acquire) + } + + /// Mark as deleted. + #[inline] + pub fn mark_deleted(&self) { + self.deleted.store(true, Ordering::Release); + } + + /// Get all neighbors. + pub fn get_neighbors(&self) -> Vec { + self.neighbors.read().clone() + } + + /// Set neighbors (replaces existing). + pub fn set_neighbors(&self, new_neighbors: &[IdType]) { + let mut guard = self.neighbors.write(); + guard.clear(); + guard.extend(new_neighbors.iter().take(self.max_neighbors).copied()); + } + + /// Clear all neighbors. + pub fn clear_neighbors(&self) { + self.neighbors.write().clear(); + } + + /// Get the number of neighbors. + pub fn neighbor_count(&self) -> usize { + self.neighbors.read().len() + } +} + +impl std::fmt::Debug for VamanaGraphData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VamanaGraphData") + .field("label", &self.get_label()) + .field("deleted", &self.is_deleted()) + .field("neighbor_count", &self.neighbor_count()) + .finish() + } +} + +/// Flat graph for Vamana index. +pub struct VamanaGraph { + /// Per-node data. + nodes: Vec>, + /// Maximum neighbors per node. + max_neighbors: usize, +} + +impl VamanaGraph { + /// Create a new graph with given capacity. + pub fn new(initial_capacity: usize, max_neighbors: usize) -> Self { + Self { + nodes: Vec::with_capacity(initial_capacity), + max_neighbors, + } + } + + /// Ensure capacity for at least n nodes. + pub fn ensure_capacity(&mut self, n: usize) { + if n > self.nodes.len() { + self.nodes.resize_with(n, || None); + } + } + + /// Get or create node data for an ID. + fn get_or_create(&mut self, id: IdType) -> &VamanaGraphData { + let idx = id as usize; + if idx >= self.nodes.len() { + self.ensure_capacity(idx + 1); + } + if self.nodes[idx].is_none() { + self.nodes[idx] = Some(VamanaGraphData::new(self.max_neighbors)); + } + self.nodes[idx].as_ref().unwrap() + } + + /// Get node data (immutable). + pub fn get(&self, id: IdType) -> Option<&VamanaGraphData> { + let idx = id as usize; + if idx < self.nodes.len() { + self.nodes[idx].as_ref() + } else { + None + } + } + + /// Get label for an ID. + pub fn get_label(&self, id: IdType) -> LabelType { + self.get(id).map(|n| n.get_label()).unwrap_or(0) + } + + /// Set label for an ID. + pub fn set_label(&mut self, id: IdType, label: LabelType) { + self.get_or_create(id).set_label(label); + } + + /// Check if a node is deleted. + pub fn is_deleted(&self, id: IdType) -> bool { + self.get(id).map(|n| n.is_deleted()).unwrap_or(true) + } + + /// Mark a node as deleted. + pub fn mark_deleted(&mut self, id: IdType) { + if let Some(Some(node)) = self.nodes.get_mut(id as usize) { + node.mark_deleted(); + } + } + + /// Get neighbors of a node. + pub fn get_neighbors(&self, id: IdType) -> Vec { + self.get(id) + .map(|n| n.get_neighbors()) + .unwrap_or_default() + } + + /// Set neighbors of a node. + pub fn set_neighbors(&mut self, id: IdType, neighbors: &[IdType]) { + self.get_or_create(id).set_neighbors(neighbors); + } + + /// Clear neighbors of a node. + pub fn clear_neighbors(&mut self, id: IdType) { + if let Some(Some(node)) = self.nodes.get_mut(id as usize) { + node.clear_neighbors(); + } + } + + /// Get the total number of nodes (including deleted). + pub fn len(&self) -> usize { + self.nodes.iter().filter(|n| n.is_some()).count() + } + + /// Check if the graph is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get average out-degree (for stats). + pub fn average_degree(&self) -> f64 { + let active_nodes: Vec<_> = self + .nodes + .iter() + .filter_map(|n| n.as_ref()) + .filter(|n| !n.is_deleted()) + .collect(); + + if active_nodes.is_empty() { + return 0.0; + } + + let total_edges: usize = active_nodes.iter().map(|n| n.neighbor_count()).sum(); + total_edges as f64 / active_nodes.len() as f64 + } + + /// Get maximum out-degree. + pub fn max_degree(&self) -> usize { + self.nodes + .iter() + .filter_map(|n| n.as_ref()) + .filter(|n| !n.is_deleted()) + .map(|n| n.neighbor_count()) + .max() + .unwrap_or(0) + } + + /// Get minimum out-degree. + pub fn min_degree(&self) -> usize { + self.nodes + .iter() + .filter_map(|n| n.as_ref()) + .filter(|n| !n.is_deleted()) + .map(|n| n.neighbor_count()) + .min() + .unwrap_or(0) + } +} + +impl std::fmt::Debug for VamanaGraph { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VamanaGraph") + .field("node_count", &self.len()) + .field("max_neighbors", &self.max_neighbors) + .field("avg_degree", &self.average_degree()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vamana_graph_basic() { + let mut graph = VamanaGraph::new(10, 4); + + graph.set_label(0, 100); + graph.set_label(1, 101); + graph.set_label(2, 102); + + assert_eq!(graph.get_label(0), 100); + assert_eq!(graph.get_label(1), 101); + assert_eq!(graph.get_label(2), 102); + } + + #[test] + fn test_vamana_graph_neighbors() { + let mut graph = VamanaGraph::new(10, 4); + + graph.set_neighbors(0, &[1, 2, 3]); + graph.set_neighbors(1, &[0, 2]); + + assert_eq!(graph.get_neighbors(0), vec![1, 2, 3]); + assert_eq!(graph.get_neighbors(1), vec![0, 2]); + assert!(graph.get_neighbors(2).is_empty()); + } + + #[test] + fn test_vamana_graph_max_neighbors() { + let mut graph = VamanaGraph::new(10, 3); + + // Try to set more neighbors than max + graph.set_neighbors(0, &[1, 2, 3, 4, 5]); + + // Should be truncated + assert_eq!(graph.get_neighbors(0).len(), 3); + } + + #[test] + fn test_vamana_graph_delete() { + let mut graph = VamanaGraph::new(10, 4); + + graph.set_label(0, 100); + assert!(!graph.is_deleted(0)); + + graph.mark_deleted(0); + assert!(graph.is_deleted(0)); + } + + #[test] + fn test_vamana_graph_stats() { + let mut graph = VamanaGraph::new(10, 4); + + graph.set_neighbors(0, &[1, 2]); + graph.set_neighbors(1, &[0, 2, 3]); + graph.set_neighbors(2, &[0]); + + assert!((graph.average_degree() - 2.0).abs() < 0.01); + assert_eq!(graph.max_degree(), 3); + assert_eq!(graph.min_degree(), 1); + } + + #[test] + fn test_vamana_graph_data_new() { + let node = VamanaGraphData::new(8); + assert_eq!(node.get_label(), 0); + assert!(!node.is_deleted()); + assert_eq!(node.neighbor_count(), 0); + assert!(node.get_neighbors().is_empty()); + } + + #[test] + fn test_vamana_graph_data_label() { + let node = VamanaGraphData::new(4); + + node.set_label(12345); + assert_eq!(node.get_label(), 12345); + + node.set_label(u64::MAX); + assert_eq!(node.get_label(), u64::MAX); + } + + #[test] + fn test_vamana_graph_data_neighbors() { + let node = VamanaGraphData::new(4); + + node.set_neighbors(&[1, 2, 3]); + assert_eq!(node.neighbor_count(), 3); + assert_eq!(node.get_neighbors(), vec![1, 2, 3]); + + // Clear and set new neighbors + node.clear_neighbors(); + assert_eq!(node.neighbor_count(), 0); + + node.set_neighbors(&[10, 20]); + assert_eq!(node.get_neighbors(), vec![10, 20]); + } + + #[test] + fn test_vamana_graph_data_debug() { + let node = VamanaGraphData::new(4); + node.set_label(42); + node.set_neighbors(&[1, 2]); + + let debug_str = format!("{:?}", node); + assert!(debug_str.contains("VamanaGraphData")); + assert!(debug_str.contains("label")); + assert!(debug_str.contains("deleted")); + assert!(debug_str.contains("neighbor_count")); + } + + #[test] + fn test_vamana_graph_empty() { + let graph = VamanaGraph::new(10, 4); + assert!(graph.is_empty()); + assert_eq!(graph.len(), 0); + assert_eq!(graph.average_degree(), 0.0); + assert_eq!(graph.max_degree(), 0); + assert_eq!(graph.min_degree(), 0); + } + + #[test] + fn test_vamana_graph_get_nonexistent() { + let graph = VamanaGraph::new(10, 4); + + assert!(graph.get(0).is_none()); + assert!(graph.get(100).is_none()); + assert_eq!(graph.get_label(0), 0); + assert!(graph.is_deleted(100)); // Non-existent treated as deleted + assert!(graph.get_neighbors(50).is_empty()); + } + + #[test] + fn test_vamana_graph_clear_neighbors_nonexistent() { + let mut graph = VamanaGraph::new(10, 4); + // Should not panic on non-existent node + graph.clear_neighbors(0); + graph.clear_neighbors(100); + } + + #[test] + fn test_vamana_graph_mark_deleted_nonexistent() { + let mut graph = VamanaGraph::new(10, 4); + // Should not panic on non-existent node + graph.mark_deleted(0); + graph.mark_deleted(100); + } + + #[test] + fn test_vamana_graph_ensure_capacity() { + let mut graph = VamanaGraph::new(5, 4); + + // Initially capacity is 5 + graph.set_label(10, 100); + + // Should have auto-expanded + assert_eq!(graph.get_label(10), 100); + } + + #[test] + fn test_vamana_graph_deleted_not_counted_in_stats() { + let mut graph = VamanaGraph::new(10, 4); + + graph.set_neighbors(0, &[1, 2, 3, 4]); + graph.set_neighbors(1, &[0]); + graph.set_neighbors(2, &[0, 1]); + + // Before deletion + let avg_before = graph.average_degree(); + + // Delete node 0 with 4 neighbors + graph.mark_deleted(0); + + // After deletion, stats should only count non-deleted nodes + let avg_after = graph.average_degree(); + + // With node 0 deleted, we have nodes 1 and 2 with degrees 1 and 2 + assert!((avg_after - 1.5).abs() < 0.01); + assert!(avg_before > avg_after); // Should be lower after removing high-degree node + } + + #[test] + fn test_vamana_graph_debug() { + let mut graph = VamanaGraph::new(10, 4); + graph.set_neighbors(0, &[1, 2]); + graph.set_neighbors(1, &[0]); + + let debug_str = format!("{:?}", graph); + assert!(debug_str.contains("VamanaGraph")); + assert!(debug_str.contains("node_count")); + assert!(debug_str.contains("max_neighbors")); + assert!(debug_str.contains("avg_degree")); + } + + #[test] + fn test_vamana_graph_concurrent_read() { + use std::sync::Arc; + use std::thread; + + let mut graph = VamanaGraph::new(100, 10); + + // Populate graph + for i in 0..100 { + graph.set_label(i, i as u64 * 10); + let neighbors: Vec = (0..10).filter(|&j| j != i).take(5).collect(); + graph.set_neighbors(i, &neighbors); + } + + let graph = Arc::new(graph); + + // Concurrent reads + let mut handles = vec![]; + for _ in 0..4 { + let graph_clone = Arc::clone(&graph); + handles.push(thread::spawn(move || { + for i in 0..100 { + let _ = graph_clone.get_label(i); + let _ = graph_clone.get_neighbors(i); + let _ = graph_clone.is_deleted(i); + } + })); + } + + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_vamana_graph_len_with_gaps() { + let mut graph = VamanaGraph::new(10, 4); + + // Create nodes with gaps + graph.set_label(0, 100); + graph.set_label(5, 500); + graph.set_label(9, 900); + + assert_eq!(graph.len(), 3); + } + + #[test] + fn test_vamana_graph_data_neighbors_replace() { + let node = VamanaGraphData::new(4); + + node.set_neighbors(&[1, 2, 3, 4]); + assert_eq!(node.get_neighbors(), vec![1, 2, 3, 4]); + + // Replace with different neighbors + node.set_neighbors(&[10, 20]); + assert_eq!(node.get_neighbors(), vec![10, 20]); + assert_eq!(node.neighbor_count(), 2); + } +} diff --git a/rust/vecsim/src/index/svs/mod.rs b/rust/vecsim/src/index/svs/mod.rs new file mode 100644 index 000000000..b3280aa4b --- /dev/null +++ b/rust/vecsim/src/index/svs/mod.rs @@ -0,0 +1,450 @@ +//! SVS (Vamana) index implementation. +//! +//! SVS (Search via Satellite) is based on the Vamana algorithm, a graph-based +//! approximate nearest neighbor index that provides high recall with efficient +//! construction and query performance. +//! +//! ## Key Differences from HNSW +//! +//! | Aspect | HNSW | Vamana | +//! |--------|------|--------| +//! | Structure | Multi-layer hierarchical | Single flat layer | +//! | Entry point | Random high-level node | Medoid (centroid) | +//! | Construction | Random level + single pass | Two passes with alpha | +//! | Search | Layer traversal | Greedy beam search | +//! +//! ## Key Parameters +//! +//! - `graph_max_degree` (R): Maximum number of neighbors per node (default: 32) +//! - `alpha`: Pruning parameter for robust neighbor selection (default: 1.2) +//! - `construction_window_size` (L): Beam width during construction (default: 200) +//! - `search_window_size`: Beam width during search (runtime) +//! +//! ## Algorithm Overview +//! +//! ### Construction (Two-Pass) +//! +//! 1. Find medoid as entry point (approximate centroid) +//! 2. Pass 1: Build initial graph with alpha = 1.0 +//! 3. Pass 2: Refine graph with configured alpha (e.g., 1.2) +//! +//! ### Robust Pruning (Alpha) +//! +//! A neighbor is selected only if: +//! `dist(neighbor, any_selected) * alpha >= dist(neighbor, target)` +//! +//! This ensures neighbors are diverse (not all clustered together). +//! +//! ### Search +//! +//! Greedy beam search from medoid entry point. + +pub mod graph; +pub mod search; +pub mod single; +pub mod multi; + +pub use graph::{VamanaGraph, VamanaGraphData}; +pub use single::{SvsSingle, SvsStats, SvsSingleBatchIterator}; +pub use multi::{SvsMulti, SvsMultiBatchIterator}; + +use crate::containers::DataBlocks; +use crate::distance::{create_distance_function, DistanceFunction, Metric}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement, INVALID_ID}; +use crate::index::hnsw::VisitedNodesHandlerPool; +use std::sync::atomic::{AtomicU32, Ordering}; + +/// Default maximum graph degree (R). +pub const DEFAULT_GRAPH_DEGREE: usize = 32; + +/// Default alpha for robust pruning. +pub const DEFAULT_ALPHA: f32 = 1.2; + +/// Default construction window size (L). +pub const DEFAULT_CONSTRUCTION_L: usize = 200; + +/// Default search window size. +pub const DEFAULT_SEARCH_L: usize = 100; + +/// Parameters for creating an SVS (Vamana) index. +#[derive(Debug, Clone)] +pub struct SvsParams { + /// Vector dimension. + pub dim: usize, + /// Distance metric. + pub metric: Metric, + /// Maximum number of neighbors per node (R). + pub graph_max_degree: usize, + /// Alpha parameter for robust pruning (1.0 = no diversity, 1.2 = moderate). + pub alpha: f32, + /// Beam width during construction (L). + pub construction_window_size: usize, + /// Default beam width during search. + pub search_window_size: usize, + /// Initial capacity (number of vectors). + pub initial_capacity: usize, + /// Use two-pass construction (recommended for better recall). + pub two_pass_construction: bool, +} + +impl SvsParams { + /// Create new parameters with required fields. + pub fn new(dim: usize, metric: Metric) -> Self { + Self { + dim, + metric, + graph_max_degree: DEFAULT_GRAPH_DEGREE, + alpha: DEFAULT_ALPHA, + construction_window_size: DEFAULT_CONSTRUCTION_L, + search_window_size: DEFAULT_SEARCH_L, + initial_capacity: 1024, + two_pass_construction: true, + } + } + + /// Set maximum graph degree (R). + pub fn with_graph_degree(mut self, r: usize) -> Self { + self.graph_max_degree = r; + self + } + + /// Set alpha parameter for robust pruning. + pub fn with_alpha(mut self, alpha: f32) -> Self { + self.alpha = alpha; + self + } + + /// Set construction window size (L). + pub fn with_construction_l(mut self, l: usize) -> Self { + self.construction_window_size = l; + self + } + + /// Set search window size. + pub fn with_search_l(mut self, l: usize) -> Self { + self.search_window_size = l; + self + } + + /// Set initial capacity. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.initial_capacity = capacity; + self + } + + /// Enable/disable two-pass construction. + pub fn with_two_pass(mut self, enable: bool) -> Self { + self.two_pass_construction = enable; + self + } +} + +/// Core SVS implementation shared between single and multi variants. +pub(crate) struct SvsCore { + /// Vector storage. + pub data: DataBlocks, + /// Graph structure. + pub graph: VamanaGraph, + /// Distance function. + pub dist_fn: Box>, + /// Medoid (entry point). + pub medoid: AtomicU32, + /// Pool of visited handlers for concurrent searches. + pub visited_pool: VisitedNodesHandlerPool, + /// Parameters. + pub params: SvsParams, +} + +impl SvsCore { + /// Create a new SVS core. + pub fn new(params: SvsParams) -> Self { + let data = DataBlocks::new(params.dim, params.initial_capacity); + let dist_fn = create_distance_function(params.metric, params.dim); + let graph = VamanaGraph::new(params.initial_capacity, params.graph_max_degree); + let visited_pool = VisitedNodesHandlerPool::new(params.initial_capacity); + + Self { + data, + graph, + dist_fn, + medoid: AtomicU32::new(INVALID_ID), + visited_pool, + params, + } + } + + /// Add a vector and return its internal ID. + pub fn add_vector(&mut self, vector: &[T]) -> Option { + let processed = self.dist_fn.preprocess(vector, self.params.dim); + self.data.add(&processed) + } + + /// Get vector data by ID. + #[inline] + pub fn get_vector(&self, id: IdType) -> Option<&[T]> { + self.data.get(id) + } + + /// Find the medoid (approximate centroid) of all vectors. + pub fn find_medoid(&self) -> Option { + let ids: Vec = self.data.iter_ids().collect(); + if ids.is_empty() { + return None; + } + if ids.len() == 1 { + return Some(ids[0]); + } + + // Sample at most 1000 vectors for medoid computation + let sample_size = ids.len().min(1000); + let sample: Vec = if ids.len() <= sample_size { + ids.clone() + } else { + use rand::seq::SliceRandom; + let mut rng = rand::thread_rng(); + let mut shuffled = ids.clone(); + shuffled.shuffle(&mut rng); + shuffled.into_iter().take(sample_size).collect() + }; + + // Find vector with minimum total distance + let mut best_id = sample[0]; + let mut best_total_dist = f64::MAX; + + for &candidate in &sample { + if let Some(candidate_data) = self.get_vector(candidate) { + let total_dist: f64 = sample + .iter() + .filter(|&&id| id != candidate) + .filter_map(|&id| self.get_vector(id)) + .map(|other_data| { + self.dist_fn + .compute(candidate_data, other_data, self.params.dim) + .to_f64() + }) + .sum(); + + if total_dist < best_total_dist { + best_total_dist = total_dist; + best_id = candidate; + } + } + } + + Some(best_id) + } + + /// Insert a new element into the graph. + pub fn insert(&mut self, id: IdType, label: LabelType) { + // Ensure graph has space + self.graph.ensure_capacity(id as usize + 1); + self.graph.set_label(id, label); + + // Update visited pool if needed + if (id as usize) >= self.visited_pool.current_capacity() { + self.visited_pool.resize(id as usize + 1024); + } + + let medoid = self.medoid.load(Ordering::Acquire); + + if medoid == INVALID_ID { + // First element becomes the medoid + self.medoid.store(id, Ordering::Release); + return; + } + + // Get query vector + let query = match self.get_vector(id) { + Some(v) => v, + None => return, + }; + + // Search for neighbors and select using robust pruning + // Use a block to limit the lifetime of `visited` before mutable operations + let selected = { + let mut visited = self.visited_pool.get(); + visited.reset(); + + let neighbors = search::greedy_beam_search( + medoid, + query, + self.params.construction_window_size, + &self.graph, + |nid| self.data.get(nid), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + ); + + search::robust_prune( + id, + &neighbors, + self.params.graph_max_degree, + self.params.alpha, + |nid| self.data.get(nid), + self.dist_fn.as_ref(), + self.params.dim, + ) + }; + + // Set outgoing edges + self.graph.set_neighbors(id, &selected); + + // Add bidirectional edges (after visited is dropped) + for &neighbor_id in &selected { + self.add_bidirectional_link(neighbor_id, id); + } + } + + /// Add a bidirectional link from one node to another. + fn add_bidirectional_link(&mut self, from: IdType, to: IdType) { + let mut current_neighbors = self.graph.get_neighbors(from); + if current_neighbors.contains(&to) { + return; + } + + current_neighbors.push(to); + + // Check if we need to prune + if current_neighbors.len() > self.params.graph_max_degree { + if let Some(from_data) = self.data.get(from) { + let candidates: Vec<_> = current_neighbors + .iter() + .filter_map(|&n| { + self.data.get(n).map(|data| { + let dist = self.dist_fn.compute(data, from_data, self.params.dim); + (n, dist) + }) + }) + .collect(); + + let selected = search::robust_prune( + from, + &candidates, + self.params.graph_max_degree, + self.params.alpha, + |id| self.data.get(id), + self.dist_fn.as_ref(), + self.params.dim, + ); + + self.graph.set_neighbors(from, &selected); + } + } else { + self.graph.set_neighbors(from, ¤t_neighbors); + } + } + + /// Rebuild the graph (second pass of two-pass construction). + pub fn rebuild_graph(&mut self) { + let ids: Vec = self.data.iter_ids().collect(); + + // Update medoid + if let Some(new_medoid) = self.find_medoid() { + self.medoid.store(new_medoid, Ordering::Release); + } + + // Clear existing neighbors + for &id in &ids { + self.graph.clear_neighbors(id); + } + + // Reinsert all nodes with current alpha + for &id in &ids { + let label = self.graph.get_label(id); + self.insert_without_label_update(id, label); + } + } + + /// Insert without updating the label (used during rebuild). + fn insert_without_label_update(&mut self, id: IdType, _label: LabelType) { + let medoid = self.medoid.load(Ordering::Acquire); + if medoid == INVALID_ID || medoid == id { + return; + } + + let query = match self.get_vector(id) { + Some(v) => v, + None => return, + }; + + // Use a block to limit the lifetime of `visited` before mutable operations + let selected = { + let mut visited = self.visited_pool.get(); + visited.reset(); + + let neighbors = search::greedy_beam_search( + medoid, + query, + self.params.construction_window_size, + &self.graph, + |nid| self.data.get(nid), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + ); + + search::robust_prune( + id, + &neighbors, + self.params.graph_max_degree, + self.params.alpha, + |nid| self.data.get(nid), + self.dist_fn.as_ref(), + self.params.dim, + ) + }; + + self.graph.set_neighbors(id, &selected); + + for &neighbor_id in &selected { + self.add_bidirectional_link(neighbor_id, id); + } + } + + /// Mark an element as deleted. + pub fn mark_deleted(&mut self, id: IdType) { + self.graph.mark_deleted(id); + self.data.mark_deleted(id); + + // Update medoid if needed + if self.medoid.load(Ordering::Acquire) == id { + if let Some(new_medoid) = self.find_medoid() { + self.medoid.store(new_medoid, Ordering::Release); + } else { + self.medoid.store(INVALID_ID, Ordering::Release); + } + } + } + + /// Search for nearest neighbors. + pub fn search( + &self, + query: &[T], + k: usize, + search_l: usize, + filter: Option<&dyn Fn(IdType) -> bool>, + ) -> Vec<(IdType, T::DistanceType)> { + let medoid = self.medoid.load(Ordering::Acquire); + if medoid == INVALID_ID { + return Vec::new(); + } + + let mut visited = self.visited_pool.get(); + visited.reset(); + + let results = search::greedy_beam_search_filtered( + medoid, + query, + search_l.max(k), + &self.graph, + |id| self.data.get(id), + self.dist_fn.as_ref(), + self.params.dim, + &visited, + filter, + ); + + results.into_iter().take(k).collect() + } +} diff --git a/rust/vecsim/src/index/svs/multi.rs b/rust/vecsim/src/index/svs/multi.rs new file mode 100644 index 000000000..3f05e67cc --- /dev/null +++ b/rust/vecsim/src/index/svs/multi.rs @@ -0,0 +1,1048 @@ +//! Multi-value SVS (Vamana) index implementation. +//! +//! This index allows multiple vectors per label. + +use super::{SvsCore, SvsParams}; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply, QueryResult}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement, INVALID_ID}; +use parking_lot::RwLock; +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Multi-value SVS (Vamana) index. +/// +/// Each label can have multiple associated vectors. +pub struct SvsMulti { + /// Core SVS implementation. + core: RwLock>, + /// Label to internal IDs mapping. + label_to_ids: RwLock>>, + /// Internal ID to label mapping. + id_to_label: RwLock>, + /// Number of vectors. + count: AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, + /// Construction completed flag. + construction_done: RwLock, +} + +impl SvsMulti { + /// Create a new multi-value SVS index. + pub fn new(params: SvsParams) -> Self { + let initial_capacity = params.initial_capacity; + Self { + core: RwLock::new(SvsCore::new(params)), + label_to_ids: RwLock::new(HashMap::with_capacity(initial_capacity)), + id_to_label: RwLock::new(HashMap::with_capacity(initial_capacity)), + count: AtomicUsize::new(0), + capacity: None, + construction_done: RwLock::new(false), + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: SvsParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the distance metric. + pub fn metric(&self) -> crate::distance::Metric { + self.core.read().params.metric + } + + /// Get all internal IDs for a label. + pub fn get_ids(&self, label: LabelType) -> Option> { + self.label_to_ids.read().get(&label).cloned() + } + + /// Build the index after all vectors have been added. + pub fn build(&self) { + let mut done = self.construction_done.write(); + if *done { + return; + } + + let mut core = self.core.write(); + if core.params.two_pass_construction && core.data.len() > 1 { + core.rebuild_graph(); + } + + *done = true; + } + + /// Get the number of unique labels. + pub fn unique_labels(&self) -> usize { + self.label_to_ids.read().len() + } + + /// Get all vectors for a label. + pub fn get_all_vectors(&self, label: LabelType) -> Option>> { + let ids = self.label_to_ids.read().get(&label)?.clone(); + let core = self.core.read(); + + let vectors: Vec<_> = ids + .iter() + .filter_map(|&id| core.data.get(id).map(|v| v.to_vec())) + .collect(); + + if vectors.is_empty() { + None + } else { + Some(vectors) + } + } + + /// Estimate memory usage. + pub fn memory_usage(&self) -> usize { + let core = self.core.read(); + let count = self.count.load(Ordering::Relaxed); + + let vector_size = count * core.params.dim * std::mem::size_of::(); + let graph_size = count * core.params.graph_max_degree * std::mem::size_of::(); + + let label_size: usize = self + .label_to_ids + .read() + .values() + .map(|ids| std::mem::size_of::() + ids.len() * std::mem::size_of::()) + .sum(); + + let id_label_size = self.id_to_label.read().capacity() + * (std::mem::size_of::() + std::mem::size_of::()); + + vector_size + graph_size + label_size + id_label_size + } + + /// Get the medoid (entry point) ID. + pub fn medoid(&self) -> Option { + let ep = self.core.read().medoid.load(Ordering::Relaxed); + if ep == INVALID_ID { + None + } else { + Some(ep) + } + } + + /// Get the fragmentation ratio (0.0 = none, 1.0 = all deleted). + pub fn fragmentation(&self) -> f64 { + self.core.read().data.fragmentation() + } + + /// Compact the index to reclaim space from deleted vectors. + /// + /// Note: SVS doesn't support true compaction without rebuilding. + /// This method returns 0 as a placeholder. + pub fn compact(&mut self, _shrink: bool) -> usize { + // SVS requires rebuild for true compaction + 0 + } + + /// Clear all vectors from the index. + pub fn clear(&mut self) { + let mut core = self.core.write(); + let params = core.params.clone(); + *core = SvsCore::new(params); + + self.label_to_ids.write().clear(); + self.id_to_label.write().clear(); + self.count.store(0, Ordering::Relaxed); + *self.construction_done.write() = false; + } +} + +impl VecSimIndex for SvsMulti { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + let mut core = self.core.write(); + + if vector.len() != core.params.dim { + return Err(IndexError::DimensionMismatch { + expected: core.params.dim, + got: vector.len(), + }); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add vector + let id = core + .add_vector(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + core.insert(id, label); + + // Update mappings + self.label_to_ids.write().entry(label).or_default().insert(id); + self.id_to_label.write().insert(id, label); + + self.count.fetch_add(1, Ordering::Relaxed); + *self.construction_done.write() = false; + + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let mut core = self.core.write(); + let mut label_to_ids = self.label_to_ids.write(); + let mut id_to_label = self.id_to_label.write(); + + let ids = label_to_ids + .remove(&label) + .ok_or(IndexError::LabelNotFound(label))?; + + let count = ids.len(); + for id in ids { + core.mark_deleted(id); + id_to_label.remove(&id); + } + + self.count.fetch_sub(count, Ordering::Relaxed); + Ok(count) + } + + fn top_k_query( + &self, + query: &[T], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + let core = self.core.read(); + + if query.len() != core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: core.params.dim, + got: query.len(), + }); + } + + if k == 0 { + return Ok(QueryReply::new()); + } + + let search_l = params + .and_then(|p| p.ef_runtime) + .unwrap_or(core.params.search_window_size); + + // Build filter if needed + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to RwLock directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.read().get(&id).is_some_and(|&label| f(label)) + })) + } else { + None + } + } else { + None + }; + + // Request more results to account for duplicates when multiple vectors share labels + // We need to search wider because many results may map to the same label + let count = self.count.load(Ordering::Relaxed); + let unique_labels = self.label_to_ids.read().len(); + let avg_vectors_per_label = if unique_labels > 0 { + count / unique_labels + } else { + 1 + }; + // Expand k by the average multiplicity plus some margin to ensure we find k unique labels + let expanded_k = (k * (avg_vectors_per_label + 2)).min(count).max(k); + let results = core.search( + query, + expanded_k, + search_l.max(expanded_k), + filter_fn.as_ref().map(|f| f.as_ref()), + ); + + // Deduplicate by label, keeping the best (lowest distance) result for each label + let id_to_label = self.id_to_label.read(); + let mut best_by_label: HashMap = HashMap::with_capacity(k); + for (id, dist) in results { + if let Some(&label) = id_to_label.get(&id) { + best_by_label + .entry(label) + .and_modify(|existing| { + if dist.to_f64() < existing.to_f64() { + *existing = dist; + } + }) + .or_insert(dist); + } + } + + // Convert to reply and sort by distance + let mut reply = QueryReply::with_capacity(best_by_label.len().min(k)); + for (label, dist) in best_by_label { + reply.push(QueryResult::new(label, dist)); + } + reply.sort_by_distance(); + reply.truncate(k); + + Ok(reply) + } + + fn range_query( + &self, + query: &[T], + radius: T::DistanceType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + let core = self.core.read(); + + if query.len() != core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: core.params.dim, + got: query.len(), + }); + } + + let count = self.count.load(Ordering::Relaxed); + let search_l = params + .and_then(|p| p.ef_runtime) + .unwrap_or(core.params.search_window_size) + .max(count.min(1000)); + + // Build filter if needed + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to RwLock directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.read().get(&id).is_some_and(|&label| f(label)) + })) + } else { + None + } + } else { + None + }; + + let results = core.search(query, count, search_l, filter_fn.as_ref().map(|f| f.as_ref())); + + // Deduplicate by label, keeping the best (lowest distance) result for each label + // and filter by radius + let id_to_label = self.id_to_label.read(); + let mut best_by_label: HashMap = HashMap::new(); + for (id, dist) in results { + if dist.to_f64() <= radius.to_f64() { + if let Some(&label) = id_to_label.get(&id) { + best_by_label + .entry(label) + .and_modify(|existing| { + if dist.to_f64() < existing.to_f64() { + *existing = dist; + } + }) + .or_insert(dist); + } + } + } + + // Convert to reply and sort by distance + let mut reply = QueryReply::with_capacity(best_by_label.len()); + for (label, dist) in best_by_label { + reply.push(QueryResult::new(label, dist)); + } + reply.sort_by_distance(); + Ok(reply) + } + + fn index_size(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.core.read().params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[T], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + let core = self.core.read(); + if query.len() != core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: core.params.dim, + got: query.len(), + }); + } + + let search_l = params + .and_then(|p| p.ef_runtime) + .unwrap_or(core.params.search_window_size); + + let count = self.count.load(Ordering::Relaxed); + + // Build filter if needed + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to RwLock directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.read().get(&id).is_some_and(|&label| f(label)) + })) + } else { + None + } + } else { + None + }; + + // Use the specified search_l to affect search quality + // A smaller search_l means more greedy search with potentially worse results + let raw_results = core.search( + query, + count, + search_l, + filter_fn.as_ref().map(|f| f.as_ref()), + ); + + // Batch iterator returns all vectors (not deduplicated by label) + let id_to_label = self.id_to_label.read(); + let results: Vec<_> = raw_results + .into_iter() + .filter_map(|(id, dist)| { + id_to_label.get(&id).map(|&label| (id, label, dist)) + }) + .collect(); + + Ok(Box::new(SvsMultiBatchIterator::::new(results))) + } + + fn info(&self) -> IndexInfo { + let core = self.core.read(); + let count = self.count.load(Ordering::Relaxed); + + IndexInfo { + size: count, + capacity: self.capacity, + dimension: core.params.dim, + index_type: "SvsMulti", + memory_bytes: self.memory_usage(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.label_to_ids.read().contains_key(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + self.label_to_ids + .read() + .get(&label) + .map(|ids| ids.len()) + .unwrap_or(0) + } +} + +// Allow read-only concurrent access for queries +unsafe impl Send for SvsMulti {} +unsafe impl Sync for SvsMulti {} + +/// Batch iterator for SvsMulti. +pub struct SvsMultiBatchIterator { + results: Vec<(IdType, LabelType, T::DistanceType)>, + current_idx: usize, +} + +impl SvsMultiBatchIterator { + /// Create a new batch iterator with pre-computed results. + pub fn new(results: Vec<(IdType, LabelType, T::DistanceType)>) -> Self { + Self { + results, + current_idx: 0, + } + } +} + +impl BatchIterator for SvsMultiBatchIterator { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + self.current_idx < self.results.len() + } + + fn next_batch(&mut self, batch_size: usize) -> Option> { + if self.current_idx >= self.results.len() { + return None; + } + + let end = (self.current_idx + batch_size).min(self.results.len()); + let batch = self.results[self.current_idx..end].to_vec(); + self.current_idx = end; + + if batch.is_empty() { + None + } else { + Some(batch) + } + } + + fn reset(&mut self) { + self.current_idx = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_svs_multi_basic() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Add multiple vectors per label + assert_eq!(index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(), 1); + assert_eq!(index.add_vector(&[0.9, 0.1, 0.0, 0.0], 1).unwrap(), 1); + assert_eq!(index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(), 1); + + assert_eq!(index.index_size(), 3); + assert_eq!(index.unique_labels(), 2); + assert_eq!(index.label_count(1), 2); + assert_eq!(index.label_count(2), 1); + } + + #[test] + fn test_svs_multi_delete() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + // Delete label 1 (removes both vectors) + assert_eq!(index.delete_vector(1).unwrap(), 2); + + assert_eq!(index.unique_labels(), 1); + assert_eq!(index.label_count(1), 0); + assert_eq!(index.label_count(2), 1); + } + + #[test] + fn test_svs_multi_get_all_vectors() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.9, 0.1, 0.0, 0.0], 1).unwrap(); + + let vectors = index.get_all_vectors(1).unwrap(); + assert_eq!(vectors.len(), 2); + } + + #[test] + fn test_svs_multi_query() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + for i in 0..50 { + let mut v = vec![0.0f32; 4]; + v[i % 4] = 1.0; + index.add_vector(&v, (i / 10) as u64).unwrap(); + } + + let query = [1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 5, None).unwrap(); + + assert!(!results.results.is_empty()); + } + + #[test] + fn test_svs_multi_batch_iterator_with_filter() { + use crate::query::QueryParams; + + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Add vectors: 2 vectors each for labels 1-5 + for i in 1..=5u64 { + let v1 = vec![i as f32, 0.0, 0.0, 0.0]; + let v2 = vec![i as f32, 0.1, 0.0, 0.0]; + index.add_vector(&v1, i).unwrap(); + index.add_vector(&v2, i).unwrap(); + } + + assert_eq!(index.index_size(), 10); // 5 labels * 2 vectors each + + // Create a filter that only allows labels 2 and 4 + let query_params = QueryParams::new().with_filter(|label| label == 2 || label == 4); + + let query = vec![3.0, 0.0, 0.0, 0.0]; + + // Test batch iterator with filter + let mut iter = index.batch_iterator(&query, Some(&query_params)).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should have 4 vectors (2 for label 2, 2 for label 4) + assert_eq!(all_results.len(), 4); + for (_, label, _) in &all_results { + assert!( + *label == 2 || *label == 4, + "Expected only labels 2 or 4, got {}", + label + ); + } + } + + #[test] + fn test_svs_multi_batch_iterator_no_filter() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Add vectors: 2 vectors each for labels 1-3 + for i in 1..=3u64 { + let v1 = vec![i as f32, 0.0, 0.0, 0.0]; + let v2 = vec![i as f32, 0.1, 0.0, 0.0]; + index.add_vector(&v1, i).unwrap(); + index.add_vector(&v2, i).unwrap(); + } + + let query = vec![2.0, 0.0, 0.0, 0.0]; + + // Test batch iterator without filter + let mut iter = index.batch_iterator(&query, None).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should have all 6 vectors + assert_eq!(all_results.len(), 6); + } + + #[test] + fn test_svs_multi_multiple_labels() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Add multiple vectors per label for multiple labels + for label in 1..=5u64 { + for i in 0..3 { + let v = vec![label as f32, i as f32, 0.0, 0.0]; + index.add_vector(&v, label).unwrap(); + } + } + + assert_eq!(index.index_size(), 15); + for label in 1..=5u64 { + assert_eq!(index.label_count(label), 3); + } + assert_eq!(index.unique_labels(), 5); + } + + #[test] + fn test_svs_multi_contains() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + assert!(index.contains(1)); + assert!(!index.contains(2)); + } + + #[test] + fn test_svs_multi_dimension_mismatch() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Wrong dimension on add + let result = index.add_vector(&[1.0, 2.0], 1); + assert!(matches!(result, Err(IndexError::DimensionMismatch { expected: 4, got: 2 }))); + + // Add a valid vector first + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Wrong dimension on query + let result = index.top_k_query(&[1.0, 2.0], 1, None); + assert!(matches!(result, Err(QueryError::DimensionMismatch { expected: 4, got: 2 }))); + } + + #[test] + fn test_svs_multi_capacity_exceeded() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::with_capacity(params, 2); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + // Third should fail + let result = index.add_vector(&[3.0, 0.0, 0.0, 0.0], 3); + assert!(matches!(result, Err(IndexError::CapacityExceeded { capacity: 2 }))); + } + + #[test] + fn test_svs_multi_delete_not_found() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let result = index.delete_vector(999); + assert!(matches!(result, Err(IndexError::LabelNotFound(999)))); + } + + #[test] + fn test_svs_multi_memory_usage() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + let initial_memory = index.memory_usage(); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + let after_memory = index.memory_usage(); + assert!(after_memory > initial_memory); + } + + #[test] + fn test_svs_multi_info() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let info = index.info(); + assert_eq!(info.size, 2); + assert_eq!(info.dimension, 4); + assert_eq!(info.index_type, "SvsMulti"); + assert!(info.memory_bytes > 0); + } + + #[test] + fn test_svs_multi_clear() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + index.clear(); + + assert_eq!(index.index_size(), 0); + assert_eq!(index.unique_labels(), 0); + assert!(!index.contains(1)); + assert!(!index.contains(2)); + } + + #[test] + fn test_svs_multi_with_capacity() { + let params = SvsParams::new(4, Metric::L2); + let index = SvsMulti::::with_capacity(params, 100); + + assert_eq!(index.index_capacity(), Some(100)); + } + + #[test] + fn test_svs_multi_inner_product() { + let params = SvsParams::new(4, Metric::InnerProduct); + let mut index = SvsMulti::::new(params); + + // For InnerProduct, higher dot product = lower "distance" + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + let query = [1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + // Label 1 has perfect alignment with query + assert_eq!(results.results[0].label, 1); + } + + #[test] + fn test_svs_multi_cosine() { + let params = SvsParams::new(4, Metric::Cosine); + let mut index = SvsMulti::::new(params); + + // Cosine similarity is direction-based + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 3).unwrap(); // Same direction as label 1 + + let query = [1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + // Labels 1 and 3 have the same direction (cosine=1), should be top results + assert!(results.results[0].label == 1 || results.results[0].label == 3); + assert!(results.results[1].label == 1 || results.results[1].label == 3); + } + + #[test] + fn test_svs_multi_metric_getter() { + let params = SvsParams::new(4, Metric::Cosine); + let index = SvsMulti::::new(params); + + assert_eq!(index.metric(), Metric::Cosine); + } + + #[test] + fn test_svs_multi_range_query() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Add vectors at different distances + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[3.0, 0.0, 0.0, 0.0], 3).unwrap(); + index.add_vector(&[10.0, 0.0, 0.0, 0.0], 4).unwrap(); + + let query = [0.0, 0.0, 0.0, 0.0]; + // L2 squared: dist to [1,0,0,0]=1, [2,0,0,0]=4, [3,0,0,0]=9, [10,0,0,0]=100 + let results = index.range_query(&query, 10.0, None).unwrap(); + + assert_eq!(results.len(), 3); // labels 1, 2, 3 are within radius 10 + for r in &results.results { + assert!(r.label != 4); // label 4 should not be included + } + } + + #[test] + fn test_svs_multi_filtered_query() { + use crate::query::QueryParams; + + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + for label in 1..=10u64 { + index.add_vector(&[label as f32, 0.0, 0.0, 0.0], label).unwrap(); + } + + // Filter to only allow even labels + let query_params = QueryParams::new().with_filter(|label| label % 2 == 0); + let query = [0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, Some(&query_params)).unwrap(); + + assert_eq!(results.len(), 5); // Only labels 2, 4, 6, 8, 10 + for r in &results.results { + assert!(r.label % 2 == 0); + } + } + + #[test] + fn test_svs_multi_empty_query() { + let params = SvsParams::new(4, Metric::L2); + let index = SvsMulti::::new(params); + + // Query on empty index + let query = [1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, None).unwrap(); + assert!(results.is_empty()); + } + + #[test] + fn test_svs_multi_fragmentation() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + for i in 1..=10u64 { + index.add_vector(&[i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + // Initially no fragmentation + assert!((index.fragmentation() - 0.0).abs() < 0.01); + + // Delete half the vectors + for i in 1..=5u64 { + index.delete_vector(i).unwrap(); + } + + // Now there should be fragmentation + assert!(index.fragmentation() > 0.3); + } + + #[test] + fn test_svs_multi_medoid() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Empty index has no medoid + assert!(index.medoid().is_none()); + + // Add a vector + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Now should have a medoid + assert!(index.medoid().is_some()); + } + + #[test] + fn test_svs_multi_build() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Add vectors + for i in 0..20u64 { + index.add_vector(&[i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + // Build the index (this should complete without error) + index.build(); + + // Query should work + let query = [10.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 5, None).unwrap(); + assert!(!results.is_empty()); + } + + #[test] + fn test_svs_multi_get_ids() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[2.0, 0.0, 0.0, 0.0], 1).unwrap(); + + let ids = index.get_ids(1).unwrap(); + assert_eq!(ids.len(), 2); + + // Non-existent label + assert!(index.get_ids(999).is_none()); + } + + #[test] + fn test_svs_multi_filtered_range_query() { + use crate::query::QueryParams; + + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + for label in 1..=10u64 { + index.add_vector(&[label as f32, 0.0, 0.0, 0.0], label).unwrap(); + } + + // Filter to only allow labels 1-5, with range 50 (covers labels 1-7 by distance) + let query_params = QueryParams::new().with_filter(|label| label <= 5); + let query = [0.0, 0.0, 0.0, 0.0]; + let results = index.range_query(&query, 50.0, Some(&query_params)).unwrap(); + + // Should have labels 1-5 (filtered) that are within range 50 + assert_eq!(results.len(), 5); + for r in &results.results { + assert!(r.label <= 5); + } + } + + #[test] + fn test_svs_multi_compact() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + for i in 1..=10u64 { + index.add_vector(&[i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + // Delete some vectors + for i in 1..=5u64 { + index.delete_vector(i).unwrap(); + } + + // Compact returns 0 for SVS (placeholder) + let reclaimed = index.compact(true); + assert_eq!(reclaimed, 0); + + // Queries should still work + let query = [8.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + assert!(!results.is_empty()); + } + + #[test] + fn test_svs_multi_query_with_ef_runtime() { + use crate::query::QueryParams; + + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + for i in 0..50u64 { + index.add_vector(&[i as f32, 0.0, 0.0, 0.0], i).unwrap(); + } + + let query = [25.0, 0.0, 0.0, 0.0]; + + // Query with default search window + let results1 = index.top_k_query(&query, 5, None).unwrap(); + assert!(!results1.is_empty()); + + // Query with larger search window (via ef_runtime) + let query_params = QueryParams::new().with_ef_runtime(100); + let results2 = index.top_k_query(&query, 5, Some(&query_params)).unwrap(); + assert!(!results2.is_empty()); + } + + #[test] + fn test_svs_multi_larger_scale() { + let params = SvsParams::new(8, Metric::L2); + let mut index = SvsMulti::::new(params); + + // Add 500 vectors, 5 per label + for label in 0..100u64 { + for i in 0..5 { + let v = vec![ + label as f32 + i as f32 * 0.01, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ]; + index.add_vector(&v, label).unwrap(); + } + } + + assert_eq!(index.index_size(), 500); + assert_eq!(index.label_count(50), 5); + + // Build the index + index.build(); + + // Query should work + let query = vec![50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, None).unwrap(); + + assert!(!results.is_empty()); + // First result should be label 50 (closest) + assert_eq!(results.results[0].label, 50); + } + + #[test] + fn test_svs_multi_zero_k_query() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsMulti::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Query with k=0 should return empty + let query = [1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 0, None).unwrap(); + assert!(results.is_empty()); + } +} diff --git a/rust/vecsim/src/index/svs/search.rs b/rust/vecsim/src/index/svs/search.rs new file mode 100644 index 000000000..1b0da3b63 --- /dev/null +++ b/rust/vecsim/src/index/svs/search.rs @@ -0,0 +1,832 @@ +//! Search algorithms for Vamana graph traversal. +//! +//! This module provides: +//! - `greedy_beam_search`: Beam search from entry point +//! - `robust_prune`: Alpha-based diverse neighbor selection + +use super::graph::VamanaGraph; +use crate::distance::DistanceFunction; +use crate::index::hnsw::VisitedNodesHandler; +use crate::types::{DistanceType, IdType, VectorElement}; +use crate::utils::{MaxHeap, MinHeap}; + +/// Result of a search: (id, distance) pairs sorted by distance. +pub type SearchResult = Vec<(IdType, D)>; + +/// Greedy beam search from entry point. +/// +/// Explores the graph using a beam of candidates, maintaining the +/// best candidates found so far. +#[allow(clippy::too_many_arguments)] +pub fn greedy_beam_search<'a, T, D, F>( + entry_point: IdType, + query: &[T], + beam_width: usize, + graph: &VamanaGraph, + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, + visited: &VisitedNodesHandler, +) -> SearchResult +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, +{ + greedy_beam_search_filtered:: bool>( + entry_point, + query, + beam_width, + graph, + data_getter, + dist_fn, + dim, + visited, + None, + ) +} + +/// Greedy beam search with optional filter. +#[allow(clippy::too_many_arguments)] +pub fn greedy_beam_search_filtered<'a, T, D, F, P>( + entry_point: IdType, + query: &[T], + beam_width: usize, + graph: &VamanaGraph, + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, + visited: &VisitedNodesHandler, + filter: Option<&P>, +) -> SearchResult +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, + P: Fn(IdType) -> bool + ?Sized, +{ + // Candidates to explore (min-heap: closest first) + let mut candidates = MinHeap::::with_capacity(beam_width * 2); + + // Results (max-heap: keeps L closest, largest at top) + let mut results = MaxHeap::::new(beam_width); + + // Initialize with entry point + visited.visit(entry_point); + + if let Some(entry_data) = data_getter(entry_point) { + let dist = dist_fn.compute(entry_data, query, dim); + candidates.push(entry_point, dist); + + if !graph.is_deleted(entry_point) { + let passes = filter.is_none_or(|f| f(entry_point)); + if passes { + results.insert(entry_point, dist); + } + } + } + + // Explore candidates + while let Some(candidate) = candidates.pop() { + // Check if we can stop early + if results.is_full() { + if let Some(worst_dist) = results.top_distance() { + if candidate.distance.to_f64() > worst_dist.to_f64() { + break; + } + } + } + + // Skip deleted nodes + if graph.is_deleted(candidate.id) { + continue; + } + + // Explore neighbors + for neighbor in graph.get_neighbors(candidate.id) { + if visited.visit(neighbor) { + continue; // Already visited + } + + // Skip deleted neighbors + if graph.is_deleted(neighbor) { + continue; + } + + // Compute distance + if let Some(neighbor_data) = data_getter(neighbor) { + let dist = dist_fn.compute(neighbor_data, query, dim); + + // Check filter + let passes = filter.is_none_or(|f| f(neighbor)); + + // Add to results if it passes filter and is close enough + if passes + && (!results.is_full() + || dist.to_f64() < results.top_distance().unwrap().to_f64()) + { + results.try_insert(neighbor, dist); + } + + // Add to candidates for exploration if close enough + if !results.is_full() || dist.to_f64() < results.top_distance().unwrap().to_f64() { + candidates.push(neighbor, dist); + } + } + } + } + + // Convert to sorted vector + results + .into_sorted_vec() + .into_iter() + .map(|e| (e.id, e.distance)) + .collect() +} + +/// Robust pruning for diverse neighbor selection. +/// +/// Implements the alpha-based pruning from the Vamana paper: +/// A neighbor is selected only if for all already-selected neighbors s: +/// dist(neighbor, s) * alpha >= dist(neighbor, target) +/// +/// This ensures neighbors are diverse and not all clustered together. +#[allow(clippy::too_many_arguments)] +pub fn robust_prune<'a, T, D, F>( + target: IdType, + candidates: &[(IdType, D)], + max_degree: usize, + alpha: f32, + data_getter: F, + dist_fn: &dyn DistanceFunction, + dim: usize, +) -> Vec +where + T: VectorElement, + D: DistanceType, + F: Fn(IdType) -> Option<&'a [T]>, +{ + if candidates.is_empty() { + return Vec::new(); + } + + let _target_data = match data_getter(target) { + Some(d) => d, + None => return select_closest(candidates, max_degree), + }; + + // Sort candidates by distance + let mut sorted: Vec<_> = candidates.to_vec(); + sorted.sort_by(|a, b| { + a.1.to_f64() + .partial_cmp(&b.1.to_f64()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let mut selected: Vec = Vec::with_capacity(max_degree); + let alpha_f64 = alpha as f64; + + for &(candidate_id, candidate_dist) in &sorted { + if selected.len() >= max_degree { + break; + } + + // Skip if candidate is the target itself + if candidate_id == target { + continue; + } + + let candidate_data = match data_getter(candidate_id) { + Some(d) => d, + None => continue, + }; + + // Check alpha condition against all selected neighbors + let mut is_diverse = true; + let candidate_dist_f64 = candidate_dist.to_f64(); + + for &selected_id in &selected { + if let Some(selected_data) = data_getter(selected_id) { + let dist_to_selected = dist_fn.compute(candidate_data, selected_data, dim); + // If candidate is closer to a selected neighbor than to target * alpha, + // it's not diverse enough + if dist_to_selected.to_f64() * alpha_f64 < candidate_dist_f64 { + is_diverse = false; + break; + } + } + } + + if is_diverse { + selected.push(candidate_id); + } + } + + // If we couldn't fill to max_degree due to alpha pruning, + // add remaining closest candidates from sorted list (fallback) + if selected.len() < max_degree { + for &(candidate_id, _) in &sorted { + if selected.len() >= max_degree { + break; + } + if !selected.contains(&candidate_id) && candidate_id != target { + selected.push(candidate_id); + } + } + } + + selected +} + +/// Simple selection of closest candidates (no diversity). +fn select_closest(candidates: &[(IdType, D)], max_degree: usize) -> Vec { + let mut sorted: Vec<_> = candidates.to_vec(); + sorted.sort_by(|a, b| { + a.1.to_f64() + .partial_cmp(&b.1.to_f64()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + sorted.into_iter().take(max_degree).map(|(id, _)| id).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::{l2::L2Distance, DistanceFunction}; + use crate::index::hnsw::VisitedNodesHandlerPool; + + #[test] + fn test_select_closest() { + let candidates = vec![(1u32, 1.0f32), (2, 0.5), (3, 2.0), (4, 0.3)]; + let selected = select_closest(&candidates, 2); + assert_eq!(selected.len(), 2); + assert_eq!(selected[0], 4); // Closest + assert_eq!(selected[1], 2); + } + + #[test] + fn test_select_closest_max() { + let candidates = vec![(1u32, 1.0f32), (2, 0.5)]; + let selected = select_closest(&candidates, 10); + assert_eq!(selected.len(), 2); + } + + #[test] + fn test_select_closest_empty() { + let candidates: Vec<(IdType, f32)> = vec![]; + let selected = select_closest(&candidates, 5); + assert!(selected.is_empty()); + } + + #[test] + fn test_select_closest_equal_distances() { + let candidates = vec![(1u32, 1.0f32), (2, 1.0), (3, 1.0)]; + let selected = select_closest(&candidates, 2); + assert_eq!(selected.len(), 2); + } + + #[test] + fn test_greedy_beam_search_single_node() { + let dist_fn = L2Distance::::new(4); + let vectors: Vec> = vec![vec![1.0, 0.0, 0.0, 0.0]]; + + let mut graph = VamanaGraph::new(10, 4); + graph.set_label(0, 100); + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [1.0, 0.0, 0.0, 0.0]; + + let results = greedy_beam_search( + 0, + &query, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + &visited, + ); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, 0); + assert!(results[0].1 < 0.001); + } + + #[test] + fn test_greedy_beam_search_finds_closest() { + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = vec![ + vec![0.0, 0.0], // id 0 + vec![1.0, 0.0], // id 1 + vec![2.0, 0.0], // id 2 + vec![3.0, 0.0], // id 3 - closest to query + ]; + + let mut graph = VamanaGraph::new(10, 4); + for i in 0..4 { + graph.set_label(i, i as u64 * 10); + } + // Linear chain: 0 -> 1 -> 2 -> 3 + graph.set_neighbors(0, &[1]); + graph.set_neighbors(1, &[0, 2]); + graph.set_neighbors(2, &[1, 3]); + graph.set_neighbors(3, &[2]); + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [3.0, 0.0]; + + let results = greedy_beam_search( + 0, // start from 0 + &query, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + ); + + // First result should be closest (id 3) + assert!(!results.is_empty()); + assert_eq!(results[0].0, 3); + } + + #[test] + fn test_greedy_beam_search_respects_beam_width() { + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = (0..10).map(|i| vec![i as f32, 0.0]).collect(); + + let mut graph = VamanaGraph::new(10, 10); + for i in 0..10 { + graph.set_label(i, i as u64); + } + // Fully connected + for i in 0..10 { + let neighbors: Vec = (0..10).filter(|&j| j != i).collect(); + graph.set_neighbors(i, &neighbors); + } + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [0.0, 0.0]; + + // Beam width = 3 + let results = greedy_beam_search( + 5, // start from middle + &query, + 3, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + ); + + assert!(results.len() <= 3); + } + + #[test] + fn test_greedy_beam_search_skips_deleted() { + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = vec![ + vec![0.0, 0.0], // id 0 + vec![1.0, 0.0], // id 1 - closest to query but deleted + vec![2.0, 0.0], // id 2 + ]; + + let mut graph = VamanaGraph::new(10, 4); + for i in 0..3 { + graph.set_label(i, i as u64 * 10); + } + graph.set_neighbors(0, &[1, 2]); + graph.set_neighbors(1, &[0, 2]); + graph.set_neighbors(2, &[0, 1]); + + // Mark id 1 as deleted + graph.mark_deleted(1); + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [1.0, 0.0]; // Closest to deleted node + + let results = greedy_beam_search( + 0, + &query, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + ); + + // Should not contain deleted id + for (id, _) in &results { + assert_ne!(*id, 1, "Deleted node should not appear in results"); + } + } + + #[test] + fn test_greedy_beam_search_filtered() { + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = vec![ + vec![0.0, 0.0], // id 0 + vec![1.0, 0.0], // id 1 + vec![2.0, 0.0], // id 2 + vec![3.0, 0.0], // id 3 + ]; + + let mut graph = VamanaGraph::new(10, 4); + for i in 0..4 { + graph.set_label(i, i as u64); + } + // Fully connected + graph.set_neighbors(0, &[1, 2, 3]); + graph.set_neighbors(1, &[0, 2, 3]); + graph.set_neighbors(2, &[0, 1, 3]); + graph.set_neighbors(3, &[0, 1, 2]); + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [1.5, 0.0]; // Between 1 and 2 + + // Filter: only accept even IDs + let filter = |id: IdType| -> bool { id % 2 == 0 }; + + let results = greedy_beam_search_filtered( + 0, + &query, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + Some(&filter), + ); + + // All results should have even IDs + for (id, _) in &results { + assert_eq!(id % 2, 0, "Filter should only allow even IDs"); + } + } + + #[test] + fn test_robust_prune_empty() { + let dist_fn = L2Distance::::new(4); + let candidates: Vec<(IdType, f32)> = vec![]; + let data_getter = |_id: IdType| -> Option<&[f32]> { None }; + + let selected = robust_prune( + 0, + &candidates, + 5, + 1.2, + data_getter, + &dist_fn as &dyn DistanceFunction, + 4, + ); + + assert!(selected.is_empty()); + } + + #[test] + fn test_robust_prune_basic() { + let vectors: Vec> = vec![ + vec![0.0, 0.0], // target (id 0) + vec![1.0, 0.0], // id 1 + vec![0.0, 1.0], // id 2 (orthogonal to 1, diverse) + vec![1.0, 0.0], // id 3 (same as 1, not diverse) + ]; + + let dist_fn = L2Distance::::new(2); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let candidates = vec![(1, 1.0f32), (2, 1.0f32), (3, 1.0f32)]; + + let selected = robust_prune( + 0, + &candidates, + 2, + 1.2, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + ); + + // Should prefer diverse neighbors + assert_eq!(selected.len(), 2); + // Should include orthogonal vector 2 for diversity + assert!(selected.contains(&1) || selected.contains(&2)); + } + + #[test] + fn test_robust_prune_excludes_target() { + let vectors: Vec> = vec![ + vec![0.0, 0.0], // target (id 0) + vec![1.0, 0.0], // id 1 + vec![2.0, 0.0], // id 2 + ]; + + let dist_fn = L2Distance::::new(2); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + // Include target in candidates + let candidates = vec![(0, 0.0f32), (1, 1.0f32), (2, 4.0f32)]; + + let selected = robust_prune( + 0, + &candidates, + 2, + 1.2, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + ); + + // Should not include target itself + assert!(!selected.contains(&0)); + } + + #[test] + fn test_robust_prune_fallback_when_all_pruned() { + // Set up vectors where alpha pruning would reject everything + let vectors: Vec> = vec![ + vec![0.0, 0.0], // target (id 0) + vec![1.0, 0.0], // id 1 + vec![1.1, 0.0], // id 2 (very close to 1) + vec![1.2, 0.0], // id 3 (very close to 1 and 2) + ]; + + let dist_fn = L2Distance::::new(2); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + // All very close together + let candidates = vec![(1, 1.0f32), (2, 1.21f32), (3, 1.44f32)]; + + // Use high alpha that would prune most candidates + let selected = robust_prune( + 0, + &candidates, + 3, + 2.0, // High alpha + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + ); + + // Should fill to max_degree using fallback + assert_eq!(selected.len(), 3); + } + + #[test] + fn test_robust_prune_fewer_candidates_than_degree() { + let vectors: Vec> = vec![ + vec![0.0, 0.0], // target (id 0) + vec![1.0, 0.0], // id 1 + ]; + + let dist_fn = L2Distance::::new(2); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let candidates = vec![(1, 1.0f32)]; + + let selected = robust_prune( + 0, + &candidates, + 10, // Want 10, only have 1 + 1.2, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + ); + + assert_eq!(selected.len(), 1); + assert_eq!(selected[0], 1); + } + + #[test] + fn test_robust_prune_alpha_effect() { + // Test that higher alpha allows more similar neighbors + let vectors: Vec> = vec![ + vec![0.0, 0.0], // target (id 0) + vec![1.0, 0.0], // id 1 + vec![1.5, 0.0], // id 2 (in same direction as 1) + vec![0.0, 2.0], // id 3 (different direction) + ]; + + let dist_fn = L2Distance::::new(2); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let candidates = vec![(1, 1.0f32), (2, 2.25f32), (3, 4.0f32)]; + + // Low alpha (more strict diversity) + let selected_low = robust_prune( + 0, + &candidates, + 3, + 1.0, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + ); + + // High alpha (less strict diversity) + let selected_high = robust_prune( + 0, + &candidates, + 3, + 2.0, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + ); + + // Both should return results + assert!(!selected_low.is_empty()); + assert!(!selected_high.is_empty()); + } + + #[test] + fn test_greedy_beam_search_disconnected_entry() { + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = vec![ + vec![0.0, 0.0], // id 0 - isolated + vec![1.0, 0.0], // id 1 + vec![2.0, 0.0], // id 2 + ]; + + let mut graph = VamanaGraph::new(10, 4); + for i in 0..3 { + graph.set_label(i, i as u64); + } + // Node 0 is isolated (no neighbors) + graph.set_neighbors(1, &[2]); + graph.set_neighbors(2, &[1]); + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [1.5, 0.0]; + + // Start from isolated node + let results = greedy_beam_search( + 0, + &query, + 10, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + ); + + // Should return at least the entry point + assert!(!results.is_empty()); + assert_eq!(results[0].0, 0); + } + + #[test] + fn test_greedy_beam_search_results_sorted() { + let dist_fn = L2Distance::::new(2); + let vectors: Vec> = vec![ + vec![0.0, 0.0], // id 0 + vec![1.0, 0.0], // id 1 + vec![2.0, 0.0], // id 2 + vec![3.0, 0.0], // id 3 + vec![4.0, 0.0], // id 4 + ]; + + let mut graph = VamanaGraph::new(10, 5); + for i in 0..5 { + graph.set_label(i, i as u64); + } + // Fully connected + for i in 0..5 { + let neighbors: Vec = (0..5).filter(|&j| j != i).collect(); + graph.set_neighbors(i, &neighbors); + } + + let pool = VisitedNodesHandlerPool::new(100); + let visited = pool.get(); + + let data_getter = |id: IdType| -> Option<&[f32]> { + if (id as usize) < vectors.len() { + Some(&vectors[id as usize]) + } else { + None + } + }; + + let query = [2.5, 0.0]; + + let results = greedy_beam_search( + 0, + &query, + 5, + &graph, + data_getter, + &dist_fn as &dyn DistanceFunction, + 2, + &visited, + ); + + // Results should be sorted by distance (ascending) + for i in 1..results.len() { + assert!( + results[i - 1].1 <= results[i].1, + "Results should be sorted by distance" + ); + } + } +} diff --git a/rust/vecsim/src/index/svs/single.rs b/rust/vecsim/src/index/svs/single.rs new file mode 100644 index 000000000..1ffef3d91 --- /dev/null +++ b/rust/vecsim/src/index/svs/single.rs @@ -0,0 +1,979 @@ +//! Single-value SVS (Vamana) index implementation. +//! +//! This index stores one vector per label. When adding a vector with +//! an existing label, the old vector is replaced. + +use super::{SvsCore, SvsParams}; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply, QueryResult}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement, INVALID_ID}; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Statistics for SVS index. +#[derive(Debug, Clone)] +pub struct SvsStats { + /// Number of vectors in the index. + pub vector_count: usize, + /// Number of labels in the index. + pub label_count: usize, + /// Dimension of vectors. + pub dimension: usize, + /// Distance metric. + pub metric: crate::distance::Metric, + /// Maximum graph degree. + pub graph_max_degree: usize, + /// Alpha parameter. + pub alpha: f32, + /// Average out-degree. + pub avg_degree: f64, + /// Maximum out-degree. + pub max_degree: usize, + /// Minimum out-degree. + pub min_degree: usize, + /// Memory usage (approximate). + pub memory_bytes: usize, +} + +/// Single-value SVS (Vamana) index. +/// +/// Each label has exactly one associated vector. +pub struct SvsSingle { + /// Core SVS implementation. + core: RwLock>, + /// Label to internal ID mapping. + label_to_id: RwLock>, + /// Internal ID to label mapping. + id_to_label: RwLock>, + /// Number of vectors. + count: AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, + /// Construction completed flag (for two-pass). + construction_done: RwLock, +} + +impl SvsSingle { + /// Create a new SVS index with the given parameters. + pub fn new(params: SvsParams) -> Self { + let initial_capacity = params.initial_capacity; + Self { + core: RwLock::new(SvsCore::new(params)), + label_to_id: RwLock::new(HashMap::with_capacity(initial_capacity)), + id_to_label: RwLock::new(HashMap::with_capacity(initial_capacity)), + count: AtomicUsize::new(0), + capacity: None, + construction_done: RwLock::new(false), + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: SvsParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the distance metric. + pub fn metric(&self) -> crate::distance::Metric { + self.core.read().params.metric + } + + /// Get index statistics. + pub fn stats(&self) -> SvsStats { + let core = self.core.read(); + let count = self.count.load(Ordering::Relaxed); + + SvsStats { + vector_count: count, + label_count: self.label_to_id.read().len(), + dimension: core.params.dim, + metric: core.params.metric, + graph_max_degree: core.params.graph_max_degree, + alpha: core.params.alpha, + avg_degree: core.graph.average_degree(), + max_degree: core.graph.max_degree(), + min_degree: core.graph.min_degree(), + memory_bytes: self.memory_usage(), + } + } + + /// Estimate memory usage in bytes. + pub fn memory_usage(&self) -> usize { + let core = self.core.read(); + let count = self.count.load(Ordering::Relaxed); + + // Vector data + let vector_size = count * core.params.dim * std::mem::size_of::(); + + // Graph (rough estimate) + let graph_size = count * core.params.graph_max_degree * std::mem::size_of::(); + + // Label mapping + let label_size = self.label_to_id.read().capacity() + * (std::mem::size_of::() + std::mem::size_of::()) + + self.id_to_label.read().capacity() + * (std::mem::size_of::() + std::mem::size_of::()); + + vector_size + graph_size + label_size + } + + /// Build the index after all vectors have been added. + /// + /// For two-pass construction, this performs the second pass. + /// Call this after adding all vectors for best recall. + pub fn build(&self) { + let mut done = self.construction_done.write(); + if *done { + return; + } + + let mut core = self.core.write(); + if core.params.two_pass_construction && core.data.len() > 1 { + core.rebuild_graph(); + } + + *done = true; + } + + /// Get the medoid (entry point) ID. + pub fn medoid(&self) -> Option { + let ep = self.core.read().medoid.load(Ordering::Relaxed); + if ep == INVALID_ID { + None + } else { + Some(ep) + } + } + + /// Get the search window size parameter. + pub fn search_l(&self) -> usize { + self.core.read().params.search_window_size + } + + /// Set the search window size parameter. + pub fn set_search_l(&self, l: usize) { + self.core.write().params.search_window_size = l; + } + + /// Get the graph max degree parameter. + pub fn graph_degree(&self) -> usize { + self.core.read().params.graph_max_degree + } + + /// Get the alpha parameter. + pub fn alpha(&self) -> f32 { + self.core.read().params.alpha + } + + /// Get the fragmentation ratio (0.0 = none, 1.0 = all deleted). + pub fn fragmentation(&self) -> f64 { + self.core.read().data.fragmentation() + } + + /// Compact the index to reclaim space from deleted vectors. + /// + /// Note: SVS doesn't support true compaction without rebuilding. + /// This method returns 0 as a placeholder. + pub fn compact(&mut self, _shrink: bool) -> usize { + // SVS requires rebuild for true compaction + // For now, we don't support in-place compaction + 0 + } + + /// Get a copy of the vector stored for a given label. + pub fn get_vector(&self, label: LabelType) -> Option> { + let id = *self.label_to_id.read().get(&label)?; + let core = self.core.read(); + core.data.get(id).map(|v| v.to_vec()) + } + + /// Clear all vectors from the index. + pub fn clear(&mut self) { + let mut core = self.core.write(); + let params = core.params.clone(); + *core = SvsCore::new(params); + + self.label_to_id.write().clear(); + self.id_to_label.write().clear(); + self.count.store(0, Ordering::Relaxed); + *self.construction_done.write() = false; + } +} + +impl VecSimIndex for SvsSingle { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + let mut core = self.core.write(); + + if vector.len() != core.params.dim { + return Err(IndexError::DimensionMismatch { + expected: core.params.dim, + got: vector.len(), + }); + } + + let mut label_to_id = self.label_to_id.write(); + let mut id_to_label = self.id_to_label.write(); + + // Check if label already exists + if let Some(&existing_id) = label_to_id.get(&label) { + // Mark old vector as deleted + core.mark_deleted(existing_id); + id_to_label.remove(&existing_id); + + // Add new vector + let new_id = core + .add_vector(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + core.insert(new_id, label); + + // Update mappings + label_to_id.insert(label, new_id); + id_to_label.insert(new_id, label); + + // Reset construction flag + *self.construction_done.write() = false; + + return Ok(0); // Replacement, not a new vector + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add new vector + let id = core + .add_vector(vector) + .ok_or_else(|| IndexError::Internal("Failed to add vector to storage".to_string()))?; + core.insert(id, label); + + // Update mappings + label_to_id.insert(label, id); + id_to_label.insert(id, label); + + self.count.fetch_add(1, Ordering::Relaxed); + *self.construction_done.write() = false; + + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let mut core = self.core.write(); + let mut label_to_id = self.label_to_id.write(); + let mut id_to_label = self.id_to_label.write(); + + if let Some(id) = label_to_id.remove(&label) { + core.mark_deleted(id); + id_to_label.remove(&id); + self.count.fetch_sub(1, Ordering::Relaxed); + Ok(1) + } else { + Err(IndexError::LabelNotFound(label)) + } + } + + fn top_k_query( + &self, + query: &[T], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + let core = self.core.read(); + + if query.len() != core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: core.params.dim, + got: query.len(), + }); + } + + if k == 0 { + return Ok(QueryReply::new()); + } + + let search_l = params + .and_then(|p| p.ef_runtime) + .unwrap_or(core.params.search_window_size); + + // Build filter if needed + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to RwLock directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.read().get(&id).is_some_and(|&label| f(label)) + })) + } else { + None + } + } else { + None + }; + + let results = core.search(query, k, search_l, filter_fn.as_ref().map(|f| f.as_ref())); + + // Look up labels for results + let id_to_label = self.id_to_label.read(); + let mut reply = QueryReply::with_capacity(results.len()); + for (id, dist) in results { + if let Some(&label) = id_to_label.get(&id) { + reply.push(QueryResult::new(label, dist)); + } + } + + Ok(reply) + } + + fn range_query( + &self, + query: &[T], + radius: T::DistanceType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + let core = self.core.read(); + + if query.len() != core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: core.params.dim, + got: query.len(), + }); + } + + let count = self.count.load(Ordering::Relaxed); + let search_l = params + .and_then(|p| p.ef_runtime) + .unwrap_or(core.params.search_window_size) + .max(count.min(1000)); + + // Build filter if needed + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to RwLock directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.read().get(&id).is_some_and(|&label| f(label)) + })) + } else { + None + } + } else { + None + }; + + let results = core.search(query, count, search_l, filter_fn.as_ref().map(|f| f.as_ref())); + + // Look up labels and filter by radius + let id_to_label = self.id_to_label.read(); + let mut reply = QueryReply::new(); + for (id, dist) in results { + if dist.to_f64() <= radius.to_f64() { + if let Some(&label) = id_to_label.get(&id) { + reply.push(QueryResult::new(label, dist)); + } + } + } + + reply.sort_by_distance(); + Ok(reply) + } + + fn index_size(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.core.read().params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[T], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + let core = self.core.read(); + if query.len() != core.params.dim { + return Err(QueryError::DimensionMismatch { + expected: core.params.dim, + got: query.len(), + }); + } + + let search_l = params + .and_then(|p| p.ef_runtime) + .unwrap_or(core.params.search_window_size); + + let count = self.count.load(Ordering::Relaxed); + + // Build filter if needed + let filter_fn: Option bool + '_>> = if let Some(p) = params { + if let Some(ref f) = p.filter { + // Use reference to RwLock directly - avoids O(n) copy + let id_to_label_ref = &self.id_to_label; + Some(Box::new(move |id: IdType| { + id_to_label_ref.read().get(&id).is_some_and(|&label| f(label)) + })) + } else { + None + } + } else { + None + }; + + // Use the specified search_l to affect search quality + // A smaller search_l means more greedy search with potentially worse results + let raw_results = core.search( + query, + count, + search_l, + filter_fn.as_ref().map(|f| f.as_ref()), + ); + + let id_to_label = self.id_to_label.read(); + let results: Vec<_> = raw_results + .into_iter() + .filter_map(|(id, dist)| { + id_to_label.get(&id).map(|&label| (id, label, dist)) + }) + .collect(); + + Ok(Box::new(SvsSingleBatchIterator::::new(results))) + } + + fn info(&self) -> IndexInfo { + let core = self.core.read(); + let count = self.count.load(Ordering::Relaxed); + + IndexInfo { + size: count, + capacity: self.capacity, + dimension: core.params.dim, + index_type: "SvsSingle", + memory_bytes: self.memory_usage(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.label_to_id.read().contains_key(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + if self.contains(label) { 1 } else { 0 } + } +} + +// Allow read-only concurrent access for queries +unsafe impl Send for SvsSingle {} +unsafe impl Sync for SvsSingle {} + +/// Batch iterator for SvsSingle. +pub struct SvsSingleBatchIterator { + results: Vec<(IdType, LabelType, T::DistanceType)>, + current_idx: usize, +} + +impl SvsSingleBatchIterator { + /// Create a new batch iterator with pre-computed results. + pub fn new(results: Vec<(IdType, LabelType, T::DistanceType)>) -> Self { + Self { + results, + current_idx: 0, + } + } +} + +impl BatchIterator for SvsSingleBatchIterator { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + self.current_idx < self.results.len() + } + + fn next_batch(&mut self, batch_size: usize) -> Option> { + if self.current_idx >= self.results.len() { + return None; + } + + let end = (self.current_idx + batch_size).min(self.results.len()); + let batch = self.results[self.current_idx..end].to_vec(); + self.current_idx = end; + + if batch.is_empty() { + None + } else { + Some(batch) + } + } + + fn reset(&mut self) { + self.current_idx = 0; + } +} + +// Serialization support for SvsSingle +impl SvsSingle { + /// Save the index to a writer. + pub fn save( + &self, + writer: &mut W, + ) -> crate::serialization::SerializationResult<()> { + use crate::serialization::*; + use std::sync::atomic::Ordering; + + let core = self.core.read(); + let label_to_id = self.label_to_id.read(); + let id_to_label = self.id_to_label.read(); + let count = self.count.load(Ordering::Relaxed); + + // Write header + let header = IndexHeader::new( + IndexTypeId::SvsSingle, + DataTypeId::F32, + core.params.metric, + core.params.dim, + count, + ); + header.write(writer)?; + + // Write SVS-specific parameters + write_usize(writer, core.params.graph_max_degree)?; + write_f32(writer, core.params.alpha)?; + write_usize(writer, core.params.construction_window_size)?; + write_usize(writer, core.params.search_window_size)?; + write_u8(writer, if core.params.two_pass_construction { 1 } else { 0 })?; + + // Write capacity + write_u8(writer, if self.capacity.is_some() { 1 } else { 0 })?; + if let Some(cap) = self.capacity { + write_usize(writer, cap)?; + } + + // Write construction_done flag + write_u8(writer, if *self.construction_done.read() { 1 } else { 0 })?; + + // Write medoid + write_u32(writer, core.medoid.load(Ordering::Relaxed))?; + + // Write label_to_id mapping + write_usize(writer, label_to_id.len())?; + for (&label, &id) in label_to_id.iter() { + write_u64(writer, label)?; + write_u32(writer, id)?; + } + + // Write id_to_label mapping + write_usize(writer, id_to_label.len())?; + for (&id, &label) in id_to_label.iter() { + write_u32(writer, id)?; + write_u64(writer, label)?; + } + + // Write vectors - collect valid IDs first + let valid_ids: Vec = core.data.iter_ids().collect(); + write_usize(writer, valid_ids.len())?; + for id in &valid_ids { + write_u32(writer, *id)?; + if let Some(vector) = core.data.get(*id) { + for &v in vector { + write_f32(writer, v)?; + } + } + } + + // Write graph structure + // For each valid ID, write its neighbors + for id in &valid_ids { + let neighbors = core.graph.get_neighbors(*id); + let label = core.graph.get_label(*id); + let deleted = core.graph.is_deleted(*id); + + write_u64(writer, label)?; + write_u8(writer, if deleted { 1 } else { 0 })?; + write_usize(writer, neighbors.len())?; + for &n in &neighbors { + write_u32(writer, n)?; + } + } + + Ok(()) + } + + /// Load the index from a reader. + pub fn load(reader: &mut R) -> crate::serialization::SerializationResult { + use crate::serialization::*; + use std::sync::atomic::Ordering; + + // Read and validate header + let header = IndexHeader::read(reader)?; + + if header.index_type != IndexTypeId::SvsSingle { + return Err(SerializationError::IndexTypeMismatch { + expected: "SvsSingle".to_string(), + got: header.index_type.as_str().to_string(), + }); + } + + if header.data_type != DataTypeId::F32 { + return Err(SerializationError::InvalidData( + "Expected f32 data type".to_string(), + )); + } + + // Read SVS-specific parameters + let graph_max_degree = read_usize(reader)?; + let alpha = read_f32(reader)?; + let construction_window_size = read_usize(reader)?; + let search_window_size = read_usize(reader)?; + let two_pass_construction = read_u8(reader)? != 0; + + // Create params + let params = SvsParams { + dim: header.dimension, + metric: header.metric, + graph_max_degree, + alpha, + construction_window_size, + search_window_size, + initial_capacity: header.count.max(1024), + two_pass_construction, + }; + + // Create the index + let mut index = Self::new(params); + + // Read capacity + let has_capacity = read_u8(reader)? != 0; + if has_capacity { + index.capacity = Some(read_usize(reader)?); + } + + // Read construction_done flag + let construction_done = read_u8(reader)? != 0; + *index.construction_done.write() = construction_done; + + // Read medoid + let medoid = read_u32(reader)?; + + // Read label_to_id mapping + let label_to_id_len = read_usize(reader)?; + let mut label_to_id = HashMap::with_capacity(label_to_id_len); + for _ in 0..label_to_id_len { + let label = read_u64(reader)?; + let id = read_u32(reader)?; + label_to_id.insert(label, id); + } + + // Read id_to_label mapping + let id_to_label_len = read_usize(reader)?; + let mut id_to_label = HashMap::with_capacity(id_to_label_len); + for _ in 0..id_to_label_len { + let id = read_u32(reader)?; + let label = read_u64(reader)?; + id_to_label.insert(id, label); + } + + // Read vectors + let num_vectors = read_usize(reader)?; + let dim = header.dimension; + let mut vector_ids = Vec::with_capacity(num_vectors); + + { + let mut core = index.core.write(); + core.data.reserve(num_vectors); + + for _ in 0..num_vectors { + let id = read_u32(reader)?; + let mut vector = vec![0.0f32; dim]; + for v in &mut vector { + *v = read_f32(reader)?; + } + + // Add vector at specific position + let added_id = core.data.add(&vector).ok_or_else(|| { + SerializationError::DataCorruption( + "Failed to add vector during deserialization".to_string(), + ) + })?; + + // Track the ID for graph restoration + vector_ids.push((id, added_id)); + } + + // Read and restore graph structure + core.graph.ensure_capacity(num_vectors); + for (original_id, _) in &vector_ids { + let label = read_u64(reader)?; + let deleted = read_u8(reader)? != 0; + let num_neighbors = read_usize(reader)?; + + let mut neighbors = Vec::with_capacity(num_neighbors); + for _ in 0..num_neighbors { + neighbors.push(read_u32(reader)?); + } + + core.graph.set_label(*original_id, label); + if deleted { + core.graph.mark_deleted(*original_id); + } + core.graph.set_neighbors(*original_id, &neighbors); + } + + // Restore medoid + core.medoid.store(medoid, Ordering::Release); + + // Resize visited pool + if num_vectors > 0 { + core.visited_pool.resize(num_vectors + 1024); + } + } + + // Restore label mappings + *index.label_to_id.write() = label_to_id; + *index.id_to_label.write() = id_to_label; + index.count.store(header.count, Ordering::Relaxed); + + Ok(index) + } + + /// Save the index to a file. + pub fn save_to_file>( + &self, + path: P, + ) -> crate::serialization::SerializationResult<()> { + let file = std::fs::File::create(path)?; + let mut writer = std::io::BufWriter::new(file); + self.save(&mut writer) + } + + /// Load the index from a file. + pub fn load_from_file>( + path: P, + ) -> crate::serialization::SerializationResult { + let file = std::fs::File::open(path)?; + let mut reader = std::io::BufReader::new(file); + Self::load(&mut reader) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_svs_single_serialization() { + use std::io::Cursor; + + let params = SvsParams::new(4, Metric::L2) + .with_graph_degree(8) + .with_alpha(1.2); + let mut index = SvsSingle::::new(params); + + // Add vectors + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3).unwrap(); + index.add_vector(&[0.0, 0.0, 0.0, 1.0], 4).unwrap(); + + // Build to ensure graph is populated + index.build(); + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = SvsSingle::::load(&mut cursor).unwrap(); + + // Verify + assert_eq!(loaded.index_size(), 4); + assert_eq!(loaded.dimension(), 4); + + // Query should work the same + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = loaded.top_k_query(&query, 4, None).unwrap(); + assert_eq!(results.results[0].label, 1); + } + + #[test] + fn test_svs_single_basic() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsSingle::::new(params); + + // Add vectors + assert_eq!(index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(), 1); + assert_eq!(index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(), 1); + assert_eq!(index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3).unwrap(), 1); + assert_eq!(index.add_vector(&[0.0, 0.0, 0.0, 1.0], 4).unwrap(), 1); + + assert_eq!(index.index_size(), 4); + + // Query + let query = [1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + + assert_eq!(results.results.len(), 2); + assert_eq!(results.results[0].label, 1); + } + + #[test] + fn test_svs_single_update() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsSingle::::new(params); + + // Add initial vector + assert_eq!(index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(), 1); + assert_eq!(index.index_size(), 1); + + // Update same label (should return 0 for replacement) + assert_eq!(index.add_vector(&[0.0, 1.0, 0.0, 0.0], 1).unwrap(), 0); + assert_eq!(index.index_size(), 1); + } + + #[test] + fn test_svs_single_delete() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsSingle::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + assert_eq!(index.delete_vector(1).unwrap(), 1); + assert_eq!(index.index_size(), 1); + assert!(!index.contains(1)); + assert!(index.contains(2)); + } + + #[test] + fn test_svs_single_build() { + let params = SvsParams::new(4, Metric::L2).with_two_pass(true); + let mut index = SvsSingle::::new(params); + + // Add vectors + for i in 0..100 { + let mut v = vec![0.0f32; 4]; + v[i % 4] = 1.0; + index.add_vector(&v, i as u64).unwrap(); + } + + // Build (second pass) + index.build(); + + // Query should still work + let query = [1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 10, None).unwrap(); + + assert!(!results.results.is_empty()); + } + + #[test] + fn test_svs_single_stats() { + let params = SvsParams::new(4, Metric::L2) + .with_graph_degree(16) + .with_alpha(1.2); + let mut index = SvsSingle::::new(params); + + for i in 0..50 { + let mut v = vec![0.0f32; 4]; + v[i % 4] = 1.0; + index.add_vector(&v, i as u64).unwrap(); + } + + let stats = index.stats(); + assert_eq!(stats.vector_count, 50); + assert_eq!(stats.label_count, 50); + assert_eq!(stats.dimension, 4); + assert_eq!(stats.graph_max_degree, 16); + assert!((stats.alpha - 1.2).abs() < 0.01); + } + + #[test] + fn test_svs_single_capacity() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsSingle::::with_capacity(params, 2); + + assert_eq!(index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(), 1); + assert_eq!(index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(), 1); + + // Should fail due to capacity + let result = index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3); + assert!(matches!(result, Err(IndexError::CapacityExceeded { capacity: 2 }))); + } + + #[test] + fn test_svs_single_batch_iterator_with_filter() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsSingle::::new(params); + + // Add vectors with labels 1-10 + for i in 1..=10u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + assert_eq!(index.index_size(), 10); + + // Create a filter that only allows even labels + let query_params = QueryParams::new().with_filter(|label| label % 2 == 0); + + let query = vec![5.0, 0.0, 0.0, 0.0]; + + // Test batch iterator with filter + let mut iter = index.batch_iterator(&query, Some(&query_params)).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should only have even labels (2, 4, 6, 8, 10) + assert_eq!(all_results.len(), 5); + for (_, label, _) in &all_results { + assert_eq!(label % 2, 0, "Expected only even labels, got {}", label); + } + + // Verify specific labels are present + let labels: Vec<_> = all_results.iter().map(|(_, l, _)| *l).collect(); + assert!(labels.contains(&2)); + assert!(labels.contains(&4)); + assert!(labels.contains(&6)); + assert!(labels.contains(&8)); + assert!(labels.contains(&10)); + } + + #[test] + fn test_svs_single_batch_iterator_no_filter() { + let params = SvsParams::new(4, Metric::L2); + let mut index = SvsSingle::::new(params); + + // Add vectors + for i in 1..=5u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let query = vec![3.0, 0.0, 0.0, 0.0]; + + // Test batch iterator without filter + let mut iter = index.batch_iterator(&query, None).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should have all 5 vectors + assert_eq!(all_results.len(), 5); + } +} diff --git a/rust/vecsim/src/index/tiered/batch_iterator.rs b/rust/vecsim/src/index/tiered/batch_iterator.rs new file mode 100644 index 000000000..b7f4f3ef6 --- /dev/null +++ b/rust/vecsim/src/index/tiered/batch_iterator.rs @@ -0,0 +1,261 @@ +//! Batch iterator implementations for tiered indices. +//! +//! These iterators merge results from both flat buffer and HNSW backend, +//! returning them in sorted order by distance. + +use super::multi::TieredMulti; +use super::single::TieredSingle; +use crate::index::traits::{BatchIterator, VecSimIndex}; +use crate::query::QueryParams; +use crate::types::{DistanceType, IdType, LabelType, VectorElement}; +use std::cmp::Ordering; + +/// Batch iterator for single-value tiered index. +/// +/// Merges and sorts results from both flat buffer and HNSW backend. +pub struct TieredBatchIterator<'a, T: VectorElement> { + /// Reference to the tiered index. + index: &'a TieredSingle, + /// The query vector. + query: Vec, + /// Query parameters. + params: Option, + /// All results sorted by distance. + results: Vec<(IdType, LabelType, T::DistanceType)>, + /// Current position in results. + position: usize, + /// Whether results have been computed. + computed: bool, +} + +impl<'a, T: VectorElement> TieredBatchIterator<'a, T> { + /// Create a new batch iterator. + pub fn new( + index: &'a TieredSingle, + query: Vec, + params: Option, + ) -> Self { + Self { + index, + query, + params, + results: Vec::new(), + position: 0, + computed: false, + } + } + + /// Compute all results from both tiers. + fn compute_results(&mut self) { + if self.computed { + return; + } + + // Get results from flat buffer + let flat = self.index.flat.read(); + if let Ok(mut iter) = flat.batch_iterator(&self.query, self.params.as_ref()) { + while let Some(batch) = iter.next_batch(1000) { + self.results.extend(batch); + } + } + drop(flat); + + // Get results from HNSW + let hnsw = self.index.hnsw.read(); + if let Ok(mut iter) = hnsw.batch_iterator(&self.query, self.params.as_ref()) { + while let Some(batch) = iter.next_batch(1000) { + self.results.extend(batch); + } + } + drop(hnsw); + + // Sort by distance + self.results.sort_by(|a, b| { + a.2.to_f64() + .partial_cmp(&b.2.to_f64()) + .unwrap_or(Ordering::Equal) + }); + + self.computed = true; + } +} + +impl<'a, T: VectorElement> BatchIterator for TieredBatchIterator<'a, T> { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + !self.computed || self.position < self.results.len() + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + self.compute_results(); + + if self.position >= self.results.len() { + return None; + } + + let end = (self.position + batch_size).min(self.results.len()); + let batch = self.results[self.position..end].to_vec(); + self.position = end; + + Some(batch) + } + + fn reset(&mut self) { + self.position = 0; + } +} + +/// Batch iterator for multi-value tiered index. +pub struct TieredMultiBatchIterator<'a, T: VectorElement> { + /// Reference to the tiered index. + index: &'a TieredMulti, + /// The query vector. + query: Vec, + /// Query parameters. + params: Option, + /// All results sorted by distance. + results: Vec<(IdType, LabelType, T::DistanceType)>, + /// Current position in results. + position: usize, + /// Whether results have been computed. + computed: bool, +} + +impl<'a, T: VectorElement> TieredMultiBatchIterator<'a, T> { + /// Create a new batch iterator. + pub fn new( + index: &'a TieredMulti, + query: Vec, + params: Option, + ) -> Self { + Self { + index, + query, + params, + results: Vec::new(), + position: 0, + computed: false, + } + } + + /// Compute all results from both tiers. + fn compute_results(&mut self) { + if self.computed { + return; + } + + // Get results from flat buffer + let flat = self.index.flat.read(); + if let Ok(mut iter) = flat.batch_iterator(&self.query, self.params.as_ref()) { + while let Some(batch) = iter.next_batch(1000) { + self.results.extend(batch); + } + } + drop(flat); + + // Get results from HNSW + let hnsw = self.index.hnsw.read(); + if let Ok(mut iter) = hnsw.batch_iterator(&self.query, self.params.as_ref()) { + while let Some(batch) = iter.next_batch(1000) { + self.results.extend(batch); + } + } + drop(hnsw); + + // Sort by distance + self.results.sort_by(|a, b| { + a.2.to_f64() + .partial_cmp(&b.2.to_f64()) + .unwrap_or(Ordering::Equal) + }); + + self.computed = true; + } +} + +impl<'a, T: VectorElement> BatchIterator for TieredMultiBatchIterator<'a, T> { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + !self.computed || self.position < self.results.len() + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + self.compute_results(); + + if self.position >= self.results.len() { + return None; + } + + let end = (self.position + batch_size).min(self.results.len()); + let batch = self.results[self.position..end].to_vec(); + self.position = end; + + Some(batch) + } + + fn reset(&mut self) { + self.position = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + use crate::index::tiered::TieredParams; + use crate::index::VecSimIndex; + + #[test] + fn test_tiered_batch_iterator() { + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredSingle::::new(params); + + // Add vectors to flat + for i in 0..5 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + // Flush to HNSW + index.flush().unwrap(); + + // Add more to flat + for i in 5..10 { + index + .add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64) + .unwrap(); + } + + let query = vec![0.0, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&query, None).unwrap(); + + assert!(iter.has_next()); + + // Get first batch + let batch1 = iter.next_batch(3).unwrap(); + assert_eq!(batch1.len(), 3); + + // Verify ordering + assert!(batch1[0].2.to_f64() <= batch1[1].2.to_f64()); + + // Get all remaining + let mut total = batch1.len(); + while let Some(batch) = iter.next_batch(3) { + total += batch.len(); + } + assert_eq!(total, 10); + + // Reset + iter.reset(); + assert!(iter.has_next()); + } +} diff --git a/rust/vecsim/src/index/tiered/mod.rs b/rust/vecsim/src/index/tiered/mod.rs new file mode 100644 index 000000000..93525e01c --- /dev/null +++ b/rust/vecsim/src/index/tiered/mod.rs @@ -0,0 +1,242 @@ +//! Tiered index combining BruteForce frontend with HNSW backend. +//! +//! The tiered architecture optimizes for: +//! - **Fast writes**: New vectors go to flat BruteForce buffer +//! - **Efficient queries**: HNSW provides logarithmic search complexity +//! - **Flexible migration**: Vectors can be flushed from flat to HNSW +//! +//! # Architecture +//! +//! ```text +//! TieredIndex +//! ├── Frontend: BruteForce index (fast write buffer) +//! ├── Backend: HNSW index (efficient approximate search) +//! └── Query: Searches both tiers, merges results +//! ``` +//! +//! # Write Modes +//! +//! - **Async**: Vectors added to flat buffer, later migrated to HNSW via `flush()` +//! - **InPlace**: Vectors added directly to HNSW (used when buffer is full) + +pub mod batch_iterator; +pub mod multi; +pub mod single; + +pub use batch_iterator::TieredBatchIterator; +pub use multi::TieredMulti; +pub use single::TieredSingle; + +use crate::distance::Metric; +use crate::index::hnsw::HnswParams; + +/// Write mode for the tiered index. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum WriteMode { + /// Async mode: vectors go to flat buffer, migrated to HNSW via flush(). + #[default] + Async, + /// InPlace mode: vectors go directly to HNSW. + InPlace, +} + +/// Configuration parameters for TieredIndex. +#[derive(Debug, Clone)] +pub struct TieredParams { + /// Vector dimension. + pub dim: usize, + /// Distance metric. + pub metric: Metric, + /// Parameters for the HNSW backend index. + pub hnsw_params: HnswParams, + /// Maximum size of the flat buffer before forcing in-place writes. + /// When flat buffer reaches this limit, new writes go directly to HNSW. + pub flat_buffer_limit: usize, + /// Write mode for the index. + pub write_mode: WriteMode, + /// Initial capacity hint. + pub initial_capacity: usize, +} + +impl TieredParams { + /// Create new tiered index parameters. + /// + /// # Arguments + /// * `dim` - Vector dimension + /// * `metric` - Distance metric (L2, InnerProduct, Cosine) + pub fn new(dim: usize, metric: Metric) -> Self { + Self { + dim, + metric, + hnsw_params: HnswParams::new(dim, metric), + flat_buffer_limit: 10_000, + write_mode: WriteMode::Async, + initial_capacity: 1000, + } + } + + /// Set the HNSW M parameter (max connections per node). + pub fn with_m(mut self, m: usize) -> Self { + self.hnsw_params = self.hnsw_params.with_m(m); + self + } + + /// Set the ef_construction parameter for HNSW. + pub fn with_ef_construction(mut self, ef: usize) -> Self { + self.hnsw_params = self.hnsw_params.with_ef_construction(ef); + self + } + + /// Set the ef_runtime parameter for HNSW queries. + pub fn with_ef_runtime(mut self, ef: usize) -> Self { + self.hnsw_params = self.hnsw_params.with_ef_runtime(ef); + self + } + + /// Set the flat buffer limit. + /// + /// When the flat buffer reaches this size, the index switches to + /// in-place writes (directly to HNSW) until flush() is called. + pub fn with_flat_buffer_limit(mut self, limit: usize) -> Self { + self.flat_buffer_limit = limit; + self + } + + /// Set the write mode. + pub fn with_write_mode(mut self, mode: WriteMode) -> Self { + self.write_mode = mode; + self + } + + /// Set the initial capacity hint. + pub fn with_initial_capacity(mut self, capacity: usize) -> Self { + self.initial_capacity = capacity; + self.hnsw_params = self.hnsw_params.with_capacity(capacity); + self + } +} + +/// Merge two sorted query replies, keeping the top k results. +/// +/// Both replies are assumed to be sorted by distance (ascending). +/// For single-value indices, no deduplication is needed. +pub fn merge_top_k( + flat_results: crate::query::QueryReply, + hnsw_results: crate::query::QueryReply, + k: usize, +) -> crate::query::QueryReply { + use crate::query::QueryReply; + + // Fast paths + if flat_results.is_empty() { + let mut results = hnsw_results; + results.results.truncate(k); + return results; + } + + if hnsw_results.is_empty() { + let mut results = flat_results; + results.results.truncate(k); + return results; + } + + // Merge both result sets + let mut merged = QueryReply::with_capacity(flat_results.len() + hnsw_results.len()); + + // Add all results + for result in flat_results.results { + merged.push(result); + } + for result in hnsw_results.results { + merged.push(result); + } + + // Sort by distance and truncate + merged.sort_by_distance(); + merged.results.truncate(k); + + merged +} + +/// Merge two sorted query replies for multi-value indices, keeping the top k unique labels. +/// +/// Both replies are assumed to be sorted by distance (ascending). +/// Deduplicates by label, keeping only the best (minimum) distance per label. +pub fn merge_top_k_multi( + flat_results: crate::query::QueryReply, + hnsw_results: crate::query::QueryReply, + k: usize, +) -> crate::query::QueryReply { + use crate::query::{QueryReply, QueryResult}; + use crate::types::LabelType; + use std::collections::HashMap; + + // Fast paths + if flat_results.is_empty() { + let mut results = hnsw_results; + results.results.truncate(k); + return results; + } + + if hnsw_results.is_empty() { + let mut results = flat_results; + results.results.truncate(k); + return results; + } + + // Deduplicate by label, keeping best distance per label + let mut label_best: HashMap = HashMap::new(); + + for result in flat_results.results.into_iter().chain(hnsw_results.results.into_iter()) { + label_best + .entry(result.label) + .and_modify(|best| { + if result.distance.to_f64() < best.to_f64() { + *best = result.distance; + } + }) + .or_insert(result.distance); + } + + // Convert to reply + let mut merged = QueryReply::with_capacity(label_best.len().min(k)); + for (label, dist) in label_best { + merged.push(QueryResult::new(label, dist)); + } + + // Sort by distance and truncate + merged.sort_by_distance(); + merged.results.truncate(k); + + merged +} + +/// Merge two query replies for range queries. +/// +/// Combines all results from both tiers. +pub fn merge_range( + flat_results: crate::query::QueryReply, + hnsw_results: crate::query::QueryReply, +) -> crate::query::QueryReply { + use crate::query::QueryReply; + + if flat_results.is_empty() { + return hnsw_results; + } + + if hnsw_results.is_empty() { + return flat_results; + } + + let mut merged = QueryReply::with_capacity(flat_results.len() + hnsw_results.len()); + + for result in flat_results.results { + merged.push(result); + } + for result in hnsw_results.results { + merged.push(result); + } + + merged.sort_by_distance(); + merged +} diff --git a/rust/vecsim/src/index/tiered/multi.rs b/rust/vecsim/src/index/tiered/multi.rs new file mode 100644 index 000000000..d4cf5417c --- /dev/null +++ b/rust/vecsim/src/index/tiered/multi.rs @@ -0,0 +1,735 @@ +//! Multi-value tiered index implementation. +//! +//! This index allows multiple vectors per label, combining a BruteForce frontend +//! (for fast writes) with an HNSW backend (for efficient queries). + +use super::{merge_range, merge_top_k_multi, TieredParams, WriteMode}; +use crate::index::brute_force::{BruteForceMulti, BruteForceParams}; +use crate::index::hnsw::HnswMulti; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply}; +use crate::types::{LabelType, VectorElement}; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Multi-value tiered index combining BruteForce frontend with HNSW backend. +/// +/// Each label can have multiple associated vectors. The tiered architecture +/// provides fast writes to the flat buffer and efficient queries from HNSW. +pub struct TieredMulti { + /// Flat BruteForce buffer (frontend). + pub(crate) flat: RwLock>, + /// HNSW index (backend). + pub(crate) hnsw: RwLock>, + /// Label counts in flat buffer. + flat_label_counts: RwLock>, + /// Label counts in HNSW. + hnsw_label_counts: RwLock>, + /// Configuration parameters. + params: TieredParams, + /// Total vector count. + count: AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, +} + +impl TieredMulti { + /// Create a new multi-value tiered index. + pub fn new(params: TieredParams) -> Self { + let bf_params = BruteForceParams::new(params.dim, params.metric) + .with_capacity(params.flat_buffer_limit.min(params.initial_capacity)); + let flat = BruteForceMulti::new(bf_params); + let hnsw = HnswMulti::new(params.hnsw_params.clone()); + + Self { + flat: RwLock::new(flat), + hnsw: RwLock::new(hnsw), + flat_label_counts: RwLock::new(HashMap::new()), + hnsw_label_counts: RwLock::new(HashMap::new()), + params, + count: AtomicUsize::new(0), + capacity: None, + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: TieredParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the current write mode. + pub fn write_mode(&self) -> WriteMode { + if self.params.write_mode == WriteMode::InPlace { + return WriteMode::InPlace; + } + if self.flat_size() >= self.params.flat_buffer_limit { + WriteMode::InPlace + } else { + WriteMode::Async + } + } + + /// Get the number of vectors in the flat buffer. + pub fn flat_size(&self) -> usize { + self.flat.read().index_size() + } + + /// Get the number of vectors in the HNSW backend. + pub fn hnsw_size(&self) -> usize { + self.hnsw.read().index_size() + } + + /// Get the memory usage in bytes. + pub fn memory_usage(&self) -> usize { + self.flat.read().memory_usage() + self.hnsw.read().memory_usage() + } + + /// Get the count of vectors for a label in the flat buffer. + pub fn flat_label_count(&self, label: LabelType) -> usize { + *self.flat_label_counts.read().get(&label).unwrap_or(&0) + } + + /// Get the count of vectors for a label in the HNSW backend. + pub fn hnsw_label_count(&self, label: LabelType) -> usize { + *self.hnsw_label_counts.read().get(&label).unwrap_or(&0) + } + + /// Flush all vectors from flat buffer to HNSW. + /// + /// # Returns + /// The number of vectors migrated. + pub fn flush(&mut self) -> Result { + let flat_labels: Vec = self.flat_label_counts.read().keys().copied().collect(); + + if flat_labels.is_empty() { + return Ok(0); + } + + let mut migrated = 0; + + // For multi-value, we need to get all vectors for each label + let vectors: Vec<(LabelType, Vec>)> = { + let flat = self.flat.read(); + flat_labels + .iter() + .filter_map(|&label| { + flat.get_vectors(label).map(|vecs| (label, vecs)) + }) + .collect() + }; + + // Add to HNSW + { + let mut hnsw = self.hnsw.write(); + let mut hnsw_label_counts = self.hnsw_label_counts.write(); + + for (label, vecs) in &vectors { + for vec in vecs { + match hnsw.add_vector(vec, *label) { + Ok(_) => { + *hnsw_label_counts.entry(*label).or_insert(0) += 1; + migrated += 1; + } + Err(e) => { + eprintln!("Failed to migrate label {label} to HNSW: {e:?}"); + } + } + } + } + } + + // Clear flat buffer + { + let mut flat = self.flat.write(); + let mut flat_label_counts = self.flat_label_counts.write(); + + flat.clear(); + flat_label_counts.clear(); + } + + Ok(migrated) + } + + /// Compact both tiers to reclaim space. + pub fn compact(&mut self, shrink: bool) -> usize { + let flat_reclaimed = self.flat.write().compact(shrink); + let hnsw_reclaimed = self.hnsw.write().compact(shrink); + flat_reclaimed + hnsw_reclaimed + } + + /// Get the fragmentation ratio. + pub fn fragmentation(&self) -> f64 { + let flat = self.flat.read(); + let hnsw = self.hnsw.read(); + let flat_frag = flat.fragmentation(); + let hnsw_frag = hnsw.fragmentation(); + + let flat_size = flat.index_size() as f64; + let hnsw_size = hnsw.index_size() as f64; + let total = flat_size + hnsw_size; + + if total == 0.0 { + 0.0 + } else { + (flat_frag * flat_size + hnsw_frag * hnsw_size) / total + } + } + + /// Clear all vectors from both tiers. + pub fn clear(&mut self) { + self.flat.write().clear(); + self.hnsw.write().clear(); + self.flat_label_counts.write().clear(); + self.hnsw_label_counts.write().clear(); + self.count.store(0, Ordering::Relaxed); + } + + /// Get the neighbors of an element in the HNSW graph by label. + /// + /// Returns a vector of vectors, where each inner vector contains the neighbor labels + /// for that level (level 0 first). Returns None if the label is not in the HNSW backend. + pub fn get_element_neighbors(&self, label: LabelType) -> Option>> { + // Only check HNSW backend - flat buffer doesn't have graph structure + if self.hnsw_label_counts.read().contains_key(&label) { + return self.hnsw.read().get_element_neighbors(label); + } + None + } +} + +impl VecSimIndex for TieredMulti { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + if vector.len() != self.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.params.dim, + got: vector.len(), + }); + } + + if let Some(cap) = self.capacity { + if self.count.load(Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // For multi-value, always add (don't replace) + match self.write_mode() { + WriteMode::Async => { + let mut flat = self.flat.write(); + let mut flat_label_counts = self.flat_label_counts.write(); + flat.add_vector(vector, label)?; + *flat_label_counts.entry(label).or_insert(0) += 1; + } + WriteMode::InPlace => { + let mut hnsw = self.hnsw.write(); + let mut hnsw_label_counts = self.hnsw_label_counts.write(); + hnsw.add_vector(vector, label)?; + *hnsw_label_counts.entry(label).or_insert(0) += 1; + } + } + + self.count.fetch_add(1, Ordering::Relaxed); + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let flat_count = self.flat_label_count(label); + let hnsw_count = self.hnsw_label_count(label); + + if flat_count == 0 && hnsw_count == 0 { + return Err(IndexError::LabelNotFound(label)); + } + + let mut deleted = 0; + + if flat_count > 0 { + let mut flat = self.flat.write(); + let mut flat_label_counts = self.flat_label_counts.write(); + if let Ok(n) = flat.delete_vector(label) { + deleted += n; + flat_label_counts.remove(&label); + } + } + + if hnsw_count > 0 { + let mut hnsw = self.hnsw.write(); + let mut hnsw_label_counts = self.hnsw_label_counts.write(); + if let Ok(n) = hnsw.delete_vector(label) { + deleted += n; + hnsw_label_counts.remove(&label); + } + } + + self.count.fetch_sub(deleted, Ordering::Relaxed); + Ok(deleted) + } + + fn top_k_query( + &self, + query: &[Self::DataType], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + let flat = self.flat.read(); + let hnsw = self.hnsw.read(); + + let flat_results = flat.top_k_query(query, k, params)?; + let hnsw_results = hnsw.top_k_query(query, k, params)?; + + Ok(merge_top_k_multi(flat_results, hnsw_results, k)) + } + + fn range_query( + &self, + query: &[Self::DataType], + radius: Self::DistType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + let flat = self.flat.read(); + let hnsw = self.hnsw.read(); + + let flat_results = flat.range_query(query, radius, params)?; + let hnsw_results = hnsw.range_query(query, radius, params)?; + + Ok(merge_range(flat_results, hnsw_results)) + } + + fn index_size(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[Self::DataType], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + Ok(Box::new( + super::batch_iterator::TieredMultiBatchIterator::new(self, query.to_vec(), params.cloned()), + )) + } + + fn info(&self) -> IndexInfo { + IndexInfo { + size: self.index_size(), + capacity: self.capacity, + dimension: self.params.dim, + index_type: "TieredMulti", + memory_bytes: self.memory_usage(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.flat_label_counts.read().contains_key(&label) + || self.hnsw_label_counts.read().contains_key(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + self.flat_label_count(label) + self.hnsw_label_count(label) + } +} + +// Serialization implementation for f32 +impl TieredMulti { + /// Save the index to a writer. + /// + /// The serialization format saves both tiers (flat buffer and HNSW) independently, + /// preserving the exact state of the index. + pub fn save( + &self, + writer: &mut W, + ) -> crate::serialization::SerializationResult<()> { + use crate::serialization::*; + + let count = self.count.load(Ordering::Relaxed); + + // Write header + let header = IndexHeader::new( + IndexTypeId::TieredMulti, + DataTypeId::F32, + self.params.metric, + self.params.dim, + count, + ); + header.write(writer)?; + + // Write tiered params + write_usize(writer, self.params.flat_buffer_limit)?; + write_u8(writer, if self.params.write_mode == WriteMode::InPlace { 1 } else { 0 })?; + write_usize(writer, self.params.initial_capacity)?; + + // Write HNSW params + write_usize(writer, self.params.hnsw_params.m)?; + write_usize(writer, self.params.hnsw_params.m_max_0)?; + write_usize(writer, self.params.hnsw_params.ef_construction)?; + write_usize(writer, self.params.hnsw_params.ef_runtime)?; + write_u8(writer, if self.params.hnsw_params.enable_heuristic { 1 } else { 0 })?; + + // Write capacity + write_u8(writer, if self.capacity.is_some() { 1 } else { 0 })?; + if let Some(cap) = self.capacity { + write_usize(writer, cap)?; + } + + // Write flat buffer state + let flat = self.flat.read(); + let flat_label_counts = self.flat_label_counts.read(); + + // Write number of unique labels in flat + write_usize(writer, flat_label_counts.len())?; + for (&label, &label_count) in flat_label_counts.iter() { + if let Some(vecs) = flat.get_vectors(label) { + write_u64(writer, label)?; + write_usize(writer, vecs.len())?; + for vec in vecs { + for &v in &vec { + write_f32(writer, v)?; + } + } + } else { + // Should not happen, but handle gracefully + write_u64(writer, label)?; + write_usize(writer, 0)?; + } + // Silence warning about unused label_count + let _ = label_count; + } + drop(flat); + drop(flat_label_counts); + + // Write HNSW state + let hnsw = self.hnsw.read(); + let hnsw_label_counts = self.hnsw_label_counts.read(); + + // Write number of unique labels in HNSW + write_usize(writer, hnsw_label_counts.len())?; + for (&label, &label_count) in hnsw_label_counts.iter() { + if let Some(vecs) = hnsw.get_vectors(label) { + write_u64(writer, label)?; + write_usize(writer, vecs.len())?; + for vec in vecs { + for &v in &vec { + write_f32(writer, v)?; + } + } + } else { + write_u64(writer, label)?; + write_usize(writer, 0)?; + } + let _ = label_count; + } + + Ok(()) + } + + /// Load the index from a reader. + pub fn load(reader: &mut R) -> crate::serialization::SerializationResult { + use crate::serialization::*; + + // Read header + let header = IndexHeader::read(reader)?; + + if header.index_type != IndexTypeId::TieredMulti { + return Err(SerializationError::IndexTypeMismatch { + expected: "TieredMulti".to_string(), + got: header.index_type.as_str().to_string(), + }); + } + + if header.data_type != DataTypeId::F32 { + return Err(SerializationError::InvalidData( + "Expected f32 data type".to_string(), + )); + } + + // Read tiered params + let flat_buffer_limit = read_usize(reader)?; + let write_mode = if read_u8(reader)? != 0 { + WriteMode::InPlace + } else { + WriteMode::Async + }; + let initial_capacity = read_usize(reader)?; + + // Read HNSW params + let m = read_usize(reader)?; + let m_max_0 = read_usize(reader)?; + let ef_construction = read_usize(reader)?; + let ef_runtime = read_usize(reader)?; + let enable_heuristic = read_u8(reader)? != 0; + + // Build params + let mut hnsw_params = crate::index::hnsw::HnswParams::new(header.dimension, header.metric) + .with_m(m) + .with_ef_construction(ef_construction) + .with_ef_runtime(ef_runtime); + hnsw_params.m_max_0 = m_max_0; + hnsw_params.enable_heuristic = enable_heuristic; + + let params = TieredParams { + dim: header.dimension, + metric: header.metric, + hnsw_params, + flat_buffer_limit, + write_mode, + initial_capacity, + }; + + let mut index = Self::new(params); + + // Read capacity + let has_capacity = read_u8(reader)? != 0; + if has_capacity { + index.capacity = Some(read_usize(reader)?); + } + + let mut total_count = 0; + + // Read flat buffer vectors + let flat_label_count = read_usize(reader)?; + { + let mut flat = index.flat.write(); + let mut flat_label_counts = index.flat_label_counts.write(); + for _ in 0..flat_label_count { + let label = read_u64(reader)?; + let vec_count = read_usize(reader)?; + for _ in 0..vec_count { + let mut vec = vec![0.0f32; header.dimension]; + for v in &mut vec { + *v = read_f32(reader)?; + } + flat.add_vector(&vec, label).map_err(|e| { + SerializationError::DataCorruption(format!("Failed to add vector: {e:?}")) + })?; + *flat_label_counts.entry(label).or_insert(0) += 1; + total_count += 1; + } + } + } + + // Read HNSW vectors + let hnsw_label_count = read_usize(reader)?; + { + let mut hnsw = index.hnsw.write(); + let mut hnsw_label_counts = index.hnsw_label_counts.write(); + for _ in 0..hnsw_label_count { + let label = read_u64(reader)?; + let vec_count = read_usize(reader)?; + for _ in 0..vec_count { + let mut vec = vec![0.0f32; header.dimension]; + for v in &mut vec { + *v = read_f32(reader)?; + } + hnsw.add_vector(&vec, label).map_err(|e| { + SerializationError::DataCorruption(format!("Failed to add vector: {e:?}")) + })?; + *hnsw_label_counts.entry(label).or_insert(0) += 1; + total_count += 1; + } + } + } + + // Set total count + index.count.store(total_count, Ordering::Relaxed); + + Ok(index) + } + + /// Save the index to a file. + pub fn save_to_file>( + &self, + path: P, + ) -> crate::serialization::SerializationResult<()> { + let file = std::fs::File::create(path)?; + let mut writer = std::io::BufWriter::new(file); + self.save(&mut writer) + } + + /// Load the index from a file. + pub fn load_from_file>( + path: P, + ) -> crate::serialization::SerializationResult { + let file = std::fs::File::open(path)?; + let mut reader = std::io::BufReader::new(file); + Self::load(&mut reader) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_tiered_multi_basic() { + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredMulti::::new(params); + + // Add multiple vectors with same label + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 3); + assert_eq!(index.label_count(1), 2); + assert_eq!(index.label_count(2), 1); + } + + #[test] + fn test_tiered_multi_delete() { + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + // Delete all vectors for label 1 + let deleted = index.delete_vector(1).unwrap(); + assert_eq!(deleted, 2); + assert_eq!(index.index_size(), 1); + assert_eq!(index.label_count(1), 0); + } + + #[test] + fn test_tiered_multi_flush() { + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredMulti::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.flat_size(), 3); + assert_eq!(index.hnsw_size(), 0); + + let migrated = index.flush().unwrap(); + assert_eq!(migrated, 3); + + assert_eq!(index.flat_size(), 0); + assert_eq!(index.hnsw_size(), 3); + assert_eq!(index.label_count(1), 2); + } + + #[test] + fn test_tiered_multi_serialization() { + use std::io::Cursor; + + let params = TieredParams::new(4, Metric::L2) + .with_flat_buffer_limit(10) + .with_m(8) + .with_ef_construction(50); + let mut index = TieredMulti::::new(params); + + // Add multiple vectors per label to flat buffer + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.9, 0.1, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + // Flush to HNSW + index.flush().unwrap(); + + // Add more to flat + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 0.0, 0.0, 1.0], 3).unwrap(); + + assert_eq!(index.flat_size(), 2); + assert_eq!(index.hnsw_size(), 3); + assert_eq!(index.index_size(), 5); + assert_eq!(index.label_count(1), 3); // 2 in HNSW + 1 in flat + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = TieredMulti::::load(&mut cursor).unwrap(); + + // Verify state + assert_eq!(loaded.index_size(), 5); + assert_eq!(loaded.flat_size(), 2); + assert_eq!(loaded.hnsw_size(), 3); + assert_eq!(loaded.dimension(), 4); + assert_eq!(loaded.label_count(1), 3); + assert_eq!(loaded.label_count(2), 1); + assert_eq!(loaded.label_count(3), 1); + + // Verify vectors can be queried + // Multi-value indices deduplicate by label, so we get 3 unique labels (1, 2, 3) + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = loaded.top_k_query(&query, 5, None).unwrap(); + assert_eq!(results.len(), 3); + } + + #[test] + fn test_tiered_multi_serialization_file() { + use std::fs; + + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredMulti::::new(params); + + // Add multiple vectors per label + for i in 0..5 { + for j in 0..3 { + let v = vec![i as f32 + j as f32 * 0.1, 0.0, 0.0, 0.0]; + index.add_vector(&v, i as u64).unwrap(); + } + } + + assert_eq!(index.index_size(), 15); + + // Flush + index.flush().unwrap(); + + // Add more + index.add_vector(&vec![10.0, 0.0, 0.0, 0.0], 10).unwrap(); + + let path = "/tmp/tiered_multi_test.idx"; + index.save_to_file(path).unwrap(); + + let loaded = TieredMulti::::load_from_file(path).unwrap(); + + assert_eq!(loaded.index_size(), 16); + assert_eq!(loaded.label_count(0), 3); + assert_eq!(loaded.label_count(10), 1); + assert!(loaded.contains(4)); + + // Cleanup + fs::remove_file(path).ok(); + } +} diff --git a/rust/vecsim/src/index/tiered/single.rs b/rust/vecsim/src/index/tiered/single.rs new file mode 100644 index 000000000..560f13c27 --- /dev/null +++ b/rust/vecsim/src/index/tiered/single.rs @@ -0,0 +1,914 @@ +//! Single-value tiered index implementation. +//! +//! This index stores one vector per label, combining a BruteForce frontend +//! (for fast writes) with an HNSW backend (for efficient queries). + +use super::{merge_range, merge_top_k, TieredParams, WriteMode}; +use crate::index::brute_force::{BruteForceParams, BruteForceSingle}; +use crate::index::hnsw::HnswSingle; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply}; +use crate::types::{LabelType, VectorElement}; +use parking_lot::RwLock; +use std::collections::HashSet; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Statistics about a tiered index. +#[derive(Debug, Clone)] +pub struct TieredStats { + /// Number of vectors in flat buffer. + pub flat_size: usize, + /// Number of vectors in HNSW backend. + pub hnsw_size: usize, + /// Total vector count. + pub total_size: usize, + /// Current write mode. + pub write_mode: WriteMode, + /// Flat buffer fragmentation. + pub flat_fragmentation: f64, + /// HNSW fragmentation. + pub hnsw_fragmentation: f64, + /// Approximate memory usage in bytes. + pub memory_bytes: usize, +} + +/// Single-value tiered index combining BruteForce frontend with HNSW backend. +/// +/// The tiered architecture provides: +/// - **Fast writes**: New vectors go to the flat BruteForce buffer +/// - **Efficient queries**: Results are merged from both tiers +/// - **Flexible migration**: Use `flush()` to migrate vectors to HNSW +/// +/// # Write Behavior +/// +/// In **Async** mode (default): +/// - New vectors are added to the flat buffer +/// - When buffer reaches `flat_buffer_limit`, mode switches to InPlace +/// - Call `flush()` to migrate all vectors to HNSW and reset buffer +/// +/// In **InPlace** mode: +/// - Vectors are added directly to HNSW +/// - Use this when you don't need the write buffer +pub struct TieredSingle { + /// Flat BruteForce buffer (frontend). + pub(crate) flat: RwLock>, + /// HNSW index (backend). + pub(crate) hnsw: RwLock>, + /// Labels currently in flat buffer. + flat_labels: RwLock>, + /// Labels currently in HNSW. + hnsw_labels: RwLock>, + /// Configuration parameters. + params: TieredParams, + /// Total vector count. + count: AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, +} + +impl TieredSingle { + /// Create a new single-value tiered index. + pub fn new(params: TieredParams) -> Self { + let bf_params = BruteForceParams::new(params.dim, params.metric) + .with_capacity(params.flat_buffer_limit.min(params.initial_capacity)); + let flat = BruteForceSingle::new(bf_params); + let hnsw = HnswSingle::new(params.hnsw_params.clone()); + + Self { + flat: RwLock::new(flat), + hnsw: RwLock::new(hnsw), + flat_labels: RwLock::new(HashSet::new()), + hnsw_labels: RwLock::new(HashSet::new()), + params, + count: AtomicUsize::new(0), + capacity: None, + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: TieredParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the current write mode. + pub fn write_mode(&self) -> WriteMode { + if self.params.write_mode == WriteMode::InPlace { + return WriteMode::InPlace; + } + // In Async mode, switch to InPlace if flat buffer is full + if self.flat_size() >= self.params.flat_buffer_limit { + WriteMode::InPlace + } else { + WriteMode::Async + } + } + + /// Get the number of vectors in the flat buffer. + pub fn flat_size(&self) -> usize { + self.flat.read().index_size() + } + + /// Get the number of vectors in the HNSW backend. + pub fn hnsw_size(&self) -> usize { + self.hnsw.read().index_size() + } + + /// Get detailed statistics about the index. + pub fn stats(&self) -> TieredStats { + let flat = self.flat.read(); + let hnsw = self.hnsw.read(); + + TieredStats { + flat_size: flat.index_size(), + hnsw_size: hnsw.index_size(), + total_size: self.count.load(Ordering::Relaxed), + write_mode: self.write_mode(), + flat_fragmentation: flat.fragmentation(), + hnsw_fragmentation: hnsw.fragmentation(), + memory_bytes: flat.memory_usage() + hnsw.memory_usage(), + } + } + + /// Get the memory usage in bytes. + pub fn memory_usage(&self) -> usize { + self.flat.read().memory_usage() + self.hnsw.read().memory_usage() + } + + /// Check if a label exists in the flat buffer. + pub fn is_in_flat(&self, label: LabelType) -> bool { + self.flat_labels.read().contains(&label) + } + + /// Check if a label exists in the HNSW backend. + pub fn is_in_hnsw(&self, label: LabelType) -> bool { + self.hnsw_labels.read().contains(&label) + } + + /// Flush all vectors from flat buffer to HNSW. + /// + /// This migrates all vectors from the flat buffer to the HNSW backend, + /// clearing the flat buffer afterward. + /// + /// # Returns + /// The number of vectors migrated. + pub fn flush(&mut self) -> Result { + let flat_labels: Vec = self.flat_labels.read().iter().copied().collect(); + + if flat_labels.is_empty() { + return Ok(0); + } + + let mut migrated = 0; + + // Collect vectors from flat buffer + let vectors: Vec<(LabelType, Vec)> = { + let flat = self.flat.read(); + flat_labels + .iter() + .filter_map(|&label| flat.get_vector(label).map(|v| (label, v))) + .collect() + }; + + // Add to HNSW + { + let mut hnsw = self.hnsw.write(); + let mut hnsw_labels = self.hnsw_labels.write(); + + for (label, vector) in &vectors { + match hnsw.add_vector(vector, *label) { + Ok(_) => { + hnsw_labels.insert(*label); + migrated += 1; + } + Err(e) => { + // Log error but continue + eprintln!("Failed to migrate label {label} to HNSW: {e:?}"); + } + } + } + } + + // Clear flat buffer + { + let mut flat = self.flat.write(); + let mut flat_labels = self.flat_labels.write(); + + flat.clear(); + flat_labels.clear(); + } + + Ok(migrated) + } + + /// Compact both tiers to reclaim space from deleted vectors. + /// + /// # Arguments + /// * `shrink` - If true, also release unused memory + /// + /// # Returns + /// Approximate bytes reclaimed. + pub fn compact(&mut self, shrink: bool) -> usize { + let flat_reclaimed = self.flat.write().compact(shrink); + let hnsw_reclaimed = self.hnsw.write().compact(shrink); + flat_reclaimed + hnsw_reclaimed + } + + /// Get the fragmentation ratio (0.0 = none, 1.0 = all deleted). + pub fn fragmentation(&self) -> f64 { + let flat = self.flat.read(); + let hnsw = self.hnsw.read(); + let flat_frag = flat.fragmentation(); + let hnsw_frag = hnsw.fragmentation(); + + // Weighted average by size + let flat_size = flat.index_size() as f64; + let hnsw_size = hnsw.index_size() as f64; + let total = flat_size + hnsw_size; + + if total == 0.0 { + 0.0 + } else { + (flat_frag * flat_size + hnsw_frag * hnsw_size) / total + } + } + + /// Get a copy of the vector stored for a given label. + pub fn get_vector(&self, label: LabelType) -> Option> { + // Check flat first + if self.flat_labels.read().contains(&label) { + return self.flat.read().get_vector(label); + } + // Then check HNSW + if self.hnsw_labels.read().contains(&label) { + return self.hnsw.read().get_vector(label); + } + None + } + + /// Clear all vectors from both tiers. + pub fn clear(&mut self) { + self.flat.write().clear(); + self.hnsw.write().clear(); + self.flat_labels.write().clear(); + self.hnsw_labels.write().clear(); + self.count.store(0, Ordering::Relaxed); + } + + /// Get the neighbors of an element in the HNSW graph by label. + /// + /// Returns a vector of vectors, where each inner vector contains the neighbor labels + /// for that level (level 0 first). Returns None if the label is not in the HNSW backend. + pub fn get_element_neighbors(&self, label: LabelType) -> Option>> { + // Only check HNSW backend - flat buffer doesn't have graph structure + if self.hnsw_labels.read().contains(&label) { + return self.hnsw.read().get_element_neighbors(label); + } + None + } +} + +impl VecSimIndex for TieredSingle { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + // Validate dimension + if vector.len() != self.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.params.dim, + got: vector.len(), + }); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + let in_flat = self.flat_labels.read().contains(&label); + let in_hnsw = self.hnsw_labels.read().contains(&label); + + // Update existing vector + if in_flat { + // Update in flat buffer (replace) + let mut flat = self.flat.write(); + flat.add_vector(vector, label)?; + return Ok(0); // No new vector + } + + if in_hnsw { + // For single-value: update means replace + // Delete from HNSW and add to flat (or direct to HNSW based on mode) + { + let mut hnsw = self.hnsw.write(); + let mut hnsw_labels = self.hnsw_labels.write(); + hnsw.delete_vector(label)?; + hnsw_labels.remove(&label); + } + self.count.fetch_sub(1, Ordering::Relaxed); + + // Now add new vector (fall through to new vector logic) + } + + // Add new vector based on write mode + match self.write_mode() { + WriteMode::Async => { + let mut flat = self.flat.write(); + let mut flat_labels = self.flat_labels.write(); + flat.add_vector(vector, label)?; + flat_labels.insert(label); + } + WriteMode::InPlace => { + let mut hnsw = self.hnsw.write(); + let mut hnsw_labels = self.hnsw_labels.write(); + hnsw.add_vector(vector, label)?; + hnsw_labels.insert(label); + } + } + + self.count.fetch_add(1, Ordering::Relaxed); + Ok(1) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let in_flat = self.flat_labels.read().contains(&label); + let in_hnsw = self.hnsw_labels.read().contains(&label); + + if !in_flat && !in_hnsw { + return Err(IndexError::LabelNotFound(label)); + } + + let mut deleted = 0; + + if in_flat { + let mut flat = self.flat.write(); + let mut flat_labels = self.flat_labels.write(); + if flat.delete_vector(label).is_ok() { + flat_labels.remove(&label); + deleted += 1; + } + } + + if in_hnsw { + let mut hnsw = self.hnsw.write(); + let mut hnsw_labels = self.hnsw_labels.write(); + if hnsw.delete_vector(label).is_ok() { + hnsw_labels.remove(&label); + deleted += 1; + } + } + + self.count.fetch_sub(deleted, Ordering::Relaxed); + Ok(deleted) + } + + fn top_k_query( + &self, + query: &[Self::DataType], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + // Lock ordering: flat first, then HNSW + let flat = self.flat.read(); + let hnsw = self.hnsw.read(); + + // Query both tiers + let flat_results = flat.top_k_query(query, k, params)?; + let hnsw_results = hnsw.top_k_query(query, k, params)?; + + // Merge results + Ok(merge_top_k(flat_results, hnsw_results, k)) + } + + fn range_query( + &self, + query: &[Self::DataType], + radius: Self::DistType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + // Lock ordering: flat first, then HNSW + let flat = self.flat.read(); + let hnsw = self.hnsw.read(); + + // Query both tiers + let flat_results = flat.range_query(query, radius, params)?; + let hnsw_results = hnsw.range_query(query, radius, params)?; + + // Merge results + Ok(merge_range(flat_results, hnsw_results)) + } + + fn index_size(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[Self::DataType], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + Ok(Box::new(super::batch_iterator::TieredBatchIterator::new( + self, + query.to_vec(), + params.cloned(), + ))) + } + + fn info(&self) -> IndexInfo { + IndexInfo { + size: self.index_size(), + capacity: self.capacity, + dimension: self.params.dim, + index_type: "TieredSingle", + memory_bytes: self.memory_usage(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.flat_labels.read().contains(&label) || self.hnsw_labels.read().contains(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + if self.contains(label) { + 1 + } else { + 0 + } + } +} + +// Serialization implementation for f32 +impl TieredSingle { + /// Save the index to a writer. + /// + /// The serialization format saves both tiers (flat buffer and HNSW) independently, + /// preserving the exact state of the index. + pub fn save( + &self, + writer: &mut W, + ) -> crate::serialization::SerializationResult<()> { + use crate::serialization::*; + + let count = self.count.load(Ordering::Relaxed); + + // Write header + let header = IndexHeader::new( + IndexTypeId::TieredSingle, + DataTypeId::F32, + self.params.metric, + self.params.dim, + count, + ); + header.write(writer)?; + + // Write tiered params + write_usize(writer, self.params.flat_buffer_limit)?; + write_u8(writer, if self.params.write_mode == WriteMode::InPlace { 1 } else { 0 })?; + write_usize(writer, self.params.initial_capacity)?; + + // Write HNSW params + write_usize(writer, self.params.hnsw_params.m)?; + write_usize(writer, self.params.hnsw_params.m_max_0)?; + write_usize(writer, self.params.hnsw_params.ef_construction)?; + write_usize(writer, self.params.hnsw_params.ef_runtime)?; + write_u8(writer, if self.params.hnsw_params.enable_heuristic { 1 } else { 0 })?; + + // Write capacity + write_u8(writer, if self.capacity.is_some() { 1 } else { 0 })?; + if let Some(cap) = self.capacity { + write_usize(writer, cap)?; + } + + // Write flat buffer state + let flat = self.flat.read(); + let flat_labels = self.flat_labels.read(); + let flat_count = flat.index_size(); + + write_usize(writer, flat_count)?; + for &label in flat_labels.iter() { + if let Some(vec) = flat.get_vector(label) { + write_u64(writer, label)?; + for &v in &vec { + write_f32(writer, v)?; + } + } + } + drop(flat); + drop(flat_labels); + + // Write HNSW state + let hnsw = self.hnsw.read(); + let hnsw_labels = self.hnsw_labels.read(); + let hnsw_count = hnsw.index_size(); + + write_usize(writer, hnsw_count)?; + for &label in hnsw_labels.iter() { + if let Some(vec) = hnsw.get_vector(label) { + write_u64(writer, label)?; + for &v in &vec { + write_f32(writer, v)?; + } + } + } + + Ok(()) + } + + /// Load the index from a reader. + pub fn load(reader: &mut R) -> crate::serialization::SerializationResult { + use crate::serialization::*; + + // Read header + let header = IndexHeader::read(reader)?; + + if header.index_type != IndexTypeId::TieredSingle { + return Err(SerializationError::IndexTypeMismatch { + expected: "TieredSingle".to_string(), + got: header.index_type.as_str().to_string(), + }); + } + + if header.data_type != DataTypeId::F32 { + return Err(SerializationError::InvalidData( + "Expected f32 data type".to_string(), + )); + } + + // Read tiered params + let flat_buffer_limit = read_usize(reader)?; + let write_mode = if read_u8(reader)? != 0 { + WriteMode::InPlace + } else { + WriteMode::Async + }; + let initial_capacity = read_usize(reader)?; + + // Read HNSW params + let m = read_usize(reader)?; + let m_max_0 = read_usize(reader)?; + let ef_construction = read_usize(reader)?; + let ef_runtime = read_usize(reader)?; + let enable_heuristic = read_u8(reader)? != 0; + + // Build params + let mut hnsw_params = crate::index::hnsw::HnswParams::new(header.dimension, header.metric) + .with_m(m) + .with_ef_construction(ef_construction) + .with_ef_runtime(ef_runtime); + hnsw_params.m_max_0 = m_max_0; + hnsw_params.enable_heuristic = enable_heuristic; + + let params = TieredParams { + dim: header.dimension, + metric: header.metric, + hnsw_params, + flat_buffer_limit, + write_mode, + initial_capacity, + }; + + let mut index = Self::new(params); + + // Read capacity + let has_capacity = read_u8(reader)? != 0; + if has_capacity { + index.capacity = Some(read_usize(reader)?); + } + + // Read flat buffer vectors + let flat_count = read_usize(reader)?; + { + let mut flat = index.flat.write(); + let mut flat_labels = index.flat_labels.write(); + for _ in 0..flat_count { + let label = read_u64(reader)?; + let mut vec = vec![0.0f32; header.dimension]; + for v in &mut vec { + *v = read_f32(reader)?; + } + flat.add_vector(&vec, label).map_err(|e| { + SerializationError::DataCorruption(format!("Failed to add vector: {e:?}")) + })?; + flat_labels.insert(label); + } + } + + // Read HNSW vectors + let hnsw_count = read_usize(reader)?; + { + let mut hnsw = index.hnsw.write(); + let mut hnsw_labels = index.hnsw_labels.write(); + for _ in 0..hnsw_count { + let label = read_u64(reader)?; + let mut vec = vec![0.0f32; header.dimension]; + for v in &mut vec { + *v = read_f32(reader)?; + } + hnsw.add_vector(&vec, label).map_err(|e| { + SerializationError::DataCorruption(format!("Failed to add vector: {e:?}")) + })?; + hnsw_labels.insert(label); + } + } + + // Set total count + index.count.store(flat_count + hnsw_count, Ordering::Relaxed); + + Ok(index) + } + + /// Save the index to a file. + pub fn save_to_file>( + &self, + path: P, + ) -> crate::serialization::SerializationResult<()> { + let file = std::fs::File::create(path)?; + let mut writer = std::io::BufWriter::new(file); + self.save(&mut writer) + } + + /// Load the index from a file. + pub fn load_from_file>( + path: P, + ) -> crate::serialization::SerializationResult { + let file = std::fs::File::open(path)?; + let mut reader = std::io::BufReader::new(file); + Self::load(&mut reader) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_tiered_single_basic() { + let params = TieredParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSingle::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 1.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + assert_eq!(index.index_size(), 3); + assert_eq!(index.flat_size(), 3); + assert_eq!(index.hnsw_size(), 0); + + // Query should find all vectors + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + assert_eq!(results.len(), 3); + assert_eq!(results.results[0].label, 1); // Closest + } + + #[test] + fn test_tiered_single_query_both_tiers() { + let params = TieredParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSingle::::new(params); + + // Add to flat + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Flush to HNSW + let migrated = index.flush().unwrap(); + assert_eq!(migrated, 1); + + // Add more to flat + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.flat_size(), 1); + assert_eq!(index.hnsw_size(), 1); + assert_eq!(index.index_size(), 2); + + // Query should find both + let query = vec![0.5, 0.5, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.len(), 2); + } + + #[test] + fn test_tiered_single_delete() { + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredSingle::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + // Delete from flat + index.delete_vector(1).unwrap(); + assert_eq!(index.index_size(), 1); + assert!(!index.contains(1)); + assert!(index.contains(2)); + } + + #[test] + fn test_tiered_single_flush() { + let params = TieredParams::new(4, Metric::L2).with_flat_buffer_limit(100); + let mut index = TieredSingle::::new(params); + + // Add several vectors + for i in 0..10 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.flat_size(), 10); + assert_eq!(index.hnsw_size(), 0); + + // Flush + let migrated = index.flush().unwrap(); + assert_eq!(migrated, 10); + + assert_eq!(index.flat_size(), 0); + assert_eq!(index.hnsw_size(), 10); + assert_eq!(index.index_size(), 10); + + // Query should still work + let results = index.top_k_query(&vec![5.0, 0.0, 0.0, 0.0], 3, None).unwrap(); + assert_eq!(results.len(), 3); + assert_eq!(results.results[0].label, 5); + } + + #[test] + fn test_tiered_single_in_place_mode() { + let params = TieredParams::new(4, Metric::L2) + .with_flat_buffer_limit(2) + .with_write_mode(WriteMode::Async); + let mut index = TieredSingle::::new(params); + + // Fill flat buffer + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.write_mode(), WriteMode::InPlace); + + // Next add should go directly to HNSW + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + assert_eq!(index.flat_size(), 2); + assert_eq!(index.hnsw_size(), 1); + assert_eq!(index.index_size(), 3); + } + + #[test] + fn test_tiered_single_compact() { + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredSingle::::new(params); + + // Add and delete + for i in 0..10 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64).unwrap(); + } + + for i in (0..10).step_by(2) { + index.delete_vector(i as u64).unwrap(); + } + + assert!(index.fragmentation() > 0.0); + + // Compact + index.compact(true); + + assert!((index.fragmentation() - 0.0).abs() < 0.01); + } + + #[test] + fn test_tiered_single_replace() { + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredSingle::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Replace with new vector + index.add_vector(&v2, 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Should return the new vector + let query = vec![0.0, 1.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 1, None).unwrap(); + assert!((results.results[0].distance as f64) < 0.001); + } + + #[test] + fn test_tiered_single_serialization() { + use std::io::Cursor; + + let params = TieredParams::new(4, Metric::L2) + .with_flat_buffer_limit(5) + .with_m(8) + .with_ef_construction(50); + let mut index = TieredSingle::::new(params); + + // Add vectors to flat buffer + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + // Flush some to HNSW + index.flush().unwrap(); + + // Add more to flat + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + assert_eq!(index.flat_size(), 1); + assert_eq!(index.hnsw_size(), 2); + assert_eq!(index.index_size(), 3); + + // Serialize + let mut buffer = Vec::new(); + index.save(&mut buffer).unwrap(); + + // Deserialize + let mut cursor = Cursor::new(buffer); + let loaded = TieredSingle::::load(&mut cursor).unwrap(); + + // Verify state + assert_eq!(loaded.index_size(), 3); + assert_eq!(loaded.flat_size(), 1); + assert_eq!(loaded.hnsw_size(), 2); + assert_eq!(loaded.dimension(), 4); + + // Verify vectors can be queried + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = loaded.top_k_query(&query, 3, None).unwrap(); + assert_eq!(results.len(), 3); + assert_eq!(results.results[0].label, 1); // Closest to query + } + + #[test] + fn test_tiered_single_serialization_file() { + use std::fs; + + let params = TieredParams::new(4, Metric::L2); + let mut index = TieredSingle::::new(params); + + for i in 0..10 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i as u64).unwrap(); + } + + // Flush half + index.flush().unwrap(); + + for i in 10..15 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i as u64).unwrap(); + } + + let path = "/tmp/tiered_single_test.idx"; + index.save_to_file(path).unwrap(); + + let loaded = TieredSingle::::load_from_file(path).unwrap(); + + assert_eq!(loaded.index_size(), 15); + assert!(loaded.contains(0)); + assert!(loaded.contains(14)); + + // Cleanup + fs::remove_file(path).ok(); + } +} diff --git a/rust/vecsim/src/index/tiered_svs/batch_iterator.rs b/rust/vecsim/src/index/tiered_svs/batch_iterator.rs new file mode 100644 index 000000000..a7881e913 --- /dev/null +++ b/rust/vecsim/src/index/tiered_svs/batch_iterator.rs @@ -0,0 +1,101 @@ +//! Batch iterator implementations for tiered SVS indices. +//! +//! These iterators hold pre-computed results from both flat buffer and SVS backend, +//! returning them in sorted order by distance. + +use crate::index::traits::BatchIterator; +use crate::types::{IdType, LabelType, VectorElement}; + +/// Batch iterator for single-value tiered SVS index. +/// +/// Holds pre-computed, sorted results from both flat buffer and SVS backend. +pub struct TieredSvsBatchIterator { + /// All results sorted by distance. + results: Vec<(IdType, LabelType, T::DistanceType)>, + /// Current position in results. + position: usize, +} + +impl TieredSvsBatchIterator { + /// Create a new batch iterator with pre-computed results. + pub fn new(results: Vec<(IdType, LabelType, T::DistanceType)>) -> Self { + Self { + results, + position: 0, + } + } +} + +impl BatchIterator for TieredSvsBatchIterator { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + self.position < self.results.len() + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + if self.position >= self.results.len() { + return None; + } + + let end = (self.position + batch_size).min(self.results.len()); + let batch = self.results[self.position..end].to_vec(); + self.position = end; + + Some(batch) + } + + fn reset(&mut self) { + self.position = 0; + } +} + +/// Batch iterator for multi-value tiered SVS index. +/// +/// Holds pre-computed, sorted results from both flat buffer and SVS backend. +pub struct TieredSvsMultiBatchIterator { + /// All results sorted by distance. + results: Vec<(IdType, LabelType, T::DistanceType)>, + /// Current position in results. + position: usize, +} + +impl TieredSvsMultiBatchIterator { + /// Create a new batch iterator with pre-computed results. + pub fn new(results: Vec<(IdType, LabelType, T::DistanceType)>) -> Self { + Self { + results, + position: 0, + } + } +} + +impl BatchIterator for TieredSvsMultiBatchIterator { + type DistType = T::DistanceType; + + fn has_next(&self) -> bool { + self.position < self.results.len() + } + + fn next_batch( + &mut self, + batch_size: usize, + ) -> Option> { + if self.position >= self.results.len() { + return None; + } + + let end = (self.position + batch_size).min(self.results.len()); + let batch = self.results[self.position..end].to_vec(); + self.position = end; + + Some(batch) + } + + fn reset(&mut self) { + self.position = 0; + } +} diff --git a/rust/vecsim/src/index/tiered_svs/mod.rs b/rust/vecsim/src/index/tiered_svs/mod.rs new file mode 100644 index 000000000..260669aa7 --- /dev/null +++ b/rust/vecsim/src/index/tiered_svs/mod.rs @@ -0,0 +1,198 @@ +//! Tiered index combining BruteForce frontend with SVS (Vamana) backend. +//! +//! Similar to the HNSW-backed tiered index, but uses the single-layer Vamana +//! graph as the backend for approximate nearest neighbor search. +//! +//! # Architecture +//! +//! ```text +//! TieredSvsIndex +//! ├── Frontend: BruteForce index (fast write buffer) +//! ├── Backend: SVS (Vamana) index (efficient approximate search) +//! └── Query: Searches both tiers, merges results +//! ``` +//! +//! # Write Modes +//! +//! - **Async**: Vectors added to flat buffer, later migrated to SVS via `flush()` +//! - **InPlace**: Vectors added directly to SVS (used when buffer is full) + +pub mod batch_iterator; +pub mod multi; +pub mod single; + +pub use batch_iterator::TieredSvsBatchIterator; +pub use multi::TieredSvsMulti; +pub use single::TieredSvsSingle; + +use crate::distance::Metric; +use crate::index::svs::SvsParams; + +/// Write mode for the tiered SVS index. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SvsWriteMode { + /// Async mode: vectors go to flat buffer, migrated to SVS via flush(). + #[default] + Async, + /// InPlace mode: vectors go directly to SVS. + InPlace, +} + +/// Configuration parameters for TieredSvsIndex. +#[derive(Debug, Clone)] +pub struct TieredSvsParams { + /// Vector dimension. + pub dim: usize, + /// Distance metric. + pub metric: Metric, + /// Parameters for the SVS backend index. + pub svs_params: SvsParams, + /// Maximum size of the flat buffer before forcing in-place writes. + /// When flat buffer reaches this limit, new writes go directly to SVS. + pub flat_buffer_limit: usize, + /// Write mode for the index. + pub write_mode: SvsWriteMode, + /// Initial capacity hint. + pub initial_capacity: usize, +} + +impl TieredSvsParams { + /// Create new tiered SVS index parameters. + /// + /// # Arguments + /// * `dim` - Vector dimension + /// * `metric` - Distance metric (L2, InnerProduct, Cosine) + pub fn new(dim: usize, metric: Metric) -> Self { + Self { + dim, + metric, + svs_params: SvsParams::new(dim, metric), + flat_buffer_limit: 10_000, + write_mode: SvsWriteMode::Async, + initial_capacity: 1000, + } + } + + /// Set the graph max degree (R) parameter for SVS. + pub fn with_graph_degree(mut self, r: usize) -> Self { + self.svs_params = self.svs_params.with_graph_degree(r); + self + } + + /// Set the alpha parameter for robust pruning. + pub fn with_alpha(mut self, alpha: f32) -> Self { + self.svs_params = self.svs_params.with_alpha(alpha); + self + } + + /// Set the construction window size (L) for SVS. + pub fn with_construction_l(mut self, l: usize) -> Self { + self.svs_params = self.svs_params.with_construction_l(l); + self + } + + /// Set the search window size for SVS queries. + pub fn with_search_l(mut self, l: usize) -> Self { + self.svs_params = self.svs_params.with_search_l(l); + self + } + + /// Set the flat buffer limit. + /// + /// When the flat buffer reaches this size, the index switches to + /// in-place writes (directly to SVS) until flush() is called. + pub fn with_flat_buffer_limit(mut self, limit: usize) -> Self { + self.flat_buffer_limit = limit; + self + } + + /// Set the write mode. + pub fn with_write_mode(mut self, mode: SvsWriteMode) -> Self { + self.write_mode = mode; + self + } + + /// Set the initial capacity hint. + pub fn with_initial_capacity(mut self, capacity: usize) -> Self { + self.initial_capacity = capacity; + self.svs_params = self.svs_params.with_capacity(capacity); + self + } + + /// Enable/disable two-pass construction for SVS. + pub fn with_two_pass(mut self, enable: bool) -> Self { + self.svs_params = self.svs_params.with_two_pass(enable); + self + } +} + +/// Merge two sorted query replies, keeping the top k results. +/// +/// Both replies are assumed to be sorted by distance (ascending). +pub fn merge_top_k( + flat_results: crate::query::QueryReply, + svs_results: crate::query::QueryReply, + k: usize, +) -> crate::query::QueryReply { + use crate::query::QueryReply; + + // Fast paths + if flat_results.is_empty() { + let mut results = svs_results; + results.results.truncate(k); + return results; + } + + if svs_results.is_empty() { + let mut results = flat_results; + results.results.truncate(k); + return results; + } + + // Merge both result sets + let mut merged = QueryReply::with_capacity(flat_results.len() + svs_results.len()); + + // Add all results + for result in flat_results.results { + merged.push(result); + } + for result in svs_results.results { + merged.push(result); + } + + // Sort by distance and truncate + merged.sort_by_distance(); + merged.results.truncate(k); + + merged +} + +/// Merge two query replies for range queries. +/// +/// Combines all results from both tiers. +pub fn merge_range( + flat_results: crate::query::QueryReply, + svs_results: crate::query::QueryReply, +) -> crate::query::QueryReply { + use crate::query::QueryReply; + + if flat_results.is_empty() { + return svs_results; + } + + if svs_results.is_empty() { + return flat_results; + } + + let mut merged = QueryReply::with_capacity(flat_results.len() + svs_results.len()); + + for result in flat_results.results { + merged.push(result); + } + for result in svs_results.results { + merged.push(result); + } + + merged.sort_by_distance(); + merged +} diff --git a/rust/vecsim/src/index/tiered_svs/multi.rs b/rust/vecsim/src/index/tiered_svs/multi.rs new file mode 100644 index 000000000..1211af94a --- /dev/null +++ b/rust/vecsim/src/index/tiered_svs/multi.rs @@ -0,0 +1,676 @@ +//! Multi-value tiered SVS index implementation. +//! +//! This index allows multiple vectors per label, combining a BruteForce frontend +//! (for fast writes) with an SVS (Vamana) backend (for efficient queries). + +use super::{merge_range, merge_top_k, TieredSvsParams, SvsWriteMode}; +use crate::index::brute_force::{BruteForceMulti, BruteForceParams}; +use crate::index::svs::SvsMulti; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply}; +use crate::types::{DistanceType, LabelType, VectorElement}; +use parking_lot::RwLock; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Statistics about a tiered SVS multi index. +#[derive(Debug, Clone)] +pub struct TieredSvsMultiStats { + /// Number of vectors in flat buffer. + pub flat_size: usize, + /// Number of vectors in SVS backend. + pub svs_size: usize, + /// Total vector count. + pub total_size: usize, + /// Current write mode. + pub write_mode: SvsWriteMode, + /// Flat buffer fragmentation. + pub flat_fragmentation: f64, + /// SVS fragmentation. + pub svs_fragmentation: f64, + /// Approximate memory usage in bytes. + pub memory_bytes: usize, +} + +/// Multi-value tiered SVS index combining BruteForce frontend with SVS backend. +/// +/// Unlike TieredSvsSingle, this allows multiple vectors per label. +/// When deleting, all vectors with that label are removed from both tiers. +pub struct TieredSvsMulti { + /// Flat BruteForce buffer (frontend). + pub(crate) flat: RwLock>, + /// SVS index (backend). + pub(crate) svs: RwLock>, + /// Count of vectors per label in flat buffer. + flat_label_counts: RwLock>, + /// Count of vectors per label in SVS. + svs_label_counts: RwLock>, + /// Configuration parameters. + params: TieredSvsParams, + /// Total vector count. + count: AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, +} + +impl TieredSvsMulti { + /// Create a new multi-value tiered SVS index. + pub fn new(params: TieredSvsParams) -> Self { + let bf_params = BruteForceParams::new(params.dim, params.metric) + .with_capacity(params.flat_buffer_limit.min(params.initial_capacity)); + let flat = BruteForceMulti::new(bf_params); + let svs = SvsMulti::new(params.svs_params.clone()); + + Self { + flat: RwLock::new(flat), + svs: RwLock::new(svs), + flat_label_counts: RwLock::new(HashMap::new()), + svs_label_counts: RwLock::new(HashMap::new()), + params, + count: AtomicUsize::new(0), + capacity: None, + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: TieredSvsParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the current write mode. + pub fn write_mode(&self) -> SvsWriteMode { + if self.params.write_mode == SvsWriteMode::InPlace { + return SvsWriteMode::InPlace; + } + // In Async mode, switch to InPlace if flat buffer is full + if self.flat_size() >= self.params.flat_buffer_limit { + SvsWriteMode::InPlace + } else { + SvsWriteMode::Async + } + } + + /// Get the number of vectors in the flat buffer. + pub fn flat_size(&self) -> usize { + self.flat.read().index_size() + } + + /// Get the number of vectors in the SVS backend. + pub fn svs_size(&self) -> usize { + self.svs.read().index_size() + } + + /// Get detailed statistics about the index. + pub fn stats(&self) -> TieredSvsMultiStats { + let flat = self.flat.read(); + let svs = self.svs.read(); + + TieredSvsMultiStats { + flat_size: flat.index_size(), + svs_size: svs.index_size(), + total_size: self.count.load(Ordering::Relaxed), + write_mode: self.write_mode(), + flat_fragmentation: flat.fragmentation(), + svs_fragmentation: svs.fragmentation(), + memory_bytes: flat.memory_usage() + svs.memory_usage(), + } + } + + /// Get the memory usage in bytes. + pub fn memory_usage(&self) -> usize { + self.flat.read().memory_usage() + self.svs.read().memory_usage() + } + + /// Check if a label exists in the flat buffer. + pub fn is_in_flat(&self, label: LabelType) -> bool { + self.flat_label_counts + .read() + .get(&label) + .is_some_and(|&c| c > 0) + } + + /// Check if a label exists in the SVS backend. + pub fn is_in_svs(&self, label: LabelType) -> bool { + self.svs_label_counts + .read() + .get(&label) + .is_some_and(|&c| c > 0) + } + + /// Flush all vectors from flat buffer to SVS. + /// + /// This migrates all vectors from the flat buffer to the SVS backend, + /// clearing the flat buffer afterward. + /// + /// # Returns + /// The number of vectors migrated. + pub fn flush(&mut self) -> Result { + let flat_labels: Vec<(LabelType, usize)> = self + .flat_label_counts + .read() + .iter() + .filter(|(_, &count)| count > 0) + .map(|(&l, &c)| (l, c)) + .collect(); + + if flat_labels.is_empty() { + return Ok(0); + } + + let mut migrated = 0; + + // Collect all vectors from flat buffer + let vectors: Vec<(LabelType, Vec>)> = { + let flat = self.flat.read(); + flat_labels + .iter() + .filter_map(|(label, _)| { + flat.get_vectors(*label).map(|vecs| (*label, vecs)) + }) + .collect() + }; + + // Add to SVS + { + let mut svs = self.svs.write(); + let mut svs_label_counts = self.svs_label_counts.write(); + + for (label, vecs) in &vectors { + for vec in vecs { + match svs.add_vector(vec, *label) { + Ok(added) => { + if added > 0 { + *svs_label_counts.entry(*label).or_insert(0) += added; + migrated += added; + } + } + Err(e) => { + eprintln!("Failed to migrate vector for label {label} to SVS: {e:?}"); + } + } + } + } + } + + // Clear flat buffer + { + let mut flat = self.flat.write(); + let mut flat_label_counts = self.flat_label_counts.write(); + + flat.clear(); + flat_label_counts.clear(); + } + + Ok(migrated) + } + + /// Compact both tiers to reclaim space from deleted vectors. + /// + /// # Arguments + /// * `shrink` - If true, also release unused memory + /// + /// # Returns + /// Approximate bytes reclaimed. + pub fn compact(&mut self, shrink: bool) -> usize { + let flat_reclaimed = self.flat.write().compact(shrink); + let svs_reclaimed = self.svs.write().compact(shrink); + flat_reclaimed + svs_reclaimed + } + + /// Get the fragmentation ratio (0.0 = none, 1.0 = all deleted). + pub fn fragmentation(&self) -> f64 { + let flat = self.flat.read(); + let svs = self.svs.read(); + let flat_frag = flat.fragmentation(); + let svs_frag = svs.fragmentation(); + + // Weighted average by size + let flat_size = flat.index_size() as f64; + let svs_size = svs.index_size() as f64; + let total = flat_size + svs_size; + + if total == 0.0 { + 0.0 + } else { + (flat_frag * flat_size + svs_frag * svs_size) / total + } + } + + /// Get all vectors stored for a given label. + pub fn get_all_vectors(&self, label: LabelType) -> Vec> { + let mut results = Vec::new(); + + // Get from flat buffer + if self.is_in_flat(label) { + if let Some(vecs) = self.flat.read().get_vectors(label) { + results.extend(vecs); + } + } + + // Get from SVS + if self.is_in_svs(label) { + if let Some(vecs) = self.svs.read().get_all_vectors(label) { + results.extend(vecs); + } + } + + results + } + + /// Clear all vectors from both tiers. + pub fn clear(&mut self) { + self.flat.write().clear(); + self.svs.write().clear(); + self.flat_label_counts.write().clear(); + self.svs_label_counts.write().clear(); + self.count.store(0, Ordering::Relaxed); + } +} + +impl VecSimIndex for TieredSvsMulti { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + // Validate dimension + if vector.len() != self.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.params.dim, + got: vector.len(), + }); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + // Add new vector based on write mode + let added = match self.write_mode() { + SvsWriteMode::Async => { + let mut flat = self.flat.write(); + let mut flat_label_counts = self.flat_label_counts.write(); + let added = flat.add_vector(vector, label)?; + *flat_label_counts.entry(label).or_insert(0) += added; + added + } + SvsWriteMode::InPlace => { + let mut svs = self.svs.write(); + let mut svs_label_counts = self.svs_label_counts.write(); + let added = svs.add_vector(vector, label)?; + *svs_label_counts.entry(label).or_insert(0) += added; + added + } + }; + + self.count.fetch_add(added, Ordering::Relaxed); + Ok(added) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let in_flat = self.is_in_flat(label); + let in_svs = self.is_in_svs(label); + + if !in_flat && !in_svs { + return Err(IndexError::LabelNotFound(label)); + } + + let mut deleted = 0; + + if in_flat { + let mut flat = self.flat.write(); + let mut flat_label_counts = self.flat_label_counts.write(); + if let Ok(count) = flat.delete_vector(label) { + flat_label_counts.remove(&label); + deleted += count; + } + } + + if in_svs { + let mut svs = self.svs.write(); + let mut svs_label_counts = self.svs_label_counts.write(); + if let Ok(count) = svs.delete_vector(label) { + svs_label_counts.remove(&label); + deleted += count; + } + } + + self.count.fetch_sub(deleted, Ordering::Relaxed); + Ok(deleted) + } + + fn top_k_query( + &self, + query: &[Self::DataType], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + // Lock ordering: flat first, then SVS + let flat = self.flat.read(); + let svs = self.svs.read(); + + // Query both tiers + let flat_results = flat.top_k_query(query, k, params)?; + let svs_results = svs.top_k_query(query, k, params)?; + + // Merge results + Ok(merge_top_k(flat_results, svs_results, k)) + } + + fn range_query( + &self, + query: &[Self::DataType], + radius: Self::DistType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + // Lock ordering: flat first, then SVS + let flat = self.flat.read(); + let svs = self.svs.read(); + + // Query both tiers + let flat_results = flat.range_query(query, radius, params)?; + let svs_results = svs.range_query(query, radius, params)?; + + // Merge results + Ok(merge_range(flat_results, svs_results)) + } + + fn index_size(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[Self::DataType], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + // Compute results immediately to preserve filter from params + let mut results = Vec::new(); + + // Get results from flat buffer + let flat = self.flat.read(); + if let Ok(mut iter) = flat.batch_iterator(query, params) { + while let Some(batch) = iter.next_batch(1000) { + results.extend(batch); + } + } + drop(flat); + + // Get results from SVS + let svs = self.svs.read(); + if let Ok(mut iter) = svs.batch_iterator(query, params) { + while let Some(batch) = iter.next_batch(1000) { + results.extend(batch); + } + } + drop(svs); + + // Sort by distance + results.sort_by(|a, b| { + a.2.to_f64() + .partial_cmp(&b.2.to_f64()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + Ok(Box::new( + super::batch_iterator::TieredSvsMultiBatchIterator::::new(results), + )) + } + + fn info(&self) -> IndexInfo { + IndexInfo { + size: self.index_size(), + capacity: self.capacity, + dimension: self.params.dim, + index_type: "TieredSvsMulti", + memory_bytes: self.memory_usage(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.is_in_flat(label) || self.is_in_svs(label) + } + + fn label_count(&self, label: LabelType) -> usize { + let flat_count = self + .flat_label_counts + .read() + .get(&label) + .copied() + .unwrap_or(0); + let svs_count = self + .svs_label_counts + .read() + .get(&label) + .copied() + .unwrap_or(0); + flat_count + svs_count + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_tiered_svs_multi_basic() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSvsMulti::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 1.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 1).unwrap(); // Same label + index.add_vector(&v3, 2).unwrap(); + + assert_eq!(index.index_size(), 3); + assert_eq!(index.label_count(1), 2); + assert_eq!(index.label_count(2), 1); + } + + #[test] + fn test_tiered_svs_multi_flush() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(100); + let mut index = TieredSvsMulti::::new(params); + + // Add vectors with same label + for i in 0..5 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, 1).unwrap(); + } + + assert_eq!(index.flat_size(), 5); + assert_eq!(index.svs_size(), 0); + + // Flush + let migrated = index.flush().unwrap(); + assert_eq!(migrated, 5); + + assert_eq!(index.flat_size(), 0); + assert_eq!(index.svs_size(), 5); + assert_eq!(index.label_count(1), 5); + } + + #[test] + fn test_tiered_svs_multi_delete() { + let params = TieredSvsParams::new(4, Metric::L2); + let mut index = TieredSvsMulti::::new(params); + + // Add multiple vectors with same label + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 3); + + // Delete label 1 (should remove 2 vectors) + let deleted = index.delete_vector(1).unwrap(); + assert_eq!(deleted, 2); + assert_eq!(index.index_size(), 1); + assert!(!index.contains(1)); + assert!(index.contains(2)); + } + + #[test] + fn test_tiered_svs_multi_get_all_vectors() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSvsMulti::::new(params); + + // Add to flat + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Flush to SVS + index.flush().unwrap(); + + // Add more to flat + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 1).unwrap(); + + // Get all vectors for label 1 (from both tiers) + let vectors = index.get_all_vectors(1); + assert_eq!(vectors.len(), 2); + } + + #[test] + fn test_tiered_svs_multi_query_both_tiers() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSvsMulti::::new(params); + + // Add to flat + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Flush to SVS + index.flush().unwrap(); + + // Add more to flat + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.flat_size(), 1); + assert_eq!(index.svs_size(), 1); + + // Query should find both + let query = vec![0.5, 0.5, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.len(), 2); + } + + #[test] + fn test_tiered_svs_multi_batch_iterator_with_filter() { + use crate::query::QueryParams; + + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSvsMulti::::new(params); + + // Add 2 vectors each for labels 1-3 to flat buffer + for i in 1..=3u64 { + let v1 = vec![i as f32, 0.0, 0.0, 0.0]; + let v2 = vec![i as f32, 0.1, 0.0, 0.0]; + index.add_vector(&v1, i).unwrap(); + index.add_vector(&v2, i).unwrap(); + } + + // Flush to SVS + index.flush().unwrap(); + + // Add 2 vectors each for labels 4-6 to flat buffer + for i in 4..=6u64 { + let v1 = vec![i as f32, 0.0, 0.0, 0.0]; + let v2 = vec![i as f32, 0.1, 0.0, 0.0]; + index.add_vector(&v1, i).unwrap(); + index.add_vector(&v2, i).unwrap(); + } + + assert_eq!(index.flat_size(), 6); // Labels 4-6, 2 each + assert_eq!(index.svs_size(), 6); // Labels 1-3, 2 each + + // Create a filter that only allows even labels (2, 4, 6) + let query_params = QueryParams::new().with_filter(|label| label % 2 == 0); + + let query = vec![4.0, 0.0, 0.0, 0.0]; + + // Test batch iterator with filter + let mut iter = index.batch_iterator(&query, Some(&query_params)).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should have 6 vectors (2 each for labels 2, 4, 6) + assert_eq!(all_results.len(), 6); + for (_, label, _) in &all_results { + assert_eq!(label % 2, 0, "Expected only even labels, got {}", label); + } + } + + #[test] + fn test_tiered_svs_multi_batch_iterator_no_filter() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSvsMulti::::new(params); + + // Add 2 vectors each for labels 1-2 to flat buffer + for i in 1..=2u64 { + let v1 = vec![i as f32, 0.0, 0.0, 0.0]; + let v2 = vec![i as f32, 0.1, 0.0, 0.0]; + index.add_vector(&v1, i).unwrap(); + index.add_vector(&v2, i).unwrap(); + } + + // Flush to SVS + index.flush().unwrap(); + + // Add 2 vectors each for labels 3-4 to flat buffer + for i in 3..=4u64 { + let v1 = vec![i as f32, 0.0, 0.0, 0.0]; + let v2 = vec![i as f32, 0.1, 0.0, 0.0]; + index.add_vector(&v1, i).unwrap(); + index.add_vector(&v2, i).unwrap(); + } + + let query = vec![2.0, 0.0, 0.0, 0.0]; + + // Test batch iterator without filter + let mut iter = index.batch_iterator(&query, None).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should have all 8 vectors from both tiers (4 labels * 2 vectors each) + assert_eq!(all_results.len(), 8); + } +} diff --git a/rust/vecsim/src/index/tiered_svs/single.rs b/rust/vecsim/src/index/tiered_svs/single.rs new file mode 100644 index 000000000..0fd93d33d --- /dev/null +++ b/rust/vecsim/src/index/tiered_svs/single.rs @@ -0,0 +1,731 @@ +//! Single-value tiered SVS index implementation. +//! +//! This index stores one vector per label, combining a BruteForce frontend +//! (for fast writes) with an SVS (Vamana) backend (for efficient queries). + +use super::{merge_range, merge_top_k, TieredSvsParams, SvsWriteMode}; +use crate::index::brute_force::{BruteForceParams, BruteForceSingle}; +use crate::index::svs::SvsSingle; +use crate::index::traits::{BatchIterator, IndexError, IndexInfo, QueryError, VecSimIndex}; +use crate::query::{QueryParams, QueryReply}; +use crate::types::{DistanceType, LabelType, VectorElement}; +use parking_lot::RwLock; +use std::collections::HashSet; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Statistics about a tiered SVS index. +#[derive(Debug, Clone)] +pub struct TieredSvsStats { + /// Number of vectors in flat buffer. + pub flat_size: usize, + /// Number of vectors in SVS backend. + pub svs_size: usize, + /// Total vector count. + pub total_size: usize, + /// Current write mode. + pub write_mode: SvsWriteMode, + /// Flat buffer fragmentation. + pub flat_fragmentation: f64, + /// SVS fragmentation. + pub svs_fragmentation: f64, + /// Approximate memory usage in bytes. + pub memory_bytes: usize, +} + +/// Single-value tiered SVS index combining BruteForce frontend with SVS backend. +/// +/// The tiered architecture provides: +/// - **Fast writes**: New vectors go to the flat BruteForce buffer +/// - **Efficient queries**: Results are merged from both tiers +/// - **Flexible migration**: Use `flush()` to migrate vectors to SVS +/// +/// # Write Behavior +/// +/// In **Async** mode (default): +/// - New vectors are added to the flat buffer +/// - When buffer reaches `flat_buffer_limit`, mode switches to InPlace +/// - Call `flush()` to migrate all vectors to SVS and reset buffer +/// +/// In **InPlace** mode: +/// - Vectors are added directly to SVS +/// - Use this when you don't need the write buffer +pub struct TieredSvsSingle { + /// Flat BruteForce buffer (frontend). + pub(crate) flat: RwLock>, + /// SVS index (backend). + pub(crate) svs: RwLock>, + /// Labels currently in flat buffer. + flat_labels: RwLock>, + /// Labels currently in SVS. + svs_labels: RwLock>, + /// Configuration parameters. + params: TieredSvsParams, + /// Total vector count. + count: AtomicUsize, + /// Maximum capacity (if set). + capacity: Option, +} + +impl TieredSvsSingle { + /// Create a new single-value tiered SVS index. + pub fn new(params: TieredSvsParams) -> Self { + let bf_params = BruteForceParams::new(params.dim, params.metric) + .with_capacity(params.flat_buffer_limit.min(params.initial_capacity)); + let flat = BruteForceSingle::new(bf_params); + let svs = SvsSingle::new(params.svs_params.clone()); + + Self { + flat: RwLock::new(flat), + svs: RwLock::new(svs), + flat_labels: RwLock::new(HashSet::new()), + svs_labels: RwLock::new(HashSet::new()), + params, + count: AtomicUsize::new(0), + capacity: None, + } + } + + /// Create with a maximum capacity. + pub fn with_capacity(params: TieredSvsParams, max_capacity: usize) -> Self { + let mut index = Self::new(params); + index.capacity = Some(max_capacity); + index + } + + /// Get the current write mode. + pub fn write_mode(&self) -> SvsWriteMode { + if self.params.write_mode == SvsWriteMode::InPlace { + return SvsWriteMode::InPlace; + } + // In Async mode, switch to InPlace if flat buffer is full + if self.flat_size() >= self.params.flat_buffer_limit { + SvsWriteMode::InPlace + } else { + SvsWriteMode::Async + } + } + + /// Get the number of vectors in the flat buffer. + pub fn flat_size(&self) -> usize { + self.flat.read().index_size() + } + + /// Get the number of vectors in the SVS backend. + pub fn svs_size(&self) -> usize { + self.svs.read().index_size() + } + + /// Get detailed statistics about the index. + pub fn stats(&self) -> TieredSvsStats { + let flat = self.flat.read(); + let svs = self.svs.read(); + + TieredSvsStats { + flat_size: flat.index_size(), + svs_size: svs.index_size(), + total_size: self.count.load(Ordering::Relaxed), + write_mode: self.write_mode(), + flat_fragmentation: flat.fragmentation(), + svs_fragmentation: svs.fragmentation(), + memory_bytes: flat.memory_usage() + svs.memory_usage(), + } + } + + /// Get the memory usage in bytes. + pub fn memory_usage(&self) -> usize { + self.flat.read().memory_usage() + self.svs.read().memory_usage() + } + + /// Check if a label exists in the flat buffer. + pub fn is_in_flat(&self, label: LabelType) -> bool { + self.flat_labels.read().contains(&label) + } + + /// Check if a label exists in the SVS backend. + pub fn is_in_svs(&self, label: LabelType) -> bool { + self.svs_labels.read().contains(&label) + } + + /// Flush all vectors from flat buffer to SVS. + /// + /// This migrates all vectors from the flat buffer to the SVS backend, + /// clearing the flat buffer afterward. + /// + /// # Returns + /// The number of vectors migrated. + pub fn flush(&mut self) -> Result { + let flat_labels: Vec = self.flat_labels.read().iter().copied().collect(); + + if flat_labels.is_empty() { + return Ok(0); + } + + let mut migrated = 0; + + // Collect vectors from flat buffer + let vectors: Vec<(LabelType, Vec)> = { + let flat = self.flat.read(); + flat_labels + .iter() + .filter_map(|&label| flat.get_vector(label).map(|v| (label, v))) + .collect() + }; + + // Add to SVS + { + let mut svs = self.svs.write(); + let mut svs_labels = self.svs_labels.write(); + + for (label, vector) in &vectors { + match svs.add_vector(vector, *label) { + Ok(_) => { + svs_labels.insert(*label); + migrated += 1; + } + Err(e) => { + // Log error but continue + eprintln!("Failed to migrate label {label} to SVS: {e:?}"); + } + } + } + } + + // Clear flat buffer + { + let mut flat = self.flat.write(); + let mut flat_labels = self.flat_labels.write(); + + flat.clear(); + flat_labels.clear(); + } + + Ok(migrated) + } + + /// Compact both tiers to reclaim space from deleted vectors. + /// + /// # Arguments + /// * `shrink` - If true, also release unused memory + /// + /// # Returns + /// Approximate bytes reclaimed. + pub fn compact(&mut self, shrink: bool) -> usize { + let flat_reclaimed = self.flat.write().compact(shrink); + let svs_reclaimed = self.svs.write().compact(shrink); + flat_reclaimed + svs_reclaimed + } + + /// Get the fragmentation ratio (0.0 = none, 1.0 = all deleted). + pub fn fragmentation(&self) -> f64 { + let flat = self.flat.read(); + let svs = self.svs.read(); + let flat_frag = flat.fragmentation(); + let svs_frag = svs.fragmentation(); + + // Weighted average by size + let flat_size = flat.index_size() as f64; + let svs_size = svs.index_size() as f64; + let total = flat_size + svs_size; + + if total == 0.0 { + 0.0 + } else { + (flat_frag * flat_size + svs_frag * svs_size) / total + } + } + + /// Get a copy of the vector stored for a given label. + pub fn get_vector(&self, label: LabelType) -> Option> { + // Check flat first + if self.flat_labels.read().contains(&label) { + return self.flat.read().get_vector(label); + } + // Then check SVS + if self.svs_labels.read().contains(&label) { + return self.svs.read().get_vector(label); + } + None + } + + /// Clear all vectors from both tiers. + pub fn clear(&mut self) { + self.flat.write().clear(); + self.svs.write().clear(); + self.flat_labels.write().clear(); + self.svs_labels.write().clear(); + self.count.store(0, Ordering::Relaxed); + } +} + +impl VecSimIndex for TieredSvsSingle { + type DataType = T; + type DistType = T::DistanceType; + + fn add_vector(&mut self, vector: &[T], label: LabelType) -> Result { + // Validate dimension + if vector.len() != self.params.dim { + return Err(IndexError::DimensionMismatch { + expected: self.params.dim, + got: vector.len(), + }); + } + + // Check capacity + if let Some(cap) = self.capacity { + if self.count.load(Ordering::Relaxed) >= cap { + return Err(IndexError::CapacityExceeded { capacity: cap }); + } + } + + let in_flat = self.flat_labels.read().contains(&label); + let in_svs = self.svs_labels.read().contains(&label); + + // Update existing vector + if in_flat { + // Update in flat buffer (replace) + let mut flat = self.flat.write(); + flat.add_vector(vector, label)?; + return Ok(0); // No new vector + } + + // Track if this is a replacement (to return 0 instead of 1) + let is_replacement = in_svs; + + if in_svs { + // For single-value: update means replace + // Delete from SVS and add to flat (or direct to SVS based on mode) + { + let mut svs = self.svs.write(); + let mut svs_labels = self.svs_labels.write(); + svs.delete_vector(label)?; + svs_labels.remove(&label); + } + self.count.fetch_sub(1, Ordering::Relaxed); + + // Now add new vector (fall through to new vector logic) + } + + // Add new vector based on write mode + match self.write_mode() { + SvsWriteMode::Async => { + let mut flat = self.flat.write(); + let mut flat_labels = self.flat_labels.write(); + flat.add_vector(vector, label)?; + flat_labels.insert(label); + } + SvsWriteMode::InPlace => { + let mut svs = self.svs.write(); + let mut svs_labels = self.svs_labels.write(); + svs.add_vector(vector, label)?; + svs_labels.insert(label); + } + } + + self.count.fetch_add(1, Ordering::Relaxed); + Ok(if is_replacement { 0 } else { 1 }) + } + + fn delete_vector(&mut self, label: LabelType) -> Result { + let in_flat = self.flat_labels.read().contains(&label); + let in_svs = self.svs_labels.read().contains(&label); + + if !in_flat && !in_svs { + return Err(IndexError::LabelNotFound(label)); + } + + let mut deleted = 0; + + if in_flat { + let mut flat = self.flat.write(); + let mut flat_labels = self.flat_labels.write(); + if flat.delete_vector(label).is_ok() { + flat_labels.remove(&label); + deleted += 1; + } + } + + if in_svs { + let mut svs = self.svs.write(); + let mut svs_labels = self.svs_labels.write(); + if svs.delete_vector(label).is_ok() { + svs_labels.remove(&label); + deleted += 1; + } + } + + self.count.fetch_sub(deleted, Ordering::Relaxed); + Ok(deleted) + } + + fn top_k_query( + &self, + query: &[Self::DataType], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + // Lock ordering: flat first, then SVS + let flat = self.flat.read(); + let svs = self.svs.read(); + + // Query both tiers + let flat_results = flat.top_k_query(query, k, params)?; + let svs_results = svs.top_k_query(query, k, params)?; + + // Merge results + Ok(merge_top_k(flat_results, svs_results, k)) + } + + fn range_query( + &self, + query: &[Self::DataType], + radius: Self::DistType, + params: Option<&QueryParams>, + ) -> Result, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + // Lock ordering: flat first, then SVS + let flat = self.flat.read(); + let svs = self.svs.read(); + + // Query both tiers + let flat_results = flat.range_query(query, radius, params)?; + let svs_results = svs.range_query(query, radius, params)?; + + // Merge results + Ok(merge_range(flat_results, svs_results)) + } + + fn index_size(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + + fn index_capacity(&self) -> Option { + self.capacity + } + + fn dimension(&self) -> usize { + self.params.dim + } + + fn batch_iterator<'a>( + &'a self, + query: &[Self::DataType], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError> { + if query.len() != self.params.dim { + return Err(QueryError::DimensionMismatch { + expected: self.params.dim, + got: query.len(), + }); + } + + // Compute results immediately to preserve filter from params + let mut results = Vec::new(); + + // Get results from flat buffer + let flat = self.flat.read(); + if let Ok(mut iter) = flat.batch_iterator(query, params) { + while let Some(batch) = iter.next_batch(1000) { + results.extend(batch); + } + } + drop(flat); + + // Get results from SVS + let svs = self.svs.read(); + if let Ok(mut iter) = svs.batch_iterator(query, params) { + while let Some(batch) = iter.next_batch(1000) { + results.extend(batch); + } + } + drop(svs); + + // Sort by distance + results.sort_by(|a, b| { + a.2.to_f64() + .partial_cmp(&b.2.to_f64()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + Ok(Box::new(super::batch_iterator::TieredSvsBatchIterator::::new(results))) + } + + fn info(&self) -> IndexInfo { + IndexInfo { + size: self.index_size(), + capacity: self.capacity, + dimension: self.params.dim, + index_type: "TieredSvsSingle", + memory_bytes: self.memory_usage(), + } + } + + fn contains(&self, label: LabelType) -> bool { + self.flat_labels.read().contains(&label) || self.svs_labels.read().contains(&label) + } + + fn label_count(&self, label: LabelType) -> usize { + if self.contains(label) { + 1 + } else { + 0 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distance::Metric; + + #[test] + fn test_tiered_svs_single_basic() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSvsSingle::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 1.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + index.add_vector(&v2, 2).unwrap(); + index.add_vector(&v3, 3).unwrap(); + + assert_eq!(index.index_size(), 3); + assert_eq!(index.flat_size(), 3); + assert_eq!(index.svs_size(), 0); + + // Query should find all vectors + let query = vec![1.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 3, None).unwrap(); + + assert_eq!(results.len(), 3); + assert_eq!(results.results[0].label, 1); // Closest + } + + #[test] + fn test_tiered_svs_single_query_both_tiers() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(10); + let mut index = TieredSvsSingle::::new(params); + + // Add to flat + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + + // Flush to SVS + let migrated = index.flush().unwrap(); + assert_eq!(migrated, 1); + + // Add more to flat + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.flat_size(), 1); + assert_eq!(index.svs_size(), 1); + assert_eq!(index.index_size(), 2); + + // Query should find both + let query = vec![0.5, 0.5, 0.0, 0.0]; + let results = index.top_k_query(&query, 2, None).unwrap(); + assert_eq!(results.len(), 2); + } + + #[test] + fn test_tiered_svs_single_delete() { + let params = TieredSvsParams::new(4, Metric::L2); + let mut index = TieredSvsSingle::::new(params); + + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.index_size(), 2); + + // Delete from flat + index.delete_vector(1).unwrap(); + assert_eq!(index.index_size(), 1); + assert!(!index.contains(1)); + assert!(index.contains(2)); + } + + #[test] + fn test_tiered_svs_single_flush() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(100); + let mut index = TieredSvsSingle::::new(params); + + // Add several vectors + for i in 0..10 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i as u64).unwrap(); + } + + assert_eq!(index.flat_size(), 10); + assert_eq!(index.svs_size(), 0); + + // Flush + let migrated = index.flush().unwrap(); + assert_eq!(migrated, 10); + + assert_eq!(index.flat_size(), 0); + assert_eq!(index.svs_size(), 10); + assert_eq!(index.index_size(), 10); + + // Query should still work + let results = index.top_k_query(&vec![5.0, 0.0, 0.0, 0.0], 3, None).unwrap(); + assert_eq!(results.len(), 3); + assert_eq!(results.results[0].label, 5); + } + + #[test] + fn test_tiered_svs_single_in_place_mode() { + let params = TieredSvsParams::new(4, Metric::L2) + .with_flat_buffer_limit(2) + .with_write_mode(SvsWriteMode::Async); + let mut index = TieredSvsSingle::::new(params); + + // Fill flat buffer + index.add_vector(&vec![1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&vec![0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + assert_eq!(index.write_mode(), SvsWriteMode::InPlace); + + // Next add should go directly to SVS + index.add_vector(&vec![0.0, 0.0, 1.0, 0.0], 3).unwrap(); + + assert_eq!(index.flat_size(), 2); + assert_eq!(index.svs_size(), 1); + assert_eq!(index.index_size(), 3); + } + + #[test] + fn test_tiered_svs_single_compact() { + let params = TieredSvsParams::new(4, Metric::L2); + let mut index = TieredSvsSingle::::new(params); + + // Add and delete + for i in 0..10 { + index.add_vector(&vec![i as f32, 0.0, 0.0, 0.0], i as u64).unwrap(); + } + + for i in (0..10).step_by(2) { + index.delete_vector(i as u64).unwrap(); + } + + assert!(index.fragmentation() > 0.0); + + // Compact + index.compact(true); + + assert!((index.fragmentation() - 0.0).abs() < 0.01); + } + + #[test] + fn test_tiered_svs_single_replace() { + let params = TieredSvsParams::new(4, Metric::L2); + let mut index = TieredSvsSingle::::new(params); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + + index.add_vector(&v1, 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Replace with new vector + index.add_vector(&v2, 1).unwrap(); + assert_eq!(index.index_size(), 1); + + // Should return the new vector + let query = vec![0.0, 1.0, 0.0, 0.0]; + let results = index.top_k_query(&query, 1, None).unwrap(); + assert!((results.results[0].distance as f64) < 0.001); + } + + #[test] + fn test_tiered_svs_single_batch_iterator_with_filter() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(5); + let mut index = TieredSvsSingle::::new(params); + + // Add vectors to flat buffer (labels 1-5) + for i in 1..=5u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + // Flush to SVS + index.flush().unwrap(); + + // Add more to flat buffer (labels 6-10) + for i in 6..=10u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + assert_eq!(index.flat_size(), 5); // Labels 6-10 in flat + assert_eq!(index.svs_size(), 5); // Labels 1-5 in SVS + + // Create a filter that only allows labels divisible by 3 + let query_params = QueryParams::new().with_filter(|label| label % 3 == 0); + + let query = vec![5.0, 0.0, 0.0, 0.0]; + + // Test batch iterator with filter + let mut iter = index.batch_iterator(&query, Some(&query_params)).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should have labels 3, 6, 9 (3 results) + assert_eq!(all_results.len(), 3); + for (_, label, _) in &all_results { + assert_eq!(label % 3, 0, "Expected only labels divisible by 3, got {}", label); + } + + let labels: Vec<_> = all_results.iter().map(|(_, l, _)| *l).collect(); + assert!(labels.contains(&3)); + assert!(labels.contains(&6)); + assert!(labels.contains(&9)); + } + + #[test] + fn test_tiered_svs_single_batch_iterator_no_filter() { + let params = TieredSvsParams::new(4, Metric::L2).with_flat_buffer_limit(3); + let mut index = TieredSvsSingle::::new(params); + + // Add to flat buffer + for i in 1..=3u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + // Flush to SVS + index.flush().unwrap(); + + // Add more to flat buffer + for i in 4..=6u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let query = vec![3.0, 0.0, 0.0, 0.0]; + + // Test batch iterator without filter + let mut iter = index.batch_iterator(&query, None).unwrap(); + let mut all_results = Vec::new(); + while let Some(batch) = iter.next_batch(100) { + all_results.extend(batch); + } + + // Should have all 6 vectors from both tiers + assert_eq!(all_results.len(), 6); + } +} diff --git a/rust/vecsim/src/index/traits.rs b/rust/vecsim/src/index/traits.rs new file mode 100644 index 000000000..ecd708f71 --- /dev/null +++ b/rust/vecsim/src/index/traits.rs @@ -0,0 +1,296 @@ +//! Core index traits defining the vector similarity interface. +//! +//! This module defines the abstract interfaces that all index implementations must support: +//! - `VecSimIndex`: The main index trait for vector storage and search +//! - `BatchIterator`: Iterator for streaming query results in batches + +use crate::query::{QueryParams, QueryReply}; +use crate::types::{DistanceType, IdType, LabelType, VectorElement}; +use thiserror::Error; + +/// Errors that can occur during index operations. +#[derive(Error, Debug)] +pub enum IndexError { + #[error("Vector dimension mismatch: expected {expected}, got {got}")] + DimensionMismatch { expected: usize, got: usize }, + + #[error("Index is at capacity ({capacity})")] + CapacityExceeded { capacity: usize }, + + #[error("Label {0} not found in index")] + LabelNotFound(LabelType), + + #[error("Invalid parameter: {0}")] + InvalidParameter(String), + + #[error("Index is empty")] + EmptyIndex, + + #[error("Internal error: {0}")] + Internal(String), + + #[error("Index data corruption detected: {0}")] + Corruption(String), + + #[error("Memory allocation failed")] + MemoryExhausted, + + #[error("Duplicate label: {0} already exists")] + DuplicateLabel(LabelType), +} + +/// Errors that can occur during query operations. +#[derive(Error, Debug)] +pub enum QueryError { + #[error("Vector dimension mismatch: expected {expected}, got {got}")] + DimensionMismatch { expected: usize, got: usize }, + + #[error("Invalid query parameter: {0}")] + InvalidParameter(String), + + #[error("Query cancelled")] + Cancelled, + + #[error("Query timed out after {0}ms")] + Timeout(u64), + + #[error("Empty query vector")] + EmptyVector, + + #[error("Vector is not normalized (required for this metric)")] + NotNormalized, +} + +/// Information about the index. +#[derive(Debug, Clone)] +pub struct IndexInfo { + /// Number of vectors currently stored. + pub size: usize, + /// Maximum capacity (if bounded). + pub capacity: Option, + /// Vector dimension. + pub dimension: usize, + /// Index type name. + pub index_type: &'static str, + /// Memory usage in bytes (approximate). + pub memory_bytes: usize, +} + +/// The main trait for vector similarity indices. +/// +/// This trait defines the core operations supported by all index types: +/// - Adding and removing vectors +/// - KNN (top-k) queries +/// - Range queries +/// - Batch iteration +pub trait VecSimIndex: Send + Sync { + /// The vector element type (f32, f64, Float16, BFloat16). + type DataType: VectorElement; + + /// The distance computation result type. + type DistType: DistanceType; + + /// Add a vector with the given label to the index. + /// + /// Returns the number of vectors added (1 for single-value indices, + /// potentially more for multi-value indices that store duplicates). + /// + /// # Errors + /// - `DimensionMismatch` if the vector has the wrong dimension + /// - `CapacityExceeded` if the index is full + fn add_vector( + &mut self, + vector: &[Self::DataType], + label: LabelType, + ) -> Result; + + /// Delete all vectors with the given label. + /// + /// Returns the number of vectors deleted. + /// + /// # Errors + /// - `LabelNotFound` if no vectors have this label + fn delete_vector(&mut self, label: LabelType) -> Result; + + /// Perform a top-k nearest neighbor query. + /// + /// Returns up to `k` vectors closest to the query vector. + /// + /// # Arguments + /// * `query` - The query vector + /// * `k` - Maximum number of results to return + /// * `params` - Optional query parameters + fn top_k_query( + &self, + query: &[Self::DataType], + k: usize, + params: Option<&QueryParams>, + ) -> Result, QueryError>; + + /// Perform a range query. + /// + /// Returns all vectors within the given radius of the query vector. + /// + /// # Arguments + /// * `query` - The query vector + /// * `radius` - Maximum distance from query + /// * `params` - Optional query parameters + fn range_query( + &self, + query: &[Self::DataType], + radius: Self::DistType, + params: Option<&QueryParams>, + ) -> Result, QueryError>; + + /// Get the current number of vectors in the index. + fn index_size(&self) -> usize; + + /// Get the maximum capacity of the index (if bounded). + fn index_capacity(&self) -> Option; + + /// Get the vector dimension. + fn dimension(&self) -> usize; + + /// Get a batch iterator for streaming query results. + /// + /// This is useful for processing large result sets incrementally + /// without loading all results into memory at once. + fn batch_iterator<'a>( + &'a self, + query: &[Self::DataType], + params: Option<&QueryParams>, + ) -> Result + 'a>, QueryError>; + + /// Get information about the index. + fn info(&self) -> IndexInfo; + + /// Check if a label exists in the index. + fn contains(&self, label: LabelType) -> bool; + + /// Get the number of vectors associated with a label. + fn label_count(&self, label: LabelType) -> usize; +} + +/// Iterator for streaming query results in batches. +/// +/// This trait allows processing query results incrementally, which is useful +/// for large result sets or when results need to be processed progressively. +pub trait BatchIterator: Send { + /// The distance type for results. + type DistType: DistanceType; + + /// Check if there are more results available. + fn has_next(&self) -> bool; + + /// Get the next batch of results. + /// + /// Returns `None` when no more results are available. + fn next_batch(&mut self, batch_size: usize) + -> Option>; + + /// Reset the iterator to the beginning. + fn reset(&mut self); +} + +/// Index type enumeration. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IndexType { + BruteForce, + HNSW, + Tiered, + Svs, + TieredSvs, + DiskIndex, +} + +impl std::fmt::Display for IndexType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IndexType::BruteForce => write!(f, "BruteForce"), + IndexType::HNSW => write!(f, "HNSW"), + IndexType::Tiered => write!(f, "Tiered"), + IndexType::Svs => write!(f, "Svs"), + IndexType::TieredSvs => write!(f, "TieredSvs"), + IndexType::DiskIndex => write!(f, "DiskIndex"), + } + } +} + +/// Whether the index supports multiple vectors per label. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MultiValue { + /// Single vector per label (new vector replaces existing). + Single, + /// Multiple vectors allowed per label. + Multi, +} + +/// Trait for indices that support garbage collection. +/// +/// Tiered indices accumulate deleted vectors over time. The GC process +/// removes these deleted vectors from the backend index to reclaim memory. +pub trait GarbageCollectable { + /// Run garbage collection to remove deleted vectors. + /// + /// Returns the number of vectors removed. + fn run_gc(&mut self) -> usize; + + /// Check if garbage collection is needed. + /// + /// This is a heuristic based on the ratio of deleted to total vectors. + fn needs_gc(&self) -> bool; + + /// Get the number of deleted vectors pending cleanup. + fn deleted_count(&self) -> usize; + + /// Get the GC threshold ratio (deleted/total that triggers GC). + fn gc_threshold(&self) -> f64 { + 0.1 // Default: trigger GC when 10% of vectors are deleted + } +} + +/// Trait for indices with configurable block sizes. +/// +/// Block size affects memory allocation patterns and can impact performance: +/// - Larger blocks: fewer allocations, better cache locality, more wasted space +/// - Smaller blocks: more allocations, less wasted space, potentially more fragmentation +pub trait BlockSizeConfigurable { + /// Get the current block size. + fn block_size(&self) -> usize; + + /// Set the block size for future allocations. + /// + /// This does not affect already-allocated blocks. + fn set_block_size(&mut self, size: usize); + + /// Get the default block size for this index type. + fn default_block_size() -> usize; +} + +/// Default block size used by indices. +pub const DEFAULT_BLOCK_SIZE: usize = 1024; + +/// Trait for indices that support memory fitting/compaction. +/// +/// Memory fitting releases unused capacity to reduce memory usage. +pub trait MemoryFittable { + /// Fit memory to actual usage, releasing unused capacity. + /// + /// Returns the number of bytes freed. + fn fit_memory(&mut self) -> usize; + + /// Get the current memory overhead (unused allocated bytes). + fn memory_overhead(&self) -> usize; +} + +/// Trait for indices that support async operations. +pub trait AsyncIndex { + /// Check if there are pending async operations. + fn has_pending_operations(&self) -> bool; + + /// Wait for all pending operations to complete. + fn wait_for_completion(&self); + + /// Get the number of pending operations. + fn pending_count(&self) -> usize; +} diff --git a/rust/vecsim/src/lib.rs b/rust/vecsim/src/lib.rs new file mode 100644 index 000000000..e9ab55ff9 --- /dev/null +++ b/rust/vecsim/src/lib.rs @@ -0,0 +1,203 @@ +//! VecSim - High-performance vector similarity search library. +//! +//! This library provides efficient implementations of vector similarity indices +//! for nearest neighbor search. It supports multiple index types and distance metrics, +//! with SIMD-optimized distance computations for high performance. +//! +//! # Index Types +//! +//! - **BruteForce**: Linear scan over all vectors. Provides exact results but O(n) query time. +//! Best for small datasets or when exact results are required. +//! +//! - **HNSW**: Hierarchical Navigable Small World graphs. Provides approximate results with +//! logarithmic query time. Best for large datasets where some recall loss is acceptable. +//! +//! # Distance Metrics +//! +//! - **L2** (Euclidean): Squared Euclidean distance. Lower values = more similar. +//! - **Inner Product**: Dot product (negated for distance). For normalized vectors. +//! - **Cosine**: 1 - cosine similarity. Measures angle between vectors. +//! +//! # Data Types +//! +//! - `f32`: Single precision float (default) +//! - `f64`: Double precision float +//! - `Float16`: Half precision float (IEEE 754-2008) +//! - `BFloat16`: Brain float (same exponent range as f32) +//! +//! # Examples +//! +//! ## BruteForce Index +//! +//! ```rust +//! use vecsim::prelude::*; +//! +//! // Create a BruteForce index +//! let params = BruteForceParams::new(4, Metric::L2); +//! let mut index = BruteForceSingle::::new(params); +//! +//! // Add vectors +//! index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); +//! index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); +//! index.add_vector(&[0.0, 0.0, 1.0, 0.0], 3).unwrap(); +//! +//! // Query for nearest neighbors +//! let query = [1.0, 0.1, 0.0, 0.0]; +//! let results = index.top_k_query(&query, 2, None).unwrap(); +//! +//! assert_eq!(results.results[0].label, 1); // Closest to query +//! ``` +//! +//! ## HNSW Index +//! +//! ```rust +//! use vecsim::prelude::*; +//! +//! // Create an HNSW index +//! let params = HnswParams::new(128, Metric::Cosine) +//! .with_m(16) +//! .with_ef_construction(200); +//! let mut index = HnswSingle::::new(params); +//! +//! // Add vectors (would normally add many more) +//! for i in 0..1000 { +//! let mut v = vec![0.0f32; 128]; +//! v[i % 128] = 1.0; +//! index.add_vector(&v, i as u64).unwrap(); +//! } +//! +//! // Query with custom ef_runtime +//! let query = vec![1.0f32; 128]; +//! let params = QueryParams::new().with_ef_runtime(50); +//! let results = index.top_k_query(&query, 10, Some(¶ms)).unwrap(); +//! ``` +//! +//! ## Multi-value Index +//! +//! ```rust +//! use vecsim::prelude::*; +//! +//! // Create a multi-value index (multiple vectors per label) +//! let params = BruteForceParams::new(4, Metric::L2); +//! let mut index = BruteForceMulti::::new(params); +//! +//! // Add multiple vectors with the same label +//! index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); +//! index.add_vector(&[0.9, 0.1, 0.0, 0.0], 1).unwrap(); // Same label +//! index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); +//! +//! assert_eq!(index.label_count(1), 2); +//! ``` + +pub mod containers; +pub mod distance; +pub mod index; +pub mod memory; +pub mod preprocessing; +pub mod quantization; +pub mod query; +pub mod serialization; +pub mod types; +pub mod utils; + +/// Prelude module for convenient imports. +/// +/// Use `use vecsim::prelude::*;` to import commonly used types. +pub mod prelude { + // Types + pub use crate::types::{ + BFloat16, DistanceType, Float16, IdType, Int32, Int64, Int8, LabelType, UInt8, + VectorElement, INVALID_ID, + }; + + // Quantization + pub use crate::quantization::{Sq8Codec, Sq8VectorMeta}; + + // Distance + pub use crate::distance::{ + batch_normalize, cosine_similarity, dot_product, euclidean_distance, l2_norm, l2_squared, + normalize, normalize_in_place, Metric, + }; + + // Query + pub use crate::query::{QueryParams, QueryReply, QueryResult}; + + // Index traits and errors + pub use crate::index::{ + BatchIterator, IndexError, IndexInfo, IndexType, MultiValue, QueryError, VecSimIndex, + }; + + // BruteForce + pub use crate::index::{BruteForceMulti, BruteForceParams, BruteForceSingle, BruteForceStats}; + + // Index estimation + pub use crate::index::{ + estimate_brute_force_element_size, estimate_brute_force_initial_size, + estimate_hnsw_element_size, estimate_hnsw_initial_size, + }; + + // HNSW + pub use crate::index::{HnswMulti, HnswParams, HnswSingle, HnswStats}; + + // Serialization + pub use crate::serialization::{Deserializable, Serializable, SerializationError}; +} + +/// Create a BruteForce index with the given parameters. +/// +/// This is a convenience function for creating single-value BruteForce indices. +pub fn create_brute_force( + dim: usize, + metric: distance::Metric, +) -> index::BruteForceSingle { + let params = index::BruteForceParams::new(dim, metric); + index::BruteForceSingle::new(params) +} + +/// Create an HNSW index with the given parameters. +/// +/// This is a convenience function for creating single-value HNSW indices. +pub fn create_hnsw( + dim: usize, + metric: distance::Metric, +) -> index::HnswSingle { + let params = index::HnswParams::new(dim, metric); + index::HnswSingle::new(params) +} + +#[cfg(test)] +mod parallel_stress_tests; + +#[cfg(test)] +mod data_type_tests; + +#[cfg(test)] +mod e2e_tests; + +#[cfg(test)] +mod tests { + use super::prelude::*; + + #[test] + fn test_prelude_imports() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + index.add_vector(&[0.0, 1.0, 0.0, 0.0], 2).unwrap(); + + let results = index.top_k_query(&[1.0, 0.0, 0.0, 0.0], 2, None).unwrap(); + assert_eq!(results.results[0].label, 1); + } + + #[test] + fn test_convenience_functions() { + let mut bf_index = super::create_brute_force::(4, Metric::L2); + bf_index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + assert_eq!(bf_index.index_size(), 1); + + let mut hnsw_index = super::create_hnsw::(4, Metric::L2); + hnsw_index.add_vector(&[1.0, 0.0, 0.0, 0.0], 1).unwrap(); + assert_eq!(hnsw_index.index_size(), 1); + } +} diff --git a/rust/vecsim/src/memory/mod.rs b/rust/vecsim/src/memory/mod.rs new file mode 100644 index 000000000..1452a3212 --- /dev/null +++ b/rust/vecsim/src/memory/mod.rs @@ -0,0 +1,516 @@ +//! Custom memory allocator interface for tracking and managing memory. +//! +//! This module provides a custom memory allocator that: +//! - Tracks total allocated memory via atomic counters +//! - Supports aligned allocations +//! - Allows custom memory functions for integration with external systems +//! - Provides RAII-style scoped memory management +//! +//! ## Usage +//! +//! ```rust,ignore +//! use vecsim::memory::{VecSimAllocator, AllocatorRef}; +//! +//! // Create an allocator +//! let allocator = VecSimAllocator::new(); +//! +//! // Allocate memory +//! let ptr = allocator.allocate(1024); +//! +//! // Check allocation size +//! println!("Allocated: {} bytes", allocator.allocation_size()); +//! +//! // Deallocate +//! allocator.deallocate(ptr, 1024); +//! ``` + +use std::alloc::{alloc, alloc_zeroed, dealloc, Layout}; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +/// Function pointer type for malloc-style allocation. +pub type AllocFn = fn(usize) -> *mut u8; + +/// Function pointer type for calloc-style allocation. +pub type CallocFn = fn(usize, usize) -> *mut u8; + +/// Function pointer type for realloc-style reallocation. +pub type ReallocFn = fn(*mut u8, usize, usize) -> *mut u8; + +/// Function pointer type for free-style deallocation. +pub type FreeFn = fn(*mut u8, usize); + +/// Custom memory functions for integration with external systems. +#[derive(Clone, Copy)] +pub struct MemoryFunctions { + /// Allocation function (malloc-style). + pub alloc: AllocFn, + /// Zero-initialized allocation function (calloc-style). + pub calloc: CallocFn, + /// Reallocation function (takes old_size for proper deallocation). + pub realloc: ReallocFn, + /// Deallocation function (takes size for proper deallocation). + pub free: FreeFn, +} + +impl Default for MemoryFunctions { + fn default() -> Self { + Self { + alloc: default_alloc, + calloc: default_calloc, + realloc: default_realloc, + free: default_free, + } + } +} + +// Default memory functions using Rust's global allocator +fn default_alloc(size: usize) -> *mut u8 { + if size == 0 { + return std::ptr::null_mut(); + } + unsafe { + let layout = Layout::from_size_align_unchecked(size, 8); + alloc(layout) + } +} + +fn default_calloc(count: usize, size: usize) -> *mut u8 { + let total = count.saturating_mul(size); + if total == 0 { + return std::ptr::null_mut(); + } + unsafe { + let layout = Layout::from_size_align_unchecked(total, 8); + alloc_zeroed(layout) + } +} + +fn default_realloc(ptr: *mut u8, old_size: usize, new_size: usize) -> *mut u8 { + if ptr.is_null() { + return default_alloc(new_size); + } + if new_size == 0 { + default_free(ptr, old_size); + return std::ptr::null_mut(); + } + + let new_ptr = default_alloc(new_size); + if !new_ptr.is_null() { + let copy_size = old_size.min(new_size); + unsafe { + std::ptr::copy_nonoverlapping(ptr, new_ptr, copy_size); + } + default_free(ptr, old_size); + } + new_ptr +} + +fn default_free(ptr: *mut u8, size: usize) { + if ptr.is_null() || size == 0 { + return; + } + unsafe { + let layout = Layout::from_size_align_unchecked(size, 8); + dealloc(ptr, layout); + } +} + +/// Header stored before each allocation for tracking. +/// This header stores information needed for deallocation. +#[repr(C)] +struct AllocationHeader { + /// Original raw pointer from allocation (for deallocation). + raw_ptr: *mut u8, + /// Total size of the raw allocation. + total_size: u64, + /// User-requested size. + user_size: u64, +} + +impl AllocationHeader { + const SIZE: usize = std::mem::size_of::(); + const ALIGN: usize = std::mem::align_of::(); +} + +/// Thread-safe memory allocator with tracking capabilities. +/// +/// This allocator tracks total memory allocated and supports custom memory functions +/// for integration with external systems like Redis. +pub struct VecSimAllocator { + /// Total bytes currently allocated. + allocated: AtomicU64, + /// Custom memory functions. + mem_functions: MemoryFunctions, +} + +impl VecSimAllocator { + /// Create a new allocator with default memory functions. + pub fn new() -> Arc { + Self::with_memory_functions(MemoryFunctions::default()) + } + + /// Create a new allocator with custom memory functions. + pub fn with_memory_functions(mem_functions: MemoryFunctions) -> Arc { + Arc::new(Self { + allocated: AtomicU64::new(0), + mem_functions, + }) + } + + /// Get the total bytes currently allocated. + pub fn allocation_size(&self) -> u64 { + self.allocated.load(Ordering::Relaxed) + } + + /// Allocate memory with tracking. + /// + /// Returns a pointer to the allocated memory, or None if allocation failed. + pub fn allocate(&self, size: usize) -> Option> { + self.allocate_aligned(size, 8) + } + + /// Allocate memory with specific alignment. + pub fn allocate_aligned(&self, size: usize, alignment: usize) -> Option> { + if size == 0 { + return None; + } + + // Ensure alignment is at least header alignment + let alignment = alignment.max(AllocationHeader::ALIGN); + + // Calculate total size: header + padding for alignment + user data + // We need enough space so that after placing header, we can align the user pointer + let header_size = AllocationHeader::SIZE; + let total_size = header_size + alignment + size; + + // Allocate raw memory + let raw_ptr = (self.mem_functions.alloc)(total_size); + if raw_ptr.is_null() { + return None; + } + + // Calculate where to place user data (aligned) + // User data starts at: raw_ptr + header_size, then aligned up + let user_ptr_unaligned = unsafe { raw_ptr.add(header_size) }; + let offset = user_ptr_unaligned.align_offset(alignment); + let user_ptr = unsafe { user_ptr_unaligned.add(offset) }; + + // Place header just before user pointer + let header_ptr = unsafe { user_ptr.sub(header_size) } as *mut AllocationHeader; + unsafe { + (*header_ptr).raw_ptr = raw_ptr; + (*header_ptr).total_size = total_size as u64; + (*header_ptr).user_size = size as u64; + } + + // Track allocation + self.allocated.fetch_add(total_size as u64, Ordering::Relaxed); + + NonNull::new(user_ptr) + } + + /// Allocate zero-initialized memory. + pub fn callocate(&self, size: usize) -> Option> { + let ptr = self.allocate(size)?; + unsafe { + std::ptr::write_bytes(ptr.as_ptr(), 0, size); + } + Some(ptr) + } + + /// Reallocate memory to a new size. + /// + /// This copies data from the old allocation to the new one. + pub fn reallocate(&self, ptr: NonNull, new_size: usize) -> Option> { + // Get old info from header + let header_ptr = + unsafe { ptr.as_ptr().sub(AllocationHeader::SIZE) } as *const AllocationHeader; + let old_user_size = unsafe { (*header_ptr).user_size } as usize; + + // Allocate new memory + let new_ptr = self.allocate(new_size)?; + + // Copy data + let copy_size = old_user_size.min(new_size); + unsafe { + std::ptr::copy_nonoverlapping(ptr.as_ptr(), new_ptr.as_ptr(), copy_size); + } + + // Free old memory + self.deallocate(ptr); + + Some(new_ptr) + } + + /// Deallocate memory. + pub fn deallocate(&self, ptr: NonNull) { + // Get header + let header_ptr = + unsafe { ptr.as_ptr().sub(AllocationHeader::SIZE) } as *const AllocationHeader; + let raw_ptr = unsafe { (*header_ptr).raw_ptr }; + let total_size = unsafe { (*header_ptr).total_size } as usize; + + // Free the raw allocation + (self.mem_functions.free)(raw_ptr, total_size); + + // Update tracking + self.allocated.fetch_sub(total_size as u64, Ordering::Relaxed); + } + + /// Allocate with RAII wrapper that automatically deallocates on drop. + pub fn allocate_scoped(self: &Arc, size: usize) -> Option { + let ptr = self.allocate(size)?; + Some(ScopedAllocation { + ptr, + size, + allocator: Arc::clone(self), + }) + } + + /// Allocate aligned memory with RAII wrapper. + pub fn allocate_aligned_scoped( + self: &Arc, + size: usize, + alignment: usize, + ) -> Option { + let ptr = self.allocate_aligned(size, alignment)?; + Some(ScopedAllocation { + ptr, + size, + allocator: Arc::clone(self), + }) + } +} + +impl Default for VecSimAllocator { + fn default() -> Self { + Self { + allocated: AtomicU64::new(0), + mem_functions: MemoryFunctions::default(), + } + } +} + +/// Reference-counted allocator handle. +pub type AllocatorRef = Arc; + +/// RAII wrapper for scoped allocations. +/// +/// The memory is automatically deallocated when this struct is dropped. +pub struct ScopedAllocation { + ptr: NonNull, + size: usize, + allocator: Arc, +} + +impl ScopedAllocation { + /// Get the raw pointer. + pub fn as_ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } + + /// Get the NonNull pointer. + pub fn as_non_null(&self) -> NonNull { + self.ptr + } + + /// Get the size of the allocation. + pub fn size(&self) -> usize { + self.size + } + + /// Get a slice view of the allocation. + pub fn as_slice(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) } + } + + /// Get a mutable slice view of the allocation. + pub fn as_mut_slice(&mut self) -> &mut [u8] { + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) } + } +} + +impl Drop for ScopedAllocation { + fn drop(&mut self) { + self.allocator.deallocate(self.ptr); + } +} + +/// Trait for types that can provide their allocator. +pub trait HasAllocator { + /// Get a reference to the allocator. + fn allocator(&self) -> &Arc; + + /// Get the total memory allocated by this allocator. + fn memory_usage(&self) -> u64 { + self.allocator().allocation_size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_allocator_basic() { + let allocator = VecSimAllocator::new(); + + // Initial allocation should be 0 + assert_eq!(allocator.allocation_size(), 0); + + // Allocate some memory + let ptr = allocator.allocate(1024).expect("allocation failed"); + + // Should have tracked the allocation (with header overhead) + assert!(allocator.allocation_size() > 1024); + + // Deallocate + allocator.deallocate(ptr); + + // Should be back to 0 + assert_eq!(allocator.allocation_size(), 0); + } + + #[test] + fn test_allocator_aligned() { + let allocator = VecSimAllocator::new(); + + // Allocate with 64-byte alignment + let ptr = allocator + .allocate_aligned(1024, 64) + .expect("allocation failed"); + + // Check alignment + assert_eq!(ptr.as_ptr() as usize % 64, 0); + + allocator.deallocate(ptr); + assert_eq!(allocator.allocation_size(), 0); + } + + #[test] + fn test_allocator_callocate() { + let allocator = VecSimAllocator::new(); + + let ptr = allocator.callocate(1024).expect("allocation failed"); + + // Check that memory is zeroed + unsafe { + let slice = std::slice::from_raw_parts(ptr.as_ptr(), 1024); + assert!(slice.iter().all(|&b| b == 0)); + } + + allocator.deallocate(ptr); + assert_eq!(allocator.allocation_size(), 0); + } + + #[test] + fn test_scoped_allocation() { + let allocator = VecSimAllocator::new(); + + { + let alloc = allocator.allocate_scoped(1024).expect("allocation failed"); + assert!(allocator.allocation_size() > 0); + assert_eq!(alloc.size(), 1024); + } + + // Should be deallocated after scope ends + assert_eq!(allocator.allocation_size(), 0); + } + + #[test] + fn test_multiple_allocations() { + let allocator = VecSimAllocator::new(); + + let ptr1 = allocator.allocate(1024).expect("allocation failed"); + let ptr2 = allocator.allocate(2048).expect("allocation failed"); + let ptr3 = allocator.allocate(512).expect("allocation failed"); + + // Total should be sum of all allocations (plus overhead) + assert!(allocator.allocation_size() > 1024 + 2048 + 512); + + allocator.deallocate(ptr1); + allocator.deallocate(ptr2); + allocator.deallocate(ptr3); + + assert_eq!(allocator.allocation_size(), 0); + } + + #[test] + fn test_reallocate() { + let allocator = VecSimAllocator::new(); + + let ptr = allocator.allocate(1024).expect("allocation failed"); + + // Write some data + unsafe { + let slice = std::slice::from_raw_parts_mut(ptr.as_ptr(), 1024); + for (i, byte) in slice.iter_mut().enumerate() { + *byte = (i % 256) as u8; + } + } + + // Reallocate to larger size + let new_ptr = allocator.reallocate(ptr, 2048).expect("reallocation failed"); + + // Check data was preserved + unsafe { + let slice = std::slice::from_raw_parts(new_ptr.as_ptr(), 1024); + for (i, &byte) in slice.iter().enumerate() { + assert_eq!(byte, (i % 256) as u8); + } + } + + allocator.deallocate(new_ptr); + assert_eq!(allocator.allocation_size(), 0); + } + + #[test] + fn test_thread_safety() { + use std::thread; + + let allocator = VecSimAllocator::new(); + + let handles: Vec<_> = (0..10) + .map(|_| { + let alloc = Arc::clone(&allocator); + thread::spawn(move || { + for _ in 0..100 { + let ptr = alloc.allocate(1024).expect("allocation failed"); + alloc.deallocate(ptr); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + assert_eq!(allocator.allocation_size(), 0); + } + + #[test] + fn test_scoped_slice_access() { + let allocator = VecSimAllocator::new(); + + let mut alloc = allocator.allocate_scoped(256).expect("allocation failed"); + + // Write via mutable slice + { + let slice = alloc.as_mut_slice(); + for (i, byte) in slice.iter_mut().enumerate() { + *byte = i as u8; + } + } + + // Read via immutable slice + { + let slice = alloc.as_slice(); + for (i, &byte) in slice.iter().enumerate() { + assert_eq!(byte, i as u8); + } + } + } +} diff --git a/rust/vecsim/src/parallel_stress_tests.rs b/rust/vecsim/src/parallel_stress_tests.rs new file mode 100644 index 000000000..b38254f92 --- /dev/null +++ b/rust/vecsim/src/parallel_stress_tests.rs @@ -0,0 +1,1230 @@ +//! Parallelism stress tests for concurrent index operations. +//! +//! These tests verify thread safety and correctness under high contention +//! scenarios with multiple threads performing concurrent operations. + +use crate::distance::Metric; +use crate::index::brute_force::{BruteForceMulti, BruteForceParams, BruteForceSingle}; +use crate::index::hnsw::{HnswMulti, HnswParams, HnswSingle}; +use crate::index::svs::{SvsParams, SvsSingle}; +use crate::index::traits::VecSimIndex; +use crate::query::{CancellationToken, QueryParams}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +// ============================================================================ +// Concurrent Query Tests +// ============================================================================ + +#[test] +fn test_brute_force_concurrent_queries() { + let params = BruteForceParams::new(8, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Add 1000 vectors + for i in 0..1000u64 { + let mut v = vec![0.0f32; 8]; + v[0] = i as f32; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 8; + let queries_per_thread = 100; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..queries_per_thread { + let q = vec![(t * queries_per_thread + i) as f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +#[test] +fn test_brute_force_parallel_queries_rayon() { + let params = BruteForceParams::new(8, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Add enough vectors to trigger parallel execution (>1000) + for i in 0..2000u64 { + let mut v = vec![0.0f32; 8]; + v[0] = i as f32; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 8; + let queries_per_thread = 50; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..queries_per_thread { + let q = vec![(t * 100 + i) as f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let params = QueryParams::new().with_parallel(true); + let results = index.top_k_query(&q, 10, Some(¶ms)).unwrap(); + assert_eq!(results.len(), 10); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +#[test] +fn test_hnsw_concurrent_queries() { + let params = HnswParams::new(8, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut index = HnswSingle::::new(params); + + // Add 1000 vectors + for i in 0..1000u64 { + let mut v = vec![0.0f32; 8]; + v[0] = i as f32; + v[1] = (i % 100) as f32; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 8; + let queries_per_thread = 100; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..queries_per_thread { + let q = vec![(t * 50 + i) as f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +#[test] +fn test_svs_concurrent_queries() { + let params = SvsParams::new(8, Metric::L2); + let mut index = SvsSingle::::new(params); + + // Add 500 vectors + for i in 0..500u64 { + let mut v = vec![0.0f32; 8]; + v[0] = i as f32; + v[1] = (i % 50) as f32; + index.add_vector(&v, i).unwrap(); + } + + // Build for optimal search + index.build(); + + let index = Arc::new(index); + let num_threads = 4; + let queries_per_thread = 50; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..queries_per_thread { + let q = vec![(t * 50 + i) as f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +// ============================================================================ +// Concurrent Query + Range Query Tests +// ============================================================================ + +#[test] +fn test_brute_force_concurrent_mixed_queries() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..500u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 6; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..50 { + let q = vec![(t * 50 + i) as f32, 0.0, 0.0, 0.0]; + if i % 2 == 0 { + // Top-k query + let results = index.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } else { + // Range query + let results = index.range_query(&q, 100.0, None).unwrap(); + assert!(!results.is_empty()); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +// ============================================================================ +// Concurrent Filtered Query Tests +// ============================================================================ + +#[test] +fn test_brute_force_concurrent_filtered_queries() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..1000u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 4; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..100 { + let q = vec![(t * 100 + i) as f32, 0.0, 0.0, 0.0]; + // Filter to only even labels in some queries + let params = if i % 2 == 0 { + QueryParams::new().with_filter(|label| label % 2 == 0) + } else { + QueryParams::new() + }; + let results = index.top_k_query(&q, 10, Some(¶ms)).unwrap(); + if i % 2 == 0 { + for r in &results.results { + assert!(r.label % 2 == 0); + } + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +#[test] +fn test_hnsw_concurrent_filtered_queries() { + let params = HnswParams::new(4, Metric::L2) + .with_m(16) + .with_ef_construction(50); + let mut index = HnswSingle::::new(params); + + for i in 0..500u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 4; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..50 { + let q = vec![(t * 50 + i) as f32, 0.0, 0.0, 0.0]; + let divisor = (t + 2) as u64; // Different filter per thread + let params = QueryParams::new().with_filter(move |label| label % divisor == 0); + let results = index.top_k_query(&q, 10, Some(¶ms)).unwrap(); + for r in &results.results { + assert!(r.label % divisor == 0); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +// ============================================================================ +// Cancellation Token Tests +// ============================================================================ + +#[test] +fn test_cancellation_token_concurrent_cancel() { + let token = CancellationToken::new(); + let num_readers = 10; + + // Clone tokens for reader threads + let reader_tokens: Vec<_> = (0..num_readers).map(|_| token.clone()).collect(); + + // Spawn reader threads that check cancellation + let check_count = Arc::new(AtomicUsize::new(0)); + let handles: Vec<_> = reader_tokens + .into_iter() + .map(|t| { + let count = Arc::clone(&check_count); + thread::spawn(move || { + while !t.is_cancelled() { + count.fetch_add(1, Ordering::Relaxed); + thread::yield_now(); + } + }) + }) + .collect(); + + // Let readers run for a bit + thread::sleep(Duration::from_millis(10)); + + // Cancel + token.cancel(); + + // All threads should exit + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // Verify cancellation was detected + assert!(token.is_cancelled()); + assert!(check_count.load(Ordering::Relaxed) > 0); +} + +#[test] +fn test_cancellation_token_with_query() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..100u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let token = CancellationToken::new(); + let query_params = QueryParams::new().with_timeout_callback(token.as_callback()); + + // Query before cancellation should work + let q = vec![50.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, Some(&query_params)).unwrap(); + assert!(!results.is_empty()); + + // Cancel the token + token.cancel(); + + // The timeout_callback now returns true, but the query still completes + // (it just checks the callback during execution) + let results2 = index.top_k_query(&q, 5, Some(&query_params)).unwrap(); + // Results may be partial or complete depending on implementation + assert!(results2.len() <= 5); +} + +// ============================================================================ +// Multi-Index Concurrent Tests +// ============================================================================ + +#[test] +fn test_brute_force_multi_concurrent_queries() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceMulti::::new(params); + + // Add multiple vectors per label + for label in 0..100u64 { + for i in 0..5 { + let v = vec![label as f32 + i as f32 * 0.01, 0.0, 0.0, 0.0]; + index.add_vector(&v, label).unwrap(); + } + } + + let index = Arc::new(index); + let num_threads = 4; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..50 { + let q = vec![(t * 25 + i) as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 10, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +#[test] +fn test_hnsw_multi_concurrent_queries() { + let params = HnswParams::new(4, Metric::L2) + .with_m(8) + .with_ef_construction(50); + let mut index = HnswMulti::::new(params); + + // Add multiple vectors per label + for label in 0..100u64 { + for i in 0..3 { + let v = vec![label as f32 + i as f32 * 0.01, 0.0, 0.0, 0.0]; + index.add_vector(&v, label).unwrap(); + } + } + + let index = Arc::new(index); + let num_threads = 4; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..50 { + let q = vec![(t * 25 + i) as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 10, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +// ============================================================================ +// High Contention Query Tests +// ============================================================================ + +#[test] +fn test_brute_force_high_contention_queries() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..200u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 16; // High thread count for contention + let queries_per_thread = 200; + let success_count = Arc::new(AtomicUsize::new(0)); + + let handles: Vec<_> = (0..num_threads) + .map(|_| { + let index = Arc::clone(&index); + let count = Arc::clone(&success_count); + thread::spawn(move || { + for i in 0..queries_per_thread { + let q = vec![(i % 200) as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, None).unwrap(); + if !results.is_empty() { + count.fetch_add(1, Ordering::Relaxed); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // All queries should succeed + assert_eq!( + success_count.load(Ordering::Relaxed), + num_threads * queries_per_thread + ); +} + +#[test] +fn test_hnsw_high_contention_queries() { + let params = HnswParams::new(8, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(50); + let mut index = HnswSingle::::new(params); + + for i in 0..500u64 { + let mut v = vec![0.0f32; 8]; + v[0] = i as f32; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 16; + let queries_per_thread = 100; + let success_count = Arc::new(AtomicUsize::new(0)); + + let handles: Vec<_> = (0..num_threads) + .map(|_| { + let index = Arc::clone(&index); + let count = Arc::clone(&success_count); + thread::spawn(move || { + for i in 0..queries_per_thread { + let q = vec![(i % 500) as f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 10, None).unwrap(); + if !results.is_empty() { + count.fetch_add(1, Ordering::Relaxed); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + assert_eq!( + success_count.load(Ordering::Relaxed), + num_threads * queries_per_thread + ); +} + +// ============================================================================ +// Batch Iterator Concurrent Tests +// ============================================================================ + +#[test] +fn test_brute_force_concurrent_batch_iterators() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..200u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 4; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for _ in 0..10 { + let q = vec![(t * 50) as f32, 0.0, 0.0, 0.0]; + let mut iter = index.batch_iterator(&q, None).unwrap(); + let mut total = 0; + while let Some(batch) = iter.next_batch(20) { + total += batch.len(); + } + assert_eq!(total, 200); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +// ============================================================================ +// Different Metrics Concurrent Tests +// ============================================================================ + +#[test] +fn test_concurrent_queries_different_metrics() { + // Test L2 + let params_l2 = BruteForceParams::new(4, Metric::L2); + let mut index_l2 = BruteForceSingle::::new(params_l2); + + // Test InnerProduct + let params_ip = BruteForceParams::new(4, Metric::InnerProduct); + let mut index_ip = BruteForceSingle::::new(params_ip); + + // Test Cosine + let params_cos = BruteForceParams::new(4, Metric::Cosine); + let mut index_cos = BruteForceSingle::::new(params_cos); + + for i in 0..200u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index_l2.add_vector(&v, i).unwrap(); + index_ip.add_vector(&v, i).unwrap(); + index_cos.add_vector(&v, i).unwrap(); + } + + let index_l2 = Arc::new(index_l2); + let index_ip = Arc::new(index_ip); + let index_cos = Arc::new(index_cos); + + let handles: Vec<_> = vec![ + { + let index = Arc::clone(&index_l2); + thread::spawn(move || { + for i in 0..100 { + let q = vec![i as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }, + { + let index = Arc::clone(&index_ip); + thread::spawn(move || { + for i in 0..100 { + let q = vec![i as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }, + { + let index = Arc::clone(&index_cos); + thread::spawn(move || { + for i in 0..100 { + let q = vec![i as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }, + ]; + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +// ============================================================================ +// Stress Test with Long Duration +// ============================================================================ + +#[test] +fn test_sustained_concurrent_load() { + let params = HnswParams::new(4, Metric::L2) + .with_m(8) + .with_ef_construction(50); + let mut index = HnswSingle::::new(params); + + for i in 0..300u64 { + let v = vec![i as f32, (i % 50) as f32, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let duration = Duration::from_millis(500); + let num_threads = 8; + let query_counts: Vec<_> = (0..num_threads) + .map(|_| Arc::new(AtomicUsize::new(0))) + .collect(); + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + let count = Arc::clone(&query_counts[t]); + let start = std::time::Instant::now(); + thread::spawn(move || { + let mut i = 0u64; + while start.elapsed() < duration { + let q = vec![(i % 300) as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + count.fetch_add(1, Ordering::Relaxed); + i += 1; + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // Verify all threads performed queries + let total: usize = query_counts.iter().map(|c| c.load(Ordering::Relaxed)).sum(); + assert!(total > num_threads * 10, "Expected more queries to complete"); +} + +// ============================================================================ +// Memory Ordering Tests +// ============================================================================ + +#[test] +fn test_atomic_count_consistency() { + let params = BruteForceParams::new(4, Metric::L2); + let index = Arc::new(parking_lot::RwLock::new(BruteForceSingle::::new(params))); + + let num_threads = 4; + let adds_per_thread = 50; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..adds_per_thread { + let label = (t * adds_per_thread + i) as u64; + let v = vec![label as f32, 0.0, 0.0, 0.0]; + index.write().add_vector(&v, label).unwrap(); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // Final count should be exact + let final_count = index.read().index_size(); + assert_eq!(final_count, num_threads * adds_per_thread); +} + +// ============================================================================ +// Query Result Correctness Under Concurrency +// ============================================================================ + +#[test] +fn test_query_result_correctness_concurrent() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + // Add vectors with known distances + // Vector at position i has value [i, 0, 0, 0] + for i in 0..100u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 8; + let error_count = Arc::new(AtomicUsize::new(0)); + + let handles: Vec<_> = (0..num_threads) + .map(|_| { + let index = Arc::clone(&index); + let errors = Arc::clone(&error_count); + thread::spawn(move || { + for target in 0..100u64 { + let q = vec![target as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 1, None).unwrap(); + // The closest vector should be the exact match + if results.results[0].label != target { + errors.fetch_add(1, Ordering::Relaxed); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // No errors should occur - results should always be correct + assert_eq!(error_count.load(Ordering::Relaxed), 0); +} + +#[test] +fn test_hnsw_query_result_stability() { + let params = HnswParams::new(4, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(100); + let mut index = HnswSingle::::new(params); + + for i in 0..100u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 8; + let error_count = Arc::new(AtomicUsize::new(0)); + + let handles: Vec<_> = (0..num_threads) + .map(|_| { + let index = Arc::clone(&index); + let errors = Arc::clone(&error_count); + thread::spawn(move || { + for target in 0..100u64 { + let q = vec![target as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 1, None).unwrap(); + // With high ef_runtime, exact match should be found + if results.results[0].label != target { + errors.fetch_add(1, Ordering::Relaxed); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + // Allow very few errors due to HNSW approximation + let errors = error_count.load(Ordering::Relaxed); + assert!(errors < 10, "Too many incorrect results: {}", errors); +} + +// ============================================================================ +// Visited Nodes Handler Tests (via HNSW search) +// ============================================================================ + +#[test] +fn test_hnsw_visited_handler_concurrent_searches() { + // This tests the visited handler pool indirectly through concurrent HNSW searches + let params = HnswParams::new(8, Metric::L2) + .with_m(16) + .with_ef_construction(100) + .with_ef_runtime(100); + let mut index = HnswSingle::::new(params); + + // Build a larger index to stress the visited handler + for i in 0..1000u64 { + let mut v = vec![0.0f32; 8]; + v[0] = (i % 100) as f32; + v[1] = (i / 100) as f32; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 8; + let queries_per_thread = 200; + + // Many concurrent searches will stress the visited handler pool + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..queries_per_thread { + let q = vec![ + ((t * queries_per_thread + i) % 100) as f32, + ((t * queries_per_thread + i) / 100) as f32, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]; + let results = index.top_k_query(&q, 10, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +#[test] +fn test_hnsw_visited_handler_with_varying_ef() { + // Test with different ef_runtime values to stress visited handler resizing + let params = HnswParams::new(4, Metric::L2) + .with_m(8) + .with_ef_construction(50); + let mut index = HnswSingle::::new(params); + + for i in 0..500u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 4; + + let handles: Vec<_> = (0..num_threads) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..100 { + let q = vec![(t * 100 + i) as f32, 0.0, 0.0, 0.0]; + // Vary ef_runtime to stress handler + let ef = 20 + (i % 80); + let params = QueryParams::new().with_ef_runtime(ef); + let results = index.top_k_query(&q, 10, Some(¶ms)).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } +} + +// ============================================================================ +// Concurrent Modification Tests (Read-Write) +// ============================================================================ + +#[test] +fn test_brute_force_concurrent_read_write() { + let params = BruteForceParams::new(4, Metric::L2); + let index = Arc::new(parking_lot::RwLock::new(BruteForceSingle::::new(params))); + + // Pre-populate with some data + { + let mut idx = index.write(); + for i in 0..100u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + idx.add_vector(&v, i).unwrap(); + } + } + + let num_readers = 4; + let num_writers = 2; + let read_count = Arc::new(AtomicUsize::new(0)); + let write_count = Arc::new(AtomicUsize::new(0)); + + // Spawn reader threads + let reader_handles: Vec<_> = (0..num_readers) + .map(|_| { + let index = Arc::clone(&index); + let count = Arc::clone(&read_count); + thread::spawn(move || { + for i in 0..100 { + let idx = index.read(); + let q = vec![(i % 100) as f32, 0.0, 0.0, 0.0]; + let results = idx.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + count.fetch_add(1, Ordering::Relaxed); + } + }) + }) + .collect(); + + // Spawn writer threads + let writer_handles: Vec<_> = (0..num_writers) + .map(|t| { + let index = Arc::clone(&index); + let count = Arc::clone(&write_count); + thread::spawn(move || { + for i in 0..50 { + let label = 1000 + (t as u64) * 100 + (i as u64); + let v = vec![label as f32, 0.0, 0.0, 0.0]; + index.write().add_vector(&v, label).unwrap(); + count.fetch_add(1, Ordering::Relaxed); + thread::yield_now(); + } + }) + }) + .collect(); + + for handle in reader_handles { + handle.join().expect("Reader thread panicked"); + } + for handle in writer_handles { + handle.join().expect("Writer thread panicked"); + } + + // Verify counts + assert_eq!(read_count.load(Ordering::Relaxed), num_readers * 100); + assert_eq!(write_count.load(Ordering::Relaxed), num_writers * 50); + + // Verify final index state + let final_size = index.read().index_size(); + assert_eq!(final_size, 100 + num_writers * 50); +} + +#[test] +fn test_hnsw_concurrent_read_write() { + let params = HnswParams::new(4, Metric::L2) + .with_m(8) + .with_ef_construction(50); + let index = Arc::new(parking_lot::RwLock::new(HnswSingle::::new(params))); + + // Pre-populate + { + let mut idx = index.write(); + for i in 0..100u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + idx.add_vector(&v, i).unwrap(); + } + } + + let num_readers = 4; + let num_writers = 2; + + let reader_handles: Vec<_> = (0..num_readers) + .map(|_| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..50 { + let idx = index.read(); + let q = vec![(i % 100) as f32, 0.0, 0.0, 0.0]; + let results = idx.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + let writer_handles: Vec<_> = (0..num_writers) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..25 { + let label = 1000 + (t as u64) * 100 + (i as u64); + let v = vec![label as f32, 0.0, 0.0, 0.0]; + index.write().add_vector(&v, label).unwrap(); + thread::yield_now(); + } + }) + }) + .collect(); + + for handle in reader_handles { + handle.join().expect("Reader thread panicked"); + } + for handle in writer_handles { + handle.join().expect("Writer thread panicked"); + } + + let final_size = index.read().index_size(); + assert_eq!(final_size, 100 + num_writers * 25); +} + +#[test] +fn test_concurrent_add_delete() { + let params = BruteForceParams::new(4, Metric::L2); + let index = Arc::new(parking_lot::RwLock::new(BruteForceSingle::::new(params))); + + // Pre-populate + { + let mut idx = index.write(); + for i in 0..200u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + idx.add_vector(&v, i).unwrap(); + } + } + + let add_count = Arc::new(AtomicUsize::new(0)); + let delete_count = Arc::new(AtomicUsize::new(0)); + + // Adder thread + let add_idx = Arc::clone(&index); + let add_cnt = Arc::clone(&add_count); + let add_handle = thread::spawn(move || { + for i in 200..300u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + add_idx.write().add_vector(&v, i).unwrap(); + add_cnt.fetch_add(1, Ordering::Relaxed); + } + }); + + // Deleter thread + let del_idx = Arc::clone(&index); + let del_cnt = Arc::clone(&delete_count); + let delete_handle = thread::spawn(move || { + for i in 0..100u64 { + // Deletes might fail if already deleted + if del_idx.write().delete_vector(i).is_ok() { + del_cnt.fetch_add(1, Ordering::Relaxed); + } + } + }); + + // Query thread while modifications happen + let query_idx = Arc::clone(&index); + let query_handle = thread::spawn(move || { + for i in 0..100 { + let q = vec![(i + 100) as f32, 0.0, 0.0, 0.0]; + let idx = query_idx.read(); + let results = idx.top_k_query(&q, 5, None); + // Query should not panic + assert!(results.is_ok()); + } + }); + + add_handle.join().expect("Add thread panicked"); + delete_handle.join().expect("Delete thread panicked"); + query_handle.join().expect("Query thread panicked"); + + // Verify final state + let adds = add_count.load(Ordering::Relaxed); + let deletes = delete_count.load(Ordering::Relaxed); + let final_size = index.read().index_size(); + + assert_eq!(adds, 100); + assert_eq!(deletes, 100); + assert_eq!(final_size, 200 + adds - deletes); // 200 initial + 100 adds - 100 deletes = 200 +} + +// ============================================================================ +// Multi-Index Concurrent Modification Tests +// ============================================================================ + +#[test] +fn test_brute_force_multi_concurrent_read_write() { + let params = BruteForceParams::new(4, Metric::L2); + let index = Arc::new(parking_lot::RwLock::new(BruteForceMulti::::new(params))); + + // Pre-populate + { + let mut idx = index.write(); + for label in 0..50u64 { + for i in 0..3 { + let v = vec![label as f32 + i as f32 * 0.01, 0.0, 0.0, 0.0]; + idx.add_vector(&v, label).unwrap(); + } + } + } + + let num_readers = 4; + let num_writers = 2; + + let reader_handles: Vec<_> = (0..num_readers) + .map(|_| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..50 { + let idx = index.read(); + let q = vec![(i % 50) as f32, 0.0, 0.0, 0.0]; + let results = idx.top_k_query(&q, 5, None).unwrap(); + assert!(!results.is_empty()); + } + }) + }) + .collect(); + + let writer_handles: Vec<_> = (0..num_writers) + .map(|t| { + let index = Arc::clone(&index); + thread::spawn(move || { + for i in 0..25 { + let label = 100 + (t as u64) * 50 + (i as u64); + let v = vec![label as f32, 0.0, 0.0, 0.0]; + index.write().add_vector(&v, label).unwrap(); + thread::yield_now(); + } + }) + }) + .collect(); + + for handle in reader_handles { + handle.join().expect("Reader thread panicked"); + } + for handle in writer_handles { + handle.join().expect("Writer thread panicked"); + } + + let final_size = index.read().index_size(); + assert_eq!(final_size, 50 * 3 + num_writers * 25); // 150 initial + 50 added +} + +// ============================================================================ +// Extreme Contention Tests +// ============================================================================ + +#[test] +fn test_extreme_read_contention() { + let params = BruteForceParams::new(4, Metric::L2); + let mut index = BruteForceSingle::::new(params); + + for i in 0..100u64 { + let v = vec![i as f32, 0.0, 0.0, 0.0]; + index.add_vector(&v, i).unwrap(); + } + + let index = Arc::new(index); + let num_threads = 32; // Many threads + let queries_per_thread = 500; + let success_count = Arc::new(AtomicUsize::new(0)); + + let handles: Vec<_> = (0..num_threads) + .map(|_| { + let index = Arc::clone(&index); + let count = Arc::clone(&success_count); + thread::spawn(move || { + for i in 0..queries_per_thread { + let q = vec![(i % 100) as f32, 0.0, 0.0, 0.0]; + let results = index.top_k_query(&q, 3, None).unwrap(); + if !results.is_empty() { + count.fetch_add(1, Ordering::Relaxed); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + assert_eq!( + success_count.load(Ordering::Relaxed), + num_threads * queries_per_thread + ); +} + +#[test] +fn test_rapid_add_query_interleave() { + let params = BruteForceParams::new(4, Metric::L2); + let index = Arc::new(parking_lot::RwLock::new(BruteForceSingle::::new(params))); + + let num_iterations = 100; + let operations_complete = Arc::new(AtomicUsize::new(0)); + + // Thread that adds and immediately queries + let handles: Vec<_> = (0..4) + .map(|t| { + let index = Arc::clone(&index); + let ops = Arc::clone(&operations_complete); + thread::spawn(move || { + for i in 0..num_iterations { + let label = (t * num_iterations + i) as u64; + let v = vec![label as f32, 0.0, 0.0, 0.0]; + + // Add + index.write().add_vector(&v, label).unwrap(); + + // Immediately query + let q = vec![label as f32, 0.0, 0.0, 0.0]; + let results = index.read().top_k_query(&q, 1, None).unwrap(); + assert!(!results.is_empty()); + + ops.fetch_add(1, Ordering::Relaxed); + } + }) + }) + .collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + assert_eq!(operations_complete.load(Ordering::Relaxed), 4 * num_iterations); +} diff --git a/rust/vecsim/src/preprocessing/mod.rs b/rust/vecsim/src/preprocessing/mod.rs new file mode 100644 index 000000000..1b7743cda --- /dev/null +++ b/rust/vecsim/src/preprocessing/mod.rs @@ -0,0 +1,554 @@ +//! Vector preprocessing pipeline. +//! +//! This module provides preprocessing transformations applied to vectors before +//! storage or querying: +//! +//! - **Normalization**: L2-normalize vectors for cosine similarity +//! - **Quantization**: Convert to quantized representation for storage +//! +//! ## Asymmetric Preprocessing +//! +//! Storage and query vectors may be processed differently: +//! - Storage vectors may be quantized to reduce memory +//! - Query vectors remain in original format with precomputed metadata +//! +//! ## Example +//! +//! ```rust,ignore +//! use vecsim::preprocessing::{PreprocessorChain, CosinePreprocessor}; +//! +//! let preprocessor = CosinePreprocessor::new(128); +//! let vector = vec![1.0f32, 2.0, 3.0, 4.0]; +//! +//! let processed = preprocessor.preprocess_storage(&vector); +//! ``` + +use crate::distance::normalize_in_place; + +/// Trait for vector preprocessors. +/// +/// Preprocessors transform vectors before storage or querying. They may apply +/// different transformations to storage vs query vectors (asymmetric preprocessing). +pub trait Preprocessor: Send + Sync { + /// Preprocess a vector for storage. + /// + /// Returns the processed vector data as bytes. The output may be a different + /// size or format than the input. + fn preprocess_storage(&self, vector: &[f32]) -> Vec; + + /// Preprocess a vector for querying. + /// + /// Returns the processed query data as bytes. For asymmetric preprocessing, + /// this may include precomputed metadata for faster distance computation. + fn preprocess_query(&self, vector: &[f32]) -> Vec; + + /// Preprocess a vector in-place for storage. + /// + /// This modifies the vector directly when possible, avoiding allocation. + /// Returns true if preprocessing was successful. + fn preprocess_storage_in_place(&self, vector: &mut [f32]) -> bool; + + /// Get the storage size in bytes for a vector of given dimension. + fn storage_size(&self, dim: usize) -> usize; + + /// Get the query size in bytes for a vector of given dimension. + fn query_size(&self, dim: usize) -> usize; + + /// Get the name of this preprocessor. + fn name(&self) -> &'static str; +} + +/// Identity preprocessor that performs no transformation. +#[derive(Debug, Clone, Default)] +pub struct IdentityPreprocessor { + #[allow(dead_code)] + dim: usize, +} + +impl IdentityPreprocessor { + /// Create a new identity preprocessor. + pub fn new(dim: usize) -> Self { + Self { dim } + } +} + +impl Preprocessor for IdentityPreprocessor { + fn preprocess_storage(&self, vector: &[f32]) -> Vec { + // Just copy the raw bytes + let bytes: &[u8] = unsafe { + std::slice::from_raw_parts(vector.as_ptr() as *const u8, vector.len() * 4) + }; + bytes.to_vec() + } + + fn preprocess_query(&self, vector: &[f32]) -> Vec { + self.preprocess_storage(vector) + } + + fn preprocess_storage_in_place(&self, _vector: &mut [f32]) -> bool { + true // No-op + } + + fn storage_size(&self, dim: usize) -> usize { + dim * std::mem::size_of::() + } + + fn query_size(&self, dim: usize) -> usize { + dim * std::mem::size_of::() + } + + fn name(&self) -> &'static str { + "Identity" + } +} + +/// Cosine preprocessor that L2-normalizes vectors. +/// +/// After normalization, cosine similarity can be computed as a simple dot product. +#[derive(Debug, Clone)] +pub struct CosinePreprocessor { + #[allow(dead_code)] + dim: usize, +} + +impl CosinePreprocessor { + /// Create a new cosine preprocessor. + pub fn new(dim: usize) -> Self { + Self { dim } + } +} + +impl Preprocessor for CosinePreprocessor { + fn preprocess_storage(&self, vector: &[f32]) -> Vec { + let mut normalized = vector.to_vec(); + normalize_in_place(&mut normalized); + + let bytes: &[u8] = unsafe { + std::slice::from_raw_parts(normalized.as_ptr() as *const u8, normalized.len() * 4) + }; + bytes.to_vec() + } + + fn preprocess_query(&self, vector: &[f32]) -> Vec { + self.preprocess_storage(vector) + } + + fn preprocess_storage_in_place(&self, vector: &mut [f32]) -> bool { + normalize_in_place(vector); + true + } + + fn storage_size(&self, dim: usize) -> usize { + dim * std::mem::size_of::() + } + + fn query_size(&self, dim: usize) -> usize { + dim * std::mem::size_of::() + } + + fn name(&self) -> &'static str { + "Cosine" + } +} + +/// Quantization preprocessor for asymmetric SQ8 encoding. +/// +/// Storage vectors are quantized to uint8, while query vectors remain as f32 +/// with precomputed metadata for efficient asymmetric distance computation. +#[derive(Debug, Clone)] +pub struct QuantPreprocessor { + #[allow(dead_code)] + dim: usize, + /// Whether to include sum of squares (needed for L2). + include_sum_squares: bool, +} + +/// Metadata stored alongside quantized vectors. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct QuantMeta { + /// Minimum value for dequantization. + pub min_val: f32, + /// Scale factor for dequantization. + pub delta: f32, + /// Sum of original values (for IP computation). + pub x_sum: f32, + /// Sum of squared values (for L2 computation). + pub x_sum_squares: f32, +} + +impl QuantPreprocessor { + /// Create a quantization preprocessor. + /// + /// # Arguments + /// * `dim` - Vector dimension + /// * `include_sum_squares` - Include sum of squares for L2 distance + pub fn new(dim: usize, include_sum_squares: bool) -> Self { + Self { + dim, + include_sum_squares, + } + } + + /// Create for L2 distance (includes sum of squares). + pub fn for_l2(dim: usize) -> Self { + Self::new(dim, true) + } + + /// Create for inner product (excludes sum of squares). + pub fn for_ip(dim: usize) -> Self { + Self::new(dim, false) + } + + /// Quantize a vector to uint8. + fn quantize(&self, vector: &[f32]) -> (Vec, QuantMeta) { + let dim = vector.len(); + + // Find min and max + let mut min_val = f32::MAX; + let mut max_val = f32::MIN; + for &v in vector { + min_val = min_val.min(v); + max_val = max_val.max(v); + } + + // Compute delta + let range = max_val - min_val; + let delta = if range > 1e-10 { range / 255.0 } else { 1.0 }; + let inv_delta = 1.0 / delta; + + // Quantize and compute sums + let mut quantized = vec![0u8; dim]; + let mut x_sum = 0.0f32; + let mut x_sum_squares = 0.0f32; + + // 4-way unrolled loop for cache efficiency + let chunks = dim / 4; + for i in 0..chunks { + let base = i * 4; + for j in 0..4 { + let idx = base + j; + let v = vector[idx]; + let q = ((v - min_val) * inv_delta).round() as u8; + quantized[idx] = q.min(255); + x_sum += v; + x_sum_squares += v * v; + } + } + + // Handle remainder + for idx in (chunks * 4)..dim { + let v = vector[idx]; + let q = ((v - min_val) * inv_delta).round() as u8; + quantized[idx] = q.min(255); + x_sum += v; + x_sum_squares += v * v; + } + + let meta = QuantMeta { + min_val, + delta, + x_sum, + x_sum_squares, + }; + + (quantized, meta) + } +} + +impl Preprocessor for QuantPreprocessor { + fn preprocess_storage(&self, vector: &[f32]) -> Vec { + let (quantized, meta) = self.quantize(vector); + + // Build storage blob: quantized values + metadata + let meta_size = if self.include_sum_squares { 16 } else { 12 }; + let mut blob = Vec::with_capacity(quantized.len() + meta_size); + blob.extend_from_slice(&quantized); + blob.extend_from_slice(&meta.min_val.to_le_bytes()); + blob.extend_from_slice(&meta.delta.to_le_bytes()); + blob.extend_from_slice(&meta.x_sum.to_le_bytes()); + if self.include_sum_squares { + blob.extend_from_slice(&meta.x_sum_squares.to_le_bytes()); + } + + blob + } + + fn preprocess_query(&self, vector: &[f32]) -> Vec { + // Query remains as f32 with precomputed sums + let mut y_sum = 0.0f32; + let mut y_sum_squares = 0.0f32; + + for &v in vector { + y_sum += v; + y_sum_squares += v * v; + } + + // Build query blob: f32 values + sums + let meta_size = if self.include_sum_squares { 8 } else { 4 }; + let mut blob = Vec::with_capacity(vector.len() * 4 + meta_size); + + for &v in vector { + blob.extend_from_slice(&v.to_le_bytes()); + } + blob.extend_from_slice(&y_sum.to_le_bytes()); + if self.include_sum_squares { + blob.extend_from_slice(&y_sum_squares.to_le_bytes()); + } + + blob + } + + fn preprocess_storage_in_place(&self, _vector: &mut [f32]) -> bool { + // Can't do in-place quantization to different type + false + } + + fn storage_size(&self, dim: usize) -> usize { + let meta_size = if self.include_sum_squares { 16 } else { 12 }; + dim + meta_size + } + + fn query_size(&self, dim: usize) -> usize { + let meta_size = if self.include_sum_squares { 8 } else { 4 }; + dim * std::mem::size_of::() + meta_size + } + + fn name(&self) -> &'static str { + "Quant" + } +} + +/// Chain of preprocessors applied sequentially. +pub struct PreprocessorChain { + preprocessors: Vec>, + dim: usize, +} + +impl PreprocessorChain { + /// Create a new empty preprocessor chain. + pub fn new(dim: usize) -> Self { + Self { + preprocessors: Vec::new(), + dim, + } + } + + /// Add a preprocessor to the chain. + pub fn add(mut self, preprocessor: P) -> Self { + self.preprocessors.push(Box::new(preprocessor)); + self + } + + /// Check if the chain is empty. + pub fn is_empty(&self) -> bool { + self.preprocessors.is_empty() + } + + /// Get the number of preprocessors in the chain. + pub fn len(&self) -> usize { + self.preprocessors.len() + } + + /// Preprocess for storage through the chain. + pub fn preprocess_storage(&self, vector: &[f32]) -> Vec { + if self.preprocessors.is_empty() { + // No preprocessing - just return raw bytes + let bytes: &[u8] = unsafe { + std::slice::from_raw_parts(vector.as_ptr() as *const u8, vector.len() * 4) + }; + return bytes.to_vec(); + } + + // Apply first preprocessor + let result = self.preprocessors[0].preprocess_storage(vector); + + // Chain remaining preprocessors (if any) + // Note: Chaining is complex when types change; for now, just use first + result + } + + /// Preprocess for query through the chain. + pub fn preprocess_query(&self, vector: &[f32]) -> Vec { + if self.preprocessors.is_empty() { + let bytes: &[u8] = unsafe { + std::slice::from_raw_parts(vector.as_ptr() as *const u8, vector.len() * 4) + }; + return bytes.to_vec(); + } + + self.preprocessors[0].preprocess_query(vector) + } + + /// Preprocess in-place for storage. + pub fn preprocess_storage_in_place(&self, vector: &mut [f32]) -> bool { + for pp in &self.preprocessors { + if !pp.preprocess_storage_in_place(vector) { + return false; + } + } + true + } + + /// Get the final storage size. + pub fn storage_size(&self) -> usize { + if let Some(last) = self.preprocessors.last() { + last.storage_size(self.dim) + } else { + self.dim * std::mem::size_of::() + } + } + + /// Get the final query size. + pub fn query_size(&self) -> usize { + if let Some(last) = self.preprocessors.last() { + last.query_size(self.dim) + } else { + self.dim * std::mem::size_of::() + } + } +} + +/// Processed vector blobs for storage and query. +#[derive(Debug)] +pub struct ProcessedBlobs { + /// Processed storage data. + pub storage: Vec, + /// Processed query data. + pub query: Vec, +} + +impl ProcessedBlobs { + /// Create from separate storage and query blobs. + pub fn new(storage: Vec, query: Vec) -> Self { + Self { storage, query } + } + + /// Create when storage and query are identical. + pub fn symmetric(data: Vec) -> Self { + Self { + query: data.clone(), + storage: data, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_identity_preprocessor() { + let pp = IdentityPreprocessor::new(4); + let vector = vec![1.0f32, 2.0, 3.0, 4.0]; + + let storage = pp.preprocess_storage(&vector); + assert_eq!(storage.len(), 16); // 4 * 4 bytes + + let query = pp.preprocess_query(&vector); + assert_eq!(query.len(), 16); + + assert_eq!(storage, query); + } + + #[test] + fn test_cosine_preprocessor() { + let pp = CosinePreprocessor::new(4); + let vector = vec![1.0f32, 0.0, 0.0, 0.0]; + + let storage = pp.preprocess_storage(&vector); + assert_eq!(storage.len(), 16); + + // Parse back to f32 + let normalized: Vec = storage + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + + // Should be normalized (unit vector) + let norm: f32 = normalized.iter().map(|x| x * x).sum::().sqrt(); + assert!((norm - 1.0).abs() < 0.001); + } + + #[test] + fn test_cosine_preprocessor_in_place() { + let pp = CosinePreprocessor::new(4); + let mut vector = vec![3.0f32, 4.0, 0.0, 0.0]; + + assert!(pp.preprocess_storage_in_place(&mut vector)); + + // Should be normalized + let norm: f32 = vector.iter().map(|x| x * x).sum::().sqrt(); + assert!((norm - 1.0).abs() < 0.001); + + // Original direction preserved + assert!((vector[0] - 0.6).abs() < 0.001); + assert!((vector[1] - 0.8).abs() < 0.001); + } + + #[test] + fn test_quant_preprocessor() { + let pp = QuantPreprocessor::for_l2(4); + let vector = vec![0.0f32, 0.5, 0.75, 1.0]; + + let storage = pp.preprocess_storage(&vector); + // 4 bytes quantized + 16 bytes metadata + assert_eq!(storage.len(), 20); + + // Check quantized values + assert_eq!(storage[0], 0); // 0.0 -> 0 + assert_eq!(storage[3], 255); // 1.0 -> 255 + } + + #[test] + fn test_quant_preprocessor_query() { + let pp = QuantPreprocessor::for_l2(4); + let vector = vec![1.0f32, 2.0, 3.0, 4.0]; + + let query = pp.preprocess_query(&vector); + // 16 bytes f32 values + 8 bytes (y_sum + y_sum_squares) + assert_eq!(query.len(), 24); + } + + #[test] + fn test_preprocessor_chain() { + let chain = PreprocessorChain::new(4).add(CosinePreprocessor::new(4)); + + assert_eq!(chain.len(), 1); + assert!(!chain.is_empty()); + + let vector = vec![3.0f32, 4.0, 0.0, 0.0]; + let storage = chain.preprocess_storage(&vector); + + // Parse and verify normalization + let normalized: Vec = storage + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + + let norm: f32 = normalized.iter().map(|x| x * x).sum::().sqrt(); + assert!((norm - 1.0).abs() < 0.001); + } + + #[test] + fn test_empty_chain() { + let chain = PreprocessorChain::new(4); + assert!(chain.is_empty()); + + let vector = vec![1.0f32, 2.0, 3.0, 4.0]; + let storage = chain.preprocess_storage(&vector); + + // Should be unchanged + let parsed: Vec = storage + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + assert_eq!(parsed, vector); + } + + #[test] + fn test_quant_meta_layout() { + // Verify QuantMeta is 16 bytes as expected + assert_eq!(std::mem::size_of::(), 16); + } +} diff --git a/rust/vecsim/src/quantization/leanvec.rs b/rust/vecsim/src/quantization/leanvec.rs new file mode 100644 index 000000000..4919d651c --- /dev/null +++ b/rust/vecsim/src/quantization/leanvec.rs @@ -0,0 +1,983 @@ +//! LeanVec quantization with dimension reduction. +//! +//! LeanVec combines dimension reduction with two-level quantization for efficient +//! vector compression while maintaining search accuracy: +//! +//! - **4x8 LeanVec**: 4-bit primary (on D/2 dims) + 8-bit residual (on D dims) +//! - **8x8 LeanVec**: 8-bit primary (on D/2 dims) + 8-bit residual (on D dims) +//! +//! ## Storage Layout +//! +//! For 4x8 LeanVec with D=128 and leanvec_dim=64: +//! ```text +//! Metadata (24 bytes): +//! - min_primary, delta_primary (8 bytes) +//! - min_residual, delta_residual (8 bytes) +//! - leanvec_dim (8 bytes) +//! +//! Primary data: leanvec_dim * primary_bits / 8 bytes +//! - For 4x8: 64 * 4 / 8 = 32 bytes +//! - For 8x8: 64 * 8 / 8 = 64 bytes +//! +//! Residual data: dim * 8 / 8 bytes = 128 bytes +//! ``` +//! +//! ## Memory Efficiency +//! +//! For D=128: +//! - f32 uncompressed: 512 bytes +//! - 4x8 LeanVec: 24 + 32 + 128 = 184 bytes (~2.8x compression) +//! - 8x8 LeanVec: 24 + 64 + 128 = 216 bytes (~2.4x compression) + +use std::fmt; + +/// LeanVec configuration specifying primary and residual quantization bits. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LeanVecBits { + /// 4-bit primary (reduced dims) + 8-bit residual (full dims). + LeanVec4x8, + /// 8-bit primary (reduced dims) + 8-bit residual (full dims). + LeanVec8x8, +} + +impl LeanVecBits { + /// Get primary quantization bits. + pub fn primary_bits(&self) -> usize { + match self { + LeanVecBits::LeanVec4x8 => 4, + LeanVecBits::LeanVec8x8 => 8, + } + } + + /// Get residual quantization bits. + pub fn residual_bits(&self) -> usize { + 8 // Always 8-bit for LeanVec + } + + /// Get number of primary quantization levels. + pub fn primary_levels(&self) -> usize { + 1 << self.primary_bits() + } + + /// Calculate primary data size in bytes. + pub fn primary_data_size(&self, leanvec_dim: usize) -> usize { + (leanvec_dim * self.primary_bits() + 7) / 8 + } + + /// Calculate residual data size in bytes. + pub fn residual_data_size(&self, full_dim: usize) -> usize { + full_dim // 8 bits per dimension = 1 byte + } + + /// Calculate total encoded size (excluding metadata). + pub fn encoded_size(&self, full_dim: usize, leanvec_dim: usize) -> usize { + self.primary_data_size(leanvec_dim) + self.residual_data_size(full_dim) + } +} + +/// Metadata for a LeanVec-encoded vector. +#[derive(Clone, Copy)] +#[repr(C)] +pub struct LeanVecMeta { + /// Minimum value for primary quantization. + pub min_primary: f32, + /// Scale factor for primary quantization. + pub delta_primary: f32, + /// Minimum value for residual quantization. + pub min_residual: f32, + /// Scale factor for residual quantization. + pub delta_residual: f32, + /// Reduced dimension used for primary quantization. + pub leanvec_dim: u32, + /// Padding for alignment. + _pad: u32, +} + +impl fmt::Debug for LeanVecMeta { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LeanVecMeta") + .field("min_primary", &self.min_primary) + .field("delta_primary", &self.delta_primary) + .field("min_residual", &self.min_residual) + .field("delta_residual", &self.delta_residual) + .field("leanvec_dim", &self.leanvec_dim) + .finish() + } +} + +impl Default for LeanVecMeta { + fn default() -> Self { + Self { + min_primary: 0.0, + delta_primary: 1.0, + min_residual: 0.0, + delta_residual: 1.0, + leanvec_dim: 0, + _pad: 0, + } + } +} + +impl LeanVecMeta { + /// Size of metadata in bytes. + pub const SIZE: usize = 24; +} + +/// LeanVec codec for encoding and decoding vectors with dimension reduction. +pub struct LeanVecCodec { + /// Full dimensionality of vectors. + dim: usize, + /// Reduced dimensionality for primary quantization. + leanvec_dim: usize, + /// Quantization configuration. + bits: LeanVecBits, + /// Indices of dimensions selected for primary quantization. + /// These are the top-variance dimensions (or simply first D/2 for simplicity). + selected_dims: Vec, +} + +impl LeanVecCodec { + /// Create a new LeanVec codec with default reduced dimension (dim/2). + pub fn new(dim: usize, bits: LeanVecBits) -> Self { + Self::with_leanvec_dim(dim, bits, 0) + } + + /// Create a new LeanVec codec with custom reduced dimension. + /// + /// # Arguments + /// * `dim` - Full dimensionality of vectors + /// * `bits` - Quantization configuration + /// * `leanvec_dim` - Reduced dimension (0 = default dim/2) + pub fn with_leanvec_dim(dim: usize, bits: LeanVecBits, leanvec_dim: usize) -> Self { + let leanvec_dim = if leanvec_dim == 0 { + dim / 2 + } else { + leanvec_dim.min(dim) + }; + + // Default: select first leanvec_dim dimensions + // In a full implementation, this would be learned from data + let selected_dims: Vec = (0..leanvec_dim).collect(); + + Self { + dim, + leanvec_dim, + bits, + selected_dims, + } + } + + /// Get the full dimension. + pub fn dim(&self) -> usize { + self.dim + } + + /// Get the reduced dimension. + pub fn leanvec_dim(&self) -> usize { + self.leanvec_dim + } + + /// Get the bits configuration. + pub fn bits(&self) -> LeanVecBits { + self.bits + } + + /// Get total encoded size including metadata. + pub fn total_size(&self) -> usize { + LeanVecMeta::SIZE + self.bits.encoded_size(self.dim, self.leanvec_dim) + } + + /// Get the size of just the encoded data (without metadata). + pub fn encoded_size(&self) -> usize { + self.bits.encoded_size(self.dim, self.leanvec_dim) + } + + /// Set custom dimension selection order (e.g., from variance analysis). + pub fn set_selected_dims(&mut self, dims: Vec) { + assert!(dims.len() >= self.leanvec_dim); + self.selected_dims = dims; + } + + /// Extract reduced-dimension vector for primary quantization. + fn extract_reduced(&self, vector: &[f32]) -> Vec { + self.selected_dims + .iter() + .take(self.leanvec_dim) + .map(|&i| vector[i]) + .collect() + } + + /// Encode a vector using LeanVec4x8 (4-bit primary + 8-bit residual). + pub fn encode_4x8(&self, vector: &[f32]) -> (LeanVecMeta, Vec) { + debug_assert_eq!(vector.len(), self.dim); + + // Extract reduced dimensions for primary quantization + let reduced = self.extract_reduced(vector); + + // Primary quantization (4-bit) on reduced dimensions + let (min_primary, max_primary) = reduced + .iter() + .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v))); + + let range_primary = max_primary - min_primary; + let delta_primary = if range_primary > 1e-10 { + range_primary / 15.0 // 4-bit = 16 levels + } else { + 1.0 + }; + let inv_delta_primary = 1.0 / delta_primary; + + // Encode primary (4-bit packed) + let primary_bytes = (self.leanvec_dim + 1) / 2; + let mut primary_encoded = vec![0u8; primary_bytes]; + + // We also need to track the primary reconstruction for residual calculation + let mut primary_reconstructed = vec![0.0f32; self.leanvec_dim]; + + for i in 0..self.leanvec_dim { + let normalized = (reduced[i] - min_primary) * inv_delta_primary; + let q = (normalized.round() as u8).min(15); + + let byte_idx = i / 2; + if i % 2 == 0 { + primary_encoded[byte_idx] |= q; + } else { + primary_encoded[byte_idx] |= q << 4; + } + + primary_reconstructed[i] = min_primary + (q as f32) * delta_primary; + } + + // Compute residuals (full dimension - reconstruction error contribution) + // For simplicity, we compute residuals as the full vector quantized + // In a full implementation, this would account for primary reconstruction + + let (min_residual, max_residual) = vector + .iter() + .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v))); + + let range_residual = max_residual - min_residual; + let delta_residual = if range_residual > 1e-10 { + range_residual / 255.0 // 8-bit = 256 levels + } else { + 1.0 + }; + let inv_delta_residual = 1.0 / delta_residual; + + // Encode residuals (8-bit, full dimension) + let mut residual_encoded = vec![0u8; self.dim]; + for i in 0..self.dim { + let normalized = (vector[i] - min_residual) * inv_delta_residual; + residual_encoded[i] = (normalized.round() as u8).min(255); + } + + // Combine primary and residual data + let mut encoded = primary_encoded; + encoded.extend(residual_encoded); + + let meta = LeanVecMeta { + min_primary, + delta_primary, + min_residual, + delta_residual, + leanvec_dim: self.leanvec_dim as u32, + _pad: 0, + }; + + (meta, encoded) + } + + /// Decode a LeanVec4x8 encoded vector. + pub fn decode_4x8(&self, meta: &LeanVecMeta, encoded: &[u8]) -> Vec { + let leanvec_dim = meta.leanvec_dim as usize; + let primary_bytes = (leanvec_dim + 1) / 2; + + // For decoding, we primarily use the residual (full precision) part + // The primary is for fast approximate search + let mut decoded = Vec::with_capacity(self.dim); + + for i in 0..self.dim { + let residual_q = encoded[primary_bytes + i]; + let value = meta.min_residual + (residual_q as f32) * meta.delta_residual; + decoded.push(value); + } + + decoded + } + + /// Encode a vector using LeanVec8x8 (8-bit primary + 8-bit residual). + pub fn encode_8x8(&self, vector: &[f32]) -> (LeanVecMeta, Vec) { + debug_assert_eq!(vector.len(), self.dim); + + // Extract reduced dimensions for primary quantization + let reduced = self.extract_reduced(vector); + + // Primary quantization (8-bit) on reduced dimensions + let (min_primary, max_primary) = reduced + .iter() + .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v))); + + let range_primary = max_primary - min_primary; + let delta_primary = if range_primary > 1e-10 { + range_primary / 255.0 // 8-bit = 256 levels + } else { + 1.0 + }; + let inv_delta_primary = 1.0 / delta_primary; + + // Encode primary (8-bit, one byte per dim) + let mut primary_encoded = vec![0u8; self.leanvec_dim]; + for i in 0..self.leanvec_dim { + let normalized = (reduced[i] - min_primary) * inv_delta_primary; + primary_encoded[i] = (normalized.round() as u8).min(255); + } + + // Residual quantization (8-bit, full dimension) + let (min_residual, max_residual) = vector + .iter() + .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v))); + + let range_residual = max_residual - min_residual; + let delta_residual = if range_residual > 1e-10 { + range_residual / 255.0 + } else { + 1.0 + }; + let inv_delta_residual = 1.0 / delta_residual; + + let mut residual_encoded = vec![0u8; self.dim]; + for i in 0..self.dim { + let normalized = (vector[i] - min_residual) * inv_delta_residual; + residual_encoded[i] = (normalized.round() as u8).min(255); + } + + // Combine + let mut encoded = primary_encoded; + encoded.extend(residual_encoded); + + let meta = LeanVecMeta { + min_primary, + delta_primary, + min_residual, + delta_residual, + leanvec_dim: self.leanvec_dim as u32, + _pad: 0, + }; + + (meta, encoded) + } + + /// Decode a LeanVec8x8 encoded vector. + pub fn decode_8x8(&self, meta: &LeanVecMeta, encoded: &[u8]) -> Vec { + let leanvec_dim = meta.leanvec_dim as usize; + + // Use residual part for full reconstruction + let mut decoded = Vec::with_capacity(self.dim); + for i in 0..self.dim { + let residual_q = encoded[leanvec_dim + i]; + let value = meta.min_residual + (residual_q as f32) * meta.delta_residual; + decoded.push(value); + } + + decoded + } + + /// Encode using configured bits. + pub fn encode(&self, vector: &[f32]) -> (LeanVecMeta, Vec) { + match self.bits { + LeanVecBits::LeanVec4x8 => self.encode_4x8(vector), + LeanVecBits::LeanVec8x8 => self.encode_8x8(vector), + } + } + + /// Decode using configured bits. + pub fn decode(&self, meta: &LeanVecMeta, encoded: &[u8]) -> Vec { + match self.bits { + LeanVecBits::LeanVec4x8 => self.decode_4x8(meta, encoded), + LeanVecBits::LeanVec8x8 => self.decode_8x8(meta, encoded), + } + } +} + +/// Compute asymmetric L2 squared distance using primary (reduced) dimensions only. +/// +/// This is a fast approximate distance for initial filtering. +#[inline] +pub fn leanvec_primary_l2_squared_4bit( + query: &[f32], + encoded: &[u8], + meta: &LeanVecMeta, + selected_dims: &[usize], +) -> f32 { + let leanvec_dim = meta.leanvec_dim as usize; + let mut sum = 0.0f32; + + for i in 0..leanvec_dim { + let byte_idx = i / 2; + let q = if i % 2 == 0 { + encoded[byte_idx] & 0x0F + } else { + (encoded[byte_idx] >> 4) & 0x0F + }; + + let stored = meta.min_primary + (q as f32) * meta.delta_primary; + let diff = query[selected_dims[i]] - stored; + sum += diff * diff; + } + + sum +} + +/// Compute asymmetric L2 squared distance using primary (reduced) dimensions only. +/// +/// This is a fast approximate distance for initial filtering. (8-bit version) +#[inline] +pub fn leanvec_primary_l2_squared_8bit( + query: &[f32], + encoded: &[u8], + meta: &LeanVecMeta, + selected_dims: &[usize], +) -> f32 { + let leanvec_dim = meta.leanvec_dim as usize; + let mut sum = 0.0f32; + + for i in 0..leanvec_dim { + let stored = meta.min_primary + (encoded[i] as f32) * meta.delta_primary; + let diff = query[selected_dims[i]] - stored; + sum += diff * diff; + } + + sum +} + +/// Compute full asymmetric L2 squared distance using residual (full) dimensions. +#[inline] +pub fn leanvec_residual_l2_squared( + query: &[f32], + encoded: &[u8], + meta: &LeanVecMeta, + primary_bytes: usize, +) -> f32 { + let dim = query.len(); + let mut sum = 0.0f32; + + for i in 0..dim { + let residual_q = encoded[primary_bytes + i]; + let stored = meta.min_residual + (residual_q as f32) * meta.delta_residual; + let diff = query[i] - stored; + sum += diff * diff; + } + + sum +} + +/// Compute full asymmetric inner product using residual (full) dimensions. +#[inline] +pub fn leanvec_residual_inner_product( + query: &[f32], + encoded: &[u8], + meta: &LeanVecMeta, + primary_bytes: usize, +) -> f32 { + let dim = query.len(); + let mut sum = 0.0f32; + + for i in 0..dim { + let residual_q = encoded[primary_bytes + i]; + let stored = meta.min_residual + (residual_q as f32) * meta.delta_residual; + sum += query[i] * stored; + } + + sum +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_leanvec_4x8_encode_decode() { + let codec = LeanVecCodec::new(128, LeanVecBits::LeanVec4x8); + assert_eq!(codec.leanvec_dim(), 64); // Default D/2 + + // Create test vector + let vector: Vec = (0..128).map(|i| (i as f32) / 128.0).collect(); + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + // Check approximate equality + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.01, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_leanvec_8x8_encode_decode() { + let codec = LeanVecCodec::new(128, LeanVecBits::LeanVec8x8); + assert_eq!(codec.leanvec_dim(), 64); + + let vector: Vec = (0..128).map(|i| (i as f32) / 128.0).collect(); + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.01, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_leanvec_custom_dim() { + let codec = LeanVecCodec::with_leanvec_dim(128, LeanVecBits::LeanVec4x8, 32); + assert_eq!(codec.leanvec_dim(), 32); + assert_eq!(codec.dim(), 128); + } + + #[test] + fn test_leanvec_memory_efficiency() { + let dim = 128; + + // f32 uncompressed + let f32_bytes = dim * 4; // 512 bytes + + // 4x8 LeanVec: primary(64*4/8=32) + residual(128) + meta(24) + let leanvec_4x8 = LeanVecMeta::SIZE + LeanVecBits::LeanVec4x8.encoded_size(dim, dim / 2); + + // 8x8 LeanVec: primary(64*8/8=64) + residual(128) + meta(24) + let leanvec_8x8 = LeanVecMeta::SIZE + LeanVecBits::LeanVec8x8.encoded_size(dim, dim / 2); + + println!("Memory for {} dimensions:", dim); + println!(" f32: {} bytes", f32_bytes); + println!( + " LeanVec4x8: {} bytes ({:.2}x compression)", + leanvec_4x8, + f32_bytes as f32 / leanvec_4x8 as f32 + ); + println!( + " LeanVec8x8: {} bytes ({:.2}x compression)", + leanvec_8x8, + f32_bytes as f32 / leanvec_8x8 as f32 + ); + + // Verify compression ratios + assert!(leanvec_4x8 < f32_bytes / 2); // >2x compression + assert!(leanvec_8x8 < f32_bytes / 2); // >2x compression + } + + #[test] + fn test_leanvec_primary_distance() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + let vector: Vec = vec![1.0, 0.5, 0.25, 0.125, 0.0, 0.0, 0.0, 0.0]; + let query = vector.clone(); + + let (meta, encoded) = codec.encode(&vector); + + // Primary distance should be small for self-query + let selected_dims: Vec = (0..4).collect(); + let primary_dist = + leanvec_primary_l2_squared_4bit(&query, &encoded, &meta, &selected_dims); + + assert!( + primary_dist < 0.1, + "Primary self-distance should be small, got {}", + primary_dist + ); + } + + #[test] + fn test_leanvec_residual_distance() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + let vector: Vec = vec![1.0, 0.5, 0.25, 0.125, 0.0, 0.0, 0.0, 0.0]; + let query = vector.clone(); + + let (meta, encoded) = codec.encode(&vector); + + let primary_bytes = LeanVecBits::LeanVec4x8.primary_data_size(4); + let residual_dist = leanvec_residual_l2_squared(&query, &encoded, &meta, primary_bytes); + + assert!( + residual_dist < 0.01, + "Residual self-distance should be very small, got {}", + residual_dist + ); + } + + #[test] + fn test_leanvec_bits_config() { + assert_eq!(LeanVecBits::LeanVec4x8.primary_bits(), 4); + assert_eq!(LeanVecBits::LeanVec4x8.residual_bits(), 8); + assert_eq!(LeanVecBits::LeanVec4x8.primary_levels(), 16); + + assert_eq!(LeanVecBits::LeanVec8x8.primary_bits(), 8); + assert_eq!(LeanVecBits::LeanVec8x8.residual_bits(), 8); + assert_eq!(LeanVecBits::LeanVec8x8.primary_levels(), 256); + } + + #[test] + fn test_leanvec_bits_data_sizes() { + // 4x8: 4 bits primary, 8 bits residual + assert_eq!(LeanVecBits::LeanVec4x8.primary_data_size(64), 32); // 64 * 4 / 8 + assert_eq!(LeanVecBits::LeanVec4x8.residual_data_size(128), 128); // 1 byte per dim + + // 8x8: 8 bits both + assert_eq!(LeanVecBits::LeanVec8x8.primary_data_size(64), 64); // 64 * 8 / 8 + assert_eq!(LeanVecBits::LeanVec8x8.residual_data_size(128), 128); + + // Total encoded size + assert_eq!(LeanVecBits::LeanVec4x8.encoded_size(128, 64), 32 + 128); + assert_eq!(LeanVecBits::LeanVec8x8.encoded_size(128, 64), 64 + 128); + } + + #[test] + fn test_leanvec_meta_default() { + let meta = LeanVecMeta::default(); + assert_eq!(meta.min_primary, 0.0); + assert_eq!(meta.delta_primary, 1.0); + assert_eq!(meta.min_residual, 0.0); + assert_eq!(meta.delta_residual, 1.0); + assert_eq!(meta.leanvec_dim, 0); + } + + #[test] + fn test_leanvec_meta_size() { + assert_eq!(LeanVecMeta::SIZE, 24); // 4 f32 + 1 u32 + 1 padding + } + + #[test] + fn test_leanvec_codec_accessors() { + let codec = LeanVecCodec::new(128, LeanVecBits::LeanVec4x8); + assert_eq!(codec.dim(), 128); + assert_eq!(codec.leanvec_dim(), 64); // Default D/2 + assert_eq!(codec.bits(), LeanVecBits::LeanVec4x8); + } + + #[test] + fn test_leanvec_codec_sizes() { + let codec = LeanVecCodec::new(128, LeanVecBits::LeanVec4x8); + // 4x8 with D=128, leanvec_dim=64: + // primary = 64*4/8 = 32 bytes, residual = 128 bytes + assert_eq!(codec.encoded_size(), 32 + 128); + assert_eq!(codec.total_size(), LeanVecMeta::SIZE + 32 + 128); + } + + #[test] + fn test_leanvec_custom_leanvec_dim() { + // Test various custom reduced dimensions + let codec = LeanVecCodec::with_leanvec_dim(128, LeanVecBits::LeanVec4x8, 32); + assert_eq!(codec.leanvec_dim(), 32); + + // 0 should default to dim/2 + let codec_default = LeanVecCodec::with_leanvec_dim(128, LeanVecBits::LeanVec4x8, 0); + assert_eq!(codec_default.leanvec_dim(), 64); + + // Value > dim should be clamped to dim + let codec_clamped = LeanVecCodec::with_leanvec_dim(128, LeanVecBits::LeanVec4x8, 256); + assert_eq!(codec_clamped.leanvec_dim(), 128); + } + + #[test] + fn test_leanvec_4x8_negative_values() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + let vector: Vec = vec![-0.5, -0.25, 0.0, 0.25, 0.5, 0.75, -0.1, -0.9]; + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.02, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_leanvec_8x8_negative_values() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec8x8); + let vector: Vec = vec![-0.5, -0.25, 0.0, 0.25, 0.5, 0.75, -0.1, -0.9]; + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.02, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_leanvec_residual_inner_product() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + let stored: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let query: Vec = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; + + let (meta, encoded) = codec.encode(&stored); + + let primary_bytes = LeanVecBits::LeanVec4x8.primary_data_size(4); + let ip = leanvec_residual_inner_product(&query, &encoded, &meta, primary_bytes); + + // IP = 1+2+3+4+5+6+7+8 = 36 + assert!( + (ip - 36.0).abs() < 2.0, + "IP should be ~36, got {}", + ip + ); + } + + #[test] + fn test_leanvec_8x8_primary_distance() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec8x8); + let vector: Vec = vec![1.0, 0.5, 0.25, 0.125, 0.0, 0.0, 0.0, 0.0]; + let query = vector.clone(); + + let (meta, encoded) = codec.encode(&vector); + + let selected_dims: Vec = (0..4).collect(); + let primary_dist = + leanvec_primary_l2_squared_8bit(&query, &encoded, &meta, &selected_dims); + + assert!( + primary_dist < 0.1, + "Primary self-distance should be small, got {}", + primary_dist + ); + } + + #[test] + fn test_leanvec_distance_ordering() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + + let v1 = vec![1.0, 0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.25, 0.0]; + let query = vec![1.0, 0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0]; + + let (meta1, enc1) = codec.encode(&v1); + let (meta2, enc2) = codec.encode(&v2); + let (meta3, enc3) = codec.encode(&v3); + + let primary_bytes = LeanVecBits::LeanVec4x8.primary_data_size(4); + + let d1 = leanvec_residual_l2_squared(&query, &enc1, &meta1, primary_bytes); + let d2 = leanvec_residual_l2_squared(&query, &enc2, &meta2, primary_bytes); + let d3 = leanvec_residual_l2_squared(&query, &enc3, &meta3, primary_bytes); + + // d1 should be smallest (self), d3 should be largest + assert!(d1 < d2, "d1={} should be < d2={}", d1, d2); + assert!(d2 < d3, "d2={} should be < d3={}", d2, d3); + } + + #[test] + fn test_leanvec_large_vectors() { + let codec = LeanVecCodec::new(512, LeanVecBits::LeanVec4x8); + let vector: Vec = (0..512).map(|i| ((i as f32) / 512.0) - 0.5).collect(); + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + let max_error = vector + .iter() + .zip(decoded.iter()) + .map(|(orig, dec)| (orig - dec).abs()) + .fold(0.0f32, f32::max); + + assert!( + max_error < 0.02, + "Large vector should decode well, max_error={}", + max_error + ); + } + + #[test] + fn test_leanvec_zero_vector() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + let vector: Vec = vec![0.0; 8]; + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + for dec in decoded.iter() { + assert!(dec.abs() < 0.001, "expected ~0, got {}", dec); + } + } + + #[test] + fn test_leanvec_uniform_vector() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + let vector: Vec = vec![0.5; 8]; + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + for dec in decoded.iter() { + assert!( + (dec - 0.5).abs() < 0.01, + "expected ~0.5, got {}", + dec + ); + } + } + + #[test] + fn test_leanvec_4x8_vs_8x8_precision() { + let codec_4x8 = LeanVecCodec::new(64, LeanVecBits::LeanVec4x8); + let codec_8x8 = LeanVecCodec::new(64, LeanVecBits::LeanVec8x8); + let vector: Vec = (0..64).map(|i| (i as f32) / 64.0).collect(); + + let (_, encoded_4x8) = codec_4x8.encode(&vector); + let (_, encoded_8x8) = codec_8x8.encode(&vector); + + // 8x8 should use more bytes for primary (8-bit vs 4-bit) + // 4x8: primary = 32 * 4 / 8 = 16 bytes + // 8x8: primary = 32 * 8 / 8 = 32 bytes + assert!( + encoded_8x8.len() > encoded_4x8.len(), + "8x8 should use more bytes: 4x8={}, 8x8={}", + encoded_4x8.len(), + encoded_8x8.len() + ); + } + + #[test] + fn test_leanvec_custom_dim_selection() { + let mut codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + + // Set custom dimension selection (use last 4 dims instead of first 4) + let custom_dims: Vec = vec![4, 5, 6, 7, 0, 1, 2, 3]; + codec.set_selected_dims(custom_dims); + + let vector: Vec = vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.25, 0.125]; + + let (meta, encoded) = codec.encode(&vector); + + // The primary quantization should now work on dims 4-7 + // Primary distance should still work with the selected dims + let selected_dims: Vec = vec![4, 5, 6, 7]; + let dist = leanvec_primary_l2_squared_4bit(&vector, &encoded, &meta, &selected_dims); + + // Self-distance should be small + assert!( + dist < 0.5, + "Self distance should be small, got {}", + dist + ); + } + + #[test] + fn test_leanvec_odd_dimension() { + // LeanVec with odd dimension + let codec = LeanVecCodec::new(7, LeanVecBits::LeanVec4x8); + assert_eq!(codec.leanvec_dim(), 3); // 7/2 = 3 + + let vector: Vec = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]; + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + assert_eq!(decoded.len(), 7); + + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.02, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_leanvec_primary_4bit_packing() { + let codec = LeanVecCodec::new(8, LeanVecBits::LeanVec4x8); + let vector: Vec = vec![0.0, 0.25, 0.5, 0.75, 1.0, 0.9, 0.8, 0.7]; + + let (_meta, encoded) = codec.encode(&vector); + + // Primary data should be (leanvec_dim + 1) / 2 = 2 bytes (4 dims * 4 bits / 8) + let primary_bytes = LeanVecBits::LeanVec4x8.primary_data_size(4); + assert_eq!(primary_bytes, 2); + + // Residual data should be 8 bytes + let residual_bytes = LeanVecBits::LeanVec4x8.residual_data_size(8); + assert_eq!(residual_bytes, 8); + + // Total encoded + assert_eq!(encoded.len(), primary_bytes + residual_bytes); + } + + #[test] + fn test_leanvec_meta_debug() { + let meta = LeanVecMeta { + min_primary: 1.0, + delta_primary: 0.5, + min_residual: -1.0, + delta_residual: 0.25, + leanvec_dim: 32, + _pad: 0, + }; + + let debug_str = format!("{:?}", meta); + assert!(debug_str.contains("LeanVecMeta")); + assert!(debug_str.contains("min_primary")); + assert!(debug_str.contains("leanvec_dim")); + } + + #[test] + fn test_leanvec_generic_encode_decode_dispatch() { + // Test that generic encode/decode correctly dispatches + for bits in [LeanVecBits::LeanVec4x8, LeanVecBits::LeanVec8x8] { + let codec = LeanVecCodec::new(16, bits); + let vector: Vec = (0..16).map(|i| (i as f32) / 16.0).collect(); + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + assert_eq!(decoded.len(), 16); + + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.02, + "bits={:?}: orig={}, dec={}", + bits, + orig, + dec + ); + } + } + } + + #[test] + fn test_leanvec_primary_reduces_dimension() { + let codec = LeanVecCodec::new(128, LeanVecBits::LeanVec4x8); + + // The primary quantization operates on leanvec_dim=64 + let selected_dims: Vec = (0..64).collect(); + + let vector: Vec = (0..128).map(|i| (i as f32) / 128.0).collect(); + let (meta, encoded) = codec.encode(&vector); + + // Primary distance only uses the first leanvec_dim dimensions + let dist = leanvec_primary_l2_squared_4bit(&vector, &encoded, &meta, &selected_dims); + + // Should be able to compute distance without error + assert!(dist.is_finite(), "Distance should be finite"); + assert!(dist >= 0.0, "Distance should be non-negative"); + } +} diff --git a/rust/vecsim/src/quantization/lvq.rs b/rust/vecsim/src/quantization/lvq.rs new file mode 100644 index 000000000..11ad41f14 --- /dev/null +++ b/rust/vecsim/src/quantization/lvq.rs @@ -0,0 +1,912 @@ +//! Learned Vector Quantization (LVQ) support. +//! +//! LVQ provides memory-efficient vector storage using learned quantization: +//! - Primary quantization: 4-bit or 8-bit per dimension +//! - Optional residual quantization: additional bits for error correction +//! - Per-vector scaling factors for optimal range utilization +//! +//! ## Two-Level LVQ (LVQ4x4, LVQ8x8) +//! +//! Two-level quantization stores: +//! 1. Primary quantized values (e.g., 4 bits) +//! 2. Residual error quantized values (e.g., 4 bits) +//! +//! This provides better accuracy than single-level quantization while +//! maintaining memory efficiency. +//! +//! ## Memory Layout +//! +//! For LVQ4x4 (4-bit primary + 4-bit residual): +//! - Meta: min_primary, delta_primary, min_residual, delta_residual (16 bytes) +//! - Primary data: dim/2 bytes (4 bits packed) +//! - Residual data: dim/2 bytes (4 bits packed) +//! - Total: 16 + dim bytes per vector + +/// Number of quantization levels for 4-bit (16 levels). +pub const LEVELS_4BIT: usize = 16; + +/// Number of quantization levels for 8-bit (256 levels). +pub const LEVELS_8BIT: usize = 256; + +/// LVQ quantization bits configuration. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LvqBits { + /// 4-bit primary quantization. + Lvq4, + /// 8-bit primary quantization. + Lvq8, + /// 4-bit primary + 4-bit residual (8 bits total). + Lvq4x4, + /// 8-bit primary + 8-bit residual (16 bits total). + Lvq8x8, +} + +impl LvqBits { + /// Get primary bits. + pub fn primary_bits(&self) -> usize { + match self { + LvqBits::Lvq4 | LvqBits::Lvq4x4 => 4, + LvqBits::Lvq8 | LvqBits::Lvq8x8 => 8, + } + } + + /// Get residual bits (0 if single-level). + pub fn residual_bits(&self) -> usize { + match self { + LvqBits::Lvq4 | LvqBits::Lvq8 => 0, + LvqBits::Lvq4x4 => 4, + LvqBits::Lvq8x8 => 8, + } + } + + /// Check if two-level quantization. + pub fn is_two_level(&self) -> bool { + matches!(self, LvqBits::Lvq4x4 | LvqBits::Lvq8x8) + } + + /// Get total bits per dimension. + pub fn total_bits(&self) -> usize { + self.primary_bits() + self.residual_bits() + } + + /// Get bytes per vector (excluding metadata). + pub fn data_bytes(&self, dim: usize) -> usize { + let total_bits = self.total_bits() * dim; + (total_bits + 7) / 8 + } +} + +/// Metadata for a quantized vector. +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct LvqVectorMeta { + /// Minimum value for primary quantization. + pub min_primary: f32, + /// Scale factor for primary quantization. + pub delta_primary: f32, + /// Minimum value for residual quantization (0 if single-level). + pub min_residual: f32, + /// Scale factor for residual quantization (0 if single-level). + pub delta_residual: f32, +} + +impl Default for LvqVectorMeta { + fn default() -> Self { + Self { + min_primary: 0.0, + delta_primary: 1.0, + min_residual: 0.0, + delta_residual: 0.0, + } + } +} + +impl LvqVectorMeta { + /// Size of metadata in bytes. + pub const SIZE: usize = 16; +} + +/// LVQ codec for encoding and decoding vectors. +pub struct LvqCodec { + dim: usize, + bits: LvqBits, +} + +impl LvqCodec { + /// Create a new LVQ codec. + pub fn new(dim: usize, bits: LvqBits) -> Self { + Self { dim, bits } + } + + /// Get the dimension. + pub fn dim(&self) -> usize { + self.dim + } + + /// Get the bits configuration. + pub fn bits(&self) -> LvqBits { + self.bits + } + + /// Get the size of encoded data in bytes (excluding metadata). + pub fn encoded_size(&self) -> usize { + self.bits.data_bytes(self.dim) + } + + /// Get total size including metadata. + pub fn total_size(&self) -> usize { + LvqVectorMeta::SIZE + self.encoded_size() + } + + /// Encode a vector using LVQ4 (4-bit single-level). + pub fn encode_lvq4(&self, vector: &[f32]) -> (LvqVectorMeta, Vec) { + debug_assert_eq!(vector.len(), self.dim); + + // Find min and max + let (min_val, max_val) = vector + .iter() + .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v))); + + let range = max_val - min_val; + let delta = if range > 1e-10 { + range / 15.0 // 4-bit = 16 levels (0-15) + } else { + 1.0 + }; + let inv_delta = 1.0 / delta; + + // Encode to 4-bit (pack 2 values per byte) + let mut encoded = vec![0u8; (self.dim + 1) / 2]; + + for i in 0..self.dim { + let normalized = (vector[i] - min_val) * inv_delta; + let quantized = (normalized.round() as u8).min(15); + + let byte_idx = i / 2; + if i % 2 == 0 { + encoded[byte_idx] |= quantized; + } else { + encoded[byte_idx] |= quantized << 4; + } + } + + let meta = LvqVectorMeta { + min_primary: min_val, + delta_primary: delta, + min_residual: 0.0, + delta_residual: 0.0, + }; + + (meta, encoded) + } + + /// Decode an LVQ4 encoded vector. + pub fn decode_lvq4(&self, meta: &LvqVectorMeta, encoded: &[u8]) -> Vec { + let mut decoded = Vec::with_capacity(self.dim); + + for i in 0..self.dim { + let byte_idx = i / 2; + let quantized = if i % 2 == 0 { + encoded[byte_idx] & 0x0F + } else { + (encoded[byte_idx] >> 4) & 0x0F + }; + + let value = meta.min_primary + (quantized as f32) * meta.delta_primary; + decoded.push(value); + } + + decoded + } + + /// Encode a vector using LVQ4x4 (4-bit primary + 4-bit residual). + pub fn encode_lvq4x4(&self, vector: &[f32]) -> (LvqVectorMeta, Vec) { + debug_assert_eq!(vector.len(), self.dim); + + // First pass: primary quantization + let (min_primary, max_primary) = vector + .iter() + .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v))); + + let range_primary = max_primary - min_primary; + let delta_primary = if range_primary > 1e-10 { + range_primary / 15.0 + } else { + 1.0 + }; + let inv_delta_primary = 1.0 / delta_primary; + + // Compute primary quantization and residuals + let mut primary_quantized = vec![0u8; self.dim]; + let mut residuals = vec![0.0f32; self.dim]; + + for i in 0..self.dim { + let normalized = (vector[i] - min_primary) * inv_delta_primary; + let q = (normalized.round() as u8).min(15); + primary_quantized[i] = q; + + // Compute residual (original - reconstructed) + let reconstructed = min_primary + (q as f32) * delta_primary; + residuals[i] = vector[i] - reconstructed; + } + + // Second pass: residual quantization + let (min_residual, max_residual) = residuals + .iter() + .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v))); + + let range_residual = max_residual - min_residual; + let delta_residual = if range_residual > 1e-10 { + range_residual / 15.0 + } else { + 1.0 + }; + let inv_delta_residual = 1.0 / delta_residual; + + // Encode both primary and residual (interleaved: primary in low nibble, residual in high) + let mut encoded = vec![0u8; self.dim]; + + for i in 0..self.dim { + let residual_normalized = (residuals[i] - min_residual) * inv_delta_residual; + let residual_q = (residual_normalized.round() as u8).min(15); + + // Pack primary (low nibble) and residual (high nibble) + encoded[i] = primary_quantized[i] | (residual_q << 4); + } + + let meta = LvqVectorMeta { + min_primary, + delta_primary, + min_residual, + delta_residual, + }; + + (meta, encoded) + } + + /// Decode an LVQ4x4 encoded vector. + pub fn decode_lvq4x4(&self, meta: &LvqVectorMeta, encoded: &[u8]) -> Vec { + let mut decoded = Vec::with_capacity(self.dim); + + for i in 0..self.dim { + let primary_q = encoded[i] & 0x0F; + let residual_q = (encoded[i] >> 4) & 0x0F; + + let primary_value = meta.min_primary + (primary_q as f32) * meta.delta_primary; + let residual_value = meta.min_residual + (residual_q as f32) * meta.delta_residual; + + decoded.push(primary_value + residual_value); + } + + decoded + } + + /// Encode a vector using LVQ8 (8-bit single-level). + pub fn encode_lvq8(&self, vector: &[f32]) -> (LvqVectorMeta, Vec) { + debug_assert_eq!(vector.len(), self.dim); + + let (min_val, max_val) = vector + .iter() + .fold((f32::MAX, f32::MIN), |(min, max), &v| (min.min(v), max.max(v))); + + let range = max_val - min_val; + let delta = if range > 1e-10 { + range / 255.0 // 8-bit = 256 levels (0-255) + } else { + 1.0 + }; + let inv_delta = 1.0 / delta; + + let mut encoded = vec![0u8; self.dim]; + + for i in 0..self.dim { + let normalized = (vector[i] - min_val) * inv_delta; + encoded[i] = (normalized.round() as u8).min(255); + } + + let meta = LvqVectorMeta { + min_primary: min_val, + delta_primary: delta, + min_residual: 0.0, + delta_residual: 0.0, + }; + + (meta, encoded) + } + + /// Decode an LVQ8 encoded vector. + pub fn decode_lvq8(&self, meta: &LvqVectorMeta, encoded: &[u8]) -> Vec { + let mut decoded = Vec::with_capacity(self.dim); + + for i in 0..self.dim { + let value = meta.min_primary + (encoded[i] as f32) * meta.delta_primary; + decoded.push(value); + } + + decoded + } + + /// Encode a vector based on configured bits. + pub fn encode(&self, vector: &[f32]) -> (LvqVectorMeta, Vec) { + match self.bits { + LvqBits::Lvq4 => self.encode_lvq4(vector), + LvqBits::Lvq4x4 => self.encode_lvq4x4(vector), + LvqBits::Lvq8 => self.encode_lvq8(vector), + LvqBits::Lvq8x8 => { + // LVQ8x8 not implemented yet, fall back to LVQ8 + self.encode_lvq8(vector) + } + } + } + + /// Decode a vector based on configured bits. + pub fn decode(&self, meta: &LvqVectorMeta, encoded: &[u8]) -> Vec { + match self.bits { + LvqBits::Lvq4 => self.decode_lvq4(meta, encoded), + LvqBits::Lvq4x4 => self.decode_lvq4x4(meta, encoded), + LvqBits::Lvq8 => self.decode_lvq8(meta, encoded), + LvqBits::Lvq8x8 => self.decode_lvq8(meta, encoded), + } + } +} + +/// Compute asymmetric L2 squared distance between f32 query and LVQ4 quantized vector. +#[inline] +pub fn lvq4_asymmetric_l2_squared( + query: &[f32], + quantized: &[u8], + meta: &LvqVectorMeta, + dim: usize, +) -> f32 { + let mut sum = 0.0f32; + + for i in 0..dim { + let byte_idx = i / 2; + let q = if i % 2 == 0 { + quantized[byte_idx] & 0x0F + } else { + (quantized[byte_idx] >> 4) & 0x0F + }; + + let stored = meta.min_primary + (q as f32) * meta.delta_primary; + let diff = query[i] - stored; + sum += diff * diff; + } + + sum +} + +/// Compute asymmetric L2 squared distance between f32 query and LVQ4x4 quantized vector. +#[inline] +pub fn lvq4x4_asymmetric_l2_squared( + query: &[f32], + quantized: &[u8], + meta: &LvqVectorMeta, + dim: usize, +) -> f32 { + let mut sum = 0.0f32; + + for i in 0..dim { + let primary_q = quantized[i] & 0x0F; + let residual_q = (quantized[i] >> 4) & 0x0F; + + let primary_value = meta.min_primary + (primary_q as f32) * meta.delta_primary; + let residual_value = meta.min_residual + (residual_q as f32) * meta.delta_residual; + let stored = primary_value + residual_value; + + let diff = query[i] - stored; + sum += diff * diff; + } + + sum +} + +/// Compute asymmetric inner product between f32 query and LVQ4x4 quantized vector. +#[inline] +pub fn lvq4x4_asymmetric_inner_product( + query: &[f32], + quantized: &[u8], + meta: &LvqVectorMeta, + dim: usize, +) -> f32 { + let mut sum = 0.0f32; + + for i in 0..dim { + let primary_q = quantized[i] & 0x0F; + let residual_q = (quantized[i] >> 4) & 0x0F; + + let primary_value = meta.min_primary + (primary_q as f32) * meta.delta_primary; + let residual_value = meta.min_residual + (residual_q as f32) * meta.delta_residual; + let stored = primary_value + residual_value; + + sum += query[i] * stored; + } + + sum +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_lvq4_encode_decode() { + let codec = LvqCodec::new(8, LvqBits::Lvq4); + let vector: Vec = vec![0.1, 0.5, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4]; + + let (meta, encoded) = codec.encode_lvq4(&vector); + let decoded = codec.decode_lvq4(&meta, &encoded); + + // Check approximate equality (4-bit has limited precision) + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.1, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_lvq4x4_encode_decode() { + let codec = LvqCodec::new(8, LvqBits::Lvq4x4); + let vector: Vec = vec![0.1, 0.5, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4]; + + let (meta, encoded) = codec.encode_lvq4x4(&vector); + let decoded = codec.decode_lvq4x4(&meta, &encoded); + + // Two-level should have better precision + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.05, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_lvq8_encode_decode() { + let codec = LvqCodec::new(8, LvqBits::Lvq8); + let vector: Vec = vec![0.1, 0.5, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4]; + + let (meta, encoded) = codec.encode_lvq8(&vector); + let decoded = codec.decode_lvq8(&meta, &encoded); + + // 8-bit should be very accurate + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.01, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_lvq4_asymmetric_distance() { + let codec = LvqCodec::new(4, LvqBits::Lvq4); + let stored = vec![1.0, 0.0, 0.0, 0.0]; + let query = vec![1.0, 0.0, 0.0, 0.0]; + + let (meta, encoded) = codec.encode_lvq4(&stored); + let dist = lvq4_asymmetric_l2_squared(&query, &encoded, &meta, 4); + + // Self-distance should be near 0 + assert!(dist < 0.1, "Self distance should be near 0, got {}", dist); + } + + #[test] + fn test_lvq4x4_asymmetric_distance() { + let codec = LvqCodec::new(4, LvqBits::Lvq4x4); + let stored = vec![1.0, 0.0, 0.0, 0.0]; + let query = vec![1.0, 0.0, 0.0, 0.0]; + + let (meta, encoded) = codec.encode_lvq4x4(&stored); + let dist = lvq4x4_asymmetric_l2_squared(&query, &encoded, &meta, 4); + + // Self-distance should be very near 0 with two-level + assert!(dist < 0.01, "Self distance should be near 0, got {}", dist); + } + + #[test] + fn test_lvq_memory_savings() { + let dim = 128; + + // f32 storage + let f32_bytes = dim * 4; + + // LVQ4 storage (4 bits per dim + meta) + let lvq4_bytes = (dim + 1) / 2 + LvqVectorMeta::SIZE; + + // LVQ4x4 storage (8 bits per dim + meta) + let lvq4x4_bytes = dim + LvqVectorMeta::SIZE; + + // LVQ8 storage (8 bits per dim + meta) + let lvq8_bytes = dim + LvqVectorMeta::SIZE; + + println!("Memory for {} dimensions:", dim); + println!(" f32: {} bytes", f32_bytes); + println!(" LVQ4: {} bytes ({}x compression)", lvq4_bytes, f32_bytes as f32 / lvq4_bytes as f32); + println!(" LVQ4x4: {} bytes ({}x compression)", lvq4x4_bytes, f32_bytes as f32 / lvq4x4_bytes as f32); + println!(" LVQ8: {} bytes ({}x compression)", lvq8_bytes, f32_bytes as f32 / lvq8_bytes as f32); + + // Verify compression ratios + assert!(lvq4_bytes < f32_bytes / 4); // >4x compression + assert!(lvq4x4_bytes < f32_bytes / 2); // >2x compression + } + + #[test] + fn test_lvq_bits_config() { + assert_eq!(LvqBits::Lvq4.primary_bits(), 4); + assert_eq!(LvqBits::Lvq4.residual_bits(), 0); + assert!(!LvqBits::Lvq4.is_two_level()); + + assert_eq!(LvqBits::Lvq4x4.primary_bits(), 4); + assert_eq!(LvqBits::Lvq4x4.residual_bits(), 4); + assert!(LvqBits::Lvq4x4.is_two_level()); + + assert_eq!(LvqBits::Lvq8.primary_bits(), 8); + assert_eq!(LvqBits::Lvq8.residual_bits(), 0); + } + + #[test] + fn test_lvq_bits_total_bits() { + assert_eq!(LvqBits::Lvq4.total_bits(), 4); + assert_eq!(LvqBits::Lvq8.total_bits(), 8); + assert_eq!(LvqBits::Lvq4x4.total_bits(), 8); + assert_eq!(LvqBits::Lvq8x8.total_bits(), 16); + } + + #[test] + fn test_lvq_bits_data_bytes() { + // LVQ4: 4 bits per dim, so for 8 dims = 32 bits = 4 bytes + assert_eq!(LvqBits::Lvq4.data_bytes(8), 4); + // Odd dimension: 9 dims = 36 bits = 5 bytes (rounded up) + assert_eq!(LvqBits::Lvq4.data_bytes(9), 5); + + // LVQ8: 8 bits per dim = 1 byte per dim + assert_eq!(LvqBits::Lvq8.data_bytes(8), 8); + + // LVQ4x4: 8 bits total per dim + assert_eq!(LvqBits::Lvq4x4.data_bytes(8), 8); + } + + #[test] + fn test_lvq4_negative_values() { + let codec = LvqCodec::new(8, LvqBits::Lvq4); + let vector: Vec = vec![-0.5, -0.25, 0.0, 0.25, 0.5, 0.75, -0.1, -0.9]; + + let (meta, encoded) = codec.encode_lvq4(&vector); + let decoded = codec.decode_lvq4(&meta, &encoded); + + // 4-bit has limited precision but should be reasonably close + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.15, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_lvq4_odd_dimension() { + // Test odd dimension where 4-bit packing doesn't align perfectly + let codec = LvqCodec::new(7, LvqBits::Lvq4); + let vector: Vec = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]; + + let (meta, encoded) = codec.encode_lvq4(&vector); + // Encoded should be (7 + 1) / 2 = 4 bytes + assert_eq!(encoded.len(), 4); + + let decoded = codec.decode_lvq4(&meta, &encoded); + assert_eq!(decoded.len(), 7); + + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.1, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_lvq4_uniform_vector() { + let codec = LvqCodec::new(8, LvqBits::Lvq4); + let vector: Vec = vec![0.5; 8]; + + let (meta, encoded) = codec.encode_lvq4(&vector); + let decoded = codec.decode_lvq4(&meta, &encoded); + + // All values should decode to approximately the same + for dec in decoded.iter() { + assert!( + (dec - 0.5).abs() < 0.01, + "expected ~0.5, got {}", + dec + ); + } + + // All encoded values should be 0 (since all at min) + for byte in encoded.iter() { + assert_eq!(*byte, 0, "uniform vector should encode to 0"); + } + } + + #[test] + fn test_lvq8_high_precision() { + let codec = LvqCodec::new(128, LvqBits::Lvq8); + let vector: Vec = (0..128).map(|i| (i as f32) / 128.0).collect(); + + let (meta, encoded) = codec.encode_lvq8(&vector); + let decoded = codec.decode_lvq8(&meta, &encoded); + + // 8-bit should have very good precision + let max_error = vector + .iter() + .zip(decoded.iter()) + .map(|(orig, dec)| (orig - dec).abs()) + .fold(0.0f32, f32::max); + + assert!( + max_error < 0.005, + "8-bit should have high precision, max_error={}", + max_error + ); + } + + #[test] + fn test_lvq4x4_negative_values() { + let codec = LvqCodec::new(8, LvqBits::Lvq4x4); + let vector: Vec = vec![-0.5, -0.25, 0.0, 0.25, 0.5, 0.75, -0.1, -0.9]; + + let (meta, encoded) = codec.encode_lvq4x4(&vector); + let decoded = codec.decode_lvq4x4(&meta, &encoded); + + // Two-level should be better than single-level + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.1, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_lvq4x4_residual_improves_accuracy() { + let codec4 = LvqCodec::new(8, LvqBits::Lvq4); + let codec4x4 = LvqCodec::new(8, LvqBits::Lvq4x4); + let vector: Vec = vec![0.1, 0.5, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4]; + + let (meta4, encoded4) = codec4.encode_lvq4(&vector); + let decoded4 = codec4.decode_lvq4(&meta4, &encoded4); + + let (meta4x4, encoded4x4) = codec4x4.encode_lvq4x4(&vector); + let decoded4x4 = codec4x4.decode_lvq4x4(&meta4x4, &encoded4x4); + + // Calculate MSE for both + let mse4: f32 = vector + .iter() + .zip(decoded4.iter()) + .map(|(orig, dec)| (orig - dec).powi(2)) + .sum::() + / vector.len() as f32; + + let mse4x4: f32 = vector + .iter() + .zip(decoded4x4.iter()) + .map(|(orig, dec)| (orig - dec).powi(2)) + .sum::() + / vector.len() as f32; + + // Two-level quantization should have lower error + assert!( + mse4x4 <= mse4, + "LVQ4x4 should have lower or equal MSE: mse4={}, mse4x4={}", + mse4, + mse4x4 + ); + } + + #[test] + fn test_lvq_codec_generic_encode_decode() { + // Test the generic encode/decode methods that dispatch based on bits + for bits in [LvqBits::Lvq4, LvqBits::Lvq4x4, LvqBits::Lvq8] { + let codec = LvqCodec::new(8, bits); + let vector: Vec = vec![0.1, 0.5, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4]; + + let (meta, encoded) = codec.encode(&vector); + let decoded = codec.decode(&meta, &encoded); + + // All should reasonably decode + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.15, + "bits={:?}: orig={}, dec={}", + bits, + orig, + dec + ); + } + } + } + + #[test] + fn test_lvq4_distance_ordering() { + let codec = LvqCodec::new(4, LvqBits::Lvq4); + + // Three stored vectors + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.5, 0.5, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 0.0, 1.0]; + let query = vec![1.0, 0.0, 0.0, 0.0]; + + let (meta1, enc1) = codec.encode_lvq4(&v1); + let (meta2, enc2) = codec.encode_lvq4(&v2); + let (meta3, enc3) = codec.encode_lvq4(&v3); + + let d1 = lvq4_asymmetric_l2_squared(&query, &enc1, &meta1, 4); + let d2 = lvq4_asymmetric_l2_squared(&query, &enc2, &meta2, 4); + let d3 = lvq4_asymmetric_l2_squared(&query, &enc3, &meta3, 4); + + // d1 should be smallest (self), d3 should be largest + assert!(d1 < d2, "d1={} should be < d2={}", d1, d2); + assert!(d2 < d3, "d2={} should be < d3={}", d2, d3); + } + + #[test] + fn test_lvq4x4_distance_ordering() { + let codec = LvqCodec::new(4, LvqBits::Lvq4x4); + + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.5, 0.5, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 0.0, 1.0]; + let query = vec![1.0, 0.0, 0.0, 0.0]; + + let (meta1, enc1) = codec.encode_lvq4x4(&v1); + let (meta2, enc2) = codec.encode_lvq4x4(&v2); + let (meta3, enc3) = codec.encode_lvq4x4(&v3); + + let d1 = lvq4x4_asymmetric_l2_squared(&query, &enc1, &meta1, 4); + let d2 = lvq4x4_asymmetric_l2_squared(&query, &enc2, &meta2, 4); + let d3 = lvq4x4_asymmetric_l2_squared(&query, &enc3, &meta3, 4); + + assert!(d1 < d2, "d1={} should be < d2={}", d1, d2); + assert!(d2 < d3, "d2={} should be < d3={}", d2, d3); + } + + #[test] + fn test_lvq4x4_inner_product() { + let codec = LvqCodec::new(4, LvqBits::Lvq4x4); + + let stored = vec![1.0, 2.0, 3.0, 4.0]; + let query = vec![1.0, 1.0, 1.0, 1.0]; + + let (meta, encoded) = codec.encode_lvq4x4(&stored); + + // IP = 1*1 + 1*2 + 1*3 + 1*4 = 10 + let ip = lvq4x4_asymmetric_inner_product(&query, &encoded, &meta, 4); + + assert!( + (ip - 10.0).abs() < 1.0, + "IP should be ~10, got {}", + ip + ); + } + + #[test] + fn test_lvq_vector_meta_default() { + let meta = LvqVectorMeta::default(); + assert_eq!(meta.min_primary, 0.0); + assert_eq!(meta.delta_primary, 1.0); + assert_eq!(meta.min_residual, 0.0); + assert_eq!(meta.delta_residual, 0.0); + } + + #[test] + fn test_lvq_vector_meta_size() { + assert_eq!(LvqVectorMeta::SIZE, 16); // 4 f32 values + } + + #[test] + fn test_lvq_codec_accessors() { + let codec = LvqCodec::new(64, LvqBits::Lvq4x4); + assert_eq!(codec.dim(), 64); + assert_eq!(codec.bits(), LvqBits::Lvq4x4); + assert_eq!(codec.encoded_size(), 64); // 8 bits per dim for 4x4 + assert_eq!(codec.total_size(), LvqVectorMeta::SIZE + 64); + } + + #[test] + fn test_lvq4_zero_vector() { + let codec = LvqCodec::new(8, LvqBits::Lvq4); + let vector: Vec = vec![0.0; 8]; + + let (meta, encoded) = codec.encode_lvq4(&vector); + let decoded = codec.decode_lvq4(&meta, &encoded); + + for dec in decoded.iter() { + assert!(dec.abs() < 0.001, "expected ~0, got {}", dec); + } + } + + #[test] + fn test_lvq8_negative_values() { + let codec = LvqCodec::new(8, LvqBits::Lvq8); + let vector: Vec = vec![-1.0, -0.5, 0.0, 0.5, 1.0, -0.25, 0.25, 0.75]; + + let (meta, encoded) = codec.encode_lvq8(&vector); + let decoded = codec.decode_lvq8(&meta, &encoded); + + // 8-bit should have very good precision + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.02, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_lvq4x4_large_vectors() { + let codec = LvqCodec::new(512, LvqBits::Lvq4x4); + let vector: Vec = (0..512).map(|i| ((i as f32) / 512.0) - 0.5).collect(); + + let (meta, encoded) = codec.encode_lvq4x4(&vector); + let decoded = codec.decode_lvq4x4(&meta, &encoded); + + let max_error = vector + .iter() + .zip(decoded.iter()) + .map(|(orig, dec)| (orig - dec).abs()) + .fold(0.0f32, f32::max); + + assert!( + max_error < 0.05, + "Large vector should decode well, max_error={}", + max_error + ); + } + + #[test] + fn test_lvq4_packed_nibble_roundtrip() { + let codec = LvqCodec::new(4, LvqBits::Lvq4); + // Create values that should map to specific nibbles + let vector: Vec = vec![0.0, 0.33, 0.67, 1.0]; + + let (meta, encoded) = codec.encode_lvq4(&vector); + + // Check encoded size + assert_eq!(encoded.len(), 2); // 4 values * 4 bits / 8 = 2 bytes + + // Decode and verify + let decoded = codec.decode_lvq4(&meta, &encoded); + assert_eq!(decoded.len(), 4); + + for (orig, dec) in vector.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.15, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_lvq_levels_constants() { + assert_eq!(LEVELS_4BIT, 16); + assert_eq!(LEVELS_8BIT, 256); + } +} diff --git a/rust/vecsim/src/quantization/mod.rs b/rust/vecsim/src/quantization/mod.rs new file mode 100644 index 000000000..6dea372d2 --- /dev/null +++ b/rust/vecsim/src/quantization/mod.rs @@ -0,0 +1,24 @@ +//! Quantization support for vector compression. +//! +//! This module provides quantization methods to compress vectors for efficient storage +//! and faster distance computations: +//! - `SQ8`: Scalar quantization to 8-bit unsigned integers with per-vector scaling +//! - `sq8_simd`: SIMD-optimized asymmetric distance functions for SQ8 +//! - `LVQ`: Learned Vector Quantization with 4-bit/8-bit and two-level support +//! - `LeanVec`: Dimension reduction with two-level quantization (4x8, 8x8) + +pub mod leanvec; +pub mod lvq; +pub mod sq8; +pub mod sq8_simd; + +pub use leanvec::{ + leanvec_primary_l2_squared_4bit, leanvec_primary_l2_squared_8bit, leanvec_residual_inner_product, + leanvec_residual_l2_squared, LeanVecBits, LeanVecCodec, LeanVecMeta, +}; +pub use lvq::{ + lvq4_asymmetric_l2_squared, lvq4x4_asymmetric_inner_product, lvq4x4_asymmetric_l2_squared, + LvqBits, LvqCodec, LvqVectorMeta, +}; +pub use sq8::{Sq8Codec, Sq8VectorMeta}; +pub use sq8_simd::{sq8_cosine_simd, sq8_inner_product_simd, sq8_l2_squared_simd}; diff --git a/rust/vecsim/src/quantization/sq8.rs b/rust/vecsim/src/quantization/sq8.rs new file mode 100644 index 000000000..69e454d2b --- /dev/null +++ b/rust/vecsim/src/quantization/sq8.rs @@ -0,0 +1,698 @@ +//! Scalar Quantization to 8-bit unsigned integers (SQ8). +//! +//! SQ8 provides a simple yet effective compression scheme that reduces f32 vectors +//! to u8 vectors with per-vector metadata for dequantization. +//! +//! ## How it works +//! +//! For each vector: +//! 1. Find min and max values +//! 2. Compute delta = (max - min) / 255 +//! 3. Quantize: q[i] = round((v[i] - min) / delta) +//! 4. Store metadata for dequantization and asymmetric distance computation +//! +//! ## Asymmetric Distance +//! +//! For L2 squared distance: ||q - v||² = ||q||² + ||v||² - 2*IP(q, v) +//! - Query vector stays in f32 +//! - Stored vector is quantized +//! - Precomputed ||v||² avoids full dequantization + +use crate::types::UInt8; + +/// Per-vector metadata for SQ8 dequantization and asymmetric distance. +#[derive(Debug, Clone, Copy, Default)] +pub struct Sq8VectorMeta { + /// Minimum value in the original vector. + pub min: f32, + /// Scale factor: (max - min) / 255. + pub delta: f32, + /// Precomputed sum of squares ||v||² for asymmetric L2. + pub sum_sq: f32, + /// Precomputed sum of elements for asymmetric IP. + pub sum: f32, +} + +impl Sq8VectorMeta { + /// Create new metadata. + pub fn new(min: f32, delta: f32, sum_sq: f32, sum: f32) -> Self { + Self { + min, + delta, + sum_sq, + sum, + } + } + + /// Dequantize a single value. + #[inline(always)] + pub fn dequantize(&self, q: u8) -> f32 { + self.min + (q as f32) * self.delta + } + + /// Size in bytes when serialized. + pub const SERIALIZED_SIZE: usize = 4 * 4; // 4 f32 values +} + +/// SQ8 encoder/decoder. +#[derive(Debug, Clone)] +pub struct Sq8Codec { + dim: usize, +} + +impl Sq8Codec { + /// Create a new SQ8 codec for vectors of the given dimension. + pub fn new(dim: usize) -> Self { + Self { dim } + } + + /// Get the dimension. + #[inline] + pub fn dimension(&self) -> usize { + self.dim + } + + /// Encode an f32 vector to SQ8 format. + /// + /// Returns the quantized vector and per-vector metadata. + pub fn encode(&self, vector: &[f32]) -> (Vec, Sq8VectorMeta) { + assert_eq!(vector.len(), self.dim, "Vector dimension mismatch"); + + // Find min and max + let mut min = f32::INFINITY; + let mut max = f32::NEG_INFINITY; + let mut sum = 0.0f32; + let mut sum_sq = 0.0f32; + + for &v in vector.iter() { + min = min.min(v); + max = max.max(v); + sum += v; + sum_sq += v * v; + } + + // Handle degenerate case where all values are the same + let delta = if (max - min).abs() < f32::EPSILON { + 1.0 // Avoid division by zero; all values will quantize to 0 + } else { + (max - min) / 255.0 + }; + + // Quantize + let inv_delta = 1.0 / delta; + let quantized: Vec = vector + .iter() + .map(|&v| { + let q = ((v - min) * inv_delta).round().clamp(0.0, 255.0) as u8; + UInt8::new(q) + }) + .collect(); + + let meta = Sq8VectorMeta::new(min, delta, sum_sq, sum); + (quantized, meta) + } + + /// Decode an SQ8 vector back to f32 format. + pub fn decode(&self, quantized: &[UInt8], meta: &Sq8VectorMeta) -> Vec { + assert_eq!(quantized.len(), self.dim, "Quantized vector dimension mismatch"); + + quantized + .iter() + .map(|&q| meta.dequantize(q.get())) + .collect() + } + + /// Encode directly from u8 slice (for performance when reading raw bytes). + pub fn encode_from_f32_slice(&self, vector: &[f32]) -> (Vec, Sq8VectorMeta) { + let (quantized, meta) = self.encode(vector); + let bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + (bytes, meta) + } + + /// Decode from raw u8 slice. + pub fn decode_from_u8_slice(&self, quantized: &[u8], meta: &Sq8VectorMeta) -> Vec { + assert_eq!(quantized.len(), self.dim, "Quantized vector dimension mismatch"); + + quantized + .iter() + .map(|&q| meta.dequantize(q)) + .collect() + } + + /// Compute quantization error (mean squared error) for a vector. + pub fn quantization_error(&self, original: &[f32], quantized: &[UInt8], meta: &Sq8VectorMeta) -> f32 { + assert_eq!(original.len(), self.dim); + assert_eq!(quantized.len(), self.dim); + + let mse: f32 = original + .iter() + .zip(quantized.iter()) + .map(|(&orig, &q)| { + let reconstructed = meta.dequantize(q.get()); + let diff = orig - reconstructed; + diff * diff + }) + .sum(); + + mse / self.dim as f32 + } + + /// Encode a batch of vectors. + pub fn encode_batch(&self, vectors: &[Vec]) -> Vec<(Vec, Sq8VectorMeta)> { + vectors.iter().map(|v| self.encode(v)).collect() + } +} + +/// Compute asymmetric L2 squared distance between f32 query and SQ8 stored vector. +/// +/// Uses the formula: ||q - v||² = ||q||² + ||v||² - 2*IP(q, v) +/// where ||v||² is precomputed in metadata. +#[inline] +pub fn sq8_asymmetric_l2_squared( + query: &[f32], + quantized: &[u8], + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + debug_assert_eq!(query.len(), dim); + debug_assert_eq!(quantized.len(), dim); + + // Compute ||q||² and IP(q, v) in one pass + let mut query_sq = 0.0f32; + let mut ip = 0.0f32; + + for i in 0..dim { + let q = query[i]; + let v = meta.min + (quantized[i] as f32) * meta.delta; + query_sq += q * q; + ip += q * v; + } + + // ||q - v||² = ||q||² + ||v||² - 2*IP(q, v) + (query_sq + meta.sum_sq - 2.0 * ip).max(0.0) +} + +/// Compute asymmetric inner product between f32 query and SQ8 stored vector. +/// +/// Returns negative inner product (for use as distance where lower is better). +#[inline] +pub fn sq8_asymmetric_inner_product( + query: &[f32], + quantized: &[u8], + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + debug_assert_eq!(query.len(), dim); + debug_assert_eq!(quantized.len(), dim); + + let mut ip = 0.0f32; + + for i in 0..dim { + let q = query[i]; + let v = meta.min + (quantized[i] as f32) * meta.delta; + ip += q * v; + } + + -ip // Negative for distance ordering +} + +/// Compute asymmetric cosine distance between f32 query and SQ8 stored vector. +/// +/// Returns 1 - cosine_similarity. +#[inline] +pub fn sq8_asymmetric_cosine( + query: &[f32], + quantized: &[u8], + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + debug_assert_eq!(query.len(), dim); + debug_assert_eq!(quantized.len(), dim); + + let mut query_sq = 0.0f32; + let mut ip = 0.0f32; + + for i in 0..dim { + let q = query[i]; + let v = meta.min + (quantized[i] as f32) * meta.delta; + query_sq += q * q; + ip += q * v; + } + + let query_norm = query_sq.sqrt(); + let stored_norm = meta.sum_sq.sqrt(); + + if query_norm < 1e-30 || stored_norm < 1e-30 { + return 1.0; + } + + let cosine_sim = (ip / (query_norm * stored_norm)).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sq8_encode_decode_roundtrip() { + let codec = Sq8Codec::new(4); + let original = vec![0.1, 0.5, 0.3, 0.9]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + // Check that decoded values are close to original + for (orig, dec) in original.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.01, + "orig={}, dec={}, diff={}", + orig, + dec, + (orig - dec).abs() + ); + } + } + + #[test] + fn test_sq8_uniform_vector() { + let codec = Sq8Codec::new(4); + let original = vec![0.5, 0.5, 0.5, 0.5]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + // All values should decode to approximately the same value + for dec in decoded.iter() { + assert!((dec - 0.5).abs() < 0.01); + } + + // All quantized values should be 0 (since all are at min) + for q in quantized.iter() { + assert_eq!(q.get(), 0); + } + } + + #[test] + fn test_sq8_metadata() { + let codec = Sq8Codec::new(4); + let original = vec![0.0, 0.5, 1.0, 0.25]; + + let (_, meta) = codec.encode(&original); + + assert_eq!(meta.min, 0.0); + assert!((meta.delta - (1.0 / 255.0)).abs() < 0.0001); + assert!((meta.sum - 1.75).abs() < 0.0001); + // sum_sq = 0 + 0.25 + 1 + 0.0625 = 1.3125 + assert!((meta.sum_sq - 1.3125).abs() < 0.0001); + } + + #[test] + fn test_sq8_asymmetric_l2() { + let codec = Sq8Codec::new(4); + let stored = vec![1.0, 2.0, 3.0, 4.0]; + let query = vec![1.0, 2.0, 3.0, 4.0]; + + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + let dist = sq8_asymmetric_l2_squared(&query, &quantized_bytes, &meta, 4); + + // Distance to self should be very small (due to quantization error) + assert!(dist < 0.1, "dist={}", dist); + } + + #[test] + fn test_sq8_asymmetric_l2_different() { + let codec = Sq8Codec::new(4); + let stored = vec![0.0, 0.0, 0.0, 0.0]; + let query = vec![1.0, 1.0, 1.0, 1.0]; + + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + let dist = sq8_asymmetric_l2_squared(&query, &quantized_bytes, &meta, 4); + + // Distance should be approximately 4 (sum of (1-0)^2 = 4) + assert!((dist - 4.0).abs() < 0.1, "dist={}", dist); + } + + #[test] + fn test_sq8_asymmetric_ip() { + let codec = Sq8Codec::new(4); + let stored = vec![1.0, 2.0, 3.0, 4.0]; + let query = vec![1.0, 1.0, 1.0, 1.0]; + + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + let dist = sq8_asymmetric_inner_product(&query, &quantized_bytes, &meta, 4); + + // IP = 1*1 + 1*2 + 1*3 + 1*4 = 10, so distance = -10 + assert!((dist + 10.0).abs() < 0.5, "dist={}", dist); + } + + #[test] + fn test_sq8_asymmetric_cosine() { + let codec = Sq8Codec::new(3); + let stored = vec![1.0, 0.0, 0.0]; + let query = vec![1.0, 0.0, 0.0]; + + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + let dist = sq8_asymmetric_cosine(&query, &quantized_bytes, &meta, 3); + + // Same direction should have cosine distance near 0 + assert!(dist < 0.1, "dist={}", dist); + } + + #[test] + fn test_sq8_quantization_error() { + let codec = Sq8Codec::new(4); + let original = vec![0.1, 0.5, 0.3, 0.9]; + + let (quantized, meta) = codec.encode(&original); + let mse = codec.quantization_error(&original, &quantized, &meta); + + // MSE should be small + assert!(mse < 0.001, "mse={}", mse); + } + + #[test] + fn test_sq8_negative_values() { + let codec = Sq8Codec::new(4); + let original = vec![-1.0, -0.5, 0.0, 0.5]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + for (orig, dec) in original.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.01, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_sq8_large_range() { + let codec = Sq8Codec::new(4); + let original = vec![-100.0, 0.0, 50.0, 100.0]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + // With large range, quantization error will be larger + let max_error = 200.0 / 255.0; // delta + for (orig, dec) in original.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() <= max_error, + "orig={}, dec={}, max_error={}", + orig, + dec, + max_error + ); + } + } + + #[test] + fn test_sq8_zero_vector() { + let codec = Sq8Codec::new(4); + let original = vec![0.0, 0.0, 0.0, 0.0]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + // All values should decode to 0 + for dec in decoded.iter() { + assert!(dec.abs() < 0.001, "expected ~0, got {}", dec); + } + + // Metadata should reflect zero vector + assert!((meta.sum - 0.0).abs() < 0.001); + assert!((meta.sum_sq - 0.0).abs() < 0.001); + } + + #[test] + fn test_sq8_very_small_values() { + let codec = Sq8Codec::new(4); + let original = vec![1e-6, 2e-6, 3e-6, 4e-6]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + // Very small values should still be preserved approximately + for (orig, dec) in original.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 1e-5, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_sq8_very_large_values() { + let codec = Sq8Codec::new(4); + let original = vec![1e6, 2e6, 3e6, 4e6]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + // With large range, relative error should be reasonable + for (orig, dec) in original.iter().zip(decoded.iter()) { + let relative_error = (orig - dec).abs() / orig; + assert!( + relative_error < 0.01, + "orig={}, dec={}, relative_error={}", + orig, + dec, + relative_error + ); + } + } + + #[test] + fn test_sq8_encode_from_f32_slice() { + let codec = Sq8Codec::new(4); + let original = vec![0.1, 0.5, 0.3, 0.9]; + + let (bytes, meta) = codec.encode_from_f32_slice(&original); + + // Bytes should have the same length as dimension + assert_eq!(bytes.len(), 4); + + // Should be able to decode back + let decoded = codec.decode_from_u8_slice(&bytes, &meta); + + for (orig, dec) in original.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.01, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_sq8_batch_encode() { + let codec = Sq8Codec::new(4); + let vectors = vec![ + vec![0.1, 0.2, 0.3, 0.4], + vec![0.5, 0.6, 0.7, 0.8], + vec![0.9, 0.8, 0.7, 0.6], + ]; + + let batch_result = codec.encode_batch(&vectors); + + assert_eq!(batch_result.len(), 3); + + // Verify each result + for (i, (quantized, meta)) in batch_result.iter().enumerate() { + let decoded = codec.decode(quantized, meta); + for (orig, dec) in vectors[i].iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.01, + "vector {}: orig={}, dec={}", + i, + orig, + dec + ); + } + } + } + + #[test] + fn test_sq8_asymmetric_l2_distance_ordering() { + let codec = Sq8Codec::new(4); + + // Three stored vectors: v1 is closest to query, v3 is farthest + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.5, 0.5, 0.0, 0.0]; + let v3 = vec![0.0, 0.0, 0.0, 1.0]; + let query = vec![1.0, 0.0, 0.0, 0.0]; + + let (q1, m1) = codec.encode(&v1); + let (q2, m2) = codec.encode(&v2); + let (q3, m3) = codec.encode(&v3); + + let b1: Vec = q1.iter().map(|q| q.get()).collect(); + let b2: Vec = q2.iter().map(|q| q.get()).collect(); + let b3: Vec = q3.iter().map(|q| q.get()).collect(); + + let d1 = sq8_asymmetric_l2_squared(&query, &b1, &m1, 4); + let d2 = sq8_asymmetric_l2_squared(&query, &b2, &m2, 4); + let d3 = sq8_asymmetric_l2_squared(&query, &b3, &m3, 4); + + // d1 should be smallest (near 0), d3 should be largest + assert!(d1 < d2, "d1={} should be < d2={}", d1, d2); + assert!(d2 < d3, "d2={} should be < d3={}", d2, d3); + } + + #[test] + fn test_sq8_asymmetric_ip_distance_ordering() { + let codec = Sq8Codec::new(4); + + // For inner product, higher IP means more similar + let v1 = vec![1.0, 0.0, 0.0, 0.0]; // IP with query = 1 + let v2 = vec![0.5, 0.0, 0.0, 0.0]; // IP with query = 0.5 + let v3 = vec![0.0, 1.0, 0.0, 0.0]; // IP with query = 0 + let query = vec![1.0, 0.0, 0.0, 0.0]; + + let (q1, m1) = codec.encode(&v1); + let (q2, m2) = codec.encode(&v2); + let (q3, m3) = codec.encode(&v3); + + let b1: Vec = q1.iter().map(|q| q.get()).collect(); + let b2: Vec = q2.iter().map(|q| q.get()).collect(); + let b3: Vec = q3.iter().map(|q| q.get()).collect(); + + // Note: asymmetric_inner_product returns negative IP + let d1 = sq8_asymmetric_inner_product(&query, &b1, &m1, 4); + let d2 = sq8_asymmetric_inner_product(&query, &b2, &m2, 4); + let d3 = sq8_asymmetric_inner_product(&query, &b3, &m3, 4); + + // d1 should be most negative (closest), d3 should be near 0 (farthest) + assert!(d1 < d2, "d1={} should be < d2={}", d1, d2); + assert!(d2 < d3, "d2={} should be < d3={}", d2, d3); + } + + #[test] + fn test_sq8_asymmetric_cosine_orthogonal() { + let codec = Sq8Codec::new(4); + + // Two orthogonal vectors + let v1 = vec![1.0, 0.0, 0.0, 0.0]; + let v2 = vec![0.0, 1.0, 0.0, 0.0]; + + let (q1, m1) = codec.encode(&v1); + let (q2, m2) = codec.encode(&v2); + + let b1: Vec = q1.iter().map(|q| q.get()).collect(); + let b2: Vec = q2.iter().map(|q| q.get()).collect(); + + // v1 vs v2 (orthogonal) - cosine distance should be ~1 + let dist = sq8_asymmetric_cosine(&v1, &b2, &m2, 4); + assert!( + (dist - 1.0).abs() < 0.1, + "Orthogonal vectors should have cosine distance ~1, got {}", + dist + ); + + // v1 vs v1 (same direction) - cosine distance should be ~0 + let self_dist = sq8_asymmetric_cosine(&v1, &b1, &m1, 4); + assert!( + self_dist < 0.1, + "Same direction should have cosine distance ~0, got {}", + self_dist + ); + } + + #[test] + fn test_sq8_meta_dequantize() { + let meta = Sq8VectorMeta::new(0.0, 0.1, 0.0, 0.0); + + // Dequantize value 0 should give min + assert!((meta.dequantize(0) - 0.0).abs() < 0.001); + + // Dequantize value 10 should give 0 + 10 * 0.1 = 1.0 + assert!((meta.dequantize(10) - 1.0).abs() < 0.001); + + // Dequantize value 255 should give max + assert!((meta.dequantize(255) - 25.5).abs() < 0.001); + } + + #[test] + fn test_sq8_meta_serialized_size() { + // Should be 4 f32 values = 16 bytes + assert_eq!(Sq8VectorMeta::SERIALIZED_SIZE, 16); + } + + #[test] + fn test_sq8_codec_dimension() { + let codec = Sq8Codec::new(128); + assert_eq!(codec.dimension(), 128); + } + + #[test] + fn test_sq8_high_dimension() { + // Test with a high-dimension vector (common for embeddings) + let codec = Sq8Codec::new(768); + let original: Vec = (0..768).map(|i| (i as f32) / 768.0).collect(); + + let (quantized, meta) = codec.encode(&original); + assert_eq!(quantized.len(), 768); + + let decoded = codec.decode(&quantized, &meta); + assert_eq!(decoded.len(), 768); + + // Check quantization error + let mse = codec.quantization_error(&original, &quantized, &meta); + assert!(mse < 0.001, "MSE should be small for normalized data, got {}", mse); + } + + #[test] + fn test_sq8_mixed_positive_negative() { + let codec = Sq8Codec::new(4); + let original = vec![-0.5, -0.1, 0.2, 0.7]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + for (orig, dec) in original.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() < 0.01, + "orig={}, dec={}", + orig, + dec + ); + } + } + + #[test] + fn test_sq8_extreme_range() { + let codec = Sq8Codec::new(4); + // Mix of very small and very large values + let original = vec![-1e5, 1e-5, 0.0, 1e5]; + + let (quantized, meta) = codec.encode(&original); + let decoded = codec.decode(&quantized, &meta); + + // With such extreme range, precision is limited + let delta = meta.delta; + for (orig, dec) in original.iter().zip(decoded.iter()) { + assert!( + (orig - dec).abs() <= delta * 1.5, + "orig={}, dec={}, delta={}", + orig, + dec, + delta + ); + } + } +} diff --git a/rust/vecsim/src/quantization/sq8_simd.rs b/rust/vecsim/src/quantization/sq8_simd.rs new file mode 100644 index 000000000..05c9dc16f --- /dev/null +++ b/rust/vecsim/src/quantization/sq8_simd.rs @@ -0,0 +1,898 @@ +//! SIMD-optimized SQ8 asymmetric distance functions. +//! +//! These functions compute distances between f32 queries and SQ8-quantized vectors +//! using SIMD instructions for significant performance improvements. + +use super::sq8::Sq8VectorMeta; + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +// ============================================================================= +// NEON implementations (ARM) +// ============================================================================= + +/// NEON-optimized asymmetric L2 squared distance. +/// +/// Computes ||query - stored||² where stored is SQ8-quantized. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - NEON must be available +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn sq8_asymmetric_l2_squared_neon( + query: *const f32, + quantized: *const u8, + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + let min_vec = vdupq_n_f32(meta.min); + let delta_vec = vdupq_n_f32(meta.delta); + + let mut query_sq_sum = vdupq_n_f32(0.0); + let mut ip_sum = vdupq_n_f32(0.0); + + let chunks = dim / 8; + let remainder = dim % 8; + + // Process 8 elements at a time + for i in 0..chunks { + let offset = i * 8; + + // Load 8 query values (two sets of 4) + let q0 = vld1q_f32(query.add(offset)); + let q1 = vld1q_f32(query.add(offset + 4)); + + // Load 8 quantized values and convert to f32 + let quant_u8 = vld1_u8(quantized.add(offset)); + + // Widen u8 to u16 + let quant_u16 = vmovl_u8(quant_u8); + + // Split and widen u16 to u32 + let quant_lo_u32 = vmovl_u16(vget_low_u16(quant_u16)); + let quant_hi_u32 = vmovl_u16(vget_high_u16(quant_u16)); + + // Convert u32 to f32 + let quant_lo_f32 = vcvtq_f32_u32(quant_lo_u32); + let quant_hi_f32 = vcvtq_f32_u32(quant_hi_u32); + + // Dequantize: stored = min + quant * delta + let stored0 = vfmaq_f32(min_vec, quant_lo_f32, delta_vec); + let stored1 = vfmaq_f32(min_vec, quant_hi_f32, delta_vec); + + // Accumulate query squared + query_sq_sum = vfmaq_f32(query_sq_sum, q0, q0); + query_sq_sum = vfmaq_f32(query_sq_sum, q1, q1); + + // Accumulate inner product + ip_sum = vfmaq_f32(ip_sum, q0, stored0); + ip_sum = vfmaq_f32(ip_sum, q1, stored1); + } + + // Reduce to scalar + let mut query_sq = vaddvq_f32(query_sq_sum); + let mut ip = vaddvq_f32(ip_sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + let q = *query.add(base + i); + let v = meta.min + (*quantized.add(base + i) as f32) * meta.delta; + query_sq += q * q; + ip += q * v; + } + + // ||q - v||² = ||q||² + ||v||² - 2*IP(q, v) + (query_sq + meta.sum_sq - 2.0 * ip).max(0.0) +} + +/// NEON-optimized asymmetric inner product. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - NEON must be available +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn sq8_asymmetric_inner_product_neon( + query: *const f32, + quantized: *const u8, + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + let min_vec = vdupq_n_f32(meta.min); + let delta_vec = vdupq_n_f32(meta.delta); + + let mut ip_sum = vdupq_n_f32(0.0); + + let chunks = dim / 8; + let remainder = dim % 8; + + for i in 0..chunks { + let offset = i * 8; + + let q0 = vld1q_f32(query.add(offset)); + let q1 = vld1q_f32(query.add(offset + 4)); + + let quant_u8 = vld1_u8(quantized.add(offset)); + let quant_u16 = vmovl_u8(quant_u8); + let quant_lo_u32 = vmovl_u16(vget_low_u16(quant_u16)); + let quant_hi_u32 = vmovl_u16(vget_high_u16(quant_u16)); + let quant_lo_f32 = vcvtq_f32_u32(quant_lo_u32); + let quant_hi_f32 = vcvtq_f32_u32(quant_hi_u32); + + let stored0 = vfmaq_f32(min_vec, quant_lo_f32, delta_vec); + let stored1 = vfmaq_f32(min_vec, quant_hi_f32, delta_vec); + + ip_sum = vfmaq_f32(ip_sum, q0, stored0); + ip_sum = vfmaq_f32(ip_sum, q1, stored1); + } + + let mut ip = vaddvq_f32(ip_sum); + + let base = chunks * 8; + for i in 0..remainder { + let q = *query.add(base + i); + let v = meta.min + (*quantized.add(base + i) as f32) * meta.delta; + ip += q * v; + } + + -ip // Negative for distance ordering +} + +/// NEON-optimized asymmetric cosine distance. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - NEON must be available +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[inline] +pub unsafe fn sq8_asymmetric_cosine_neon( + query: *const f32, + quantized: *const u8, + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + let min_vec = vdupq_n_f32(meta.min); + let delta_vec = vdupq_n_f32(meta.delta); + + let mut query_sq_sum = vdupq_n_f32(0.0); + let mut ip_sum = vdupq_n_f32(0.0); + + let chunks = dim / 8; + let remainder = dim % 8; + + for i in 0..chunks { + let offset = i * 8; + + let q0 = vld1q_f32(query.add(offset)); + let q1 = vld1q_f32(query.add(offset + 4)); + + let quant_u8 = vld1_u8(quantized.add(offset)); + let quant_u16 = vmovl_u8(quant_u8); + let quant_lo_u32 = vmovl_u16(vget_low_u16(quant_u16)); + let quant_hi_u32 = vmovl_u16(vget_high_u16(quant_u16)); + let quant_lo_f32 = vcvtq_f32_u32(quant_lo_u32); + let quant_hi_f32 = vcvtq_f32_u32(quant_hi_u32); + + let stored0 = vfmaq_f32(min_vec, quant_lo_f32, delta_vec); + let stored1 = vfmaq_f32(min_vec, quant_hi_f32, delta_vec); + + query_sq_sum = vfmaq_f32(query_sq_sum, q0, q0); + query_sq_sum = vfmaq_f32(query_sq_sum, q1, q1); + ip_sum = vfmaq_f32(ip_sum, q0, stored0); + ip_sum = vfmaq_f32(ip_sum, q1, stored1); + } + + let mut query_sq = vaddvq_f32(query_sq_sum); + let mut ip = vaddvq_f32(ip_sum); + + let base = chunks * 8; + for i in 0..remainder { + let q = *query.add(base + i); + let v = meta.min + (*quantized.add(base + i) as f32) * meta.delta; + query_sq += q * q; + ip += q * v; + } + + let query_norm = query_sq.sqrt(); + let stored_norm = meta.sum_sq.sqrt(); + + if query_norm < 1e-30 || stored_norm < 1e-30 { + return 1.0; + } + + let cosine_sim = (ip / (query_norm * stored_norm)).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +// ============================================================================= +// AVX2 implementations (x86_64) +// ============================================================================= + +/// AVX2-optimized asymmetric L2 squared distance. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - AVX2 must be available +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn sq8_asymmetric_l2_squared_avx2( + query: *const f32, + quantized: *const u8, + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + let min_vec = _mm256_set1_ps(meta.min); + let delta_vec = _mm256_set1_ps(meta.delta); + + let mut query_sq_sum = _mm256_setzero_ps(); + let mut ip_sum = _mm256_setzero_ps(); + + let chunks = dim / 8; + let remainder = dim % 8; + + for i in 0..chunks { + let offset = i * 8; + + // Load 8 query values + let q = _mm256_loadu_ps(query.add(offset)); + + // Load 8 quantized values + let quant_u8 = _mm_loadl_epi64(quantized.add(offset) as *const __m128i); + + // Convert u8 to i32 (zero extend) + let quant_i32 = _mm256_cvtepu8_epi32(quant_u8); + + // Convert i32 to f32 + let quant_f32 = _mm256_cvtepi32_ps(quant_i32); + + // Dequantize: stored = min + quant * delta + let stored = _mm256_fmadd_ps(quant_f32, delta_vec, min_vec); + + // Accumulate query squared + query_sq_sum = _mm256_fmadd_ps(q, q, query_sq_sum); + + // Accumulate inner product + ip_sum = _mm256_fmadd_ps(q, stored, ip_sum); + } + + // Horizontal sum + let mut query_sq = hsum256_ps_avx2(query_sq_sum); + let mut ip = hsum256_ps_avx2(ip_sum); + + // Handle remainder + let base = chunks * 8; + for i in 0..remainder { + let q = *query.add(base + i); + let v = meta.min + (*quantized.add(base + i) as f32) * meta.delta; + query_sq += q * q; + ip += q * v; + } + + (query_sq + meta.sum_sq - 2.0 * ip).max(0.0) +} + +/// AVX2-optimized asymmetric inner product. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - AVX2 must be available +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn sq8_asymmetric_inner_product_avx2( + query: *const f32, + quantized: *const u8, + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + let min_vec = _mm256_set1_ps(meta.min); + let delta_vec = _mm256_set1_ps(meta.delta); + + let mut ip_sum = _mm256_setzero_ps(); + + let chunks = dim / 8; + let remainder = dim % 8; + + for i in 0..chunks { + let offset = i * 8; + + let q = _mm256_loadu_ps(query.add(offset)); + let quant_u8 = _mm_loadl_epi64(quantized.add(offset) as *const __m128i); + let quant_i32 = _mm256_cvtepu8_epi32(quant_u8); + let quant_f32 = _mm256_cvtepi32_ps(quant_i32); + let stored = _mm256_fmadd_ps(quant_f32, delta_vec, min_vec); + + ip_sum = _mm256_fmadd_ps(q, stored, ip_sum); + } + + let mut ip = hsum256_ps_avx2(ip_sum); + + let base = chunks * 8; + for i in 0..remainder { + let q = *query.add(base + i); + let v = meta.min + (*quantized.add(base + i) as f32) * meta.delta; + ip += q * v; + } + + -ip +} + +/// AVX2-optimized asymmetric cosine distance. +/// +/// # Safety +/// - Pointers must be valid for `dim` elements +/// - AVX2 must be available +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn sq8_asymmetric_cosine_avx2( + query: *const f32, + quantized: *const u8, + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + let min_vec = _mm256_set1_ps(meta.min); + let delta_vec = _mm256_set1_ps(meta.delta); + + let mut query_sq_sum = _mm256_setzero_ps(); + let mut ip_sum = _mm256_setzero_ps(); + + let chunks = dim / 8; + let remainder = dim % 8; + + for i in 0..chunks { + let offset = i * 8; + + let q = _mm256_loadu_ps(query.add(offset)); + let quant_u8 = _mm_loadl_epi64(quantized.add(offset) as *const __m128i); + let quant_i32 = _mm256_cvtepu8_epi32(quant_u8); + let quant_f32 = _mm256_cvtepi32_ps(quant_i32); + let stored = _mm256_fmadd_ps(quant_f32, delta_vec, min_vec); + + query_sq_sum = _mm256_fmadd_ps(q, q, query_sq_sum); + ip_sum = _mm256_fmadd_ps(q, stored, ip_sum); + } + + let mut query_sq = hsum256_ps_avx2(query_sq_sum); + let mut ip = hsum256_ps_avx2(ip_sum); + + let base = chunks * 8; + for i in 0..remainder { + let q = *query.add(base + i); + let v = meta.min + (*quantized.add(base + i) as f32) * meta.delta; + query_sq += q * q; + ip += q * v; + } + + let query_norm = query_sq.sqrt(); + let stored_norm = meta.sum_sq.sqrt(); + + if query_norm < 1e-30 || stored_norm < 1e-30 { + return 1.0; + } + + let cosine_sim = (ip / (query_norm * stored_norm)).clamp(-1.0, 1.0); + 1.0 - cosine_sim +} + +/// Horizontal sum helper for AVX2. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +#[inline] +unsafe fn hsum256_ps_avx2(v: __m256) -> f32 { + let high = _mm256_extractf128_ps(v, 1); + let low = _mm256_castps256_ps128(v); + let sum128 = _mm_add_ps(high, low); + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf = _mm_movehl_ps(sums, sums); + let sums = _mm_add_ss(sums, shuf); + _mm_cvtss_f32(sums) +} + +// ============================================================================= +// Safe wrappers with runtime dispatch +// ============================================================================= + +/// Compute SQ8 asymmetric L2 squared distance with SIMD acceleration. +#[inline] +pub fn sq8_l2_squared_simd( + query: &[f32], + quantized: &[u8], + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + debug_assert_eq!(query.len(), dim); + debug_assert_eq!(quantized.len(), dim); + + #[cfg(target_arch = "aarch64")] + { + unsafe { sq8_asymmetric_l2_squared_neon(query.as_ptr(), quantized.as_ptr(), meta, dim) } + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + unsafe { + sq8_asymmetric_l2_squared_avx2(query.as_ptr(), quantized.as_ptr(), meta, dim) + } + } else { + super::sq8::sq8_asymmetric_l2_squared(query, quantized, meta, dim) + } + } + + #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))] + { + super::sq8::sq8_asymmetric_l2_squared(query, quantized, meta, dim) + } +} + +/// Compute SQ8 asymmetric inner product with SIMD acceleration. +#[inline] +pub fn sq8_inner_product_simd( + query: &[f32], + quantized: &[u8], + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + debug_assert_eq!(query.len(), dim); + debug_assert_eq!(quantized.len(), dim); + + #[cfg(target_arch = "aarch64")] + { + unsafe { sq8_asymmetric_inner_product_neon(query.as_ptr(), quantized.as_ptr(), meta, dim) } + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + unsafe { + sq8_asymmetric_inner_product_avx2(query.as_ptr(), quantized.as_ptr(), meta, dim) + } + } else { + super::sq8::sq8_asymmetric_inner_product(query, quantized, meta, dim) + } + } + + #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))] + { + super::sq8::sq8_asymmetric_inner_product(query, quantized, meta, dim) + } +} + +/// Compute SQ8 asymmetric cosine distance with SIMD acceleration. +#[inline] +pub fn sq8_cosine_simd( + query: &[f32], + quantized: &[u8], + meta: &Sq8VectorMeta, + dim: usize, +) -> f32 { + debug_assert_eq!(query.len(), dim); + debug_assert_eq!(quantized.len(), dim); + + #[cfg(target_arch = "aarch64")] + { + unsafe { sq8_asymmetric_cosine_neon(query.as_ptr(), quantized.as_ptr(), meta, dim) } + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + unsafe { sq8_asymmetric_cosine_avx2(query.as_ptr(), quantized.as_ptr(), meta, dim) } + } else { + super::sq8::sq8_asymmetric_cosine(query, quantized, meta, dim) + } + } + + #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))] + { + super::sq8::sq8_asymmetric_cosine(query, quantized, meta, dim) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::quantization::sq8::Sq8Codec; + + fn create_test_data(dim: usize) -> (Vec, Vec, Sq8VectorMeta) { + let codec = Sq8Codec::new(dim); + let stored: Vec = (0..dim).map(|i| (i as f32) / (dim as f32)).collect(); + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + (stored, quantized_bytes, meta) + } + + #[test] + fn test_sq8_simd_l2_accuracy() { + let dim = 128; + let (_stored, quantized, meta) = create_test_data(dim); + let query: Vec = (0..dim).map(|i| (i as f32 + 0.5) / (dim as f32)).collect(); + + let scalar_result = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized, &meta, dim); + let simd_result = sq8_l2_squared_simd(&query, &quantized, &meta, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.001, + "scalar={}, simd={}", + scalar_result, + simd_result + ); + } + + #[test] + fn test_sq8_simd_ip_accuracy() { + let dim = 128; + let (_, quantized, meta) = create_test_data(dim); + let query: Vec = (0..dim).map(|i| (i as f32) / (dim as f32)).collect(); + + let scalar_result = + crate::quantization::sq8::sq8_asymmetric_inner_product(&query, &quantized, &meta, dim); + let simd_result = sq8_inner_product_simd(&query, &quantized, &meta, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.001, + "scalar={}, simd={}", + scalar_result, + simd_result + ); + } + + #[test] + fn test_sq8_simd_cosine_accuracy() { + let dim = 128; + let (_, quantized, meta) = create_test_data(dim); + let query: Vec = (0..dim).map(|i| (i as f32 + 1.0) / (dim as f32)).collect(); + + let scalar_result = + crate::quantization::sq8::sq8_asymmetric_cosine(&query, &quantized, &meta, dim); + let simd_result = sq8_cosine_simd(&query, &quantized, &meta, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.001, + "scalar={}, simd={}", + scalar_result, + simd_result + ); + } + + #[test] + fn test_sq8_simd_remainder_handling() { + // Test with dimensions that don't align to SIMD width + for dim in [17, 33, 65, 100, 127] { + let (_, quantized, meta) = create_test_data(dim); + let query: Vec = (0..dim).map(|i| (i as f32) / (dim as f32)).collect(); + + let scalar_result = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized, &meta, dim); + let simd_result = sq8_l2_squared_simd(&query, &quantized, &meta, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.001, + "dim={}: scalar={}, simd={}", + dim, + scalar_result, + simd_result + ); + } + } + + #[test] + fn test_sq8_simd_self_distance() { + let dim = 128; + let codec = Sq8Codec::new(dim); + let stored: Vec = (0..dim).map(|i| (i as f32) / (dim as f32)).collect(); + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + // Distance to self should be small (due to quantization error) + let dist = sq8_l2_squared_simd(&stored, &quantized_bytes, &meta, dim); + assert!(dist < 0.01, "Self distance should be small, got {}", dist); + } + + #[test] + fn test_sq8_simd_negative_values() { + let dim = 64; + let codec = Sq8Codec::new(dim); + let stored: Vec = (0..dim).map(|i| ((i as f32) / (dim as f32)) - 0.5).collect(); + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + let query: Vec = (0..dim).map(|i| ((i as f32 + 0.25) / (dim as f32)) - 0.5).collect(); + + let scalar_l2 = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized_bytes, &meta, dim); + let simd_l2 = sq8_l2_squared_simd(&query, &quantized_bytes, &meta, dim); + + assert!( + (scalar_l2 - simd_l2).abs() < 0.001, + "L2: scalar={}, simd={}", + scalar_l2, + simd_l2 + ); + + let scalar_ip = + crate::quantization::sq8::sq8_asymmetric_inner_product(&query, &quantized_bytes, &meta, dim); + let simd_ip = sq8_inner_product_simd(&query, &quantized_bytes, &meta, dim); + + assert!( + (scalar_ip - simd_ip).abs() < 0.001, + "IP: scalar={}, simd={}", + scalar_ip, + simd_ip + ); + + let scalar_cos = + crate::quantization::sq8::sq8_asymmetric_cosine(&query, &quantized_bytes, &meta, dim); + let simd_cos = sq8_cosine_simd(&query, &quantized_bytes, &meta, dim); + + assert!( + (scalar_cos - simd_cos).abs() < 0.001, + "Cosine: scalar={}, simd={}", + scalar_cos, + simd_cos + ); + } + + #[test] + fn test_sq8_simd_large_vectors() { + let dim = 768; // Common embedding dimension + let (_, quantized, meta) = create_test_data(dim); + let query: Vec = (0..dim).map(|i| (i as f32) / (dim as f32)).collect(); + + let scalar_result = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized, &meta, dim); + let simd_result = sq8_l2_squared_simd(&query, &quantized, &meta, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.01, + "dim={}: scalar={}, simd={}", + dim, + scalar_result, + simd_result + ); + } + + #[test] + fn test_sq8_simd_zero_vector() { + let dim = 32; + let codec = Sq8Codec::new(dim); + let stored: Vec = vec![0.0; dim]; + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + let query: Vec = vec![1.0; dim]; + + let scalar_result = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized_bytes, &meta, dim); + let simd_result = sq8_l2_squared_simd(&query, &quantized_bytes, &meta, dim); + + assert!( + (scalar_result - simd_result).abs() < 0.001, + "scalar={}, simd={}", + scalar_result, + simd_result + ); + + // Distance should be approximately dim (sum of 1^2) + assert!((simd_result - dim as f32).abs() < 0.1, "Expected ~{}, got {}", dim, simd_result); + } + + #[test] + fn test_sq8_simd_cosine_orthogonal() { + let dim = 64; + let codec = Sq8Codec::new(dim); + + // Vector pointing along first half of dimensions + let mut v1 = vec![0.0; dim]; + for i in 0..dim/2 { + v1[i] = 1.0; + } + + // Vector pointing along second half of dimensions + let mut v2 = vec![0.0; dim]; + for i in dim/2..dim { + v2[i] = 1.0; + } + + let (quantized1, meta1) = codec.encode(&v1); + let quantized_bytes1: Vec = quantized1.iter().map(|q| q.get()).collect(); + + let (quantized2, meta2) = codec.encode(&v2); + let quantized_bytes2: Vec = quantized2.iter().map(|q| q.get()).collect(); + + // v1 vs quantized v2 should have cosine distance ~1 (orthogonal) + let scalar_cos = crate::quantization::sq8::sq8_asymmetric_cosine(&v1, &quantized_bytes2, &meta2, dim); + let simd_cos = sq8_cosine_simd(&v1, &quantized_bytes2, &meta2, dim); + + assert!( + (scalar_cos - simd_cos).abs() < 0.01, + "Orthogonal cosine: scalar={}, simd={}", + scalar_cos, + simd_cos + ); + assert!(simd_cos > 0.9, "Orthogonal vectors should have cosine distance near 1, got {}", simd_cos); + + // v1 vs quantized v1 should have cosine distance ~0 (same direction) + let scalar_self = crate::quantization::sq8::sq8_asymmetric_cosine(&v1, &quantized_bytes1, &meta1, dim); + let simd_self = sq8_cosine_simd(&v1, &quantized_bytes1, &meta1, dim); + + assert!( + (scalar_self - simd_self).abs() < 0.01, + "Self cosine: scalar={}, simd={}", + scalar_self, + simd_self + ); + assert!(simd_self < 0.1, "Same direction should have cosine distance near 0, got {}", simd_self); + } + + #[test] + fn test_sq8_simd_distance_ordering_consistency() { + let dim = 128; + let codec = Sq8Codec::new(dim); + + let v1: Vec = (0..dim).map(|i| (i as f32) / (dim as f32)).collect(); + let v2: Vec = (0..dim).map(|i| ((i + dim/4) as f32 % dim as f32) / (dim as f32)).collect(); + let v3: Vec = (0..dim).map(|i| (1.0 - (i as f32) / (dim as f32))).collect(); + + let query = v1.clone(); + + let (q1, m1) = codec.encode(&v1); + let (q2, m2) = codec.encode(&v2); + let (q3, m3) = codec.encode(&v3); + + let b1: Vec = q1.iter().map(|q| q.get()).collect(); + let b2: Vec = q2.iter().map(|q| q.get()).collect(); + let b3: Vec = q3.iter().map(|q| q.get()).collect(); + + // Test L2 + let d1_scalar = crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &b1, &m1, dim); + let d2_scalar = crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &b2, &m2, dim); + let d3_scalar = crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &b3, &m3, dim); + + let d1_simd = sq8_l2_squared_simd(&query, &b1, &m1, dim); + let d2_simd = sq8_l2_squared_simd(&query, &b2, &m2, dim); + let d3_simd = sq8_l2_squared_simd(&query, &b3, &m3, dim); + + // Ordering should be preserved + if d1_scalar < d2_scalar { + assert!(d1_simd < d2_simd, "Ordering mismatch: d1_simd={}, d2_simd={}", d1_simd, d2_simd); + } + if d2_scalar < d3_scalar { + assert!(d2_simd < d3_simd, "Ordering mismatch: d2_simd={}, d3_simd={}", d2_simd, d3_simd); + } + } + + #[test] + fn test_sq8_simd_small_dimensions() { + // Test dimensions smaller than SIMD width + for dim in [4, 7, 8, 15, 16] { + let (_, quantized, meta) = create_test_data(dim); + let query: Vec = (0..dim).map(|i| (i as f32 + 0.5) / (dim as f32)).collect(); + + let scalar_l2 = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized, &meta, dim); + let simd_l2 = sq8_l2_squared_simd(&query, &quantized, &meta, dim); + + assert!( + (scalar_l2 - simd_l2).abs() < 0.001, + "dim={}: L2 scalar={}, simd={}", + dim, + scalar_l2, + simd_l2 + ); + + let scalar_ip = + crate::quantization::sq8::sq8_asymmetric_inner_product(&query, &quantized, &meta, dim); + let simd_ip = sq8_inner_product_simd(&query, &quantized, &meta, dim); + + assert!( + (scalar_ip - simd_ip).abs() < 0.001, + "dim={}: IP scalar={}, simd={}", + dim, + scalar_ip, + simd_ip + ); + } + } + + #[test] + fn test_sq8_simd_uniform_values() { + let dim = 64; + let codec = Sq8Codec::new(dim); + let stored: Vec = vec![0.5; dim]; + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + let query: Vec = vec![0.6; dim]; + + let scalar_l2 = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized_bytes, &meta, dim); + let simd_l2 = sq8_l2_squared_simd(&query, &quantized_bytes, &meta, dim); + + assert!( + (scalar_l2 - simd_l2).abs() < 0.001, + "Uniform L2: scalar={}, simd={}", + scalar_l2, + simd_l2 + ); + } + + #[test] + fn test_sq8_simd_large_values() { + let dim = 64; + let codec = Sq8Codec::new(dim); + let stored: Vec = (0..dim).map(|i| (i as f32) * 100.0).collect(); + let (quantized, meta) = codec.encode(&stored); + let quantized_bytes: Vec = quantized.iter().map(|q| q.get()).collect(); + + let query: Vec = (0..dim).map(|i| (i as f32 + 0.5) * 100.0).collect(); + + let scalar_l2 = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized_bytes, &meta, dim); + let simd_l2 = sq8_l2_squared_simd(&query, &quantized_bytes, &meta, dim); + + // For large values, allow slightly larger tolerance + let relative_error = (scalar_l2 - simd_l2).abs() / scalar_l2.max(1.0); + assert!( + relative_error < 0.001, + "Large values L2: scalar={}, simd={}, relative_error={}", + scalar_l2, + simd_l2, + relative_error + ); + } + + #[test] + fn test_sq8_simd_all_distances_consistency() { + // Test that all three distance functions are consistent between SIMD and scalar + let dim = 256; + let (_, quantized, meta) = create_test_data(dim); + let query: Vec = (0..dim).map(|i| (i as f32 + 0.3) / (dim as f32)).collect(); + + // L2 squared + let scalar_l2 = + crate::quantization::sq8::sq8_asymmetric_l2_squared(&query, &quantized, &meta, dim); + let simd_l2 = sq8_l2_squared_simd(&query, &quantized, &meta, dim); + assert!( + (scalar_l2 - simd_l2).abs() < 0.01, + "L2: scalar={}, simd={}", + scalar_l2, + simd_l2 + ); + + // Inner product + let scalar_ip = + crate::quantization::sq8::sq8_asymmetric_inner_product(&query, &quantized, &meta, dim); + let simd_ip = sq8_inner_product_simd(&query, &quantized, &meta, dim); + assert!( + (scalar_ip - simd_ip).abs() < 0.01, + "IP: scalar={}, simd={}", + scalar_ip, + simd_ip + ); + + // Cosine + let scalar_cos = + crate::quantization::sq8::sq8_asymmetric_cosine(&query, &quantized, &meta, dim); + let simd_cos = sq8_cosine_simd(&query, &quantized, &meta, dim); + assert!( + (scalar_cos - simd_cos).abs() < 0.001, + "Cosine: scalar={}, simd={}", + scalar_cos, + simd_cos + ); + } +} diff --git a/rust/vecsim/src/query/hybrid.rs b/rust/vecsim/src/query/hybrid.rs new file mode 100644 index 000000000..37d14b525 --- /dev/null +++ b/rust/vecsim/src/query/hybrid.rs @@ -0,0 +1,462 @@ +//! Hybrid search heuristics for choosing between ad-hoc and batch search. +//! +//! When performing filtered queries (queries that combine a filter with similarity search), +//! there are two strategies: +//! +//! 1. **Ad-hoc brute force**: Iterate through all filtered vectors and compute distances directly +//! 2. **Batch iterator**: Use the index's batch iterator to progressively retrieve results +//! +//! The `prefer_adhoc_search` heuristic uses a decision tree classifier to choose the optimal +//! strategy based on index characteristics and query parameters. +//! +//! ## Decision Factors +//! +//! - `index_size`: Total number of vectors in the index +//! - `subset_size`: Estimated number of vectors passing the filter +//! - `k`: Number of results requested +//! - `dim`: Vector dimensionality +//! - `m`: HNSW connectivity parameter (for HNSW indices) + +/// Search mode tracking for hybrid queries. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SearchMode { + /// Initial state - no search performed yet. + #[default] + Empty, + /// Standard k-NN query (no filtering). + StandardKnn, + /// Hybrid query using ad-hoc brute force from the start. + HybridAdhocBf, + /// Hybrid query using batch iterator. + HybridBatches, + /// Hybrid query that switched from batches to ad-hoc brute force. + HybridBatchesToAdhocBf, + /// Range query mode. + RangeQuery, +} + +/// Result ordering options for query results. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum QueryResultOrder { + /// Sort by distance score (ascending - closest first). + #[default] + ByScore, + /// Sort by vector ID. + ById, + /// Primary sort by score, secondary by ID. + ByScoreThenId, +} + +/// Hybrid search policy for filtered queries. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum HybridPolicy { + /// Automatically choose the best strategy. + #[default] + Auto, + /// Force ad-hoc brute force search. + ForceAdhocBf, + /// Force batch iterator search. + ForceBatches, +} + +/// Parameters for hybrid search decision. +#[derive(Debug, Clone)] +pub struct HybridSearchParams { + /// Total vectors in the index. + pub index_size: usize, + /// Estimated vectors passing the filter. + pub subset_size: usize, + /// Number of results requested. + pub k: usize, + /// Vector dimensionality. + pub dim: usize, + /// Whether this is the initial check (true) or re-evaluation (false). + pub initial_check: bool, +} + +/// Decision tree heuristic for BruteForce indices. +/// +/// This implements a 10-leaf decision tree trained on empirical data. +/// Returns true if ad-hoc brute force is preferred, false for batch iterator. +pub fn prefer_adhoc_brute_force(params: &HybridSearchParams, label_count: usize) -> bool { + let index_size = params.index_size; + let dim = params.dim; + + // Calculate ratio r = subset_size / label_count + let r = if label_count == 0 { + 0.0 + } else { + (params.subset_size.min(label_count) as f64) / (label_count as f64) + }; + + // Decision tree based on C++ implementation + // (sklearn DecisionTreeClassifier with 10 leaves) + + if index_size <= 5500 { + return true; // Ad-hoc for small indices + } + + if dim <= 300 { + if r <= 0.15 { + return true; // Ad-hoc for small filter ratios + } + if r <= 0.35 { + if dim <= 75 { + return false; // Batches for very low dimensions + } + if index_size <= 550_000 { + return true; // Ad-hoc for medium index size + } + return false; // Batches for large indices + } + return false; // Batches for large filter ratios + } + + // dim > 300 + if r <= 0.025 { + return true; // Ad-hoc for very small filter ratios + } + if r <= 0.55 { + if index_size <= 55_000 { + return true; // Ad-hoc for smaller indices + } + if r <= 0.045 { + return true; // Ad-hoc for small filter ratios + } + return false; // Batches otherwise + } + + false // Batches for large filter ratios +} + +/// Decision tree heuristic for HNSW indices. +/// +/// This implements a 20-leaf decision tree trained on empirical data. +/// HNSW considers additional parameters like k and M (connectivity). +pub fn prefer_adhoc_hnsw(params: &HybridSearchParams, label_count: usize, m: usize) -> bool { + let index_size = params.index_size; + let dim = params.dim; + let k = params.k; + + // Calculate ratio + let r = if label_count == 0 { + 0.0 + } else { + (params.subset_size.min(label_count) as f64) / (label_count as f64) + }; + + // Decision tree based on C++ implementation (20 leaves) + + if index_size <= 30_000 { + // Small to medium index + if index_size <= 5500 { + return true; // Always ad-hoc for very small indices + } + if r <= 0.17 { + return true; // Ad-hoc for small filter ratios + } + if k <= 12 { + if dim <= 55 { + return false; // Batches for small dimensions + } + if m <= 10 { + return false; // Batches for low connectivity + } + return true; // Ad-hoc otherwise + } + return true; // Ad-hoc for larger k + } + + // Large index (> 30000) + if r <= 0.025 { + // Very small filter ratio + if index_size <= 750_000 { + return true; + } + if r <= 0.0035 { + return true; + } + if k <= 45 { + return false; + } + return true; + } + + if r <= 0.085 { + // Small filter ratio + if index_size <= 75_000 { + return true; + } + if k <= 7 { + if dim <= 25 { + return false; + } + if m <= 10 { + return false; + } + return true; + } + if dim <= 55 { + return false; + } + if m <= 24 { + if index_size <= 550_000 { + return true; + } + return false; + } + return true; + } + + if r <= 0.175 { + // Medium filter ratio + if index_size <= 75_000 { + if k <= 7 { + if dim <= 25 { + return false; + } + return true; + } + return true; + } + if k <= 45 { + return false; + } + return true; + } + + // Large filter ratio (> 0.175) + if index_size <= 75_000 { + if k <= 12 { + if dim <= 55 { + return false; + } + if m <= 10 { + return false; + } + return true; + } + return true; + } + + false // Batches for large indices with large filter ratios +} + +/// Threshold-based heuristic for SVS indices. +/// +/// SVS uses a simpler threshold-based approach rather than a full decision tree. +pub fn prefer_adhoc_svs(params: &HybridSearchParams) -> bool { + let index_size = params.index_size; + let k = params.k; + + // Calculate ratio + let subset_ratio = if index_size == 0 { + 0.0 + } else { + (params.subset_size as f64) / (index_size as f64) + }; + + // Thresholds from C++ implementation + const SMALL_SUBSET_THRESHOLD: f64 = 0.07; // 7% of index + const LARGE_SUBSET_THRESHOLD: f64 = 0.21; // 21% of index + const SMALL_INDEX_THRESHOLD: usize = 75_000; // 75k vectors + const LARGE_INDEX_THRESHOLD: usize = 750_000; // 750k vectors + + if subset_ratio < SMALL_SUBSET_THRESHOLD { + // Small subset: ad-hoc if index is not large + index_size < LARGE_INDEX_THRESHOLD + } else if subset_ratio < LARGE_SUBSET_THRESHOLD { + // Medium subset: ad-hoc if index is small OR k is large + index_size < SMALL_INDEX_THRESHOLD || k > 12 + } else { + // Large subset: ad-hoc only if index is small + index_size < SMALL_INDEX_THRESHOLD + } +} + +/// Heuristic for Tiered indices (uses the larger sub-index's heuristic). +pub fn prefer_adhoc_tiered( + params: &HybridSearchParams, + frontend_size: usize, + backend_size: usize, + frontend_heuristic: F, + backend_heuristic: B, +) -> bool +where + F: FnOnce(&HybridSearchParams) -> bool, + B: FnOnce(&HybridSearchParams) -> bool, +{ + if backend_size > frontend_size { + backend_heuristic(params) + } else { + frontend_heuristic(params) + } +} + +/// Determine the search mode based on the decision and initial_check flag. +pub fn determine_search_mode(prefer_adhoc: bool, initial_check: bool) -> SearchMode { + if prefer_adhoc { + if initial_check { + SearchMode::HybridAdhocBf + } else { + SearchMode::HybridBatchesToAdhocBf + } + } else { + SearchMode::HybridBatches + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_brute_force_small_index() { + let params = HybridSearchParams { + index_size: 1000, + subset_size: 500, + k: 10, + dim: 128, + initial_check: true, + }; + + // Small index should prefer ad-hoc + assert!(prefer_adhoc_brute_force(¶ms, 1000)); + } + + #[test] + fn test_brute_force_large_index_small_ratio() { + let params = HybridSearchParams { + index_size: 1_000_000, + subset_size: 10_000, // 1% ratio + k: 10, + dim: 128, + initial_check: true, + }; + + // Large index with small ratio - check decision + let result = prefer_adhoc_brute_force(¶ms, 1_000_000); + // With 1% ratio and dim=128, should prefer ad-hoc + assert!(result); + } + + #[test] + fn test_brute_force_large_index_large_ratio() { + let params = HybridSearchParams { + index_size: 1_000_000, + subset_size: 500_000, // 50% ratio + k: 10, + dim: 128, + initial_check: true, + }; + + // Large index with large ratio should prefer batches + assert!(!prefer_adhoc_brute_force(¶ms, 1_000_000)); + } + + #[test] + fn test_hnsw_small_index() { + let params = HybridSearchParams { + index_size: 1000, + subset_size: 500, + k: 10, + dim: 128, + initial_check: true, + }; + + // Small index should prefer ad-hoc + assert!(prefer_adhoc_hnsw(¶ms, 1000, 16)); + } + + #[test] + fn test_hnsw_considers_k() { + let params_small_k = HybridSearchParams { + index_size: 100_000, + subset_size: 10_000, // 10% ratio + k: 5, + dim: 128, + initial_check: true, + }; + + let params_large_k = HybridSearchParams { + index_size: 100_000, + subset_size: 10_000, + k: 100, + dim: 128, + initial_check: true, + }; + + // k affects the decision for HNSW + let _result_small_k = prefer_adhoc_hnsw(¶ms_small_k, 100_000, 16); + let _result_large_k = prefer_adhoc_hnsw(¶ms_large_k, 100_000, 16); + // Both are valid results; k does affect the decision + } + + #[test] + fn test_svs_thresholds() { + // Small subset ratio (< 7%) + let params_small = HybridSearchParams { + index_size: 100_000, + subset_size: 5_000, // 5% + k: 10, + dim: 128, + initial_check: true, + }; + assert!(prefer_adhoc_svs(¶ms_small)); // Should prefer ad-hoc + + // Medium subset ratio (7-21%) + let params_medium = HybridSearchParams { + index_size: 100_000, + subset_size: 15_000, // 15% + k: 10, + dim: 128, + initial_check: true, + }; + // Medium with large k should prefer ad-hoc + let params_medium_large_k = HybridSearchParams { + k: 20, + ..params_medium + }; + assert!(prefer_adhoc_svs(¶ms_medium_large_k)); + + // Large subset ratio (> 21%) with small index + let params_large = HybridSearchParams { + index_size: 50_000, + subset_size: 25_000, // 50% + k: 10, + dim: 128, + initial_check: true, + }; + assert!(prefer_adhoc_svs(¶ms_large)); // Small index prefers ad-hoc + } + + #[test] + fn test_search_mode_determination() { + assert_eq!( + determine_search_mode(true, true), + SearchMode::HybridAdhocBf + ); + assert_eq!( + determine_search_mode(true, false), + SearchMode::HybridBatchesToAdhocBf + ); + assert_eq!( + determine_search_mode(false, true), + SearchMode::HybridBatches + ); + assert_eq!( + determine_search_mode(false, false), + SearchMode::HybridBatches + ); + } + + #[test] + fn test_query_result_order() { + assert_eq!(QueryResultOrder::default(), QueryResultOrder::ByScore); + } + + #[test] + fn test_hybrid_policy() { + assert_eq!(HybridPolicy::default(), HybridPolicy::Auto); + } +} diff --git a/rust/vecsim/src/query/mod.rs b/rust/vecsim/src/query/mod.rs new file mode 100644 index 000000000..3381c3571 --- /dev/null +++ b/rust/vecsim/src/query/mod.rs @@ -0,0 +1,20 @@ +//! Query parameter and result types. +//! +//! This module provides types for configuring queries and handling results: +//! - `QueryParams`: Configuration for query execution +//! - `QueryResult`: A single result (label + distance) +//! - `QueryReply`: Collection of query results +//! - `TimeoutChecker`: Efficient timeout checking during search +//! - `CancellationToken`: Thread-safe cancellation mechanism +//! - `hybrid`: Heuristics for choosing between ad-hoc and batch search + +pub mod hybrid; +pub mod params; +pub mod results; + +pub use hybrid::{ + determine_search_mode, prefer_adhoc_brute_force, prefer_adhoc_hnsw, prefer_adhoc_svs, + HybridPolicy, HybridSearchParams, QueryResultOrder, SearchMode, +}; +pub use params::{CancellationToken, QueryParams, TimeoutChecker}; +pub use results::{QueryReply, QueryResult}; diff --git a/rust/vecsim/src/query/params.rs b/rust/vecsim/src/query/params.rs new file mode 100644 index 000000000..b6c00b355 --- /dev/null +++ b/rust/vecsim/src/query/params.rs @@ -0,0 +1,591 @@ +//! Query parameter configuration. + +use crate::types::LabelType; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +/// Parameters for controlling query execution. +#[derive(Default)] +pub struct QueryParams { + /// For HNSW: the size of the dynamic candidate list during search (ef). + /// Higher values improve recall at the cost of speed. + /// If None, uses the index's default ef_runtime value. + pub ef_runtime: Option, + + /// Maximum number of results to return. + /// For batch iterators, this may be used to hint at batch sizes. + pub batch_size: Option, + + /// Filter function to exclude certain labels from results. + /// If Some, only vectors whose labels pass the filter are included. + pub filter: Option bool + Send + Sync>>, + + /// Enable parallel query execution if supported. + pub parallel: bool, + + /// Timeout callback function. + /// Returns true if the query should be cancelled. + /// This is checked periodically during search operations. + pub timeout_callback: Option bool + Send + Sync>>, + + /// Query timeout duration. + /// If set, creates an automatic timeout based on elapsed time. + pub timeout: Option, + + /// Epsilon for range query search boundary expansion. + /// Controls how far beyond the current best distance to search. + /// E.g., 0.01 means search 1% beyond the dynamic range. + /// If None, uses default of 0.01. + pub epsilon: Option, +} + +impl std::fmt::Debug for QueryParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QueryParams") + .field("ef_runtime", &self.ef_runtime) + .field("batch_size", &self.batch_size) + .field("filter", &self.filter.as_ref().map(|_| "")) + .field("parallel", &self.parallel) + .field( + "timeout_callback", + &self.timeout_callback.as_ref().map(|_| ""), + ) + .field("timeout", &self.timeout) + .finish() + } +} + +impl Clone for QueryParams { + fn clone(&self) -> Self { + Self { + ef_runtime: self.ef_runtime, + batch_size: self.batch_size, + filter: None, // Filter cannot be cloned + parallel: self.parallel, + timeout_callback: None, // Callback cannot be cloned + timeout: self.timeout, + epsilon: self.epsilon, + } + } +} + + +impl QueryParams { + /// Create new query parameters with default values. + pub fn new() -> Self { + Self::default() + } + + /// Set the ef_runtime parameter for HNSW search. + pub fn with_ef_runtime(mut self, ef: usize) -> Self { + self.ef_runtime = Some(ef); + self + } + + /// Set the batch size hint. + pub fn with_batch_size(mut self, size: usize) -> Self { + self.batch_size = Some(size); + self + } + + /// Set a filter function. + pub fn with_filter(mut self, filter: F) -> Self + where + F: Fn(LabelType) -> bool + Send + Sync + 'static, + { + self.filter = Some(Box::new(filter)); + self + } + + /// Enable parallel query execution. + pub fn with_parallel(mut self, parallel: bool) -> Self { + self.parallel = parallel; + self + } + + /// Check if a label passes the filter (if any). + #[inline] + pub fn passes_filter(&self, label: LabelType) -> bool { + self.filter.as_ref().is_none_or(|f| f(label)) + } + + /// Set a timeout callback function. + /// + /// The callback is invoked periodically during search. If it returns `true`, + /// the search is cancelled and returns with partial results or an error. + pub fn with_timeout_callback(mut self, callback: F) -> Self + where + F: Fn() -> bool + Send + Sync + 'static, + { + self.timeout_callback = Some(Box::new(callback)); + self + } + + /// Set a timeout duration. + /// + /// The query will be cancelled if it exceeds this duration. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// Set a timeout in milliseconds. + pub fn with_timeout_ms(self, ms: u64) -> Self { + self.with_timeout(Duration::from_millis(ms)) + } + + /// Set epsilon for range query search boundary expansion. + /// + /// Controls how far beyond the current best distance to search. + /// E.g., 0.01 means search 1% beyond the dynamic range. + pub fn with_epsilon(mut self, epsilon: f64) -> Self { + self.epsilon = Some(epsilon); + self + } + + /// Create a timeout checker that can be used during search. + /// + /// Returns a TimeoutChecker if a timeout duration is set. + pub fn create_timeout_checker(&self) -> Option { + TimeoutChecker::from_params(self) + } + + /// Check if the query should be timed out. + /// + /// This checks both the timeout duration (if set) and the timeout callback (if set). + #[inline] + pub fn is_timed_out(&self, start_time: Instant) -> bool { + // Check duration-based timeout + if let Some(timeout) = self.timeout { + if start_time.elapsed() >= timeout { + return true; + } + } + + // Check callback-based timeout + if let Some(ref callback) = self.timeout_callback { + if callback() { + return true; + } + } + + false + } +} + +/// Helper struct for efficient timeout checking during search. +/// +/// This struct caches the start time and provides efficient timeout checking +/// with configurable check intervals to minimize overhead. +pub struct TimeoutChecker { + start_time: Instant, + timeout: Option, + check_interval: usize, + check_counter: usize, + timed_out: bool, +} + +impl TimeoutChecker { + /// Create a new timeout checker with a duration. + pub fn with_duration(timeout: Duration) -> Self { + Self { + start_time: Instant::now(), + timeout: Some(timeout), + check_interval: 64, // Check every 64 iterations + check_counter: 0, + timed_out: false, + } + } + + /// Create a new timeout checker from query params. + pub fn from_params(params: &QueryParams) -> Option { + params.timeout.map(Self::with_duration) + } + + /// Check if the query should time out. + /// + /// This method is optimized to only perform the actual check every N iterations + /// to minimize overhead in tight loops. + #[inline] + pub fn check(&mut self) -> bool { + if self.timed_out { + return true; + } + + self.check_counter += 1; + if self.check_counter < self.check_interval { + return false; + } + + self.check_counter = 0; + self.timed_out = self.check_now(); + self.timed_out + } + + /// Force an immediate timeout check. + #[inline] + pub fn check_now(&self) -> bool { + if let Some(timeout) = self.timeout { + if self.start_time.elapsed() >= timeout { + return true; + } + } + false + } + + /// Get the elapsed time since the checker was created. + pub fn elapsed(&self) -> Duration { + self.start_time.elapsed() + } + + /// Check if the timeout has already been triggered. + pub fn is_timed_out(&self) -> bool { + self.timed_out + } + + /// Get the elapsed time in milliseconds. + pub fn elapsed_ms(&self) -> u64 { + self.start_time.elapsed().as_millis() as u64 + } +} + +/// A cancellation token that can be shared across threads. +/// +/// This is useful for implementing query cancellation from external code. +#[derive(Clone)] +pub struct CancellationToken { + cancelled: Arc, +} + +impl CancellationToken { + /// Create a new cancellation token. + pub fn new() -> Self { + Self { + cancelled: Arc::new(AtomicBool::new(false)), + } + } + + /// Cancel the associated operation. + pub fn cancel(&self) { + self.cancelled.store(true, Ordering::Release); + } + + /// Check if cancellation has been requested. + pub fn is_cancelled(&self) -> bool { + self.cancelled.load(Ordering::Acquire) + } + + /// Create a timeout callback from this token. + /// + /// This can be passed to `QueryParams::with_timeout_callback`. + pub fn as_callback(&self) -> impl Fn() -> bool + Send + Sync + 'static { + let cancelled = Arc::clone(&self.cancelled); + move || cancelled.load(Ordering::Acquire) + } +} + +impl Default for CancellationToken { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query_params_default() { + let params = QueryParams::new(); + assert_eq!(params.ef_runtime, None); + assert_eq!(params.batch_size, None); + assert!(params.filter.is_none()); + assert!(!params.parallel); + assert!(params.timeout_callback.is_none()); + assert!(params.timeout.is_none()); + } + + #[test] + fn test_query_params_builder() { + let params = QueryParams::new() + .with_ef_runtime(100) + .with_batch_size(50) + .with_parallel(true) + .with_timeout_ms(1000); + + assert_eq!(params.ef_runtime, Some(100)); + assert_eq!(params.batch_size, Some(50)); + assert!(params.parallel); + assert_eq!(params.timeout, Some(Duration::from_millis(1000))); + } + + #[test] + fn test_query_params_with_filter() { + let params = QueryParams::new().with_filter(|label| label % 2 == 0); + + assert!(params.filter.is_some()); + assert!(params.passes_filter(2)); + assert!(params.passes_filter(4)); + assert!(!params.passes_filter(1)); + assert!(!params.passes_filter(3)); + } + + #[test] + fn test_query_params_passes_filter_no_filter() { + let params = QueryParams::new(); + // Without a filter, all labels should pass + assert!(params.passes_filter(0)); + assert!(params.passes_filter(100)); + assert!(params.passes_filter(u64::MAX)); + } + + #[test] + fn test_query_params_with_timeout_callback() { + let should_cancel = Arc::new(AtomicBool::new(false)); + let should_cancel_clone = Arc::clone(&should_cancel); + + let params = QueryParams::new() + .with_timeout_callback(move || should_cancel_clone.load(Ordering::Acquire)); + + let start = Instant::now(); + + // Initially not timed out + assert!(!params.is_timed_out(start)); + + // Set the cancel flag + should_cancel.store(true, Ordering::Release); + + // Now should be timed out + assert!(params.is_timed_out(start)); + } + + #[test] + fn test_query_params_with_timeout_duration() { + let params = QueryParams::new().with_timeout(Duration::from_millis(10)); + + let start = Instant::now(); + + // Initially not timed out + assert!(!params.is_timed_out(start)); + + // Wait for timeout to expire + std::thread::sleep(Duration::from_millis(15)); + + // Now should be timed out + assert!(params.is_timed_out(start)); + } + + #[test] + fn test_query_params_clone() { + let params = QueryParams::new() + .with_ef_runtime(50) + .with_batch_size(25) + .with_parallel(true) + .with_timeout_ms(500) + .with_filter(|_| true); + + let cloned = params.clone(); + + assert_eq!(cloned.ef_runtime, Some(50)); + assert_eq!(cloned.batch_size, Some(25)); + assert!(cloned.parallel); + assert_eq!(cloned.timeout, Some(Duration::from_millis(500))); + // Filter cannot be cloned + assert!(cloned.filter.is_none()); + } + + #[test] + fn test_query_params_debug() { + let params = QueryParams::new() + .with_ef_runtime(100) + .with_filter(|_| true) + .with_timeout_callback(|| false); + + let debug_str = format!("{:?}", params); + assert!(debug_str.contains("QueryParams")); + assert!(debug_str.contains("ef_runtime")); + assert!(debug_str.contains("")); + assert!(debug_str.contains("")); + } + + #[test] + fn test_query_params_create_timeout_checker() { + let params_no_timeout = QueryParams::new(); + assert!(params_no_timeout.create_timeout_checker().is_none()); + + let params_with_timeout = QueryParams::new().with_timeout_ms(100); + assert!(params_with_timeout.create_timeout_checker().is_some()); + } + + #[test] + fn test_timeout_checker_with_duration() { + let mut checker = TimeoutChecker::with_duration(Duration::from_millis(50)); + + // Initially not timed out + assert!(!checker.is_timed_out()); + assert!(!checker.check_now()); + + // Should not timeout immediately + for _ in 0..100 { + if checker.check() { + break; + } + } + + // Wait for timeout + std::thread::sleep(Duration::from_millis(60)); + + // Now should timeout on check_now + assert!(checker.check_now()); + + // Force check() to do actual time check by calling enough times + // to reach the check_interval (64) + for _ in 0..64 { + if checker.check() { + break; + } + } + // After enough iterations, check() will have detected timeout + assert!(checker.is_timed_out()); + } + + #[test] + fn test_timeout_checker_check_interval() { + let mut checker = TimeoutChecker::with_duration(Duration::from_secs(100)); // Long timeout + + // The first N-1 checks should return false (doesn't actually check time) + for _ in 0..63 { + assert!(!checker.check()); + } + + // The Nth check (64th) triggers actual time check + // Since timeout is 100s, should still be false + assert!(!checker.check()); + } + + #[test] + fn test_timeout_checker_elapsed() { + let checker = TimeoutChecker::with_duration(Duration::from_secs(10)); + + std::thread::sleep(Duration::from_millis(10)); + + let elapsed = checker.elapsed(); + assert!(elapsed >= Duration::from_millis(10)); + + let elapsed_ms = checker.elapsed_ms(); + assert!(elapsed_ms >= 10); + } + + #[test] + fn test_timeout_checker_from_params() { + let params = QueryParams::new().with_timeout_ms(100); + let checker = TimeoutChecker::from_params(¶ms); + assert!(checker.is_some()); + + let params_no_timeout = QueryParams::new(); + let checker_none = TimeoutChecker::from_params(¶ms_no_timeout); + assert!(checker_none.is_none()); + } + + #[test] + fn test_cancellation_token_basic() { + let token = CancellationToken::new(); + + assert!(!token.is_cancelled()); + + token.cancel(); + + assert!(token.is_cancelled()); + } + + #[test] + fn test_cancellation_token_clone() { + let token = CancellationToken::new(); + let token_clone = token.clone(); + + assert!(!token.is_cancelled()); + assert!(!token_clone.is_cancelled()); + + // Cancel through original + token.cancel(); + + // Both should see cancellation + assert!(token.is_cancelled()); + assert!(token_clone.is_cancelled()); + } + + #[test] + fn test_cancellation_token_as_callback() { + let token = CancellationToken::new(); + let callback = token.as_callback(); + + assert!(!callback()); + + token.cancel(); + + assert!(callback()); + } + + #[test] + fn test_cancellation_token_with_query_params() { + let token = CancellationToken::new(); + let params = QueryParams::new().with_timeout_callback(token.as_callback()); + + let start = Instant::now(); + + assert!(!params.is_timed_out(start)); + + token.cancel(); + + assert!(params.is_timed_out(start)); + } + + #[test] + fn test_cancellation_token_thread_safety() { + use std::thread; + + let token = CancellationToken::new(); + let token_clone = token.clone(); + + let handle = thread::spawn(move || { + // Wait a bit then cancel + std::thread::sleep(Duration::from_millis(10)); + token_clone.cancel(); + }); + + // Poll until cancelled + while !token.is_cancelled() { + std::thread::sleep(Duration::from_millis(1)); + } + + handle.join().unwrap(); + assert!(token.is_cancelled()); + } + + #[test] + fn test_cancellation_token_default() { + let token = CancellationToken::default(); + assert!(!token.is_cancelled()); + } + + #[test] + fn test_query_params_combined_timeout_check() { + // Test with both duration and callback + let should_cancel = Arc::new(AtomicBool::new(false)); + let should_cancel_clone = Arc::clone(&should_cancel); + + let params = QueryParams::new() + .with_timeout(Duration::from_millis(100)) + .with_timeout_callback(move || should_cancel_clone.load(Ordering::Acquire)); + + let start = Instant::now(); + + // Neither triggered + assert!(!params.is_timed_out(start)); + + // Trigger callback + should_cancel.store(true, Ordering::Release); + assert!(params.is_timed_out(start)); + } +} diff --git a/rust/vecsim/src/query/results.rs b/rust/vecsim/src/query/results.rs new file mode 100644 index 000000000..a6cb5b90e --- /dev/null +++ b/rust/vecsim/src/query/results.rs @@ -0,0 +1,653 @@ +//! Query result types. + +use crate::types::{DistanceType, LabelType}; +use std::cmp::Ordering; + +/// A single query result containing a label and its distance to the query. +#[derive(Debug, Clone, Copy)] +pub struct QueryResult { + /// The external label of the matching vector. + pub label: LabelType, + /// The distance from the query vector to this result. + pub distance: D, +} + +impl QueryResult { + /// Create a new query result. + #[inline] + pub fn new(label: LabelType, distance: D) -> Self { + Self { label, distance } + } +} + +impl PartialEq for QueryResult { + fn eq(&self, other: &Self) -> bool { + self.label == other.label && self.distance.to_f64() == other.distance.to_f64() + } +} + +impl Eq for QueryResult {} + +impl PartialOrd for QueryResult { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for QueryResult { + fn cmp(&self, other: &Self) -> Ordering { + // Compare by distance first, then by label for tie-breaking + match self.distance.to_f64().partial_cmp(&other.distance.to_f64()) { + Some(Ordering::Equal) | None => self.label.cmp(&other.label), + Some(ord) => ord, + } + } +} + +/// A collection of query results. +#[derive(Debug, Clone)] +pub struct QueryReply { + /// The results, sorted by distance (ascending for most metrics). + pub results: Vec>, +} + +impl QueryReply { + /// Create a new empty query reply. + pub fn new() -> Self { + Self { + results: Vec::new(), + } + } + + /// Create a query reply with pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self { + results: Vec::with_capacity(capacity), + } + } + + /// Create a query reply from a vector of results. + pub fn from_results(results: Vec>) -> Self { + Self { results } + } + + /// Add a result to the reply. + #[inline] + pub fn push(&mut self, result: QueryResult) { + self.results.push(result); + } + + /// Get the number of results. + #[inline] + pub fn len(&self) -> usize { + self.results.len() + } + + /// Check if there are no results. + #[inline] + pub fn is_empty(&self) -> bool { + self.results.is_empty() + } + + /// Sort results by distance (ascending). + pub fn sort_by_distance(&mut self) { + self.results.sort(); + } + + /// Sort results by distance (descending) - useful for inner product. + pub fn sort_by_distance_desc(&mut self) { + self.results.sort_by(|a, b| b.cmp(a)); + } + + /// Truncate to at most k results. + pub fn truncate(&mut self, k: usize) { + self.results.truncate(k); + } + + /// Iterate over results. + pub fn iter(&self) -> impl Iterator> { + self.results.iter() + } + + /// Get the best (closest) result, if any. + pub fn best(&self) -> Option<&QueryResult> { + self.results.first() + } + + /// Sort results by label (ascending). + pub fn sort_by_label(&mut self) { + self.results.sort_by_key(|r| r.label); + } + + /// Sort results by distance, then by label for ties. + /// + /// This is the default sort behavior but provided explicitly. + pub fn sort_by_distance_then_label(&mut self) { + self.sort_by_distance(); + } + + /// Remove duplicate labels, keeping only the best (closest) result for each label. + /// + /// Results should be sorted by distance first for deterministic behavior. + pub fn deduplicate_by_label(&mut self) { + if self.results.is_empty() { + return; + } + + // Sort by distance first to ensure we keep the best per label + self.sort_by_distance(); + + let mut seen = std::collections::HashSet::new(); + self.results.retain(|r| seen.insert(r.label)); + } + + /// Filter results to only include those within the given distance threshold. + pub fn filter_by_distance(&mut self, max_distance: D) { + let threshold = max_distance.to_f64(); + self.results.retain(|r| r.distance.to_f64() <= threshold); + } + + /// Get top-k results (sorts by distance and truncates). + pub fn top_k(&mut self, k: usize) { + self.sort_by_distance(); + self.truncate(k); + } + + /// Skip the first n results (for pagination). + pub fn skip(&mut self, n: usize) { + if n >= self.results.len() { + self.results.clear(); + } else { + self.results = self.results.split_off(n); + } + } + + /// Convert distances to similarity scores. + /// + /// For metrics where lower distance = more similar (L2, Cosine), + /// this returns `1 / (1 + distance)` giving values in (0, 1]. + /// + /// Returns a vector of (label, similarity) pairs. + pub fn to_similarities(&self) -> Vec<(LabelType, f64)> { + self.results + .iter() + .map(|r| { + let dist = r.distance.to_f64(); + let similarity = 1.0 / (1.0 + dist); + (r.label, similarity) + }) + .collect() + } + + /// Convert a single distance to a similarity score. + pub fn distance_to_similarity(distance: D) -> f64 { + 1.0 / (1.0 + distance.to_f64()) + } + + /// Get results within a percentage of the best distance. + /// + /// For example, `threshold_percent = 0.2` keeps results within 20% of the best. + pub fn filter_by_relative_distance(&mut self, threshold_percent: f64) { + if self.results.is_empty() { + return; + } + + self.sort_by_distance(); + let best_dist = self.results[0].distance.to_f64(); + let threshold = best_dist * (1.0 + threshold_percent); + self.results.retain(|r| r.distance.to_f64() <= threshold); + } +} + +impl Default for QueryReply { + fn default() -> Self { + Self::new() + } +} + +impl IntoIterator for QueryReply { + type Item = QueryResult; + type IntoIter = std::vec::IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + self.results.into_iter() + } +} + +impl<'a, D: DistanceType> IntoIterator for &'a QueryReply { + type Item = &'a QueryResult; + type IntoIter = std::slice::Iter<'a, QueryResult>; + + fn into_iter(self) -> Self::IntoIter { + self.results.iter() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query_result_ordering() { + let r1 = QueryResult::::new(1, 0.5); + let r2 = QueryResult::::new(2, 1.0); + let r3 = QueryResult::::new(3, 0.5); + + assert!(r1 < r2); + assert!(r1 < r3); // Same distance, but label 1 < 3 + } + + #[test] + fn test_query_reply_sort() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 1.0)); + reply.push(QueryResult::new(2, 0.5)); + reply.push(QueryResult::new(3, 0.75)); + + reply.sort_by_distance(); + + assert_eq!(reply.results[0].label, 2); + assert_eq!(reply.results[1].label, 3); + assert_eq!(reply.results[2].label, 1); + } + + #[test] + fn test_query_result_new() { + let result = QueryResult::::new(42, 1.5); + assert_eq!(result.label, 42); + assert!((result.distance - 1.5).abs() < f32::EPSILON); + } + + #[test] + fn test_query_result_equality() { + let r1 = QueryResult::::new(1, 0.5); + let r2 = QueryResult::::new(1, 0.5); + let r3 = QueryResult::::new(1, 0.6); + let r4 = QueryResult::::new(2, 0.5); + + assert_eq!(r1, r2); + assert_ne!(r1, r3); // Different distance + assert_ne!(r1, r4); // Different label + } + + #[test] + fn test_query_result_clone() { + let r1 = QueryResult::::new(42, 1.5); + let r2 = r1; + assert_eq!(r1.label, r2.label); + assert!((r1.distance - r2.distance).abs() < f32::EPSILON); + } + + #[test] + fn test_query_result_ordering_tie_breaking() { + let r1 = QueryResult::::new(1, 0.5); + let r2 = QueryResult::::new(2, 0.5); + let r3 = QueryResult::::new(3, 0.5); + + // Same distance, so compare by label + assert!(r1 < r2); + assert!(r2 < r3); + assert!(r1 < r3); + } + + #[test] + fn test_query_reply_new() { + let reply = QueryReply::::new(); + assert!(reply.is_empty()); + assert_eq!(reply.len(), 0); + } + + #[test] + fn test_query_reply_with_capacity() { + let reply = QueryReply::::with_capacity(10); + assert!(reply.is_empty()); + assert!(reply.results.capacity() >= 10); + } + + #[test] + fn test_query_reply_from_results() { + let results = vec![ + QueryResult::new(1, 0.5), + QueryResult::new(2, 1.0), + ]; + let reply = QueryReply::from_results(results); + assert_eq!(reply.len(), 2); + } + + #[test] + fn test_query_reply_push() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + assert_eq!(reply.len(), 2); + } + + #[test] + fn test_query_reply_sort_by_distance_desc() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + reply.push(QueryResult::new(3, 0.75)); + + reply.sort_by_distance_desc(); + + assert_eq!(reply.results[0].label, 2); // 1.0 + assert_eq!(reply.results[1].label, 3); // 0.75 + assert_eq!(reply.results[2].label, 1); // 0.5 + } + + #[test] + fn test_query_reply_truncate() { + let mut reply = QueryReply::::new(); + for i in 0..10 { + reply.push(QueryResult::new(i, i as f32 * 0.1)); + } + + reply.truncate(5); + assert_eq!(reply.len(), 5); + + // Truncate to larger than current size does nothing + reply.truncate(100); + assert_eq!(reply.len(), 5); + } + + #[test] + fn test_query_reply_best() { + let mut reply = QueryReply::::new(); + assert!(reply.best().is_none()); + + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 0.3)); + + reply.sort_by_distance(); + + let best = reply.best().unwrap(); + assert_eq!(best.label, 2); + } + + #[test] + fn test_query_reply_sort_by_label() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(3, 0.5)); + reply.push(QueryResult::new(1, 1.0)); + reply.push(QueryResult::new(2, 0.75)); + + reply.sort_by_label(); + + assert_eq!(reply.results[0].label, 1); + assert_eq!(reply.results[1].label, 2); + assert_eq!(reply.results[2].label, 3); + } + + #[test] + fn test_query_reply_deduplicate_by_label() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(1, 0.3)); // Same label, closer + reply.push(QueryResult::new(2, 1.0)); + reply.push(QueryResult::new(2, 1.5)); // Same label, farther + + reply.deduplicate_by_label(); + + assert_eq!(reply.len(), 2); + // Should keep the closer one for each label + let labels: Vec<_> = reply.results.iter().map(|r| r.label).collect(); + assert!(labels.contains(&1)); + assert!(labels.contains(&2)); + } + + #[test] + fn test_query_reply_filter_by_distance() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + reply.push(QueryResult::new(3, 1.5)); + reply.push(QueryResult::new(4, 2.0)); + + reply.filter_by_distance(1.0); + + assert_eq!(reply.len(), 2); + assert!(reply.results.iter().all(|r| r.distance <= 1.0)); + } + + #[test] + fn test_query_reply_top_k() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(3, 1.5)); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(4, 2.0)); + reply.push(QueryResult::new(2, 1.0)); + + reply.top_k(2); + + assert_eq!(reply.len(), 2); + assert_eq!(reply.results[0].label, 1); // 0.5 + assert_eq!(reply.results[1].label, 2); // 1.0 + } + + #[test] + fn test_query_reply_skip() { + let mut reply = QueryReply::::new(); + for i in 0..5 { + reply.push(QueryResult::new(i, i as f32)); + } + + reply.skip(2); + + assert_eq!(reply.len(), 3); + assert_eq!(reply.results[0].label, 2); + } + + #[test] + fn test_query_reply_skip_all() { + let mut reply = QueryReply::::new(); + for i in 0..5 { + reply.push(QueryResult::new(i, i as f32)); + } + + reply.skip(10); + + assert!(reply.is_empty()); + } + + #[test] + fn test_query_reply_skip_exact() { + let mut reply = QueryReply::::new(); + for i in 0..5 { + reply.push(QueryResult::new(i, i as f32)); + } + + reply.skip(5); + + assert!(reply.is_empty()); + } + + #[test] + fn test_query_reply_to_similarities() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.0)); // similarity = 1.0 + reply.push(QueryResult::new(2, 1.0)); // similarity = 0.5 + + let sims = reply.to_similarities(); + + assert_eq!(sims.len(), 2); + assert!((sims[0].1 - 1.0).abs() < 0.001); + assert!((sims[1].1 - 0.5).abs() < 0.001); + } + + #[test] + fn test_query_reply_distance_to_similarity() { + assert!((QueryReply::distance_to_similarity(0.0f32) - 1.0).abs() < 0.001); + assert!((QueryReply::distance_to_similarity(1.0f32) - 0.5).abs() < 0.001); + assert!((QueryReply::distance_to_similarity(3.0f32) - 0.25).abs() < 0.001); + } + + #[test] + fn test_query_reply_filter_by_relative_distance() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 1.0)); + reply.push(QueryResult::new(2, 1.1)); // Within 20% + reply.push(QueryResult::new(3, 1.15)); // Within 20% (threshold is 1.2) + reply.push(QueryResult::new(4, 2.0)); // Beyond 20% + + reply.filter_by_relative_distance(0.2); + + assert_eq!(reply.len(), 3); + assert!(reply.results.iter().all(|r| r.distance <= 1.2 + 0.001)); + } + + #[test] + fn test_query_reply_filter_by_relative_distance_empty() { + let mut reply = QueryReply::::new(); + reply.filter_by_relative_distance(0.2); + assert!(reply.is_empty()); + } + + #[test] + fn test_query_reply_default() { + let reply: QueryReply = QueryReply::default(); + assert!(reply.is_empty()); + } + + #[test] + fn test_query_reply_into_iterator() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + + let collected: Vec<_> = reply.into_iter().collect(); + assert_eq!(collected.len(), 2); + } + + #[test] + fn test_query_reply_ref_iterator() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + + let mut count = 0; + for _ in &reply { + count += 1; + } + assert_eq!(count, 2); + // reply is still valid + assert_eq!(reply.len(), 2); + } + + #[test] + fn test_query_reply_iter() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + + let labels: Vec<_> = reply.iter().map(|r| r.label).collect(); + assert_eq!(labels, vec![1, 2]); + } + + #[test] + fn test_query_reply_clone() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + + let cloned = reply.clone(); + assert_eq!(cloned.len(), 2); + assert_eq!(cloned.results[0].label, 1); + assert_eq!(cloned.results[1].label, 2); + } + + #[test] + fn test_query_reply_debug() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + + let debug_str = format!("{:?}", reply); + assert!(debug_str.contains("QueryReply")); + assert!(debug_str.contains("results")); + } + + #[test] + fn test_query_result_debug() { + let result = QueryResult::::new(42, 1.5); + let debug_str = format!("{:?}", result); + assert!(debug_str.contains("QueryResult")); + assert!(debug_str.contains("label")); + assert!(debug_str.contains("distance")); + } + + #[test] + fn test_query_reply_sort_by_distance_then_label() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(3, 0.5)); + reply.push(QueryResult::new(1, 0.5)); // Same distance + reply.push(QueryResult::new(2, 1.0)); + + reply.sort_by_distance_then_label(); + + assert_eq!(reply.results[0].label, 1); // 0.5, label 1 + assert_eq!(reply.results[1].label, 3); // 0.5, label 3 + assert_eq!(reply.results[2].label, 2); // 1.0 + } + + #[test] + fn test_query_reply_deduplicate_empty() { + let mut reply = QueryReply::::new(); + reply.deduplicate_by_label(); + assert!(reply.is_empty()); + } + + #[test] + fn test_query_reply_deduplicate_single() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.deduplicate_by_label(); + assert_eq!(reply.len(), 1); + } + + #[test] + fn test_query_result_with_f64() { + let r1 = QueryResult::::new(1, 0.5); + let r2 = QueryResult::::new(2, 1.0); + + assert!(r1 < r2); + assert_eq!(r1.label, 1); + assert!((r1.distance - 0.5).abs() < f64::EPSILON); + } + + #[test] + fn test_query_reply_with_f64() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + + reply.sort_by_distance(); + + assert_eq!(reply.results[0].label, 1); + assert_eq!(reply.len(), 2); + } + + #[test] + fn test_query_reply_top_k_more_than_available() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.5)); + reply.push(QueryResult::new(2, 1.0)); + + reply.top_k(10); + + // Should have all results, sorted + assert_eq!(reply.len(), 2); + assert_eq!(reply.results[0].label, 1); + } + + #[test] + fn test_query_reply_filter_by_distance_zero() { + let mut reply = QueryReply::::new(); + reply.push(QueryResult::new(1, 0.0)); + reply.push(QueryResult::new(2, 0.5)); + + reply.filter_by_distance(0.0); + + assert_eq!(reply.len(), 1); + assert_eq!(reply.results[0].label, 1); + } +} diff --git a/rust/vecsim/src/serialization/mod.rs b/rust/vecsim/src/serialization/mod.rs new file mode 100644 index 000000000..0b0166bbd --- /dev/null +++ b/rust/vecsim/src/serialization/mod.rs @@ -0,0 +1,434 @@ +//! Serialization and deserialization for vector indices. +//! +//! This module provides functionality to save and load indices to/from disk. +//! Versioned encoding is used to support backward compatibility. + +mod version; + +pub use version::{SerializationVersion, CURRENT_VERSION}; + +use crate::distance::Metric; +use std::io::{self, Read, Write}; +use thiserror::Error; + +/// Magic number for vecsim index files. +pub const MAGIC_NUMBER: u32 = 0x5645_4353; // "VECS" in hex + +/// Errors that can occur during serialization. +#[derive(Error, Debug)] +pub enum SerializationError { + #[error("IO error: {0}")] + Io(#[from] io::Error), + + #[error("Invalid magic number: expected {expected:#x}, got {got:#x}")] + InvalidMagicNumber { expected: u32, got: u32 }, + + #[error("Unsupported version: {0}")] + UnsupportedVersion(u32), + + #[error("Index type mismatch: expected {expected}, got {got}")] + IndexTypeMismatch { expected: String, got: String }, + + #[error("Dimension mismatch: expected {expected}, got {got}")] + DimensionMismatch { expected: usize, got: usize }, + + #[error("Metric mismatch: expected {expected:?}, got {got:?}")] + MetricMismatch { expected: Metric, got: Metric }, + + #[error("Data corruption: {0}")] + DataCorruption(String), + + #[error("Invalid data: {0}")] + InvalidData(String), +} + +/// Result type for serialization operations. +pub type SerializationResult = Result; + +/// Index type identifier for serialization. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum IndexTypeId { + BruteForceSingle = 1, + BruteForceMulti = 2, + HnswSingle = 3, + HnswMulti = 4, + TieredSingle = 5, + TieredMulti = 6, + SvsSingle = 7, + SvsMulti = 8, + TieredSvsSingle = 9, + TieredSvsMulti = 10, +} + +impl IndexTypeId { + pub fn from_u8(value: u8) -> Option { + match value { + 1 => Some(IndexTypeId::BruteForceSingle), + 2 => Some(IndexTypeId::BruteForceMulti), + 3 => Some(IndexTypeId::HnswSingle), + 4 => Some(IndexTypeId::HnswMulti), + 5 => Some(IndexTypeId::TieredSingle), + 6 => Some(IndexTypeId::TieredMulti), + 7 => Some(IndexTypeId::SvsSingle), + 8 => Some(IndexTypeId::SvsMulti), + 9 => Some(IndexTypeId::TieredSvsSingle), + 10 => Some(IndexTypeId::TieredSvsMulti), + _ => None, + } + } + + pub fn as_str(&self) -> &'static str { + match self { + IndexTypeId::BruteForceSingle => "BruteForceSingle", + IndexTypeId::BruteForceMulti => "BruteForceMulti", + IndexTypeId::HnswSingle => "HnswSingle", + IndexTypeId::HnswMulti => "HnswMulti", + IndexTypeId::TieredSingle => "TieredSingle", + IndexTypeId::TieredMulti => "TieredMulti", + IndexTypeId::SvsSingle => "SvsSingle", + IndexTypeId::SvsMulti => "SvsMulti", + IndexTypeId::TieredSvsSingle => "TieredSvsSingle", + IndexTypeId::TieredSvsMulti => "TieredSvsMulti", + } + } +} + +/// Data type identifier for serialization. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum DataTypeId { + F32 = 1, + F64 = 2, + Float16 = 3, + BFloat16 = 4, + Int8 = 5, + UInt8 = 6, + Int32 = 7, + Int64 = 8, +} + +impl DataTypeId { + pub fn from_u8(value: u8) -> Option { + match value { + 1 => Some(DataTypeId::F32), + 2 => Some(DataTypeId::F64), + 3 => Some(DataTypeId::Float16), + 4 => Some(DataTypeId::BFloat16), + 5 => Some(DataTypeId::Int8), + 6 => Some(DataTypeId::UInt8), + 7 => Some(DataTypeId::Int32), + 8 => Some(DataTypeId::Int64), + _ => None, + } + } +} + +/// Header for serialized index files. +#[derive(Debug, Clone)] +pub struct IndexHeader { + pub magic: u32, + pub version: u32, + pub index_type: IndexTypeId, + pub data_type: DataTypeId, + pub metric: Metric, + pub dimension: usize, + pub count: usize, +} + +impl IndexHeader { + pub fn new( + index_type: IndexTypeId, + data_type: DataTypeId, + metric: Metric, + dimension: usize, + count: usize, + ) -> Self { + Self { + magic: MAGIC_NUMBER, + version: CURRENT_VERSION, + index_type, + data_type, + metric, + dimension, + count, + } + } + + pub fn write(&self, writer: &mut W) -> SerializationResult<()> { + write_u32(writer, self.magic)?; + write_u32(writer, self.version)?; + write_u8(writer, self.index_type as u8)?; + write_u8(writer, self.data_type as u8)?; + write_u8(writer, metric_to_u8(self.metric))?; + write_usize(writer, self.dimension)?; + write_usize(writer, self.count)?; + Ok(()) + } + + pub fn read(reader: &mut R) -> SerializationResult { + let magic = read_u32(reader)?; + if magic != MAGIC_NUMBER { + return Err(SerializationError::InvalidMagicNumber { + expected: MAGIC_NUMBER, + got: magic, + }); + } + + let version = read_u32(reader)?; + if version > CURRENT_VERSION { + return Err(SerializationError::UnsupportedVersion(version)); + } + + let index_type = IndexTypeId::from_u8(read_u8(reader)?) + .ok_or_else(|| SerializationError::InvalidData("Invalid index type".to_string()))?; + + let data_type = DataTypeId::from_u8(read_u8(reader)?) + .ok_or_else(|| SerializationError::InvalidData("Invalid data type".to_string()))?; + + let metric = metric_from_u8(read_u8(reader)?)?; + let dimension = read_usize(reader)?; + let count = read_usize(reader)?; + + Ok(Self { + magic, + version, + index_type, + data_type, + metric, + dimension, + count, + }) + } +} + +// Helper functions for binary I/O + +#[inline] +pub fn write_u8(writer: &mut W, value: u8) -> io::Result<()> { + writer.write_all(&[value]) +} + +#[inline] +pub fn read_u8(reader: &mut R) -> io::Result { + let mut buf = [0u8; 1]; + reader.read_exact(&mut buf)?; + Ok(buf[0]) +} + +#[inline] +pub fn write_u32(writer: &mut W, value: u32) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +#[inline] +pub fn read_u32(reader: &mut R) -> io::Result { + let mut buf = [0u8; 4]; + reader.read_exact(&mut buf)?; + Ok(u32::from_le_bytes(buf)) +} + +#[inline] +pub fn write_u64(writer: &mut W, value: u64) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +#[inline] +pub fn read_u64(reader: &mut R) -> io::Result { + let mut buf = [0u8; 8]; + reader.read_exact(&mut buf)?; + Ok(u64::from_le_bytes(buf)) +} + +#[inline] +pub fn write_usize(writer: &mut W, value: usize) -> io::Result<()> { + write_u64(writer, value as u64) +} + +#[inline] +pub fn read_usize(reader: &mut R) -> io::Result { + Ok(read_u64(reader)? as usize) +} + +#[inline] +pub fn write_f32(writer: &mut W, value: f32) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +#[inline] +pub fn read_f32(reader: &mut R) -> io::Result { + let mut buf = [0u8; 4]; + reader.read_exact(&mut buf)?; + Ok(f32::from_le_bytes(buf)) +} + +#[inline] +pub fn write_f64(writer: &mut W, value: f64) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +#[inline] +pub fn read_f64(reader: &mut R) -> io::Result { + let mut buf = [0u8; 8]; + reader.read_exact(&mut buf)?; + Ok(f64::from_le_bytes(buf)) +} + +/// Write a vector of f32 values. +pub fn write_f32_slice(writer: &mut W, data: &[f32]) -> io::Result<()> { + write_usize(writer, data.len())?; + for &value in data { + write_f32(writer, value)?; + } + Ok(()) +} + +/// Read a vector of f32 values. +pub fn read_f32_vec(reader: &mut R) -> io::Result> { + let len = read_usize(reader)?; + let mut data = Vec::with_capacity(len); + for _ in 0..len { + data.push(read_f32(reader)?); + } + Ok(data) +} + +/// Write a vector of f64 values. +pub fn write_f64_slice(writer: &mut W, data: &[f64]) -> io::Result<()> { + write_usize(writer, data.len())?; + for &value in data { + write_f64(writer, value)?; + } + Ok(()) +} + +/// Read a vector of f64 values. +pub fn read_f64_vec(reader: &mut R) -> io::Result> { + let len = read_usize(reader)?; + let mut data = Vec::with_capacity(len); + for _ in 0..len { + data.push(read_f64(reader)?); + } + Ok(data) +} + +/// Write a u16 value (for Float16, BFloat16). +#[inline] +pub fn write_u16(writer: &mut W, value: u16) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Read a u16 value. +#[inline] +pub fn read_u16(reader: &mut R) -> io::Result { + let mut buf = [0u8; 2]; + reader.read_exact(&mut buf)?; + Ok(u16::from_le_bytes(buf)) +} + +/// Write an i8 value. +#[inline] +pub fn write_i8(writer: &mut W, value: i8) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Read an i8 value. +#[inline] +pub fn read_i8(reader: &mut R) -> io::Result { + let mut buf = [0u8; 1]; + reader.read_exact(&mut buf)?; + Ok(i8::from_le_bytes(buf)) +} + +/// Write an i32 value. +#[inline] +pub fn write_i32(writer: &mut W, value: i32) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Read an i32 value. +#[inline] +pub fn read_i32(reader: &mut R) -> io::Result { + let mut buf = [0u8; 4]; + reader.read_exact(&mut buf)?; + Ok(i32::from_le_bytes(buf)) +} + +/// Write an i64 value. +#[inline] +pub fn write_i64(writer: &mut W, value: i64) -> io::Result<()> { + writer.write_all(&value.to_le_bytes()) +} + +/// Read an i64 value. +#[inline] +pub fn read_i64(reader: &mut R) -> io::Result { + let mut buf = [0u8; 8]; + reader.read_exact(&mut buf)?; + Ok(i64::from_le_bytes(buf)) +} + +fn metric_to_u8(metric: Metric) -> u8 { + match metric { + Metric::L2 => 1, + Metric::InnerProduct => 2, + Metric::Cosine => 3, + } +} + +fn metric_from_u8(value: u8) -> SerializationResult { + match value { + 1 => Ok(Metric::L2), + 2 => Ok(Metric::InnerProduct), + 3 => Ok(Metric::Cosine), + _ => Err(SerializationError::InvalidData(format!( + "Invalid metric value: {value}" + ))), + } +} + +/// Trait for serializable indices. +pub trait Serializable { + /// Save the index to a writer. + fn save(&self, writer: &mut W) -> SerializationResult<()>; + + /// Get the size in bytes when serialized. + fn serialized_size(&self) -> usize; +} + +/// Trait for deserializable indices. +pub trait Deserializable: Sized { + /// Load the index from a reader. + fn load(reader: &mut R) -> SerializationResult; +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + #[test] + fn test_header_roundtrip() { + let header = IndexHeader::new( + IndexTypeId::HnswSingle, + DataTypeId::F32, + Metric::L2, + 128, + 1000, + ); + + let mut buffer = Vec::new(); + header.write(&mut buffer).unwrap(); + + let mut cursor = Cursor::new(buffer); + let loaded = IndexHeader::read(&mut cursor).unwrap(); + + assert_eq!(loaded.magic, MAGIC_NUMBER); + assert_eq!(loaded.version, CURRENT_VERSION); + assert_eq!(loaded.index_type, IndexTypeId::HnswSingle); + assert_eq!(loaded.data_type, DataTypeId::F32); + assert_eq!(loaded.metric, Metric::L2); + assert_eq!(loaded.dimension, 128); + assert_eq!(loaded.count, 1000); + } +} diff --git a/rust/vecsim/src/serialization/version.rs b/rust/vecsim/src/serialization/version.rs new file mode 100644 index 000000000..2ee279026 --- /dev/null +++ b/rust/vecsim/src/serialization/version.rs @@ -0,0 +1,73 @@ +//! Versioning support for serialization. +//! +//! This module defines version constants and compatibility checking +//! for index serialization formats. + +/// Current serialization version. +pub const CURRENT_VERSION: u32 = 1; + +/// Minimum supported version for reading. +#[allow(dead_code)] +pub const MIN_SUPPORTED_VERSION: u32 = 1; + +/// Serialization version information. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SerializationVersion { + pub major: u16, + pub minor: u16, +} + +impl SerializationVersion { + pub const fn new(major: u16, minor: u16) -> Self { + Self { major, minor } + } + + pub const fn current() -> Self { + Self::new(1, 0) + } + + pub fn to_u32(self) -> u32 { + ((self.major as u32) << 16) | (self.minor as u32) + } + + pub fn from_u32(value: u32) -> Self { + Self { + major: (value >> 16) as u16, + minor: (value & 0xFFFF) as u16, + } + } + + pub fn is_compatible(self, other: Self) -> bool { + // Major version must match for compatibility + self.major == other.major + } +} + +impl Default for SerializationVersion { + fn default() -> Self { + Self::current() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_version_roundtrip() { + let version = SerializationVersion::new(1, 5); + let encoded = version.to_u32(); + let decoded = SerializationVersion::from_u32(encoded); + assert_eq!(version, decoded); + } + + #[test] + fn test_version_compatibility() { + let v1_0 = SerializationVersion::new(1, 0); + let v1_5 = SerializationVersion::new(1, 5); + let v2_0 = SerializationVersion::new(2, 0); + + assert!(v1_0.is_compatible(v1_5)); + assert!(!v1_0.is_compatible(v2_0)); + } +} diff --git a/rust/vecsim/src/types/bf16.rs b/rust/vecsim/src/types/bf16.rs new file mode 100644 index 000000000..9b5f3477b --- /dev/null +++ b/rust/vecsim/src/types/bf16.rs @@ -0,0 +1,159 @@ +//! Brain floating point (BF16/BFloat16) support. +//! +//! This module provides a wrapper around the `half` crate's `bf16` type, +//! implementing the `VectorElement` trait for use in vector similarity operations. + +use super::VectorElement; +use crate::serialization::DataTypeId; +use std::fmt; + +/// Brain floating point number (bfloat16). +/// +/// BFloat16 is a 16-bit floating point format that uses the same exponent +/// range as f32 but with reduced mantissa precision. This format is +/// particularly popular in machine learning applications. +/// +/// BF16 provides: +/// - 1 sign bit +/// - 8 exponent bits (same as f32) +/// - 7 mantissa bits +/// - Range: ~1.2e-38 to ~3.4e38 (same as f32) +/// - Precision: ~2 decimal digits +#[derive(Copy, Clone, Default, PartialEq, PartialOrd)] +#[repr(transparent)] +pub struct BFloat16(half::bf16); + +impl BFloat16 { + /// Create a new BFloat16 from raw bits. + #[inline(always)] + pub const fn from_bits(bits: u16) -> Self { + Self(half::bf16::from_bits(bits)) + } + + /// Get the raw bits of this BFloat16. + #[inline(always)] + pub const fn to_bits(self) -> u16 { + self.0.to_bits() + } + + /// Create a BFloat16 from an f32. + #[inline(always)] + pub fn from_f32(v: f32) -> Self { + Self(half::bf16::from_f32(v)) + } + + /// Convert to f32. + #[inline(always)] + pub fn to_f32(self) -> f32 { + self.0.to_f32() + } + + /// Zero value. + pub const ZERO: Self = Self(half::bf16::ZERO); + + /// One value. + pub const ONE: Self = Self(half::bf16::ONE); + + /// Positive infinity. + pub const INFINITY: Self = Self(half::bf16::INFINITY); + + /// Negative infinity. + pub const NEG_INFINITY: Self = Self(half::bf16::NEG_INFINITY); + + /// Not a number. + pub const NAN: Self = Self(half::bf16::NAN); +} + +impl fmt::Debug for BFloat16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "BFloat16({})", self.to_f32()) + } +} + +impl fmt::Display for BFloat16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +impl From for BFloat16 { + #[inline(always)] + fn from(v: f32) -> Self { + Self::from_f32(v) + } +} + +impl From for f32 { + #[inline(always)] + fn from(v: BFloat16) -> Self { + v.to_f32() + } +} + +impl VectorElement for BFloat16 { + type DistanceType = f32; + + #[inline(always)] + fn to_f32(self) -> f32 { + self.0.to_f32() + } + + #[inline(always)] + fn from_f32(v: f32) -> Self { + Self(half::bf16::from_f32(v)) + } + + #[inline(always)] + fn zero() -> Self { + Self::ZERO + } + + #[inline(always)] + fn alignment() -> usize { + 32 // AVX alignment for f32 intermediate calculations + } + + #[inline] + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_all(&self.to_bits().to_le_bytes()) + } + + #[inline] + fn read_from(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 2]; + reader.read_exact(&mut buf)?; + Ok(Self::from_bits(u16::from_le_bytes(buf))) + } + + fn data_type_id() -> DataTypeId { + DataTypeId::BFloat16 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bfloat16_roundtrip() { + let values = [0.0f32, 1.0, -1.0, 0.5, 100.0, -100.0, 1e10, -1e10]; + for v in values { + let bf16 = BFloat16::from_f32(v); + let back = bf16.to_f32(); + // BF16 has the same range as f32 but limited precision + if v == 0.0 { + assert_eq!(back, 0.0); + } else { + let rel_error = ((back - v) / v).abs(); + assert!(rel_error < 0.01, "rel_error = {} for v = {}", rel_error, v); + } + } + } + + #[test] + fn test_bfloat16_vector_element() { + let bf16 = BFloat16::from_f32(2.5); + assert!((VectorElement::to_f32(bf16) - 2.5).abs() < 0.1); + assert_eq!(BFloat16::zero().to_f32(), 0.0); + } +} diff --git a/rust/vecsim/src/types/fp16.rs b/rust/vecsim/src/types/fp16.rs new file mode 100644 index 000000000..f9a258672 --- /dev/null +++ b/rust/vecsim/src/types/fp16.rs @@ -0,0 +1,151 @@ +//! Half-precision floating point (FP16/Float16) support. +//! +//! This module provides a wrapper around the `half` crate's `f16` type, +//! implementing the `VectorElement` trait for use in vector similarity operations. + +use super::VectorElement; +use crate::serialization::DataTypeId; +use std::fmt; + +/// Half-precision floating point number (IEEE 754-2008 binary16). +/// +/// This is a wrapper around `half::f16` that implements `VectorElement`. +/// FP16 provides: +/// - 1 sign bit +/// - 5 exponent bits +/// - 10 mantissa bits +/// - Range: ~6.0e-5 to 65504 +/// - Precision: ~3 decimal digits +#[derive(Copy, Clone, Default, PartialEq, PartialOrd)] +#[repr(transparent)] +pub struct Float16(half::f16); + +impl Float16 { + /// Create a new Float16 from raw bits. + #[inline(always)] + pub const fn from_bits(bits: u16) -> Self { + Self(half::f16::from_bits(bits)) + } + + /// Get the raw bits of this Float16. + #[inline(always)] + pub const fn to_bits(self) -> u16 { + self.0.to_bits() + } + + /// Create a Float16 from an f32. + #[inline(always)] + pub fn from_f32(v: f32) -> Self { + Self(half::f16::from_f32(v)) + } + + /// Convert to f32. + #[inline(always)] + pub fn to_f32(self) -> f32 { + self.0.to_f32() + } + + /// Zero value. + pub const ZERO: Self = Self(half::f16::ZERO); + + /// One value. + pub const ONE: Self = Self(half::f16::ONE); + + /// Positive infinity. + pub const INFINITY: Self = Self(half::f16::INFINITY); + + /// Negative infinity. + pub const NEG_INFINITY: Self = Self(half::f16::NEG_INFINITY); + + /// Not a number. + pub const NAN: Self = Self(half::f16::NAN); +} + +impl fmt::Debug for Float16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Float16({})", self.to_f32()) + } +} + +impl fmt::Display for Float16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +impl From for Float16 { + #[inline(always)] + fn from(v: f32) -> Self { + Self::from_f32(v) + } +} + +impl From for f32 { + #[inline(always)] + fn from(v: Float16) -> Self { + v.to_f32() + } +} + +impl VectorElement for Float16 { + type DistanceType = f32; + + #[inline(always)] + fn to_f32(self) -> f32 { + self.0.to_f32() + } + + #[inline(always)] + fn from_f32(v: f32) -> Self { + Self(half::f16::from_f32(v)) + } + + #[inline(always)] + fn zero() -> Self { + Self::ZERO + } + + #[inline(always)] + fn alignment() -> usize { + 32 // AVX alignment for f32 intermediate calculations + } + + #[inline] + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_all(&self.to_bits().to_le_bytes()) + } + + #[inline] + fn read_from(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 2]; + reader.read_exact(&mut buf)?; + Ok(Self::from_bits(u16::from_le_bytes(buf))) + } + + fn data_type_id() -> DataTypeId { + DataTypeId::Float16 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_float16_roundtrip() { + let values = [0.0f32, 1.0, -1.0, 0.5, 100.0, -100.0]; + for v in values { + let fp16 = Float16::from_f32(v); + let back = fp16.to_f32(); + // FP16 has limited precision, so we check approximate equality + assert!((back - v).abs() < 0.01 * v.abs().max(1.0)); + } + } + + #[test] + fn test_float16_vector_element() { + let fp16 = Float16::from_f32(2.5); + assert!((VectorElement::to_f32(fp16) - 2.5).abs() < 0.01); + assert_eq!(Float16::zero().to_f32(), 0.0); + } +} diff --git a/rust/vecsim/src/types/int32.rs b/rust/vecsim/src/types/int32.rs new file mode 100644 index 000000000..ff21ca6af --- /dev/null +++ b/rust/vecsim/src/types/int32.rs @@ -0,0 +1,183 @@ +//! Signed 32-bit integer (INT32) support. +//! +//! This module provides an `Int32` wrapper type implementing the `VectorElement` trait +//! for use in vector similarity operations with 32-bit signed integer vectors. + +use super::VectorElement; +use crate::serialization::DataTypeId; +use std::fmt; + +/// Signed 32-bit integer for vector storage. +/// +/// This type wraps `i32` and implements `VectorElement` for use in vector indices. +/// INT32 provides: +/// - Range: -2,147,483,648 to 2,147,483,647 +/// - Useful for large integer-based features or sparse indices +/// +/// Distance calculations are performed in f64 for precision with large values. +#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct Int32(pub i32); + +impl Int32 { + /// Create a new Int32 from a raw i32 value. + #[inline(always)] + pub const fn new(v: i32) -> Self { + Self(v) + } + + /// Get the raw i32 value. + #[inline(always)] + pub const fn get(self) -> i32 { + self.0 + } + + /// Zero value. + pub const ZERO: Self = Self(0); + + /// Maximum value. + pub const MAX: Self = Self(i32::MAX); + + /// Minimum value. + pub const MIN: Self = Self(i32::MIN); +} + +impl fmt::Debug for Int32 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Int32({})", self.0) + } +} + +impl fmt::Display for Int32 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for Int32 { + #[inline(always)] + fn from(v: i32) -> Self { + Self(v) + } +} + +impl From for i32 { + #[inline(always)] + fn from(v: Int32) -> Self { + v.0 + } +} + +impl From for f32 { + #[inline(always)] + fn from(v: Int32) -> Self { + v.0 as f32 + } +} + +impl From for f64 { + #[inline(always)] + fn from(v: Int32) -> Self { + v.0 as f64 + } +} + +impl VectorElement for Int32 { + type DistanceType = f32; + + #[inline(always)] + fn to_f32(self) -> f32 { + self.0 as f32 + } + + #[inline(always)] + fn from_f32(v: f32) -> Self { + // Clamp to i32 range and round + Self(v.round().clamp(i32::MIN as f32, i32::MAX as f32) as i32) + } + + #[inline(always)] + fn zero() -> Self { + Self::ZERO + } + + #[inline(always)] + fn alignment() -> usize { + 32 // AVX alignment + } + + #[inline(always)] + fn can_normalize() -> bool { + // Int32 cannot be meaningfully normalized - normalized values round to 0 + false + } + + #[inline] + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_all(&self.0.to_le_bytes()) + } + + #[inline] + fn read_from(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 4]; + reader.read_exact(&mut buf)?; + Ok(Self(i32::from_le_bytes(buf))) + } + + fn data_type_id() -> DataTypeId { + DataTypeId::Int32 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_int32_roundtrip() { + let values = [0i32, 1, -1, 1000, -1000, i32::MAX, i32::MIN]; + for v in values { + let int32 = Int32::new(v); + assert_eq!(int32.get(), v); + } + } + + #[test] + fn test_int32_from_f32() { + // Exact values + assert_eq!(Int32::from_f32(0.0).get(), 0); + assert_eq!(Int32::from_f32(1000.0).get(), 1000); + assert_eq!(Int32::from_f32(-1000.0).get(), -1000); + + // Rounding + assert_eq!(Int32::from_f32(500.4).get(), 500); + assert_eq!(Int32::from_f32(500.6).get(), 501); + assert_eq!(Int32::from_f32(-500.4).get(), -500); + assert_eq!(Int32::from_f32(-500.6).get(), -501); + } + + #[test] + fn test_int32_vector_element() { + let int32 = Int32::new(42); + assert_eq!(VectorElement::to_f32(int32), 42.0); + assert_eq!(Int32::zero().get(), 0); + } + + #[test] + fn test_int32_traits() { + // Test Copy, Clone + let a = Int32::new(10); + let b = a; + let c = a.clone(); + assert_eq!(a, b); + assert_eq!(a, c); + + // Test Ord + assert!(Int32::new(10) > Int32::new(5)); + assert!(Int32::new(-5) < Int32::new(5)); + + // Test Default + let d: Int32 = Default::default(); + assert_eq!(d.get(), 0); + } +} diff --git a/rust/vecsim/src/types/int64.rs b/rust/vecsim/src/types/int64.rs new file mode 100644 index 000000000..0d9e66daf --- /dev/null +++ b/rust/vecsim/src/types/int64.rs @@ -0,0 +1,190 @@ +//! Signed 64-bit integer (INT64) support. +//! +//! This module provides an `Int64` wrapper type implementing the `VectorElement` trait +//! for use in vector similarity operations with 64-bit signed integer vectors. + +use super::VectorElement; +use crate::serialization::DataTypeId; +use std::fmt; + +/// Signed 64-bit integer for vector storage. +/// +/// This type wraps `i64` and implements `VectorElement` for use in vector indices. +/// INT64 provides: +/// - Range: -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 +/// - Useful for very large integer-based features or identifiers +/// +/// Note: Conversion to f32 may lose precision for large values. +/// Distance calculations use f32 for compatibility with other types. +#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct Int64(pub i64); + +impl Int64 { + /// Create a new Int64 from a raw i64 value. + #[inline(always)] + pub const fn new(v: i64) -> Self { + Self(v) + } + + /// Get the raw i64 value. + #[inline(always)] + pub const fn get(self) -> i64 { + self.0 + } + + /// Zero value. + pub const ZERO: Self = Self(0); + + /// Maximum value. + pub const MAX: Self = Self(i64::MAX); + + /// Minimum value. + pub const MIN: Self = Self(i64::MIN); +} + +impl fmt::Debug for Int64 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Int64({})", self.0) + } +} + +impl fmt::Display for Int64 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for Int64 { + #[inline(always)] + fn from(v: i64) -> Self { + Self(v) + } +} + +impl From for i64 { + #[inline(always)] + fn from(v: Int64) -> Self { + v.0 + } +} + +impl From for f32 { + #[inline(always)] + fn from(v: Int64) -> Self { + v.0 as f32 + } +} + +impl From for f64 { + #[inline(always)] + fn from(v: Int64) -> Self { + v.0 as f64 + } +} + +impl VectorElement for Int64 { + type DistanceType = f32; + + #[inline(always)] + fn to_f32(self) -> f32 { + self.0 as f32 + } + + #[inline(always)] + fn from_f32(v: f32) -> Self { + // Clamp to i64 range (limited by f32 precision) + // f32 can exactly represent integers up to 2^24 + Self(v.round() as i64) + } + + #[inline(always)] + fn zero() -> Self { + Self::ZERO + } + + #[inline(always)] + fn alignment() -> usize { + 64 // AVX-512 alignment + } + + #[inline(always)] + fn can_normalize() -> bool { + // Int64 cannot be meaningfully normalized - normalized values round to 0 + false + } + + #[inline] + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_all(&self.0.to_le_bytes()) + } + + #[inline] + fn read_from(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 8]; + reader.read_exact(&mut buf)?; + Ok(Self(i64::from_le_bytes(buf))) + } + + fn data_type_id() -> DataTypeId { + DataTypeId::Int64 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_int64_roundtrip() { + let values = [0i64, 1, -1, 1_000_000, -1_000_000]; + for v in values { + let int64 = Int64::new(v); + assert_eq!(int64.get(), v); + } + } + + #[test] + fn test_int64_from_f32() { + // Exact values (within f32 precision) + assert_eq!(Int64::from_f32(0.0).get(), 0); + assert_eq!(Int64::from_f32(1000.0).get(), 1000); + assert_eq!(Int64::from_f32(-1000.0).get(), -1000); + + // Rounding + assert_eq!(Int64::from_f32(500.4).get(), 500); + assert_eq!(Int64::from_f32(500.6).get(), 501); + } + + #[test] + fn test_int64_vector_element() { + let int64 = Int64::new(42); + assert_eq!(VectorElement::to_f32(int64), 42.0); + assert_eq!(Int64::zero().get(), 0); + } + + #[test] + fn test_int64_traits() { + // Test Copy, Clone + let a = Int64::new(10); + let b = a; + let c = a.clone(); + assert_eq!(a, b); + assert_eq!(a, c); + + // Test Ord + assert!(Int64::new(10) > Int64::new(5)); + assert!(Int64::new(-5) < Int64::new(5)); + + // Test Default + let d: Int64 = Default::default(); + assert_eq!(d.get(), 0); + } + + #[test] + fn test_int64_large_values() { + // Test with larger values that still fit in f32 precisely + let large = Int64::new(16_000_000); + assert_eq!(large.to_f32(), 16_000_000.0); + } +} diff --git a/rust/vecsim/src/types/int8.rs b/rust/vecsim/src/types/int8.rs new file mode 100644 index 000000000..32e0ed443 --- /dev/null +++ b/rust/vecsim/src/types/int8.rs @@ -0,0 +1,182 @@ +//! Signed 8-bit integer (INT8) support. +//! +//! This module provides an `Int8` wrapper type implementing the `VectorElement` trait +//! for use in vector similarity operations with 8-bit signed integer vectors. + +use super::VectorElement; +use crate::serialization::DataTypeId; +use std::fmt; + +/// Signed 8-bit integer for vector storage. +/// +/// This type wraps `i8` and implements `VectorElement` for use in vector indices. +/// INT8 provides: +/// - Range: -128 to 127 +/// - Memory efficient: 4x smaller than f32 +/// - Useful for quantized embeddings or integer-based features +/// +/// Distance calculations are performed in f32 for precision. +#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct Int8(pub i8); + +impl Int8 { + /// Create a new Int8 from a raw i8 value. + #[inline(always)] + pub const fn new(v: i8) -> Self { + Self(v) + } + + /// Get the raw i8 value. + #[inline(always)] + pub const fn get(self) -> i8 { + self.0 + } + + /// Zero value. + pub const ZERO: Self = Self(0); + + /// Maximum value (127). + pub const MAX: Self = Self(i8::MAX); + + /// Minimum value (-128). + pub const MIN: Self = Self(i8::MIN); +} + +impl fmt::Debug for Int8 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Int8({})", self.0) + } +} + +impl fmt::Display for Int8 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for Int8 { + #[inline(always)] + fn from(v: i8) -> Self { + Self(v) + } +} + +impl From for i8 { + #[inline(always)] + fn from(v: Int8) -> Self { + v.0 + } +} + +impl From for f32 { + #[inline(always)] + fn from(v: Int8) -> Self { + v.0 as f32 + } +} + +impl VectorElement for Int8 { + type DistanceType = f32; + + #[inline(always)] + fn to_f32(self) -> f32 { + self.0 as f32 + } + + #[inline(always)] + fn from_f32(v: f32) -> Self { + // Clamp to i8 range and round + Self(v.round().clamp(-128.0, 127.0) as i8) + } + + #[inline(always)] + fn zero() -> Self { + Self::ZERO + } + + #[inline(always)] + fn alignment() -> usize { + 32 // AVX alignment for f32 intermediate calculations + } + + #[inline(always)] + fn can_normalize() -> bool { + // Int8 cannot be meaningfully normalized - normalized values round to 0 + false + } + + #[inline] + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_all(&self.0.to_le_bytes()) + } + + #[inline] + fn read_from(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 1]; + reader.read_exact(&mut buf)?; + Ok(Self(i8::from_le_bytes(buf))) + } + + fn data_type_id() -> DataTypeId { + DataTypeId::Int8 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_int8_roundtrip() { + let values = [0i8, 1, -1, 50, -50, 127, -128]; + for v in values { + let int8 = Int8::new(v); + assert_eq!(int8.get(), v); + assert_eq!(int8.to_f32() as i8, v); + } + } + + #[test] + fn test_int8_from_f32() { + // Exact values + assert_eq!(Int8::from_f32(0.0).get(), 0); + assert_eq!(Int8::from_f32(100.0).get(), 100); + assert_eq!(Int8::from_f32(-100.0).get(), -100); + + // Rounding + assert_eq!(Int8::from_f32(50.4).get(), 50); + assert_eq!(Int8::from_f32(50.6).get(), 51); + assert_eq!(Int8::from_f32(-50.4).get(), -50); + assert_eq!(Int8::from_f32(-50.6).get(), -51); + + // Clamping + assert_eq!(Int8::from_f32(200.0).get(), 127); + assert_eq!(Int8::from_f32(-200.0).get(), -128); + } + + #[test] + fn test_int8_vector_element() { + let int8 = Int8::new(42); + assert_eq!(VectorElement::to_f32(int8), 42.0); + assert_eq!(Int8::zero().get(), 0); + } + + #[test] + fn test_int8_traits() { + // Test Copy, Clone + let a = Int8::new(10); + let b = a; + let c = a.clone(); + assert_eq!(a, b); + assert_eq!(a, c); + + // Test Ord + assert!(Int8::new(10) > Int8::new(5)); + assert!(Int8::new(-5) < Int8::new(5)); + + // Test Default + let d: Int8 = Default::default(); + assert_eq!(d.get(), 0); + } +} diff --git a/rust/vecsim/src/types/mod.rs b/rust/vecsim/src/types/mod.rs new file mode 100644 index 000000000..beba91550 --- /dev/null +++ b/rust/vecsim/src/types/mod.rs @@ -0,0 +1,248 @@ +//! Core type definitions for the vector similarity library. +//! +//! This module defines the fundamental types used throughout the library: +//! - `LabelType`: External label for vectors (user-provided identifier) +//! - `IdType`: Internal vector identifier +//! - `VectorElement`: Trait for vector element types (f32, f64, Float16, BFloat16, Int8, UInt8, Int32, Int64) +//! - `DistanceType`: Trait for distance computation result types + +pub mod bf16; +pub mod fp16; +pub mod int8; +pub mod int32; +pub mod int64; +pub mod uint8; + +pub use bf16::BFloat16; +pub use fp16::Float16; +pub use int8::Int8; +pub use int32::Int32; +pub use int64::Int64; +pub use uint8::UInt8; + +use crate::serialization::DataTypeId; +use num_traits::Float; +use std::fmt::Debug; + +/// External label type for vectors (user-provided identifier). +pub type LabelType = u64; + +/// Internal vector identifier. +pub type IdType = u32; + +/// Invalid/sentinel value for internal IDs. +pub const INVALID_ID: IdType = IdType::MAX; + +/// Trait for types that can be used as vector elements. +/// +/// This trait abstracts over the different numeric types that can be stored +/// in vectors: f32, f64, Float16, and BFloat16. +pub trait VectorElement: Copy + Clone + Debug + Send + Sync + 'static { + /// The type used for distance calculations (typically f32 or f64). + type DistanceType: DistanceType; + + /// Convert to f32 for distance calculations. + fn to_f32(self) -> f32; + + /// Create from f32. + fn from_f32(v: f32) -> Self; + + /// Zero value. + fn zero() -> Self; + + /// Alignment requirement in bytes for SIMD operations. + fn alignment() -> usize { + std::mem::align_of::() + } + + /// Whether this type can be meaningfully normalized for cosine distance. + /// + /// Returns `true` for floating-point types that can represent values in [0, 1]. + /// Returns `false` for integer types where normalization would lose precision. + fn can_normalize() -> bool { + true + } + + /// Write this value to a writer. + fn write_to(&self, writer: &mut W) -> std::io::Result<()>; + + /// Read a value from a reader. + fn read_from(reader: &mut R) -> std::io::Result; + + /// Size of this type in bytes when serialized. + fn serialized_size() -> usize { + std::mem::size_of::() + } + + /// Get the data type identifier for serialization. + fn data_type_id() -> DataTypeId; +} + +/// Trait for distance computation result types. +pub trait DistanceType: + Copy + Clone + Debug + Send + Sync + PartialOrd + 'static + Default +{ + /// Zero distance. + fn zero() -> Self; + + /// Maximum possible distance (infinity). + fn infinity() -> Self; + + /// Minimum possible distance (negative infinity or zero depending on metric). + fn neg_infinity() -> Self; + + /// Convert to f64 for comparisons. + fn to_f64(self) -> f64; + + /// Create from f64. + fn from_f64(v: f64) -> Self; + + /// Square root (for L2 distance normalization). + fn sqrt(self) -> Self; +} + +// Implementation for f32 +impl VectorElement for f32 { + type DistanceType = f32; + + #[inline(always)] + fn to_f32(self) -> f32 { + self + } + + #[inline(always)] + fn from_f32(v: f32) -> Self { + v + } + + #[inline(always)] + fn zero() -> Self { + 0.0 + } + + #[inline(always)] + fn alignment() -> usize { + 32 // AVX alignment + } + + #[inline] + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_all(&self.to_le_bytes()) + } + + #[inline] + fn read_from(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 4]; + reader.read_exact(&mut buf)?; + Ok(f32::from_le_bytes(buf)) + } + + fn data_type_id() -> DataTypeId { + DataTypeId::F32 + } +} + +impl DistanceType for f32 { + #[inline(always)] + fn zero() -> Self { + 0.0 + } + + #[inline(always)] + fn infinity() -> Self { + f32::INFINITY + } + + #[inline(always)] + fn neg_infinity() -> Self { + f32::NEG_INFINITY + } + + #[inline(always)] + fn to_f64(self) -> f64 { + self as f64 + } + + #[inline(always)] + fn from_f64(v: f64) -> Self { + v as f32 + } + + #[inline(always)] + fn sqrt(self) -> Self { + Float::sqrt(self) + } +} + +// Implementation for f64 +impl VectorElement for f64 { + type DistanceType = f64; + + #[inline(always)] + fn to_f32(self) -> f32 { + self as f32 + } + + #[inline(always)] + fn from_f32(v: f32) -> Self { + v as f64 + } + + #[inline(always)] + fn zero() -> Self { + 0.0 + } + + #[inline(always)] + fn alignment() -> usize { + 64 // AVX-512 alignment + } + + #[inline] + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_all(&self.to_le_bytes()) + } + + #[inline] + fn read_from(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 8]; + reader.read_exact(&mut buf)?; + Ok(f64::from_le_bytes(buf)) + } + + fn data_type_id() -> DataTypeId { + DataTypeId::F64 + } +} + +impl DistanceType for f64 { + #[inline(always)] + fn zero() -> Self { + 0.0 + } + + #[inline(always)] + fn infinity() -> Self { + f64::INFINITY + } + + #[inline(always)] + fn neg_infinity() -> Self { + f64::NEG_INFINITY + } + + #[inline(always)] + fn to_f64(self) -> f64 { + self + } + + #[inline(always)] + fn from_f64(v: f64) -> Self { + v + } + + #[inline(always)] + fn sqrt(self) -> Self { + Float::sqrt(self) + } +} diff --git a/rust/vecsim/src/types/uint8.rs b/rust/vecsim/src/types/uint8.rs new file mode 100644 index 000000000..948efd9c7 --- /dev/null +++ b/rust/vecsim/src/types/uint8.rs @@ -0,0 +1,180 @@ +//! Unsigned 8-bit integer (UINT8) support. +//! +//! This module provides a `UInt8` wrapper type implementing the `VectorElement` trait +//! for use in vector similarity operations with 8-bit unsigned integer vectors. + +use super::VectorElement; +use crate::serialization::DataTypeId; +use std::fmt; + +/// Unsigned 8-bit integer for vector storage. +/// +/// This type wraps `u8` and implements `VectorElement` for use in vector indices. +/// UINT8 provides: +/// - Range: 0 to 255 +/// - Memory efficient: 4x smaller than f32 +/// - Useful for quantized embeddings, image features, or SQ8 storage +/// +/// Distance calculations are performed in f32 for precision. +#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct UInt8(pub u8); + +impl UInt8 { + /// Create a new UInt8 from a raw u8 value. + #[inline(always)] + pub const fn new(v: u8) -> Self { + Self(v) + } + + /// Get the raw u8 value. + #[inline(always)] + pub const fn get(self) -> u8 { + self.0 + } + + /// Zero value. + pub const ZERO: Self = Self(0); + + /// Maximum value (255). + pub const MAX: Self = Self(u8::MAX); + + /// Minimum value (0). + pub const MIN: Self = Self(0); +} + +impl fmt::Debug for UInt8 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "UInt8({})", self.0) + } +} + +impl fmt::Display for UInt8 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for UInt8 { + #[inline(always)] + fn from(v: u8) -> Self { + Self(v) + } +} + +impl From for u8 { + #[inline(always)] + fn from(v: UInt8) -> Self { + v.0 + } +} + +impl From for f32 { + #[inline(always)] + fn from(v: UInt8) -> Self { + v.0 as f32 + } +} + +impl VectorElement for UInt8 { + type DistanceType = f32; + + #[inline(always)] + fn to_f32(self) -> f32 { + self.0 as f32 + } + + #[inline(always)] + fn from_f32(v: f32) -> Self { + // Clamp to u8 range and round + Self(v.round().clamp(0.0, 255.0) as u8) + } + + #[inline(always)] + fn zero() -> Self { + Self::ZERO + } + + #[inline(always)] + fn alignment() -> usize { + 32 // AVX alignment for f32 intermediate calculations + } + + #[inline(always)] + fn can_normalize() -> bool { + // UInt8 cannot be meaningfully normalized - normalized values round to 0 + false + } + + #[inline] + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_all(&[self.0]) + } + + #[inline] + fn read_from(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 1]; + reader.read_exact(&mut buf)?; + Ok(Self(buf[0])) + } + + fn data_type_id() -> DataTypeId { + DataTypeId::UInt8 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_uint8_roundtrip() { + let values = [0u8, 1, 50, 100, 200, 255]; + for v in values { + let uint8 = UInt8::new(v); + assert_eq!(uint8.get(), v); + assert_eq!(uint8.to_f32() as u8, v); + } + } + + #[test] + fn test_uint8_from_f32() { + // Exact values + assert_eq!(UInt8::from_f32(0.0).get(), 0); + assert_eq!(UInt8::from_f32(100.0).get(), 100); + assert_eq!(UInt8::from_f32(255.0).get(), 255); + + // Rounding + assert_eq!(UInt8::from_f32(50.4).get(), 50); + assert_eq!(UInt8::from_f32(50.6).get(), 51); + + // Clamping + assert_eq!(UInt8::from_f32(300.0).get(), 255); + assert_eq!(UInt8::from_f32(-50.0).get(), 0); + } + + #[test] + fn test_uint8_vector_element() { + let uint8 = UInt8::new(42); + assert_eq!(VectorElement::to_f32(uint8), 42.0); + assert_eq!(UInt8::zero().get(), 0); + } + + #[test] + fn test_uint8_traits() { + // Test Copy, Clone + let a = UInt8::new(10); + let b = a; + let c = a.clone(); + assert_eq!(a, b); + assert_eq!(a, c); + + // Test Ord + assert!(UInt8::new(10) > UInt8::new(5)); + assert!(UInt8::new(100) < UInt8::new(200)); + + // Test Default + let d: UInt8 = Default::default(); + assert_eq!(d.get(), 0); + } +} diff --git a/rust/vecsim/src/utils/heap.rs b/rust/vecsim/src/utils/heap.rs new file mode 100644 index 000000000..71dcac12b --- /dev/null +++ b/rust/vecsim/src/utils/heap.rs @@ -0,0 +1,348 @@ +//! Priority queue implementations for KNN search. +//! +//! This module provides specialized heap implementations optimized for +//! vector similarity search operations: +//! - `MaxHeap`: Used to maintain top-k results (evict largest distance) +//! - `MinHeap`: Used for candidate exploration (process smallest distance first) + +use crate::types::{DistanceType, IdType}; +use std::cmp::Ordering; +use std::collections::BinaryHeap; + +/// An entry in the priority queue containing an ID and distance. +#[derive(Debug, Clone, Copy)] +pub struct HeapEntry { + pub id: IdType, + pub distance: D, +} + +impl HeapEntry { + #[inline] + pub fn new(id: IdType, distance: D) -> Self { + Self { id, distance } + } +} + +/// Wrapper for max-heap ordering (largest distance at top). +#[derive(Debug, Clone, Copy)] +struct MaxHeapEntry(HeapEntry); + +impl PartialEq for MaxHeapEntry { + fn eq(&self, other: &Self) -> bool { + self.0.distance.partial_cmp(&other.0.distance) == Some(Ordering::Equal) + } +} + +impl Eq for MaxHeapEntry {} + +impl PartialOrd for MaxHeapEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MaxHeapEntry { + fn cmp(&self, other: &Self) -> Ordering { + // Natural ordering for max-heap: larger distances come first + self.0 + .distance + .partial_cmp(&other.0.distance) + .unwrap_or(Ordering::Equal) + } +} + +/// Wrapper for min-heap ordering (smallest distance at top). +#[derive(Debug, Clone, Copy)] +struct MinHeapEntry(HeapEntry); + +impl PartialEq for MinHeapEntry { + fn eq(&self, other: &Self) -> bool { + self.0.distance.partial_cmp(&other.0.distance) == Some(Ordering::Equal) + } +} + +impl Eq for MinHeapEntry {} + +impl PartialOrd for MinHeapEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MinHeapEntry { + fn cmp(&self, other: &Self) -> Ordering { + // Reverse ordering for min-heap: smaller distances come first + other + .0 + .distance + .partial_cmp(&self.0.distance) + .unwrap_or(Ordering::Equal) + } +} + +/// A max-heap that keeps track of the k smallest elements by evicting the largest. +/// +/// This is used to maintain the current top-k results during KNN search. +/// The largest distance is at the top, making it efficient to check if a new +/// candidate should replace an existing result. +#[derive(Debug)] +pub struct MaxHeap { + heap: BinaryHeap>, + capacity: usize, +} + +impl MaxHeap { + /// Create a new max-heap with the given capacity. + pub fn new(capacity: usize) -> Self { + Self { + heap: BinaryHeap::with_capacity(capacity + 1), + capacity, + } + } + + /// Get the number of elements in the heap. + #[inline] + pub fn len(&self) -> usize { + self.heap.len() + } + + /// Check if the heap is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.heap.is_empty() + } + + /// Check if the heap is full. + #[inline] + pub fn is_full(&self) -> bool { + self.heap.len() >= self.capacity + } + + /// Get the largest distance in the heap (top element). + #[inline] + pub fn top_distance(&self) -> Option { + self.heap.peek().map(|e| e.0.distance) + } + + /// Get the top entry without removing it. + #[inline] + pub fn peek(&self) -> Option> { + self.heap.peek().map(|e| e.0) + } + + /// Try to insert an element. Returns true if inserted. + /// + /// If the heap is not full, always inserts. + /// If full, only inserts if the distance is smaller than the current maximum. + #[inline] + pub fn try_insert(&mut self, id: IdType, distance: D) -> bool { + if self.heap.len() < self.capacity { + self.heap.push(MaxHeapEntry(HeapEntry::new(id, distance))); + true + } else if let Some(mut top) = self.heap.peek_mut() { + if distance < top.0.distance { + // Replace in-place using PeekMut - avoids separate pop+push operations + // PeekMut::pop() removes the top element efficiently, then we push the new one + // This is more efficient than pop() + push() because it avoids redundant sift operations + *top = MaxHeapEntry(HeapEntry::new(id, distance)); + // Drop the PeekMut to trigger sift_down + drop(top); + true + } else { + false + } + } else { + false + } + } + + /// Insert an element unconditionally, maintaining capacity. + #[inline] + pub fn insert(&mut self, id: IdType, distance: D) { + self.heap.push(MaxHeapEntry(HeapEntry::new(id, distance))); + if self.heap.len() > self.capacity { + self.heap.pop(); + } + } + + /// Pop the largest element. + #[inline] + pub fn pop(&mut self) -> Option> { + self.heap.pop().map(|e| e.0) + } + + /// Convert to a sorted vector (smallest distance first). + pub fn into_sorted_vec(self) -> Vec> { + // Pre-allocate with exact capacity to avoid reallocations + let len = self.heap.len(); + let mut entries = Vec::with_capacity(len); + entries.extend(self.heap.into_iter().map(|e| e.0)); + entries.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(Ordering::Equal) + }); + entries + } + + /// Convert to a sorted vector of (id, distance) pairs. + /// This is optimized to minimize allocations by reusing the heap's buffer. + #[inline] + pub fn into_sorted_pairs(self) -> Vec<(IdType, D)> { + // Use into_sorted_iter from BinaryHeap which pops in sorted order + let len = self.heap.len(); + let mut result = Vec::with_capacity(len); + // BinaryHeap::into_sorted_vec() uses into_iter + sort, we do the same + // but map directly to the output format + let mut entries: Vec<_> = self.heap.into_vec(); + entries.sort_by(|a, b| { + a.0.distance + .partial_cmp(&b.0.distance) + .unwrap_or(Ordering::Equal) + }); + result.extend(entries.into_iter().map(|e| (e.0.id, e.0.distance))); + result + } + + /// Convert to a vector (unordered). + pub fn into_vec(self) -> Vec> { + self.heap.into_iter().map(|e| e.0).collect() + } + + /// Iterate over entries (unordered). + pub fn iter(&self) -> impl Iterator> + '_ { + self.heap.iter().map(|e| e.0) + } + + /// Clear the heap. + pub fn clear(&mut self) { + self.heap.clear(); + } +} + +/// A min-heap for processing candidates in order of increasing distance. +/// +/// This is used for the candidate list during HNSW graph traversal, +/// where we want to process the closest candidates first. +#[derive(Debug)] +pub struct MinHeap { + heap: BinaryHeap>, +} + +impl MinHeap { + /// Create a new empty min-heap. + pub fn new() -> Self { + Self { + heap: BinaryHeap::new(), + } + } + + /// Create a new min-heap with pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self { + heap: BinaryHeap::with_capacity(capacity), + } + } + + /// Get the number of elements. + #[inline] + pub fn len(&self) -> usize { + self.heap.len() + } + + /// Check if empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.heap.is_empty() + } + + /// Get the smallest distance without removing. + #[inline] + pub fn top_distance(&self) -> Option { + self.heap.peek().map(|e| e.0.distance) + } + + /// Peek at the smallest entry. + #[inline] + pub fn peek(&self) -> Option> { + self.heap.peek().map(|e| e.0) + } + + /// Insert an element. + #[inline] + pub fn push(&mut self, id: IdType, distance: D) { + self.heap.push(MinHeapEntry(HeapEntry::new(id, distance))); + } + + /// Pop the smallest element. + #[inline] + pub fn pop(&mut self) -> Option> { + self.heap.pop().map(|e| e.0) + } + + /// Clear the heap. + pub fn clear(&mut self) { + self.heap.clear(); + } + + /// Convert to a sorted vector (smallest distance first). + pub fn into_sorted_vec(mut self) -> Vec> { + let mut result = Vec::with_capacity(self.heap.len()); + while let Some(entry) = self.pop() { + result.push(entry); + } + result + } +} + +impl Default for MinHeap { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_max_heap_topk() { + let mut heap = MaxHeap::::new(3); + + heap.try_insert(1, 5.0); + heap.try_insert(2, 3.0); + heap.try_insert(3, 7.0); + assert!(heap.is_full()); + assert_eq!(heap.top_distance(), Some(7.0)); + + // This should replace 7.0 + assert!(heap.try_insert(4, 2.0)); + assert_eq!(heap.top_distance(), Some(5.0)); + + // This should not be inserted (8.0 > 5.0) + assert!(!heap.try_insert(5, 8.0)); + + let sorted = heap.into_sorted_vec(); + assert_eq!(sorted.len(), 3); + assert_eq!(sorted[0].id, 4); // distance 2.0 + assert_eq!(sorted[1].id, 2); // distance 3.0 + assert_eq!(sorted[2].id, 1); // distance 5.0 + } + + #[test] + fn test_min_heap() { + let mut heap = MinHeap::::new(); + + heap.push(1, 5.0); + heap.push(2, 3.0); + heap.push(3, 7.0); + heap.push(4, 1.0); + + // Should come out in ascending order + assert_eq!(heap.pop().unwrap().id, 4); // 1.0 + assert_eq!(heap.pop().unwrap().id, 2); // 3.0 + assert_eq!(heap.pop().unwrap().id, 1); // 5.0 + assert_eq!(heap.pop().unwrap().id, 3); // 7.0 + assert!(heap.pop().is_none()); + } +} diff --git a/rust/vecsim/src/utils/mod.rs b/rust/vecsim/src/utils/mod.rs new file mode 100644 index 000000000..3325395b8 --- /dev/null +++ b/rust/vecsim/src/utils/mod.rs @@ -0,0 +1,10 @@ +//! Utility types and functions. +//! +//! This module provides utility data structures used throughout the library: +//! - Priority queues (max-heap and min-heap) for KNN search +//! - Memory prefetching utilities for cache optimization + +pub mod heap; +pub mod prefetch; + +pub use heap::{MaxHeap, MinHeap}; diff --git a/rust/vecsim/src/utils/prefetch.rs b/rust/vecsim/src/utils/prefetch.rs new file mode 100644 index 000000000..e0de1b42c --- /dev/null +++ b/rust/vecsim/src/utils/prefetch.rs @@ -0,0 +1,239 @@ +//! Memory prefetching utilities for optimizing cache behavior. +//! +//! This module provides platform-specific memory prefetching hints to help +//! hide memory latency during graph traversal in HNSW search. + +/// Prefetch data into L1 cache for reading. +/// +/// This is a hint to the processor that the data at the given pointer +/// will be needed soon. On x86_64, this uses the `_mm_prefetch` intrinsic. +/// On aarch64, this uses inline assembly with the `prfm pldl1keep` instruction. +#[inline] +pub fn prefetch_read(ptr: *const T) { + #[cfg(target_arch = "x86_64")] + { + // Safety: _mm_prefetch is safe to call with any pointer. + // If the pointer is invalid, the prefetch is simply ignored. + unsafe { + use std::arch::x86_64::*; + _mm_prefetch(ptr as *const i8, _MM_HINT_T0); + } + } + + #[cfg(target_arch = "aarch64")] + { + // Use inline assembly for aarch64 prefetch (prfm pldl1keep) + // PLDL1KEEP = Prefetch for Load, L1 cache, temporal (keep in cache) + // Safety: prfm is a hint instruction that is safe to call with any pointer. + // If the pointer is invalid or unmapped, the prefetch is silently ignored. + unsafe { + std::arch::asm!( + "prfm pldl1keep, [{ptr}]", + ptr = in(reg) ptr, + options(nostack, preserves_flags) + ); + } + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let _ = ptr; // Suppress unused warning on other architectures + } +} + +/// Prefetch a slice of data into cache. +/// +/// This prefetches the beginning of the slice, which is typically +/// sufficient for vector data that fits in a few cache lines. +#[inline] +pub fn prefetch_slice(slice: &[T]) { + if !slice.is_empty() { + prefetch_read(slice.as_ptr()); + } +} + +/// Configuration for adaptive prefetching based on vector characteristics. +#[derive(Debug, Clone, Copy)] +pub struct PrefetchConfig { + /// Number of vectors to prefetch ahead in the search path. + pub prefetch_depth: usize, + /// Whether to also prefetch graph structure. + pub prefetch_graph: bool, +} + +impl Default for PrefetchConfig { + fn default() -> Self { + Self { + prefetch_depth: 2, + prefetch_graph: true, + } + } +} + +impl PrefetchConfig { + /// Create config optimized for the given vector size. + /// + /// Smaller vectors benefit from more aggressive prefetching since more + /// can fit in cache. Larger vectors need less prefetching to avoid + /// cache pollution. + pub fn for_vector_size(dim: usize, element_size: usize) -> Self { + let vector_bytes = dim * element_size; + + // Adjust prefetch depth based on vector size + // Smaller vectors: prefetch more aggressively + // Larger vectors: prefetch less to avoid cache pollution + let prefetch_depth = if vector_bytes <= 256 { + 4 // Small vectors (e.g., dim=64 f32): prefetch 4 ahead + } else if vector_bytes <= 1024 { + 2 // Medium vectors (e.g., dim=256 f32): prefetch 2 ahead + } else if vector_bytes <= 4096 { + 1 // Large vectors (e.g., dim=1024 f32): prefetch 1 ahead + } else { + 0 // Very large vectors: no prefetch (would pollute cache) + }; + + // Only prefetch graph structure for smaller vectors + let prefetch_graph = vector_bytes <= 512; + + Self { + prefetch_depth, + prefetch_graph, + } + } + + /// Create config for a specific vector type. + pub fn for_type(dim: usize) -> Self { + Self::for_vector_size(dim, std::mem::size_of::()) + } +} + +/// Prefetch multiple cache lines of a vector. +/// +/// For larger vectors spanning multiple cache lines, this prefetches +/// all cache lines to ensure the entire vector is in cache. +#[inline] +pub fn prefetch_vector(data: &[T]) { + if data.is_empty() { + return; + } + + let bytes = std::mem::size_of_val(data); + let cache_line_size = 64; + let cache_lines = (bytes + cache_line_size - 1) / cache_line_size; + let ptr = data.as_ptr() as *const u8; + + // Prefetch up to 8 cache lines (512 bytes) + for i in 0..cache_lines.min(8) { + prefetch_read(unsafe { ptr.add(i * cache_line_size) } as *const T); + } +} + +/// Prefetch data for multiple neighbors ahead in the search path. +/// +/// This function prefetches vector data for neighbors that will be +/// visited in upcoming iterations, hiding memory latency. +#[inline] +pub fn prefetch_neighbors<'a, T, F>( + neighbors: &[u32], + current_idx: usize, + config: &PrefetchConfig, + data_getter: &F, +) where + T: 'a, + F: Fn(u32) -> Option<&'a [T]>, +{ + if config.prefetch_depth == 0 { + return; + } + + // Prefetch data for neighbors ahead in the iteration + for offset in 1..=config.prefetch_depth { + let prefetch_idx = current_idx + offset; + if prefetch_idx < neighbors.len() { + if let Some(data) = data_getter(neighbors[prefetch_idx]) { + prefetch_vector(data); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prefetch_read_null() { + // Should not crash on null pointer + prefetch_read(std::ptr::null::()); + } + + #[test] + fn test_prefetch_read_valid() { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + prefetch_read(data.as_ptr()); + } + + #[test] + fn test_prefetch_slice_empty() { + let empty: &[f32] = &[]; + prefetch_slice(empty); + } + + #[test] + fn test_prefetch_slice_valid() { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + prefetch_slice(&data); + } + + #[test] + fn test_prefetch_config_default() { + let config = PrefetchConfig::default(); + assert_eq!(config.prefetch_depth, 2); + assert!(config.prefetch_graph); + } + + #[test] + fn test_prefetch_config_for_vector_size() { + // Small vectors (256 bytes = 64 f32) + let small = PrefetchConfig::for_vector_size(64, 4); + assert_eq!(small.prefetch_depth, 4); + assert!(small.prefetch_graph); + + // Medium vectors (1024 bytes = 256 f32) + let medium = PrefetchConfig::for_vector_size(256, 4); + assert_eq!(medium.prefetch_depth, 2); + assert!(!medium.prefetch_graph); // 1024 > 512, so no graph prefetch + + // Large vectors (4096 bytes = 1024 f32) + let large = PrefetchConfig::for_vector_size(1024, 4); + assert_eq!(large.prefetch_depth, 1); + assert!(!large.prefetch_graph); + + // Very large vectors (8192 bytes = 2048 f32) + let very_large = PrefetchConfig::for_vector_size(2048, 4); + assert_eq!(very_large.prefetch_depth, 0); + assert!(!very_large.prefetch_graph); + } + + #[test] + fn test_prefetch_config_for_type() { + let config = PrefetchConfig::for_type::(128); + // 128 * 4 = 512 bytes + assert_eq!(config.prefetch_depth, 2); + assert!(config.prefetch_graph); + } + + #[test] + fn test_prefetch_vector_multiple_cache_lines() { + // Vector spanning multiple cache lines + let data: Vec = (0..256).map(|i| i as f32).collect(); // 1024 bytes = 16 cache lines + prefetch_vector(&data); + } + + #[test] + fn test_prefetch_vector_empty() { + let empty: &[f32] = &[]; + prefetch_vector(empty); + } +} + diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt deleted file mode 100644 index b459e377d..000000000 --- a/src/VecSim/CMakeLists.txt +++ /dev/null @@ -1,60 +0,0 @@ -cmake_minimum_required(VERSION 3.10) -cmake_policy(SET CMP0077 NEW) -set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) - -set(CMAKE_CXX_STANDARD 20) - -project(VecsimLib) - -file(GLOB_RECURSE headers ./**.h) -set(HEADER_LIST "${headers}") - -include_directories(../) - -set(SVS_CXX_STANDARD 20) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=gnu++20") - -add_subdirectory(spaces) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") - -if(USE_SVS) - set(svs_factory_file "index_factories/svs_factory.cpp") -endif() - -add_library(VectorSimilarity ${VECSIM_LIBTYPE} - index_factories/brute_force_factory.cpp - index_factories/hnsw_factory.cpp - index_factories/tiered_factory.cpp - index_factories/svs_factory.cpp - index_factories/index_factory.cpp - algorithms/hnsw/visited_nodes_handler.cpp - vec_sim.cpp - vec_sim_debug.cpp - vec_sim_interface.cpp - query_results.cpp - info_iterator.cpp - utils/vec_utils.cpp - containers/data_block.cpp - containers/data_blocks_container.cpp - memory/vecsim_malloc.cpp - memory/vecsim_base.cpp - ${HEADER_LIST} -) - -target_link_libraries(VectorSimilarity VectorSimilaritySpaces) - -if (TARGET svs::svs) - target_link_libraries(VectorSimilarity svs::svs) - if(TARGET svs::svs_static_library) - target_link_libraries(VectorSimilarity svs::svs_static_library) - endif() -endif() - -if(VECSIM_BUILD_TESTS) - add_library(VectorSimilaritySerializer - algorithms/hnsw/hnsw_serializer.cpp - algorithms/svs/svs_serializer.cpp - ) - target_link_libraries(VectorSimilarity VectorSimilaritySerializer) -endif() diff --git a/src/VecSim/__init__.py b/src/VecSim/__init__.py deleted file mode 100644 index 244c7ba6f..000000000 --- a/src/VecSim/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright Redis Ltd. 2021 - present -# Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or -# the Server Side Public License v1 (SSPLv1). - -pass # needed for poetry to consider this to be a package diff --git a/src/VecSim/algorithms/brute_force/bf_batch_iterator.h b/src/VecSim/algorithms/brute_force/bf_batch_iterator.h deleted file mode 100644 index 5477db21a..000000000 --- a/src/VecSim/algorithms/brute_force/bf_batch_iterator.h +++ /dev/null @@ -1,214 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/batch_iterator.h" -#include "VecSim/algorithms/brute_force/brute_force.h" - -#include -#include -#include -#include //nth_element -#include -#include -#include - -using std::pair; - -template -class BF_BatchIterator : public VecSimBatchIterator { -protected: - const BruteForceIndex *index; - size_t index_label_count; // number of labels in the index when calculating the scores, - // which is the only time we access the index. - vecsim_stl::vector> scores; // vector of scores for every label. - size_t scores_valid_start_pos; // the first index in the scores vector that contains a vector - // that hasn't been returned already. - - VecSimQueryReply *searchByHeuristics(size_t n_res, VecSimQueryReply_Order order); - VecSimQueryReply *selectBasedSearch(size_t n_res); - VecSimQueryReply *heapBasedSearch(size_t n_res); - void swapScores(const vecsim_stl::unordered_map &TopCandidatesIndices, - size_t res_num); - - virtual inline VecSimQueryReply_Code calculateScores() = 0; - -public: - BF_BatchIterator(void *query_vector, const BruteForceIndex *bf_index, - VecSimQueryParams *queryParams, std::shared_ptr allocator); - - VecSimQueryReply *getNextResults(size_t n_res, VecSimQueryReply_Order order) override; - - bool isDepleted() override; - - void reset() override; - - ~BF_BatchIterator() override = default; -}; - -/******************** Implementation **************/ - -// heuristics: decide if using heap or select search, based on the ratio between the -// number of remaining results and the index size. -template -VecSimQueryReply * -BF_BatchIterator::searchByHeuristics(size_t n_res, - VecSimQueryReply_Order order) { - if ((this->index_label_count - this->getResultsCount()) / 1000 > n_res) { - // Heap based search always returns the results ordered by score - return this->heapBasedSearch(n_res); - } - VecSimQueryReply *rep = this->selectBasedSearch(n_res); - if (order == BY_SCORE) { - sort_results_by_score(rep); - } else if (order == BY_SCORE_THEN_ID) { - sort_results_by_score_then_id(rep); - } - return rep; -} - -template -void BF_BatchIterator::swapScores( - const vecsim_stl::unordered_map &TopCandidatesIndices, size_t res_num) { - // Create a set of the indices in the scores array for every results that we return. - vecsim_stl::set indices(this->allocator); - for (auto pos : TopCandidatesIndices) { - indices.insert(pos.second); - } - // Get the first valid position in the next iteration. - size_t next_scores_valid_start_pos = this->scores_valid_start_pos + res_num; - // Get the first index of a results in this iteration which is greater or equal to - // next_scores_valid_start_pos. - auto reuse_index_it = indices.lower_bound(next_scores_valid_start_pos); - auto it = indices.begin(); - size_t ind = this->scores_valid_start_pos; - // Swap elements which are in the first res_num positions in the scores array, and place them - // in indices of results that we return now (reuse these indices). - while (ind < next_scores_valid_start_pos) { - // don't swap if there is a result in one of the heading indices which will be invalid from - // next iteration. - if (*it == ind) { - it++; - } else { - this->scores[*reuse_index_it] = this->scores[ind]; - reuse_index_it++; - } - ind++; - } - this->scores_valid_start_pos = next_scores_valid_start_pos; -} - -template -VecSimQueryReply *BF_BatchIterator::heapBasedSearch(size_t n_res) { - auto rep = new VecSimQueryReply(this->allocator); - DistType upperBound = std::numeric_limits::lowest(); - vecsim_stl::max_priority_queue TopCandidates(this->allocator); - // map vector's label to its index in the scores vector. - vecsim_stl::unordered_map TopCandidatesIndices(n_res, this->allocator); - for (size_t i = this->scores_valid_start_pos; i < this->scores.size(); i++) { - if (TopCandidates.size() >= n_res) { - if (this->scores[i].first < upperBound) { - // remove the furthest vector from the candidates and from the label->index mappings - // we first remove the worst candidate so we wont exceed the allocated size - TopCandidatesIndices.erase(TopCandidates.top().second); - TopCandidates.pop(); - } else { - continue; - } - } - // top candidate heap size is smaller than n either because we didn't reach n_res yet, - // or we popped the heap top since the the current score is closer - TopCandidates.emplace(this->scores[i].first, this->scores[i].second); - TopCandidatesIndices[this->scores[i].second] = i; - upperBound = TopCandidates.top().first; - } - - // Save the top results to return. - rep->results.resize(TopCandidates.size()); - for (auto result = rep->results.rbegin(); result != rep->results.rend(); result++) { - std::tie(result->score, result->id) = TopCandidates.top(); - TopCandidates.pop(); - } - swapScores(TopCandidatesIndices, rep->results.size()); - return rep; -} - -template -VecSimQueryReply *BF_BatchIterator::selectBasedSearch(size_t n_res) { - auto rep = new VecSimQueryReply(this->allocator); - size_t remaining_vectors_count = this->scores.size() - this->scores_valid_start_pos; - // Get an iterator to the effective first element in the scores array, which is the first - // element that hasn't been returned in previous iterations. - auto valid_begin_it = this->scores.begin() + (int)(this->scores_valid_start_pos); - // We return up to n_res vectors, the remaining vectors size is an upper bound. - if (n_res > remaining_vectors_count) { - n_res = remaining_vectors_count; - } - auto n_th_element_pos = valid_begin_it + (int)n_res; - // This will perform an in-place partition of the elements in the slice of the array that - // contains valid results, based on the n-th element as the pivot - every element with a lower - // will be placed before it, and all the rest will be placed after. - std::nth_element(valid_begin_it, n_th_element_pos, this->scores.end()); - - rep->results.reserve(n_res); - for (size_t i = this->scores_valid_start_pos; i < this->scores_valid_start_pos + n_res; i++) { - rep->results.push_back(VecSimQueryResult{this->scores[i].second, this->scores[i].first}); - } - // Update the valid results start position after returning the results. - this->scores_valid_start_pos += rep->results.size(); - return rep; -} - -template -BF_BatchIterator::BF_BatchIterator( - void *query_vector, const BruteForceIndex *bf_index, - VecSimQueryParams *queryParams, std::shared_ptr allocator) - : VecSimBatchIterator(query_vector, queryParams ? queryParams->timeoutCtx : nullptr, allocator), - index(bf_index), index_label_count(index->indexLabelCount()), scores(allocator), - scores_valid_start_pos(0) {} - -template -VecSimQueryReply * -BF_BatchIterator::getNextResults(size_t n_res, VecSimQueryReply_Order order) { - // Only in the first iteration we need to compute all the scores - if (this->scores.empty()) { - assert(getResultsCount() == 0); - - // The only time we access the index. This function also updates the iterator's label count. - auto rc = calculateScores(); - - if (VecSim_OK != rc) { - return new VecSimQueryReply(this->allocator, rc); - } - } - if (VECSIM_TIMEOUT(this->getTimeoutCtx())) { - return new VecSimQueryReply(this->allocator, VecSim_QueryReply_TimedOut); - } - VecSimQueryReply *rep = searchByHeuristics(n_res, order); - - this->updateResultsCount(VecSimQueryReply_Len(rep)); - if (order == BY_ID) { - sort_results_by_id(rep); - } - return rep; -} - -template -bool BF_BatchIterator::isDepleted() { - assert(this->getResultsCount() <= this->index_label_count); - bool depleted = this->getResultsCount() == this->index_label_count; - return depleted; -} - -template -void BF_BatchIterator::reset() { - this->scores.clear(); - this->resetResultsCount(); - this->scores_valid_start_pos = 0; -} diff --git a/src/VecSim/algorithms/brute_force/bfm_batch_iterator.h b/src/VecSim/algorithms/brute_force/bfm_batch_iterator.h deleted file mode 100644 index 466b97e72..000000000 --- a/src/VecSim/algorithms/brute_force/bfm_batch_iterator.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "bf_batch_iterator.h" - -#include - -template -class BFM_BatchIterator : public BF_BatchIterator { -public: - BFM_BatchIterator(void *query_vector, const BruteForceIndex *index, - VecSimQueryParams *queryParams, std::shared_ptr allocator) - : BF_BatchIterator(query_vector, index, queryParams, allocator) {} - - ~BFM_BatchIterator() override = default; - -private: - VecSimQueryReply_Code calculateScores() override { - this->index_label_count = this->index->indexLabelCount(); - this->scores.reserve(this->index_label_count); - vecsim_stl::unordered_map tmp_scores(this->index_label_count, - this->allocator); - - idType curr_id = 0; - auto vectors_it = this->index->getVectorsIterator(); - while (auto *vector = vectors_it->next()) { - // Compute the scores for every vector and extend the scores array. - if (VECSIM_TIMEOUT(this->getTimeoutCtx())) { - return VecSim_QueryReply_TimedOut; - } - auto score = this->index->calcDistance(vector, this->getQueryBlob()); - labelType curr_label = this->index->getVectorLabel(curr_id); - auto curr_pair = tmp_scores.find(curr_label); - // For each score, emplace or update the score of the label. - if (curr_pair == tmp_scores.end()) { - tmp_scores.emplace(curr_label, score); - } else if (curr_pair->second > score) { - curr_pair->second = score; - } - ++curr_id; - } - assert(curr_id == this->index->indexSize()); - for (auto p : tmp_scores) { - this->scores.emplace_back(p.second, p.first); - } - return VecSim_QueryReply_OK; - } -}; diff --git a/src/VecSim/algorithms/brute_force/bfs_batch_iterator.h b/src/VecSim/algorithms/brute_force/bfs_batch_iterator.h deleted file mode 100644 index 03ca10515..000000000 --- a/src/VecSim/algorithms/brute_force/bfs_batch_iterator.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "bf_batch_iterator.h" - -#include - -template -class BFS_BatchIterator : public BF_BatchIterator { -public: - BFS_BatchIterator(void *query_vector, const BruteForceIndex *index, - VecSimQueryParams *queryParams, std::shared_ptr allocator) - : BF_BatchIterator(query_vector, index, queryParams, allocator) {} - - ~BFS_BatchIterator() override = default; - -private: - VecSimQueryReply_Code calculateScores() override { - this->index_label_count = this->index->indexLabelCount(); - this->scores.reserve(this->index_label_count); - - idType curr_id = 0; - auto vectors_it = this->index->getVectorsIterator(); - while (auto *vector = vectors_it->next()) { - // Compute the scores for every vector and extend the scores array. - if (VECSIM_TIMEOUT(this->getTimeoutCtx())) { - return VecSim_QueryReply_TimedOut; - } - auto score = this->index->calcDistance(vector, this->getQueryBlob()); - this->scores.emplace_back(score, this->index->getVectorLabel(curr_id)); - ++curr_id; - } - assert(curr_id == this->index->indexSize()); - return VecSim_QueryReply_OK; - } -}; diff --git a/src/VecSim/algorithms/brute_force/brute_force.h b/src/VecSim/algorithms/brute_force/brute_force.h deleted file mode 100644 index 3be453024..000000000 --- a/src/VecSim/algorithms/brute_force/brute_force.h +++ /dev/null @@ -1,448 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/containers/data_block.h" -#include "VecSim/vec_sim_index.h" -#include "VecSim/spaces/spaces.h" -#include "VecSim/utils/vecsim_stl.h" -#include "VecSim/containers/vecsim_results_container.h" -#include "VecSim/index_factories/brute_force_factory.h" -#include "VecSim/spaces/spaces.h" -#include "VecSim/query_result_definitions.h" -#include "VecSim/containers/data_blocks_container.h" -#include "VecSim/containers/raw_data_container_interface.h" -#include "VecSim/utils/vec_utils.h" - -#include -#include -#include -#include -#include -#include - -using spaces::dist_func_t; - -template -class BruteForceIndex : public VecSimIndexAbstract { -protected: - vecsim_stl::vector idToLabelMapping; - idType count; - -public: - BruteForceIndex(const BFParams *params, const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components); - - size_t indexSize() const override; - size_t indexCapacity() const override; - std::unique_ptr getVectorsIterator() const; - const DataType *getDataByInternalId(idType id) const { - return reinterpret_cast(this->vectors->getElement(id)); - } - VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const override; - VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams) const override; - VecSimIndexDebugInfo debugInfo() const override; - VecSimDebugInfoIterator *debugInfoIterator() const override; - VecSimIndexBasicInfo basicInfo() const override; - VecSimBatchIterator *newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const override; - bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override; - labelType getVectorLabel(idType id) const { return idToLabelMapping.at(id); } - - const RawDataContainer *getVectorsContainer() const { return this->vectors; } - - // Remove a specific vector that is stored under a label from the index by its internal id. - virtual int deleteVectorById(labelType label, idType id) = 0; - // Remove a vector and return a map between internal ids and the original internal ids of the - // vector that they hold as a result of the overall removals and swaps, along with its label. - virtual std::unordered_map> - deleteVectorAndGetUpdatedIds(labelType label) = 0; - // Check if a certain label exists in the index. - virtual bool isLabelExists(labelType label) = 0; - - // Unsafe (assume index data guard is held in MT mode). - virtual vecsim_stl::vector getElementIds(size_t label) const = 0; - - virtual ~BruteForceIndex() = default; -#ifdef BUILD_TESTS - void fitMemory() override { - if (count == 0) { - return; - } - idToLabelMapping.shrink_to_fit(); - resizeLabelLookup(idToLabelMapping.size()); - } - - size_t indexMetaDataCapacity() const override { return idToLabelMapping.capacity(); } -#endif - -protected: - // Private internal function that implements generic single vector insertion. - virtual void appendVector(const void *vector_data, labelType label); - - // Private internal function that implements generic single vector deletion. - virtual void removeVector(idType id); - - void resizeIndexCommon(size_t new_max_elements) { - assert(new_max_elements % this->blockSize == 0 && - "new_max_elements must be a multiple of blockSize"); - this->log(VecSimCommonStrings::LOG_VERBOSE_STRING, "Resizing FLAT index from %zu to %zu", - idToLabelMapping.capacity(), new_max_elements); - assert(idToLabelMapping.capacity() == idToLabelMapping.size()); - idToLabelMapping.resize(new_max_elements); - idToLabelMapping.shrink_to_fit(); - assert(idToLabelMapping.capacity() == idToLabelMapping.size()); - resizeLabelLookup(new_max_elements); - } - - void growByBlock() { - assert(indexCapacity() == idToLabelMapping.capacity()); - assert(indexCapacity() % this->blockSize == 0); - assert(indexCapacity() == indexSize()); - assert((dynamic_cast(this->vectors)->numBlocks() == - (indexSize()) / this->blockSize)); - - resizeIndexCommon(indexCapacity() + this->blockSize); - } - - void shrinkByBlock() { - assert(indexCapacity() >= this->blockSize); - assert(indexCapacity() % this->blockSize == 0); - assert(dynamic_cast(this->vectors)->numBlocks() == - indexSize() / this->blockSize); - - if (indexSize() == 0) { - resizeIndexCommon(0); - } else if (indexCapacity() >= (indexSize() + 2 * this->blockSize)) { - - assert(indexCapacity() == idToLabelMapping.capacity()); - assert(idToLabelMapping.size() == idToLabelMapping.capacity()); - assert(dynamic_cast(this->vectors)->size() + - 2 * this->blockSize == - idToLabelMapping.capacity()); - resizeIndexCommon(indexCapacity() - this->blockSize); - } - } - - void setVectorLabel(idType id, labelType new_label) { idToLabelMapping.at(id) = new_label; } - // inline priority queue getter that need to be implemented by derived class - virtual vecsim_stl::abstract_priority_queue * - getNewMaxPriorityQueue() const = 0; - - // inline label to id setters that need to be implemented by derived class - virtual std::unique_ptr - getNewResultsContainer(size_t cap) const = 0; - - // inline label to id setters that need to be implemented by derived class - virtual void replaceIdOfLabel(labelType label, idType new_id, idType old_id) = 0; - virtual void setVectorId(labelType label, idType id) = 0; - virtual void resizeLabelLookup(size_t new_max_elements) = 0; - - virtual VecSimBatchIterator * - newBatchIterator_Instance(void *queryBlob, VecSimQueryParams *queryParams) const = 0; - -#ifdef BUILD_TESTS -#include "VecSim/algorithms/brute_force/brute_force_friend_tests.h" -#endif -}; - -/******************************* Implementation **********************************/ - -/******************** Ctor / Dtor **************/ -template -BruteForceIndex::BruteForceIndex( - const BFParams *params, const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components) - : VecSimIndexAbstract(abstractInitParams, components), - idToLabelMapping(this->allocator), count(0) { - assert(VecSimType_sizeof(this->vecType) == sizeof(DataType)); -} - -/******************** Implementation **************/ - -template -void BruteForceIndex::appendVector(const void *vector_data, labelType label) { - // Resize the index meta data structures if needed - if (indexSize() >= indexCapacity()) { - growByBlock(); - } - - auto processed_blob = this->preprocessForStorage(vector_data); - // Give the vector new id and increase count. - idType id = this->count++; - - // add vector data to vector raw data container - this->vectors->addElement(processed_blob.get(), id); - - // add label to idToLabelMapping - setVectorLabel(id, label); - - // add id to label:id map - setVectorId(label, id); -} - -template -void BruteForceIndex::removeVector(idType id_to_delete) { - - // Get last vector id and label - idType last_idx = --this->count; - labelType last_idx_label = getVectorLabel(last_idx); - - // If we are *not* trying to remove the last vector, update mapping and move - // the data of the last vector in the index in place of the deleted vector. - if (id_to_delete != last_idx) { - assert(id_to_delete < last_idx); - // Update idToLabelMapping. - // Put the label of the last_id in the deleted_id. - setVectorLabel(id_to_delete, last_idx_label); - - // Update label2id mapping. - // Update this id in label:id pair of last index. - replaceIdOfLabel(last_idx_label, id_to_delete, last_idx); - - // Put data of last vector inplace of the deleted vector. - const char *last_vector_data = this->vectors->getElement(last_idx); - this->vectors->updateElement(id_to_delete, last_vector_data); - } - this->vectors->removeElement(last_idx); - - // If we reached to a multiply of a block size, we can reduce meta data structures size. - if (this->count % this->blockSize == 0) { - shrinkByBlock(); - } -} - -template -size_t BruteForceIndex::indexSize() const { - return this->count; -} - -template -size_t BruteForceIndex::indexCapacity() const { - return this->idToLabelMapping.size(); -} - -template -std::unique_ptr -BruteForceIndex::getVectorsIterator() const { - return this->vectors->getIterator(); -} - -template -VecSimQueryReply * -BruteForceIndex::topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const { - - auto rep = new VecSimQueryReply(this->allocator); - void *timeoutCtx = queryParams ? queryParams->timeoutCtx : NULL; - this->lastMode = STANDARD_KNN; - - if (0 == k) { - return rep; - } - - auto processed_query_ptr = this->preprocessQuery(queryBlob); - const void *processed_query = processed_query_ptr.get(); - DistType upperBound = std::numeric_limits::lowest(); - vecsim_stl::abstract_priority_queue *TopCandidates = - getNewMaxPriorityQueue(); - - // For vector, compute its scores and update the Top candidates max heap - auto vectors_it = this->vectors->getIterator(); - idType curr_id = 0; - while (auto *vector = vectors_it->next()) { - if (VECSIM_TIMEOUT(timeoutCtx)) { - rep->code = VecSim_QueryReply_TimedOut; - delete TopCandidates; - return rep; - } - auto score = this->calcDistance(vector, processed_query); - // If we have less than k or a better score, insert it. - if (score < upperBound || TopCandidates->size() < k) { - TopCandidates->emplace(score, getVectorLabel(curr_id)); - if (TopCandidates->size() > k) { - // If we now have more than k results, pop the worst one. - TopCandidates->pop(); - } - upperBound = TopCandidates->top().first; - } - ++curr_id; - } - assert(curr_id == this->count); - - rep->results.resize(TopCandidates->size()); - for (auto &result : std::ranges::reverse_view(rep->results)) { - std::tie(result.score, result.id) = TopCandidates->top(); - TopCandidates->pop(); - } - delete TopCandidates; - return rep; -} - -template -VecSimQueryReply * -BruteForceIndex::rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams) const { - auto processed_query_ptr = this->preprocessQuery(queryBlob); - auto rep = new VecSimQueryReply(this->allocator); - void *timeoutCtx = queryParams ? queryParams->timeoutCtx : nullptr; - this->lastMode = RANGE_QUERY; - - // Compute scores in every block and save results that are within the range. - auto res_container = - getNewResultsContainer(10); // Use 10 as the initial capacity for the dynamic array. - - DistType radius_ = DistType(radius); - auto vectors_it = this->vectors->getIterator(); - idType curr_id = 0; - const void *processed_query = processed_query_ptr.get(); - while (vectors_it->hasNext()) { - if (VECSIM_TIMEOUT(timeoutCtx)) { - rep->code = VecSim_QueryReply_TimedOut; - break; - } - auto score = this->calcDistance(vectors_it->next(), processed_query); - if (score <= radius_) { - res_container->emplace(getVectorLabel(curr_id), score); - } - ++curr_id; - } - // assert only if the loop finished iterating all the ids (we didn't get rep->code != - // VecSim_OK). - assert((rep->code != VecSim_OK || curr_id == this->count)); - rep->results = res_container->get_results(); - return rep; -} - -template -VecSimIndexDebugInfo BruteForceIndex::debugInfo() const { - - VecSimIndexDebugInfo info; - info.commonInfo = this->getCommonInfo(); - info.commonInfo.basicInfo.algo = VecSimAlgo_BF; - - return info; -} - -template -VecSimIndexBasicInfo BruteForceIndex::basicInfo() const { - - VecSimIndexBasicInfo info = this->getBasicInfo(); - info.algo = VecSimAlgo_BF; - info.isTiered = false; - return info; -} - -template -VecSimDebugInfoIterator *BruteForceIndex::debugInfoIterator() const { - VecSimIndexDebugInfo info = this->debugInfo(); - // For readability. Update this number when needed. - size_t numberOfInfoFields = 10; - auto *infoIterator = new VecSimDebugInfoIterator(numberOfInfoFields, this->allocator); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::ALGORITHM_STRING, - .fieldType = INFOFIELD_STRING, - .fieldValue = {FieldValue{ - .stringValue = VecSimAlgo_ToString(info.commonInfo.basicInfo.algo)}}}); - this->addCommonInfoToIterator(infoIterator, info.commonInfo); - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::BLOCK_SIZE_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.commonInfo.basicInfo.blockSize}}}); - return infoIterator; -} - -template -VecSimBatchIterator * -BruteForceIndex::newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const { - // force_copy == true. - auto queryBlobCopy = this->preprocessQuery(queryBlob, true); - - // take ownership of the blob copy and pass it to the batch iterator. - auto *queryBlobCopyPtr = queryBlobCopy.release(); - // Ownership of queryBlobCopy moves to BF_BatchIterator that will free it at the end. - return newBatchIterator_Instance(queryBlobCopyPtr, queryParams); -} - -template -bool BruteForceIndex::preferAdHocSearch(size_t subsetSize, size_t k, - bool initial_check) const { - // This heuristic is based on sklearn decision tree classifier (with 10 leaves nodes) - - // see scripts/BF_batches_clf.py - size_t index_size = this->indexSize(); - // Referring to too large subset size as if it was the maximum possible size. - subsetSize = std::min(subsetSize, index_size); - - size_t d = this->dim; - float r = (index_size == 0) ? 0.0f : (float)(subsetSize) / (float)this->indexLabelCount(); - bool res; - if (index_size <= 5500) { - // node 1 - res = true; - } else { - // node 2 - if (d <= 300) { - // node 3 - if (r <= 0.15) { - // node 5 - res = true; - } else { - // node 6 - if (r <= 0.35) { - // node 9 - if (d <= 75) { - // node 11 - res = false; - } else { - // node 12 - if (index_size <= 550000) { - // node 17 - res = true; - } else { - // node 18 - res = false; - } - } - } else { - // node 10 - res = false; - } - } - } else { - // node 4 - if (r <= 0.55) { - // node 7 - res = true; - } else { - // node 8 - if (d <= 750) { - // node 13 - res = false; - } else { - // node 14 - if (r <= 0.75) { - // node 15 - res = true; - } else { - // node 16 - res = false; - } - } - } - } - } - // Set the mode - if this isn't the initial check, we switched mode form batches to ad-hoc. - this->lastMode = - res ? (initial_check ? HYBRID_ADHOC_BF : HYBRID_BATCHES_TO_ADHOC_BF) : HYBRID_BATCHES; - return res; -} diff --git a/src/VecSim/algorithms/brute_force/brute_force_friend_tests.h b/src/VecSim/algorithms/brute_force/brute_force_friend_tests.h deleted file mode 100644 index 97318c69a..000000000 --- a/src/VecSim/algorithms/brute_force/brute_force_friend_tests.h +++ /dev/null @@ -1,21 +0,0 @@ - -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/friend_test_decl.h" -// Allow the following tests to access the index private members. -INDEX_TEST_FRIEND_CLASS(BruteForceTest_brute_force_vector_update_test_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceTest_resize_and_align_index_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceTest_brute_force_empty_index_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceTest_brute_force_reindexing_same_vector_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceTest_test_delete_swap_block_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceTest_test_dynamic_bf_info_iterator_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceTest_brute_force_zero_minimal_capacity_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceTest_preferAdHocOptimization_Test) -INDEX_TEST_FRIEND_CLASS(IndexAllocatorTest_test_bf_index_block_size_1_Test) -INDEX_TEST_FRIEND_CLASS(BM_VecSimBasics) diff --git a/src/VecSim/algorithms/brute_force/brute_force_multi.h b/src/VecSim/algorithms/brute_force/brute_force_multi.h deleted file mode 100644 index 343faea6b..000000000 --- a/src/VecSim/algorithms/brute_force/brute_force_multi.h +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "brute_force.h" -#include "bfm_batch_iterator.h" -#include "VecSim/utils/updatable_heap.h" -#include "VecSim/utils/vec_utils.h" - -template -class BruteForceIndex_Multi : public BruteForceIndex { -private: - vecsim_stl::unordered_map> labelToIdsLookup; - -public: - BruteForceIndex_Multi(const BFParams *params, const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components) - : BruteForceIndex(params, abstractInitParams, components), - labelToIdsLookup(this->allocator) {} - - ~BruteForceIndex_Multi() = default; - - int addVector(const void *vector_data, labelType label) override; - int deleteVector(labelType labelType) override; - int deleteVectorById(labelType label, idType id) override; - double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override; - inline size_t indexLabelCount() const override { return this->labelToIdsLookup.size(); } - - inline std::unique_ptr - getNewResultsContainer(size_t cap) const override { - return std::unique_ptr( - new (this->allocator) vecsim_stl::unique_results_container(cap, this->allocator)); - } - std::unordered_map> - deleteVectorAndGetUpdatedIds(labelType label) override; -#ifdef BUILD_TESTS - void getDataByLabel(labelType label, - std::vector> &vectors_output) const override { - - auto ids = labelToIdsLookup.find(label); - - for (idType id : ids->second) { - auto vec = std::vector(this->dim); - // Only copy the vector data (dim * sizeof(DataType)), not any additional metadata like - // the norm - memcpy(vec.data(), this->getDataByInternalId(id), this->dim * sizeof(DataType)); - vectors_output.push_back(vec); - } - } - - std::vector> getStoredVectorDataByLabel(labelType label) const override { - std::vector> vectors_output; - auto ids = labelToIdsLookup.find(label); - - for (idType id : ids->second) { - // Get the data pointer - need to cast to char* for memcpy - const char *data = reinterpret_cast(this->getDataByInternalId(id)); - - // Create a vector with the full data (including any metadata like norms) - std::vector vec(this->getStoredDataSize()); - memcpy(vec.data(), data, this->getStoredDataSize()); - vectors_output.push_back(std::move(vec)); - } - - return vectors_output; - } - -#endif -private: - // inline definitions - - inline void setVectorId(labelType label, idType id) override; - - inline void replaceIdOfLabel(labelType label, idType new_id, idType old_id) override; - - inline void resizeLabelLookup(size_t new_max_elements) override { - labelToIdsLookup.reserve(new_max_elements); - } - - inline bool isLabelExists(labelType label) override { - return labelToIdsLookup.find(label) != labelToIdsLookup.end(); - } - // Return a set of all labels that are stored in the index (helper for computing label count - // without duplicates in tiered index). Caller should hold the flat buffer lock for read. - inline vecsim_stl::set getLabelsSet() const override { - vecsim_stl::set keys(this->allocator); - for (auto &it : labelToIdsLookup) { - keys.insert(it.first); - } - return keys; - } - - inline vecsim_stl::vector getElementIds(size_t label) const override { - auto it = labelToIdsLookup.find(label); - if (it == labelToIdsLookup.end()) { - return vecsim_stl::vector{this->allocator}; // return an empty collection - } - return it->second; - } - - inline vecsim_stl::abstract_priority_queue * - getNewMaxPriorityQueue() const override { - return new (this->allocator) - vecsim_stl::updatable_max_heap(this->allocator); - } - - inline BF_BatchIterator * - newBatchIterator_Instance(void *queryBlob, VecSimQueryParams *queryParams) const override { - return new (this->allocator) - BFM_BatchIterator(queryBlob, this, queryParams, this->allocator); - } - -#ifdef BUILD_TESTS -#include "VecSim/algorithms/brute_force/brute_force_multi_tests_friends.h" -#endif -}; - -/******************************* Implementation **********************************/ - -template -int BruteForceIndex_Multi::addVector(const void *vector_data, labelType label) { - this->appendVector(vector_data, label); - return 1; -} - -template -int BruteForceIndex_Multi::deleteVector(labelType label) { - int ret = 0; - - // Find the id to delete. - auto deleted_label_ids_pair = this->labelToIdsLookup.find(label); - if (deleted_label_ids_pair == this->labelToIdsLookup.end()) { - // Nothing to delete. - return ret; - } - - // Deletes all vectors under the given label. - for (auto &ids = deleted_label_ids_pair->second; idType id_to_delete : ids) { - this->removeVector(id_to_delete); - ret++; - } - - // Remove the pair of the deleted vector. - labelToIdsLookup.erase(label); - return ret; -} - -template -std::unordered_map> -BruteForceIndex_Multi::deleteVectorAndGetUpdatedIds(labelType label) { - // Hold a mapping from ids that are removed and changed to the original ids that were swapped - // into it. For example, if we have ids 0, 1, 2, 3, 4 and are about to remove ids 1, 3, 4, we - // should get the following scenario: {1->4} => {1->4} => {1->2}. - // Explanation: first we delete 1 and swap it with 4. Then, we remove 3 and have no swap since 3 - // is the last id. Lastly, we delete the original 4 which is now in id 1, and swap it with 2. - // Eventually, in id 1 we should have the original vector whose id was 2. - std::unordered_map> updated_ids; - - // Find the id to delete. - auto deleted_label_ids_pair = this->labelToIdsLookup.find(label); - if (deleted_label_ids_pair == this->labelToIdsLookup.end()) { - // Nothing to delete. - return updated_ids; - } - - // Deletes all vectors under the given label. - for (size_t i = 0; i < deleted_label_ids_pair->second.size(); i++) { - idType cur_id_to_delete = deleted_label_ids_pair->second[i]; - // The removal take into consideration the current internal id to remove, even if it is not - // the original id, and it has swapped into this id after previous swap of another id that - // belongs to this label. - labelType last_id_label = this->idToLabelMapping[this->count - 1]; - this->removeVector(cur_id_to_delete); - // If cur_id_to_delete exists in the map, remove it as it is no longer valid, whether it - // will get a new value due to a swap, or it is the last element in the index. - updated_ids.erase(cur_id_to_delete); - // If a swap was made, update who was the original id that now resides in cur_id_to_delete. - if (cur_id_to_delete != this->count) { - if (updated_ids.find(this->count) != updated_ids.end()) { - updated_ids[cur_id_to_delete] = updated_ids[this->count]; - updated_ids.erase(this->count); - } else { - // Otherwise, the last id now resides where the deleted id was. - updated_ids[cur_id_to_delete] = {this->count, last_id_label}; - } - } - } - // Remove the pair of the deleted vector. - labelToIdsLookup.erase(label); - return updated_ids; -} - -template -int BruteForceIndex_Multi::deleteVectorById(labelType label, idType id) { - // Find the id to delete. - auto deleted_label_ids_pair = this->labelToIdsLookup.find(label); - if (deleted_label_ids_pair == this->labelToIdsLookup.end()) { - // Nothing to delete. - return 0; - } - - // Delete the specific vector id which is under the given label. - auto &ids = deleted_label_ids_pair->second; - for (size_t i = 0; i < ids.size(); i++) { - if (ids[i] == id) { - this->removeVector(id); - ids.erase(ids.begin() + i); - if (ids.empty()) { - labelToIdsLookup.erase(label); - } - return 1; - } - } - assert(false && "id to delete was not found under the given label"); - return 0; -} - -template -double -BruteForceIndex_Multi::getDistanceFrom_Unsafe(labelType label, - const void *vector_data) const { - - auto IDs = this->labelToIdsLookup.find(label); - if (IDs == this->labelToIdsLookup.end()) { - return INVALID_SCORE; - } - - DistType dist = std::numeric_limits::infinity(); - for (auto id : IDs->second) { - DistType d = this->calcDistance(this->getDataByInternalId(id), vector_data); - dist = (dist < d) ? dist : d; - } - - return dist; -} - -template -void BruteForceIndex_Multi::replaceIdOfLabel(labelType label, idType new_id, - idType old_id) { - assert(labelToIdsLookup.find(label) != labelToIdsLookup.end()); - // *Non-trivial code here* - in every iteration we replace the internal id of the previous last - // id that has been swapped with the deleted id. Note that if the old and the new replaced ids - // both belong to the same label, then we are going to delete the new id later on as well, since - // we are currently iterating on this exact array of ids in 'deleteVector'. Hence, the relevant - // part of the vector that should be updated is the "tail" that comes after the position of - // old_id, while the "head" may contain old occurrences of old_id that are irrelevant for the - // future deletions. Therefore, we iterate from end to beginning. For example, assuming we are - // deleting a label that contains the only 3 ids that exist in the index. Hence, we would - // expect the following scenario w.r.t. the ids array: - // [|1, 0, 2] -> [1, |0, 1] -> [1, 0, |0] (where | marks the current position) - auto &ids = labelToIdsLookup.at(label); - for (int i = ids.size() - 1; i >= 0; i--) { - if (ids[i] == old_id) { - ids[i] = new_id; - return; - } - } - assert(!"should have found the old id"); -} - -template -void BruteForceIndex_Multi::setVectorId(labelType label, idType id) { - auto ids = labelToIdsLookup.find(label); - if (ids != labelToIdsLookup.end()) { - ids->second.push_back(id); - } else { - // Initial capacity is 1. We can consider increasing this value or having it as a - // parameter. - labelToIdsLookup.emplace(label, vecsim_stl::vector{1, id, this->allocator}); - } -} diff --git a/src/VecSim/algorithms/brute_force/brute_force_multi_tests_friends.h b/src/VecSim/algorithms/brute_force/brute_force_multi_tests_friends.h deleted file mode 100644 index 9e20b78fd..000000000 --- a/src/VecSim/algorithms/brute_force/brute_force_multi_tests_friends.h +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/friend_test_decl.h" - -// Allow the following tests to access the index private members. -INDEX_TEST_FRIEND_CLASS(BruteForceMultiTest_resize_and_align_index_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceMultiTest_empty_index_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceMultiTest_search_more_than_there_is_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceMultiTest_indexing_same_vector_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceMultiTest_test_delete_swap_block_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceMultiTest_test_dynamic_bf_info_iterator_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceMultiTest_remove_vector_after_replacing_block_Test) -INDEX_TEST_FRIEND_CLASS(BruteForceMultiTest_removeVectorWithSwaps_Test) diff --git a/src/VecSim/algorithms/brute_force/brute_force_single.h b/src/VecSim/algorithms/brute_force/brute_force_single.h deleted file mode 100644 index 9afe46ed3..000000000 --- a/src/VecSim/algorithms/brute_force/brute_force_single.h +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "brute_force.h" -#include "bfs_batch_iterator.h" -#include "VecSim/utils/vec_utils.h" - -template -class BruteForceIndex_Single : public BruteForceIndex { - -protected: - vecsim_stl::unordered_map labelToIdLookup; - -public: - BruteForceIndex_Single(const BFParams *params, - const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components); - ~BruteForceIndex_Single() = default; - - int addVector(const void *vector_data, labelType label) override; - int deleteVector(labelType label) override; - int deleteVectorById(labelType label, idType id) override; - double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override; - - std::unique_ptr - getNewResultsContainer(size_t cap) const override { - return std::unique_ptr( - new (this->allocator) vecsim_stl::default_results_container(cap, this->allocator)); - } - - size_t indexLabelCount() const override { return this->count; } - std::unordered_map> - deleteVectorAndGetUpdatedIds(labelType label) override; - - // We call this when we KNOW that the label exists in the index. - idType getIdOfLabel(labelType label) const { return labelToIdLookup.find(label)->second; } - -#ifdef BUILD_TESTS - void getDataByLabel(labelType label, - std::vector> &vectors_output) const override { - - auto id = labelToIdLookup.at(label); - - auto vec = std::vector(this->dim); - // Only copy the vector data (dim * sizeof(DataType)), not any additional metadata like the - // norm - memcpy(vec.data(), this->getDataByInternalId(id), this->dim * sizeof(DataType)); - vectors_output.push_back(vec); - } - - std::vector> getStoredVectorDataByLabel(labelType label) const override { - std::vector> vectors_output; - auto id = labelToIdLookup.at(label); - - // Get the data pointer - need to cast to char* for memcpy - const char *data = reinterpret_cast(this->getDataByInternalId(id)); - - // Create a vector with the full data (including any metadata like norms) - std::vector vec(this->getStoredDataSize()); - memcpy(vec.data(), data, this->getStoredDataSize()); - vectors_output.push_back(std::move(vec)); - - return vectors_output; - } -#endif -protected: - // inline definitions - void setVectorId(labelType label, idType id) override { labelToIdLookup.emplace(label, id); } - - void replaceIdOfLabel(labelType label, idType new_id, idType old_id) override { - labelToIdLookup.at(label) = new_id; - } - - void resizeLabelLookup(size_t new_max_elements) override { - labelToIdLookup.reserve(new_max_elements); - } - - bool isLabelExists(labelType label) override { - return labelToIdLookup.find(label) != labelToIdLookup.end(); - } - // Return a set of all labels that are stored in the index (helper for computing label count - // without duplicates in tiered index). Caller should hold the flat buffer lock for read. - vecsim_stl::set getLabelsSet() const override { - vecsim_stl::set keys(this->allocator); - for (auto &it : labelToIdLookup) { - keys.insert(it.first); - } - return keys; - } - - vecsim_stl::vector getElementIds(size_t label) const override { - vecsim_stl::vector ids(this->allocator); - auto it = labelToIdLookup.find(label); - if (it != labelToIdLookup.end()) { - ids.push_back(it->second); - } - return ids; - } - - vecsim_stl::abstract_priority_queue * - getNewMaxPriorityQueue() const override { - return new (this->allocator) - vecsim_stl::max_priority_queue(this->allocator); - } - - BF_BatchIterator * - newBatchIterator_Instance(void *queryBlob, VecSimQueryParams *queryParams) const override { - return new (this->allocator) - BFS_BatchIterator(queryBlob, this, queryParams, this->allocator); - } - -#ifdef BUILD_TESTS -#include "VecSim/algorithms/brute_force/brute_force_friend_tests.h" - -#endif -}; - -/******************************* Implementation **********************************/ - -template -BruteForceIndex_Single::BruteForceIndex_Single( - const BFParams *params, const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components) - : BruteForceIndex(params, abstractInitParams, components), - labelToIdLookup(this->allocator) {} - -template -int BruteForceIndex_Single::addVector(const void *vector_data, - labelType label) { - - auto optionalID = this->labelToIdLookup.find(label); - // Check if label already exists, so it is an update operation. - if (optionalID != this->labelToIdLookup.end()) { - idType id = optionalID->second; - this->vectors->updateElement(id, vector_data); - return 0; - } - - this->appendVector(vector_data, label); - return 1; -} - -template -int BruteForceIndex_Single::deleteVector(labelType label) { - - // Find the id to delete. - auto deleted_label_id_pair = this->labelToIdLookup.find(label); - if (deleted_label_id_pair == this->labelToIdLookup.end()) { - // Nothing to delete. - return 0; - } - - // Get deleted vector id. - idType id_to_delete = deleted_label_id_pair->second; - - // Remove the pair of the deleted vector. - labelToIdLookup.erase(label); - - this->removeVector(id_to_delete); - return 1; -} - -template -std::unordered_map> -BruteForceIndex_Single::deleteVectorAndGetUpdatedIds(labelType label) { - - std::unordered_map> updated_ids; - // Find the id to delete. - auto deleted_label_id_pair = this->labelToIdLookup.find(label); - if (deleted_label_id_pair == this->labelToIdLookup.end()) { - // Nothing to delete. - return updated_ids; - } - - // Get deleted vector id. - idType id_to_delete = deleted_label_id_pair->second; - - // Remove the pair of the deleted vector. - labelToIdLookup.erase(label); - labelType last_id_label = this->idToLabelMapping[this->count - 1]; - this->removeVector(id_to_delete); // this will decrease this->count and make the swap - if (id_to_delete != this->count) { - updated_ids[id_to_delete] = {this->count, last_id_label}; - } - return updated_ids; -} - -template -int BruteForceIndex_Single::deleteVectorById(labelType label, idType id) { - return deleteVector(label); -} - -template -double -BruteForceIndex_Single::getDistanceFrom_Unsafe(labelType label, - const void *vector_data) const { - - auto optionalId = this->labelToIdLookup.find(label); - if (optionalId == this->labelToIdLookup.end()) { - return INVALID_SCORE; - } - idType id = optionalId->second; - - return this->calcDistance(this->getDataByInternalId(id), vector_data); -} diff --git a/src/VecSim/algorithms/hnsw/graph_data.h b/src/VecSim/algorithms/hnsw/graph_data.h deleted file mode 100644 index 28df1167b..000000000 --- a/src/VecSim/algorithms/hnsw/graph_data.h +++ /dev/null @@ -1,126 +0,0 @@ - -#pragma once - -#include -#include -#include -#include "VecSim/utils/vec_utils.h" - -template -using candidatesList = vecsim_stl::vector>; - -typedef uint16_t linkListSize; - -struct ElementLevelData { - // A list of ids that are pointing to the node where each edge is *unidirectional* - vecsim_stl::vector *incomingUnidirectionalEdges; - linkListSize numLinks; - // Flexible array member - https://en.wikipedia.org/wiki/Flexible_array_member - // Using this trick, we can have the links list as part of the ElementLevelData struct, and - // avoid the need to dereference a pointer to get to the links list. We have to calculate the - // size of the struct manually, as `sizeof(ElementLevelData)` will not include this member. We - // do so in the constructor of the index, under the name `levelDataSize` (and - // `elementGraphDataSize`). Notice that this member must be the last member of the struct and - // all nesting structs. - idType links[]; - - explicit ElementLevelData(std::shared_ptr allocator) - : incomingUnidirectionalEdges(new(allocator) vecsim_stl::vector(allocator)), - numLinks(0) {} - - linkListSize getNumLinks() const { return this->numLinks; } - idType getLinkAtPos(size_t pos) const { - assert(pos < numLinks); - return this->links[pos]; - } - const vecsim_stl::vector &getIncomingEdges() const { - return *incomingUnidirectionalEdges; - } - std::vector copyLinks() { - std::vector links_copy; - links_copy.assign(links, links + numLinks); - return links_copy; - } - // Sets the outgoing links of the current element. - // Assumes that the object has the capacity to hold all the links. - void setLinks(vecsim_stl::vector &links) { - numLinks = links.size(); - memcpy(this->links, links.data(), numLinks * sizeof(idType)); - } - template - void setLinks(candidatesList &links) { - numLinks = 0; - for (auto &link : links) { - this->links[numLinks++] = link.second; - } - } - void popLink() { this->numLinks--; } - void setNumLinks(linkListSize num) { this->numLinks = num; } - void setLinkAtPos(size_t pos, idType node_id) { this->links[pos] = node_id; } - void appendLink(idType node_id) { this->links[this->numLinks++] = node_id; } - void removeLink(idType node_id) { - size_t i = 0; - for (; i < numLinks; i++) { - if (links[i] == node_id) { - links[i] = links[numLinks - 1]; - break; - } - } - assert(i < numLinks && "Corruption in HNSW index"); // node_id not found - error - numLinks--; - } - void newIncomingUnidirectionalEdge(idType node_id) { - this->incomingUnidirectionalEdges->push_back(node_id); - } - bool removeIncomingUnidirectionalEdgeIfExists(idType node_id) { - return this->incomingUnidirectionalEdges->remove(node_id); - } - void swapNodeIdInIncomingEdges(idType id_before, idType id_after) { - auto it = std::find(this->incomingUnidirectionalEdges->begin(), - this->incomingUnidirectionalEdges->end(), id_before); - // This should always succeed - assert(it != this->incomingUnidirectionalEdges->end()); - *it = id_after; - } -}; - -struct ElementGraphData { - size_t toplevel; - std::mutex neighborsGuard; - ElementLevelData *others; - ElementLevelData level0; - - ElementGraphData(size_t maxLevel, size_t high_level_size, - std::shared_ptr allocator) - : toplevel(maxLevel), others(nullptr), level0(allocator) { - if (toplevel > 0) { - others = (ElementLevelData *)allocator->callocate(high_level_size * toplevel); - if (others == nullptr) { - throw std::runtime_error("VecSim index low memory error"); - } - for (size_t i = 0; i < maxLevel; i++) { - new ((char *)others + i * high_level_size) ElementLevelData(allocator); - } - } - } - ~ElementGraphData() = delete; // should be destroyed using `destroy' - - void destroy(size_t levelDataSize, std::shared_ptr allocator) { - delete this->level0.incomingUnidirectionalEdges; - ElementLevelData *cur_ld = this->others; - for (size_t i = 0; i < this->toplevel; i++) { - delete cur_ld->incomingUnidirectionalEdges; - cur_ld = reinterpret_cast(reinterpret_cast(cur_ld) + - levelDataSize); - } - allocator->free_allocation(this->others); - } - ElementLevelData &getElementLevelData(size_t level, size_t levelDataSize) { - assert(level <= this->toplevel); - if (level == 0) { - return this->level0; - } - return *reinterpret_cast(reinterpret_cast(this->others) + - (level - 1) * levelDataSize); - } -}; diff --git a/src/VecSim/algorithms/hnsw/hnsw.h b/src/VecSim/algorithms/hnsw/hnsw.h deleted file mode 100644 index e5994d314..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw.h +++ /dev/null @@ -1,2354 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "graph_data.h" -#include "visited_nodes_handler.h" -#include "VecSim/memory/vecsim_malloc.h" -#include "VecSim/utils/vecsim_stl.h" -#include "VecSim/utils/vec_utils.h" -#include "VecSim/containers/data_block.h" -#include "VecSim/containers/raw_data_container_interface.h" -#include "VecSim/containers/data_blocks_container.h" -#include "VecSim/containers/vecsim_results_container.h" -#include "VecSim/query_result_definitions.h" -#include "VecSim/vec_sim_common.h" -#include "VecSim/vec_sim_index.h" -#include "VecSim/tombstone_interface.h" - -#ifdef BUILD_TESTS -#include "hnsw_serialization_utils.h" -#include "VecSim/utils/serializer.h" -#include "hnsw_serializer.h" -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using std::pair; - -typedef uint8_t elementFlags; - -template -using candidatesMaxHeap = vecsim_stl::max_priority_queue; -template -using candidatesList = vecsim_stl::vector>; -template -using candidatesLabelsMaxHeap = vecsim_stl::abstract_priority_queue; -using graphNodeType = pair; // represented as: (element_id, level) - -////////////////////////////////////// Auxiliary HNSW structs ////////////////////////////////////// - -// Vectors flags (for marking a specific vector) -typedef enum { - DELETE_MARK = 0x1, // element is logically deleted, but still exists in the graph - IN_PROCESS = 0x2, // element is being inserted into the graph -} Flags; - -// The state of the index and the newly stored vector to be passed to indexVector. -struct HNSWAddVectorState { - idType newElementId; - int elementMaxLevel; - idType currEntryPoint; - int currMaxLevel; -}; - -#pragma pack(1) -struct ElementMetaData { - labelType label; - elementFlags flags; - - explicit ElementMetaData(labelType label = SIZE_MAX) noexcept - : label(label), flags(IN_PROCESS) {} -}; -#pragma pack() // restore default packing - -//////////////////////////////////// HNSW index implementation //////////////////////////////////// - -template -class HNSWIndex : public VecSimIndexAbstract, - public VecSimIndexTombstone -#ifdef BUILD_TESTS - , - public HNSWSerializer -#endif -{ -protected: - // Index build parameters - size_t maxElements; - size_t M; - size_t M0; - size_t efConstruction; - - // Index search parameter - size_t ef; - double epsilon; - - // Index meta-data (based on the data dimensionality and index parameters) - size_t elementGraphDataSize; - size_t levelDataSize; - double mult; - - // Index level generator of the top level for a new element - std::default_random_engine levelGenerator; - - // Index global state - these should be guarded by the indexDataGuard lock in - // multithreaded scenario. - size_t curElementCount; - idType entrypointNode; - size_t maxLevel; // this is the top level of the entry point's element - - // Index data - vecsim_stl::vector graphDataBlocks; - vecsim_stl::vector idToMetaData; - - // Used for marking the visited nodes in graph scans (the pool supports parallel graph scans). - // This is mutable since the object changes upon search operations as well (which are const). - mutable VisitedNodesHandlerPool visitedNodesHandlerPool; - mutable std::shared_mutex indexDataGuard; - -#ifdef BUILD_TESTS -#include "VecSim/algorithms/hnsw/hnsw_base_tests_friends.h" - -#include "hnsw_serializer_declarations.h" -#endif - -protected: - HNSWIndex() = delete; // default constructor is disabled. - HNSWIndex(const HNSWIndex &) = delete; // default (shallow) copy constructor is disabled. - size_t getRandomLevel(double reverse_size); - template // Either idType or labelType - void processCandidate(idType curNodeId, const void *data_point, size_t layer, size_t ef, - tag_t *elements_tags, tag_t visited_tag, - vecsim_stl::abstract_priority_queue &top_candidates, - candidatesMaxHeap &candidates_set, DistType &lowerBound) const; - void processCandidate_RangeSearch( - idType curNodeId, const void *data_point, size_t layer, double epsilon, - tag_t *elements_tags, tag_t visited_tag, - std::unique_ptr &top_candidates, - candidatesMaxHeap &candidate_set, DistType lowerBound, DistType radius) const; - candidatesMaxHeap searchLayer(idType ep_id, const void *data_point, size_t layer, - size_t ef) const; - candidatesLabelsMaxHeap * - searchBottomLayer_WithTimeout(idType ep_id, const void *data_point, size_t ef, size_t k, - void *timeoutCtx, VecSimQueryReply_Code *rc) const; - VecSimQueryResultContainer searchRangeBottomLayer_WithTimeout(idType ep_id, - const void *data_point, - double epsilon, DistType radius, - void *timeoutCtx, - VecSimQueryReply_Code *rc) const; - idType getNeighborsByHeuristic2(candidatesList &top_candidates, size_t M) const; - void getNeighborsByHeuristic2(candidatesList &top_candidates, size_t M, - vecsim_stl::vector ¬_chosen_candidates) const; - template - void getNeighborsByHeuristic2_internal( - candidatesList &top_candidates, size_t M, - vecsim_stl::vector *removed_candidates = nullptr) const; - // Helper function for re-selecting node's neighbors which was selected as a neighbor for - // a newly inserted node. Also, responsible for mutually connect the new node and the neighbor - // (unidirectional or bidirectional connection). - // *Note that node_lock and neighbor_lock should be locked upon calling this function* - void revisitNeighborConnections(size_t level, idType new_node_id, - const std::pair &neighbor_data, - ElementLevelData &new_node_level, - ElementLevelData &neighbor_level); - idType mutuallyConnectNewElement(idType new_node_id, - candidatesMaxHeap &top_candidates, size_t level); - void mutuallyUpdateForRepairedNode(idType node_id, size_t level, - vecsim_stl::vector &nodes_to_update, - vecsim_stl::vector &chosen_neighbors, - size_t max_M_cur); - - template - void greedySearchLevel(const void *vector_data, size_t level, idType &curObj, DistType &curDist, - void *timeoutCtx = nullptr, VecSimQueryReply_Code *rc = nullptr) const; - void repairConnectionsForDeletion(idType element_internal_id, idType neighbour_id, - ElementLevelData &node_level, - ElementLevelData &neighbor_level, size_t level, - vecsim_stl::vector &neighbours_bitmap); - void replaceEntryPoint(); - - void SwapLastIdWithDeletedId(idType element_internal_id, ElementGraphData *last_element, - const void *last_element_data); - - /** Add vector functions */ - // Protected internal function that implements generic single vector insertion. - - void appendVector(const void *vector_data, labelType label); - - HNSWAddVectorState storeVector(const void *vector_data, const labelType label); - - // Protected internal functions for index resizing. - void growByBlock(); - void shrinkByBlock(); - // DO NOT USE DIRECTLY. Use `[grow|shrink]ByBlock` instead. - void resizeIndexCommon(size_t new_max_elements); - - void emplaceToHeap(vecsim_stl::abstract_priority_queue &heap, DistType dist, - idType id) const; - void emplaceToHeap(vecsim_stl::abstract_priority_queue &heap, - DistType dist, idType id) const; - void removeAndSwap(idType internalId); - - size_t getVectorRelativeIndex(idType id) const { return id % this->blockSize; } - - // Flagging API - template - void markAs(idType internalId) { - __atomic_fetch_or(&idToMetaData[internalId].flags, FLAG, 0); - } - template - void unmarkAs(idType internalId) { - __atomic_fetch_and(&idToMetaData[internalId].flags, ~FLAG, 0); - } - template - bool isMarkedAs(idType internalId) const { - return idToMetaData[internalId].flags & FLAG; - } - void mutuallyRemoveNeighborAtPos(ElementLevelData &node_level, size_t level, idType node_id, - size_t pos); - -public: - HNSWIndex(const HNSWParams *params, const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, size_t random_seed = 100); - virtual ~HNSWIndex(); - - void setEf(size_t ef); - size_t getEf() const; - void setEpsilon(double epsilon); - double getEpsilon() const; - size_t indexSize() const override; - size_t indexCapacity() const override; - /** - * Checks if the index capacity is full to hint the caller a resize is needed. - * @note Must be called with indexDataGuard locked. - */ - size_t isCapacityFull() const; - size_t getEfConstruction() const; - size_t getM() const; - size_t getMaxLevel() const; - labelType getEntryPointLabel() const; - labelType getExternalLabel(idType internal_id) const { return idToMetaData[internal_id].label; } - auto safeGetEntryPointState() const; - void lockIndexDataGuard() const; - void unlockIndexDataGuard() const; - void lockSharedIndexDataGuard() const; - void unlockSharedIndexDataGuard() const; - void lockNodeLinks(idType node_id) const; - void unlockNodeLinks(idType node_id) const; - void lockNodeLinks(ElementGraphData *node_data) const; - void unlockNodeLinks(ElementGraphData *node_data) const; - VisitedNodesHandler *getVisitedList() const; - void returnVisitedList(VisitedNodesHandler *visited_nodes_handler) const; - VecSimIndexDebugInfo debugInfo() const override; - VecSimIndexBasicInfo basicInfo() const override; - VecSimDebugInfoIterator *debugInfoIterator() const override; - bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override; - const char *getDataByInternalId(idType internal_id) const; - ElementGraphData *getGraphDataByInternalId(idType internal_id) const; - ElementLevelData &getElementLevelData(idType internal_id, size_t level) const; - ElementLevelData &getElementLevelData(ElementGraphData *element, size_t level) const; - idType searchBottomLayerEP(const void *query_data, void *timeoutCtx, - VecSimQueryReply_Code *rc) const; - - void indexVector(const void *vector_data, const labelType label, - const HNSWAddVectorState &state); - VecSimQueryReply *topKQuery(const void *query_data, size_t k, - VecSimQueryParams *queryParams) const override; - VecSimQueryReply *rangeQuery(const void *query_data, double radius, - VecSimQueryParams *queryParams) const override; - - void markDeletedInternal(idType internalId); - bool isMarkedDeleted(idType internalId) const; - bool isInProcess(idType internalId) const; - void unmarkInProcess(idType internalId); - HNSWAddVectorState storeNewElement(labelType label, const void *vector_data); - void removeAndSwapMarkDeletedElement(idType internalId); - void repairNodeConnections(idType node_id, size_t level); - // For prefetching only. - const ElementMetaData *getMetaDataAddress(idType internal_id) const { - return idToMetaData.data() + internal_id; - } - vecsim_stl::vector safeCollectAllNodeIncomingNeighbors(idType node_id) const; - VecSimDebugCommandCode getHNSWElementNeighbors(size_t label, int ***neighborsData); - void insertElementToGraph(idType element_id, size_t element_max_level, idType entry_point, - size_t global_max_level, const void *vector_data); - void removeVectorInPlace(idType id); - - /*************************** Labels lookup API ***************************/ - - // Inline priority queue getter that need to be implemented by derived class. - virtual inline candidatesLabelsMaxHeap *getNewMaxPriorityQueue() const = 0; - - // Unsafe (assume index data guard is held in MT mode). - virtual vecsim_stl::vector getElementIds(size_t label) = 0; - - // Remove label from the index. - virtual int removeLabel(labelType label) = 0; - -#ifdef BUILD_TESTS - void fitMemory() override { - if (maxElements > 0) { - idToMetaData.shrink_to_fit(); - resizeLabelLookup(idToMetaData.size()); - } - } - - size_t indexMetaDataCapacity() const override { return idToMetaData.capacity(); } -#endif - -protected: - // inline label to id setters that need to be implemented by derived class - virtual std::unique_ptr - getNewResultsContainer(size_t cap) const = 0; - virtual void replaceIdOfLabel(labelType label, idType new_id, idType old_id) = 0; - virtual void setVectorId(labelType label, idType id) = 0; - virtual void resizeLabelLookup(size_t new_max_elements) = 0; -}; - -/** - * getters and setters of index data - */ - -template -void HNSWIndex::setEf(size_t ef) { - this->ef = ef; -} - -template -size_t HNSWIndex::getEf() const { - return this->ef; -} - -template -void HNSWIndex::setEpsilon(double epsilon) { - this->epsilon = epsilon; -} - -template -double HNSWIndex::getEpsilon() const { - return this->epsilon; -} - -template -size_t HNSWIndex::indexSize() const { - return this->curElementCount; -} - -template -size_t HNSWIndex::indexCapacity() const { - return this->maxElements; -} - -template -size_t HNSWIndex::isCapacityFull() const { - return indexSize() == this->maxElements; -} - -template -size_t HNSWIndex::getEfConstruction() const { - return this->efConstruction; -} - -template -size_t HNSWIndex::getM() const { - return this->M; -} - -template -size_t HNSWIndex::getMaxLevel() const { - return this->maxLevel; -} - -template -labelType HNSWIndex::getEntryPointLabel() const { - if (entrypointNode != INVALID_ID) - return getExternalLabel(entrypointNode); - return SIZE_MAX; -} - -template -const char *HNSWIndex::getDataByInternalId(idType internal_id) const { - return this->vectors->getElement(internal_id); -} - -template -ElementGraphData * -HNSWIndex::getGraphDataByInternalId(idType internal_id) const { - return (ElementGraphData *)graphDataBlocks[internal_id / this->blockSize].getElement( - internal_id % this->blockSize); -} - -template -size_t HNSWIndex::getRandomLevel(double reverse_size) { - std::uniform_real_distribution distribution(0.0, 1.0); - double r = -log(distribution(levelGenerator)) * reverse_size; - return (size_t)r; -} - -template -ElementLevelData &HNSWIndex::getElementLevelData(idType internal_id, - size_t level) const { - return getGraphDataByInternalId(internal_id)->getElementLevelData(level, this->levelDataSize); -} - -template -ElementLevelData &HNSWIndex::getElementLevelData(ElementGraphData *graph_data, - size_t level) const { - return graph_data->getElementLevelData(level, this->levelDataSize); -} - -template -VisitedNodesHandler *HNSWIndex::getVisitedList() const { - return visitedNodesHandlerPool.getAvailableVisitedNodesHandler(); -} - -template -void HNSWIndex::returnVisitedList( - VisitedNodesHandler *visited_nodes_handler) const { - visitedNodesHandlerPool.returnVisitedNodesHandlerToPool(visited_nodes_handler); -} - -template -void HNSWIndex::markDeletedInternal(idType internalId) { - // Here we are holding the global index data guard (and the main index lock of the tiered index - // for shared ownership). - assert(internalId < this->curElementCount); - if (!isMarkedDeleted(internalId)) { - if (internalId == entrypointNode) { - // Internally, we hold and release the entrypoint neighbors lock. - replaceEntryPoint(); - } - // Atomically set the deletion mark flag (note that other parallel threads may set the flags - // at the same time (for changing the IN_PROCESS flag). - markAs(internalId); - this->numMarkedDeleted++; - } -} - -template -bool HNSWIndex::isMarkedDeleted(idType internalId) const { - return isMarkedAs(internalId); -} - -template -bool HNSWIndex::isInProcess(idType internalId) const { - return isMarkedAs(internalId); -} - -template -void HNSWIndex::unmarkInProcess(idType internalId) { - // Atomically unset the IN_PROCESS mark flag (note that other parallel threads may set the flags - // at the same time (for marking the element with IN_PROCCESS flag). - unmarkAs(internalId); -} - -template -void HNSWIndex::lockIndexDataGuard() const { - indexDataGuard.lock(); -} - -template -void HNSWIndex::unlockIndexDataGuard() const { - indexDataGuard.unlock(); -} - -template -void HNSWIndex::lockSharedIndexDataGuard() const { - indexDataGuard.lock_shared(); -} - -template -void HNSWIndex::unlockSharedIndexDataGuard() const { - indexDataGuard.unlock_shared(); -} - -template -void HNSWIndex::lockNodeLinks(ElementGraphData *node_data) const { - node_data->neighborsGuard.lock(); -} - -template -void HNSWIndex::unlockNodeLinks(ElementGraphData *node_data) const { - node_data->neighborsGuard.unlock(); -} - -template -void HNSWIndex::lockNodeLinks(idType node_id) const { - lockNodeLinks(getGraphDataByInternalId(node_id)); -} - -template -void HNSWIndex::unlockNodeLinks(idType node_id) const { - unlockNodeLinks(getGraphDataByInternalId(node_id)); -} - -/** - * helper functions - */ - -template -void HNSWIndex::emplaceToHeap( - vecsim_stl::abstract_priority_queue &heap, DistType dist, idType id) const { - heap.emplace(dist, id); -} - -template -void HNSWIndex::emplaceToHeap( - vecsim_stl::abstract_priority_queue &heap, DistType dist, - idType id) const { - heap.emplace(dist, getExternalLabel(id)); -} - -// This function handles both label heaps and internal ids heaps. It uses the `emplaceToHeap` -// overloading to emplace correctly for both cases. -template -template -void HNSWIndex::processCandidate( - idType curNodeId, const void *query_data, size_t layer, size_t ef, tag_t *elements_tags, - tag_t visited_tag, vecsim_stl::abstract_priority_queue &top_candidates, - candidatesMaxHeap &candidate_set, DistType &lowerBound) const { - - ElementGraphData *cur_element = getGraphDataByInternalId(curNodeId); - lockNodeLinks(cur_element); - ElementLevelData &node_level = getElementLevelData(cur_element, layer); - linkListSize num_links = node_level.getNumLinks(); - if (num_links > 0) { - - const char *cur_data, *next_data; - // Pre-fetch first candidate tag address. - __builtin_prefetch(elements_tags + node_level.getLinkAtPos(0)); - // Pre-fetch first candidate data block address. - next_data = getDataByInternalId(node_level.getLinkAtPos(0)); - __builtin_prefetch(next_data); - - for (linkListSize j = 0; j < num_links - 1; j++) { - idType candidate_id = node_level.getLinkAtPos(j); - cur_data = next_data; - - // Pre-fetch next candidate tag address. - __builtin_prefetch(elements_tags + node_level.getLinkAtPos(j + 1)); - // Pre-fetch next candidate data block address. - next_data = getDataByInternalId(node_level.getLinkAtPos(j + 1)); - __builtin_prefetch(next_data); - - if (elements_tags[candidate_id] == visited_tag || isInProcess(candidate_id)) - continue; - - elements_tags[candidate_id] = visited_tag; - - DistType cur_dist = this->calcDistance(query_data, cur_data); - if (lowerBound > cur_dist || top_candidates.size() < ef) { - - candidate_set.emplace(-cur_dist, candidate_id); - - // Insert the candidate to the top candidates heap only if it is not marked as - // deleted. - if (!isMarkedDeleted(candidate_id)) - emplaceToHeap(top_candidates, cur_dist, candidate_id); - - if (top_candidates.size() > ef) - top_candidates.pop(); - - // If we have marked deleted elements, we need to verify that `top_candidates` is - // not empty (since we might have not added any non-deleted element yet). - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } - } - - // Running the last neighbor outside the loop to avoid prefetching invalid neighbor - idType candidate_id = node_level.getLinkAtPos(num_links - 1); - cur_data = next_data; - - if (elements_tags[candidate_id] != visited_tag && !isInProcess(candidate_id)) { - - elements_tags[candidate_id] = visited_tag; - - DistType cur_dist = this->calcDistance(query_data, cur_data); - if (lowerBound > cur_dist || top_candidates.size() < ef) { - candidate_set.emplace(-cur_dist, candidate_id); - - // Insert the candidate to the top candidates heap only if it is not marked as - // deleted. - if (!isMarkedDeleted(candidate_id)) - emplaceToHeap(top_candidates, cur_dist, candidate_id); - - if (top_candidates.size() > ef) - top_candidates.pop(); - - // If we have marked deleted elements, we need to verify that `top_candidates` is - // not empty (since we might have not added any non-deleted element yet). - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } - } - } - unlockNodeLinks(cur_element); -} - -template -void HNSWIndex::processCandidate_RangeSearch( - idType curNodeId, const void *query_data, size_t layer, double epsilon, tag_t *elements_tags, - tag_t visited_tag, std::unique_ptr &results, - candidatesMaxHeap &candidate_set, DistType dyn_range, DistType radius) const { - - auto *cur_element = getGraphDataByInternalId(curNodeId); - lockNodeLinks(cur_element); - ElementLevelData &node_level = getElementLevelData(cur_element, layer); - linkListSize num_links = node_level.getNumLinks(); - - if (num_links > 0) { - - const char *cur_data, *next_data; - // Pre-fetch first candidate tag address. - __builtin_prefetch(elements_tags + node_level.getLinkAtPos(0)); - // Pre-fetch first candidate data block address. - next_data = getDataByInternalId(node_level.getLinkAtPos(0)); - __builtin_prefetch(next_data); - - for (linkListSize j = 0; j < num_links - 1; j++) { - idType candidate_id = node_level.getLinkAtPos(j); - cur_data = next_data; - - // Pre-fetch next candidate tag address. - __builtin_prefetch(elements_tags + node_level.getLinkAtPos(j + 1)); - // Pre-fetch next candidate data block address. - next_data = getDataByInternalId(node_level.getLinkAtPos(j + 1)); - __builtin_prefetch(next_data); - - if (elements_tags[candidate_id] == visited_tag || isInProcess(candidate_id)) - continue; - - elements_tags[candidate_id] = visited_tag; - - DistType cur_dist = this->calcDistance(query_data, cur_data); - if (cur_dist < dyn_range) { - candidate_set.emplace(-cur_dist, candidate_id); - - // If the new candidate is in the requested radius, add it to the results set. - if (cur_dist <= radius && !isMarkedDeleted(candidate_id)) { - results->emplace(getExternalLabel(candidate_id), cur_dist); - } - } - } - // Running the last candidate outside the loop to avoid prefetching invalid candidate - idType candidate_id = node_level.getLinkAtPos(num_links - 1); - cur_data = next_data; - - if (elements_tags[candidate_id] != visited_tag && !isInProcess(candidate_id)) { - - elements_tags[candidate_id] = visited_tag; - - DistType cur_dist = this->calcDistance(query_data, cur_data); - if (cur_dist < dyn_range) { - candidate_set.emplace(-cur_dist, candidate_id); - - // If the new candidate is in the requested radius, add it to the results set. - if (cur_dist <= radius && !isMarkedDeleted(candidate_id)) { - results->emplace(getExternalLabel(candidate_id), cur_dist); - } - } - } - } - unlockNodeLinks(cur_element); -} - -template -candidatesMaxHeap -HNSWIndex::searchLayer(idType ep_id, const void *data_point, size_t layer, - size_t ef) const { - - auto *visited_nodes_handler = getVisitedList(); - tag_t visited_tag = visited_nodes_handler->getFreshTag(); - - candidatesMaxHeap top_candidates(this->allocator); - candidatesMaxHeap candidate_set(this->allocator); - - DistType lowerBound; - if (!isMarkedDeleted(ep_id)) { - DistType dist = this->calcDistance(data_point, getDataByInternalId(ep_id)); - lowerBound = dist; - top_candidates.emplace(dist, ep_id); - candidate_set.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidate_set.emplace(-lowerBound, ep_id); - } - - visited_nodes_handler->tagNode(ep_id, visited_tag); - - while (!candidate_set.empty()) { - pair curr_el_pair = candidate_set.top(); - - if ((-curr_el_pair.first) > lowerBound && top_candidates.size() >= ef) { - break; - } - candidate_set.pop(); - - processCandidate(curr_el_pair.second, data_point, layer, ef, - visited_nodes_handler->getElementsTags(), visited_tag, top_candidates, - candidate_set, lowerBound); - } - - returnVisitedList(visited_nodes_handler); - return top_candidates; -} - -template -idType -HNSWIndex::getNeighborsByHeuristic2(candidatesList &top_candidates, - const size_t M) const { - if (top_candidates.size() < M) { - return std::min_element(top_candidates.begin(), top_candidates.end(), - [](const auto &a, const auto &b) { return a.first < b.first; }) - ->second; - } - getNeighborsByHeuristic2_internal(top_candidates, M, nullptr); - return top_candidates.front().second; -} - -template -void HNSWIndex::getNeighborsByHeuristic2( - candidatesList &top_candidates, const size_t M, - vecsim_stl::vector &removed_candidates) const { - getNeighborsByHeuristic2_internal(top_candidates, M, &removed_candidates); -} - -template -template -void HNSWIndex::getNeighborsByHeuristic2_internal( - candidatesList &top_candidates, const size_t M, - vecsim_stl::vector *removed_candidates) const { - if (top_candidates.size() < M) { - return; - } - - candidatesList return_list(this->allocator); - vecsim_stl::vector cached_vectors(this->allocator); - return_list.reserve(M); - cached_vectors.reserve(M); - if constexpr (record_removed) { - removed_candidates->reserve(top_candidates.size()); - } - - // Sort the candidates by their distance (we don't mind the secondary order (the internal id)) - std::sort(top_candidates.begin(), top_candidates.end(), - [](const auto &a, const auto &b) { return a.first < b.first; }); - - auto current_pair = top_candidates.begin(); - for (; current_pair != top_candidates.end() && return_list.size() < M; ++current_pair) { - DistType candidate_to_query_dist = current_pair->first; - bool good = true; - const void *curr_vector = getDataByInternalId(current_pair->second); - - // a candidate is "good" to become a neighbour, unless we find - // another item that was already selected to the neighbours set which is closer - // to both q and the candidate than the distance between the candidate and q. - for (size_t i = 0; i < return_list.size(); i++) { - DistType candidate_to_selected_dist = - this->calcDistance(cached_vectors[i], curr_vector); - if (candidate_to_selected_dist < candidate_to_query_dist) { - if constexpr (record_removed) { - removed_candidates->push_back(current_pair->second); - } - good = false; - break; - } - } - if (good) { - cached_vectors.push_back(curr_vector); - return_list.push_back(*current_pair); - } - } - - if constexpr (record_removed) { - for (; current_pair != top_candidates.end(); ++current_pair) { - removed_candidates->push_back(current_pair->second); - } - } - - top_candidates.swap(return_list); -} - -template -void HNSWIndex::revisitNeighborConnections( - size_t level, idType new_node_id, const std::pair &neighbor_data, - ElementLevelData &new_node_level, ElementLevelData &neighbor_level) { - // Note - expect that node_lock and neighbor_lock are locked at that point. - - // Collect the existing neighbors and the new node as the neighbor's neighbors candidates. - candidatesList candidates(this->allocator); - candidates.reserve(neighbor_level.getNumLinks() + 1); - // Add the new node along with the pre-calculated distance to the current neighbor, - candidates.emplace_back(neighbor_data.first, new_node_id); - - idType selected_neighbor = neighbor_data.second; - const void *selected_neighbor_data = getDataByInternalId(selected_neighbor); - for (size_t j = 0; j < neighbor_level.getNumLinks(); j++) { - candidates.emplace_back( - this->calcDistance(getDataByInternalId(neighbor_level.getLinkAtPos(j)), - selected_neighbor_data), - neighbor_level.getLinkAtPos(j)); - } - - // Candidates will store the newly selected neighbours (for the neighbor). - size_t max_M_cur = level ? M : M0; - vecsim_stl::vector nodes_to_update(this->allocator); - getNeighborsByHeuristic2(candidates, max_M_cur, nodes_to_update); - - // Acquire all relevant locks for making the updates for the selected neighbor - all its removed - // neighbors, along with the neighbors itself and the cur node. - // but first, we release the node and neighbors lock to avoid deadlocks. - unlockNodeLinks(new_node_id); - unlockNodeLinks(selected_neighbor); - - // Check if the new node was selected as a neighbor for the current neighbor. - // Make sure to add the cur node to the list of nodes to update if it was selected. - bool cur_node_chosen; - auto new_node_iter = std::find(nodes_to_update.begin(), nodes_to_update.end(), new_node_id); - if (new_node_iter != nodes_to_update.end()) { - cur_node_chosen = false; - } else { - cur_node_chosen = true; - nodes_to_update.push_back(new_node_id); - } - nodes_to_update.push_back(selected_neighbor); - - std::sort(nodes_to_update.begin(), nodes_to_update.end()); - size_t nodes_to_update_count = nodes_to_update.size(); - for (size_t i = 0; i < nodes_to_update_count; i++) { - lockNodeLinks(nodes_to_update[i]); - } - size_t neighbour_neighbours_idx = 0; - bool update_cur_node_required = true; - for (size_t i = 0; i < neighbor_level.getNumLinks(); i++) { - if (!std::binary_search(nodes_to_update.begin(), nodes_to_update.end(), - neighbor_level.getLinkAtPos(i))) { - // The neighbor is not in the "to_update" nodes list - leave it as is. - neighbor_level.setLinkAtPos(neighbour_neighbours_idx++, neighbor_level.getLinkAtPos(i)); - continue; - } - if (neighbor_level.getLinkAtPos(i) == new_node_id) { - // The new node got into the neighbor's neighbours - this means there was an update in - // another thread during between we released and reacquire the locks - leave it - // as is. - neighbor_level.setLinkAtPos(neighbour_neighbours_idx++, neighbor_level.getLinkAtPos(i)); - update_cur_node_required = false; - continue; - } - // Now we know that we are looking at a node to be removed from the neighbor's neighbors. - mutuallyRemoveNeighborAtPos(neighbor_level, level, selected_neighbor, i); - } - - if (update_cur_node_required && new_node_level.getNumLinks() < max_M_cur && - !isMarkedDeleted(new_node_id) && !isMarkedDeleted(selected_neighbor)) { - // update the connection between the new node and the neighbor. - new_node_level.appendLink(selected_neighbor); - if (cur_node_chosen && neighbour_neighbours_idx < max_M_cur) { - // connection is mutual - both new node and the selected neighbor in each other's list. - neighbor_level.setLinkAtPos(neighbour_neighbours_idx++, new_node_id); - } else { - // unidirectional connection - put the new node in the neighbour's incoming edges. - neighbor_level.newIncomingUnidirectionalEdge(new_node_id); - } - } - // Done updating the neighbor's neighbors. - neighbor_level.setNumLinks(neighbour_neighbours_idx); - for (size_t i = 0; i < nodes_to_update_count; i++) { - unlockNodeLinks(nodes_to_update[i]); - } -} - -template -idType HNSWIndex::mutuallyConnectNewElement( - idType new_node_id, candidatesMaxHeap &top_candidates, size_t level) { - - // The maximum number of neighbors allowed for an existing neighbor (not new). - size_t max_M_cur = level ? M : M0; - - // Filter the top candidates to the selected neighbors by the algorithm heuristics. - // First, we need to copy the top candidates to a vector. - candidatesList top_candidates_list(this->allocator); - top_candidates_list.insert(top_candidates_list.end(), top_candidates.begin(), - top_candidates.end()); - // Use the heuristic to filter the top candidates, and get the next closest entry point. - idType next_closest_entry_point = getNeighborsByHeuristic2(top_candidates_list, M); - assert(top_candidates_list.size() <= M && - "Should be not be more than M candidates returned by the heuristic"); - - auto *new_node_level = getGraphDataByInternalId(new_node_id); - ElementLevelData &new_node_level_data = getElementLevelData(new_node_level, level); - assert(new_node_level_data.getNumLinks() == 0 && - "The newly inserted element should have blank link list"); - - for (auto &neighbor_data : top_candidates_list) { - idType selected_neighbor = neighbor_data.second; // neighbor's id - auto *neighbor_graph_data = getGraphDataByInternalId(selected_neighbor); - if (new_node_id < selected_neighbor) { - lockNodeLinks(new_node_level); - lockNodeLinks(neighbor_graph_data); - } else { - lockNodeLinks(neighbor_graph_data); - lockNodeLinks(new_node_level); - } - - // validations... - assert(new_node_level_data.getNumLinks() <= max_M_cur && "Neighbors number exceeds limit"); - assert(selected_neighbor != new_node_id && "Trying to connect an element to itself"); - - // Revalidate the updated count - this may change between iterations due to releasing the - // lock. - if (new_node_level_data.getNumLinks() == max_M_cur) { - // The new node cannot add more neighbors - this->log(VecSimCommonStrings::LOG_DEBUG_STRING, - "Couldn't add all chosen neighbors upon inserting a new node"); - unlockNodeLinks(new_node_level); - unlockNodeLinks(neighbor_graph_data); - break; - } - - // If one of the two nodes has already deleted - skip the operation. - if (isMarkedDeleted(new_node_id) || isMarkedDeleted(selected_neighbor)) { - unlockNodeLinks(new_node_level); - unlockNodeLinks(neighbor_graph_data); - continue; - } - - ElementLevelData &neighbor_level_data = getElementLevelData(neighbor_graph_data, level); - - // if the neighbor's neighbors list has the capacity to add the new node, make the update - // and finish. - if (neighbor_level_data.getNumLinks() < max_M_cur) { - new_node_level_data.appendLink(selected_neighbor); - neighbor_level_data.appendLink(new_node_id); - unlockNodeLinks(new_node_level); - unlockNodeLinks(neighbor_graph_data); - continue; - } - - // Otherwise - we need to re-evaluate the neighbor's neighbors. - // We collect all the existing neighbors and the new node as candidates, and mutually update - // the neighbor's neighbors. We also release the acquired locks inside this call. - revisitNeighborConnections(level, new_node_id, neighbor_data, new_node_level_data, - neighbor_level_data); - } - return next_closest_entry_point; -} - -template -void HNSWIndex::repairConnectionsForDeletion( - idType element_internal_id, idType neighbour_id, ElementLevelData &node_level, - ElementLevelData &neighbor_level, size_t level, vecsim_stl::vector &neighbours_bitmap) { - - if (isMarkedDeleted(neighbour_id)) { - // Just remove the deleted element from the neighbor's neighbors list. No need to repair as - // this change is temporary, this neighbor is about to be removed from the graph as well. - neighbor_level.removeLink(element_internal_id); - return; - } - - // Add the deleted element's neighbour's original neighbors in the candidates. - vecsim_stl::vector candidate_ids(this->allocator); - candidate_ids.reserve(node_level.getNumLinks() + neighbor_level.getNumLinks()); - vecsim_stl::vector neighbour_orig_neighbours_set(curElementCount, false, this->allocator); - for (size_t j = 0; j < neighbor_level.getNumLinks(); j++) { - idType cand = neighbor_level.getLinkAtPos(j); - neighbour_orig_neighbours_set[cand] = true; - // Don't add the removed element to the candidates, nor nodes that are neighbors of the - // original deleted element and will also be added to the candidates set. - if (cand != element_internal_id && !neighbours_bitmap[cand]) { - candidate_ids.push_back(cand); - } - } - // Put the deleted element's neighbours in the candidates. - for (size_t j = 0; j < node_level.getNumLinks(); j++) { - // Don't put the neighbor itself in his own candidates and nor marked deleted elements that - // were not neighbors before. - idType cand = node_level.getLinkAtPos(j); - if (cand != neighbour_id && - (!isMarkedDeleted(cand) || neighbour_orig_neighbours_set[cand])) { - candidate_ids.push_back(cand); - } - } - - size_t Mcurmax = level ? M : M0; - if (candidate_ids.size() > Mcurmax) { - // We need to filter the candidates by the heuristic. - candidatesList candidates(this->allocator); - candidates.reserve(candidate_ids.size()); - auto neighbours_data = getDataByInternalId(neighbour_id); - for (auto candidate_id : candidate_ids) { - candidates.emplace_back( - this->calcDistance(getDataByInternalId(candidate_id), neighbours_data), - candidate_id); - } - - candidate_ids.clear(); - auto ¬_chosen_candidates = candidate_ids; // rename and reuse the vector - getNeighborsByHeuristic2(candidates, Mcurmax, not_chosen_candidates); - - neighbor_level.setLinks(candidates); - - // Update unidirectional incoming edges w.r.t. the edges that were removed. - for (auto node_id : not_chosen_candidates) { - if (neighbour_orig_neighbours_set[node_id]) { - // if the node id (the neighbour's neighbour to be removed) - // wasn't pointing to the neighbour (edge was one directional), - // we should remove it from the node's incoming edges. - // otherwise, edge turned from bidirectional to one directional, - // and it should be saved in the neighbor's incoming edges. - auto &node_level_data = getElementLevelData(node_id, level); - if (!node_level_data.removeIncomingUnidirectionalEdgeIfExists(neighbour_id)) { - neighbor_level.newIncomingUnidirectionalEdge(node_id); - } - } - } - } else { - // We don't need to filter the candidates - just update the edges. - neighbor_level.setLinks(candidate_ids); - } - - // Updates for the new edges created - for (size_t i = 0; i < neighbor_level.getNumLinks(); i++) { - idType node_id = neighbor_level.getLinkAtPos(i); - if (!neighbour_orig_neighbours_set[node_id]) { - ElementLevelData &node_level = getElementLevelData(node_id, level); - // If the node has an edge to the neighbour as well, remove it from the incoming nodes - // of the neighbour. Otherwise, we need to update the edge as unidirectional incoming. - bool bidirectional_edge = false; - for (size_t j = 0; j < node_level.getNumLinks(); j++) { - if (node_level.getLinkAtPos(j) == neighbour_id) { - // Swap the last element with the current one (equivalent to removing the - // neighbor from the list) - this should always succeed and return true. - bool res = neighbor_level.removeIncomingUnidirectionalEdgeIfExists(node_id); - (void)res; - assert(res && "The edge should be in the incoming unidirectional edges"); - bidirectional_edge = true; - break; - } - } - if (!bidirectional_edge) { - node_level.newIncomingUnidirectionalEdge(neighbour_id); - } - } - } -} - -template -void HNSWIndex::replaceEntryPoint() { - idType old_entry_point_id = entrypointNode; - auto *old_entry_point = getGraphDataByInternalId(old_entry_point_id); - - // Sets an (arbitrary) new entry point, after deleting the current entry point. - while (old_entry_point_id == entrypointNode) { - // Use volatile for this variable, so that in case we would have to busy wait for this - // element to finish its indexing, the compiler will not use optimizations. Otherwise, - // the compiler might evaluate 'isInProcess(candidate_in_process)' once instead of calling - // it multiple times in a busy wait manner, and we'll run into an infinite loop if the - // candidate is in process when we reach the loop. - volatile idType candidate_in_process = INVALID_ID; - - // Go over the entry point's neighbors at the top level. - lockNodeLinks(old_entry_point); - ElementLevelData &old_ep_level = getElementLevelData(old_entry_point, maxLevel); - // Tries to set the (arbitrary) first neighbor as the entry point which is not deleted, - // if exists. - for (size_t i = 0; i < old_ep_level.getNumLinks(); i++) { - if (!isMarkedDeleted(old_ep_level.getLinkAtPos(i))) { - if (!isInProcess(old_ep_level.getLinkAtPos(i))) { - entrypointNode = old_ep_level.getLinkAtPos(i); - unlockNodeLinks(old_entry_point); - return; - } else { - // Store this candidate which is currently being inserted into the graph in - // case we won't find other candidate at the top level. - candidate_in_process = old_ep_level.getLinkAtPos(i); - } - } - } - unlockNodeLinks(old_entry_point); - - // If there is no neighbors in the current level, check for any vector at - // this level to be the new entry point. - idType cur_id = 0; - for (DataBlock &graph_data_block : graphDataBlocks) { - size_t size = graph_data_block.getLength(); - for (size_t i = 0; i < size; i++) { - auto cur_element = (ElementGraphData *)graph_data_block.getElement(i); - if (cur_element->toplevel == maxLevel && cur_id != old_entry_point_id && - !isMarkedDeleted(cur_id)) { - // Found a non element in the current max level. - if (!isInProcess(cur_id)) { - entrypointNode = cur_id; - return; - } else if (candidate_in_process == INVALID_ID) { - // This element is still in process, and there hasn't been another candidate - // in process that has found in this level. - candidate_in_process = cur_id; - } - } - cur_id++; - } - } - // If we only found candidates which are in process at this level, do busy wait until they - // are done being processed (this should happen in very rare cases...). Since - // candidate_in_process was declared volatile, we can be sure that isInProcess is called in - // every iteration. - if (candidate_in_process != INVALID_ID) { - while (isInProcess(candidate_in_process)) - ; - entrypointNode = candidate_in_process; - return; - } - // If we didn't find any vector at the top level, decrease the maxLevel and try again, - // until we find a new entry point, or the index is empty. - assert(old_entry_point_id == entrypointNode); - maxLevel--; - if ((int)maxLevel < 0) { - maxLevel = HNSW_INVALID_LEVEL; - entrypointNode = INVALID_ID; - } - } -} - -template -void HNSWIndex::SwapLastIdWithDeletedId(idType element_internal_id, - ElementGraphData *last_element, - const void *last_element_data) { - // Swap label - this is relevant when the last element's label exists (it is not marked as - // deleted). - if (!isMarkedDeleted(curElementCount)) { - replaceIdOfLabel(getExternalLabel(curElementCount), element_internal_id, curElementCount); - } - - // Swap neighbours - for (size_t level = 0; level <= last_element->toplevel; level++) { - auto &cur_level = getElementLevelData(last_element, level); - - // Go over the neighbours that also points back to the last element whose is going to - // change, and update the id. - for (size_t i = 0; i < cur_level.getNumLinks(); i++) { - idType neighbour_id = cur_level.getLinkAtPos(i); - ElementLevelData &neighbor_level = getElementLevelData(neighbour_id, level); - - bool bidirectional_edge = false; - for (size_t j = 0; j < neighbor_level.getNumLinks(); j++) { - // if the edge is bidirectional, update for this neighbor - if (neighbor_level.getLinkAtPos(j) == curElementCount) { - bidirectional_edge = true; - neighbor_level.setLinkAtPos(j, element_internal_id); - break; - } - } - - // If this edge is uni-directional, we should update the id in the neighbor's - // incoming edges. - if (!bidirectional_edge) { - neighbor_level.swapNodeIdInIncomingEdges(curElementCount, element_internal_id); - } - } - - // Next, go over the rest of incoming edges (the ones that are not bidirectional) and make - // updates. - for (auto incoming_edge : cur_level.getIncomingEdges()) { - ElementLevelData &incoming_neighbor_level = getElementLevelData(incoming_edge, level); - for (size_t j = 0; j < incoming_neighbor_level.getNumLinks(); j++) { - if (incoming_neighbor_level.getLinkAtPos(j) == curElementCount) { - incoming_neighbor_level.setLinkAtPos(j, element_internal_id); - break; - } - } - } - } - - // Move the last element's data to the deleted element's place - auto element = getGraphDataByInternalId(element_internal_id); - memcpy((void *)element, last_element, this->elementGraphDataSize); - - auto data = getDataByInternalId(element_internal_id); - memcpy((void *)data, last_element_data, this->getStoredDataSize()); - - this->idToMetaData[element_internal_id] = this->idToMetaData[curElementCount]; - - if (curElementCount == this->entrypointNode) { - this->entrypointNode = element_internal_id; - } -} - -// This function is greedily searching for the closest candidate to the given data point at the -// given level, starting at the given node. It sets `curObj` to the closest node found, and -// `curDist` to the distance to this node. If `running_query` is true, the search will check for -// timeout and return if it has occurred. `timeoutCtx` and `rc` must be valid if `running_query` is -// true. *Note that we assume that level is higher than 0*. Also, if we're not running a query (we -// are searching neighbors for a new vector), then bestCand should be a non-deleted element! -template -template -void HNSWIndex::greedySearchLevel(const void *vector_data, size_t level, - idType &bestCand, DistType &curDist, - void *timeoutCtx, - VecSimQueryReply_Code *rc) const { - bool changed; - // Don't allow choosing a deleted node as an entry point upon searching for neighbors - // candidates (that is, we're NOT running a query, but inserting a new vector). - idType bestNonDeletedCand = bestCand; - - do { - if (running_query && VECSIM_TIMEOUT(timeoutCtx)) { - *rc = VecSim_QueryReply_TimedOut; - bestCand = INVALID_ID; - return; - } - - changed = false; - auto *element = getGraphDataByInternalId(bestCand); - lockNodeLinks(element); - ElementLevelData &node_level_data = getElementLevelData(element, level); - - for (int i = 0; i < node_level_data.getNumLinks(); i++) { - idType candidate = node_level_data.getLinkAtPos(i); - assert(candidate < this->curElementCount && "candidate error: out of index range"); - if (isInProcess(candidate)) { - continue; - } - DistType d = this->calcDistance(vector_data, getDataByInternalId(candidate)); - if (d < curDist) { - curDist = d; - bestCand = candidate; - changed = true; - // Run this code only for non-query code - update the best non deleted cand as well. - // Upon running a query, we don't mind having a deleted element as an entry point - // for the next level, as eventually we return non-deleted elements in level 0. - if (!running_query && !isMarkedDeleted(candidate)) { - bestNonDeletedCand = bestCand; - } - } - } - unlockNodeLinks(element); - } while (changed); - if (!running_query) { - bestCand = bestNonDeletedCand; - } -} - -template -vecsim_stl::vector -HNSWIndex::safeCollectAllNodeIncomingNeighbors(idType node_id) const { - vecsim_stl::vector incoming_neighbors(this->allocator); - - auto element = getGraphDataByInternalId(node_id); - for (size_t level = 0; level <= element->toplevel; level++) { - // Save the node neighbor's in the current level while holding its neighbors lock. - lockNodeLinks(element); - auto &node_level_data = getElementLevelData(element, level); - // Store the deleted element's neighbours. - auto neighbors_copy = node_level_data.copyLinks(); - unlockNodeLinks(element); - - // Go over the neighbours and collect tho ones that also points back to the removed node. - for (auto neighbour_id : neighbors_copy) { - // Hold the neighbor's lock while we are going over its neighbors. - auto *neighbor = getGraphDataByInternalId(neighbour_id); - lockNodeLinks(neighbor); - ElementLevelData &neighbour_level_data = getElementLevelData(neighbor, level); - - for (size_t j = 0; j < neighbour_level_data.getNumLinks(); j++) { - // A bidirectional edge was found - this connection should be repaired. - if (neighbour_level_data.getLinkAtPos(j) == node_id) { - incoming_neighbors.emplace_back(neighbour_id, (unsigned short)level); - break; - } - } - unlockNodeLinks(neighbor); - } - - // Next, collect the rest of incoming edges (the ones that are not bidirectional) in the - // current level to repair them. - lockNodeLinks(element); - for (auto incoming_edge : node_level_data.getIncomingEdges()) { - incoming_neighbors.emplace_back(incoming_edge, (unsigned short)level); - } - unlockNodeLinks(element); - } - return incoming_neighbors; -} - -template -void HNSWIndex::resizeIndexCommon(size_t new_max_elements) { - assert(new_max_elements % this->blockSize == 0 && - "new_max_elements must be a multiple of blockSize"); - this->log(VecSimCommonStrings::LOG_VERBOSE_STRING, "Resizing HNSW index from %zu to %zu", - idToMetaData.capacity(), new_max_elements); - resizeLabelLookup(new_max_elements); - visitedNodesHandlerPool.resize(new_max_elements); - assert(idToMetaData.capacity() == idToMetaData.size()); - idToMetaData.resize(new_max_elements); - idToMetaData.shrink_to_fit(); - assert(idToMetaData.capacity() == idToMetaData.size()); -} - -template -void HNSWIndex::growByBlock() { - assert(this->maxElements % this->blockSize == 0); - assert(this->maxElements == indexSize()); - assert(graphDataBlocks.size() == this->maxElements / this->blockSize); - assert(idToMetaData.capacity() == maxElements || - idToMetaData.capacity() == maxElements + this->blockSize); - - this->log(VecSimCommonStrings::LOG_VERBOSE_STRING, - "Updating HNSW index capacity from %zu to %zu", maxElements, - maxElements + this->blockSize); - maxElements += this->blockSize; - - graphDataBlocks.emplace_back(this->blockSize, this->elementGraphDataSize, this->allocator); - - if (idToMetaData.capacity() == indexSize()) { - resizeIndexCommon(maxElements); - } -} - -template -void HNSWIndex::shrinkByBlock() { - assert(this->maxElements >= this->blockSize); - assert(this->maxElements % this->blockSize == 0); - - if (indexSize() % this->blockSize == 0) { - this->log(VecSimCommonStrings::LOG_VERBOSE_STRING, - "Updating HNSW index capacity from %zu to %zu", maxElements, - maxElements - this->blockSize); - graphDataBlocks.pop_back(); - assert(graphDataBlocks.size() == indexSize() / this->blockSize); - - // assuming idToMetaData reflects the capacity of the heavy reallocation containers. - if (indexSize() == 0) { - resizeIndexCommon(0); - } else if (idToMetaData.capacity() >= (indexSize() + 2 * this->blockSize)) { - assert(this->maxElements + this->blockSize == idToMetaData.capacity()); - resizeIndexCommon(idToMetaData.capacity() - this->blockSize); - } - - // Take the lower bound into account. - maxElements -= this->blockSize; - } -} - -template -void HNSWIndex::mutuallyUpdateForRepairedNode( - idType node_id, size_t level, vecsim_stl::vector &nodes_to_update, - vecsim_stl::vector &chosen_neighbors, size_t max_M_cur) { - - // Acquire the required locks for the updates, after sorting the nodes to update - // (to avoid deadlocks) - nodes_to_update.push_back(node_id); - std::sort(nodes_to_update.begin(), nodes_to_update.end()); - size_t nodes_to_update_count = nodes_to_update.size(); - for (size_t i = 0; i < nodes_to_update_count; i++) { - lockNodeLinks(nodes_to_update[i]); - } - - ElementLevelData &node_level = getElementLevelData(node_id, level); - - // Perform mutual updates: go over the node's neighbors and overwrite the neighbors to remove - // that are still exist. - size_t node_neighbors_idx = 0; - for (size_t i = 0; i < node_level.getNumLinks(); i++) { - if (!std::binary_search(nodes_to_update.begin(), nodes_to_update.end(), - node_level.getLinkAtPos(i))) { - // The repaired node added a new neighbor that we didn't account for before in the - // meantime - leave it as is. - node_level.setLinkAtPos(node_neighbors_idx++, node_level.getLinkAtPos(i)); - continue; - } - // Check if the current neighbor is in the chosen neighbors list, and remove it from there - // if so. - if (chosen_neighbors.remove(node_level.getLinkAtPos(i))) { - // A chosen neighbor is already connected to the node - leave it as is. - node_level.setLinkAtPos(node_neighbors_idx++, node_level.getLinkAtPos(i)); - continue; - } - // Now we know that we are looking at a neighbor that needs to be removed. - mutuallyRemoveNeighborAtPos(node_level, level, node_id, i); - } - - // Go over the chosen new neighbors that are not connected yet and perform updates. - for (auto chosen_id : chosen_neighbors) { - if (node_neighbors_idx == max_M_cur) { - // Cannot add more new neighbors, we reached the capacity. - this->log(VecSimCommonStrings::LOG_DEBUG_STRING, - "Couldn't add all the chosen new nodes upon updating %u, as we reached the" - " maximum number of neighbors per node", - node_id); - break; - } - // We don't add new neighbors for deleted nodes - if node_id is deleted we can finish. - // Also, don't add new neighbors to a node who is currently being indexed in parallel, as it - // may choose the same element as its neighbor right after the repair is done and connect it - // to it, and have a duplicate neighbor as a result. - if (isMarkedDeleted(node_id) || isInProcess(node_id)) { - break; - } - // If this specific new neighbor is deleted, we don't add this connection and continue. - // Also, don't add a new node whose being indexed in parallel, as it may choose this node - // as its neighbor and create a double connection (then this node will have a duplicate - // neighbor). - if (isMarkedDeleted(chosen_id) || isInProcess(chosen_id)) { - continue; - } - node_level.setLinkAtPos(node_neighbors_idx++, chosen_id); - // If the node is in the chosen new node incoming edges, there is a unidirectional - // connection from the chosen node to the repaired node that turns into bidirectional. Then, - // remove it from the incoming edges set. Otherwise, the edge is created unidirectional, so - // we add it to the unidirectional edges set. Note: we assume that all updates occur - // mutually and atomically, then can rely on this assumption. - auto &chosen_node_level_data = getElementLevelData(chosen_id, level); - if (!node_level.removeIncomingUnidirectionalEdgeIfExists(chosen_id)) { - chosen_node_level_data.newIncomingUnidirectionalEdge(node_id); - } - } - // Done updating the node's neighbors. - node_level.setNumLinks(node_neighbors_idx); - for (size_t i = 0; i < nodes_to_update_count; i++) { - unlockNodeLinks(nodes_to_update[i]); - } -} - -template -void HNSWIndex::repairNodeConnections(idType node_id, size_t level) { - - vecsim_stl::vector neighbors_candidate_ids(this->allocator); - // Use bitmaps for fast accesses: - // node_orig_neighbours_set is used to differentiate between the neighbors that will *not* be - // selected by the heuristics - only the ones that were originally neighbors should be removed. - vecsim_stl::vector node_orig_neighbours_set(maxElements, false, this->allocator); - // neighbors_candidates_set is used to store the nodes that were already collected as - // candidates, so we will not collect them again as candidates if we run into them from another - // path. - vecsim_stl::vector neighbors_candidates_set(maxElements, false, this->allocator); - vecsim_stl::vector deleted_neighbors(this->allocator); - - // Go over the repaired node neighbors, collect the non-deleted ones to be neighbors candidates - // after the repair as well. - auto *element = getGraphDataByInternalId(node_id); - lockNodeLinks(element); - ElementLevelData &node_level_data = getElementLevelData(element, level); - for (size_t j = 0; j < node_level_data.getNumLinks(); j++) { - node_orig_neighbours_set[node_level_data.getLinkAtPos(j)] = true; - // Don't add the removed element to the candidates. - if (isMarkedDeleted(node_level_data.getLinkAtPos(j))) { - deleted_neighbors.push_back(node_level_data.getLinkAtPos(j)); - continue; - } - neighbors_candidates_set[node_level_data.getLinkAtPos(j)] = true; - neighbors_candidate_ids.push_back(node_level_data.getLinkAtPos(j)); - } - unlockNodeLinks(element); - - // If there are not deleted neighbors at that point the repair job has already been made by - // another parallel job, and there is no need to repair the node anymore. - if (deleted_neighbors.empty()) { - return; - } - - // Hold 3 sets of nodes - all the original neighbors at that point to later (potentially) - // update, subset of these which are the chosen neighbors nodes, and a subset of the original - // neighbors that are going to be removed. - vecsim_stl::vector nodes_to_update(this->allocator); - vecsim_stl::vector chosen_neighbors(this->allocator); - - // Go over the deleted nodes and collect their neighbors to the candidates set. - for (idType deleted_neighbor_id : deleted_neighbors) { - nodes_to_update.push_back(deleted_neighbor_id); - - auto *neighbor = getGraphDataByInternalId(deleted_neighbor_id); - lockNodeLinks(neighbor); - ElementLevelData &neighbor_level_data = getElementLevelData(neighbor, level); - - for (size_t j = 0; j < neighbor_level_data.getNumLinks(); j++) { - // Don't add removed elements to the candidates, nor nodes that are already in the - // candidates set, nor the original node to repair itself. - if (isMarkedDeleted(neighbor_level_data.getLinkAtPos(j)) || - neighbors_candidates_set[neighbor_level_data.getLinkAtPos(j)] || - neighbor_level_data.getLinkAtPos(j) == node_id) { - continue; - } - neighbors_candidates_set[neighbor_level_data.getLinkAtPos(j)] = true; - neighbors_candidate_ids.push_back(neighbor_level_data.getLinkAtPos(j)); - } - unlockNodeLinks(neighbor); - } - - size_t max_M_cur = level ? M : M0; - if (neighbors_candidate_ids.size() > max_M_cur) { - // We have more candidates than the maximum number of neighbors, so we need to select which - // ones to keep. We use the heuristic to select the neighbors, and then remove the ones that - // were not originally neighbors. - candidatesList neighbors_candidates(this->allocator); - neighbors_candidates.reserve(neighbors_candidate_ids.size()); - const void *node_data = getDataByInternalId(node_id); - for (idType candidate : neighbors_candidate_ids) { - neighbors_candidates.emplace_back( - this->calcDistance(getDataByInternalId(candidate), node_data), candidate); - } - vecsim_stl::vector not_chosen_neighbors(this->allocator); - getNeighborsByHeuristic2(neighbors_candidates, max_M_cur, not_chosen_neighbors); - - for (idType not_chosen_neighbor : not_chosen_neighbors) { - if (node_orig_neighbours_set[not_chosen_neighbor]) { - nodes_to_update.push_back(not_chosen_neighbor); - } - } - - for (auto &neighbor : neighbors_candidates) { - chosen_neighbors.push_back(neighbor.second); - nodes_to_update.push_back(neighbor.second); - } - } else { - // We have less candidates than the maximum number of neighbors, so we choose them all, and - // extend the nodes to update with them. - chosen_neighbors.swap(neighbors_candidate_ids); - nodes_to_update.insert(nodes_to_update.end(), chosen_neighbors.begin(), - chosen_neighbors.end()); - } - - // Perform the actual updates for the node and the impacted neighbors while holding the nodes' - // locks. - mutuallyUpdateForRepairedNode(node_id, level, nodes_to_update, chosen_neighbors, max_M_cur); -} - -template -void HNSWIndex::mutuallyRemoveNeighborAtPos(ElementLevelData &node_level, - size_t level, idType node_id, - size_t pos) { - // Now we know that we are looking at a neighbor that needs to be removed. - auto removed_node = node_level.getLinkAtPos(pos); - ElementLevelData &removed_node_level = getElementLevelData(removed_node, level); - // Perform the mutual update: - // if the removed node id (the node's neighbour to be removed) - // wasn't pointing to the node (i.e., the edge was uni-directional), - // we should remove the current neighbor from the node's incoming edges. - // otherwise, the edge turned from bidirectional to uni-directional, so we insert it to the - // neighbour's incoming edges set. Note: we assume that every update is performed atomically - // mutually, so it should be sufficient to look at the removed node's incoming edges set - // alone. - if (!removed_node_level.removeIncomingUnidirectionalEdgeIfExists(node_id)) { - node_level.newIncomingUnidirectionalEdge(removed_node); - } -} - -template -void HNSWIndex::insertElementToGraph(idType element_id, - size_t element_max_level, - idType entry_point, - size_t global_max_level, - const void *vector_data) { - - idType curr_element = entry_point; - DistType cur_dist = std::numeric_limits::max(); - size_t max_common_level; - if (element_max_level < global_max_level) { - max_common_level = element_max_level; - cur_dist = this->calcDistance(vector_data, getDataByInternalId(curr_element)); - for (auto level = static_cast(global_max_level); - level > static_cast(element_max_level); level--) { - // this is done for the levels which are above the max level - // to which we are going to insert the new element. We do - // a greedy search in the graph starting from the entry point - // at each level, and move on with the closest element we can find. - // When there is no improvement to do, we take a step down. - greedySearchLevel(vector_data, level, curr_element, cur_dist); - } - } else { - max_common_level = global_max_level; - } - - for (auto level = static_cast(max_common_level); level >= 0; level--) { - candidatesMaxHeap top_candidates = - searchLayer(curr_element, vector_data, level, efConstruction); - // If the entry point was marked deleted between iterations, we may receive an empty - // candidates set. - if (!top_candidates.empty()) { - curr_element = mutuallyConnectNewElement(element_id, top_candidates, level); - } - } -} - -/** - * Ctor / Dtor - */ -/* typedef struct { - VecSimType type; // Datatype to index. - size_t dim; // Vector's dimension. - VecSimMetric metric; // Distance metric to use in the index. - size_t initialCapacity; // Deprecated and not respected. - size_t blockSize; - size_t M; - size_t efConstruction; - size_t efRuntime; - double epsilon; -} HNSWParams; */ -template -HNSWIndex::HNSWIndex(const HNSWParams *params, - const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, - size_t random_seed) - : VecSimIndexAbstract(abstractInitParams, components), - VecSimIndexTombstone(), maxElements(0), graphDataBlocks(this->allocator), - idToMetaData(this->allocator), visitedNodesHandlerPool(0, this->allocator) { - - M = params->M ? params->M : HNSW_DEFAULT_M; - M0 = M * 2; - if (M0 > UINT16_MAX) - throw std::runtime_error("HNSW index parameter M is too large: argument overflow"); - - efConstruction = params->efConstruction ? params->efConstruction : HNSW_DEFAULT_EF_C; - efConstruction = std::max(efConstruction, M); - ef = params->efRuntime ? params->efRuntime : HNSW_DEFAULT_EF_RT; - epsilon = params->epsilon > 0.0 ? params->epsilon : HNSW_DEFAULT_EPSILON; - - curElementCount = 0; - numMarkedDeleted = 0; - - // initializations for special treatment of the first node - entrypointNode = INVALID_ID; - maxLevel = HNSW_INVALID_LEVEL; - - if (M <= 1) - throw std::runtime_error("HNSW index parameter M cannot be 1"); - mult = 1 / log(1.0 * M); - levelGenerator.seed(random_seed); - - elementGraphDataSize = sizeof(ElementGraphData) + sizeof(idType) * M0; - levelDataSize = sizeof(ElementLevelData) + sizeof(idType) * M; -} - -template -HNSWIndex::~HNSWIndex() { - for (idType id = 0; id < curElementCount; id++) { - getGraphDataByInternalId(id)->destroy(this->levelDataSize, this->allocator); - } -} - -/** - * Index API functions - */ - -template -void HNSWIndex::removeAndSwap(idType internalId) { - // Sanity check - the id to remove cannot be the entry point, as it should have been replaced - // upon marking it as deleted. - assert(entrypointNode != internalId); - auto element = getGraphDataByInternalId(internalId); - - // Remove the deleted id form the relevant incoming edges sets in which it appears. - for (size_t level = 0; level <= element->toplevel; level++) { - ElementLevelData &cur_level = getElementLevelData(element, level); - for (size_t i = 0; i < cur_level.getNumLinks(); i++) { - ElementLevelData &neighbour = getElementLevelData(cur_level.getLinkAtPos(i), level); - // Note that in case of in-place delete, we might have not accounted for this edge in - // in the unidirectional edges, since there is no point in keeping it there temporarily - // (we know we will get here and remove this deleted id permanently). - // However, upon asynchronous delete, this should always succeed since we do update - // the incoming edges in the mutual update even for deleted elements. - bool res = neighbour.removeIncomingUnidirectionalEdgeIfExists(internalId); - // Assert the logical condition of: is_marked_deleted(id) => res==True. - (void)res; - assert((!isMarkedDeleted(internalId) || res) && "The edge should be in the incoming " - "unidirectional edges"); - } - } - - // Free the element's resources - element->destroy(this->levelDataSize, this->allocator); - - // We can say now that the element has removed completely from index. - --curElementCount; - - // Get the last element's metadata and data. - // If we are deleting the last element, we already destroyed it's metadata. - auto *last_element_data = getDataByInternalId(curElementCount); - DataBlock &last_gd_block = graphDataBlocks.back(); - auto last_element = (ElementGraphData *)last_gd_block.removeAndFetchLastElement(); - - // Swap the last id with the deleted one, and invalidate the last id data. - if (curElementCount != internalId) { - SwapLastIdWithDeletedId(internalId, last_element, last_element_data); - } - - // If we need to free a complete block and there is at least one block between the - // capacity and the size. - this->vectors->removeElement(curElementCount); - shrinkByBlock(); -} - -template -void HNSWIndex::removeAndSwapMarkDeletedElement(idType internalId) { - removeAndSwap(internalId); - // element is permanently removed from the index, it is no longer counted as marked deleted. - --numMarkedDeleted; -} - -template -void HNSWIndex::removeVectorInPlace(const idType element_internal_id) { - - vecsim_stl::vector neighbours_bitmap(this->allocator); - - // Go over the element's nodes at every level and repair the effected connections. - auto element = getGraphDataByInternalId(element_internal_id); - for (size_t level = 0; level <= element->toplevel; level++) { - ElementLevelData &cur_level = getElementLevelData(element, level); - // Reset the neighbours' bitmap for the current level. - neighbours_bitmap.assign(curElementCount, false); - // Store the deleted element's neighbours set in a bitmap for fast access. - for (size_t j = 0; j < cur_level.getNumLinks(); j++) { - neighbours_bitmap[cur_level.getLinkAtPos(j)] = true; - } - // Go over the neighbours that also points back to the removed point and make a local - // repair. - for (size_t i = 0; i < cur_level.getNumLinks(); i++) { - idType neighbour_id = cur_level.getLinkAtPos(i); - ElementLevelData &neighbor_level = getElementLevelData(neighbour_id, level); - - bool bidirectional_edge = false; - for (size_t j = 0; j < neighbor_level.getNumLinks(); j++) { - // If the edge is bidirectional, do repair for this neighbor. - if (neighbor_level.getLinkAtPos(j) == element_internal_id) { - bidirectional_edge = true; - repairConnectionsForDeletion(element_internal_id, neighbour_id, cur_level, - neighbor_level, level, neighbours_bitmap); - break; - } - } - - // If this edge is uni-directional, we should remove the element from the neighbor's - // incoming edges. - if (!bidirectional_edge) { - // This should always return true (remove should succeed). - bool res = - neighbor_level.removeIncomingUnidirectionalEdgeIfExists(element_internal_id); - (void)res; - assert(res && "The edge should be in the incoming unidirectional edges"); - } - } - - // Next, go over the rest of incoming edges (the ones that are not bidirectional) and make - // repairs. - for (auto incoming_edge : cur_level.getIncomingEdges()) { - repairConnectionsForDeletion(element_internal_id, incoming_edge, cur_level, - getElementLevelData(incoming_edge, level), level, - neighbours_bitmap); - } - } - if (entrypointNode == element_internal_id) { - // Replace entry point if needed. - assert(element->toplevel == maxLevel); - replaceEntryPoint(); - } - // Finally, remove the element from the index and make a swap with the last internal id to - // avoid fragmentation and reclaim memory when needed. - removeAndSwap(element_internal_id); -} - -// Store the new element in the global data structures and keep the new state. In multithreaded -// scenario, the index data guard should be held by the caller (exclusive lock). -template -HNSWAddVectorState HNSWIndex::storeNewElement(labelType label, - const void *vector_data) { - if (isCapacityFull()) { - growByBlock(); - } - HNSWAddVectorState state{}; - - // Choose randomly the maximum level in which the new element will be in the index. - state.elementMaxLevel = getRandomLevel(mult); - - // Access and update the index global data structures with the new element meta-data. - state.newElementId = curElementCount++; - - // Create the new element's graph metadata. - // We must assign manually enough memory on the stack and not just declare an `ElementGraphData` - // variable, since it has a flexible array member. - auto tmpData = this->allocator->allocate_unique(this->elementGraphDataSize); - memset(tmpData.get(), 0, this->elementGraphDataSize); - ElementGraphData *cur_egd = (ElementGraphData *)(tmpData.get()); - // Allocate memory (inside `ElementGraphData` constructor) for the links in higher levels and - // initialize this memory to zeros. The reason for doing it here is that we might mark this - // vector as deleted BEFORE we finish its indexing. In that case, we will collect the incoming - // edges to this element in every level, and try to access its link lists in higher levels. - // Therefore, we allocate it here and initialize it with zeros, (otherwise we might crash...) - try { - new (cur_egd) ElementGraphData(state.elementMaxLevel, levelDataSize, this->allocator); - } catch (std::runtime_error &e) { - this->log(VecSimCommonStrings::LOG_WARNING_STRING, - "Error - allocating memory for new element failed due to low memory"); - throw e; - } - - // Insert the new element to the data block - this->vectors->addElement(vector_data, state.newElementId); - this->graphDataBlocks.back().addElement(cur_egd); - // We mark id as in process *before* we set it in the label lookup, so that IN_PROCESS flag is - // set when checking if label . - this->idToMetaData[state.newElementId] = ElementMetaData(label); - setVectorId(label, state.newElementId); - - state.currMaxLevel = (int)maxLevel; - state.currEntryPoint = entrypointNode; - if (state.elementMaxLevel > state.currMaxLevel) { - if (entrypointNode == INVALID_ID && maxLevel != HNSW_INVALID_LEVEL) { - throw std::runtime_error("Internal error - inserting the first element to the graph," - " but the current max level is not INVALID"); - } - // If the new elements max level is higher than the maximum level the currently exists in - // the graph, update the max level and set the new element as entry point. - entrypointNode = state.newElementId; - maxLevel = state.elementMaxLevel; - } - return state; -} - -template -HNSWAddVectorState HNSWIndex::storeVector(const void *vector_data, - const labelType label) { - HNSWAddVectorState state{}; - - this->lockIndexDataGuard(); - state = storeNewElement(label, vector_data); - if (state.currMaxLevel >= state.elementMaxLevel) { - this->unlockIndexDataGuard(); - } - - return state; -} - -template -void HNSWIndex::indexVector(const void *vector_data, const labelType label, - const HNSWAddVectorState &state) { - // Deconstruct the state variables from the auxiliaryCtx. prev_entry_point and prev_max_level - // are the entry point and index max level at the point of time when the element was stored, and - // they may (or may not) have changed due to the insertion. - auto [new_element_id, element_max_level, prev_entry_point, prev_max_level] = state; - - // This condition only means that we are not inserting the first (non-deleted) element (for the - // first element we do nothing - we don't need to connect to it). - if (prev_entry_point != INVALID_ID) { - // Start scanning the graph from the current entry point. - insertElementToGraph(new_element_id, element_max_level, prev_entry_point, prev_max_level, - vector_data); - } - unmarkInProcess(new_element_id); -} - -template -void HNSWIndex::appendVector(const void *vector_data, const labelType label) { - - ProcessedBlobs processedBlobs = this->preprocess(vector_data); - HNSWAddVectorState state = this->storeVector(processedBlobs.getStorageBlob(), label); - - this->indexVector(processedBlobs.getQueryBlob(), label, state); - - if (state.currMaxLevel < state.elementMaxLevel) { - // No external auxiliaryCtx, so it's this function responsibility to release the lock. - this->unlockIndexDataGuard(); - } -} - -template -auto HNSWIndex::safeGetEntryPointState() const { - std::shared_lock lock(indexDataGuard); - return std::make_pair(entrypointNode, maxLevel); -} - -template -idType HNSWIndex::searchBottomLayerEP(const void *query_data, void *timeoutCtx, - VecSimQueryReply_Code *rc) const { - *rc = VecSim_QueryReply_OK; - - auto [curr_element, max_level] = safeGetEntryPointState(); - if (curr_element == INVALID_ID) - return curr_element; // index is empty. - - DistType cur_dist = this->calcDistance(query_data, getDataByInternalId(curr_element)); - for (size_t level = max_level; level > 0 && curr_element != INVALID_ID; --level) { - greedySearchLevel(query_data, level, curr_element, cur_dist, timeoutCtx, rc); - } - return curr_element; -} - -template -candidatesLabelsMaxHeap * -HNSWIndex::searchBottomLayer_WithTimeout(idType ep_id, const void *data_point, - size_t ef, size_t k, void *timeoutCtx, - VecSimQueryReply_Code *rc) const { - - auto *visited_nodes_handler = getVisitedList(); - tag_t visited_tag = visited_nodes_handler->getFreshTag(); - - candidatesLabelsMaxHeap *top_candidates = getNewMaxPriorityQueue(); - candidatesMaxHeap candidate_set(this->allocator); - - DistType lowerBound; - if (!isMarkedDeleted(ep_id)) { - // If ep is not marked as deleted, get its distance and set lower bound and heaps - // accordingly - DistType dist = this->calcDistance(data_point, getDataByInternalId(ep_id)); - lowerBound = dist; - top_candidates->emplace(dist, getExternalLabel(ep_id)); - candidate_set.emplace(-dist, ep_id); - } else { - // If ep is marked as deleted, set initial lower bound to max, and don't insert to top - // candidates heap - lowerBound = std::numeric_limits::max(); - candidate_set.emplace(-lowerBound, ep_id); - } - - visited_nodes_handler->tagNode(ep_id, visited_tag); - - while (!candidate_set.empty()) { - pair curr_el_pair = candidate_set.top(); - - if ((-curr_el_pair.first) > lowerBound && top_candidates->size() >= ef) { - break; - } - if (VECSIM_TIMEOUT(timeoutCtx)) { - returnVisitedList(visited_nodes_handler); - *rc = VecSim_QueryReply_TimedOut; - return top_candidates; - } - candidate_set.pop(); - - processCandidate(curr_el_pair.second, data_point, 0, ef, - visited_nodes_handler->getElementsTags(), visited_tag, *top_candidates, - candidate_set, lowerBound); - } - returnVisitedList(visited_nodes_handler); - while (top_candidates->size() > k) { - top_candidates->pop(); - } - *rc = VecSim_QueryReply_OK; - return top_candidates; -} - -template -VecSimQueryReply *HNSWIndex::topKQuery(const void *query_data, size_t k, - VecSimQueryParams *queryParams) const { - - auto rep = new VecSimQueryReply(this->allocator); - this->lastMode = STANDARD_KNN; - - if (curElementCount == 0 || k == 0) { - return rep; - } - - auto processed_query_ptr = this->preprocessQuery(query_data); - const void *processed_query = processed_query_ptr.get(); - void *timeoutCtx = nullptr; - - // Get original efRuntime and store it. - size_t query_ef = this->ef; - - if (queryParams) { - timeoutCtx = queryParams->timeoutCtx; - if (queryParams->hnswRuntimeParams.efRuntime != 0) { - query_ef = queryParams->hnswRuntimeParams.efRuntime; - } - } - - idType bottom_layer_ep = searchBottomLayerEP(processed_query, timeoutCtx, &rep->code); - if (VecSim_OK != rep->code || bottom_layer_ep == INVALID_ID) { - // Although we checked that the index is not empty (curElementCount == 0), it might be - // that another thread deleted all the elements or didn't finish inserting the first element - // yet. Anyway, we observed that the index is empty, so we return an empty result list. - return rep; - } - - // We now oun the results heap, we need to free (delete) it when we done - candidatesLabelsMaxHeap *results; - results = searchBottomLayer_WithTimeout(bottom_layer_ep, processed_query, std::max(query_ef, k), - k, timeoutCtx, &rep->code); - - if (VecSim_OK == rep->code) { - rep->results.resize(results->size()); - for (auto result = rep->results.rbegin(); result != rep->results.rend(); result++) { - std::tie(result->score, result->id) = results->top(); - results->pop(); - } - } - delete results; - return rep; -} - -template -VecSimQueryResultContainer HNSWIndex::searchRangeBottomLayer_WithTimeout( - idType ep_id, const void *data_point, double epsilon, DistType radius, void *timeoutCtx, - VecSimQueryReply_Code *rc) const { - - *rc = VecSim_QueryReply_OK; - auto res_container = getNewResultsContainer(10); // arbitrary initial cap. - - auto *visited_nodes_handler = getVisitedList(); - tag_t visited_tag = visited_nodes_handler->getFreshTag(); - - candidatesMaxHeap candidate_set(this->allocator); - - // Set the initial effective-range to be at least the distance from the entry-point. - DistType ep_dist, dynamic_range, dynamic_range_search_boundaries; - if (isMarkedDeleted(ep_id)) { - // If ep is marked as deleted, set initial ranges to max - ep_dist = std::numeric_limits::max(); - dynamic_range_search_boundaries = dynamic_range = ep_dist; - } else { - // If ep is not marked as deleted, get its distance and set ranges accordingly - ep_dist = this->calcDistance(data_point, getDataByInternalId(ep_id)); - dynamic_range = ep_dist; - if (ep_dist <= radius) { - // Entry-point is within the radius - add it to the results. - res_container->emplace(getExternalLabel(ep_id), ep_dist); - dynamic_range = radius; // to ensure that dyn_range >= radius. - } - dynamic_range_search_boundaries = dynamic_range * (1.0 + epsilon); - } - - candidate_set.emplace(-ep_dist, ep_id); - visited_nodes_handler->tagNode(ep_id, visited_tag); - - while (!candidate_set.empty()) { - pair curr_el_pair = candidate_set.top(); - // If the best candidate is outside the dynamic range in more than epsilon (relatively) - we - // finish the search. - - if ((-curr_el_pair.first) > dynamic_range_search_boundaries) { - break; - } - if (VECSIM_TIMEOUT(timeoutCtx)) { - *rc = VecSim_QueryReply_TimedOut; - break; - } - candidate_set.pop(); - - // Decrease the effective range, but keep dyn_range >= radius. - if (-curr_el_pair.first < dynamic_range && -curr_el_pair.first >= radius) { - dynamic_range = -curr_el_pair.first; - dynamic_range_search_boundaries = dynamic_range * (1.0 + epsilon); - } - - // Go over the candidate neighbours, add them to the candidates list if they are within the - // epsilon environment of the dynamic range, and add them to the results if they are in the - // requested radius. - // Here we send the radius as double to match the function arguments type. - processCandidate_RangeSearch( - curr_el_pair.second, data_point, 0, epsilon, visited_nodes_handler->getElementsTags(), - visited_tag, res_container, candidate_set, dynamic_range_search_boundaries, radius); - } - returnVisitedList(visited_nodes_handler); - return res_container->get_results(); -} - -template -VecSimQueryReply *HNSWIndex::rangeQuery(const void *query_data, double radius, - VecSimQueryParams *queryParams) const { - - auto rep = new VecSimQueryReply(this->allocator); - this->lastMode = RANGE_QUERY; - - if (curElementCount == 0) { - return rep; - } - auto processed_query_ptr = this->preprocessQuery(query_data); - const void *processed_query = processed_query_ptr.get(); - void *timeoutCtx = nullptr; - - double query_epsilon = epsilon; - if (queryParams) { - timeoutCtx = queryParams->timeoutCtx; - if (queryParams->hnswRuntimeParams.epsilon != 0.0) { - query_epsilon = queryParams->hnswRuntimeParams.epsilon; - } - } - - idType bottom_layer_ep = searchBottomLayerEP(processed_query, timeoutCtx, &rep->code); - // Although we checked that the index is not empty (curElementCount == 0), it might be - // that another thread deleted all the elements or didn't finish inserting the first element - // yet. Anyway, we observed that the index is empty, so we return an empty result list. - if (VecSim_OK != rep->code || bottom_layer_ep == INVALID_ID) { - return rep; - } - - // search bottom layer - // Here we send the radius as double to match the function arguments type. - rep->results = searchRangeBottomLayer_WithTimeout( - bottom_layer_ep, processed_query, query_epsilon, radius, timeoutCtx, &rep->code); - return rep; -} - -template -VecSimIndexDebugInfo HNSWIndex::debugInfo() const { - - VecSimIndexDebugInfo info; - info.commonInfo = this->getCommonInfo(); - auto [ep_id, max_level] = this->safeGetEntryPointState(); - - info.commonInfo.basicInfo.algo = VecSimAlgo_HNSWLIB; - info.hnswInfo.M = this->getM(); - info.hnswInfo.efConstruction = this->getEfConstruction(); - info.hnswInfo.efRuntime = this->getEf(); - info.hnswInfo.epsilon = this->epsilon; - info.hnswInfo.max_level = max_level; - info.hnswInfo.entrypoint = ep_id != INVALID_ID ? getExternalLabel(ep_id) : INVALID_LABEL; - info.hnswInfo.visitedNodesPoolSize = this->visitedNodesHandlerPool.getPoolSize(); - info.hnswInfo.numberOfMarkedDeletedNodes = this->getNumMarkedDeleted(); - return info; -} - -template -VecSimIndexBasicInfo HNSWIndex::basicInfo() const { - VecSimIndexBasicInfo info = this->getBasicInfo(); - info.algo = VecSimAlgo_HNSWLIB; - info.isTiered = false; - return info; -} - -template -VecSimDebugInfoIterator *HNSWIndex::debugInfoIterator() const { - VecSimIndexDebugInfo info = this->debugInfo(); - // For readability. Update this number when needed. - size_t numberOfInfoFields = 17; - auto *infoIterator = new VecSimDebugInfoIterator(numberOfInfoFields, this->allocator); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::ALGORITHM_STRING, - .fieldType = INFOFIELD_STRING, - .fieldValue = {FieldValue{ - .stringValue = VecSimAlgo_ToString(info.commonInfo.basicInfo.algo)}}}); - - this->addCommonInfoToIterator(infoIterator, info.commonInfo); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::BLOCK_SIZE_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.commonInfo.basicInfo.blockSize}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::HNSW_M_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.hnswInfo.M}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::HNSW_EF_CONSTRUCTION_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.hnswInfo.efConstruction}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::HNSW_EF_RUNTIME_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.hnswInfo.efRuntime}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::HNSW_MAX_LEVEL, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.hnswInfo.max_level}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::HNSW_ENTRYPOINT, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.hnswInfo.entrypoint}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::EPSILON_STRING, - .fieldType = INFOFIELD_FLOAT64, - .fieldValue = {FieldValue{.floatingPointValue = info.hnswInfo.epsilon}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::NUM_MARKED_DELETED, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.hnswInfo.numberOfMarkedDeletedNodes}}}); - - return infoIterator; -} - -template -bool HNSWIndex::preferAdHocSearch(size_t subsetSize, size_t k, - bool initial_check) const { - // This heuristic is based on sklearn decision tree classifier (with 20 leaves nodes) - - // see scripts/HNSW_batches_clf.py - size_t index_size = this->indexSize(); - // Referring to too large subset size as if it was the maximum possible size. - subsetSize = std::min(subsetSize, index_size); - - size_t d = this->dim; - size_t M = this->getM(); - float r = (index_size == 0) ? 0.0f : (float)(subsetSize) / (float)this->indexLabelCount(); - bool res; - - // node 0 - if (index_size <= 30000) { - // node 1 - if (index_size <= 5500) { - // node 5 - res = true; - } else { - // node 6 - if (r <= 0.17) { - // node 11 - res = true; - } else { - // node 12 - if (k <= 12) { - // node 13 - if (d <= 55) { - // node 17 - res = false; - } else { - // node 18 - if (M <= 10) { - // node 19 - res = false; - } else { - // node 20 - res = true; - } - } - } else { - // node 14 - res = true; - } - } - } - } else { - // node 2 - if (r < 0.07) { - // node 3 - if (index_size <= 750000) { - // node 15 - res = true; - } else { - // node 16 - if (k <= 7) { - // node 21 - res = false; - } else { - // node 22 - if (r <= 0.03) { - // node 23 - res = true; - } else { - // node 24 - res = false; - } - } - } - } else { - // node 4 - if (d <= 75) { - // node 7 - res = false; - } else { - // node 8 - if (k <= 12) { - // node 9 - if (r <= 0.21) { - // node 27 - if (M <= 57) { - // node 29 - if (index_size <= 75000) { - // node 31 - res = true; - } else { - // node 32 - res = false; - } - } else { - // node 30 - res = true; - } - } else { - // node 28 - res = false; - } - } else { - // node 10 - if (M <= 10) { - // node 25 - if (r <= 0.17) { - // node 33 - res = true; - } else { - // node 34 - res = false; - } - } else { - // node 26 - if (index_size <= 300000) { - // node 35 - res = true; - } else { - // node 36 - if (r <= 0.17) { - // node 37 - res = true; - } else { - // node 38 - res = false; - } - } - } - } - } - } - } - // Set the mode - if this isn't the initial check, we switched mode form batches to ad-hoc. - this->lastMode = - res ? (initial_check ? HYBRID_ADHOC_BF : HYBRID_BATCHES_TO_ADHOC_BF) : HYBRID_BATCHES; - return res; -} - -/********************************************** Debug commands ******************************/ - -template -VecSimDebugCommandCode -HNSWIndex::getHNSWElementNeighbors(size_t label, int ***neighborsData) { - std::shared_lock lock(indexDataGuard); - // Assume single value index. TODO: support for multi as well. - if (this->isMultiValue()) { - return VecSimDebugCommandCode_MultiNotSupported; - } - auto ids = this->getElementIds(label); - if (ids.empty()) { - return VecSimDebugCommandCode_LabelNotExists; - } - idType id = ids[0]; - auto graph_data = this->getGraphDataByInternalId(id); - lockNodeLinks(graph_data); - *neighborsData = new int *[graph_data->toplevel + 2]; - for (size_t level = 0; level <= graph_data->toplevel; level++) { - auto &level_data = this->getElementLevelData(graph_data, level); - assert(level_data.getNumLinks() <= (level > 0 ? this->getM() : 2 * this->getM())); - (*neighborsData)[level] = new int[level_data.getNumLinks() + 1]; - (*neighborsData)[level][0] = level_data.getNumLinks(); - for (size_t i = 0; i < level_data.getNumLinks(); i++) { - (*neighborsData)[level][i + 1] = (int)idToMetaData.at(level_data.getLinkAtPos(i)).label; - } - } - (*neighborsData)[graph_data->toplevel + 1] = nullptr; - unlockNodeLinks(graph_data); - return VecSimDebugCommandCode_OK; -} - -#ifdef BUILD_TESTS -#include "hnsw_serializer_impl.h" -#endif diff --git a/src/VecSim/algorithms/hnsw/hnsw_base_tests_friends.h b/src/VecSim/algorithms/hnsw/hnsw_base_tests_friends.h deleted file mode 100644 index cf289dffa..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_base_tests_friends.h +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/friend_test_decl.h" -INDEX_TEST_FRIEND_CLASS(HNSWTest_test_dynamic_hnsw_info_iterator_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTest_preferAdHocOptimization_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_test_dynamic_hnsw_info_iterator_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_preferAdHocOptimization_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_MultiBatchIteratorHeapLogic_Test) -INDEX_TEST_FRIEND_CLASS(IndexAllocatorTest_testIncomingEdgesSet_Test) -INDEX_TEST_FRIEND_CLASS(IndexAllocatorTest_test_hnsw_reclaim_memory_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTest_markDelete_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_markDelete_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTest_allMarkedDeletedLevel_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTestParallel) -INDEX_TEST_FRIEND_CLASS(HNSWTestParallel_parallelSearchKnn_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTestParallel_parallelSearchCombined_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteFromHNSWMultiLevels_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_deleteFromHNSWWithRepairJobExec_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_swapJobBasic_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteInplaceAvoidUpdatedMarkedDeleted_Test) diff --git a/src/VecSim/algorithms/hnsw/hnsw_batch_iterator.h b/src/VecSim/algorithms/hnsw/hnsw_batch_iterator.h deleted file mode 100644 index e99642868..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_batch_iterator.h +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/batch_iterator.h" -#include "hnsw.h" -#include "VecSim/spaces/spaces.h" -#include "VecSim/query_result_definitions.h" -#include "VecSim/utils/vec_utils.h" -#include "VecSim/algorithms/hnsw/visited_nodes_handler.h" -#include - -using spaces::dist_func_t; - -template -class HNSW_BatchIterator : public VecSimBatchIterator { -protected: - const HNSWIndex *index; - VisitedNodesHandler *visited_list; // Pointer to the hnsw visitedList structure. - tag_t visited_tag; // Used to mark nodes that were scanned. - idType entry_point; // Internal id of the node to begin the scan from. - bool depleted; - size_t ef; // EF Runtime value for this query. - - // Data structure that holds the search state between iterations. - template - using candidatesMinHeap = vecsim_stl::min_priority_queue; - - DistType lower_bound; - candidatesMinHeap top_candidates_extras; - candidatesMinHeap candidates; - - VecSimQueryReply_Code scanGraphInternal(candidatesLabelsMaxHeap *top_candidates); - candidatesLabelsMaxHeap *scanGraph(VecSimQueryReply_Code *rc); - virtual inline void prepareResults(VecSimQueryReply *rep, - candidatesLabelsMaxHeap *top_candidates, - size_t n_res) = 0; - inline void visitNode(idType node_id) { - this->visited_list->tagNode(node_id, this->visited_tag); - } - inline bool hasVisitedNode(idType node_id) const { - return this->visited_list->getNodeTag(node_id) == this->visited_tag; - } - - virtual inline void fillFromExtras(candidatesLabelsMaxHeap *top_candidates) = 0; - virtual inline void updateHeaps(candidatesLabelsMaxHeap *top_candidates, - DistType dist, idType id) = 0; - -public: - HNSW_BatchIterator(void *query_vector, const HNSWIndex *index, - VecSimQueryParams *queryParams, std::shared_ptr allocator); - - VecSimQueryReply *getNextResults(size_t n_res, VecSimQueryReply_Order order) override; - - bool isDepleted() override; - - void reset() override; - - virtual ~HNSW_BatchIterator() { index->returnVisitedList(this->visited_list); } -}; - -/******************** Ctor / Dtor **************/ - -template -HNSW_BatchIterator::HNSW_BatchIterator( - void *query_vector, const HNSWIndex *index, VecSimQueryParams *queryParams, - std::shared_ptr allocator) - : VecSimBatchIterator(query_vector, queryParams ? queryParams->timeoutCtx : nullptr, - std::move(allocator)), - index(index), depleted(false), top_candidates_extras(this->allocator), - candidates(this->allocator) { - - this->entry_point = INVALID_ID; // temporary until we store the entry point to level 0. - // Use "fresh" tag to mark nodes that were visited along the search in some iteration. - this->visited_list = index->getVisitedList(); - this->visited_tag = this->visited_list->getFreshTag(); - - if (queryParams && queryParams->hnswRuntimeParams.efRuntime > 0) { - this->ef = queryParams->hnswRuntimeParams.efRuntime; - } else { - this->ef = this->index->getEf(); - } -} - -/******************** Implementation **************/ - -template -VecSimQueryReply_Code HNSW_BatchIterator::scanGraphInternal( - candidatesLabelsMaxHeap *top_candidates) { - while (!candidates.empty()) { - DistType curr_node_dist = candidates.top().first; - idType curr_node_id = candidates.top().second; - - __builtin_prefetch(this->index->getGraphDataByInternalId(curr_node_id)); - __builtin_prefetch(this->index->getMetaDataAddress(curr_node_id)); - // If the closest element in the candidates set is further than the furthest element in the - // top candidates set, and we have enough results, we finish the search. - if (curr_node_dist > this->lower_bound && top_candidates->size() >= this->ef) { - break; - } - if (VECSIM_TIMEOUT(this->getTimeoutCtx())) { - return VecSim_QueryReply_TimedOut; - } - // Checks if we need to add the current id to the top_candidates heap, - // and updates the extras heap accordingly. - if (!index->isMarkedDeleted(curr_node_id)) - updateHeaps(top_candidates, curr_node_dist, curr_node_id); - - // Take the current node out of the candidates queue and go over his neighbours. - candidates.pop(); - auto *node_graph_data = this->index->getGraphDataByInternalId(curr_node_id); - this->index->lockNodeLinks(node_graph_data); - ElementLevelData &node_level_data = this->index->getElementLevelData(node_graph_data, 0); - if (node_level_data.numLinks > 0) { - - // Pre-fetch first candidate tag address. - __builtin_prefetch(visited_list->getElementsTags() + node_level_data.links[0]); - // // Pre-fetch first candidate data block address. - __builtin_prefetch(index->getDataByInternalId(node_level_data.links[0])); - - for (linkListSize j = 0; j < node_level_data.numLinks - 1; j++) { - idType candidate_id = node_level_data.links[j]; - - // Pre-fetch next candidate tag address. - __builtin_prefetch(visited_list->getElementsTags() + node_level_data.links[j + 1]); - // Pre-fetch next candidate data block address. - __builtin_prefetch(index->getDataByInternalId(node_level_data.links[j + 1])); - - if (this->hasVisitedNode(candidate_id)) { - continue; - } - this->visitNode(candidate_id); - - const char *candidate_data = this->index->getDataByInternalId(candidate_id); - DistType candidate_dist = - this->index->calcDistance(this->getQueryBlob(), (const void *)candidate_data); - - candidates.emplace(candidate_dist, candidate_id); - } - // Running the last candidate outside the loop to avoid prefetching invalid candidate - idType candidate_id = node_level_data.links[node_level_data.numLinks - 1]; - - if (!this->hasVisitedNode(candidate_id)) { - this->visitNode(candidate_id); - - const char *candidate_data = this->index->getDataByInternalId(candidate_id); - DistType candidate_dist = - this->index->calcDistance(this->getQueryBlob(), (const void *)candidate_data); - - candidates.emplace(candidate_dist, candidate_id); - } - } - this->index->unlockNodeLinks(curr_node_id); - } - return VecSim_QueryReply_OK; -} - -template -candidatesLabelsMaxHeap * -HNSW_BatchIterator::scanGraph(VecSimQueryReply_Code *rc) { - - candidatesLabelsMaxHeap *top_candidates = this->index->getNewMaxPriorityQueue(); - if (this->entry_point == INVALID_ID) { - this->depleted = true; - return top_candidates; - } - - // In the first iteration, add the entry point to the empty candidates set. - if (this->getResultsCount() == 0 && this->top_candidates_extras.empty() && - this->candidates.empty()) { - if (!index->isMarkedDeleted(this->entry_point)) { - this->lower_bound = this->index->calcDistance( - this->getQueryBlob(), this->index->getDataByInternalId(this->entry_point)); - } else { - this->lower_bound = std::numeric_limits::max(); - } - this->visitNode(this->entry_point); - candidates.emplace(this->lower_bound, this->entry_point); - } - // Checks that we didn't got timeout between iterations. - if (VECSIM_TIMEOUT(this->getTimeoutCtx())) { - *rc = VecSim_QueryReply_TimedOut; - return top_candidates; - } - - // Move extras from previous iteration to the top candidates. - fillFromExtras(top_candidates); - if (top_candidates->size() == this->ef) { - return top_candidates; - } - *rc = this->scanGraphInternal(top_candidates); - - // If we found fewer results than wanted, mark the search as depleted. - if (top_candidates->size() < this->ef) { - this->depleted = true; - } - return top_candidates; -} - -template -VecSimQueryReply * -HNSW_BatchIterator::getNextResults(size_t n_res, VecSimQueryReply_Order order) { - - auto batch = new VecSimQueryReply(this->allocator); - // If ef_runtime lower than the number of results to return, increase it. Therefore, we assume - // that the number of results that return from the graph scan is at least n_res (if exist). - size_t orig_ef = this->ef; - if (orig_ef < n_res) { - this->ef = n_res; - } - - // In the first iteration, we search the graph from top bottom to find the initial entry point, - // and then we scan the graph to get results (layer 0). - if (this->getResultsCount() == 0) { - idType bottom_layer_ep = this->index->searchBottomLayerEP( - this->getQueryBlob(), this->getTimeoutCtx(), &batch->code); - if (VecSim_OK != batch->code) { - return batch; - } - this->entry_point = bottom_layer_ep; - } - // We ask for at least n_res candidate from the scan. In fact, at most ef results will return, - // and it could be that ef > n_res. - auto *top_candidates = this->scanGraph(&batch->code); - if (VecSim_OK != batch->code) { - delete top_candidates; - return batch; - } - // Move the spare results to the "extras" queue if needed, and create the batch results array. - this->prepareResults(batch, top_candidates, n_res); - delete top_candidates; - - this->updateResultsCount(VecSimQueryReply_Len(batch)); - if (this->getResultsCount() == this->index->indexLabelCount()) { - this->depleted = true; - } - // By default, results are ordered by score. - if (order == BY_ID) { - sort_results_by_id(batch); - } - this->ef = orig_ef; - return batch; -} - -template -bool HNSW_BatchIterator::isDepleted() { - return this->depleted && this->top_candidates_extras.empty(); -} - -template -void HNSW_BatchIterator::reset() { - this->resetResultsCount(); - this->depleted = false; - - // Reset the visited nodes handler. - this->visited_tag = this->visited_list->getFreshTag(); - this->lower_bound = std::numeric_limits::infinity(); - // Clear the queues. - this->candidates = candidatesMinHeap(this->allocator); - this->top_candidates_extras = candidatesMinHeap(this->allocator); -} diff --git a/src/VecSim/algorithms/hnsw/hnsw_multi.h b/src/VecSim/algorithms/hnsw/hnsw_multi.h deleted file mode 100644 index 50ff1a37d..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_multi.h +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "hnsw.h" -#include "hnsw_multi_batch_iterator.h" -#include "VecSim/utils/updatable_heap.h" - -template -class HNSWIndex_Multi : public HNSWIndex { -private: - // Index global state - this should be guarded by the indexDataGuard lock in - // multithreaded scenario. - vecsim_stl::unordered_map> labelLookup; - -#ifdef BUILD_TESTS -#include "VecSim/algorithms/hnsw/hnsw_multi_tests_friends.h" -#endif - - inline void replaceIdOfLabel(labelType label, idType new_id, idType old_id) override; - inline void setVectorId(labelType label, idType id) override { - // Checking if an element with the given label already exists. - // if not, add an empty vector under the new label. - if (labelLookup.find(label) == labelLookup.end()) { - labelLookup.emplace(label, vecsim_stl::vector{this->allocator}); - } - labelLookup.at(label).push_back(id); - } - inline vecsim_stl::vector getElementIds(size_t label) override { - auto it = labelLookup.find(label); - if (it == labelLookup.end()) { - return vecsim_stl::vector{this->allocator}; // return an empty collection - } - return it->second; - } - inline void resizeLabelLookup(size_t new_max_elements) override; - - // Return all the labels in the index - this should be used for computing the number of distinct - // labels in a tiered index. - inline vecsim_stl::set getLabelsSet() const override { - std::shared_lock index_data_lock(this->indexDataGuard); - vecsim_stl::set keys(this->allocator); - for (auto &it : labelLookup) { - keys.insert(it.first); - } - return keys; - }; - - inline double getDistanceFromInternal(labelType label, const void *vector_data) const; - -public: - HNSWIndex_Multi(const HNSWParams *params, const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, size_t random_seed = 100) - : HNSWIndex(params, abstractInitParams, components, random_seed), - labelLookup(this->allocator) {} -#ifdef BUILD_TESTS - // Ctor to be used before loading a serialized index. Can be used from v2 and up. - HNSWIndex_Multi(std::ifstream &input, const HNSWParams *params, - const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, - HNSWSerializer::EncodingVersion version) - : HNSWIndex(input, params, abstractInitParams, components, version), - labelLookup(this->maxElements, this->allocator) {} - - void getDataByLabel(labelType label, - std::vector> &vectors_output) const override { - - auto ids = labelLookup.find(label); - - for (idType id : ids->second) { - auto vec = std::vector(this->dim); - // Only copy the vector data (dim * sizeof(DataType)), not any additional metadata like - // the norm - memcpy(vec.data(), this->getDataByInternalId(id), this->dim * sizeof(DataType)); - vectors_output.push_back(vec); - } - } - - std::vector> getStoredVectorDataByLabel(labelType label) const override { - std::vector> vectors_output; - auto ids = labelLookup.find(label); - - for (idType id : ids->second) { - const char *data = this->getDataByInternalId(id); - - // Create a vector with the full data (including any metadata like norms) - std::vector vec(this->getStoredDataSize()); - memcpy(vec.data(), data, this->getStoredDataSize()); - vectors_output.push_back(std::move(vec)); - } - - return vectors_output; - } -#endif - ~HNSWIndex_Multi() = default; - - inline candidatesLabelsMaxHeap *getNewMaxPriorityQueue() const override { - return new (this->allocator) - vecsim_stl::updatable_max_heap(this->allocator); - } - inline std::unique_ptr - getNewResultsContainer(size_t cap) const override { - return std::unique_ptr( - new (this->allocator) vecsim_stl::unique_results_container(cap, this->allocator)); - } - - inline size_t indexLabelCount() const override; - VecSimBatchIterator *newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const override; - - int deleteVector(labelType label) override; - int addVector(const void *vector_data, labelType label) override; - vecsim_stl::vector markDelete(labelType label) override; - double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override { - return getDistanceFromInternal(label, vector_data); - } - int removeLabel(labelType label) override { return labelLookup.erase(label); } -}; - -/** - * getters and setters of index data - */ - -template -size_t HNSWIndex_Multi::indexLabelCount() const { - return labelLookup.size(); -} - -/** - * helper functions - */ - -template -double HNSWIndex_Multi::getDistanceFromInternal(labelType label, - const void *vector_data) const { - DistType dist = INVALID_SCORE; - - // Check if the label exists in the index, return invalid score if not. - auto it = this->labelLookup.find(label); - if (it == this->labelLookup.end()) { - return dist; - } - - // Get the vector of ids associated with the label. - // Get a copy if `Safe` is true, otherwise get a reference. - auto &IDs = it->second; - - // Iterate over the ids and find the minimum distance. - for (auto id : IDs) { - DistType d = this->calcDistance(this->getDataByInternalId(id), vector_data); - dist = std::fmin(dist, d); - } - - return dist; -} - -template -void HNSWIndex_Multi::replaceIdOfLabel(labelType label, idType new_id, - idType old_id) { - assert(labelLookup.find(label) != labelLookup.end()); - // *Non-trivial code here* - in every iteration we replace the internal id of the previous last - // id that has been swapped with the deleted id. Note that if the old and the new replaced ids - // both belong to the same label, then we are going to delete the new id later on as well, since - // we are currently iterating on this exact array of ids in 'deleteVector'. Hence, the relevant - // part of the vector that should be updated is the "tail" that comes after the position of - // old_id, while the "head" may contain old occurrences of old_id that are irrelevant for the - // future deletions. Therefore, we iterate from end to beginning. For example, assuming we are - // deleting a label that contains the only 3 ids that exist in the index. Hence, we would - // expect the following scenario w.r.t. the ids array: - // [|1, 0, 2] -> [1, |0, 1] -> [1, 0, |0] (where | marks the current position) - auto &ids = labelLookup.at(label); - for (int i = ids.size() - 1; i >= 0; i--) { - if (ids[i] == old_id) { - ids[i] = new_id; - return; - } - } - assert(!"should have found the old id"); -} - -template -void HNSWIndex_Multi::resizeLabelLookup(size_t new_max_elements) { - labelLookup.reserve(new_max_elements); -} - -/** - * Index API functions - */ - -template -int HNSWIndex_Multi::deleteVector(const labelType label) { - int ret = 0; - // check that the label actually exists in the graph, and update the number of elements. - auto ids_it = labelLookup.find(label); - if (ids_it == labelLookup.end()) { - return ret; - } - for (auto &ids = ids_it->second; idType id : ids) { - this->removeVectorInPlace(id); - ret++; - } - labelLookup.erase(label); - return ret; -} - -template -int HNSWIndex_Multi::addVector(const void *vector_data, const labelType label) { - - this->appendVector(vector_data, label); - return 1; // We always add the vector, no overrides in multi. -} - -template -VecSimBatchIterator * -HNSWIndex_Multi::newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const { - // force_copy == true. - auto queryBlobCopy = this->preprocessQuery(queryBlob, true); - - // take ownership of the blob copy and pass it to the batch iterator. - auto *queryBlobCopyPtr = queryBlobCopy.release(); - // Ownership of queryBlobCopy moves to HNSW_BatchIterator that will free it at the end. - return new (this->allocator) HNSWMulti_BatchIterator( - queryBlobCopyPtr, this, queryParams, this->allocator); -} - -/** - * Marks an element with the given label deleted, does NOT really change the current graph. - * @param label - */ -template -vecsim_stl::vector HNSWIndex_Multi::markDelete(labelType label) { - std::unique_lock index_data_lock(this->indexDataGuard); - - auto ids = this->getElementIds(label); - for (idType id : ids) { - this->markDeletedInternal(id); - } - labelLookup.erase(label); - return ids; -} diff --git a/src/VecSim/algorithms/hnsw/hnsw_multi_batch_iterator.h b/src/VecSim/algorithms/hnsw/hnsw_multi_batch_iterator.h deleted file mode 100644 index 8459d4234..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_multi_batch_iterator.h +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "hnsw_batch_iterator.h" - -template -class HNSWMulti_BatchIterator : public HNSW_BatchIterator { -private: - vecsim_stl::unordered_set returned; - - inline void fillFromExtras(candidatesLabelsMaxHeap *top_candidates) override; - inline void prepareResults(VecSimQueryReply *rep, - candidatesLabelsMaxHeap *top_candidates, - size_t n_res) override; - inline void updateHeaps(candidatesLabelsMaxHeap *top_candidates, DistType dist, - idType id) override; - -public: - HNSWMulti_BatchIterator(void *query_vector, const HNSWIndex *index, - VecSimQueryParams *queryParams, - std::shared_ptr allocator) - : HNSW_BatchIterator(query_vector, index, queryParams, allocator), - returned(this->index->indexSize(), this->allocator) {} - - ~HNSWMulti_BatchIterator() override = default; - - void reset() override; -}; - -/******************** Implementation **************/ - -template -void HNSWMulti_BatchIterator::prepareResults( - VecSimQueryReply *rep, candidatesLabelsMaxHeap *top_candidates, size_t n_res) { - - // Put the "spare" results (if exist) in the extra candidates heap. - while (top_candidates->size() > n_res) { - this->top_candidates_extras.emplace(top_candidates->top().first, - top_candidates->top().second); // (distance, label) - top_candidates->pop(); - } - // Return results from the top candidates heap, put them in reverse order in the batch results - // array. - rep->results.resize(top_candidates->size()); - for (auto result = rep->results.rbegin(); result != rep->results.rend(); ++result) { - std::tie(result->score, result->id) = top_candidates->top(); - this->returned.insert(result->id); - top_candidates->pop(); - } -} - -template -void HNSWMulti_BatchIterator::fillFromExtras( - candidatesLabelsMaxHeap *top_candidates) { - while (top_candidates->size() < this->ef && !this->top_candidates_extras.empty()) { - if (returned.find(this->top_candidates_extras.top().second) == returned.end()) { - top_candidates->emplace(this->top_candidates_extras.top().first, - this->top_candidates_extras.top().second); - } - this->top_candidates_extras.pop(); - } -} - -template -void HNSWMulti_BatchIterator::updateHeaps( - candidatesLabelsMaxHeap *top_candidates, DistType dist, idType id) { - - if (this->lower_bound > dist || top_candidates->size() < this->ef) { - labelType label = this->index->getExternalLabel(id); - if (returned.find(label) == returned.end()) { - top_candidates->emplace(dist, label); - if (top_candidates->size() > this->ef) { - // If the top candidates queue is full, pass the "worst" results to the "extras", - // for the next iterations. - this->top_candidates_extras.emplace(top_candidates->top().first, - top_candidates->top().second); - top_candidates->pop(); - } - this->lower_bound = top_candidates->top().first; - } - } -} - -template -void HNSWMulti_BatchIterator::reset() { - this->returned.clear(); - HNSW_BatchIterator::reset(); -} diff --git a/src/VecSim/algorithms/hnsw/hnsw_multi_tests_friends.h b/src/VecSim/algorithms/hnsw/hnsw_multi_tests_friends.h deleted file mode 100644 index dcc87d434..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_multi_tests_friends.h +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/friend_test_decl.h" -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_empty_index_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_search_more_than_there_is_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_indexing_same_vector_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_test_dynamic_hnsw_info_iterator_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_preferAdHocOptimization_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_testSizeEstimation_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_markDelete_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_testSizeEstimation_Test) -INDEX_TEST_FRIEND_CLASS(HNSWMultiTest_removeVectorWithSwaps_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_swapJobBasic_Test) diff --git a/src/VecSim/algorithms/hnsw/hnsw_serialization_utils.h b/src/VecSim/algorithms/hnsw/hnsw_serialization_utils.h deleted file mode 100644 index b807c2977..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_serialization_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include - -#define HNSW_INVALID_META_DATA SIZE_MAX - -typedef struct { - bool valid_state; - long memory_usage; // in bytes - size_t double_connections; - size_t unidirectional_connections; - size_t min_in_degree; - size_t max_in_degree; - size_t connections_to_repair; -} HNSWIndexMetaData; diff --git a/src/VecSim/algorithms/hnsw/hnsw_serializer.cpp b/src/VecSim/algorithms/hnsw/hnsw_serializer.cpp deleted file mode 100644 index 4a0fac9c6..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_serializer.cpp +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#include "hnsw_serializer.h" - -HNSWSerializer::HNSWSerializer(EncodingVersion version) : m_version(version) {} - -HNSWSerializer::EncodingVersion HNSWSerializer::ReadVersion(std::ifstream &input) { - input.seekg(0, std::ifstream::beg); - - EncodingVersion version = EncodingVersion::INVALID; - readBinaryPOD(input, version); - - if (version <= EncodingVersion::DEPRECATED) { - input.close(); - throw std::runtime_error("Cannot load index: deprecated encoding version: " + - std::to_string(static_cast(version))); - } else if (version >= EncodingVersion::INVALID) { - input.close(); - throw std::runtime_error("Cannot load index: bad encoding version: " + - std::to_string(static_cast(version))); - } - return version; -} - -void HNSWSerializer::saveIndex(const std::string &location) { - EncodingVersion version = EncodingVersion::V4; - std::ofstream output(location, std::ios::binary); - writeBinaryPOD(output, version); - saveIndexIMP(output); - output.close(); -} - -HNSWSerializer::EncodingVersion HNSWSerializer::getVersion() const { return m_version; } diff --git a/src/VecSim/algorithms/hnsw/hnsw_serializer.h b/src/VecSim/algorithms/hnsw/hnsw_serializer.h deleted file mode 100644 index af1ae2871..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_serializer.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include -#include -#include "VecSim/utils/serializer.h" - -// Middle layer for HNSW serialization -// Abstract functions should be implemented by the templated HNSW index - -class HNSWSerializer : public Serializer { -public: - enum class EncodingVersion { - DEPRECATED = 2, // Last deprecated version - V3, - V4, - INVALID - }; - - explicit HNSWSerializer(EncodingVersion version = EncodingVersion::V4); - - static EncodingVersion ReadVersion(std::ifstream &input); - - void saveIndex(const std::string &location); - - EncodingVersion getVersion() const; - -protected: - EncodingVersion m_version; - -private: - void saveIndexFields(std::ofstream &output) const = 0; -}; diff --git a/src/VecSim/algorithms/hnsw/hnsw_serializer_declarations.h b/src/VecSim/algorithms/hnsw/hnsw_serializer_declarations.h deleted file mode 100644 index 9a86133a5..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_serializer_declarations.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -// Serializing and tests functions. -public: -HNSWIndex(std::ifstream &input, const HNSWParams *params, - const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, - HNSWSerializer::EncodingVersion version); - -// Validates the connections between vectors -HNSWIndexMetaData checkIntegrity() const; - -// Index memory size might be changed during index saving. -virtual void saveIndexIMP(std::ofstream &output) override; - -// used by index factory to load nodes connections -void restoreGraph(std::ifstream &input, HNSWSerializer::EncodingVersion version); - -private: -// Functions for index saving. -void saveIndexFields(std::ofstream &output) const override; - -void saveGraph(std::ofstream &output) const; - -void saveLevel(std::ofstream &output, ElementLevelData &data) const; -void restoreLevel(std::ifstream &input, ElementLevelData &data, - HNSWSerializer::EncodingVersion version); -void computeIndegreeForAll(); - -// Functions for index loading. -void restoreIndexFields(std::ifstream &input); -void fieldsValidation() const; diff --git a/src/VecSim/algorithms/hnsw/hnsw_serializer_impl.h b/src/VecSim/algorithms/hnsw/hnsw_serializer_impl.h deleted file mode 100644 index 5f9dd8cbf..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_serializer_impl.h +++ /dev/null @@ -1,321 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include "hnsw_serializer.h" - -template -HNSWIndex::HNSWIndex(std::ifstream &input, const HNSWParams *params, - const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, - HNSWSerializer::EncodingVersion version) - : VecSimIndexAbstract(abstractInitParams, components), - HNSWSerializer(version), epsilon(params->epsilon), graphDataBlocks(this->allocator), - idToMetaData(this->allocator), visitedNodesHandlerPool(0, this->allocator) { - - this->restoreIndexFields(input); - this->fieldsValidation(); - - // Since level generator is implementation-defined, we dont read its value from the file. - // We use seed = 200 and not the default value (100) to get different sequence of - // levels value than the loaded index. - levelGenerator.seed(200); - - // Set the initial capacity based on the number of elements in the loaded index. - maxElements = RoundUpInitialCapacity(this->curElementCount, this->blockSize); - this->idToMetaData.resize(maxElements); - this->visitedNodesHandlerPool.resize(maxElements); - - size_t initial_vector_size = maxElements / this->blockSize; - graphDataBlocks.reserve(initial_vector_size); -} - -template -void HNSWIndex::saveIndexIMP(std::ofstream &output) { - this->saveIndexFields(output); - this->saveGraph(output); -} - -template -void HNSWIndex::fieldsValidation() const { - if (this->M > UINT16_MAX / 2) - throw std::runtime_error("HNSW index parameter M is too large: argument overflow"); - if (this->M <= 1) - throw std::runtime_error("HNSW index parameter M cannot be 1 or 0"); -} - -template -HNSWIndexMetaData HNSWIndex::checkIntegrity() const { - HNSWIndexMetaData res = {.valid_state = false, - .memory_usage = -1, - .double_connections = HNSW_INVALID_META_DATA, - .unidirectional_connections = HNSW_INVALID_META_DATA, - .min_in_degree = HNSW_INVALID_META_DATA, - .max_in_degree = HNSW_INVALID_META_DATA, - .connections_to_repair = 0}; - - // Save the current memory usage (before we use additional memory for the integrity check). - res.memory_usage = this->getAllocationSize(); - size_t connections_checked = 0, double_connections = 0, num_deleted = 0, - min_in_degree = SIZE_MAX, max_in_degree = 0; - size_t max_level_in_graph = 0; // including marked deleted elements - for (size_t i = 0; i < this->curElementCount; i++) { - if (this->isMarkedDeleted(i)) { - num_deleted++; - } - if (getGraphDataByInternalId(i)->toplevel > max_level_in_graph) { - max_level_in_graph = getGraphDataByInternalId(i)->toplevel; - } - } - std::vector> inbound_connections_num( - this->curElementCount, std::vector(max_level_in_graph + 1, 0)); - size_t incoming_edges_sets_sizes = 0; - for (size_t i = 0; i < this->curElementCount; i++) { - for (size_t l = 0; l <= getGraphDataByInternalId(i)->toplevel; l++) { - ElementLevelData &cur = this->getElementLevelData(i, l); - std::set s; - for (unsigned int j = 0; j < cur.numLinks; j++) { - // Check if we found an invalid neighbor. - if (cur.links[j] >= this->curElementCount || cur.links[j] == i) { - return res; - } - // If the neighbor has deleted, then this connection should be repaired. - if (isMarkedDeleted(cur.links[j])) { - res.connections_to_repair++; - } - inbound_connections_num[cur.links[j]][l]++; - s.insert(cur.links[j]); - connections_checked++; - - // Check if this connection is bidirectional. - ElementLevelData &other = this->getElementLevelData(cur.links[j], l); - for (int r = 0; r < other.numLinks; r++) { - if (other.links[r] == (idType)i) { - double_connections++; - break; - } - } - } - // Check if a certain neighbor appeared more than once. - if (s.size() != cur.numLinks) { - return res; - } - incoming_edges_sets_sizes += cur.incomingUnidirectionalEdges->size(); - } - } - if (num_deleted != this->numMarkedDeleted) { - return res; - } - - // Validate that each node's in-degree is coherent with the in-degree observed by the - // outgoing edges. - for (size_t i = 0; i < this->curElementCount; i++) { - for (size_t l = 0; l <= getGraphDataByInternalId(i)->toplevel; l++) { - if (inbound_connections_num[i][l] > max_in_degree) { - max_in_degree = inbound_connections_num[i][l]; - } - if (inbound_connections_num[i][l] < min_in_degree) { - min_in_degree = inbound_connections_num[i][l]; - } - } - } - - res.double_connections = double_connections; - res.unidirectional_connections = incoming_edges_sets_sizes; - res.min_in_degree = max_in_degree; - res.max_in_degree = min_in_degree; - if (incoming_edges_sets_sizes + double_connections != connections_checked) { - return res; - } - - res.valid_state = true; - return res; -} - -template -void HNSWIndex::restoreIndexFields(std::ifstream &input) { - // Restore index build parameters - readBinaryPOD(input, this->M); - readBinaryPOD(input, this->M0); - readBinaryPOD(input, this->efConstruction); - - // Restore index search parameter - readBinaryPOD(input, this->ef); - readBinaryPOD(input, this->epsilon); - - // Restore index meta-data - this->elementGraphDataSize = sizeof(ElementGraphData) + sizeof(idType) * this->M0; - this->levelDataSize = sizeof(ElementLevelData) + sizeof(idType) * this->M; - readBinaryPOD(input, this->mult); - - // Restore index state - readBinaryPOD(input, this->curElementCount); - readBinaryPOD(input, this->numMarkedDeleted); - readBinaryPOD(input, this->maxLevel); - readBinaryPOD(input, this->entrypointNode); -} - -template -void HNSWIndex::restoreGraph(std::ifstream &input, - HNSWSerializer::EncodingVersion version) { - // Restore id to metadata vector - labelType label = 0; - elementFlags flags = 0; - for (idType id = 0; id < this->curElementCount; id++) { - readBinaryPOD(input, label); - readBinaryPOD(input, flags); - this->idToMetaData[id].label = label; - this->idToMetaData[id].flags = flags; - - // Restore label lookup by getting the label from data_level0_memory_ - setVectorId(label, id); - } - - // Todo: create vector data container and load the stored data based on the index storage params - // when other storage types will be available. - dynamic_cast(this->vectors) - ->restoreBlocks(input, this->curElementCount, - static_cast(m_version)); - - // Get graph data blocks - ElementGraphData *cur_egt; - auto tmpData = this->getAllocator()->allocate_unique(this->elementGraphDataSize); - size_t toplevel = 0; - size_t num_blocks = dynamic_cast(this->vectors)->numBlocks(); - for (size_t i = 0; i < num_blocks; i++) { - this->graphDataBlocks.emplace_back(this->blockSize, this->elementGraphDataSize, - this->allocator); - unsigned int block_len = 0; - readBinaryPOD(input, block_len); - for (size_t j = 0; j < block_len; j++) { - // Reset tmpData - memset(tmpData.get(), 0, this->elementGraphDataSize); - // Read the current element top level - readBinaryPOD(input, toplevel); - // Allocate space and structs for the current element - try { - new (tmpData.get()) - ElementGraphData(toplevel, this->levelDataSize, this->allocator); - } catch (std::runtime_error &e) { - this->log(VecSimCommonStrings::LOG_WARNING_STRING, - "Error - allocating memory for new element failed due to low memory"); - throw e; - } - // Add the current element to the current block, and update cur_egt to point to it. - this->graphDataBlocks.back().addElement(tmpData.get()); - cur_egt = (ElementGraphData *)this->graphDataBlocks.back().getElement(j); - - // Restore the current element's graph data - for (size_t k = 0; k <= toplevel; k++) { - restoreLevel(input, getElementLevelData(cur_egt, k), version); - } - } - } -} - -template -void HNSWIndex::restoreLevel(std::ifstream &input, ElementLevelData &data, - HNSWSerializer::EncodingVersion version) { - readBinaryPOD(input, data.numLinks); - for (size_t i = 0; i < data.numLinks; i++) { - readBinaryPOD(input, data.links[i]); - } - - // Restore the incoming edges of the current element - unsigned int size; - readBinaryPOD(input, size); - data.incomingUnidirectionalEdges->reserve(size); - idType id = INVALID_ID; - for (size_t i = 0; i < size; i++) { - readBinaryPOD(input, id); - data.incomingUnidirectionalEdges->push_back(id); - } -} - -template -void HNSWIndex::saveIndexFields(std::ofstream &output) const { - // Save index type - writeBinaryPOD(output, VecSimAlgo_HNSWLIB); - - // Save VecSimIndex fields - writeBinaryPOD(output, this->dim); - writeBinaryPOD(output, this->vecType); - writeBinaryPOD(output, this->metric); - writeBinaryPOD(output, this->blockSize); - writeBinaryPOD(output, this->isMulti); - writeBinaryPOD(output, this->maxElements); // This will be used to restore the index initial - // capacity - - // Save index build parameters - writeBinaryPOD(output, this->M); - writeBinaryPOD(output, this->M0); - writeBinaryPOD(output, this->efConstruction); - - // Save index search parameter - writeBinaryPOD(output, this->ef); - writeBinaryPOD(output, this->epsilon); - - // Save index meta-data - writeBinaryPOD(output, this->mult); - - // Save index state - writeBinaryPOD(output, this->curElementCount); - writeBinaryPOD(output, this->numMarkedDeleted); - writeBinaryPOD(output, this->maxLevel); - writeBinaryPOD(output, this->entrypointNode); -} - -template -void HNSWIndex::saveGraph(std::ofstream &output) const { - // Save id to metadata vector - for (idType id = 0; id < this->curElementCount; id++) { - labelType label = this->idToMetaData[id].label; - elementFlags flags = this->idToMetaData[id].flags; - writeBinaryPOD(output, label); - writeBinaryPOD(output, flags); - } - - this->vectors->saveVectorsData(output); - - // Save graph data blocks - for (size_t i = 0; i < this->graphDataBlocks.size(); i++) { - auto &block = this->graphDataBlocks[i]; - unsigned int block_len = block.getLength(); - writeBinaryPOD(output, block_len); - for (size_t j = 0; j < block_len; j++) { - ElementGraphData *cur_element = (ElementGraphData *)block.getElement(j); - writeBinaryPOD(output, cur_element->toplevel); - - // Save all the levels of the current element - for (size_t level = 0; level <= cur_element->toplevel; level++) { - saveLevel(output, getElementLevelData(cur_element, level)); - } - } - } -} - -template -void HNSWIndex::saveLevel(std::ofstream &output, ElementLevelData &data) const { - // Save the links of the current element - writeBinaryPOD(output, data.numLinks); - for (size_t i = 0; i < data.numLinks; i++) { - writeBinaryPOD(output, data.links[i]); - } - - // Save the incoming edges of the current element - unsigned int size = data.incomingUnidirectionalEdges->size(); - writeBinaryPOD(output, size); - for (idType id : *data.incomingUnidirectionalEdges) { - writeBinaryPOD(output, id); - } - - // Shrink the incoming edges vector for integrity check - data.incomingUnidirectionalEdges->shrink_to_fit(); -} diff --git a/src/VecSim/algorithms/hnsw/hnsw_single.h b/src/VecSim/algorithms/hnsw/hnsw_single.h deleted file mode 100644 index 61899a142..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_single.h +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "hnsw.h" -#include "hnsw_single_batch_iterator.h" - -template -class HNSWIndex_Single : public HNSWIndex { -private: - // Index global state - this should be guarded by the indexDataGuard lock in - // multithreaded scenario. - vecsim_stl::unordered_map labelLookup; - -#ifdef BUILD_TESTS -#include "VecSim/algorithms/hnsw/hnsw_single_tests_friends.h" -#endif - - inline void replaceIdOfLabel(labelType label, idType new_id, idType old_id) override; - inline void setVectorId(labelType label, idType id) override { labelLookup[label] = id; } - inline void resizeLabelLookup(size_t new_max_elements) override; - inline vecsim_stl::set getLabelsSet() const override; - inline vecsim_stl::vector getElementIds(size_t label) override; - inline double getDistanceFromInternal(labelType label, const void *vector_data) const; - -public: - HNSWIndex_Single(const HNSWParams *params, const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, - size_t random_seed = 100) - : HNSWIndex(params, abstractInitParams, components, random_seed), - labelLookup(this->allocator) {} -#ifdef BUILD_TESTS - // Ctor to be used before loading a serialized index. Can be used from v2 and up. - HNSWIndex_Single(std::ifstream &input, const HNSWParams *params, - const AbstractIndexInitParams &abstractInitParams, - const IndexComponents &components, - HNSWSerializer::EncodingVersion version) - : HNSWIndex(input, params, abstractInitParams, components, version), - labelLookup(this->maxElements, this->allocator) {} - - void getDataByLabel(labelType label, - std::vector> &vectors_output) const override { - - auto id = labelLookup.at(label); - - auto vec = std::vector(this->dim); - // Only copy the vector data (dim * sizeof(DataType)), not any additional metadata like the - // norm - memcpy(vec.data(), this->getDataByInternalId(id), this->dim * sizeof(DataType)); - vectors_output.push_back(vec); - } - - std::vector> getStoredVectorDataByLabel(labelType label) const override { - std::vector> vectors_output; - auto id = labelLookup.at(label); - const char *data = this->getDataByInternalId(id); - - // Create a vector with the full data (including any metadata like norms) - std::vector vec(this->getStoredDataSize()); - memcpy(vec.data(), data, this->getStoredDataSize()); - vectors_output.push_back(std::move(vec)); - - return vectors_output; - } -#endif - ~HNSWIndex_Single() = default; - - candidatesLabelsMaxHeap *getNewMaxPriorityQueue() const override { - return new (this->allocator) - vecsim_stl::max_priority_queue(this->allocator); - } - std::unique_ptr - getNewResultsContainer(size_t cap) const override { - return std::unique_ptr( - new (this->allocator) vecsim_stl::default_results_container(cap, this->allocator)); - } - size_t indexLabelCount() const override; - VecSimBatchIterator *newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const override; - - int deleteVector(labelType label) override; - int addVector(const void *vector_data, labelType label) override; - vecsim_stl::vector markDelete(labelType label) override; - double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override { - return getDistanceFromInternal(label, vector_data); - } - int removeLabel(labelType label) override { return labelLookup.erase(label); } -}; - -/** - * getters and setters of index data - */ - -template -size_t HNSWIndex_Single::indexLabelCount() const { - return labelLookup.size(); -} - -/** - * helper functions - */ - -// Return all the labels in the index - this should be used for computing the number of distinct -// labels in a tiered index. -template -vecsim_stl::set HNSWIndex_Single::getLabelsSet() const { - std::shared_lock index_data_lock(this->indexDataGuard); - vecsim_stl::set keys(this->allocator); - for (auto &it : labelLookup) { - keys.insert(it.first); - } - return keys; -} - -template -double -HNSWIndex_Single::getDistanceFromInternal(labelType label, - const void *vector_data) const { - - auto it = labelLookup.find(label); - if (it == labelLookup.end()) { - return INVALID_SCORE; - } - idType id = it->second; - - return this->calcDistance(vector_data, this->getDataByInternalId(id)); -} - -template -void HNSWIndex_Single::replaceIdOfLabel(labelType label, idType new_id, - idType old_id) { - labelLookup[label] = new_id; -} - -template -void HNSWIndex_Single::resizeLabelLookup(size_t new_max_elements) { - labelLookup.reserve(new_max_elements); -} - -/** - * Index API functions - */ - -template -int HNSWIndex_Single::deleteVector(const labelType label) { - // Check that the label actually exists in the graph, and update the number of elements. - if (labelLookup.find(label) == labelLookup.end()) { - return 0; - } - idType element_internal_id = labelLookup[label]; - labelLookup.erase(label); - this->removeVectorInPlace(element_internal_id); - return 1; -} - -template -int HNSWIndex_Single::addVector(const void *vector_data, - const labelType label) { - // Checking if an element with the given label already exists. - bool label_exists = labelLookup.find(label) != labelLookup.end(); - if (label_exists) { - // Remove the vector in place if override allowed (in non-async scenario). - deleteVector(label); - } - - this->appendVector(vector_data, label); - return label_exists ? 0 : 1; -} - -template -VecSimBatchIterator * -HNSWIndex_Single::newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const { - // force_copy == true. - auto queryBlobCopy = this->preprocessQuery(queryBlob, true); - - // take ownership of the blob copy and pass it to the batch iterator. - auto *queryBlobCopyPtr = queryBlobCopy.release(); - // Ownership of queryBlobCopy moves to HNSW_BatchIterator that will free it at the end. - return new (this->allocator) HNSWSingle_BatchIterator( - queryBlobCopyPtr, this, queryParams, this->allocator); -} - -/** - * Marks an element with the given label deleted, does NOT really change the current graph. - * @param label - */ -template -vecsim_stl::vector HNSWIndex_Single::markDelete(labelType label) { - std::unique_lock index_data_lock(this->indexDataGuard); - auto internal_ids = this->getElementIds(label); - if (!internal_ids.empty()) { - assert(internal_ids.size() == 1); // expect to have only one id in index of type "single" - this->markDeletedInternal(internal_ids[0]); - labelLookup.erase(label); - } - return internal_ids; -} - -template -inline vecsim_stl::vector -HNSWIndex_Single::getElementIds(size_t label) { - vecsim_stl::vector ids(this->allocator); - auto it = labelLookup.find(label); - if (it == labelLookup.end()) { - return ids; - } - ids.push_back(it->second); - return ids; -} diff --git a/src/VecSim/algorithms/hnsw/hnsw_single_batch_iterator.h b/src/VecSim/algorithms/hnsw/hnsw_single_batch_iterator.h deleted file mode 100644 index ded88ee8e..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_single_batch_iterator.h +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "hnsw_batch_iterator.h" - -template -class HNSWSingle_BatchIterator : public HNSW_BatchIterator { -private: - inline void fillFromExtras(candidatesLabelsMaxHeap *top_candidates) override; - inline void prepareResults(VecSimQueryReply *rep, - candidatesLabelsMaxHeap *top_candidates, - size_t n_res) override; - - inline void updateHeaps(candidatesLabelsMaxHeap *top_candidates, DistType dist, - idType id) override; - -public: - HNSWSingle_BatchIterator(void *query_vector, const HNSWIndex *index, - VecSimQueryParams *queryParams, - std::shared_ptr allocator) - : HNSW_BatchIterator(query_vector, index, queryParams, allocator) {} - - ~HNSWSingle_BatchIterator() override = default; -}; - -/******************** Implementation **************/ - -template -void HNSWSingle_BatchIterator::prepareResults( - VecSimQueryReply *rep, candidatesLabelsMaxHeap *top_candidates, size_t n_res) { - - // Put the "spare" results (if exist) in the extra candidates heap. - while (top_candidates->size() > n_res) { - this->top_candidates_extras.emplace(top_candidates->top()); // (distance, label) - top_candidates->pop(); - } - // Return results from the top candidates heap, put them in reverse order in the batch results - // array. - rep->results.resize(top_candidates->size()); - for (auto result = rep->results.rbegin(); result != rep->results.rend(); ++result) { - std::tie(result->score, result->id) = top_candidates->top(); - top_candidates->pop(); - } -} - -template -void HNSWSingle_BatchIterator::fillFromExtras( - candidatesLabelsMaxHeap *top_candidates) { - while (top_candidates->size() < this->ef && !this->top_candidates_extras.empty()) { - top_candidates->emplace(this->top_candidates_extras.top().first, - this->top_candidates_extras.top().second); - this->top_candidates_extras.pop(); - } -} - -template -void HNSWSingle_BatchIterator::updateHeaps( - candidatesLabelsMaxHeap *top_candidates, DistType dist, idType id) { - if (top_candidates->size() < this->ef) { - top_candidates->emplace(dist, this->index->getExternalLabel(id)); - this->lower_bound = top_candidates->top().first; - } else if (this->lower_bound > dist) { - top_candidates->emplace(dist, this->index->getExternalLabel(id)); - // If the top candidates queue is full, pass the "worst" results to the "extras", - // for the next iterations. - this->top_candidates_extras.emplace(top_candidates->top().first, - top_candidates->top().second); - top_candidates->pop(); - this->lower_bound = top_candidates->top().first; - } -} diff --git a/src/VecSim/algorithms/hnsw/hnsw_single_tests_friends.h b/src/VecSim/algorithms/hnsw/hnsw_single_tests_friends.h deleted file mode 100644 index 9c27adead..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_single_tests_friends.h +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/friend_test_decl.h" -INDEX_TEST_FRIEND_CLASS(HNSWTest_test_dynamic_hnsw_info_iterator_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTest_preferAdHocOptimization_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTest_testSizeEstimation_Test) -INDEX_TEST_FRIEND_CLASS(IndexAllocatorTest_testIncomingEdgesSet_Test) -INDEX_TEST_FRIEND_CLASS(IndexAllocatorTest_test_hnsw_reclaim_memory_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTestParallel_parallelInsertSearch_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_testSizeEstimation_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_swapJobBasic_Test) -friend class BF16HNSWTest_testSizeEstimation_Test; -friend class BF16TieredTest_testSizeEstimation_Test; -friend class FP16HNSWTest_testSizeEstimation_Test; -friend class FP16TieredTest_testSizeEstimation_Test; diff --git a/src/VecSim/algorithms/hnsw/hnsw_tiered.h b/src/VecSim/algorithms/hnsw/hnsw_tiered.h deleted file mode 100644 index f9d94dc52..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_tiered.h +++ /dev/null @@ -1,1198 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include "VecSim/algorithms/brute_force/brute_force_single.h" -#include "VecSim/vec_sim_tiered_index.h" -#include "hnsw.h" -#include "VecSim/index_factories/hnsw_factory.h" - -/** - * Definition of a job that inserts a new vector from flat into HNSW Index. - */ -struct HNSWInsertJob : public AsyncJob { - labelType label; - idType id; - - HNSWInsertJob(std::shared_ptr allocator, labelType label_, idType id_, - JobCallback insertCb, VecSimIndex *index_) - : AsyncJob(allocator, HNSW_INSERT_VECTOR_JOB, insertCb, index_), label(label_), id(id_) {} -}; - -/** - * Definition of a job that swaps last id with a deleted id in HNSW Index after delete operation. - */ -struct HNSWSwapJob : public VecsimBaseObject { - idType deleted_id; - std::atomic_int - pending_repair_jobs_counter; // number of repair jobs left to complete before this job - // is ready to be executed (atomic counter). - HNSWSwapJob(std::shared_ptr allocator, idType deletedId) - : VecsimBaseObject(allocator), deleted_id(deletedId), pending_repair_jobs_counter(0) {} - void setRepairJobsNum(long num_repair_jobs) { pending_repair_jobs_counter = num_repair_jobs; } - int atomicDecreasePendingJobsNum() { - int ret = --pending_repair_jobs_counter; - assert(pending_repair_jobs_counter >= 0); - return ret; - } -}; - -static const size_t DEFAULT_PENDING_SWAP_JOBS_THRESHOLD = DEFAULT_BLOCK_SIZE; -static const size_t MAX_PENDING_SWAP_JOBS_THRESHOLD = 100000; - -/** - * Definition of a job that repairs a certain node's connection in HNSW Index after delete - * operation. - */ -struct HNSWRepairJob : public AsyncJob { - idType node_id; - unsigned short level; - vecsim_stl::vector associatedSwapJobs; - - HNSWRepairJob(std::shared_ptr allocator, idType id_, unsigned short level_, - JobCallback repairCb, VecSimIndex *index_, HNSWSwapJob *swapJob) - : AsyncJob(allocator, HNSW_REPAIR_NODE_CONNECTIONS_JOB, repairCb, index_), node_id(id_), - level(level_), - // Insert the first swap job from which this repair job was created. - associatedSwapJobs(1, swapJob, this->allocator) {} - // In case that a repair job is required for deleting another neighbor of the node, save a - // reference to additional swap job. - void appendAnotherAssociatedSwapJob(HNSWSwapJob *swapJob) { - associatedSwapJobs.push_back(swapJob); - } -}; - -template -class TieredHNSWIndex : public VecSimTieredIndex { -private: - /// Mappings from id/label to associated jobs, for invalidating and update ids if necessary. - // In MULTI, we can have more than one insert job pending per label. - // **This map is protected with the flat buffer lock** - vecsim_stl::unordered_map> labelToInsertJobs; - vecsim_stl::unordered_map> idToRepairJobs; - vecsim_stl::unordered_map idToSwapJob; - - // A mapping to hold invalid jobs, so we can dispose them upon index deletion. - vecsim_stl::unordered_map invalidJobs; - idType currInvalidJobId; // A unique arbitrary identifier for accessing invalid jobs - std::mutex invalidJobsLookupGuard; - - // This threshold is tested upon deleting a label from HNSW, and once the number of deleted - // vectors reached this limit, we apply swap jobs *only for vectors that has no more pending - // repair jobs*, and are ready to be removed from the graph. - size_t pendingSwapJobsThreshold; - size_t readySwapJobs; - - // Protect the both idToRepairJobs lookup and the pending_repair_jobs_counter for the - // associated swap jobs. - std::mutex idToRepairJobsGuard; - - void executeInsertJob(HNSWInsertJob *job); - void executeRepairJob(HNSWRepairJob *job); - - // To be executed synchronously upon deleting a vector, doesn't require a wrapper. Main HNSW - // lock is assumed to be held exclusive here. - void executeSwapJob(idType deleted_id, vecsim_stl::vector &idsToRemove); - - // Execute the ready swap jobs, run no more than 'maxSwapsToRun' jobs (run all of them for -1). - void executeReadySwapJobs(size_t maxSwapsToRun = -1); - - // Wrappers static functions to be sent as callbacks upon creating the jobs (since members - // functions cannot serve as callback, this serve as the "gateway" to the appropriate index). - static void executeInsertJobWrapper(AsyncJob *job); - static void executeRepairJobWrapper(AsyncJob *job); - - inline HNSWIndex *getHNSWIndex() const; - - // Helper function for deleting a vector from the flat buffer (after it has already been - // ingested into HNSW or deleted). This includes removing the corresponding insert job from the - // label-to-insert-jobs lookup. Also, since deletion a vector triggers swapping of the - // internal last id with the deleted vector id, here we update the pending insert job(s) for the - // last id (if needed). This should be called while *flat lock is held* (exclusive lock). - void updateInsertJobInternalId(idType prev_id, idType new_id, labelType label); - - // Helper function for performing in place mark delete of vector(s) associated with a label - // and creating the appropriate repair jobs for the effected connections. This should be called - // while *HNSW shared lock is held* (shared locked). - int deleteLabelFromHNSW(labelType label); - - // Insert a single vector to HNSW. This can be called in both write modes - insert async and - // in-place. For the async mode, we have to release the flat index guard that is held for shared - // ownership (we do it right after we update the HNSW global data and receive the new state). - template - void insertVectorToHNSW(HNSWIndex *hnsw_index, labelType label, - const void *blob); - - // Set an insert/repair job as invalid, put the job pointer in the invalid jobs lookup under - // the current available id, increase it and return it (while holding invalidJobsLookupGuard). - // Returns the id that the job was stored under (to be set in the job id field). - idType setAndSaveInvalidJob(AsyncJob *job); - - // Handle deletion of vector inplace considering that async deletion might occurred beforehand. - int deleteLabelFromHNSWInplace(labelType label); - -#ifdef BUILD_TESTS -#include "VecSim/algorithms/hnsw/hnsw_tiered_tests_friends.h" -#endif - -public: - class TieredHNSW_BatchIterator : public VecSimBatchIterator { - private: - const TieredHNSWIndex *index; - VecSimQueryParams *queryParams; - - VecSimQueryResultContainer flat_results; - VecSimQueryResultContainer hnsw_results; - - VecSimBatchIterator *flat_iterator; - VecSimBatchIterator *hnsw_iterator; - - // On single value indices, this set holds the IDs of the results that were returned from - // the flat buffer. - // On multi value indices, this set holds the IDs of all the results that were returned. - // The difference between the two cases is that on multi value indices, the same ID can - // appear in both indexes and results with different scores, and therefore we can't tell in - // advance when we expect a possibility of a duplicate. - // On single value indices, a duplicate may appear at the same batch (and we will handle it - // when merging the results) Or it may appear in a different batches, first from the flat - // buffer and then from the HNSW, in the cases where a better result if found later in HNSW - // because of the approximate nature of the algorithm. - vecsim_stl::unordered_set returned_results_set; - - private: - template - inline VecSimQueryReply *compute_current_batch(size_t n_res); - inline void filter_irrelevant_results(VecSimQueryResultContainer &); - - public: - TieredHNSW_BatchIterator(const void *query_vector, - const TieredHNSWIndex *index, - VecSimQueryParams *queryParams, - std::shared_ptr allocator); - - ~TieredHNSW_BatchIterator(); - - const void *getQueryBlob() const override { return flat_iterator->getQueryBlob(); } - - VecSimQueryReply *getNextResults(size_t n_res, VecSimQueryReply_Order order) override; - - bool isDepleted() override; - - void reset() override; - -#ifdef BUILD_TESTS - VecSimBatchIterator *getHNSWIterator() { return hnsw_iterator; } -#endif - }; - -public: - TieredHNSWIndex(HNSWIndex *hnsw_index, - BruteForceIndex *bf_index, - const TieredIndexParams &tieredParams, - std::shared_ptr allocator); - virtual ~TieredHNSWIndex(); - - int addVector(const void *blob, labelType label) override; - int deleteVector(labelType label) override; - size_t getNumMarkedDeleted() const override { - return this->getHNSWIndex()->getNumMarkedDeleted(); - } - size_t indexSize() const override; - size_t indexCapacity() const override; - double getDistanceFrom_Unsafe(labelType label, const void *blob) const override; - // Do nothing here, each tier (flat buffer and HNSW) should increase capacity for itself when - // needed. - VecSimIndexDebugInfo debugInfo() const override; - VecSimIndexBasicInfo basicInfo() const override; - VecSimDebugInfoIterator *debugInfoIterator() const override; - VecSimBatchIterator *newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const override { - // The query blob will be processed and copied by the internal indexes's batch iterator. - return new (this->allocator) - TieredHNSW_BatchIterator(queryBlob, this, queryParams, this->allocator); - } - inline void setLastSearchMode(VecSearchMode mode) override { - return this->backendIndex->setLastSearchMode(mode); - } - void runGC() override { - // Run no more than pendingSwapJobsThreshold value jobs. - TIERED_LOG(VecSimCommonStrings::LOG_VERBOSE_STRING, - "running asynchronous GC for tiered HNSW index"); - this->executeReadySwapJobs(this->pendingSwapJobsThreshold); - } - void acquireSharedLocks() override { - this->flatIndexGuard.lock_shared(); - this->mainIndexGuard.lock_shared(); - this->getHNSWIndex()->lockSharedIndexDataGuard(); - } - - void releaseSharedLocks() override { - this->flatIndexGuard.unlock_shared(); - this->mainIndexGuard.unlock_shared(); - this->getHNSWIndex()->unlockSharedIndexDataGuard(); - } - - VecSimDebugCommandCode getHNSWElementNeighbors(size_t label, int ***neighborsData) { - this->mainIndexGuard.lock_shared(); - auto res = this->getHNSWIndex()->getHNSWElementNeighbors(label, neighborsData); - this->mainIndexGuard.unlock_shared(); - return res; - } - -#ifdef BUILD_TESTS - void getDataByLabel(labelType label, std::vector> &vectors_output) const; - size_t indexMetaDataCapacity() const override { - return this->backendIndex->indexMetaDataCapacity() + - this->frontendIndex->indexMetaDataCapacity(); - } -#endif -}; - -/** - ******************************* Implementation ************************** - */ - -/* Helper methods */ -template -void TieredHNSWIndex::executeInsertJobWrapper(AsyncJob *job) { - auto *insert_job = reinterpret_cast(job); - auto *job_index = reinterpret_cast *>(insert_job->index); - job_index->executeInsertJob(insert_job); - delete job; -} - -template -void TieredHNSWIndex::executeRepairJobWrapper(AsyncJob *job) { - auto *repair_job = reinterpret_cast(job); - auto *job_index = reinterpret_cast *>(repair_job->index); - job_index->executeRepairJob(repair_job); - delete job; -} - -template -void TieredHNSWIndex::executeSwapJob(idType deleted_id, - vecsim_stl::vector &idsToRemove) { - // Get the id that was last and was had been swapped with the job's deleted id. - idType prev_last_id = this->getHNSWIndex()->indexSize(); - - // Invalidate repair jobs for the disposed id (if exist), and update the associated swap jobs. - if (idToRepairJobs.find(deleted_id) != idToRepairJobs.end()) { - for (auto &job_it : idToRepairJobs.at(deleted_id)) { - job_it->node_id = this->setAndSaveInvalidJob(job_it); - for (auto &swap_job_it : job_it->associatedSwapJobs) { - if (swap_job_it->atomicDecreasePendingJobsNum() == 0) { - readySwapJobs++; - } - } - } - idToRepairJobs.erase(deleted_id); - } - // Swap the ids in the pending jobs for the current last id (if exist). - if (idToRepairJobs.find(prev_last_id) != idToRepairJobs.end()) { - for (auto &job_it : idToRepairJobs.at(prev_last_id)) { - job_it->node_id = deleted_id; - } - idToRepairJobs.insert({deleted_id, idToRepairJobs.at(prev_last_id)}); - idToRepairJobs.erase(prev_last_id); - } - // Update the swap jobs if the last id also needs a swap, otherwise just collect to deleted id - // to be removed from the swap jobs. - if (prev_last_id != deleted_id && idToSwapJob.find(prev_last_id) != idToSwapJob.end() && - std::find(idsToRemove.begin(), idsToRemove.end(), prev_last_id) == idsToRemove.end()) { - // Update the curr_last_id pending swap job id after the removal that renamed curr_last_id - // with the deleted id. - idsToRemove.push_back(prev_last_id); - idToSwapJob.at(prev_last_id)->deleted_id = deleted_id; - // If id was deleted in-place and there is no swap job for it, this will create a new entry - // in idToSwapJob for the swapped id, otherwise it will update the existing entry. - idToSwapJob[deleted_id] = idToSwapJob.at(prev_last_id); - } else { - idsToRemove.push_back(deleted_id); - } -} - -template -HNSWIndex *TieredHNSWIndex::getHNSWIndex() const { - return dynamic_cast *>(this->backendIndex); -} - -template -void TieredHNSWIndex::executeReadySwapJobs(size_t maxJobsToRun) { - - // Execute swap jobs - acquire hnsw write lock. - this->lockMainIndexGuard(); - TIERED_LOG(VecSimCommonStrings::LOG_VERBOSE_STRING, - "Tiered HNSW index GC: there are %zu ready swap jobs. Start executing %zu swap jobs", - readySwapJobs, std::min(readySwapJobs, maxJobsToRun)); - - vecsim_stl::vector idsToRemove(this->allocator); - idsToRemove.reserve(idToSwapJob.size()); - for (auto &it : idToSwapJob) { - auto *swap_job = it.second; - if (swap_job->pending_repair_jobs_counter.load() == 0) { - // Swap job is ready for execution - execute and delete it. - this->getHNSWIndex()->removeAndSwapMarkDeletedElement(swap_job->deleted_id); - this->executeSwapJob(swap_job->deleted_id, idsToRemove); - delete swap_job; - } - if (maxJobsToRun > 0 && idsToRemove.size() >= maxJobsToRun) { - break; - } - } - for (idType id : idsToRemove) { - idToSwapJob.erase(id); - } - readySwapJobs -= idsToRemove.size(); - TIERED_LOG(VecSimCommonStrings::LOG_VERBOSE_STRING, - "Tiered HNSW index GC: done executing %zu swap jobs", idsToRemove.size()); - this->unlockMainIndexGuard(); -} - -template -int TieredHNSWIndex::deleteLabelFromHNSW(labelType label) { - auto *hnsw_index = getHNSWIndex(); - this->mainIndexGuard.lock_shared(); - - // Get the required data about the relevant ids to delete. - // Internally, this will hold the index data lock. - auto internal_ids = hnsw_index->markDelete(label); - - for (size_t i = 0; i < internal_ids.size(); i++) { - idType id = internal_ids[i]; - vecsim_stl::vector repair_jobs(this->allocator); - auto *swap_job = new (this->allocator) HNSWSwapJob(this->allocator, id); - - // Go over all the deleted element links in every level and create repair jobs. - auto incomingEdges = hnsw_index->safeCollectAllNodeIncomingNeighbors(id); - - // Protect the id->repair_jobs lookup while we update it with the new jobs. - this->idToRepairJobsGuard.lock(); - for (pair &node : incomingEdges) { - bool repair_job_exists = false; - HNSWRepairJob *repair_job = nullptr; - if (idToRepairJobs.find(node.first) != idToRepairJobs.end()) { - for (auto it : idToRepairJobs.at(node.first)) { - if (it->level == node.second) { - // There is already an existing pending repair job for this node due to - // the deletion of another node - avoid creating another job. - repair_job_exists = true; - repair_job = it; - break; - } - } - } else { - // There is no repair jobs at all for this element, create a new array for it. - idToRepairJobs.insert( - {node.first, vecsim_stl::vector(this->allocator)}); - } - if (repair_job_exists) { - repair_job->appendAnotherAssociatedSwapJob(swap_job); - } else { - repair_job = - new (this->allocator) HNSWRepairJob(this->allocator, node.first, node.second, - executeRepairJobWrapper, this, swap_job); - repair_jobs.emplace_back(repair_job); - idToRepairJobs.at(node.first).push_back(repair_job); - } - } - swap_job->setRepairJobsNum(incomingEdges.size()); - if (incomingEdges.size() == 0) { - // No pending repair jobs, so swap jobs is ready from the beginning. - readySwapJobs++; - } - this->idToRepairJobsGuard.unlock(); - - this->submitJobs(repair_jobs); - // Insert the swap job into the swap jobs lookup (for fast update in case that the - // node id is changed due to swap job). - assert(idToSwapJob.find(id) == idToSwapJob.end()); - idToSwapJob[id] = swap_job; - } - this->mainIndexGuard.unlock_shared(); - return internal_ids.size(); -} - -template -void TieredHNSWIndex::updateInsertJobInternalId(idType prev_id, idType new_id, - labelType label) { - // Update the pending job id, due to a swap that was caused after the removal of new_id. - assert(new_id != INVALID_ID && prev_id != INVALID_ID); - auto it = this->labelToInsertJobs.find(label); - if (it != this->labelToInsertJobs.end()) { - // There is a pending job for the label of the swapped last id - update its id. - for (HNSWInsertJob *job_it : it->second) { - if (job_it->id == prev_id) { - job_it->id = new_id; - } - } - } -} - -template -template -void TieredHNSWIndex::insertVectorToHNSW( - HNSWIndex *hnsw_index, labelType label, const void *blob) { - - // Preprocess for storage and indexing in the hnsw index - ProcessedBlobs processed_blobs = hnsw_index->preprocess(blob); - const void *processed_storage_blob = processed_blobs.getStorageBlob(); - const void *processed_for_index = processed_blobs.getQueryBlob(); - - // Acquire the index data lock, so we know what is the exact index size at this time. Acquire - // the main r/w lock before to avoid deadlocks. - this->mainIndexGuard.lock_shared(); - hnsw_index->lockIndexDataGuard(); - // Check if resizing is needed for HNSW index (requires write lock). - if (hnsw_index->isCapacityFull()) { - // Release the inner HNSW data lock before we re-acquire the global HNSW lock. - this->mainIndexGuard.unlock_shared(); - hnsw_index->unlockIndexDataGuard(); - this->lockMainIndexGuard(); - hnsw_index->lockIndexDataGuard(); - - // Hold the index data lock while we store the new element. If the new node's max level is - // higher than the current one, hold the lock through the entire insertion to ensure that - // graph scans will not occur, as they will try access the entry point's neighbors. - // If an index resize is still needed, `storeNewElement` will perform it. This is OK since - // we hold the main index lock for exclusive access. - auto state = hnsw_index->storeNewElement(label, processed_storage_blob); - if constexpr (releaseFlatGuard) { - this->flatIndexGuard.unlock_shared(); - } - - // If we're still holding the index data guard, we cannot take the main index lock for - // shared ownership as it may cause deadlocks, and we also cannot release the main index - // lock between, since we cannot allow swap jobs to happen, as they will make the - // saved state invalid. Hence, we insert the vector with the current exclusive lock held. - if (state.elementMaxLevel <= state.currMaxLevel) { - hnsw_index->unlockIndexDataGuard(); - } - // Take the vector from the flat buffer and insert it to HNSW (overwrite should not occur). - hnsw_index->indexVector(processed_for_index, label, state); - if (state.elementMaxLevel > state.currMaxLevel) { - hnsw_index->unlockIndexDataGuard(); - } - this->unlockMainIndexGuard(); - } else { - // Do the same as above except for changing the capacity, but with *shared* lock held: - // Hold the index data lock while we store the new element. If the new node's max level is - // higher than the current one, hold the lock through the entire insertion to ensure that - // graph scans will not occur, as they will try access the entry point's neighbors. - // At this point we are certain that the index has enough capacity for the new element, and - // this call will not resize the index. - auto state = hnsw_index->storeNewElement(label, processed_storage_blob); - if constexpr (releaseFlatGuard) { - this->flatIndexGuard.unlock_shared(); - } - - if (state.elementMaxLevel <= state.currMaxLevel) { - hnsw_index->unlockIndexDataGuard(); - } - // Take the vector from the flat buffer and insert it to HNSW (overwrite should not occur). - hnsw_index->indexVector(processed_for_index, label, state); - if (state.elementMaxLevel > state.currMaxLevel) { - hnsw_index->unlockIndexDataGuard(); - } - this->mainIndexGuard.unlock_shared(); - } -} - -template -idType TieredHNSWIndex::setAndSaveInvalidJob(AsyncJob *job) { - this->invalidJobsLookupGuard.lock(); - job->isValid = false; - idType curInvalidId = currInvalidJobId++; - this->invalidJobs.insert({curInvalidId, job}); - this->invalidJobsLookupGuard.unlock(); - return curInvalidId; -} - -template -int TieredHNSWIndex::deleteLabelFromHNSWInplace(labelType label) { - auto *hnsw_index = this->getHNSWIndex(); - - auto ids = hnsw_index->getElementIds(label); - // Dispose pending repair and swap jobs for the removed ids. - vecsim_stl::vector idsToRemove(this->allocator); - idsToRemove.reserve(ids.size()); - readySwapJobs += ids.size(); // account for the current ids that are going to be removed. - for (size_t id_ind = 0; id_ind < ids.size(); id_ind++) { - // Get the id in every iteration, since the ids can be swapped in every iteration. - idType id = hnsw_index->getElementIds(label).at(id_ind); - hnsw_index->removeVectorInPlace(id); - this->executeSwapJob(id, idsToRemove); - } - hnsw_index->removeLabel(label); - for (idType id : idsToRemove) { - idToSwapJob.erase(id); - } - readySwapJobs -= idsToRemove.size(); - return ids.size(); -} - -/******************** Job's callbacks **********************************/ -template -void TieredHNSWIndex::executeInsertJob(HNSWInsertJob *job) { - // Note that accessing the job fields should occur with flat index guard held (here and later). - this->flatIndexGuard.lock_shared(); - if (!job->isValid) { - this->flatIndexGuard.unlock_shared(); - // Job has been invalidated in the meantime - nothing to execute, and remove it from the - // lookup. - this->invalidJobsLookupGuard.lock(); - this->invalidJobs.erase(job->id); - this->invalidJobsLookupGuard.unlock(); - return; - } - - HNSWIndex *hnsw_index = this->getHNSWIndex(); - // Copy the vector blob from the flat buffer, so we can release the flat lock while we are - // indexing the vector into HNSW index. - size_t data_size = this->frontendIndex->getStoredDataSize(); - auto blob_copy = this->getAllocator()->allocate_unique(data_size); - // Assuming the size of the blob stored in the frontend index matches the size of the blob - // stored in the HNSW index. - memcpy(blob_copy.get(), this->frontendIndex->getDataByInternalId(job->id), data_size); - - this->insertVectorToHNSW(hnsw_index, job->label, blob_copy.get()); - - // Remove the vector and the insert job from the flat buffer. - this->flatIndexGuard.lock(); - // The job might have been invalidated due to overwrite in the meantime. In this case, - // it was already deleted and the job has been evicted. Otherwise, we need to do it now. - if (job->isValid) { - // Remove the job pointer from the labelToInsertJobs mapping. - auto &jobs = labelToInsertJobs.at(job->label); - for (size_t i = 0; i < jobs.size(); i++) { - if (jobs[i]->id == job->id) { - jobs.erase(jobs.begin() + (long)i); - break; - } - } - if (labelToInsertJobs.at(job->label).empty()) { - labelToInsertJobs.erase(job->label); - } - // Remove the vector from the flat buffer. This may cause the last vector id to swap with - // the deleted id. Hold the label for the last id, so we can later on update its - // corresponding job id. Note that after calling deleteVectorById, the last id's label - // shouldn't be available, since it is removed from the lookup. - labelType last_vec_label = - this->frontendIndex->getVectorLabel(this->frontendIndex->indexSize() - 1); - int deleted = this->frontendIndex->deleteVectorById(job->label, job->id); - if (deleted && job->id != this->frontendIndex->indexSize()) { - // If the vector removal caused a swap with the last id, update the relevant insert job. - this->updateInsertJobInternalId(this->frontendIndex->indexSize(), job->id, - last_vec_label); - } - } else { - // Remove the current job from the invalid jobs' lookup, as we are about to delete it now. - this->invalidJobsLookupGuard.lock(); - this->invalidJobs.erase(job->id); - this->invalidJobsLookupGuard.unlock(); - } - this->flatIndexGuard.unlock(); -} - -template -void TieredHNSWIndex::executeRepairJob(HNSWRepairJob *job) { - // Lock the HNSW shared lock before accessing its internals. - this->mainIndexGuard.lock_shared(); - if (!job->isValid) { - this->mainIndexGuard.unlock_shared(); - // The current node has already been removed and disposed. - this->invalidJobsLookupGuard.lock(); - this->invalidJobs.erase(job->node_id); - this->invalidJobsLookupGuard.unlock(); - return; - } - HNSWIndex *hnsw_index = this->getHNSWIndex(); - - // Remove this job pointer from the repair jobs lookup BEFORE it has been executed. Had we done - // it after executing the repair job, we might have see that there is a pending repair job for - // this node id upon deleting another neighbor of this node, and we may avoid creating another - // repair job even though *it has already been executed*. - this->idToRepairJobsGuard.lock(); - auto &repair_jobs = this->idToRepairJobs.at(job->node_id); - assert(repair_jobs.size() > 0); - if (repair_jobs.size() == 1) { - // This was the only pending repair job for this id. - this->idToRepairJobs.erase(job->node_id); - } else { - // There are more pending jobs for the current id, remove just this job from the pending - // repair jobs list for this element id by replacing it with the last one (and trim the - // last job in the list). - auto it = std::find(repair_jobs.begin(), repair_jobs.end(), job); - assert(it != repair_jobs.end()); - *it = repair_jobs.back(); - repair_jobs.pop_back(); - } - for (auto &it : job->associatedSwapJobs) { - if (it->atomicDecreasePendingJobsNum() == 0) { - readySwapJobs++; - } - } - this->idToRepairJobsGuard.unlock(); - - hnsw_index->repairNodeConnections(job->node_id, job->level); - - this->mainIndexGuard.unlock_shared(); -} - -/******************** Index API ****************************************/ - -template -TieredHNSWIndex::TieredHNSWIndex(HNSWIndex *hnsw_index, - BruteForceIndex *bf_index, - const TieredIndexParams &tiered_index_params, - std::shared_ptr allocator) - : VecSimTieredIndex(hnsw_index, bf_index, tiered_index_params, allocator), - labelToInsertJobs(this->allocator), idToRepairJobs(this->allocator), - idToSwapJob(this->allocator), invalidJobs(this->allocator), currInvalidJobId(0), - readySwapJobs(0) { - // If the param for swapJobThreshold is 0 use the default value, if it exceeds the maximum - // allowed, use the maximum value. - this->pendingSwapJobsThreshold = - tiered_index_params.specificParams.tieredHnswParams.swapJobThreshold == 0 - ? DEFAULT_PENDING_SWAP_JOBS_THRESHOLD - : std::min(tiered_index_params.specificParams.tieredHnswParams.swapJobThreshold, - MAX_PENDING_SWAP_JOBS_THRESHOLD); -} - -template -TieredHNSWIndex::~TieredHNSWIndex() { - // Delete all the pending insert jobs. - for (auto &jobs : this->labelToInsertJobs) { - for (auto *job : jobs.second) { - delete job; - } - } - // Delete all the pending repair jobs. - for (auto &jobs : this->idToRepairJobs) { - for (auto *job : jobs.second) { - delete job; - } - } - // Delete all the pending swap jobs. - for (auto &it : this->idToSwapJob) { - delete it.second; - } - // Delete all the pending invalid jobs. - for (auto &it : this->invalidJobs) { - delete it.second; - } -} - -template -size_t TieredHNSWIndex::indexSize() const { - this->flatIndexGuard.lock_shared(); - this->getHNSWIndex()->lockSharedIndexDataGuard(); - size_t res = this->backendIndex->indexSize() + this->frontendIndex->indexSize(); - this->getHNSWIndex()->unlockSharedIndexDataGuard(); - this->flatIndexGuard.unlock_shared(); - return res; -} - -template -size_t TieredHNSWIndex::indexCapacity() const { - return this->backendIndex->indexCapacity() + this->frontendIndex->indexCapacity(); -} - -// In the tiered index, we assume that the blobs are processed by the flat buffer -// before being transferred to the HNSW index. -// When inserting vectors directly into the HNSW index—such as in VecSim_WriteInPlace mode— or when -// the flat buffer is full, we must manually preprocess the blob according to the **frontend** index -// parameters. -template -int TieredHNSWIndex::addVector(const void *blob, labelType label) { - int ret = 1; - auto hnsw_index = this->getHNSWIndex(); - // writeMode is not protected since it is assumed to be called only from the "main thread" - // (that is the thread that is exclusively calling add/delete vector). - if (this->getWriteMode() == VecSim_WriteInPlace) { - // First, check if we need to overwrite the vector in-place for single (from both indexes). - if (!this->backendIndex->isMultiValue()) { - ret -= this->deleteVector(label); - } - - // Use the frontend parameters to manually prepare the blob for its transfer to the HNSW - // index. - auto storage_blob = this->frontendIndex->preprocessForStorage(blob); - // Insert the vector to the HNSW index. Internally, we will never have to overwrite the - // label since we already checked it outside. - this->lockMainIndexGuard(); - hnsw_index->addVector(storage_blob.get(), label); - this->unlockMainIndexGuard(); - return ret; - } - if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { - // Handle overwrite situation. - if (!this->backendIndex->isMultiValue()) { - // This will do nothing (and return 0) if this label doesn't exist. Otherwise, it may - // remove vector from the flat buffer and/or the HNSW index. - ret -= this->deleteVector(label); - } - if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { - // We didn't remove a vector from flat buffer due to overwrite, insert the new vector - // directly to HNSW. Since flat buffer guard was not held, no need to release it - // internally. - // Use the frontend parameters to manually prepare the blob for its transfer to the HNSW - // index. - auto storage_blob = this->frontendIndex->preprocessForStorage(blob); - this->insertVectorToHNSW(hnsw_index, label, storage_blob.get()); - return ret; - } - // Otherwise, we fall back to the "regular" insertion into the flat buffer - // (since it is not full anymore after removing the previous vector stored under the label). - } - this->flatIndexGuard.lock(); - idType new_flat_id = this->frontendIndex->indexSize(); - if (this->frontendIndex->isLabelExists(label) && !this->frontendIndex->isMultiValue()) { - // Overwrite the vector and invalidate its only pending job (since we are not in MULTI). - auto *old_job = this->labelToInsertJobs.at(label).at(0); - old_job->id = this->setAndSaveInvalidJob(old_job); - this->labelToInsertJobs.erase(label); - ret = 0; - // We are going to update the internal id that currently holds the vector associated with - // the given label. - new_flat_id = - dynamic_cast *>(this->frontendIndex) - ->getIdOfLabel(label); - // If we are adding a new element (rather than updating an exiting one) we may need to - // increase index capacity. - } - // If this label already exists, this will do overwrite. - this->frontendIndex->addVector(blob, label); - - AsyncJob *new_insert_job = new (this->allocator) - HNSWInsertJob(this->allocator, label, new_flat_id, executeInsertJobWrapper, this); - // Save a pointer to the job, so that if the vector is overwritten, we'll have an indication. - if (this->labelToInsertJobs.find(label) != this->labelToInsertJobs.end()) { - // There's already a pending insert job for this label, add another one (without overwrite, - // only possible in multi index) - assert(this->backendIndex->isMultiValue()); - this->labelToInsertJobs.at(label).push_back((HNSWInsertJob *)new_insert_job); - } else { - vecsim_stl::vector new_jobs_vec(1, (HNSWInsertJob *)new_insert_job, - this->allocator); - this->labelToInsertJobs.insert({label, new_jobs_vec}); - } - this->flatIndexGuard.unlock(); - - // Here, a worker might ingest the previous vector that was stored under "label" - // (in case of override in non-MULTI index) - so if it's there, we remove it (and create the - // required repair jobs), *before* we submit the insert job. - if (!this->backendIndex->isMultiValue()) { - // If we removed the previous vector from both HNSW and flat in the overwrite process, - // we still return 0 (not -1). - ret = std::max(ret - this->deleteLabelFromHNSW(label), 0); - } - // Apply ready swap jobs if number of deleted vectors reached the threshold (under exclusive - // lock of the main index guard). - // If swapJobs size is equal or larger than a threshold, go over the swap jobs and execute a - // batch of jobs for which all of its pending repair jobs were executed (otherwise finish and - // return). - if (readySwapJobs >= this->pendingSwapJobsThreshold) { - this->executeReadySwapJobs(this->pendingSwapJobsThreshold); - } - - // Insert job to the queue and signal the workers' updater. - this->submitSingleJob(new_insert_job); - return ret; -} - -template -int TieredHNSWIndex::deleteVector(labelType label) { - int num_deleted_vectors = 0; - this->flatIndexGuard.lock_shared(); - if (this->frontendIndex->isLabelExists(label)) { - this->flatIndexGuard.unlock_shared(); - this->flatIndexGuard.lock(); - // Check again if the label exists, as it may have been removed while we released the lock. - if (this->frontendIndex->isLabelExists(label)) { - // Invalidate the pending insert job(s) into HNSW associated with this label - auto &insert_jobs = this->labelToInsertJobs.at(label); - for (auto *job : insert_jobs) { - job->id = this->setAndSaveInvalidJob(job); - } - num_deleted_vectors += insert_jobs.size(); - // Remove the pending insert job(s) from the labelToInsertJobs mapping. - this->labelToInsertJobs.erase(label); - // Go over the every id that corresponds the label and remove it from the flat buffer. - // Every delete may cause a swap of the deleted id with the last id, and we return a - // mapping from id to the original id that resides in this id after the deletion(s) (see - // an example in this function implementation in MULTI index). - auto updated_ids = this->frontendIndex->deleteVectorAndGetUpdatedIds(label); - for (auto &it : updated_ids) { - idType prev_id = it.second.first; - labelType updated_vec_label = it.second.second; - this->updateInsertJobInternalId(prev_id, it.first, updated_vec_label); - } - } - this->flatIndexGuard.unlock(); - } else { - this->flatIndexGuard.unlock_shared(); - } - - // Next, check if there vector(s) stored under the given label in HNSW and delete them as well. - // Note that we may remove the same vector that has been removed from the flat index, if it was - // being ingested at that time. - // writeMode is not protected since it is assumed to be called only from the "main thread" - // (that is the thread that is exclusively calling add/delete vector). - if (this->getWriteMode() == VecSim_WriteAsync) { - num_deleted_vectors += this->deleteLabelFromHNSW(label); - // Apply ready swap jobs if number of deleted vectors reached the threshold - // (under exclusive lock of the main index guard). - if (readySwapJobs >= this->pendingSwapJobsThreshold) { - this->executeReadySwapJobs(this->pendingSwapJobsThreshold); - } - } else { - // delete in place. - this->lockMainIndexGuard(); - num_deleted_vectors += this->deleteLabelFromHNSWInplace(label); - this->unlockMainIndexGuard(); - } - - return num_deleted_vectors; -} - -// `getDistanceFrom` returns the minimum distance between the given blob and the vector with the -// given label. If the label doesn't exist, the distance will be NaN. -// Therefore, it's better to just call `getDistanceFrom` on both indexes and return the minimum -// instead of checking if the label exists in each index. We first try to get the distance from the -// flat buffer, as vectors in the buffer might move to the Main while we're "between" the locks. -// Behavior for single (regular) index: -// 1. label doesn't exist in both indexes - return NaN -// 2. label exists in one of the indexes only - return the distance from that index (which is valid) -// 3. label exists in both indexes - return the value from the flat buffer (which is valid and equal -// to the value from the Main index), saving us from locking the Main index. -// Behavior for multi index: -// 1. label doesn't exist in both indexes - return NaN -// 2. label exists in one of the indexes only - return the distance from that index (which is valid) -// 3. label exists in both indexes - we may have some of the vectors with the same label in the flat -// buffer only and some in the Main index only (and maybe temporal duplications). -// So, we get the distance from both indexes and return the minimum. - -// IMPORTANT: this should be called when the *tiered index locks are locked for shared ownership*, -// along with HNSW index data guard lock. That is since the internal getDistanceFrom calls access -// the indexes' data, and it is not safe to run insert/delete operation in parallel. Also, we avoid -// acquiring the locks internally, since this is usually called for every vector individually, and -// the overhead of acquiring and releasing the locks is significant in that case. -template -double TieredHNSWIndex::getDistanceFrom_Unsafe(labelType label, - const void *blob) const { - // Try to get the distance from the flat buffer. - // If the label doesn't exist, the distance will be NaN. - auto flat_dist = this->frontendIndex->getDistanceFrom_Unsafe(label, blob); - - // Optimization. TODO: consider having different implementations for single and multi indexes, - // to avoid checking the index type on every query. - if (!this->backendIndex->isMultiValue() && !std::isnan(flat_dist)) { - // If the index is single value, and we got a valid distance from the flat buffer, - // we can return the distance without querying the Main index. - return flat_dist; - } - - // Try to get the distance from the Main index. - auto hnsw_dist = getHNSWIndex()->getDistanceFrom_Unsafe(label, blob); - - // Return the minimum distance that is not NaN. - return std::fmin(flat_dist, hnsw_dist); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// TieredHNSW_BatchIterator // -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/******************** Ctor / Dtor *****************/ - -// Defining spacial values for the hnsw_iterator field, to indicate if the iterator is uninitialized -// or depleted when we don't have a valid iterator. -#define UNINITIALIZED ((VecSimBatchIterator *)0) -#define DEPLETED ((VecSimBatchIterator *)1) - -template -TieredHNSWIndex::TieredHNSW_BatchIterator::TieredHNSW_BatchIterator( - const void *query_vector, const TieredHNSWIndex *index, - VecSimQueryParams *queryParams, std::shared_ptr allocator) - // Tiered batch iterator doesn't hold its own copy of the query vector. - // Instead, each internal batch iterators (flat_iterator and hnsw_iterator) create their own - // copies: flat_iterator copy is created during TieredHNSW_BatchIterator construction When - // TieredHNSW_BatchIterator::getNextResults() is called and hnsw_iterator is not initialized, it - // retrieves the blob from flat_iterator - : VecSimBatchIterator(nullptr, queryParams ? queryParams->timeoutCtx : nullptr, - std::move(allocator)), - index(index), flat_results(this->allocator), hnsw_results(this->allocator), - flat_iterator(this->index->frontendIndex->newBatchIterator(query_vector, queryParams)), - hnsw_iterator(UNINITIALIZED), returned_results_set(this->allocator) { - // Save a copy of the query params to initialize the HNSW iterator with (on first batch and - // first batch after reset). - if (queryParams) { - this->queryParams = - (VecSimQueryParams *)this->allocator->allocate(sizeof(VecSimQueryParams)); - *this->queryParams = *queryParams; - } else { - this->queryParams = nullptr; - } -} - -template -TieredHNSWIndex::TieredHNSW_BatchIterator::~TieredHNSW_BatchIterator() { - delete this->flat_iterator; - - if (this->hnsw_iterator != UNINITIALIZED && this->hnsw_iterator != DEPLETED) { - delete this->hnsw_iterator; - this->index->mainIndexGuard.unlock_shared(); - } - - this->allocator->free_allocation(this->queryParams); -} - -/******************** Implementation **************/ - -template -VecSimQueryReply *TieredHNSWIndex::TieredHNSW_BatchIterator::getNextResults( - size_t n_res, VecSimQueryReply_Order order) { - - const bool isMulti = this->index->backendIndex->isMultiValue(); - auto hnsw_code = VecSim_QueryReply_OK; - - if (this->hnsw_iterator == UNINITIALIZED) { - // First call to getNextResults. The call to the BF iterator will include calculating all - // the distances and access the BF index. We take the lock on this call. - this->index->flatIndexGuard.lock_shared(); - auto cur_flat_results = this->flat_iterator->getNextResults(n_res, BY_SCORE_THEN_ID); - this->index->flatIndexGuard.unlock_shared(); - // This is also the only time `getNextResults` on the BF iterator can fail. - if (VecSim_OK != cur_flat_results->code) { - return cur_flat_results; - } - this->flat_results.swap(cur_flat_results->results); - VecSimQueryReply_Free(cur_flat_results); - // We also take the lock on the main index on the first call to getNextResults, and we hold - // it until the iterator is depleted or freed. - this->index->mainIndexGuard.lock_shared(); - this->hnsw_iterator = this->index->backendIndex->newBatchIterator( - this->flat_iterator->getQueryBlob(), queryParams); - auto cur_hnsw_results = this->hnsw_iterator->getNextResults(n_res, BY_SCORE_THEN_ID); - hnsw_code = cur_hnsw_results->code; - this->hnsw_results.swap(cur_hnsw_results->results); - VecSimQueryReply_Free(cur_hnsw_results); - if (this->hnsw_iterator->isDepleted()) { - delete this->hnsw_iterator; - this->hnsw_iterator = DEPLETED; - this->index->mainIndexGuard.unlock_shared(); - } - } else { - while (this->flat_results.size() < n_res && !this->flat_iterator->isDepleted()) { - auto tail = this->flat_iterator->getNextResults(n_res - this->flat_results.size(), - BY_SCORE_THEN_ID); - this->flat_results.insert(this->flat_results.end(), tail->results.begin(), - tail->results.end()); - VecSimQueryReply_Free(tail); - - if (!isMulti) { - // On single-value indexes, duplicates will never appear in the hnsw results before - // they appear in the flat results (at the same time or later if the approximation - // misses) so we don't need to try and filter the flat results (and recheck - // conditions). - break; - } else { - // On multi-value indexes, the flat results may contain results that are already - // returned from the hnsw index. We need to filter them out. - filter_irrelevant_results(this->flat_results); - } - } - - while (this->hnsw_results.size() < n_res && this->hnsw_iterator != DEPLETED && - hnsw_code == VecSim_OK) { - auto tail = this->hnsw_iterator->getNextResults(n_res - this->hnsw_results.size(), - BY_SCORE_THEN_ID); - hnsw_code = tail->code; // Set the hnsw_results code to the last `getNextResults` code. - // New batch may contain better results than the previous batch, so we need to merge. - // We don't expect duplications (hence the ), as the iterator guarantees that - // no result is returned twice. - VecSimQueryResultContainer cur_hnsw_results(this->allocator); - merge_results(cur_hnsw_results, this->hnsw_results, tail->results, n_res); - VecSimQueryReply_Free(tail); - this->hnsw_results.swap(cur_hnsw_results); - filter_irrelevant_results(this->hnsw_results); - if (this->hnsw_iterator->isDepleted()) { - delete this->hnsw_iterator; - this->hnsw_iterator = DEPLETED; - this->index->mainIndexGuard.unlock_shared(); - } - } - } - - if (VecSim_OK != hnsw_code) { - return new VecSimQueryReply(this->allocator, hnsw_code); - } - - VecSimQueryReply *batch; - if (isMulti) - batch = compute_current_batch(n_res); - else - batch = compute_current_batch(n_res); - - if (order == BY_ID) { - sort_results_by_id(batch); - } - size_t batch_len = VecSimQueryReply_Len(batch); - this->updateResultsCount(batch_len); - - return batch; -} - -// DISCLAIMER: After the last batch, one of the iterators may report that it is not depleted, -// while all of its remaining results were already returned from the other iterator. -// (On single-value indexes, this can happen to the hnsw iterator only, on multi-value -// indexes, this can happen to both iterators). -// The next call to `getNextResults` will return an empty batch, and then the iterators will -// correctly report that they are depleted. -template -bool TieredHNSWIndex::TieredHNSW_BatchIterator::isDepleted() { - return this->flat_results.empty() && this->flat_iterator->isDepleted() && - this->hnsw_results.empty() && this->hnsw_iterator == DEPLETED; -} - -template -void TieredHNSWIndex::TieredHNSW_BatchIterator::reset() { - if (this->hnsw_iterator != UNINITIALIZED && this->hnsw_iterator != DEPLETED) { - delete this->hnsw_iterator; - this->index->mainIndexGuard.unlock_shared(); - } - this->resetResultsCount(); - this->flat_iterator->reset(); - this->hnsw_iterator = UNINITIALIZED; - this->flat_results.clear(); - this->hnsw_results.clear(); - returned_results_set.clear(); -} - -/****************** Helper Functions **************/ - -template -template -VecSimQueryReply * -TieredHNSWIndex::TieredHNSW_BatchIterator::compute_current_batch(size_t n_res) { - // Merge results - // This call will update `hnsw_res` and `bf_res` to point to the end of the merged results. - auto batch_res = new VecSimQueryReply(this->allocator); - std::pair p; - if (isMultiValue) { - p = merge_results(batch_res->results, this->hnsw_results, this->flat_results, n_res); - } else { - p = merge_results(batch_res->results, this->hnsw_results, this->flat_results, n_res); - } - auto [from_hnsw, from_flat] = p; - - if (!isMultiValue) { - // If we're on a single-value index, update the set of results returned from the FLAT index - // before popping them, to prevent them to be returned from the HNSW index in later batches. - for (size_t i = 0; i < from_flat; ++i) { - this->returned_results_set.insert(this->flat_results[i].id); - } - } else { - // If we're on a multi-value index, update the set of results returned (from `batch_res`) - for (size_t i = 0; i < batch_res->results.size(); ++i) { - this->returned_results_set.insert(batch_res->results[i].id); - } - } - - // Update results - this->flat_results.erase(this->flat_results.begin(), this->flat_results.begin() + from_flat); - this->hnsw_results.erase(this->hnsw_results.begin(), this->hnsw_results.begin() + from_hnsw); - - // clean up the results - // On multi-value indexes, one (or both) results lists may contain results that are already - // returned form the other list (with a different score). We need to filter them out. - if (isMultiValue) { - filter_irrelevant_results(this->flat_results); - filter_irrelevant_results(this->hnsw_results); - } - - // Return current batch - return batch_res; -} - -template -void TieredHNSWIndex::TieredHNSW_BatchIterator::filter_irrelevant_results( - VecSimQueryResultContainer &results) { - // Filter out results that were already returned. - auto it = results.begin(); - const auto end = results.end(); - // Skip results that not returned yet - while (it != end && this->returned_results_set.count(it->id) == 0) { - ++it; - } - // If none of the results were returned, return - if (it == end) { - return; - } - // Mark the current result as the first result to be filtered - auto cur_end = it; - ++it; - // "Append" all results that were not returned from the FLAT index - while (it != end) { - if (this->returned_results_set.count(it->id) == 0) { - *cur_end = *it; - ++cur_end; - } - ++it; - } - // Update number of results (pop the tail) - results.resize(cur_end - results.begin()); -} - -template -VecSimIndexDebugInfo TieredHNSWIndex::debugInfo() const { - auto info = VecSimTieredIndex::debugInfo(); - - HnswTieredInfo hnswTieredInfo = {.pendingSwapJobsThreshold = this->pendingSwapJobsThreshold}; - info.tieredInfo.specificTieredBackendInfo.hnswTieredInfo = hnswTieredInfo; - - info.tieredInfo.backgroundIndexing = - info.tieredInfo.frontendCommonInfo.indexSize > 0 ? VecSimBool_TRUE : VecSimBool_FALSE; - - return info; -} - -template -VecSimDebugInfoIterator *TieredHNSWIndex::debugInfoIterator() const { - VecSimIndexDebugInfo info = this->debugInfo(); - // Get the base tiered fields. - auto *infoIterator = VecSimTieredIndex::debugInfoIterator(); - - // Tiered HNSW specific param. - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::TIERED_HNSW_SWAP_JOBS_THRESHOLD_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.tieredInfo.specificTieredBackendInfo - .hnswTieredInfo.pendingSwapJobsThreshold}}}); - - return infoIterator; -} - -template -VecSimIndexBasicInfo TieredHNSWIndex::basicInfo() const { - VecSimIndexBasicInfo info = this->backendIndex->getBasicInfo(); - info.isTiered = true; - info.algo = VecSimAlgo_HNSWLIB; - return info; -} - -#ifdef BUILD_TESTS -template -void TieredHNSWIndex::getDataByLabel( - labelType label, std::vector> &vectors_output) const { - this->getHNSWIndex()->getDataByLabel(label, vectors_output); -} - -#endif diff --git a/src/VecSim/algorithms/hnsw/hnsw_tiered_tests_friends.h b/src/VecSim/algorithms/hnsw/hnsw_tiered_tests_friends.h deleted file mode 100644 index 772d72724..000000000 --- a/src/VecSim/algorithms/hnsw/hnsw_tiered_tests_friends.h +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#include "VecSim/friend_test_decl.h" -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_CreateIndexInstance_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_addVector_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_manageIndexOwnership_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_insertJob_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_parallelSearch_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_parallelInsertSearch_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_deleteFromHNSWBasic_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_deleteFromHNSWWithRepairJobExec_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_manageIndexOwnershipWithPendingJobs_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_parallelInsertAdHoc_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_deleteVector_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_deleteVectorAndRepairAsync_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_alternateInsertDeleteAsync_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_swapJobBasic_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_swapJobBasic2_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_deleteVectorsAndSwapSync_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_BatchIterator_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_BatchIteratorAdvanced_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_BatchIteratorSize1_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_BatchIteratorReset_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_BatchIteratorWithOverlaps_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_parallelBatchIteratorSearch_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_testInfo_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_testInfoIterator_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_writeInPlaceMode_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_switchWriteModes_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_bufferLimit_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_bufferLimitAsync_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_RangeSearch_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTest_parallelRangeSearch_Test) - -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_insertJobAsync_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_insertJobAsyncMulti_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_KNNSearch_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_MergeMulti_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteFromHNSWMulti_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteFromHNSWMultiLevels_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_AdHocSingle_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_AdHocMulti_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteVectorMulti_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteVectorMultiFromFlatAdvanced_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_BatchIteratorWithOverlaps_SpacialMultiCases_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteVectorMulti_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteVectorMultiFromFlatAdvanced_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_overwriteVectorBasic_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_overwriteVectorAsync_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_preferAdHocOptimization_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_runGCAPI_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_FitMemoryTest_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteBothAsyncAndInplace_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteBothAsyncAndInplaceMulti_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteInplaceMultiSwapId_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_deleteInplaceAvoidUpdatedMarkedDeleted_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_switchDeleteModes_Test) -INDEX_TEST_FRIEND_CLASS(HNSWTieredIndexTestBasic_HNSWResize_Test) - -friend class CommonAPITest_SearchDifferentScores_Test; -friend class BF16TieredTest; -friend class FP16TieredTest; -friend class INT8TieredTest; -friend class UINT8TieredTest; -friend class CommonTypeMetricTieredTests_TestDataSizeTieredHNSW_Test; - -INDEX_TEST_FRIEND_CLASS(BM_VecSimBasics) -INDEX_TEST_FRIEND_CLASS(BM_VecSimCommon) diff --git a/src/VecSim/algorithms/hnsw/visited_nodes_handler.cpp b/src/VecSim/algorithms/hnsw/visited_nodes_handler.cpp deleted file mode 100644 index 289ee0b8e..000000000 --- a/src/VecSim/algorithms/hnsw/visited_nodes_handler.cpp +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include -#include "visited_nodes_handler.h" - -VisitedNodesHandler::VisitedNodesHandler(unsigned int cap, - const std::shared_ptr &allocator) - : VecsimBaseObject(allocator) { - cur_tag = 0; - num_elements = cap; - elements_tags = reinterpret_cast(allocator->callocate(sizeof(tag_t) * num_elements)); -} - -void VisitedNodesHandler::reset() { - memset(elements_tags, 0, sizeof(tag_t) * num_elements); - cur_tag = 0; -} - -void VisitedNodesHandler::resize(size_t new_size) { - this->num_elements = new_size; - this->elements_tags = reinterpret_cast( - allocator->reallocate(this->elements_tags, sizeof(tag_t) * new_size)); - this->reset(); -} - -tag_t VisitedNodesHandler::getFreshTag() { - cur_tag++; - if (cur_tag == 0) { - this->reset(); - cur_tag++; - } - return cur_tag; -} - -VisitedNodesHandler::~VisitedNodesHandler() noexcept { allocator->free_allocation(elements_tags); } - -/** - * VisitedNodesHandlerPool methods to enable parallel graph scans. - */ -VisitedNodesHandlerPool::VisitedNodesHandlerPool(int cap, - const std::shared_ptr &allocator) - : VecsimBaseObject(allocator), pool(allocator), num_elements(cap), total_handlers_in_use(0) {} - -VisitedNodesHandler *VisitedNodesHandlerPool::getAvailableVisitedNodesHandler() { - VisitedNodesHandler *handler; - std::unique_lock lock(pool_guard); - if (!pool.empty()) { - handler = pool.back(); - pool.pop_back(); - } else { - handler = new (allocator) VisitedNodesHandler(this->num_elements, this->allocator); - total_handlers_in_use++; - } - return handler; -} - -void VisitedNodesHandlerPool::returnVisitedNodesHandlerToPool(VisitedNodesHandler *handler) { - std::unique_lock lock(pool_guard); - pool.push_back(handler); - pool.shrink_to_fit(); -} - -void VisitedNodesHandlerPool::resize(size_t new_size) { - assert(total_handlers_in_use == - pool.size()); // validate that there is no handlers in use outside the pool. - this->num_elements = new_size; - for (auto &handler : this->pool) { - handler->resize(new_size); - } -} - -void VisitedNodesHandlerPool::clearPool() { - for (auto &handler : pool) { - delete handler; - } - pool.clear(); - pool.shrink_to_fit(); -} - -VisitedNodesHandlerPool::~VisitedNodesHandlerPool() { clearPool(); } diff --git a/src/VecSim/algorithms/hnsw/visited_nodes_handler.h b/src/VecSim/algorithms/hnsw/visited_nodes_handler.h deleted file mode 100644 index fc6babe5a..000000000 --- a/src/VecSim/algorithms/hnsw/visited_nodes_handler.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include -#include -#include "VecSim/memory/vecsim_malloc.h" -#include "VecSim/memory/vecsim_base.h" - -typedef unsigned short tag_t; - -/** - * Used as a singleton that is responsible for marking nodes that were visited in the graph scan. - * Every scan has a "pseudo unique" tag which is associated with this specific scan, and nodes - * that were visited in this particular scan are tagged with this tag. The tags range from - * 1-MAX_USHORT, and we reset the tags after we complete MAX_USHORT scans. - */ -class VisitedNodesHandler : public VecsimBaseObject { -private: - tag_t cur_tag; - tag_t *elements_tags; - unsigned int num_elements; - -public: - VisitedNodesHandler(unsigned int cap, const std::shared_ptr &allocator); - - // Return unused tag for marking the visited nodes. The tags are cyclic, so whenever we reach - // zero, we reset the tags of all the nodes (and use 1 as the fresh tag) - tag_t getFreshTag(); - - inline tag_t *getElementsTags() { return elements_tags; } - - void reset(); - - void resize(size_t new_size); - - // Mark node_id with tag, to have an indication that this node has been visited. - inline void tagNode(unsigned int node_id, tag_t tag) { elements_tags[node_id] = tag; } - - // Get the tag in which node_id is marked currently. - inline tag_t getNodeTag(unsigned int node_id) { return elements_tags[node_id]; } - - ~VisitedNodesHandler() noexcept override; -}; - -/** - * A wrapper class for using a pool of VisitedNodesHandler (relevant for parallel graph scans). - */ -class VisitedNodesHandlerPool : public VecsimBaseObject { -private: - std::vector> pool; - std::mutex pool_guard; - unsigned int num_elements; - unsigned short total_handlers_in_use; - -public: - VisitedNodesHandlerPool(int cap, const std::shared_ptr &allocator); - - VisitedNodesHandler *getAvailableVisitedNodesHandler(); - - void returnVisitedNodesHandlerToPool(VisitedNodesHandler *handler); - - // This should be called under a guarded section only (NOT in parallel). - void resize(size_t new_size); - - size_t getPoolSize() { return pool.size(); } - - void clearPool(); - - ~VisitedNodesHandlerPool() override; -}; diff --git a/src/VecSim/algorithms/svs/svs.h b/src/VecSim/algorithms/svs/svs.h deleted file mode 100644 index 59b056d3c..000000000 --- a/src/VecSim/algorithms/svs/svs.h +++ /dev/null @@ -1,753 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once -#include "VecSim/vec_sim_index.h" -#include "VecSim/query_result_definitions.h" -#include "VecSim/utils/vec_utils.h" - -#include -#include -#include -#include -#include - -#include "svs/index/vamana/dynamic_index.h" -#include "svs/index/vamana/multi.h" -#include "spdlog/sinks/callback_sink.h" - -#include "VecSim/algorithms/svs/svs_utils.h" -#include "VecSim/algorithms/svs/svs_batch_iterator.h" -#include "VecSim/algorithms/svs/svs_extensions.h" - -#ifdef BUILD_TESTS -#include "svs_serializer.h" -#endif - -struct SVSIndexBase -#ifdef BUILD_TESTS - : public SVSSerializer -#endif -{ - SVSIndexBase() : num_marked_deleted{0} {}; - virtual ~SVSIndexBase() = default; - virtual int addVectors(const void *vectors_data, const labelType *labels, size_t n) = 0; - virtual int deleteVectors(const labelType *labels, size_t n) = 0; - virtual size_t indexStorageSize() const = 0; - virtual size_t getNumThreads() const = 0; - virtual void setNumThreads(size_t numThreads) = 0; - virtual size_t getThreadPoolCapacity() const = 0; - virtual bool isCompressed() const = 0; - size_t getNumMarkedDeleted() const { return num_marked_deleted; } -#ifdef BUILD_TESTS - virtual svs::logging::logger_ptr getLogger() const = 0; -#endif -protected: - // Index marked deleted vectors counter to initiate reindexing if it exceeds threshold - // markIndexUpdate() manages this counter - size_t num_marked_deleted; -}; - -/** Thread Management Strategy: - * - addVector(): Requires numThreads == 1 - * - addVectors(): Allows any numThreads value, but prohibits n=1 with numThreads>1 - * - Callers are responsible for setting appropriate thread counts - **/ -template -class SVSIndex : public VecSimIndexAbstract, float>, - public SVSIndexBase { -protected: - using data_type = DataType; - using distance_f = MetricType; - using Base = VecSimIndexAbstract, float>; - using index_component_t = IndexComponents, float>; - - using storage_traits_t = SVSStorageTraits; - using index_storage_type = typename storage_traits_t::index_storage_type; - - using graph_builder_t = SVSGraphBuilder; - using graph_type = typename graph_builder_t::graph_type; - - using impl_type = std::conditional_t< - isMulti, - svs::index::vamana::MultiMutableVamanaIndex, - svs::index::vamana::MutableVamanaIndex>; - - bool forcePreprocessing; - - // Index build parameters - svs::index::vamana::VamanaBuildParameters buildParams; - - // Index search parameters - size_t search_window_size; - size_t search_buffer_capacity; - // LeanVec dataset dimension - // This parameter allows to tune LeanVec dimension if LeanVec is enabled - size_t leanvec_dim; - double epsilon; - - // Check if the dataset is Two-level LVQ - // This allows to tune default window capacity during search - bool is_two_level_lvq; - - // SVS thread pool - VecSimSVSThreadPool threadpool_; - svs::logging::logger_ptr logger_; - // SVS Index implementation instance - std::unique_ptr impl_; - - static double toVecSimDistance(float v) { return svs_details::toVecSimDistance(v); } - - svs::logging::logger_ptr makeLogger() { - spdlog::custom_log_callback callback = [this](const spdlog::details::log_msg &msg) { - if (!VecSimIndexInterface::logCallback) { - return; // No callback function provided - } - // Custom callback implementation - const char *vecsim_level = [msg]() { - switch (msg.level) { - case spdlog::level::trace: - return VecSimCommonStrings::LOG_DEBUG_STRING; - case spdlog::level::debug: - return VecSimCommonStrings::LOG_VERBOSE_STRING; - case spdlog::level::info: - return VecSimCommonStrings::LOG_NOTICE_STRING; - case spdlog::level::warn: - case spdlog::level::err: - case spdlog::level::critical: - return VecSimCommonStrings::LOG_WARNING_STRING; - default: - return "UNKNOWN"; - } - }(); - - std::string msg_str{msg.payload.data(), msg.payload.size()}; - // Log the message using the custom callback - VecSimIndexInterface::logCallback(this->logCallbackCtx, vecsim_level, msg_str.c_str()); - }; - - // Create a logger with the custom callback - auto sink = std::make_shared(callback); - auto logger = std::make_shared("SVSIndex", sink); - // Sink all messages to VecSim - logger->set_level(spdlog::level::trace); - return logger; - } - - // Create SVS index instance with initial data - // Data should not be empty - template - void initImpl(const Dataset &points, std::span ids) { - svs::threads::ThreadPoolHandle threadpool_handle{VecSimSVSThreadPool{threadpool_}}; - - // Construct SVS index initial storage with compression if needed - auto data = storage_traits_t::create_storage(points, this->blockSize, threadpool_handle, - this->getAllocator(), this->leanvec_dim); - // Compute the entry point. - auto entry_point = - svs::index::vamana::extensions::compute_entry_point(data, threadpool_handle); - - // Perform graph construction. - auto distance = distance_f{}; - const auto ¶meters = this->buildParams; - - // Construct initial Vamana Graph - auto graph = - graph_builder_t::build_graph(parameters, data, distance, threadpool_, entry_point, - this->blockSize, this->getAllocator(), logger_); - - // Create SVS MutableIndex instance - impl_ = std::make_unique(std::move(graph), std::move(data), entry_point, - std::move(distance), ids, threadpool_, logger_); - - // Set SVS MutableIndex build parameters to be used in future updates - impl_->set_construction_window_size(parameters.window_size); - impl_->set_max_candidates(parameters.max_candidate_pool_size); - impl_->set_prune_to(parameters.prune_to); - impl_->set_alpha(parameters.alpha); - impl_->set_full_search_history(parameters.use_full_search_history); - - // Configure default search parameters - auto sp = impl_->get_search_parameters(); - sp.buffer_config({this->search_window_size, this->search_buffer_capacity}); - impl_->set_search_parameters(sp); - impl_->reset_performance_parameters(); - } - - // Preprocess batch of vectors - MemoryUtils::unique_blob preprocessForBatchStorage(const void *original_data, size_t n) const { - // Buffer alignment isn't necessary for storage since SVS index will copy the data - if (!this->forcePreprocessing) { - return MemoryUtils::unique_blob{const_cast(original_data), [](void *) {}}; - } - - const auto data_size = this->getStoredDataSize() * n; - - auto processed_blob = - MemoryUtils::unique_blob{this->allocator->allocate(data_size), - [this](void *ptr) { this->allocator->free_allocation(ptr); }}; - // Assuming original data size equals to processed data size - assert(this->getInputBlobSize() == this->getStoredDataSize()); - memcpy(processed_blob.get(), original_data, data_size); - // Preprocess each vector in place - for (size_t i = 0; i < n; i++) { - this->preprocessStorageInPlace(static_cast(processed_blob.get()) + - i * this->dim); - } - return processed_blob; - } - - // Assuming numThreads was updated to reflect the number of available threads before this - // function was called. - // This function assumes that the caller has already set numThreads to the appropriate value - // for the operation. - // Important NOTE: For single vector operations (n=1), numThreads should be 1. - // For bulk operations (n>1), numThreads should reflect the number of available threads. - int addVectorsImpl(const void *vectors_data, const labelType *labels, size_t n) { - if (n == 0) { - return 0; - } - - int deleted_num = 0; - if constexpr (!isMulti) { - // SVS index does not support overriding vectors with the same label - // so we have to delete them first if needed - deleted_num = deleteVectorsImpl(labels, n); - } - - std::span ids(labels, n); - auto processed_blob = this->preprocessForBatchStorage(vectors_data, n); - auto typed_vectors_data = static_cast(processed_blob.get()); - // Wrap data into SVS SimpleDataView for SVS API - auto points = svs::data::SimpleDataView{typed_vectors_data, n, this->dim}; - - if (!impl_) { - // SVS index instance cannot be empty, so we have to construct it at first rows - initImpl(points, ids); - } else { - // Add new points to existing SVS index - impl_->add_points(points, ids); - } - - return n - deleted_num; - } - - int deleteVectorsImpl(const labelType *labels, size_t n) { - if (indexLabelCount() == 0) { - return 0; - } - - // SVS fails if we try to delete non-existing entries - std::vector entries_to_delete; - entries_to_delete.reserve(n); - for (size_t i = 0; i < n; i++) { - if (impl_->has_id(labels[i])) { - entries_to_delete.push_back(labels[i]); - } - } - - if (entries_to_delete.size() == 0) { - return 0; - } - - // If entries_to_delete.size() == 1, we should ensure single-threading - const size_t current_num_threads = getNumThreads(); - if (n == 1 && current_num_threads > 1) { - setNumThreads(1); - } - - const auto deleted_num = impl_->delete_entries(entries_to_delete); - - // Restore multi-threading if needed - if (n == 1 && current_num_threads > 1) { - setNumThreads(current_num_threads); - } - - this->markIndexUpdate(deleted_num); - return deleted_num; - } - - // Count deletions and consolidate index if needed - void markIndexUpdate(size_t n = 1) { - if (!impl_) - return; - - // SVS index instance should not be empty - if (indexLabelCount() == 0) { - this->impl_.reset(); - num_marked_deleted = 0; - return; - } - - num_marked_deleted += n; - } - - bool isTwoLevelLVQ(const VecSimSvsQuantBits &qbits) { - switch (qbits) { - case VecSimSvsQuant_4x4: - case VecSimSvsQuant_4x8: - case VecSimSvsQuant_4x8_LeanVec: - case VecSimSvsQuant_8x8_LeanVec: - return true; - default: - return false; - } - } - -public: - SVSIndex(const SVSParams ¶ms, const AbstractIndexInitParams &abstractInitParams, - const index_component_t &components, bool force_preprocessing) - : Base{abstractInitParams, components}, forcePreprocessing{force_preprocessing}, - buildParams{svs_details::makeVamanaBuildParameters(params)}, - search_window_size{svs_details::getOrDefault(params.search_window_size, - SVS_VAMANA_DEFAULT_SEARCH_WINDOW_SIZE)}, - search_buffer_capacity{ - svs_details::getOrDefault(params.search_buffer_capacity, search_window_size)}, - leanvec_dim{ - svs_details::getOrDefault(params.leanvec_dim, SVS_VAMANA_DEFAULT_LEANVEC_DIM)}, - epsilon{svs_details::getOrDefault(params.epsilon, SVS_VAMANA_DEFAULT_EPSILON)}, - is_two_level_lvq{isTwoLevelLVQ(params.quantBits)}, - threadpool_{std::max(size_t{SVS_VAMANA_DEFAULT_NUM_THREADS}, params.num_threads)}, - impl_{nullptr} { - logger_ = makeLogger(); - } - - ~SVSIndex() = default; - - size_t indexSize() const override { return indexStorageSize(); } - - size_t indexStorageSize() const override { return impl_ ? impl_->view_data().size() : 0; } - - size_t indexCapacity() const override { - return impl_ ? storage_traits_t::storage_capacity(impl_->view_data()) : 0; - } - - size_t indexLabelCount() const override { - if constexpr (isMulti) { - return impl_ ? impl_->labelcount() : 0; - } else { - return impl_ ? impl_->size() : 0; - } - } - - vecsim_stl::set getLabelsSet() const override { - vecsim_stl::set labels(this->allocator); - if (impl_) { - impl_->on_ids([&labels](size_t label) { labels.insert(label); }); - } - return labels; - } - - VecSimIndexBasicInfo basicInfo() const override { - VecSimIndexBasicInfo info = this->getBasicInfo(); - info.algo = VecSimAlgo_SVS; - info.isTiered = false; - return info; - } - - VecSimIndexDebugInfo debugInfo() const override { - VecSimIndexDebugInfo info; - info.commonInfo = this->getCommonInfo(); - info.commonInfo.basicInfo.algo = VecSimAlgo_SVS; - - info.svsInfo = - svsInfoStruct{.quantBits = getCompressionMode(), - .alpha = this->buildParams.alpha, - .graphMaxDegree = this->buildParams.graph_max_degree, - .constructionWindowSize = this->buildParams.window_size, - .maxCandidatePoolSize = this->buildParams.max_candidate_pool_size, - .pruneTo = this->buildParams.prune_to, - .useSearchHistory = this->buildParams.use_full_search_history, - .numThreads = this->getThreadPoolCapacity(), - .lastReservedThreads = this->getNumThreads(), - .numberOfMarkedDeletedNodes = this->num_marked_deleted, - .searchWindowSize = this->search_window_size, - .searchBufferCapacity = this->search_buffer_capacity, - .leanvecDim = this->leanvec_dim, - .epsilon = this->epsilon}; - return info; - } - - VecSimDebugInfoIterator *debugInfoIterator() const override { - VecSimIndexDebugInfo info = this->debugInfo(); - // For readability. Update this number when needed. - size_t numberOfInfoFields = 23; - VecSimDebugInfoIterator *infoIterator = - new VecSimDebugInfoIterator(numberOfInfoFields, this->allocator); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::ALGORITHM_STRING, - .fieldType = INFOFIELD_STRING, - .fieldValue = { - FieldValue{.stringValue = VecSimAlgo_ToString(info.commonInfo.basicInfo.algo)}}}); - this->addCommonInfoToIterator(infoIterator, info.commonInfo); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::BLOCK_SIZE_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.commonInfo.basicInfo.blockSize}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SVS_QUANT_BITS_STRING, - .fieldType = INFOFIELD_STRING, - .fieldValue = { - FieldValue{.stringValue = VecSimQuantBits_ToString(info.svsInfo.quantBits)}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::SVS_ALPHA_STRING, - .fieldType = INFOFIELD_FLOAT64, - .fieldValue = {FieldValue{.floatingPointValue = info.svsInfo.alpha}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SVS_GRAPH_MAX_DEGREE_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.graphMaxDegree}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SVS_CONSTRUCTION_WS_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.constructionWindowSize}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SVS_MAX_CANDIDATE_POOL_SIZE_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.maxCandidatePoolSize}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::SVS_PRUNE_TO_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.pruneTo}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SVS_USE_SEARCH_HISTORY_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.useSearchHistory}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::SVS_NUM_THREADS_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.numThreads}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SVS_LAST_RESERVED_THREADS_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.lastReservedThreads}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::NUM_MARKED_DELETED, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.numberOfMarkedDeletedNodes}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SVS_SEARCH_WS_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.searchWindowSize}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SVS_SEARCH_BC_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.searchBufferCapacity}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::SVS_LEANVEC_DIM_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.svsInfo.leanvecDim}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::EPSILON_STRING, - .fieldType = INFOFIELD_FLOAT64, - .fieldValue = {FieldValue{.floatingPointValue = info.svsInfo.epsilon}}}); - - return infoIterator; - } - - int addVector(const void *vector_data, labelType label) override { - // Enforce single-threaded execution for single vector operations to ensure optimal - // performance and consistent behavior. Callers must set numThreads=1 before calling this - // method. - assert(getNumThreads() == 1 && "Can't use more than one thread to insert a single vector"); - return addVectorsImpl(vector_data, &label, 1); - } - - int addVectors(const void *vectors_data, const labelType *labels, size_t n) override { - // Prevent misuse: single vector operations should use addVector(), not addVectors() with - // n=1 This ensures proper thread management and API contract enforcement. - assert(!(n == 1 && getNumThreads() > 1) && - "Can't use more than one thread to insert a single vector"); - return addVectorsImpl(vectors_data, labels, n); - } - - int deleteVector(labelType label) override { return deleteVectorsImpl(&label, 1); } - - int deleteVectors(const labelType *labels, size_t n) override { - return deleteVectorsImpl(labels, n); - } - - size_t getNumThreads() const override { return threadpool_.size(); } - void setNumThreads(size_t numThreads) override { threadpool_.resize(numThreads); } - - size_t getThreadPoolCapacity() const override { return threadpool_.capacity(); } - - bool isCompressed() const override { return storage_traits_t::is_compressed(); } - - VecSimSvsQuantBits getCompressionMode() const { - return storage_traits_t::get_compression_mode(); - } - - double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override { - if (!impl_ || !impl_->has_id(label)) { - return std::numeric_limits::quiet_NaN(); - }; - - auto query_datum = std::span{static_cast(vector_data), this->dim}; - auto dist = impl_->get_distance(label, query_datum); - return toVecSimDistance(dist); - } - - VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const override { - auto rep = new VecSimQueryReply(this->allocator); - this->lastMode = STANDARD_KNN; - if (k == 0 || this->indexLabelCount() == 0) { - return rep; - } - - // limit result size to index size - k = std::min(k, this->indexLabelCount()); - - auto processed_query_ptr = this->preprocessQuery(queryBlob); - const void *processed_query = processed_query_ptr.get(); - - auto query = svs::data::ConstSimpleDataView{ - static_cast(processed_query), 1, this->dim}; - auto result = svs::QueryResult{query.size(), k}; - auto sp = svs_details::joinSearchParams(impl_->get_search_parameters(), queryParams, - is_two_level_lvq); - - auto timeoutCtx = queryParams ? queryParams->timeoutCtx : nullptr; - auto cancel = [timeoutCtx]() { return VECSIM_TIMEOUT(timeoutCtx); }; - - impl_->search(result.view(), query, sp, cancel); - if (cancel()) { - rep->code = VecSim_QueryReply_TimedOut; - return rep; - } - - assert(result.n_queries() == 1); - - const auto n_neighbors = result.n_neighbors(); - rep->results.reserve(n_neighbors); - - for (size_t i = 0; i < n_neighbors; i++) { - rep->results.push_back( - VecSimQueryResult{result.index(0, i), toVecSimDistance(result.distance(0, i))}); - } - // Workaround for VecSim merge_results() that expects results to be sorted - // by score, then by id from both indices. - // TODO: remove this workaround when merge_results() is fixed. - sort_results_by_score_then_id(rep); - return rep; - } - - VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams) const override { - auto rep = new VecSimQueryReply(this->allocator); - this->lastMode = RANGE_QUERY; - if (radius == 0 || this->indexLabelCount() == 0) { - return rep; - } - - auto timeoutCtx = queryParams ? queryParams->timeoutCtx : nullptr; - auto cancel = [timeoutCtx]() { return VECSIM_TIMEOUT(timeoutCtx); }; - - // Prepare query blob for SVS - auto processed_query_ptr = this->preprocessQuery(queryBlob); - const void *processed_query = processed_query_ptr.get(); - std::span query{static_cast(processed_query), - this->dim}; - - // Base search parameters for the SVS iterator schedule. - auto sp = svs_details::joinSearchParams(impl_->get_search_parameters(), queryParams, - is_two_level_lvq); - // SVS BatchIterator handles the search in batches - // The batch size is set to the index search window size by default - const size_t batch_size = sp.buffer_config_.get_search_window_size(); - - // Create SVS BatchIterator for range search - // Search result is cached in the iterator and can be accessed by the user - auto svs_it = impl_->make_batch_iterator(query); - svs_it.next(batch_size, cancel); - if (cancel()) { - rep->code = VecSim_QueryReply_TimedOut; - return rep; - } - - // range search using epsilon - const auto epsilon = queryParams && queryParams->svsRuntimeParams.epsilon != 0 - ? queryParams->svsRuntimeParams.epsilon - : this->epsilon; - - const auto range_search_boundaries = radius * (1.0 + std::abs(epsilon)); - bool keep_searching = true; - - // Loop while iterator cache is not empty and search radius + epsilon is not exceeded - while (keep_searching && svs_it.size() > 0) { - // Iterate over the cached search results - for (auto &neighbor : svs_it) { - const auto dist = toVecSimDistance(neighbor.distance()); - if (dist <= radius) { - rep->results.push_back(VecSimQueryResult{neighbor.id(), dist}); - } else if (dist > range_search_boundaries) { - keep_searching = false; - } - } - // If search radius + epsilon is not exceeded, request SVS BatchIterator for the next - // batch - if (keep_searching) { - svs_it.next(batch_size, cancel); - if (cancel()) { - rep->code = VecSim_QueryReply_TimedOut; - return rep; - } - } - } - // Workaround for VecSim merge_results() that expects results to be sorted - // by score, then by id from both indices. - // TODO: remove this workaround when merge_results() is fixed. - sort_results_by_score_then_id(rep); - return rep; - } - - VecSimBatchIterator *newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const override { - // force_copy == true. - auto queryBlobCopy = this->preprocessQuery(queryBlob, true); - - // take ownership of the blob copy and pass it to the batch iterator. - auto *queryBlobCopyPtr = queryBlobCopy.release(); - // Ownership of queryBlobCopy moves to VecSimBatchIterator that will free it at the end. - if (indexLabelCount() == 0) { - return new (this->getAllocator()) - NullSVS_BatchIterator(queryBlobCopyPtr, queryParams, this->getAllocator()); - } else { - return new (this->getAllocator()) SVS_BatchIterator( - queryBlobCopyPtr, impl_.get(), queryParams, this->getAllocator(), is_two_level_lvq); - } - } - - bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override { - size_t index_size = this->indexLabelCount(); - - // Calculate the ratio of the subset size to the total index size. - double subsetRatio = (index_size == 0) ? 0.f : static_cast(subsetSize) / index_size; - - // Heuristic thresholds - const double smallSubsetThreshold = 0.07; // Subset is small if less than 7% of index. - const double largeSubsetThreshold = 0.21; // Subset is large If more than 21% of index. - const double smallIndexThreshold = 75000; // Index is small if size is less than 75k. - const double largeIndexThreshold = 750000; // Index is large if size is more than 750k. - - bool res = false; - if (subsetRatio < smallSubsetThreshold) { - // For small subsets, ad-hoc if index is not large. - res = (index_size < largeIndexThreshold); - } else if (subsetRatio < largeSubsetThreshold) { - // For medium subsets, ad-hoc if index is small or k is big. - res = (index_size < smallIndexThreshold) || (k > 12); - } else { - // For large subsets, ad-hoc only if index is small. - res = (index_size < smallIndexThreshold); - } - - this->lastMode = - res ? (initial_check ? HYBRID_ADHOC_BF : HYBRID_BATCHES_TO_ADHOC_BF) : HYBRID_BATCHES; - return res; - } - - void runGC() override { - if (impl_) { - // There is documentation for consolidate(): - // https://intel.github.io/ScalableVectorSearch/python/dynamic.html#svs.DynamicVamana.consolidate - impl_->consolidate(); - // There is documentation for compact(): - // https://intel.github.io/ScalableVectorSearch/python/dynamic.html#svs.DynamicVamana.compact - impl_->compact(); - } - num_marked_deleted = 0; - } - -#ifdef BUILD_TESTS - -private: - void saveIndexIMP(std::ofstream &output) override; - void impl_save(const std::string &location) override; - void saveIndexFields(std::ofstream &output) const override; - - bool compareMetadataFile(const std::string &metadataFilePath) const override; - void loadIndex(const std::string &folder_path) override; - bool checkIntegrity() const override; - -public: - void fitMemory() override {} - size_t indexMetaDataCapacity() const override { return this->indexCapacity(); } - std::vector> getStoredVectorDataByLabel(labelType label) const override { - - // For compressed/quantized indices, this function is not meaningful - // since the stored data is in compressed format and not directly accessible - if constexpr (QuantBits > 0 || ResidualBits > 0) { - throw std::runtime_error( - "getStoredVectorDataByLabel is not supported for compressed/quantized indices"); - } else { - - std::vector> vectors_output; - - if constexpr (isMulti) { - // Multi-index case: get all vectors for this label - auto it = impl_->get_label_to_external_lookup().find(label); - if (it != impl_->get_label_to_external_lookup().end()) { - const auto &external_ids = it->second; - for (auto external_id : external_ids) { - auto indexed_span = impl_->get_parent_index().get_datum(external_id); - - // For uncompressed data, indexed_span should be a simple span - const char *data_ptr = reinterpret_cast(indexed_span.data()); - std::vector vec_data(this->getStoredDataSize()); - std::memcpy(vec_data.data(), data_ptr, this->getStoredDataSize()); - vectors_output.push_back(std::move(vec_data)); - } - } - } else { - // Single-index case - auto indexed_span = impl_->get_datum(label); - - // For uncompressed data, indexed_span should be a simple span - const char *data_ptr = reinterpret_cast(indexed_span.data()); - std::vector vec_data(this->getStoredDataSize()); - std::memcpy(vec_data.data(), data_ptr, this->getStoredDataSize()); - vectors_output.push_back(std::move(vec_data)); - } - - return vectors_output; - } - } - void getDataByLabel( - labelType label, - std::vector>> &vectors_output) const override { - assert(false && "Not implemented"); - } - - svs::logging::logger_ptr getLogger() const override { return logger_; } -#endif -}; - -#ifdef BUILD_TESTS -// Including implementations for Serializer base -#include "svs_serializer_impl.h" -#endif diff --git a/src/VecSim/algorithms/svs/svs_batch_iterator.h b/src/VecSim/algorithms/svs/svs_batch_iterator.h deleted file mode 100644 index 116546150..000000000 --- a/src/VecSim/algorithms/svs/svs_batch_iterator.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include "VecSim/batch_iterator.h" -#include "VecSim/utils/vec_utils.h" -#include "VecSim/query_results.h" - -#include -#include - -#include "svs/index/vamana/iterator.h" - -#include "VecSim/algorithms/svs/svs_utils.h" - -template -class SVS_BatchIterator : public VecSimBatchIterator { -private: - using query_type = std::span; - using mkbi_t = decltype(&Index::template make_batch_iterator); - using impl_type = std::invoke_result_t; - - using dist_type = typename Index::distance_type; - size_t dim; - impl_type impl_; - typename impl_type::const_iterator curr_it; - size_t batch_size; - - VecSimQueryReply *getNextResultsImpl(size_t n_res) { - auto rep = new VecSimQueryReply(this->allocator); - rep->results.reserve(n_res); - auto timeoutCtx = this->getTimeoutCtx(); - auto cancel = [timeoutCtx]() { return VECSIM_TIMEOUT(timeoutCtx); }; - - if (cancel()) { - rep->code = VecSim_QueryReply_TimedOut; - return rep; - } - - const auto bs = std::max(n_res, batch_size); - - for (size_t i = 0; i < n_res; i++) { - if (curr_it == impl_.end()) { - impl_.next(bs, cancel); - if (cancel()) { - rep->code = VecSim_QueryReply_TimedOut; - rep->results.clear(); - return rep; - } - curr_it = impl_.begin(); - if (impl_.size() == 0) { - return rep; - } - } - rep->results.push_back(VecSimQueryResult{ - curr_it->id(), svs_details::toVecSimDistance(curr_it->distance())}); - ++curr_it; - } - return rep; - } - -public: - SVS_BatchIterator(void *query_vector, const Index *index, const VecSimQueryParams *queryParams, - std::shared_ptr allocator, bool is_two_level_lvq) - : VecSimBatchIterator{query_vector, queryParams ? queryParams->timeoutCtx : nullptr, - std::move(allocator)}, - dim{index->dimensions()}, impl_{index->make_batch_iterator(std::span{ - static_cast(query_vector), dim})}, - curr_it{impl_.begin()} { - auto sp = svs_details::joinSearchParams(index->get_search_parameters(), queryParams, - is_two_level_lvq); - batch_size = queryParams && queryParams->batchSize - ? queryParams->batchSize - : sp.buffer_config_.get_search_window_size(); - } - - VecSimQueryReply *getNextResults(size_t n_res, VecSimQueryReply_Order order) override { - auto rep = getNextResultsImpl(n_res); - this->updateResultsCount(VecSimQueryReply_Len(rep)); - sort_results(rep, order); - return rep; - } - - bool isDepleted() override { return curr_it == impl_.end() && impl_.done(); } - - void reset() override { - impl_.update(std::span{static_cast(this->getQueryBlob()), dim}); - curr_it = impl_.begin(); - } -}; - -// Empty index iterator -class NullSVS_BatchIterator : public VecSimBatchIterator { -private: -public: - NullSVS_BatchIterator(void *query_vector, const VecSimQueryParams *queryParams, - std::shared_ptr allocator) - : VecSimBatchIterator{query_vector, queryParams ? queryParams->timeoutCtx : nullptr, - allocator} {} - - VecSimQueryReply *getNextResults(size_t n_res, VecSimQueryReply_Order order) override { - return new VecSimQueryReply(this->allocator); - } - - bool isDepleted() override { return true; } - - void reset() override {} -}; diff --git a/src/VecSim/algorithms/svs/svs_extensions.h b/src/VecSim/algorithms/svs/svs_extensions.h deleted file mode 100644 index 3903d2289..000000000 --- a/src/VecSim/algorithms/svs/svs_extensions.h +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once -#include "VecSim/algorithms/svs/svs_utils.h" -#include "svs/extensions/vamana/scalar.h" - -#if HAVE_SVS_LVQ -#include SVS_LVQ_HEADER -#include SVS_LEANVEC_HEADER -#endif // HAVE_SVS_LVQ - -// Scalar Quantization traits for SVS -template -struct SVSStorageTraits { - using element_type = std::int8_t; - using allocator_type = svs_details::SVSAllocator; - using blocked_type = svs::data::Blocked>; - using index_storage_type = - svs::quantization::scalar::SQDataset; - - static constexpr bool is_compressed() { return true; } - - static auto make_blocked_allocator(size_t block_size, size_t dim, - std::shared_ptr allocator) { - // SVS block size is a power of two, so we can use it directly - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); - allocator_type data_allocator{std::move(allocator)}; - return svs::make_blocked_allocator_handle({svs_bs}, data_allocator); - } - - static constexpr VecSimSvsQuantBits get_compression_mode() { return VecSimSvsQuant_Scalar; } - - template - static index_storage_type create_storage(const Dataset &data, size_t block_size, Pool &pool, - std::shared_ptr allocator, - size_t /*leanvec_dim*/) { - const auto dim = data.dimensions(); - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - return index_storage_type::compress(data, pool, blocked_alloc); - } - - static index_storage_type load(const svs::lib::LoadTable &table, size_t block_size, size_t dim, - std::shared_ptr allocator) { - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - return index_storage_type::load(table, blocked_alloc); - } - - static index_storage_type load(const std::string &path, size_t block_size, size_t dim, - std::shared_ptr allocator) { - assert(svs::data::detail::is_likely_reload(path)); // TODO implement auto_load for SQDataset - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - // Load the data from disk - return svs::lib::load_from_disk(path, blocked_alloc); - } - - static constexpr size_t element_size(size_t dims, size_t alignment = 0, - size_t /*leanvec_dim*/ = 0) { - return dims * sizeof(element_type); - } - - static size_t storage_capacity(const index_storage_type &storage) { - // SQDataset does not provide a capacity method - return storage.size(); - } -}; - -#if HAVE_SVS_LVQ -namespace svs_details { -template -struct LVQSelector { - using strategy = svs::quantization::lvq::Sequential; -}; - -template <> -struct LVQSelector<4> { - using strategy = svs::quantization::lvq::Turbo<16, 8>; -}; -} // namespace svs_details - -// LVQDataset traits for SVS -template -struct SVSStorageTraits 1)>> { - using allocator_type = svs_details::SVSAllocator; - using blocked_type = svs::data::Blocked>; - using strategy_type = typename svs_details::LVQSelector::strategy; - using index_storage_type = - svs::quantization::lvq::LVQDataset; - - static constexpr bool is_compressed() { return true; } - - static constexpr VecSimSvsQuantBits get_compression_mode() { - if constexpr (QuantBits == 4 && ResidualBits == 0) { - return VecSimSvsQuant_4; - } else if constexpr (QuantBits == 8 && ResidualBits == 0) { - return VecSimSvsQuant_8; - } else if constexpr (QuantBits == 4 && ResidualBits == 4) { - return VecSimSvsQuant_4x4; - } else if constexpr (QuantBits == 4 && ResidualBits == 8) { - return VecSimSvsQuant_4x8; - } else { - assert(false && "Unsupported quantization mode"); - return VecSimSvsQuant_NONE; // Unsupported case - } - } - - static auto make_blocked_allocator(size_t block_size, size_t dim, - std::shared_ptr allocator) { - // SVS block size is a power of two, so we can use it directly - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); - allocator_type data_allocator{std::move(allocator)}; - return svs::make_blocked_allocator_handle({svs_bs}, data_allocator); - } - - template - static index_storage_type create_storage(const Dataset &data, size_t block_size, Pool &pool, - std::shared_ptr allocator, - size_t /*leanvec_dim*/) { - const auto dim = data.dimensions(); - - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - return index_storage_type::compress(data, pool, 0, blocked_alloc); - } - - static index_storage_type load(const svs::lib::LoadTable &table, size_t block_size, size_t dim, - std::shared_ptr allocator) { - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - return index_storage_type::load(table, /*alignment=*/0, blocked_alloc); - } - - static index_storage_type load(const std::string &path, size_t block_size, size_t dim, - std::shared_ptr allocator) { - assert(svs::data::detail::is_likely_reload(path)); // TODO implement auto_load for LVQ - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - // Load the data from disk - return svs::lib::load_from_disk(path, /*alignment=*/0, blocked_alloc); - } - - static constexpr size_t element_size(size_t dims, size_t alignment = 0, - size_t /*leanvec_dim*/ = 0) { - using primary_type = typename index_storage_type::primary_type; - using layout_type = typename primary_type::helper_type; - using layout_dims_type = svs::lib::MaybeStatic; - const auto layout_dims = layout_dims_type{dims}; - return primary_type::compute_data_dimensions(layout_type{layout_dims}, alignment); - } - - static size_t storage_capacity(const index_storage_type &storage) { - // LVQDataset does not provide a capacity method - return storage.size(); - } -}; - -// LeanVec dataset traits for SVS -template -struct SVSStorageTraits { - using allocator_type = svs_details::SVSAllocator; - using blocked_type = svs::data::Blocked>; - using index_storage_type = svs::leanvec::LeanDataset, - svs::leanvec::UsingLVQ, - svs::Dynamic, svs::Dynamic, blocked_type>; - - static size_t check_leanvec_dim(size_t dims, size_t leanvec_dim) { - if (leanvec_dim == 0) { - return dims / 2; /* default LeanVec dimension */ - } - return leanvec_dim; - } - - static constexpr bool is_compressed() { return true; } - - static constexpr auto get_compression_mode() { - if constexpr (QuantBits == 4 && ResidualBits == 8) { - return VecSimSvsQuant_4x8_LeanVec; - } else if constexpr (QuantBits == 8 && ResidualBits == 8) { - return VecSimSvsQuant_8x8_LeanVec; - } else { - assert(false && "Unsupported quantization mode"); - return VecSimSvsQuant_NONE; // Unsupported case - } - } - - static auto make_blocked_allocator(size_t block_size, size_t dim, - std::shared_ptr allocator) { - // SVS block size is a power of two, so we can use it directly - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); - allocator_type data_allocator{std::move(allocator)}; - return svs::make_blocked_allocator_handle({svs_bs}, data_allocator); - } - - template - static index_storage_type create_storage(const Dataset &data, size_t block_size, Pool &pool, - std::shared_ptr allocator, - size_t leanvec_dim) { - const auto dim = data.dimensions(); - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - - return index_storage_type::reduce( - data, std::nullopt, pool, 0, - svs::lib::MaybeStatic(check_leanvec_dim(dim, leanvec_dim)), - blocked_alloc); - } - - static index_storage_type load(const svs::lib::LoadTable &table, size_t block_size, size_t dim, - std::shared_ptr allocator) { - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - return index_storage_type::load(table, /*alignment=*/0, blocked_alloc); - } - - static index_storage_type load(const std::string &path, size_t block_size, size_t dim, - std::shared_ptr allocator) { - assert(svs::data::detail::is_likely_reload(path)); // TODO implement auto_load for LeanVec - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - // Load the data from disk - return svs::lib::load_from_disk(path, /*alignment=*/0, blocked_alloc); - } - - static constexpr size_t element_size(size_t dims, size_t alignment = 0, - size_t leanvec_dim = 0) { - return SVSStorageTraits::element_size( - check_leanvec_dim(dims, leanvec_dim), alignment) + - SVSStorageTraits::element_size(dims, alignment); - } - - static size_t storage_capacity(const index_storage_type &storage) { - // LeanDataset does not provide a capacity method - return storage.size(); - } -}; -#else -#pragma message "SVS LVQ is not available" -#endif // HAVE_SVS_LVQ diff --git a/src/VecSim/algorithms/svs/svs_serializer.cpp b/src/VecSim/algorithms/svs/svs_serializer.cpp deleted file mode 100644 index 58ef82ebe..000000000 --- a/src/VecSim/algorithms/svs/svs_serializer.cpp +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#include "svs_serializer.h" - -namespace fs = std::filesystem; - -SVSSerializer::SVSSerializer(EncodingVersion version) : m_version(version) {} - -SVSSerializer::EncodingVersion SVSSerializer::ReadVersion(std::ifstream &input) { - input.seekg(0, std::ifstream::beg); - - EncodingVersion version = EncodingVersion::INVALID; - readBinaryPOD(input, version); - - if (version >= EncodingVersion::INVALID) { - input.close(); - throw std::runtime_error("Cannot load index: bad encoding version: " + - std::to_string(static_cast(version))); - } - return version; -} - -void SVSSerializer::saveIndex(const std::string &location) { - EncodingVersion version = EncodingVersion::V0; - auto metadata_path = fs::path(location) / "metadata"; - std::ofstream output(metadata_path, std::ios::binary); - writeBinaryPOD(output, version); - saveIndexIMP(output); - output.close(); - impl_save(location); -} - -SVSSerializer::EncodingVersion SVSSerializer::getVersion() const { return m_version; } diff --git a/src/VecSim/algorithms/svs/svs_serializer.h b/src/VecSim/algorithms/svs/svs_serializer.h deleted file mode 100644 index d66644364..000000000 --- a/src/VecSim/algorithms/svs/svs_serializer.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include -#include -#include -#include "VecSim/utils/serializer.h" -#include - -typedef struct { - bool valid_state; - long memory_usage; // in bytes - size_t index_size; - size_t storage_size; - size_t label_count; - size_t capacity; - size_t changes_count; - bool is_compressed; - bool is_multi; -} SVSIndexMetaData; - -// Middle layer for SVS serialization -// Abstract functions should be implemented by the templated SVS index - -class SVSSerializer : public Serializer { -public: - enum class EncodingVersion { V0, INVALID }; - - explicit SVSSerializer(EncodingVersion version = EncodingVersion::V0); - - static EncodingVersion ReadVersion(std::ifstream &input); - - void saveIndex(const std::string &location) override; - - virtual void loadIndex(const std::string &location) = 0; - - EncodingVersion getVersion() const; - - virtual bool checkIntegrity() const = 0; - -protected: - EncodingVersion m_version; - - virtual void impl_save(const std::string &location) = 0; - - // Helper function to compare the svs index fields with the metadata file - template - static void compareField(std::istream &in, const T &expected, const std::string &fieldName); - -private: - virtual bool compareMetadataFile(const std::string &metadataFilePath) const = 0; -}; - -// Implement << operator for enum class -inline std::ostream &operator<<(std::ostream &os, SVSSerializer::EncodingVersion version) { - return os << static_cast(version); -} - -template -void SVSSerializer::compareField(std::istream &in, const T &expected, - const std::string &fieldName) { - T actual; - Serializer::readBinaryPOD(in, actual); - if (!in.good()) { - throw std::runtime_error("Failed to read field: " + fieldName); - } - if (actual != expected) { - std::ostringstream msg; - msg << "Field mismatch in \"" << fieldName << "\": expected " << expected << ", got " - << actual; - throw std::runtime_error(msg.str()); - } -} diff --git a/src/VecSim/algorithms/svs/svs_serializer_impl.h b/src/VecSim/algorithms/svs/svs_serializer_impl.h deleted file mode 100644 index 2780d3457..000000000 --- a/src/VecSim/algorithms/svs/svs_serializer_impl.h +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "svs_serializer.h" -#include "svs/index/vamana/dynamic_index.h" -#include "svs/index/vamana/multi.h" - -// Saves all relevant fields of SVSIndex to the output stream -// This function saves all template parameters and instance fields needed to reconstruct -// an SVSIndex -template -void SVSIndex::saveIndexFields( - std::ofstream &output) const { - // Save base class fields from VecSimIndexAbstract - // Note: this->vecType corresponds to DataType template parameter - // Note: this->metric corresponds to MetricType template parameter - writeBinaryPOD(output, this->dim); - writeBinaryPOD(output, this->vecType); // DataType template parameter (as VecSimType enum) - writeBinaryPOD(output, this->getStoredDataSize()); - writeBinaryPOD(output, this->metric); // MetricType template parameter (as VecSimMetric enum) - writeBinaryPOD(output, this->blockSize); - writeBinaryPOD(output, this->isMulti); - - // Save SVS-specific configuration fields - writeBinaryPOD(output, this->forcePreprocessing); - - // Save build parameters - writeBinaryPOD(output, this->buildParams.alpha); - writeBinaryPOD(output, this->buildParams.graph_max_degree); - writeBinaryPOD(output, this->buildParams.window_size); - writeBinaryPOD(output, this->buildParams.max_candidate_pool_size); - writeBinaryPOD(output, this->buildParams.prune_to); - writeBinaryPOD(output, this->buildParams.use_full_search_history); - - // Save search parameters - writeBinaryPOD(output, this->search_window_size); - writeBinaryPOD(output, this->epsilon); - - // Save template parameters as metadata for validation during loading - writeBinaryPOD(output, getCompressionMode()); - - // QuantBits, ResidualBits, and IsLeanVec information - - // Save additional template parameter constants for complete reconstruction - writeBinaryPOD(output, static_cast(QuantBits)); // Template parameter QuantBits - writeBinaryPOD(output, static_cast(ResidualBits)); // Template parameter ResidualBits - writeBinaryPOD(output, static_cast(IsLeanVec)); // Template parameter IsLeanVec - writeBinaryPOD(output, static_cast(isMulti)); // Template parameter isMulti - - // Save additional metadata for validation during loading - writeBinaryPOD(output, this->lastMode); // Last search mode -} - -// Saves metadata (e.g., encoding version) to satisfy Serializer interface. -// Full index is saved separately in saveIndex() using file paths. -template -void SVSIndex::saveIndexIMP( - std::ofstream &output) { - - // Save all index fields using the dedicated function - saveIndexFields(output); -} - -// Saves metadata (e.g., encoding version) to satisfy Serializer interface. -// Full index is saved separately in saveIndex() using file paths. -template -void SVSIndex::impl_save( - const std::string &location) { - impl_->save(location + "/config", location + "/graph", location + "/data"); -} - -// This function will load the serialized svs index from the given folder path -// This function should be called after the index is created with the same parameters as the -// original index. The index fields and template parameters will be validated before loading. After -// sucssessful loading, the graph can be validated with checkIntegrity. -template -void SVSIndex::loadIndex( - const std::string &folder_path) { - svs::threads::ThreadPoolHandle threadpool_handle{VecSimSVSThreadPool{threadpool_}}; - - // Verify metadata compatibility, will throw runtime exception if not compatible - compareMetadataFile(folder_path + "/metadata"); - - if constexpr (isMulti) { - auto loaded = svs::index::vamana::auto_multi_dynamic_assemble( - folder_path + "/config", - SVS_LAZY(graph_builder_t::load(folder_path + "/graph", this->blockSize, - this->buildParams, this->getAllocator())), - SVS_LAZY(storage_traits_t::load(folder_path + "/data", this->blockSize, this->dim, - this->getAllocator())), - distance_f(), std::move(threadpool_handle), - svs::index::vamana::MultiMutableVamanaLoad::FROM_MULTI, logger_); - impl_ = std::make_unique(std::move(loaded)); - } else { - auto loaded = svs::index::vamana::auto_dynamic_assemble( - folder_path + "/config", - SVS_LAZY(graph_builder_t::load(folder_path + "/graph", this->blockSize, - this->buildParams, this->getAllocator())), - SVS_LAZY(storage_traits_t::load(folder_path + "/data", this->blockSize, this->dim, - this->getAllocator())), - distance_f(), std::move(threadpool_handle), false, logger_); - impl_ = std::make_unique(std::move(loaded)); - } -} - -template -bool SVSIndex::compareMetadataFile(const std::string &metadataFilePath) const { - std::ifstream input(metadataFilePath, std::ios::binary); - if (!input.is_open()) { - throw std::runtime_error("Failed to open metadata file: " + metadataFilePath); - } - - // To check version, use ReadVersion - SVSSerializer::ReadVersion(input); - - compareField(input, this->dim, "dim"); - compareField(input, this->vecType, "vecType"); - compareField(input, this->getStoredDataSize(), "dataSize"); - compareField(input, this->metric, "metric"); - compareField(input, this->blockSize, "blockSize"); - compareField(input, this->isMulti, "isMulti"); - - compareField(input, this->forcePreprocessing, "forcePreprocessing"); - - compareField(input, this->buildParams.alpha, "buildParams.alpha"); - compareField(input, this->buildParams.graph_max_degree, "buildParams.graph_max_degree"); - compareField(input, this->buildParams.window_size, "buildParams.window_size"); - compareField(input, this->buildParams.max_candidate_pool_size, - "buildParams.max_candidate_pool_size"); - compareField(input, this->buildParams.prune_to, "buildParams.prune_to"); - compareField(input, this->buildParams.use_full_search_history, - "buildParams.use_full_search_history"); - - compareField(input, this->search_window_size, "search_window_size"); - compareField(input, this->epsilon, "epsilon"); - - auto compressionMode = getCompressionMode(); - compareField(input, compressionMode, "compression_mode"); - - compareField(input, static_cast(QuantBits), "QuantBits"); - compareField(input, static_cast(ResidualBits), "ResidualBits"); - compareField(input, static_cast(IsLeanVec), "IsLeanVec"); - compareField(input, static_cast(isMulti), "isMulti (template param)"); - - return true; -} - -template -bool SVSIndex::checkIntegrity() - const { - if (!impl_) { - throw std::runtime_error( - "SVSIndex integrity check failed: index implementation (impl_) is null."); - } - - try { - // SVS internal index integrity validation - if constexpr (isMulti) { - impl_->get_parent_index().debug_check_invariants(true); - } else { - impl_->debug_check_invariants(true); - } - } - // debug_check_invariants throws svs::lib::ANNException : public std::runtime_error in case of - // fail. - catch (...) { - throw; - } - - try { - size_t index_size = impl_->size(); - size_t storage_size = impl_->view_data().size(); - size_t capacity = storage_traits_t::storage_capacity(impl_->view_data()); - size_t label_count = this->indexLabelCount(); - - // Storage size must match index size - if (storage_size != index_size) { - throw std::runtime_error( - "SVSIndex integrity check failed: storage_size != index_size."); - } - - // Capacity must be at least index size - if (capacity < index_size) { - throw std::runtime_error("SVSIndex integrity check failed: capacity < index_size."); - } - - // Binary label validation: verify label iteration and count consistency - size_t labels_counted = 0; - bool label_validation_passed = true; - - try { - impl_->on_ids([&](size_t label) { labels_counted++; }); - - // Validate label count consistency - label_validation_passed = (labels_counted == label_count); - - // For multi-index, also ensure label count doesn't exceed index size - if constexpr (isMulti) { - label_validation_passed = label_validation_passed && (label_count <= index_size); - } - } catch (...) { - label_validation_passed = false; - } - - if (!label_validation_passed) { - throw std::runtime_error("SVSIndex integrity check failed: label validation failed."); - } - - return true; - - } catch (const std::exception &e) { - throw std::runtime_error(std::string("SVSIndex integrity check failed with exception: ") + - e.what()); - } catch (...) { - throw std::runtime_error("SVSIndex integrity check failed with unknown exception."); - } -} diff --git a/src/VecSim/algorithms/svs/svs_tiered.h b/src/VecSim/algorithms/svs/svs_tiered.h deleted file mode 100644 index 351122367..000000000 --- a/src/VecSim/algorithms/svs/svs_tiered.h +++ /dev/null @@ -1,1045 +0,0 @@ -#pragma once -#include "VecSim/vec_sim_common.h" -#include "VecSim/algorithms/brute_force/brute_force_single.h" -#include "VecSim/vec_sim_tiered_index.h" -#include "VecSim/algorithms/svs/svs.h" -#include "VecSim/index_factories/svs_factory.h" - -#include -#include -#include -#include -#include -#include -#include - -/** - * @class SVSMultiThreadJob - * @brief Represents a multi-threaded asynchronous job for the SVS algorithm. - * - * This class is responsible for managing multi-threaded jobs, including thread reservation, - * synchronization, and execution of tasks. It uses a control block to coordinate threads - * and ensure proper execution of the job. - * - * @details - * The SVSMultiThreadJob class supports creating multiple threads for a task and ensures - * synchronization between them. It uses a nested ControlBlock class to manage thread - * reservations and job completion. Additionally, it includes a nested ReserveThreadJob - * class to handle individual thread reservations. - * - * The main job executes a user-defined task with the number of reserved threads, while - * additional threads wait for the main job to complete. - * - * @note This class is designed to work with the AsyncJob framework. - */ -class SVSMultiThreadJob : public AsyncJob { -public: - class JobsRegistry { - vecsim_stl::unordered_set jobs; - std::mutex m_jobs; - - public: - JobsRegistry(const std::shared_ptr &allocator) : jobs(allocator) {} - - ~JobsRegistry() { - std::lock_guard lock{m_jobs}; - for (auto job : jobs) { - delete job; - } - jobs.clear(); - } - - void register_jobs(const vecsim_stl::vector &jobs) { - std::lock_guard lock{m_jobs}; - this->jobs.insert(jobs.begin(), jobs.end()); - } - - void delete_job(AsyncJob *job) { - { - std::lock_guard lock{m_jobs}; - jobs.erase(job); - } - delete job; - } - }; - -private: - // Thread reservation control block shared between all threads - // to reserve threads and wait for the job to be done - // actual reserved threads can be less than requested if timeout is reached - class ControlBlock { - const size_t requestedThreads; // number of threads requested to reserve - const std::chrono::microseconds timeout; // timeout for threads reservation - size_t reservedThreads; // number of threads reserved - bool jobDone; - std::mutex m_reserve; - std::condition_variable cv_reserve; - std::mutex m_done; - std::condition_variable cv_done; - - public: - template - ControlBlock(size_t requested_threads, - std::chrono::duration threads_wait_timeout) - : requestedThreads{requested_threads}, timeout{threads_wait_timeout}, - reservedThreads{0}, jobDone{false} {} - - // reserve a thread and wait for the job to be done - void reserveThreadAndWait() { - // count current thread - { - std::unique_lock lock{m_reserve}; - ++reservedThreads; - } - cv_reserve.notify_one(); - std::unique_lock lock{m_done}; - // Wait until the job is marked as done, handling potential spurious wakeups. - cv_done.wait(lock, [&] { return jobDone; }); - } - - // wait for threads to be reserved - // return actual number of reserved threads - size_t waitForThreads() { - std::unique_lock lock{m_reserve}; - ++reservedThreads; // count current thread - cv_reserve.wait_for(lock, timeout, [&] { return reservedThreads >= requestedThreads; }); - return reservedThreads; - } - - // mark the whole job as done - void markJobDone() { - { - std::lock_guard lock{m_done}; - jobDone = true; - } - cv_done.notify_all(); - } - }; - - // Job to reserve a thread and wait for the job to be done - class ReserveThreadJob : public AsyncJob { - std::weak_ptr controlBlock; // control block is owned by the main job and can - // be destroyed before this job is started - JobsRegistry *jobsRegistry; - - static void ExecuteReserveThreadImpl(AsyncJob *job) { - auto *jobPtr = static_cast(job); - // if control block is already destroyed by the update job, just delete the job - auto controlBlock = jobPtr->controlBlock.lock(); - if (controlBlock) { - controlBlock->reserveThreadAndWait(); - } - jobPtr->jobsRegistry->delete_job(job); - } - - public: - ReserveThreadJob(std::shared_ptr allocator, JobType jobType, - VecSimIndex *index, std::weak_ptr controlBlock, - JobsRegistry *registry) - : AsyncJob(std::move(allocator), jobType, ExecuteReserveThreadImpl, index), - controlBlock(std::move(controlBlock)), jobsRegistry(registry) {} - }; - - using task_type = std::function; - task_type task; - std::shared_ptr controlBlock; - JobsRegistry *jobsRegistry; - - static void ExecuteMultiThreadJobImpl(AsyncJob *job) { - auto *jobPtr = static_cast(job); - auto controlBlock = jobPtr->controlBlock; - size_t num_threads = 1; - if (controlBlock) { - num_threads = controlBlock->waitForThreads(); - } - assert(num_threads > 0); - jobPtr->task(jobPtr->index, num_threads); - if (controlBlock) { - jobPtr->controlBlock->markJobDone(); - } - jobPtr->jobsRegistry->delete_job(job); - } - - SVSMultiThreadJob(std::shared_ptr allocator, JobType jobType, - task_type callback, VecSimIndex *index, - std::shared_ptr controlBlock, JobsRegistry *registry) - : AsyncJob(std::move(allocator), jobType, ExecuteMultiThreadJobImpl, index), - task(std::move(callback)), controlBlock(std::move(controlBlock)), jobsRegistry(registry) { - } - -public: - template - static vecsim_stl::vector - createJobs(const std::shared_ptr &allocator, JobType jobType, - std::function callback, VecSimIndex *index, - size_t num_threads, std::chrono::duration threads_wait_timeout, - JobsRegistry *registry) { - assert(num_threads > 0); - std::shared_ptr controlBlock = - num_threads == 1 ? nullptr - : std::make_shared(num_threads, threads_wait_timeout); - - vecsim_stl::vector jobs(num_threads, allocator); - jobs[0] = new (allocator) - SVSMultiThreadJob(allocator, jobType, callback, index, controlBlock, registry); - for (size_t i = 1; i < num_threads; ++i) { - jobs[i] = - new (allocator) ReserveThreadJob(allocator, jobType, index, controlBlock, registry); - } - registry->register_jobs(jobs); - return jobs; - } - -#ifdef BUILD_TESTS -public: - static constexpr size_t estimateSize(size_t num_threads) { - return sizeof(SVSMultiThreadJob) + (num_threads - 1) * sizeof(ReserveThreadJob); - } -#endif -}; - -template -class TieredSVSIndex : public VecSimTieredIndex { - using Self = TieredSVSIndex; - using Base = VecSimTieredIndex; - using flat_index_t = BruteForceIndex; - using backend_index_t = VecSimIndexAbstract; - using svs_index_t = SVSIndexBase; - - // swaps_journal is used by updateSVSIndex() to track vectors swap operations that were done in - // the Flat index during SVS index updating. - // The journal contains tuples of (label, oldId, newId). - // oldId is the index of the label in flat index before the swap. - // newId is the index of the label in flat index after the swap. - // if oldId == newId, it means that the vector was not moved in the Flat index, but was removed - // from the end of the Flat index (no id swaps occurred internally). - // if label == SKIP_LABEL, it means that the vector was not moved in the Flat index, but - // updated in-place (hence was already removed and no need to remove it again) - using swap_record = std::tuple; - constexpr static size_t SKIP_LABEL = std::numeric_limits::max(); - std::vector swaps_journal; - - size_t trainingTriggerThreshold; - size_t updateTriggerThreshold; - size_t updateJobWaitTime; - // Used to prevent scheduling multiple index update jobs at the same time. - // As far as the update job does a batch update, job queue should have just 1 job at the moment. - std::atomic_flag indexUpdateScheduled = ATOMIC_FLAG_INIT; - // Used to prevent scheduling multiple index GC jobs at the same time. - std::atomic_flag indexGCScheduled = ATOMIC_FLAG_INIT; - // Used to prevent running multiple index update jobs in parallel. - // Even if update jobs scheduled sequentially, they can be started in parallel. - mutable std::mutex updateJobMutex; - - // The reason of following container just to properly destroy jobs which not executed yet - SVSMultiThreadJob::JobsRegistry uncompletedJobs; - - /// - //////////////////////////////////////////////////////////////////////////////////////////////////// - // TieredSVS_BatchIterator // - //////////////////////////////////////////////////////////////////////////////////////////////////// - - class TieredSVS_BatchIterator : public VecSimBatchIterator { - // Defining spacial values for the svs_iterator field, to indicate if the iterator is - // uninitialized or depleted when we don't have a valid iterator. - static constexpr VecSimBatchIterator *depleted() { - constexpr VecSimBatchIterator *p = nullptr; - return p + 1; - } - - private: - using Index = TieredSVSIndex; - const Index *index; - VecSimQueryParams *queryParams; - - VecSimQueryResultContainer flat_results; - VecSimQueryResultContainer svs_results; - - VecSimBatchIterator *flat_iterator; - VecSimBatchIterator *svs_iterator; - std::shared_lock svs_lock; - - // On single value indices, this set holds the IDs of the results that were returned from - // the flat buffer. - // On multi value indices, this set holds the IDs of all the results that were returned. - // The difference between the two cases is that on multi value indices, the same ID can - // appear in both indexes and results with different scores, and therefore we can't tell in - // advance when we expect a possibility of a duplicate. - // On single value indices, a duplicate may appear at the same batch (and we will handle it - // when merging the results) Or it may appear in a different batches, first from the flat - // buffer and then from the SVS, in the cases where a better result if found later in SVS - // because of the approximate nature of the algorithm. - vecsim_stl::unordered_set returned_results_set; - - VecSimQueryReply *compute_current_batch(size_t n_res, bool isMultiValue) { - // Merge results - // This call will update `svs_res` and `bf_res` to point to the end of the merged - // results. - auto batch_res = new VecSimQueryReply(allocator); - // VecSim and SVS distance computation is implemented differently, so we always have to - // merge results with set. - auto [from_svs, from_flat] = - merge_results(batch_res->results, svs_results, flat_results, n_res); - - if (!isMultiValue) { - // If we're on a single-value index, update the set of results returned from the - // FLAT index before popping them, to prevent them to be returned from the SVS index - // in later batches. - for (size_t i = 0; i < from_flat; ++i) { - this->returned_results_set.insert(this->flat_results[i].id); - } - } else { - // If we're on a multi-value index, update the set of results returned (from - // `batch_res`) - for (size_t i = 0; i < batch_res->results.size(); ++i) { - this->returned_results_set.insert(batch_res->results[i].id); - } - } - - // Update results - flat_results.erase(flat_results.begin(), flat_results.begin() + from_flat); - svs_results.erase(svs_results.begin(), svs_results.begin() + from_svs); - - // clean up the results - // On multi-value indexes, one (or both) results lists may contain results that are - // already returned form the other list (with a different score). We need to filter them - // out. - if (isMultiValue) { - filter_irrelevant_results(this->flat_results); - filter_irrelevant_results(this->svs_results); - } - - // Return current batch - return batch_res; - } - - void filter_irrelevant_results(VecSimQueryResultContainer &results) { - // Filter out results that were already returned. - const auto it = std::remove_if(results.begin(), results.end(), [this](const auto &r) { - return returned_results_set.count(r.id) != 0; - }); - results.erase(it, results.end()); - } - - void acquire_svs_iterator() { - assert(svs_iterator == nullptr); - this->index->mainIndexGuard.lock_shared(); - svs_iterator = index->backendIndex->newBatchIterator( - this->flat_iterator->getQueryBlob(), queryParams); - } - - void release_svs_iterator() { - if (svs_iterator != nullptr && svs_iterator != depleted()) { - delete svs_iterator; - svs_iterator = nullptr; - this->index->mainIndexGuard.unlock_shared(); - } - } - - void handle_svs_depletion() { - assert(svs_iterator != depleted()); - if (svs_iterator->isDepleted()) { - release_svs_iterator(); - svs_iterator = depleted(); - } - } - - public: - TieredSVS_BatchIterator(const void *query_vector, const Index *index, - VecSimQueryParams *queryParams, - std::shared_ptr allocator) - // Tiered batch iterator doesn't hold its own copy of the query vector. - // Instead, each internal batch iterators (flat_iterator and svs_iterator) create their - // own copies: flat_iterator copy is created during TieredSVS_BatchIterator - // construction When TieredSVS_BatchIterator::getNextResults() is called and - // svs_iterator is not initialized, it retrieves the blob from flat_iterator - : VecSimBatchIterator(nullptr, queryParams ? queryParams->timeoutCtx : nullptr, - std::move(allocator)), - index(index), flat_results(this->allocator), svs_results(this->allocator), - flat_iterator(index->frontendIndex->newBatchIterator(query_vector, queryParams)), - svs_iterator(nullptr), svs_lock(index->mainIndexGuard, std::defer_lock), - returned_results_set(this->allocator) { - if (queryParams) { - this->queryParams = - (VecSimQueryParams *)this->allocator->allocate(sizeof(VecSimQueryParams)); - *this->queryParams = *queryParams; - } else { - this->queryParams = nullptr; - } - } - - ~TieredSVS_BatchIterator() { - release_svs_iterator(); - if (queryParams) { - this->allocator->free_allocation(queryParams); - } - delete flat_iterator; - } - - VecSimQueryReply *getNextResults(size_t n_res, VecSimQueryReply_Order order) override { - auto svs_code = VecSim_QueryReply_OK; - - const bool isMulti = this->index->backendIndex->isMultiValue(); - if (svs_iterator == nullptr) { // first call - // First call to getNextResults. The call to the BF iterator will include - // calculating all the distances and access the BF index. We take the lock on this - // call. - auto cur_flat_results = [this, n_res]() { - std::shared_lock flat_lock{index->flatIndexGuard}; - return flat_iterator->getNextResults(n_res, BY_SCORE_THEN_ID); - }(); - // This is also the only time `getNextResults` on the BF iterator can fail. - if (VecSim_OK != cur_flat_results->code) { - return cur_flat_results; - } - flat_results.swap(cur_flat_results->results); - VecSimQueryReply_Free(cur_flat_results); - // We also take the lock on the main index on the first call to getNextResults, and - // we hold it until the iterator is depleted or freed. - acquire_svs_iterator(); - auto cur_svs_results = svs_iterator->getNextResults(n_res, BY_SCORE_THEN_ID); - svs_code = cur_svs_results->code; - svs_results.swap(cur_svs_results->results); - VecSimQueryReply_Free(cur_svs_results); - handle_svs_depletion(); - } else { - while (flat_results.size() < n_res && !flat_iterator->isDepleted()) { - auto tail = flat_iterator->getNextResults(n_res - flat_results.size(), - BY_SCORE_THEN_ID); - flat_results.insert(flat_results.end(), tail->results.begin(), - tail->results.end()); - VecSimQueryReply_Free(tail); - - if (!isMulti) { - // On single-value indexes, duplicates will never appear in the hnsw results - // before they appear in the flat results (at the same time or later if the - // approximation misses) so we don't need to try and filter the flat results - // (and recheck conditions). - break; - } else { - // On multi-value indexes, the flat results may contain results that are - // already returned from the hnsw index. We need to filter them out. - filter_irrelevant_results(this->flat_results); - } - } - - while (svs_results.size() < n_res && svs_iterator != depleted() && - svs_code == VecSim_OK) { - auto tail = - svs_iterator->getNextResults(n_res - svs_results.size(), BY_SCORE_THEN_ID); - svs_code = - tail->code; // Set the svs_results code to the last `getNextResults` code. - // New batch may contain better results than the previous batch, so we need to - // merge. We don't expect duplications (hence the ), as the iterator - // guarantees that no result is returned twice. - VecSimQueryResultContainer cur_svs_results(this->allocator); - merge_results(cur_svs_results, svs_results, tail->results, n_res); - VecSimQueryReply_Free(tail); - svs_results.swap(cur_svs_results); - filter_irrelevant_results(svs_results); - handle_svs_depletion(); - } - } - - if (VecSim_OK != svs_code) { - return new VecSimQueryReply(this->allocator, svs_code); - } - - VecSimQueryReply *batch; - batch = compute_current_batch(n_res, isMulti); - - if (order == BY_ID) { - sort_results_by_id(batch); - } - size_t batch_len = VecSimQueryReply_Len(batch); - this->updateResultsCount(batch_len); - - return batch; - } - - // DISCLAIMER: After the last batch, one of the iterators may report that it is not - // depleted, while all of its remaining results were already returned from the other - // iterator. (On single-value indexes, this can happen to the svs iterator only, on - // multi-value indexes, this can happen to both iterators). - // The next call to `getNextResults` will return an empty batch, and then the iterators will - // correctly report that they are depleted. - bool isDepleted() override { - return flat_results.empty() && flat_iterator->isDepleted() && svs_results.empty() && - svs_iterator == depleted(); - } - - void reset() override { - release_svs_iterator(); - resetResultsCount(); - flat_iterator->reset(); - svs_iterator = nullptr; - flat_results.clear(); - svs_results.clear(); - returned_results_set.clear(); - } - }; - - /// - -#ifdef BUILD_TESTS -public: -#endif - flat_index_t *GetFlatIndex() { - auto result = dynamic_cast(this->frontendIndex); - assert(result); - return result; - } - - svs_index_t *GetSVSIndex() const { - auto result = dynamic_cast(this->backendIndex); - assert(result); - return result; - } - -#ifdef BUILD_TESTS -public: - backend_index_t *GetBackendIndex() { return this->backendIndex; } - void submitSingleJob(AsyncJob *job) { Base::submitSingleJob(job); } - void submitJobs(vecsim_stl::vector &jobs) { Base::submitJobs(jobs); } - - // Tracing helpers can be used to trace/inject code in the index update process. - std::map> tracingCallbacks; - void registerTracingCallback(const std::string &name, std::function callback) { - tracingCallbacks[name] = std::move(callback); - } - void executeTracingCallback(const std::string &name) const { - auto it = tracingCallbacks.find(name); - if (it != tracingCallbacks.end()) { - it->second(); - } - } - size_t indexMetaDataCapacity() const override { - std::shared_lock flat_lock(this->flatIndexGuard); - std::shared_lock main_lock(this->mainIndexGuard); - return this->frontendIndex->indexMetaDataCapacity() + - this->backendIndex->indexMetaDataCapacity(); - } -#else - void executeTracingCallback(const std::string &) const { - // In production, we do nothing. - } -#endif - -private: - /** - * @brief Updates the SVS index in a thread-safe manner. - * - * This static wrapper function performs the following actions: - * - Acquires a lock on the index's updateJobMutex to prevent concurrent updates. - * - Clears the indexUpdateScheduled flag to allow future scheduling. - * - Configures the number of threads for the underlying SVS index update operation. - * - Calls the updateSVSIndex method to perform the actual index update. - * - * @param idx Pointer to the VecSimIndex to be updated. - * @param availableThreads The number of threads available for the update operation. Current - * thread us used as well, so the minimal value is 1. - */ - static void updateSVSIndexWrapper(VecSimIndex *idx, size_t availableThreads) { - assert(availableThreads > 0); - auto index = static_cast *>(idx); - assert(index); - // prevent parallel updates - std::lock_guard lock(index->updateJobMutex); - // Release the scheduled flag to allow scheduling again - index->indexUpdateScheduled.clear(); - // Update the SVS index - index->updateSVSIndex(availableThreads); - } - - /** - * @brief Run SVS index GC in a thread-safe manner. - * - * This static wrapper function performs the following actions: - * - Acquires a lock on the index's mainIndexGuard to ensure thread safety during the GC - * - Configures the number of threads for the underlying SVS index update operation. - * - Calls the SVSIndex::runGC() method to perform the actual index update. - * - Clears the indexGCScheduled flag to allow future scheduling. - * - * @param idx Pointer to the VecSimIndex to be updated. - * @param availableThreads The number of threads available for the update operation. Current - * thread us used as well, so the minimal value is 1. - * @note no need to implement extra non-static method, as GC logic is simple enough to be done - * here. - */ - static void SVSIndexGCWrapper(VecSimIndex *idx, size_t availableThreads) { - assert(availableThreads > 0); - auto index = static_cast *>(idx); - assert(index); - - std::lock_guard lock{index->mainIndexGuard}; - // Release the scheduled flag to allow scheduling again - index->indexGCScheduled.clear(); - - // Do SVS index GC - index->backendIndex->log(VecSimCommonStrings::LOG_VERBOSE_STRING, - "running asynchronous GC for tiered SVS index"); - auto svs_index = index->GetSVSIndex(); - if (index->backendIndex->indexSize() == 0) { - // No need to run GC on an empty index. - return; - } - svs_index->setNumThreads(std::min(availableThreads, index->backendIndex->indexSize())); - // VecSimIndexAbstract::runGC() is protected - static_cast(index->backendIndex)->runGC(); - } - -#ifdef BUILD_TESTS -public: -#endif - void scheduleSVSIndexUpdate() { - // do not schedule if scheduled already - if (indexUpdateScheduled.test_and_set()) { - return; - } - - auto total_threads = this->GetSVSIndex()->getThreadPoolCapacity(); - auto jobs = SVSMultiThreadJob::createJobs( - this->allocator, SVS_BATCH_UPDATE_JOB, updateSVSIndexWrapper, this, total_threads, - std::chrono::microseconds(updateJobWaitTime), &uncompletedJobs); - this->submitJobs(jobs); - } - - void scheduleSVSIndexGC() { - // do not schedule if scheduled already - if (indexGCScheduled.test_and_set()) { - return; - } - - auto total_threads = this->GetSVSIndex()->getThreadPoolCapacity(); - auto jobs = SVSMultiThreadJob::createJobs( - this->allocator, SVS_GC_JOB, SVSIndexGCWrapper, this, total_threads, - std::chrono::microseconds(updateJobWaitTime), &uncompletedJobs); - this->submitJobs(jobs); - } - -private: - static void applySwapsToLabelsArray(std::vector &labels, - const std::vector &swaps) { - // Enumerate journal and reflect swaps in the labels. - // The journal contains tuples of (label, oldId, newId). - // oldId is the index of the label in flat index before the swap. - // newId is the index of the label in flat index after the swap. - for (const auto &p : swaps) { - auto oldId = std::get<1>(p); - auto newId = std::get<2>(p); - - if (oldId == newId || oldId >= labels.size()) { - // If oldId == newId, it means that the vector was not moved in the Flat index, - // but was removed or updated in-place. - // If oldId is out of bounds - new vector was added and swapped meanwhile. - // In both cases, we should not touch flat index at the new position. - if (newId < labels.size()) { - labels[newId] = SKIP_LABEL; - } - continue; // Next swap record. - } - - // Real swap case. - // If oldId != newId and oldId is in bounds, it means that the tracked vector was moved - // So, move the label to new position and skip the old position from deletions. - // NOTE: labels[oldId] can be SKIP_LABEL, but we still need to move it to newId. - labels[newId] = labels[oldId]; - labels[oldId] = SKIP_LABEL; - } - } - - void updateSVSIndex(size_t availableThreads) { - std::vector labels_to_move; - std::vector vectors_to_move; - - { // lock frontendIndex from modifications - std::shared_lock flat_lock{this->flatIndexGuard}; - - auto flat_index = this->GetFlatIndex(); - const auto frontend_index_size = this->frontendIndex->indexSize(); - const size_t dim = flat_index->getDim(); - labels_to_move.reserve(frontend_index_size); - vectors_to_move.reserve(frontend_index_size * dim); - - for (idType i = 0; i < frontend_index_size; i++) { - labels_to_move.push_back(flat_index->getVectorLabel(i)); - auto data = flat_index->getDataByInternalId(i); - vectors_to_move.insert(vectors_to_move.end(), data, data + dim); - } - // reset journal to the current frontend index state - swaps_journal.clear(); - } // release frontend index - - executeTracingCallback("UpdateJob::before_add_to_svs"); - { // lock backend index for writing and add vectors there - std::lock_guard lock(this->mainIndexGuard); - auto svs_index = GetSVSIndex(); - assert(labels_to_move.size() == vectors_to_move.size() / this->frontendIndex->getDim()); - svs_index->setNumThreads(std::min(availableThreads, labels_to_move.size())); - svs_index->addVectors(vectors_to_move.data(), labels_to_move.data(), - labels_to_move.size()); - } - executeTracingCallback("UpdateJob::after_add_to_svs"); - // clean-up frontend index - { // lock frontend index for writing and delete moved vectors - std::lock_guard lock(this->flatIndexGuard); - - // Apply swaps from journal to labels_to_move to reflect changes made in meanwhile. - applySwapsToLabelsArray(labels_to_move, this->swaps_journal); - - // delete vectors from the frontend index in reverse order - // it increases the chance of avoiding swaps in the frontend index and performance - // improvement - int deleted = 0; - idType id = labels_to_move.size(); - while (id-- > 0) { - auto label = labels_to_move[id]; - // Delete the vector from the frontend index if not in-place updated. - if (label != SKIP_LABEL) { - deleted += this->frontendIndex->deleteVectorById(label, id); - } - } - assert(deleted == std::count_if(labels_to_move.begin(), labels_to_move.end(), - [](labelType label) { return label != SKIP_LABEL; }) && - "Deleted vectors count does not match the number of labels to delete"); - } - } - -public: - TieredSVSIndex(VecSimIndexAbstract *svs_index, flat_index_t *bf_index, - const TieredIndexParams &tiered_index_params, - std::shared_ptr allocator) - : Base(svs_index, bf_index, tiered_index_params, allocator), - uncompletedJobs(this->allocator) { - const auto &tiered_svs_params = tiered_index_params.specificParams.tieredSVSParams; - - // If flatBufferLimit is not initialized (0), use the default update threshold. - const size_t flat_buffer_bound = tiered_index_params.flatBufferLimit == 0 - ? SVS_VAMANA_DEFAULT_UPDATE_THRESHOLD - : tiered_index_params.flatBufferLimit; - - this->updateTriggerThreshold = - tiered_svs_params.updateTriggerThreshold == 0 - ? SVS_VAMANA_DEFAULT_UPDATE_THRESHOLD - : std::min({tiered_svs_params.updateTriggerThreshold, flat_buffer_bound, - static_cast(SVS_VAMANA_DEFAULT_UPDATE_THRESHOLD)}); - - const size_t default_training_threshold = this->GetSVSIndex()->isCompressed() - ? SVS_VAMANA_DEFAULT_TRAINING_THRESHOLD - : this->updateTriggerThreshold; - - this->trainingTriggerThreshold = - tiered_svs_params.trainingTriggerThreshold == 0 - ? default_training_threshold - : std::min(tiered_svs_params.trainingTriggerThreshold, SVS_MAX_TRAINING_THRESHOLD); - - this->updateJobWaitTime = tiered_svs_params.updateJobWaitTime == 0 - ? SVS_DEFAULT_UPDATE_JOB_WAIT_TIME - : tiered_svs_params.updateJobWaitTime; - - // Reserve space for the journal to avoid reallocation. - this->swaps_journal.reserve(this->trainingTriggerThreshold); - } - - int addVector(const void *blob, labelType label) override { - int ret = 0; - auto svs_index = GetSVSIndex(); - size_t update_threshold = 0; - size_t frontend_index_size = 0; - - // In-Place mode - add vector syncronously to the backend index. - if (this->getWriteMode() == VecSim_WriteInPlace) { - // It is ok to lock everything at once for in-place mode, - // but we will have to unlock averything before calling updateSVSIndexWrapper() - // so make the minimal needed lock here. - std::shared_lock backend_shared_lock(this->mainIndexGuard); - // Backend index initialization data have to be buffered for proper - // compression/training. - if (this->backendIndex->indexSize() == 0) { - // If backend index size is 0, first collect vectors in frontend index - // lock in scope to ensure that these will be released before - // updateSVSIndexWrapper() is called. - { - std::lock_guard lock(this->flatIndexGuard); - ret = this->frontendIndex->addVector(blob, label); - // If frontend size exceeds the update job threshold, ... - frontend_index_size = this->frontendIndex->indexSize(); - } - // ... move vectors to the backend index. - if (frontend_index_size >= this->trainingTriggerThreshold) { - // updateSVSIndexWrapper() accures it's own locks - backend_shared_lock.unlock(); - // initialize the SVS index synchonously using current thread only - updateSVSIndexWrapper(this, 1); - } - return ret; - } else { - // backend index is initialized - we can add the vector directly - backend_shared_lock.unlock(); - auto storage_blob = this->frontendIndex->preprocessForStorage(blob); - // prevent update job from running in parallel and lock any access to the backend - // index - std::scoped_lock lock(this->updateJobMutex, this->mainIndexGuard); - // Set available thread count to 1 for single vector write-in-place operation. - // This maintains the contract that single vector operations use exactly one thread. - // TODO: Replace this setNumThreads(1) call with an assertion once we establish - // a contract that write-in-place mode guarantees numThreads == 1. - svs_index->setNumThreads(1); - return this->backendIndex->addVector(storage_blob.get(), label); - } - } - assert(this->getWriteMode() != VecSim_WriteInPlace && "InPlace mode returns early"); - - // Async mode - add vector to the frontend index and schedule an update job if needed. - if (!this->backendIndex->isMultiValue()) { - { - std::shared_lock lock(this->flatIndexGuard); - // If the label already exists in the frontend index, we should count it - // to prevent the case when existing vector is moved meanwhile by the update job. - if (this->frontendIndex->isLabelExists(label)) { - ret = -1; - } - } - // Remove vector from the backend index if it exists in case of non-MULTI. - std::lock_guard lock(this->mainIndexGuard); - ret -= svs_index->deleteVectors(&label, 1); - } - { // Add vector to the frontend index. - std::lock_guard lock(this->flatIndexGuard); - const auto ft_ret = this->frontendIndex->addVector(blob, label); - - if (ft_ret == 0) { // Vector was overriden - add 'skiping' swap to the journal. - assert(!this->backendIndex->isMultiValue() && - "addVector() may return 0 for single value indices only"); - for (auto id : this->frontendIndex->getElementIds(label)) { - this->swaps_journal.emplace_back(SKIP_LABEL, id, id); - } - } - ret = std::max(ret + ft_ret, 0); - // Check frontend index size to determine if an update job schedule is needed. - frontend_index_size = this->frontendIndex->indexSize(); - } - { - // If main index is empty then update_threshold is trainingTriggerThreshold, - // overwise it is updateTriggerThreshold. - std::shared_lock lock(this->mainIndexGuard); - update_threshold = this->backendIndex->indexSize() == 0 ? this->trainingTriggerThreshold - : this->updateTriggerThreshold; - } - if (frontend_index_size >= update_threshold) { - scheduleSVSIndexUpdate(); - } - - return ret; - } - - int deleteAndRecordSwaps_Unsafe(labelType label) { - auto deleting_ids = this->frontendIndex->getElementIds(label); - - // assert if all elements of deleting_ids are unique - assert(std::set(deleting_ids.begin(), deleting_ids.end()).size() == deleting_ids.size() && - "deleting_ids should contain unique ids"); - - // Sort deleting_ids by id descending order - std::sort(deleting_ids.begin(), deleting_ids.end(), - [](const auto &a, const auto &b) { return a > b; }); - - // Delete vector from the frontend index. - auto updated_ids = this->frontendIndex->deleteVectorAndGetUpdatedIds(label); - - assert(std::all_of(updated_ids.begin(), updated_ids.end(), - [&deleting_ids](const auto &pair) { - return std::find(deleting_ids.begin(), deleting_ids.end(), - pair.first) != deleting_ids.end(); - }) && - "updated_ids should be a subset of deleting_ids"); - - // Record swaps in the journal. - for (auto id : deleting_ids) { - auto it = updated_ids.find(id); - if (it != updated_ids.end()) { - assert(id == it->first && "id in updated_ids should match the id in deleting_ids"); - auto newId = id; - auto oldId = it->second.first; - auto oldLabel = it->second.second; - this->swaps_journal.emplace_back(oldLabel, oldId, newId); - } else { - // No swap, just delete is marked by oldId == newId == deleted id - this->swaps_journal.emplace_back(SKIP_LABEL, id, id); - } - } - - return deleting_ids.size(); - } - - int deleteVector(labelType label) override { - int ret = 0; - auto svs_index = GetSVSIndex(); - // Backend index deletions to be synchronized with the frontend index, - // elsewhere there is the risk of labels duplication in both indices which can lead to wrong - // results of topK queries. In such case we should behave as if InPlace mode is always set. - bool label_exists = [&]() { - std::shared_lock lock(this->flatIndexGuard); - return this->frontendIndex->isLabelExists(label); - }(); - - if (label_exists) { - std::lock_guard lock(this->flatIndexGuard); - ret = this->deleteAndRecordSwaps_Unsafe(label); - } - { - std::lock_guard lock(this->mainIndexGuard); - ret += svs_index->deleteVectors(&label, 1); - } - return ret; - } - size_t getNumMarkedDeleted() const override { - return this->GetSVSIndex()->getNumMarkedDeleted(); - } - - size_t indexSize() const override { - std::shared_lock flat_lock(this->flatIndexGuard); - std::shared_lock main_lock(this->mainIndexGuard); - return this->frontendIndex->indexSize() + this->backendIndex->indexSize(); - } - - size_t indexCapacity() const override { - std::shared_lock flat_lock(this->flatIndexGuard); - std::shared_lock main_lock(this->mainIndexGuard); - return this->frontendIndex->indexCapacity() + this->backendIndex->indexCapacity(); - } - - double getDistanceFrom_Unsafe(labelType label, const void *blob) const override { - // Try to get the distance from the flat buffer. - // If the label doesn't exist, the distance will be NaN. - auto flat_dist = this->frontendIndex->getDistanceFrom_Unsafe(label, blob); - - // Optimization. TODO: consider having different implementations for single and multi - // indexes, to avoid checking the index type on every query. - if (!this->backendIndex->isMultiValue() && !std::isnan(flat_dist)) { - // If the index is single value, and we got a valid distance from the flat buffer, - // we can return the distance without querying the Main index. - return flat_dist; - } - - // Try to get the distance from the Main index. - auto svs_dist = this->backendIndex->getDistanceFrom_Unsafe(label, blob); - - // Return the minimum distance that is not NaN. - return std::fmin(flat_dist, svs_dist); - } - - VecSimIndexDebugInfo debugInfo() const override { - auto info = Base::debugInfo(); - - SvsTieredInfo svsTieredInfo = { - .trainingTriggerThreshold = this->trainingTriggerThreshold, - .updateTriggerThreshold = this->updateTriggerThreshold, - .updateJobWaitTime = this->updateJobWaitTime, - }; - { - std::lock_guard lock(this->updateJobMutex); - svsTieredInfo.indexUpdateScheduled = - this->indexUpdateScheduled.test() == VecSimBool_TRUE; - } - info.tieredInfo.specificTieredBackendInfo.svsTieredInfo = svsTieredInfo; - info.tieredInfo.backgroundIndexing = - svsTieredInfo.indexUpdateScheduled && info.tieredInfo.frontendCommonInfo.indexSize > 0 - ? VecSimBool_TRUE - : VecSimBool_FALSE; - return info; - } - - VecSimIndexBasicInfo basicInfo() const override { - VecSimIndexBasicInfo info = this->backendIndex->getBasicInfo(); - info.blockSize = info.blockSize; - info.isTiered = true; - info.algo = VecSimAlgo_SVS; - return info; - } - - VecSimDebugInfoIterator *debugInfoIterator() const override { - // Get the base tiered fields. - auto *infoIterator = Base::debugInfoIterator(); - VecSimIndexDebugInfo info = this->debugInfo(); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::TIERED_SVS_TRAINING_THRESHOLD_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = { - FieldValue{.uintegerValue = info.tieredInfo.specificTieredBackendInfo.svsTieredInfo - .trainingTriggerThreshold}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::TIERED_SVS_UPDATE_THRESHOLD_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = { - FieldValue{.uintegerValue = info.tieredInfo.specificTieredBackendInfo.svsTieredInfo - .updateTriggerThreshold}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::TIERED_SVS_THREADS_RESERVE_TIMEOUT_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{ - .uintegerValue = - info.tieredInfo.specificTieredBackendInfo.svsTieredInfo.updateJobWaitTime}}}); - return infoIterator; - } - - VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const override { - // SVS implements it's own distance computation functions which may cause sligthly different - // distance values than VecSim Flat Index does, so we always have to merge results with set. - return this->template topKQueryImp(queryBlob, k, queryParams); - } - - VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) const override { - // SVS implements it's own distance computation functions which may cause sligthly different - // distance values than VecSim Flat Index does, so we always have to merge results with set. - return this->template rangeQueryImp(queryBlob, radius, queryParams, order); - } - - VecSimBatchIterator *newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const override { - // The query blob will be processed and copied by the internal indexes's batch iterator. - return new (this->allocator) - TieredSVS_BatchIterator(queryBlob, this, queryParams, this->allocator); - } - - void setLastSearchMode(VecSearchMode mode) override { - return this->backendIndex->setLastSearchMode(mode); - } - - void runGC() override { - if (this->getWriteMode() == VecSim_WriteInPlace) { - TIERED_LOG(VecSimCommonStrings::LOG_VERBOSE_STRING, - "running synchronous GC for tiered SVS index in write-in-place mode"); - // In write-in-place mode, we run GC synchronously. - std::lock_guard lock{this->mainIndexGuard}; - if (this->backendIndex->indexSize() == 0) { - // No need to run GC on an empty index. - return; - } - // Force single thread for write-in-place mode. - this->GetSVSIndex()->setNumThreads(1); - // VecSimIndexAbstract::runGC() is protected - static_cast(this->backendIndex)->runGC(); - return; - } - TIERED_LOG(VecSimCommonStrings::LOG_VERBOSE_STRING, - "scheduling asynchronous GC for tiered SVS index"); - scheduleSVSIndexGC(); - } - - void acquireSharedLocks() override { - this->flatIndexGuard.lock_shared(); - this->mainIndexGuard.lock_shared(); - } - - void releaseSharedLocks() override { - this->mainIndexGuard.unlock_shared(); - this->flatIndexGuard.unlock_shared(); - } -}; diff --git a/src/VecSim/algorithms/svs/svs_utils.h b/src/VecSim/algorithms/svs/svs_utils.h deleted file mode 100644 index 2e240358f..000000000 --- a/src/VecSim/algorithms/svs/svs_utils.h +++ /dev/null @@ -1,450 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once -#include "VecSim/query_results.h" -#include "VecSim/types/float16.h" - -#include "svs/core/distance.h" -#include "svs/lib/float16.h" -#include "svs/index/vamana/dynamic_index.h" - -#if HAVE_SVS_LVQ -#include "svs/cpuid.h" -#endif - -#include -#include -#include -#include - -// Maximum training threshold for SVS index, used to limit the size of training data -constexpr size_t SVS_MAX_TRAINING_THRESHOLD = 100 * DEFAULT_BLOCK_SIZE; // 100 * 1024 vectors -// Default wait time for the update job in microseconds -constexpr size_t SVS_DEFAULT_UPDATE_JOB_WAIT_TIME = 5000; // 5 ms - -namespace svs_details { -// VecSim->SVS data type conversion -template -struct vecsim_dtype; - -template <> -struct vecsim_dtype { - using type = float; -}; - -template <> -struct vecsim_dtype { - using type = vecsim_types::float16; -}; - -template -using vecsim_dt = typename vecsim_dtype::type; - -// SVS->VecSim distance conversion -template -double toVecSimDistance(float); - -template <> -inline double toVecSimDistance(float v) { - return static_cast(v); -} - -template <> -inline double toVecSimDistance(float v) { - return 1.0 - static_cast(v); -} - -template <> -inline double toVecSimDistance(float v) { - return 1.0 - static_cast(v); -} - -// VecSim allocator wrapper for SVS containers -template -using SVSAllocator = VecsimSTLAllocator; - -template -static T getOrDefault(T v, U def) { - return v != T{} ? v : static_cast(def); -} - -inline svs::index::vamana::VamanaBuildParameters -makeVamanaBuildParameters(const SVSParams ¶ms) { - // clang-format off - // evaluate optimal default parameters; current assumption: - // * alpha (1.2 or 0.95) depends on metric: L2: > 1.0, IP, Cosine: < 1.0 - // In the Vamana algorithm implementation in SVS, the choice of alpha value - // depends on the type of similarity measure used. For L2, which minimizes distance, - // an alpha value greater than 1 is needed, typically around 1.2. - // For Inner Product and Cosine, which maximize similarity or distance, - // the alpha value should be less than 1, usually 0.9 or 0.95 works. - // * construction_window_size (200): similar to HNSW_EF_CONSTRUCTION - // * graph_max_degree (32): similar to HNSW_M * 2 - // * max_candidate_pool_size (600): =~ construction_window_size * 3 - // * prune_to (28): < graph_max_degree, optimal = graph_max_degree - 4 - // The prune_to parameter is a performance feature designed to enhance build time - // by setting a small difference between this value and the maximum graph degree. - // This acts as a threshold for how much pruning can reduce the number of neighbors. - // Typically, a small gap of 4 or 8 is sufficient to improve build time - // without compromising the quality of the graph. - // * use_search_history (true): now: is enabled if not disabled explicitly - // future: default value based on other index parameters - const auto construction_window_size = getOrDefault(params.construction_window_size, SVS_VAMANA_DEFAULT_CONSTRUCTION_WINDOW_SIZE); - const auto graph_max_degree = getOrDefault(params.graph_max_degree, SVS_VAMANA_DEFAULT_GRAPH_MAX_DEGREE); - - // More info about VamanaBuildParameters can be found there: - // https://intel.github.io/ScalableVectorSearch/python/vamana.html#svs.VamanaBuildParameters - return svs::index::vamana::VamanaBuildParameters{ - getOrDefault(params.alpha, (params.metric == VecSimMetric_L2 ? - SVS_VAMANA_DEFAULT_ALPHA_L2 : SVS_VAMANA_DEFAULT_ALPHA_IP)), - graph_max_degree, - construction_window_size, - getOrDefault(params.max_candidate_pool_size, construction_window_size * 3), - getOrDefault(params.prune_to, graph_max_degree - 4), - params.use_search_history == VecSimOption_AUTO ? SVS_VAMANA_DEFAULT_USE_SEARCH_HISTORY : - params.use_search_history == VecSimOption_ENABLE, - }; - // clang-format on -} - -// Join default SVS search parameters with VecSim query runtime parameters -inline svs::index::vamana::VamanaSearchParameters -joinSearchParams(svs::index::vamana::VamanaSearchParameters &&sp, - const VecSimQueryParams *queryParams, bool is_two_level_lvq) { - if (queryParams == nullptr) { - return std::move(sp); - } - - auto &rt_params = queryParams->svsRuntimeParams; - size_t sws = sp.buffer_config_.get_search_window_size(); - size_t sbc = sp.buffer_config_.get_total_capacity(); - - // buffer capacity is changed only if window size is changed - if (rt_params.windowSize > 0) { - sws = rt_params.windowSize; - if (rt_params.bufferCapacity > 0) { - // case 1: change both window size and buffer capacity - sbc = rt_params.bufferCapacity; - } else { - // case 2: change only window size - // In this case, set buffer capacity based on window size - if (!is_two_level_lvq) { - // set buffer capacity to windowSize - sbc = rt_params.windowSize; - } else { - // set buffer capacity to windowSize * 1.5 for Two-level LVQ - sbc = static_cast(rt_params.windowSize * 1.5); - } - } - } - sp.buffer_config({sws, sbc}); - switch (rt_params.searchHistory) { - case VecSimOption_ENABLE: - sp.search_buffer_visited_set(true); - break; - case VecSimOption_DISABLE: - sp.search_buffer_visited_set(false); - break; - default: // AUTO mode, let the algorithm decide - break; - } - return std::move(sp); -} - -// @brief Block size for SVS storage required to be a power-of-two -// @param bs VecSim block size -// @param elem_size SVS storage element size -// @return block size in type of SVS `PowerOfTwo` -inline svs::lib::PowerOfTwo SVSBlockSize(size_t bs, size_t elem_size) { - auto svs_bs = svs::lib::prevpow2(bs * elem_size); - // block size should not be less than element size - while (svs_bs.value() < elem_size) { - svs_bs = svs::lib::PowerOfTwo{svs_bs.raw() + 1}; - } - return svs_bs; -} - -// Check if the SVS implementation supports Quantization mode -// @param quant_bits requested SVS quantization mode -// @return pair -// @note even if VecSimSvsQuantBits is a simple enum value, -// in theory, it can be a complex type with a combination of modes: -// - primary bits, secondary/residual bits, dimesionality reduction, etc. -// which can be incompatible to each-other. -inline std::pair isSVSQuantBitsSupported(VecSimSvsQuantBits quant_bits) { - switch (quant_bits) { - // non-quantized mode and scalar quantization are always supported - case VecSimSvsQuant_NONE: - case VecSimSvsQuant_Scalar: - return std::make_pair(quant_bits, true); - default: - // fallback to no quantization if we have no LVQ support in code - // or if the CPU doesn't support it -#if HAVE_SVS_LVQ - return svs::detail::intel_enabled() ? std::make_pair(quant_bits, true) - : std::make_pair(VecSimSvsQuant_Scalar, true); -#else - return std::make_pair(VecSimSvsQuant_Scalar, true); -#endif - } - assert(false && "Should never reach here"); - // unreachable code, but to avoid compiler warning - return std::make_pair(VecSimSvsQuant_NONE, false); -} -} // namespace svs_details - -template -struct SVSStorageTraits { - using allocator_type = svs_details::SVSAllocator; - // In SVS, the default allocator is designed for static indices, - // where the size of the data or graph is known in advance, - // allowing all structures to be allocated at once. In contrast, - // the Blocked allocator supports dynamic allocations, - // enabling memory to be allocated in blocks as needed when the index size grows. - using blocked_type = svs::data::Blocked; // Used in creating storage - // svs::Dynamic means runtime dimensionality in opposite to compile-time dimensionality - using index_storage_type = svs::data::BlockedData; - - static constexpr bool is_compressed() { return false; } - - static constexpr VecSimSvsQuantBits get_compression_mode() { - return VecSimSvsQuant_NONE; // No compression for this storage - } - - static blocked_type make_blocked_allocator(size_t block_size, size_t dim, - std::shared_ptr allocator) { - // SVS storage element size and block size can be differ than VecSim - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(dim)); - allocator_type data_allocator{std::move(allocator)}; - return blocked_type{{svs_bs}, data_allocator}; - } - - template - static index_storage_type create_storage(const Dataset &data, size_t block_size, Pool &pool, - std::shared_ptr allocator, - size_t /* leanvec_dim */) { - const auto dim = data.dimensions(); - const auto size = data.size(); - // Allocate initial SVS storage for index - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - index_storage_type init_data{size, dim, blocked_alloc}; - // Copy data to allocated storage - svs::threads::parallel_for(pool, svs::threads::StaticPartition(data.eachindex()), - [&](auto is, auto SVS_UNUSED(tid)) { - for (auto i : is) { - init_data.set_datum(i, data.get_datum(i)); - } - }); - return init_data; - } - - static index_storage_type load(const svs::lib::LoadTable &table, size_t block_size, size_t dim, - std::shared_ptr allocator) { - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - // Load the data from disk - return index_storage_type::load(table, blocked_alloc); - } - - static index_storage_type load(const std::string &path, size_t block_size, size_t dim, - std::shared_ptr allocator) { - auto blocked_alloc = make_blocked_allocator(block_size, dim, std::move(allocator)); - // Load the data from disk - return index_storage_type::load(path, blocked_alloc); - } - - // SVS storage element size can be differ than VecSim DataSize - static constexpr size_t element_size(size_t dims, size_t /*alignment*/ = 0, - size_t /*leanvec_dim*/ = 0) { - return dims * sizeof(DataType); - } - - static size_t storage_capacity(const index_storage_type &storage) { return storage.capacity(); } -}; - -template -struct SVSGraphBuilder { - using allocator_type = svs_details::SVSAllocator; - using blocked_type = svs::data::Blocked; - using graph_data_type = svs::data::BlockedData; - using graph_type = svs::graphs::SimpleGraph; - - static blocked_type make_blocked_allocator(size_t block_size, size_t graph_max_degree, - std::shared_ptr allocator) { - // SVS block size is a power of two, so we can use it directly - auto svs_bs = svs_details::SVSBlockSize(block_size, element_size(graph_max_degree)); - allocator_type data_allocator{std::move(allocator)}; - return blocked_type{{svs_bs}, data_allocator}; - } - - // Build SVS Graph using custom allocator - // The logic has been taken from one of `MutableVamanaIndex` constructors - // See: - // https://github.com/intel/ScalableVectorSearch/blob/main/include/svs/index/vamana/dynamic_index.h#L189 - template - static graph_type build_graph(const svs::index::vamana::VamanaBuildParameters ¶meters, - const Data &data, DistType distance, Pool &threadpool, - SVSIdType entry_point, size_t block_size, - std::shared_ptr allocator, - const svs::logging::logger_ptr &logger) { - // Perform graph construction. - auto blocked_alloc = - make_blocked_allocator(block_size, parameters.graph_max_degree, std::move(allocator)); - auto graph = graph_type{data.size(), parameters.graph_max_degree, blocked_alloc}; - // SVS incorporates an advanced software prefetching scheme with two parameters: step and - // lookahead. These parameters determine how far ahead to prefetch data vectors - // and how many items to prefetch at a time. We have set default values for these parameters - // based on the data types, which we found to perform better through heuristic analysis. - auto prefetch_parameters = - svs::index::vamana::extensions::estimate_prefetch_parameters(data); - auto builder = svs::index::vamana::VamanaBuilder( - graph, data, std::move(distance), parameters, threadpool, prefetch_parameters, logger); - - // Specific to the Vamana algorithm: - // It builds in two rounds, one with alpha=1 and the second time with the user/config - // provided alpha value. - builder.construct(1.0f, entry_point, svs::logging::Level::Trace, logger); - builder.construct(parameters.alpha, entry_point, svs::logging::Level::Trace, logger); - return graph; - } - - static graph_type load(const svs::lib::LoadTable &table, size_t block_size, - const svs::index::vamana::VamanaBuildParameters ¶meters, - std::shared_ptr allocator) { - auto blocked_alloc = - make_blocked_allocator(block_size, parameters.graph_max_degree, std::move(allocator)); - // Load the graph from disk - return graph_type::load(table, blocked_alloc); - } - - static graph_type load(const std::string &path, size_t block_size, - const svs::index::vamana::VamanaBuildParameters ¶meters, - std::shared_ptr allocator) { - auto blocked_alloc = - make_blocked_allocator(block_size, parameters.graph_max_degree, std::move(allocator)); - // Load the graph from disk - return graph_type::load(path, blocked_alloc); - } - - // SVS Vamana graph element size - static constexpr size_t element_size(size_t graph_max_degree, size_t alignment = 0) { - // For every Vamana graph node SVS allocates a record with current node ID and - // graph_max_degree neighbors - return sizeof(SVSIdType) * (graph_max_degree + 1); - } -}; - -// Custom thread pool for SVS index -// Based on svs::threads::NativeThreadPoolBase with changes: -// * Number of threads is fixed on construction time -// * Pool is resizable in bounds of pre-allocated threads -class VecSimSVSThreadPoolImpl { -public: - // Allocate `num_threads - 1` threads since the main thread participates in the work - // as well. - explicit VecSimSVSThreadPoolImpl(size_t num_threads = 1) - : size_{num_threads}, threads_(num_threads - 1) {} - - size_t capacity() const { return threads_.size() + 1; } - size_t size() const { return size_; } - - // Support resize - do not modify threads container just limit the size - void resize(size_t new_size) { - std::lock_guard lock{use_mutex_}; - size_ = std::clamp(new_size, size_t{1}, threads_.size() + 1); - } - - void parallel_for(std::function f, size_t n) { - if (n > size_) { - throw svs::threads::ThreadingException("Number of tasks exceeds the thread pool size"); - } - if (n == 0) { - return; - } else if (n == 1) { - // Run on the main function. - try { - f(0); - } catch (const std::exception &error) { - manage_exception_during_run(error.what()); - } - return; - } else { - std::lock_guard lock{use_mutex_}; - for (size_t i = 0; i < n - 1; ++i) { - threads_[i].assign({&f, i + 1}); - } - // Run on the main function. - try { - f(0); - } catch (const std::exception &error) { - manage_exception_during_run(error.what()); - } - - // Wait until all threads are done. - // If any thread fails, then we're throwing. - for (size_t i = 0; i < size_ - 1; ++i) { - auto &thread = threads_[i]; - thread.wait(); - if (!thread.is_okay()) { - manage_exception_during_run(); - } - } - } - } - - void manage_exception_during_run(const std::string &thread_0_message = {}) { - auto message = std::string{}; - auto inserter = std::back_inserter(message); - if (!thread_0_message.empty()) { - fmt::format_to(inserter, "Thread 0: {}\n", thread_0_message); - } - - // Manage all other exceptions thrown, restarting crashed threads. - for (size_t i = 0; i < size_ - 1; ++i) { - auto &thread = threads_[i]; - thread.wait(); - if (!thread.is_okay()) { - try { - thread.unsafe_get_exception(); - } catch (const std::exception &error) { - fmt::format_to(inserter, "Thread {}: {}\n", i + 1, error.what()); - } - // Restart the thread. - threads_[i].shutdown(); - threads_[i] = svs::threads::Thread{}; - } - } - throw svs::threads::ThreadingException{std::move(message)}; - } - -private: - std::mutex use_mutex_; - size_t size_; - std::vector threads_; -}; - -// Copy-movable wrapper for VecSimSVSThreadPoolImpl -class VecSimSVSThreadPool { -private: - std::shared_ptr pool_; - -public: - explicit VecSimSVSThreadPool(size_t num_threads = 1) - : pool_{std::make_shared(num_threads)} {} - - size_t capacity() const { return pool_->capacity(); } - size_t size() const { return pool_->size(); } - - void parallel_for(std::function f, size_t n) { - pool_->parallel_for(std::move(f), n); - } - - void resize(size_t new_size) { pool_->resize(new_size); } -}; diff --git a/src/VecSim/batch_iterator.h b/src/VecSim/batch_iterator.h deleted file mode 100644 index 9e2791130..000000000 --- a/src/VecSim/batch_iterator.h +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/vec_sim.h" -#include "VecSim/memory/vecsim_base.h" - -/** - * An abstract class for performing search in batches. Every index type should implement its own - * batch iterator class. - * A batch iterator instance is NOT meant to be shared between threads, but the iterated index can - * be and in this case the iterator should be able to iterate the index concurrently and safely. - */ -struct VecSimBatchIterator : public VecsimBaseObject { -private: - void *query_vector; - size_t returned_results_count; - void *timeoutCtx; - -public: - explicit VecSimBatchIterator(void *query_vector, void *tctx, - std::shared_ptr allocator) - : VecsimBaseObject(allocator), query_vector(query_vector), returned_results_count(0), - timeoutCtx(tctx) {}; - - virtual inline const void *getQueryBlob() const { return query_vector; } - - inline void *getTimeoutCtx() const { return timeoutCtx; } - - inline size_t getResultsCount() const { return returned_results_count; } - - inline void updateResultsCount(size_t num) { returned_results_count += num; } - - inline void resetResultsCount() { returned_results_count = 0; } - - // Returns the Top n_res results that *hasn't been returned* in the previous calls. - // The implementation is specific to the underline index algorithm. - virtual VecSimQueryReply *getNextResults(size_t n_res, VecSimQueryReply_Order order) = 0; - - // Indicates whether there are additional results from the index to return - virtual bool isDepleted() = 0; - - // Reset the iterator to the initial state, before any results has been returned. - virtual void reset() = 0; - - virtual ~VecSimBatchIterator() noexcept { allocator->free_allocation(this->query_vector); }; -}; diff --git a/src/VecSim/containers/data_block.cpp b/src/VecSim/containers/data_block.cpp deleted file mode 100644 index 898a2f085..000000000 --- a/src/VecSim/containers/data_block.cpp +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "data_block.h" -#include "VecSim/memory/vecsim_malloc.h" -#include - -DataBlock::DataBlock(size_t blockSize, size_t elementBytesCount, - std::shared_ptr allocator, unsigned char alignment) - : VecsimBaseObject(allocator), element_bytes_count(elementBytesCount), length(0), - data((char *)this->allocator->allocate_aligned(blockSize * elementBytesCount, alignment)) {} - -DataBlock::DataBlock(DataBlock &&other) noexcept - : VecsimBaseObject(other.allocator), element_bytes_count(other.element_bytes_count), - length(other.length), data(other.data) { - other.data = nullptr; // take ownership of the data -} - -DataBlock::~DataBlock() noexcept { this->allocator->free_allocation(data); } - -void DataBlock::addElement(const void *element) { - - // Copy element data and update block size. - memcpy(this->data + (this->length * element_bytes_count), element, element_bytes_count); - this->length++; -} - -void DataBlock::updateElement(size_t index, const void *new_element) { - char *destinaion = (char *)getElement(index); - memcpy(destinaion, new_element, element_bytes_count); -} diff --git a/src/VecSim/containers/data_block.h b/src/VecSim/containers/data_block.h deleted file mode 100644 index efef04749..000000000 --- a/src/VecSim/containers/data_block.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/memory/vecsim_base.h" - -#include - -struct DataBlock : public VecsimBaseObject { - -public: - DataBlock(size_t blockSize, size_t elementBytesCount, - std::shared_ptr allocator, unsigned char alignment = 0); - // Move constructor - // We need to implement this because we want to have a vector of DataBlocks, and we want it to - // use the move constructor upon resizing (instead of the copy constructor). We also need to - // mark it as noexcept so the vector will use it. - DataBlock(DataBlock &&other) noexcept; - ~DataBlock() noexcept; - // Delete copy constructor so we won't have a vector of DataBlocks that uses the copy - // constructor - DataBlock(const DataBlock &other) = delete; - - DataBlock &operator=(DataBlock &&other) noexcept { - allocator = other.allocator; - element_bytes_count = other.element_bytes_count; - length = other.length; - // take ownership of the data - data = other.data; - other.data = nullptr; - return *this; - }; - - void addElement(const void *element); - - void updateElement(size_t index, const void *new_element); - - const char *getElement(size_t index) const { - return this->data + (index * element_bytes_count); - } - - void popLastElement() { - assert(this->length > 0); - this->length--; - } - - char *removeAndFetchLastElement() { - return this->data + ((--this->length) * element_bytes_count); - } - - size_t getLength() const { return length; } - -private: - // Element size in bytes - size_t element_bytes_count; - // Current block length. - size_t length; - // Elements hosted in the block. - char *data; -}; diff --git a/src/VecSim/containers/data_blocks_container.cpp b/src/VecSim/containers/data_blocks_container.cpp deleted file mode 100644 index bf63b683a..000000000 --- a/src/VecSim/containers/data_blocks_container.cpp +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#include "data_blocks_container.h" -#include "VecSim/algorithms/hnsw/hnsw_serializer.h" -#include - -DataBlocksContainer::DataBlocksContainer(size_t blockSize, size_t elementBytesCount, - std::shared_ptr allocator, - unsigned char _alignment) - : VecsimBaseObject(allocator), RawDataContainer(), element_bytes_count(elementBytesCount), - element_count(0), blocks(allocator), block_size(blockSize), alignment(_alignment) {} - -DataBlocksContainer::~DataBlocksContainer() = default; - -size_t DataBlocksContainer::size() const { return element_count; } - -size_t DataBlocksContainer::capacity() const { return blocks.capacity(); } - -size_t DataBlocksContainer::blockSize() const { return block_size; } - -size_t DataBlocksContainer::elementByteCount() const { return element_bytes_count; } - -RawDataContainer::Status DataBlocksContainer::addElement(const void *element, size_t id) { - assert(id == element_count); // we can only append new elements - if (element_count % block_size == 0) { - blocks.emplace_back(this->block_size, this->element_bytes_count, this->allocator, - this->alignment); - } - blocks.back().addElement(element); - element_count++; - return Status::OK; -} - -const char *DataBlocksContainer::getElement(size_t id) const { - assert(id < element_count); - return blocks.at(id / this->block_size).getElement(id % this->block_size); -} - -RawDataContainer::Status DataBlocksContainer::removeElement(size_t id) { - assert(id == element_count - 1); // only the last element can be removed - blocks.back().popLastElement(); - if (blocks.back().getLength() == 0) { - blocks.pop_back(); - } - element_count--; - return Status::OK; -} - -RawDataContainer::Status DataBlocksContainer::updateElement(size_t id, const void *element) { - assert(id < element_count); - auto &block = blocks.at(id / this->block_size); - block.updateElement(id % block_size, element); // update the relative index in the block - return Status::OK; -} - -std::unique_ptr DataBlocksContainer::getIterator() const { - return std::make_unique(*this); -} - -size_t DataBlocksContainer::numBlocks() const { return this->blocks.size(); } - -#ifdef BUILD_TESTS -void DataBlocksContainer::saveVectorsData(std::ostream &output) const { - // Save data blocks - for (size_t i = 0; i < this->numBlocks(); i++) { - auto &block = this->blocks[i]; - unsigned int block_len = block.getLength(); - for (size_t j = 0; j < block_len; j++) { - output.write(block.getElement(j), this->element_bytes_count); - } - } -} - -void DataBlocksContainer::restoreBlocks(std::istream &input, size_t num_vectors, - Serializer::EncodingVersion version) { - - // Get number of blocks - unsigned int num_blocks = 0; - HNSWSerializer::EncodingVersion hnsw_version = - static_cast(version); - if (hnsw_version == HNSWSerializer::EncodingVersion::V3) { - // In V3, the number of blocks is serialized, so we need to read it from the file. - Serializer::readBinaryPOD(input, num_blocks); - } else { - // Otherwise, calculate the number of blocks based on the number of vectors. - num_blocks = std::ceil((float)num_vectors / this->block_size); - } - this->blocks.reserve(num_blocks); - - // Get data blocks - for (size_t i = 0; i < num_blocks; i++) { - this->blocks.emplace_back(this->block_size, this->element_bytes_count, this->allocator, - this->alignment); - unsigned int block_len = 0; - if (hnsw_version == HNSWSerializer::EncodingVersion::V3) { - // In V3, the length of each block is serialized, so we need to read it from the file. - Serializer::readBinaryPOD(input, block_len); - } else { - size_t vectors_left = num_vectors - this->element_count; - block_len = vectors_left > this->block_size ? this->block_size : vectors_left; - } - for (size_t j = 0; j < block_len; j++) { - auto cur_vec = this->getAllocator()->allocate_unique(this->element_bytes_count); - input.read(static_cast(cur_vec.get()), - (std::streamsize)this->element_bytes_count); - this->blocks.back().addElement(cur_vec.get()); - this->element_count++; - } - } -} - -void DataBlocksContainer::shrinkToFit() { this->blocks.shrink_to_fit(); } - -#endif -/********************************** Iterator API ************************************************/ - -DataBlocksContainer::Iterator::Iterator(const DataBlocksContainer &container_) - : RawDataContainer::Iterator(), cur_id(0), cur_element(nullptr), container(container_) {} - -bool DataBlocksContainer::Iterator::hasNext() const { - return this->cur_id != this->container.size(); -} - -const char *DataBlocksContainer::Iterator::next() { - if (!this->hasNext()) { - return nullptr; - } - // Advance the pointer to the next element in the current block, or in the next block. - if (this->cur_id % container.blockSize() == 0) { - this->cur_element = container.getElement(this->cur_id); - } else { - this->cur_element += container.elementByteCount(); - } - this->cur_id++; - return this->cur_element; -} - -void DataBlocksContainer::Iterator::reset() { - this->cur_id = 0; - this->cur_element = nullptr; -} diff --git a/src/VecSim/containers/data_blocks_container.h b/src/VecSim/containers/data_blocks_container.h deleted file mode 100644 index c375590f2..000000000 --- a/src/VecSim/containers/data_blocks_container.h +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include "data_block.h" -#include "raw_data_container_interface.h" -#include "VecSim/memory/vecsim_malloc.h" -#include "VecSim/utils/serializer.h" -#include "VecSim/utils/vecsim_stl.h" - -class DataBlocksContainer : public VecsimBaseObject, public RawDataContainer { - - size_t element_bytes_count; // Element size in bytes - size_t element_count; // Number of items in the container - vecsim_stl::vector blocks; // data blocks - size_t block_size; // number of element in block - unsigned char alignment; // alignment for data allocation in each block - -public: - DataBlocksContainer(size_t blockSize, size_t elementBytesCount, - std::shared_ptr allocator, unsigned char alignment = 0); - ~DataBlocksContainer(); - - // Number of elements in the container. - size_t size() const override; - - // Number of blocks allocated. - size_t capacity() const; - - size_t blockSize() const; - - size_t elementByteCount() const; - - Status addElement(const void *element, size_t id) override; - - const char *getElement(size_t id) const override; - - Status removeElement(size_t id) override; - - Status updateElement(size_t id, const void *element) override; - - std::unique_ptr getIterator() const override; - - size_t numBlocks() const; -#ifdef BUILD_TESTS - void saveVectorsData(std::ostream &output) const override; - // Use that in deserialization when file was created with old version (v3) that serialized - // the blocks themselves and not just thw raw vector data. - void restoreBlocks(std::istream &input, size_t num_vectors, - Serializer::EncodingVersion version); - void shrinkToFit(); -#endif - - class Iterator : public RawDataContainer::Iterator { - size_t cur_id; - const char *cur_element; - const DataBlocksContainer &container; - - public: - explicit Iterator(const DataBlocksContainer &container); - ~Iterator() override = default; - - bool hasNext() const override; - const char *next() override; - void reset() override; - }; -}; diff --git a/src/VecSim/containers/raw_data_container_interface.h b/src/VecSim/containers/raw_data_container_interface.h deleted file mode 100644 index 29a76f74e..000000000 --- a/src/VecSim/containers/raw_data_container_interface.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -struct RawDataContainer { - enum class Status { OK = 0, ID_ALREADY_EXIST, ID_NOT_EXIST, ERR }; - /** - * This is an abstract interface, constructor/destructor should be implemented by the derived - * classes - */ - RawDataContainer() = default; - virtual ~RawDataContainer() = default; - - /** - * @return number of elements in the container - */ - virtual size_t size() const = 0; - /** - * @param element element's raw data to be added into the container - * @param id of the new element - * @return status - */ - virtual Status addElement(const void *element, size_t id) = 0; - /** - * @param id of the element to return - * @return Immutable reference to the element's data, NULL if id doesn't exist - */ - virtual const char *getElement(size_t id) const = 0; - /** - * @param id of the element to remove - * @return status - */ - virtual Status removeElement(size_t id) = 0; - /** - * @param id to change its asociated data - * @param element the new raw data to associate with id - * @return status - */ - virtual Status updateElement(size_t id, const void *element) = 0; - - struct Iterator { - /** - * This is an abstract interface, constructor/destructor should be implemented by the - * derived classes - */ - Iterator() = default; - virtual ~Iterator() = default; - - /** - * The basic iterator operations API - */ - virtual bool hasNext() const = 0; - virtual const char *next() = 0; - virtual void reset() = 0; - }; - - /** - * Create a new iterator. Should be freed by the iterator's destroctor. - */ - virtual std::unique_ptr getIterator() const = 0; - -#ifdef BUILD_TESTS - /** - * Save the raw data of all elements in the container to the output stream. - */ - virtual void saveVectorsData(std::ostream &output) const = 0; -#endif -}; diff --git a/src/VecSim/containers/vecsim_results_container.h b/src/VecSim/containers/vecsim_results_container.h deleted file mode 100644 index 9d3775e50..000000000 --- a/src/VecSim/containers/vecsim_results_container.h +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/query_result_definitions.h" -#include "VecSim/utils/vecsim_stl.h" - -namespace vecsim_stl { - -// An abstract API for query result container, used by RANGE queries. -struct abstract_results_container : public VecsimBaseObject { -public: - abstract_results_container(const std::shared_ptr &alloc) - : VecsimBaseObject(alloc) {} - ~abstract_results_container() = default; - - // Inserts (or updates) a new result to the container. - virtual inline void emplace(size_t id, double score) = 0; - - // Returns the size of the container - virtual inline size_t size() const = 0; - - // Returns a vector containing all current data, and passes its ownership - virtual inline VecSimQueryResultContainer get_results() = 0; -}; - -struct unique_results_container : public abstract_results_container { -private: - vecsim_stl::unordered_map idToScore; - -public: - explicit unique_results_container(const std::shared_ptr &alloc) - : abstract_results_container(alloc), idToScore(alloc) {} - explicit unique_results_container(size_t cap, const std::shared_ptr &alloc) - : abstract_results_container(alloc), idToScore(cap, alloc) {} - - inline void emplace(size_t id, double score) override { - auto existing = idToScore.find(id); - if (existing == idToScore.end()) { - idToScore.emplace(id, score); - } else if (existing->second > score) { - existing->second = score; - } - } - - inline size_t size() const override { return idToScore.size(); } - - inline VecSimQueryResultContainer get_results() override { - VecSimQueryResultContainer results(this->allocator); - results.reserve(idToScore.size()); - for (auto res : idToScore) { - results.push_back(VecSimQueryResult{res.first, res.second}); - } - return results; - } -}; - -struct default_results_container : public abstract_results_container { -private: - VecSimQueryResultContainer _data; - -public: - explicit default_results_container(const std::shared_ptr &alloc) - : abstract_results_container(alloc), _data(alloc) {} - explicit default_results_container(size_t cap, const std::shared_ptr &alloc) - : abstract_results_container(alloc), _data(alloc) { - _data.reserve(cap); - } - ~default_results_container() = default; - - inline void emplace(size_t id, double score) override { - _data.push_back(VecSimQueryResult{id, score}); - } - inline size_t size() const override { return _data.size(); } - inline VecSimQueryResultContainer get_results() override { return std::move(_data); } -}; -} // namespace vecsim_stl diff --git a/src/VecSim/friend_test_decl.h b/src/VecSim/friend_test_decl.h deleted file mode 100644 index 3f34bb6c2..000000000 --- a/src/VecSim/friend_test_decl.h +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#ifndef INDEX_TEST_FRIEND_CLASS -#define INDEX_TEST_FRIEND_CLASS(class_name) \ - template \ - friend class class_name; -#endif diff --git a/src/VecSim/index_factories/brute_force_factory.cpp b/src/VecSim/index_factories/brute_force_factory.cpp deleted file mode 100644 index 9f18e40ad..000000000 --- a/src/VecSim/index_factories/brute_force_factory.cpp +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/index_factories/brute_force_factory.h" -#include "VecSim/algorithms/brute_force/brute_force.h" -#include "VecSim/algorithms/brute_force/brute_force_single.h" -#include "VecSim/algorithms/brute_force/brute_force_multi.h" -#include "VecSim/index_factories/components/components_factory.h" -#include "VecSim/index_factories/factory_utils.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; - -namespace BruteForceFactory { -template -inline VecSimIndex *NewIndex_ChooseMultiOrSingle(const BFParams *params, - const AbstractIndexInitParams &abstractInitParams, - IndexComponents &components) { - - // check if single and return new bf_index - if (params->multi) - return new (abstractInitParams.allocator) - BruteForceIndex_Multi(params, abstractInitParams, components); - else - return new (abstractInitParams.allocator) - BruteForceIndex_Single(params, abstractInitParams, components); -} - -VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized) { - const BFParams *bfParams = ¶ms->algoParams.bfParams; - AbstractIndexInitParams abstractInitParams = - VecSimFactory::NewAbstractInitParams(bfParams, params->logCtx, is_normalized); - return NewIndex(bfParams, abstractInitParams, is_normalized); -} - -VecSimIndex *NewIndex(const BFParams *bfparams, const AbstractIndexInitParams &abstractInitParams, - bool is_normalized) { - assert(is_normalized || - abstractInitParams.inputBlobSize == bfparams->dim * VecSimType_sizeof(bfparams->type)); - assert(!is_normalized || - abstractInitParams.inputBlobSize != bfparams->dim * VecSimType_sizeof(bfparams->type)); - if (bfparams->type == VecSimType_FLOAT32) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, bfparams->metric, bfparams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(bfparams, abstractInitParams, indexComponents); - } else if (bfparams->type == VecSimType_FLOAT64) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, bfparams->metric, bfparams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(bfparams, abstractInitParams, indexComponents); - } else if (bfparams->type == VecSimType_BFLOAT16) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, bfparams->metric, bfparams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(bfparams, abstractInitParams, - indexComponents); - } else if (bfparams->type == VecSimType_FLOAT16) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, bfparams->metric, bfparams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(bfparams, abstractInitParams, - indexComponents); - } else if (bfparams->type == VecSimType_INT8) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, bfparams->metric, bfparams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(bfparams, abstractInitParams, - indexComponents); - } else if (bfparams->type == VecSimType_UINT8) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, bfparams->metric, bfparams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(bfparams, abstractInitParams, - indexComponents); - } - - // If we got here something is wrong. - return NULL; -} - -VecSimIndex *NewIndex(const BFParams *bfparams, bool is_normalized) { - VecSimParams params = {.algoParams{.bfParams = BFParams{*bfparams}}}; - return NewIndex(¶ms, is_normalized); -} - -template -inline size_t EstimateInitialSize_ChooseMultiOrSingle(bool is_multi) { - // check if single and return new bf_index - if (is_multi) - return sizeof(BruteForceIndex_Multi); - else - return sizeof(BruteForceIndex_Single); -} - -size_t EstimateInitialSize(const BFParams *params, bool is_normalized) { - - size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); - - // Constant part (not effected by parameters). - size_t est = sizeof(VecSimAllocator) + allocations_overhead; - - if (params->type == VecSimType_FLOAT32) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_FLOAT64) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_BFLOAT16) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_FLOAT16) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_INT8) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_UINT8) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else { - throw std::invalid_argument("Invalid params->type"); - } - - est += sizeof(DataBlocksContainer) + allocations_overhead; - return est; -} - -size_t EstimateElementSize(const BFParams *params) { - // counting the vector size + idToLabel entry + LabelToIds entry (map reservation) - return VecSimParams_GetStoredDataSize(params->type, params->dim, params->metric) + - sizeof(labelType) + sizeof(void *); -} -}; // namespace BruteForceFactory diff --git a/src/VecSim/index_factories/brute_force_factory.h b/src/VecSim/index_factories/brute_force_factory.h deleted file mode 100644 index 5828e6446..000000000 --- a/src/VecSim/index_factories/brute_force_factory.h +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include // size_t -#include // std::shared_ptr - -#include "VecSim/vec_sim.h" //typedef VecSimIndex -#include "VecSim/vec_sim_common.h" // BFParams -#include "VecSim/memory/vecsim_malloc.h" // VecSimAllocator -#include "VecSim/vec_sim_index.h" - -namespace BruteForceFactory { -/** Overloading the NewIndex function to support different parameters - * @param is_normalized is used to determine the index's computer type. If the index metric is - * Cosine, and is_normalized == true, we will create the computer as if the metric is IP, assuming - * the blobs sent to the index are already normalized. For example, in case it's a tiered index, - * where the blobs are normalized by the frontend index. - */ -VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized = false); -VecSimIndex *NewIndex(const BFParams *bfparams, bool is_normalized = false); -VecSimIndex *NewIndex(const BFParams *bfparams, const AbstractIndexInitParams &abstractInitParams, - bool is_normalized); -size_t EstimateInitialSize(const BFParams *params, bool is_normalized = false); -size_t EstimateElementSize(const BFParams *params); - -}; // namespace BruteForceFactory diff --git a/src/VecSim/index_factories/components/components_factory.h b/src/VecSim/index_factories/components/components_factory.h deleted file mode 100644 index f52db7c5f..000000000 --- a/src/VecSim/index_factories/components/components_factory.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" -#include "VecSim/vec_sim_common.h" -#include "VecSim/vec_sim_index.h" -#include "VecSim/index_factories/components/preprocessors_factory.h" -#include "VecSim/spaces/computer/calculator.h" - -template -IndexComponents -CreateIndexComponents(std::shared_ptr allocator, VecSimMetric metric, size_t dim, - bool is_normalized) { - unsigned char alignment = 0; - spaces::dist_func_t distFunc = - spaces::GetDistFunc(metric, dim, &alignment); - // Currently we have only one distance calculator implementation - auto indexCalculator = new (allocator) DistanceCalculatorCommon(allocator, distFunc); - - // TODO: take into account quantization - auto preprocessors = - CreatePreprocessorsContainer(allocator, metric, dim, is_normalized, alignment); - - return {indexCalculator, preprocessors}; -} - -template -size_t EstimateComponentsMemory(VecSimMetric metric, bool is_normalized) { - size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); - - // Currently we have only one distance calculator implementation - size_t est = allocations_overhead + sizeof(DistanceCalculatorCommon); - - est += EstimatePreprocessorsContainerMemory(metric, is_normalized); - - return est; -} diff --git a/src/VecSim/index_factories/components/preprocessors_factory.h b/src/VecSim/index_factories/components/preprocessors_factory.h deleted file mode 100644 index c91863ea0..000000000 --- a/src/VecSim/index_factories/components/preprocessors_factory.h +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/computer/preprocessor_container.h" -#include "VecSim/vec_sim_common.h" - -struct PreprocessorsContainerParams { - VecSimMetric metric; - size_t dim; - unsigned char alignment; - size_t processed_bytes_count; -}; - -/** - * @brief Creates parameters for a preprocessors container based on the given metric, dimension, - * normalization flag, and alignment. - * - * @tparam DataType The data type of the vector elements (e.g., float, int). - * @param metric The similarity metric to be used (e.g., Cosine, Inner Product). - * @param dim The dimensionality of the vectors. - * @param is_normalized A flag indicating whether the vectors are already normalized. - * @param alignment The alignment requirement for the data. - * @return A PreprocessorsContainerParams object containing the processed parameters: - * - metric: The adjusted metric based on the input and normalization flag. - * - dim: The dimensionality of the vectors. - * - alignment: The alignment requirement for the data. - * - processed_bytes_count: The size of the processed data blob in bytes. - * - * @details - * If the metric is Cosine and the data type is integral, the processed bytes count may include - * additional space for normalization. If the vectors are already - * normalized (is_normalized == true), the metric is adjusted to Inner Product (IP) to skip - * redundant normalization during preprocessing. - */ -template -PreprocessorsContainerParams CreatePreprocessorsContainerParams(VecSimMetric metric, size_t dim, - bool is_normalized, - unsigned char alignment) { - // By default the processed blob size is the same as the original blob size. - size_t processed_bytes_count = dim * sizeof(DataType); - - VecSimMetric pp_metric = metric; - if (metric == VecSimMetric_Cosine) { - // if metric is cosine and DataType is integral, the processed_bytes_count includes the - // norm appended to the vector. - if (std::is_integral::value) { - processed_bytes_count += sizeof(float); - } - // if is_normalized == true, we will enforce skipping normalizing vector and query blobs by - // setting the metric to IP. - if (is_normalized) { - pp_metric = VecSimMetric_IP; - } - } - return {.metric = pp_metric, - .dim = dim, - .alignment = alignment, - .processed_bytes_count = processed_bytes_count}; -} - -template -PreprocessorsContainerAbstract * -CreatePreprocessorsContainer(std::shared_ptr allocator, - PreprocessorsContainerParams params) { - - if (params.metric == VecSimMetric_Cosine) { - auto multiPPContainer = - new (allocator) MultiPreprocessorsContainer(allocator, params.alignment); - auto cosine_preprocessor = new (allocator) - CosinePreprocessor(allocator, params.dim, params.processed_bytes_count); - int next_valid_pp_index = multiPPContainer->addPreprocessor(cosine_preprocessor); - UNUSED(next_valid_pp_index); - assert(next_valid_pp_index != -1 && "Cosine preprocessor was not added correctly"); - return multiPPContainer; - } - - return new (allocator) PreprocessorsContainerAbstract(allocator, params.alignment); -} - -template -PreprocessorsContainerAbstract * -CreatePreprocessorsContainer(std::shared_ptr allocator, VecSimMetric metric, - size_t dim, bool is_normalized, unsigned char alignment) { - - PreprocessorsContainerParams ppParams = - CreatePreprocessorsContainerParams(metric, dim, is_normalized, alignment); - return CreatePreprocessorsContainer(allocator, ppParams); -} - -template -size_t EstimatePreprocessorsContainerMemory(VecSimMetric metric, bool is_normalized = false) { - size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); - VecSimMetric pp_metric; - if (is_normalized && metric == VecSimMetric_Cosine) { - pp_metric = VecSimMetric_IP; - } else { - pp_metric = metric; - } - - if (pp_metric == VecSimMetric_Cosine) { - constexpr size_t n_preprocessors = 1; - // One entry in preprocessors array - size_t est = - allocations_overhead + sizeof(MultiPreprocessorsContainer); - est += allocations_overhead + sizeof(CosinePreprocessor); - return est; - } - - return allocations_overhead + sizeof(PreprocessorsContainerAbstract); -} diff --git a/src/VecSim/index_factories/factory_utils.h b/src/VecSim/index_factories/factory_utils.h deleted file mode 100644 index 271d2b5cc..000000000 --- a/src/VecSim/index_factories/factory_utils.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/vec_sim_index.h" - -namespace VecSimFactory { -template -static AbstractIndexInitParams NewAbstractInitParams(const IndexParams *algo_params, void *logCtx, - bool is_input_preprocessed) { - - size_t storedDataSize = - VecSimParams_GetStoredDataSize(algo_params->type, algo_params->dim, algo_params->metric); - - // If the input vectors are already processed (for example, normalized), the input blob size is - // the same as the stored data size. inputBlobSize = storedDataSize Otherwise, the input blob - // size is the original size of the vector. inputBlobSize = algo_params->dim * - // VecSimType_sizeof(algo_params->type) - size_t inputBlobSize = is_input_preprocessed - ? storedDataSize - : algo_params->dim * VecSimType_sizeof(algo_params->type); - AbstractIndexInitParams abstractInitParams = {.allocator = - VecSimAllocator::newVecsimAllocator(), - .dim = algo_params->dim, - .vecType = algo_params->type, - .storedDataSize = storedDataSize, - .metric = algo_params->metric, - .blockSize = algo_params->blockSize, - .multi = algo_params->multi, - .isDisk = false, - .logCtx = logCtx, - .inputBlobSize = inputBlobSize}; - return abstractInitParams; -} -} // namespace VecSimFactory diff --git a/src/VecSim/index_factories/hnsw_factory.cpp b/src/VecSim/index_factories/hnsw_factory.cpp deleted file mode 100644 index cb30ea734..000000000 --- a/src/VecSim/index_factories/hnsw_factory.cpp +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/algorithms/hnsw/hnsw_single.h" -#include "VecSim/algorithms/hnsw/hnsw_multi.h" -#include "VecSim/index_factories/hnsw_factory.h" -#include "VecSim/index_factories/components/components_factory.h" -#include "VecSim/index_factories/factory_utils.h" -#include "VecSim/algorithms/hnsw/hnsw.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; - -namespace HNSWFactory { - -template -inline HNSWIndex * -NewIndex_ChooseMultiOrSingle(const HNSWParams *params, - const AbstractIndexInitParams &abstractInitParams, - IndexComponents &components) { - // check if single and return new hnsw_index - if (params->multi) - return new (abstractInitParams.allocator) - HNSWIndex_Multi(params, abstractInitParams, components); - else - return new (abstractInitParams.allocator) - HNSWIndex_Single(params, abstractInitParams, components); -} - -VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized) { - const HNSWParams *hnswParams = ¶ms->algoParams.hnswParams; - AbstractIndexInitParams abstractInitParams = - VecSimFactory::NewAbstractInitParams(hnswParams, params->logCtx, is_normalized); - - if (hnswParams->type == VecSimType_FLOAT32) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, hnswParams->metric, hnswParams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(hnswParams, abstractInitParams, indexComponents); - - } else if (hnswParams->type == VecSimType_FLOAT64) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, hnswParams->metric, hnswParams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(hnswParams, abstractInitParams, - indexComponents); - - } else if (hnswParams->type == VecSimType_BFLOAT16) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, hnswParams->metric, hnswParams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(hnswParams, abstractInitParams, - indexComponents); - } else if (hnswParams->type == VecSimType_FLOAT16) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, hnswParams->metric, hnswParams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(hnswParams, abstractInitParams, - indexComponents); - } else if (hnswParams->type == VecSimType_INT8) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, hnswParams->metric, hnswParams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(hnswParams, abstractInitParams, - indexComponents); - } else if (hnswParams->type == VecSimType_UINT8) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, hnswParams->metric, hnswParams->dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(hnswParams, abstractInitParams, - indexComponents); - } - - // If we got here something is wrong. - return NULL; -} - -VecSimIndex *NewIndex(const HNSWParams *params, bool is_normalized) { - VecSimParams vecSimParams = {.algoParams = {.hnswParams = HNSWParams{*params}}}; - return NewIndex(&vecSimParams); -} - -template -inline size_t EstimateInitialSize_ChooseMultiOrSingle(bool is_multi) { - // check if single or multi and return the size of the matching class struct. - if (is_multi) - return sizeof(HNSWIndex_Multi); - else - return sizeof(HNSWIndex_Single); -} - -size_t EstimateInitialSize(const HNSWParams *params, bool is_normalized) { - size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); - - size_t est = sizeof(VecSimAllocator) + allocations_overhead; - if (params->type == VecSimType_FLOAT32) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_FLOAT64) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_BFLOAT16) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_FLOAT16) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_INT8) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else if (params->type == VecSimType_UINT8) { - est += EstimateComponentsMemory(params->metric, is_normalized); - est += EstimateInitialSize_ChooseMultiOrSingle(params->multi); - } else { - throw std::invalid_argument("Invalid params->type"); - } - est += sizeof(DataBlocksContainer) + allocations_overhead; - - return est; -} - -size_t EstimateElementSize(const HNSWParams *params) { - - size_t M = (params->M) ? params->M : HNSW_DEFAULT_M; - size_t elementGraphDataSize = sizeof(ElementGraphData) + sizeof(idType) * M * 2; - - size_t size_total_data_per_element = - elementGraphDataSize + - VecSimParams_GetStoredDataSize(params->type, params->dim, params->metric); - - // when reserving space for new labels in the lookup hash table, each entry is a pointer to a - // label node (bucket). - size_t size_label_lookup_entry = sizeof(void *); - - // 1 entry in visited nodes + 1 entry in element metadata map + (approximately) 1 bucket in - // labels lookup hash map. - size_t size_meta_data = sizeof(tag_t) + sizeof(ElementMetaData) + size_label_lookup_entry; - - /* Disclaimer: we are neglecting two additional factors that consume memory: - * 1. The overall bucket size in labels_lookup hash table is usually higher than the number of - * requested buckets (which is the index capacity), and it is auto selected according to the - * hashing policy and the max load factor. - * 2. The incoming edges that aren't bidirectional are stored in a dynamic array - * (vecsim_stl::vector) Those edges' memory *is omitted completely* from this estimation. - */ - return size_meta_data + size_total_data_per_element; -} - -#ifdef BUILD_TESTS - -template -inline VecSimIndex *NewIndex_ChooseMultiOrSingle(std::ifstream &input, const HNSWParams *params, - const AbstractIndexInitParams &abstractInitParams, - IndexComponents &components, - HNSWSerializer::EncodingVersion version) { - HNSWIndex *index = nullptr; - // check if single and call the ctor that loads index information from file. - if (params->multi) - index = new (abstractInitParams.allocator) HNSWIndex_Multi( - input, params, abstractInitParams, components, version); - else - index = new (abstractInitParams.allocator) HNSWIndex_Single( - input, params, abstractInitParams, components, version); - - index->restoreGraph(input, version); - - return index; -} - -// Initialize @params from file for V3 -static void InitializeParams(std::ifstream &source_params, HNSWParams ¶ms) { - Serializer::readBinaryPOD(source_params, params.dim); - Serializer::readBinaryPOD(source_params, params.type); - Serializer::readBinaryPOD(source_params, params.metric); - Serializer::readBinaryPOD(source_params, params.blockSize); - Serializer::readBinaryPOD(source_params, params.multi); - Serializer::readBinaryPOD(source_params, params.initialCapacity); -} - -VecSimIndex *NewIndex(const std::string &location, bool is_normalized) { - - std::ifstream input(location, std::ios::binary); - if (!input.is_open()) { - throw std::runtime_error("Cannot open file"); - } - - HNSWSerializer::EncodingVersion version = HNSWSerializer::ReadVersion(input); - - VecSimAlgo algo = VecSimAlgo_BF; - Serializer::readBinaryPOD(input, algo); - if (algo != VecSimAlgo_HNSWLIB) { - input.close(); - auto bad_name = VecSimAlgo_ToString(algo); - if (bad_name == nullptr) { - bad_name = "Unknown (corrupted file?)"; - } - throw std::runtime_error( - std::string("Cannot load index: Expected HNSW file but got algorithm type: ") + - bad_name); - } - - HNSWParams params; - InitializeParams(input, params); - - VecSimParams vecsimParams = {.algo = VecSimAlgo_HNSWLIB, - .algoParams = {.hnswParams = HNSWParams{params}}}; - - AbstractIndexInitParams abstractInitParams = - VecSimFactory::NewAbstractInitParams(¶ms, vecsimParams.logCtx, is_normalized); - if (params.type == VecSimType_FLOAT32) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, params.metric, abstractInitParams.dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(input, ¶ms, abstractInitParams, - indexComponents, version); - } else if (params.type == VecSimType_FLOAT64) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, params.metric, abstractInitParams.dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(input, ¶ms, abstractInitParams, - indexComponents, version); - } else if (params.type == VecSimType_BFLOAT16) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, params.metric, abstractInitParams.dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(input, ¶ms, abstractInitParams, - indexComponents, version); - } else if (params.type == VecSimType_FLOAT16) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, params.metric, abstractInitParams.dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(input, ¶ms, abstractInitParams, - indexComponents, version); - } else if (params.type == VecSimType_INT8) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, params.metric, abstractInitParams.dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(input, ¶ms, abstractInitParams, - indexComponents, version); - } else if (params.type == VecSimType_UINT8) { - IndexComponents indexComponents = CreateIndexComponents( - abstractInitParams.allocator, params.metric, abstractInitParams.dim, is_normalized); - return NewIndex_ChooseMultiOrSingle(input, ¶ms, abstractInitParams, - indexComponents, version); - } else { - auto bad_name = VecSimType_ToString(params.type); - if (bad_name == nullptr) { - bad_name = "Unknown (corrupted file?)"; - } - throw std::runtime_error(std::string("Cannot load index: bad index data type: ") + - bad_name); - } -} -#endif - -}; // namespace HNSWFactory diff --git a/src/VecSim/index_factories/hnsw_factory.h b/src/VecSim/index_factories/hnsw_factory.h deleted file mode 100644 index ccda1437f..000000000 --- a/src/VecSim/index_factories/hnsw_factory.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include // size_t -#include // std::shared_ptr - -#include "VecSim/vec_sim.h" //typedef VecSimIndex -#include "VecSim/vec_sim_common.h" // HNSWParams -#include "VecSim/memory/vecsim_malloc.h" // VecSimAllocator -#include "VecSim/vec_sim_index.h" - -namespace HNSWFactory { -/** @param is_normalized is used to determine the index's computer type. If the index metric is - * Cosine, and is_normalized == true, we will create the computer as if the metric is IP, assuming - * the blobs sent to the index are already normalized. For example, in case it's a tiered index, - * where the blobs are normalized by the frontend index. - */ -VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized = false); -VecSimIndex *NewIndex(const HNSWParams *params, bool is_normalized = false); -size_t EstimateInitialSize(const HNSWParams *params, bool is_normalized = false); -size_t EstimateElementSize(const HNSWParams *params); - -#ifdef BUILD_TESTS -// Factory function to be used before loading a serialized index. -// @params is only used for backward compatibility with V1. It won't be used if V2 and up is loaded. -// Required fields: type, dim, metric and multi -// Permission fields that *** must be initalized to zero ***: blockSize, epsilon * -VecSimIndex *NewIndex(const std::string &location, bool is_normalized = false); - -#endif - -}; // namespace HNSWFactory diff --git a/src/VecSim/index_factories/index_factory.cpp b/src/VecSim/index_factories/index_factory.cpp deleted file mode 100644 index 16459a091..000000000 --- a/src/VecSim/index_factories/index_factory.cpp +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/vec_sim_index.h" -#include "index_factory.h" -#include "hnsw_factory.h" -#include "brute_force_factory.h" -#include "tiered_factory.h" -#include "svs_factory.h" - -namespace VecSimFactory { -VecSimIndex *NewIndex(const VecSimParams *params) { - VecSimIndex *index = NULL; - std::shared_ptr allocator = VecSimAllocator::newVecsimAllocator(); - try { - switch (params->algo) { - case VecSimAlgo_HNSWLIB: { - index = HNSWFactory::NewIndex(params); - break; - } - - case VecSimAlgo_BF: { - index = BruteForceFactory::NewIndex(params); - break; - } - case VecSimAlgo_TIERED: { - index = TieredFactory::NewIndex(¶ms->algoParams.tieredParams); - break; - } - case VecSimAlgo_SVS: { - index = SVSFactory::NewIndex(params); - break; - } - } - } catch (...) { - // Index will delete itself. For now, do nothing. - } - return index; -} - -size_t EstimateInitialSize(const VecSimParams *params) { - switch (params->algo) { - case VecSimAlgo_HNSWLIB: - return HNSWFactory::EstimateInitialSize(¶ms->algoParams.hnswParams); - case VecSimAlgo_BF: - return BruteForceFactory::EstimateInitialSize(¶ms->algoParams.bfParams); - case VecSimAlgo_TIERED: - return TieredFactory::EstimateInitialSize(¶ms->algoParams.tieredParams); - case VecSimAlgo_SVS:; // empty statement if svs not available - return SVSFactory::EstimateInitialSize(¶ms->algoParams.svsParams); - } - return -1; -} - -size_t EstimateElementSize(const VecSimParams *params) { - switch (params->algo) { - case VecSimAlgo_HNSWLIB: - return HNSWFactory::EstimateElementSize(¶ms->algoParams.hnswParams); - case VecSimAlgo_BF: - return BruteForceFactory::EstimateElementSize(¶ms->algoParams.bfParams); - case VecSimAlgo_TIERED: - return TieredFactory::EstimateElementSize(¶ms->algoParams.tieredParams); - case VecSimAlgo_SVS:; // empty statement if svs not available - return SVSFactory::EstimateElementSize(¶ms->algoParams.svsParams); - } - return -1; -} - -} // namespace VecSimFactory diff --git a/src/VecSim/index_factories/index_factory.h b/src/VecSim/index_factories/index_factory.h deleted file mode 100644 index 5f471a82b..000000000 --- a/src/VecSim/index_factories/index_factory.h +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/vec_sim.h" -#include "VecSim/vec_sim_common.h" -#include "VecSim/memory/vecsim_malloc.h" - -namespace VecSimFactory { -VecSimIndex *NewIndex(const VecSimParams *params); -size_t EstimateInitialSize(const VecSimParams *params); -size_t EstimateElementSize(const VecSimParams *params); -}; // namespace VecSimFactory diff --git a/src/VecSim/index_factories/svs_factory.cpp b/src/VecSim/index_factories/svs_factory.cpp deleted file mode 100644 index eac93fb55..000000000 --- a/src/VecSim/index_factories/svs_factory.cpp +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#include "VecSim/index_factories/svs_factory.h" - -#if HAVE_SVS -#include "VecSim/memory/vecsim_malloc.h" -#include "VecSim/vec_sim_index.h" -#include "VecSim/algorithms/svs/svs.h" -#include "VecSim/index_factories/components/components_factory.h" -#include "VecSim/index_factories/factory_utils.h" - -namespace SVSFactory { - -namespace { - -// NewVectorsImpl() is the chain of a template helper functions to create a new SVS index. -template -VecSimIndex *NewIndexImpl(const VecSimParams *params, bool is_normalized) { - auto &svsParams = params->algoParams.svsParams; - auto abstractInitParams = - VecSimFactory::NewAbstractInitParams(&svsParams, params->logCtx, is_normalized); - auto preprocessors = CreatePreprocessorsContainer>( - abstractInitParams.allocator, svsParams.metric, svsParams.dim, is_normalized, 0); - IndexComponents, float> components = { - nullptr, preprocessors}; // calculator is not in use in svs. - bool forcePreprocessing = !is_normalized && svsParams.metric == VecSimMetric_Cosine; - if (svsParams.multi) { - return new (abstractInitParams.allocator) - SVSIndex( - svsParams, abstractInitParams, components, forcePreprocessing); - } else { - return new (abstractInitParams.allocator) - SVSIndex( - svsParams, abstractInitParams, components, forcePreprocessing); - } -} - -template -VecSimIndex *NewIndexImpl(const VecSimParams *params, bool is_normalized) { - // Ignore the 'supported' flag because we always fallback at least to the non-quantized mode - // elsewhere we got code coverage failure for the `supported==false` case - auto quantBits = - std::get<0>(svs_details::isSVSQuantBitsSupported(params->algoParams.svsParams.quantBits)); - - switch (quantBits) { - case VecSimSvsQuant_NONE: - return NewIndexImpl(params, is_normalized); - case VecSimSvsQuant_Scalar: - return NewIndexImpl(params, is_normalized); - case VecSimSvsQuant_8: - return NewIndexImpl(params, is_normalized); - case VecSimSvsQuant_4: - return NewIndexImpl(params, is_normalized); - case VecSimSvsQuant_4x4: - return NewIndexImpl(params, is_normalized); - case VecSimSvsQuant_4x8: - return NewIndexImpl(params, is_normalized); - case VecSimSvsQuant_4x8_LeanVec: - return NewIndexImpl(params, is_normalized); - case VecSimSvsQuant_8x8_LeanVec: - return NewIndexImpl(params, is_normalized); - default: - // If we got here something is wrong. - assert(false && "Unsupported quantization mode"); - return NULL; - } -} - -template -VecSimIndex *NewIndexImpl(const VecSimParams *params, bool is_normalized) { - assert(params && params->algo == VecSimAlgo_SVS); - switch (params->algoParams.svsParams.type) { - case VecSimType_FLOAT32: - return NewIndexImpl(params, is_normalized); - case VecSimType_FLOAT16: - return NewIndexImpl(params, is_normalized); - default: - // If we got here something is wrong. - assert(false && "Unsupported data type"); - return NULL; - } -} - -VecSimIndex *NewIndexImpl(const VecSimParams *params, bool is_normalized) { - assert(params && params->algo == VecSimAlgo_SVS); - switch (params->algoParams.svsParams.metric) { - case VecSimMetric_L2: - return NewIndexImpl(params, is_normalized); - case VecSimMetric_IP: - case VecSimMetric_Cosine: - return NewIndexImpl(params, is_normalized); - default: - // If we got here something is wrong. - assert(false && "Unknown distance metric type"); - return NULL; - } -} - -// QuantizedVectorSize() is the chain of template functions to estimate vector DataSize. -template -constexpr size_t QuantizedVectorSize(size_t dims, size_t alignment = 0, size_t leanvec_dim = 0) { - return SVSStorageTraits::element_size( - dims, alignment, leanvec_dim); -} - -template -size_t QuantizedVectorSize(VecSimSvsQuantBits quant_bits, size_t dims, size_t alignment = 0, - size_t leanvec_dim = 0) { - // Ignore the 'supported' flag because we always fallback at least to the non-quantized mode - // elsewhere we got code coverage failure for the `supported==false` case - auto quantBits = std::get<0>(svs_details::isSVSQuantBitsSupported(quant_bits)); - - switch (quantBits) { - case VecSimSvsQuant_NONE: - return QuantizedVectorSize(dims, alignment); - case VecSimSvsQuant_Scalar: - return QuantizedVectorSize(dims, alignment); - case VecSimSvsQuant_8: - return QuantizedVectorSize(dims, alignment); - case VecSimSvsQuant_4: - return QuantizedVectorSize(dims, alignment); - case VecSimSvsQuant_4x4: - return QuantizedVectorSize(dims, alignment); - case VecSimSvsQuant_4x8: - return QuantizedVectorSize(dims, alignment); - case VecSimSvsQuant_4x8_LeanVec: - return QuantizedVectorSize(dims, alignment, leanvec_dim); - case VecSimSvsQuant_8x8_LeanVec: - return QuantizedVectorSize(dims, alignment, leanvec_dim); - default: - // If we got here something is wrong. - assert(false && "Unsupported quantization mode"); - return 0; - } -} - -size_t QuantizedVectorSize(VecSimType data_type, VecSimSvsQuantBits quant_bits, size_t dims, - size_t alignment = 0, size_t leanvec_dim = 0) { - switch (data_type) { - case VecSimType_FLOAT32: - return QuantizedVectorSize(quant_bits, dims, alignment, leanvec_dim); - case VecSimType_FLOAT16: - return QuantizedVectorSize(quant_bits, dims, alignment, leanvec_dim); - default: - // If we got here something is wrong. - assert(false && "Unsupported data type"); - return 0; - } -} - -size_t EstimateSVSIndexSize(const SVSParams *params) { - // SVSindex class has no fields which size depend on template specialization - // when VecSimIndexAbstract may depend on DataType template parameter - switch (params->type) { - case VecSimType_FLOAT32: - return sizeof(SVSIndex); - case VecSimType_FLOAT16: - return sizeof(SVSIndex); - default: - // If we got here something is wrong. - assert(false && "Unsupported data type"); - return 0; - } -} - -size_t EstimateComponentsMemorySVS(VecSimType type, VecSimMetric metric, bool is_normalized) { - // SVS index only includes a preprocessor container. - switch (type) { - case VecSimType_FLOAT32: - return EstimatePreprocessorsContainerMemory(metric, is_normalized); - case VecSimType_FLOAT16: - return EstimatePreprocessorsContainerMemory>( - metric, is_normalized); - default: - // If we got here something is wrong. - assert(false && "Unsupported data type"); - return 0; - } -} -} // namespace - -VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized) { - return NewIndexImpl(params, is_normalized); -} - -#if BUILD_TESTS -VecSimIndex *NewIndex(const std::string &location, const VecSimParams *params, bool is_normalized) { - auto index = NewIndexImpl(params, is_normalized); - // Side-cast to SVSIndexBase to call loadIndex - SVSIndexBase *svs_index = dynamic_cast(index); - if (svs_index != nullptr) { - try { - svs_index->loadIndex(location); - } catch (const std::exception &e) { - VecSimIndex_Free(index); - throw; - } - } else { - VecSimIndex_Free(index); - throw std::runtime_error( - "Cannot load index: Error in index creation before loading serialization"); - } - return index; -} -#endif - -size_t EstimateElementSize(const SVSParams *params) { - using graph_idx_type = uint32_t; - // Assuming that the graph_max_degree can be unset in params. - const auto graph_max_degree = svs_details::makeVamanaBuildParameters(*params).graph_max_degree; - const auto graph_node_size = SVSGraphBuilder::element_size(graph_max_degree); - const auto vector_size = - QuantizedVectorSize(params->type, params->quantBits, params->dim, 0, params->leanvec_dim); - - return vector_size + graph_node_size; -} - -size_t EstimateInitialSize(const SVSParams *params, bool is_normalized) { - size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); - size_t est = sizeof(VecSimAllocator) + allocations_overhead; - - est += EstimateSVSIndexSize(params); - est += EstimateComponentsMemorySVS(params->type, params->metric, is_normalized); - est += sizeof(DataBlocksContainer) + allocations_overhead; - return est; -} - -} // namespace SVSFactory - -// This is a temporary solution to avoid breaking the build when SVS is not available -// and to allow the code to compile without SVS support. -// TODO: remove HAVE_SVS when SVS will support all Redis platforms and compilers -#else // HAVE_SVS -namespace SVSFactory { -VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized) { return NULL; } -#if BUILD_TESTS -VecSimIndex *NewIndex(const std::string &location, const VecSimParams *params, bool is_normalized) { - return NULL; -} -#endif -size_t EstimateInitialSize(const SVSParams *params, bool is_normalized) { return -1; } -size_t EstimateElementSize(const SVSParams *params) { return -1; } -}; // namespace SVSFactory -#endif // HAVE_SVS diff --git a/src/VecSim/index_factories/svs_factory.h b/src/VecSim/index_factories/svs_factory.h deleted file mode 100644 index c4c6d04db..000000000 --- a/src/VecSim/index_factories/svs_factory.h +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include // size_t -#include - -#include "VecSim/vec_sim.h" //typedef VecSimIndex -#include "VecSim/vec_sim_common.h" // VecSimParams, SVSParams - -namespace SVSFactory { -VecSimIndex *NewIndex(const VecSimParams *params, bool is_normalized = false); -#if BUILD_TESTS -VecSimIndex *NewIndex(const std::string &location, const VecSimParams *params, - bool is_normalized = false); -#endif -size_t EstimateInitialSize(const SVSParams *params, bool is_normalized = false); -size_t EstimateElementSize(const SVSParams *params); -}; // namespace SVSFactory diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp deleted file mode 100644 index 337db6cc3..000000000 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/index_factories/tiered_factory.h" -#include "VecSim/index_factories/hnsw_factory.h" -#include "VecSim/index_factories/brute_force_factory.h" - -#include "VecSim/algorithms/hnsw/hnsw_tiered.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include "VecSim/index_factories/svs_factory.h" - -#if HAVE_SVS -#include "VecSim/algorithms/svs/svs_tiered.h" -#endif - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; - -namespace TieredFactory { - -namespace TieredHNSWFactory { - -static inline BFParams NewBFParams(const TieredIndexParams *params) { - auto hnsw_params = params->primaryIndexParams->algoParams.hnswParams; - BFParams bf_params = {.type = hnsw_params.type, - .dim = hnsw_params.dim, - .metric = hnsw_params.metric, - .multi = hnsw_params.multi, - .blockSize = hnsw_params.blockSize}; - - return bf_params; -} - -template -inline VecSimIndex *NewIndex(const TieredIndexParams *params) { - - // initialize hnsw index - // Normalization is done by the frontend index. - auto *hnsw_index = reinterpret_cast *>( - HNSWFactory::NewIndex(params->primaryIndexParams, true)); - // initialize brute force index - - BFParams bf_params = NewBFParams(params); - - AbstractIndexInitParams abstractInitParams = - VecSimFactory::NewAbstractInitParams(&bf_params, params->primaryIndexParams->logCtx, false); - assert(hnsw_index->getInputBlobSize() == abstractInitParams.storedDataSize); - assert(hnsw_index->getStoredDataSize() == abstractInitParams.storedDataSize); - auto frontendIndex = static_cast *>( - BruteForceFactory::NewIndex(&bf_params, abstractInitParams, false)); - - // Create new tiered hnsw index - std::shared_ptr management_layer_allocator = - VecSimAllocator::newVecsimAllocator(); - - return new (management_layer_allocator) TieredHNSWIndex( - hnsw_index, frontendIndex, *params, management_layer_allocator); -} - -inline size_t EstimateInitialSize(const TieredIndexParams *params) { - HNSWParams hnsw_params = params->primaryIndexParams->algoParams.hnswParams; - - // Add size estimation of VecSimTieredIndex sub indexes. - // Normalization is done by the frontend index. - size_t est = HNSWFactory::EstimateInitialSize(&hnsw_params, true); - - // Management layer allocator overhead. - size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); - est += sizeof(VecSimAllocator) + allocations_overhead; - - // Size of the TieredHNSWIndex struct. - if (hnsw_params.type == VecSimType_FLOAT32) { - est += sizeof(TieredHNSWIndex); - } else if (hnsw_params.type == VecSimType_FLOAT64) { - est += sizeof(TieredHNSWIndex); - } else if (hnsw_params.type == VecSimType_BFLOAT16) { - est += sizeof(TieredHNSWIndex); - } else if (hnsw_params.type == VecSimType_FLOAT16) { - est += sizeof(TieredHNSWIndex); - } else if (hnsw_params.type == VecSimType_INT8) { - est += sizeof(TieredHNSWIndex); - } else if (hnsw_params.type == VecSimType_UINT8) { - est += sizeof(TieredHNSWIndex); - } else { - throw std::invalid_argument("Invalid hnsw_params.type"); - } - - return est; -} - -VecSimIndex *NewIndex(const TieredIndexParams *params) { - // Tiered index that contains HNSW index as primary index - VecSimType type = params->primaryIndexParams->algoParams.hnswParams.type; - if (type == VecSimType_FLOAT32) { - return TieredHNSWFactory::NewIndex(params); - } else if (type == VecSimType_FLOAT64) { - return TieredHNSWFactory::NewIndex(params); - } else if (type == VecSimType_BFLOAT16) { - return TieredHNSWFactory::NewIndex(params); - } else if (type == VecSimType_FLOAT16) { - return TieredHNSWFactory::NewIndex(params); - } else if (type == VecSimType_INT8) { - return TieredHNSWFactory::NewIndex(params); - } else if (type == VecSimType_UINT8) { - return TieredHNSWFactory::NewIndex(params); - } - return nullptr; // Invalid type. -} -} // namespace TieredHNSWFactory - -namespace TieredSVSFactory { -BFParams NewBFParams(const TieredIndexParams *params) { - auto &svs_params = params->primaryIndexParams->algoParams.svsParams; - return BFParams{.type = svs_params.type, - .dim = svs_params.dim, - .metric = svs_params.metric, - .multi = svs_params.multi, - .blockSize = svs_params.blockSize}; -} - -#if HAVE_SVS -template -inline VecSimIndex *NewIndex(const TieredIndexParams *params) { - - // initialize svs index - // Normalization is done by the frontend index. - auto *svs_index = static_cast *>( - SVSFactory::NewIndex(params->primaryIndexParams, true)); - assert(svs_index != nullptr); - // initialize brute force index - - auto bf_params = NewBFParams(params); - - AbstractIndexInitParams abstractInitParams = - VecSimFactory::NewAbstractInitParams(&bf_params, params->primaryIndexParams->logCtx, false); - assert(svs_index->getInputBlobSize() == abstractInitParams.storedDataSize); - assert(svs_index->getStoredDataSize() == abstractInitParams.storedDataSize); - auto frontendIndex = static_cast *>( - BruteForceFactory::NewIndex(&bf_params, abstractInitParams, false)); - - // Create new tiered svs index - std::shared_ptr management_layer_allocator = - VecSimAllocator::newVecsimAllocator(); - - return new (management_layer_allocator) - TieredSVSIndex(svs_index, frontendIndex, *params, management_layer_allocator); -} - -inline size_t EstimateInitialSize(const TieredIndexParams *params) { - auto &svs_params = params->primaryIndexParams->algoParams.svsParams; - - // Add size estimation of VecSimTieredIndex sub indexes. - size_t est = SVSFactory::EstimateInitialSize(&svs_params, true); - - // Management layer allocator overhead. - size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); - est += sizeof(VecSimAllocator) + allocations_overhead; - - // Size of the TieredHNSWIndex struct. - switch (svs_params.type) { - case VecSimType_FLOAT32: - est += sizeof(TieredSVSIndex); - break; - case VecSimType_FLOAT16: - est += sizeof(TieredSVSIndex); - break; - default: - assert(false && "Unsupported data type"); - break; - } - - return est; -} - -VecSimIndex *NewIndex(const TieredIndexParams *params) { - // Tiered index that contains SVS index as primary index - VecSimType type = params->primaryIndexParams->algoParams.svsParams.type; - switch (type) { - case VecSimType_FLOAT32: - return TieredSVSFactory::NewIndex(params); - case VecSimType_FLOAT16: - return TieredSVSFactory::NewIndex(params); - default: - assert(false && "Unsupported data type"); - return nullptr; // Invalid type. - } - return nullptr; // Invalid type. -} - -// This is a temporary solution to avoid breaking the build when SVS is not available -// and to allow the code to compile without SVS support. -// TODO: remove HAVE_SVS when SVS will support all Redis platforms and compilers -#else // HAVE_SVS -inline VecSimIndex *NewIndex(const TieredIndexParams *params) { return nullptr; } -inline size_t EstimateInitialSize(const TieredIndexParams *params) { return 0; } -inline size_t EstimateElementSize(const TieredIndexParams *params) { return 0; } -#endif -} // namespace TieredSVSFactory - -VecSimIndex *NewIndex(const TieredIndexParams *params) { - // Tiered index that contains HNSW index as primary index - if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { - return TieredHNSWFactory::NewIndex(params); - } - // Tiered index that contains SVS index as primary index - if (params->primaryIndexParams->algo == VecSimAlgo_SVS) { - return TieredSVSFactory::NewIndex(params); - } - return nullptr; // Invalid algorithm or type. -} -size_t EstimateInitialSize(const TieredIndexParams *params) { - - size_t est = 0; - - BFParams bf_params{}; - if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { - est += TieredHNSWFactory::EstimateInitialSize(params); - bf_params = TieredHNSWFactory::NewBFParams(params); - } - if (params->primaryIndexParams->algo == VecSimAlgo_SVS) { - est += TieredSVSFactory::EstimateInitialSize(params); - bf_params = TieredSVSFactory::NewBFParams(params); - } - - est += BruteForceFactory::EstimateInitialSize(&bf_params, false); - return est; -} - -size_t EstimateElementSize(const TieredIndexParams *params) { - size_t est = 0; - if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { - est = HNSWFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.hnswParams); - } - if (params->primaryIndexParams->algo == VecSimAlgo_SVS) { - est = SVSFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.svsParams); - } - return est; -} - -}; // namespace TieredFactory diff --git a/src/VecSim/index_factories/tiered_factory.h b/src/VecSim/index_factories/tiered_factory.h deleted file mode 100644 index fbb55d3b3..000000000 --- a/src/VecSim/index_factories/tiered_factory.h +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/vec_sim.h" -#include "VecSim/vec_sim_common.h" -#include "VecSim/memory/vecsim_malloc.h" -#include "VecSim/vec_sim_index.h" -#include "VecSim/algorithms/hnsw/hnsw_tiered.h" -#include "VecSim/algorithms/svs/svs_tiered.h" -#include "VecSim/algorithms/brute_force/brute_force.h" -#include "VecSim/index_factories/factory_utils.h" - -namespace TieredFactory { - -VecSimIndex *NewIndex(const TieredIndexParams *params); - -// The size estimation is the sum of the buffer (brute force) and main index initial sizes -// estimations, plus the tiered index class size. Note it does not include the size of internal -// containers such as the job queue, as those depend on the user implementation. -size_t EstimateInitialSize(const TieredIndexParams *params); -size_t EstimateElementSize(const TieredIndexParams *params); - -#ifdef BUILD_TESTS -namespace TieredHNSWFactory { -// Build tiered index from existing HNSW index - for internal benchmarks purposes -template -VecSimIndex *NewIndex(const TieredIndexParams *params, HNSWIndex *hnsw_index) { - // Initialize brute force index. - BFParams bf_params = {.type = hnsw_index->getType(), - .dim = hnsw_index->getDim(), - .metric = hnsw_index->getMetric(), - .multi = hnsw_index->isMultiValue(), - .blockSize = hnsw_index->getBlockSize()}; - - AbstractIndexInitParams abstractInitParams = - VecSimFactory::NewAbstractInitParams(&bf_params, nullptr, false); - assert(hnsw_index->getInputBlobSize() == abstractInitParams.storedDataSize); - assert(hnsw_index->getStoredDataSize() == abstractInitParams.storedDataSize); - auto frontendIndex = static_cast *>( - BruteForceFactory::NewIndex(&bf_params, abstractInitParams, false)); - - // Create new tiered hnsw index - std::shared_ptr management_layer_allocator = - VecSimAllocator::newVecsimAllocator(); - return new (management_layer_allocator) TieredHNSWIndex( - hnsw_index, frontendIndex, *params, management_layer_allocator); -} -} // namespace TieredHNSWFactory - -// The function below is exported to calculate a brute force index size in tests to align -// with the logic of the TieredFactory::EstimateInitialSize(), which currently doesn’t have a -// verification of the backend index algorithm. To be removed once a proper verification is -// introduced. -namespace TieredSVSFactory { - -#if HAVE_SVS -template -inline VecSimIndex *NewIndex(const TieredIndexParams *params, - VecSimIndexAbstract *svs_index) { - // Initialize brute force index. - BFParams bf_params = {.type = svs_index->getType(), - .dim = svs_index->getDim(), - .metric = svs_index->getMetric(), - .multi = svs_index->isMultiValue(), - .blockSize = svs_index->getBlockSize()}; - - AbstractIndexInitParams abstractInitParams = - VecSimFactory::NewAbstractInitParams(&bf_params, params->primaryIndexParams->logCtx, false); - assert(svs_index->getInputBlobSize() == abstractInitParams.storedDataSize); - assert(svs_index->getStoredDataSize() == abstractInitParams.storedDataSize); - auto frontendIndex = static_cast *>( - BruteForceFactory::NewIndex(&bf_params, abstractInitParams, false)); - - // Create new tiered svs index - std::shared_ptr management_layer_allocator = - VecSimAllocator::newVecsimAllocator(); - - return new (management_layer_allocator) - TieredSVSIndex(svs_index, frontendIndex, *params, management_layer_allocator); -} -#endif -BFParams NewBFParams(const TieredIndexParams *params); -} // namespace TieredSVSFactory -#endif - -}; // namespace TieredFactory diff --git a/src/VecSim/info_iterator.cpp b/src/VecSim/info_iterator.cpp deleted file mode 100644 index d41105ba8..000000000 --- a/src/VecSim/info_iterator.cpp +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "info_iterator_struct.h" - -extern "C" size_t VecSimDebugInfoIterator_NumberOfFields(VecSimDebugInfoIterator *infoIterator) { - return infoIterator->numberOfFields(); -} - -extern "C" bool VecSimDebugInfoIterator_HasNextField(VecSimDebugInfoIterator *infoIterator) { - return infoIterator->hasNext(); -} - -extern "C" VecSim_InfoField * -VecSimDebugInfoIterator_NextField(VecSimDebugInfoIterator *infoIterator) { - if (infoIterator->hasNext()) { - return infoIterator->next(); - } - return NULL; -} - -extern "C" void VecSimDebugInfoIterator_Free(VecSimDebugInfoIterator *infoIterator) { - if (infoIterator != NULL) { - delete infoIterator; - } -} diff --git a/src/VecSim/info_iterator.h b/src/VecSim/info_iterator.h deleted file mode 100644 index cf1059787..000000000 --- a/src/VecSim/info_iterator.h +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include -#include "vec_sim_common.h" -#ifdef __cplusplus -extern "C" { -#endif - -/** - * @brief A struct to hold an index information for generic purposes. Each information field is of - * the type VecSim_InfoFieldType. This struct exposes an iterator-like API to iterate over the - * information fields. - */ -typedef struct VecSimDebugInfoIterator VecSimDebugInfoIterator; - -typedef enum { - INFOFIELD_STRING, - INFOFIELD_INT64, - INFOFIELD_UINT64, - INFOFIELD_FLOAT64, - INFOFIELD_ITERATOR -} VecSim_InfoFieldType; - -typedef union { - double floatingPointValue; // Floating point value. 64 bits float. - int64_t integerValue; // Integer value. Signed 64 bits integer. - uint64_t uintegerValue; // Unsigned value. Unsigned 64 bits integer. - const char *stringValue; // String value. - VecSimDebugInfoIterator *iteratorValue; // Iterator value. -} FieldValue; - -/** - * @brief A struct to hold field information. This struct contains three members: - * fieldType - Enum describing the content of the value. - * fieldName - Field name. - * fieldValue - A union of string/integer/float values. - */ -typedef struct { - const char *fieldName; // Field name. - VecSim_InfoFieldType fieldType; // Field type (in {STR, INT64, FLOAT64}) - FieldValue fieldValue; -} VecSim_InfoField; - -/** - * @brief Returns the number of fields in the info iterator. - * - * @param infoIterator Given info iterator. - * @return size_t Number of fields. - */ -size_t VecSimDebugInfoIterator_NumberOfFields(VecSimDebugInfoIterator *infoIterator); - -/** - * @brief Returns if the fields iterator is depleted. - * - * @param infoIterator Given info iterator. - * @return true Iterator is not depleted. - * @return false Otherwise. - */ -bool VecSimDebugInfoIterator_HasNextField(VecSimDebugInfoIterator *infoIterator); - -/** - * @brief Returns a pointer to the next info field. - * - * @param infoIterator Given info iterator. - * @return VecSim_InfoField* A pointer to the next info field. - */ -VecSim_InfoField *VecSimDebugInfoIterator_NextField(VecSimDebugInfoIterator *infoIterator); - -/** - * @brief Free an info iterator. - * - * @param infoIterator Given info iterator. - */ -void VecSimDebugInfoIterator_Free(VecSimDebugInfoIterator *infoIterator); - -#ifdef __cplusplus -} -#endif diff --git a/src/VecSim/info_iterator_struct.h b/src/VecSim/info_iterator_struct.h deleted file mode 100644 index d8a17a979..000000000 --- a/src/VecSim/info_iterator_struct.h +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "info_iterator.h" -#include "VecSim/utils/vecsim_stl.h" - -struct VecSimDebugInfoIterator { -private: - vecsim_stl::vector fields; - size_t currentIndex; - -public: - VecSimDebugInfoIterator(size_t len, const std::shared_ptr &alloc) - : fields(alloc), currentIndex(0) { - this->fields.reserve(len); - } - - inline void addInfoField(VecSim_InfoField infoField) { this->fields.push_back(infoField); } - - inline bool hasNext() { return this->currentIndex < this->fields.size(); } - - inline VecSim_InfoField *next() { return &this->fields[this->currentIndex++]; } - - inline size_t numberOfFields() { return this->fields.size(); } - - virtual ~VecSimDebugInfoIterator() { - for (size_t i = 0; i < this->fields.size(); i++) { - if (this->fields[i].fieldType == INFOFIELD_ITERATOR) { - delete this->fields[i].fieldValue.iteratorValue; - } - } - } -}; diff --git a/src/VecSim/memory/memory_utils.h b/src/VecSim/memory/memory_utils.h deleted file mode 100644 index 3798f0c04..000000000 --- a/src/VecSim/memory/memory_utils.h +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include - -namespace MemoryUtils { - -using alloc_deleter_t = std::function; -using unique_blob = std::unique_ptr; - -} // namespace MemoryUtils diff --git a/src/VecSim/memory/vecsim_base.cpp b/src/VecSim/memory/vecsim_base.cpp deleted file mode 100644 index 33382cb74..000000000 --- a/src/VecSim/memory/vecsim_base.cpp +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "vecsim_base.h" - -void *VecsimBaseObject::operator new(size_t size, std::shared_ptr allocator) { - return allocator->allocate(size); -} - -void *VecsimBaseObject::operator new[](size_t size, std::shared_ptr allocator) { - return allocator->allocate(size); -} - -void VecsimBaseObject::operator delete(void *p, size_t size) { - VecsimBaseObject *obj = reinterpret_cast(p); - obj->allocator->deallocate(obj, size); -} - -void VecsimBaseObject::operator delete[](void *p, size_t size) { - VecsimBaseObject *obj = reinterpret_cast(p); - obj->allocator->deallocate(obj, size); -} - -void VecsimBaseObject::operator delete(void *p, std::shared_ptr allocator) { - allocator->free_allocation(p); -} -void VecsimBaseObject::operator delete[](void *p, std::shared_ptr allocator) { - allocator->free_allocation(p); -} - -void operator delete(void *p, std::shared_ptr allocator) { - allocator->free_allocation(p); -} -void operator delete[](void *p, std::shared_ptr allocator) { - allocator->free_allocation(p); -} - -// TODO: Probably unused functions. See Codcove output in order to remove - -void operator delete(void *p, size_t size, std::shared_ptr allocator) { - allocator->deallocate(p, size); -} - -void operator delete[](void *p, size_t size, std::shared_ptr allocator) { - allocator->deallocate(p, size); -} - -std::shared_ptr VecsimBaseObject::getAllocator() const { return this->allocator; } diff --git a/src/VecSim/memory/vecsim_base.h b/src/VecSim/memory/vecsim_base.h deleted file mode 100644 index 672597126..000000000 --- a/src/VecSim/memory/vecsim_base.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "vecsim_malloc.h" -#include - -struct VecsimBaseObject { - -protected: - std::shared_ptr allocator; - -public: - VecsimBaseObject(std::shared_ptr allocator) : allocator(allocator) {} - - static void *operator new(size_t size, std::shared_ptr allocator); - static void *operator new[](size_t size, std::shared_ptr allocator); - static void operator delete(void *p, size_t size); - static void operator delete[](void *p, size_t size); - - // Placement delete. To be used in try/catch clause when called with the respected constructor - static void operator delete(void *p, std::shared_ptr allocator); - static void operator delete[](void *p, std::shared_ptr allocator); - static void operator delete(void *p, size_t size, std::shared_ptr allocator); - static void operator delete[](void *p, size_t size, std::shared_ptr allocator); - - std::shared_ptr getAllocator() const; - virtual inline uint64_t getAllocationSize() const { - return this->allocator->getAllocationSize(); - } - - virtual ~VecsimBaseObject() = default; -}; diff --git a/src/VecSim/memory/vecsim_malloc.cpp b/src/VecSim/memory/vecsim_malloc.cpp deleted file mode 100644 index caa28fc1f..000000000 --- a/src/VecSim/memory/vecsim_malloc.cpp +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "vecsim_malloc.h" -#include -#include -#include -#include - -std::shared_ptr VecSimAllocator::newVecsimAllocator() { - std::shared_ptr allocator(new VecSimAllocator()); - return allocator; -} - -struct VecSimAllocationHeader { - std::size_t allocation_size : 63; - std::size_t is_aligned : 1; -}; - -size_t VecSimAllocator::allocation_header_size = sizeof(VecSimAllocationHeader); - -VecSimMemoryFunctions VecSimAllocator::memFunctions = {.allocFunction = malloc, - .callocFunction = calloc, - .reallocFunction = realloc, - .freeFunction = free}; - -void VecSimAllocator::setMemoryFunctions(VecSimMemoryFunctions memFunctions) { - VecSimAllocator::memFunctions = memFunctions; -} - -void *VecSimAllocator::allocate(size_t size) { - auto ptr = static_cast(vecsim_malloc(size + allocation_header_size)); - if (ptr) { - this->allocated += size + allocation_header_size; - *ptr = {size, false}; - return ptr + 1; - } - return nullptr; -} - -void *VecSimAllocator::allocate_aligned(size_t size, unsigned char alignment) { - if (!alignment) { - return allocate(size); - } - - size += alignment; // Add enough space for alignment. - auto ptr = static_cast(vecsim_malloc(size + allocation_header_size)); - if (ptr) { - this->allocated += size + allocation_header_size; - size_t remainder = (((uintptr_t)ptr) + allocation_header_size) % alignment; - unsigned char offset = alignment - remainder; - // Store the allocation header in the 8 bytes before the returned pointer. - new (ptr + offset) VecSimAllocationHeader{size, true}; - // Store the offset in the byte right before the header. - ptr[offset - 1] = offset; - // Return the aligned pointer. - return ptr + allocation_header_size + offset; - } - return nullptr; -} - -void VecSimAllocator::deallocate(void *p, size_t size) { free_allocation(p); } - -void *VecSimAllocator::reallocate(void *p, size_t size) { - if (!p) { - return this->allocate(size); - } - size_t oldSize = getPointerAllocationSize(p); - void *new_ptr = this->allocate(size); - if (new_ptr) { - memcpy(new_ptr, p, MIN(oldSize, size)); - free_allocation(p); - return new_ptr; - } - return nullptr; -} - -void VecSimAllocator::free_allocation(void *p) { - if (!p) - return; - - auto hdr = ((VecSimAllocationHeader *)p) - 1; - unsigned char offset = hdr->is_aligned ? ((unsigned char *)hdr)[-1] : 0; - - this->allocated -= (hdr->allocation_size + allocation_header_size); - vecsim_free((char *)p - offset - allocation_header_size); -} - -void *VecSimAllocator::callocate(size_t size) { - size_t *ptr = (size_t *)vecsim_calloc(1, size + allocation_header_size); - - if (ptr) { - this->allocated += size + allocation_header_size; - *ptr = size; - return ptr + 1; - } - return nullptr; -} - -std::unique_ptr -VecSimAllocator::allocate_aligned_unique(size_t size, size_t alignment) { - void *ptr = this->allocate_aligned(size, alignment); - return {ptr, Deleter(*this)}; -} - -std::unique_ptr VecSimAllocator::allocate_unique(size_t size) { - void *ptr = this->allocate(size); - return {ptr, Deleter(*this)}; -} - -void *VecSimAllocator::operator new(size_t size) { return vecsim_malloc(size); } - -void *VecSimAllocator::operator new[](size_t size) { return vecsim_malloc(size); } -void VecSimAllocator::operator delete(void *p, size_t size) { vecsim_free(p); } -void VecSimAllocator::operator delete[](void *p, size_t size) { vecsim_free(p); } - -uint64_t VecSimAllocator::getAllocationSize() const { return this->allocated; } diff --git a/src/VecSim/memory/vecsim_malloc.h b/src/VecSim/memory/vecsim_malloc.h deleted file mode 100644 index 178d22d3a..000000000 --- a/src/VecSim/memory/vecsim_malloc.h +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/vec_sim_common.h" -#include -#include -#include -#include - -struct VecSimAllocator { - // Allow global vecsim memory functions to access this class. - friend inline void *vecsim_malloc(size_t n); - friend inline void *vecsim_calloc(size_t nelem, size_t elemsz); - friend inline void *vecsim_realloc(void *p, size_t n); - friend inline void vecsim_free(void *p); - -private: - std::atomic_uint64_t allocated; - - // Static member that indicates each allocation additional size. - static size_t allocation_header_size; - static VecSimMemoryFunctions memFunctions; - - // Forward declaration of the deleter for the unique_ptr. - struct Deleter; - VecSimAllocator() : allocated(std::atomic_uint64_t(sizeof(VecSimAllocator))) {} - -public: - static std::shared_ptr newVecsimAllocator(); - void *allocate(size_t size); - void *allocate_aligned(size_t size, unsigned char alignment); - void *callocate(size_t size); - void deallocate(void *p, size_t size); - void *reallocate(void *p, size_t size); - void free_allocation(void *p); - - // Allocations for scope-life-time memory. - std::unique_ptr allocate_aligned_unique(size_t size, size_t alignment); - std::unique_ptr allocate_unique(size_t size); - - void *operator new(size_t size); - void *operator new[](size_t size); - void operator delete(void *p, size_t size); - void operator delete[](void *p, size_t size); - - uint64_t getAllocationSize() const; - inline friend bool operator==(const VecSimAllocator &a, const VecSimAllocator &b) { - return a.allocated == b.allocated; - } - - inline friend bool operator!=(const VecSimAllocator &a, const VecSimAllocator &b) { - return a.allocated != b.allocated; - } - - static void setMemoryFunctions(VecSimMemoryFunctions memFunctions); - - static size_t getAllocationOverheadSize() { return allocation_header_size; } - -private: - // Retrieve the original requested allocation size. Required for remalloc. - inline size_t getPointerAllocationSize(void *p) { return *(((size_t *)p) - 1); } - - struct Deleter { - VecSimAllocator &allocator; - explicit constexpr Deleter(VecSimAllocator &allocator) : allocator(allocator) {} - void operator()(void *ptr) const { allocator.free_allocation(ptr); } - }; -}; - -/** - * @brief Global function to call for allocating memory buffer (malloc style). - * - * @param n - Amount of bytes to allocate. - * @return void* - Allocated buffer. - */ -inline void *vecsim_malloc(size_t n) { return VecSimAllocator::memFunctions.allocFunction(n); } - -/** - * @brief Global function to call for allocating memory buffer initiliazed to zero (calloc style). - * - * @param nelem Number of elements. - * @param elemsz Element size. - * @return void* - Allocated buffer. - */ -inline void *vecsim_calloc(size_t nelem, size_t elemsz) { - return VecSimAllocator::memFunctions.callocFunction(nelem, elemsz); -} - -/** - * @brief Global function to reallocate a buffer (realloc style). - * - * @param p Allocated buffer. - * @param n Number of bytes required to the new buffer. - * @return void* Allocated buffer with size >= n. - */ -inline void *vecsim_realloc(void *p, size_t n) { - return VecSimAllocator::memFunctions.reallocFunction(p, n); -} - -/** - * @brief Global function to free an allocated buffer. - * - * @param p Allocated buffer. - */ -inline void vecsim_free(void *p) { VecSimAllocator::memFunctions.freeFunction(p); } - -template -struct VecsimSTLAllocator { - using value_type = T; - -private: - VecsimSTLAllocator() {} - -public: - std::shared_ptr vecsim_allocator; - VecsimSTLAllocator(std::shared_ptr vecsim_allocator) - : vecsim_allocator(vecsim_allocator) {} - - // Copy constructor and assignment operator. Any VecsimSTLAllocator can be used for any type. - template - VecsimSTLAllocator(const VecsimSTLAllocator &other) - : vecsim_allocator(other.vecsim_allocator) {} - - template - VecsimSTLAllocator &operator=(const VecsimSTLAllocator &other) { - this->vecsim_allocator = other.vecsim_allocator; - return *this; - } - - T *allocate(size_t size) { return (T *)this->vecsim_allocator->allocate(size * sizeof(T)); } - - void deallocate(T *ptr, size_t size) { - this->vecsim_allocator->deallocate(ptr, size * sizeof(T)); - } -}; - -template -bool operator==(const VecsimSTLAllocator &a, const VecsimSTLAllocator &b) { - return a.vecsim_allocator == b.vecsim_allocator; -} -template -bool operator!=(const VecsimSTLAllocator &a, const VecsimSTLAllocator &b) { - return a.vecsim_allocator != b.vecsim_allocator; -} diff --git a/src/VecSim/query_result_definitions.h b/src/VecSim/query_result_definitions.h deleted file mode 100644 index 39f7354ce..000000000 --- a/src/VecSim/query_result_definitions.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/utils/vecsim_stl.h" -#include "VecSim/query_results.h" - -#include -#include - -// Use the "not a number" value to represent invalid score. This is for distinguishing the invalid -// score from "inf" score (which is valid). -#define INVALID_SCORE std::numeric_limits::quiet_NaN() - -/** - * This file contains the headers to be used internally for creating an array of results in - * TopKQuery methods. - */ -struct VecSimQueryResult { - size_t id; - double score; -}; - -using VecSimQueryResultContainer = vecsim_stl::vector; - -struct VecSimQueryReply { - VecSimQueryResultContainer results; - VecSimQueryReply_Code code; - - VecSimQueryReply(std::shared_ptr allocator, - VecSimQueryReply_Code code = VecSim_QueryReply_OK) - : results(allocator), code(code) {} -}; - -#ifdef BUILD_TESTS -#include - -// Print operators -inline std::ostream &operator<<(std::ostream &os, const VecSimQueryResult &result) { - os << "id: " << result.id << ", score: " << result.score; - return os; -} - -inline std::ostream &operator<<(std::ostream &os, const VecSimQueryReply &reply) { - for (const auto &result : reply.results) { - os << result << std::endl; - } - return os; -} -#endif diff --git a/src/VecSim/query_results.cpp b/src/VecSim/query_results.cpp deleted file mode 100644 index 845a980a7..000000000 --- a/src/VecSim/query_results.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/query_result_definitions.h" -#include "VecSim/vec_sim.h" -#include "VecSim/batch_iterator.h" -#include - -struct VecSimQueryReply_Iterator { - using iterator = decltype(VecSimQueryReply::results)::iterator; - const iterator begin, end; - iterator current; - - explicit VecSimQueryReply_Iterator(VecSimQueryReply *reply) - : begin(reply->results.begin()), end(reply->results.end()), current(begin) {} -}; - -extern "C" size_t VecSimQueryReply_Len(VecSimQueryReply *qr) { return qr->results.size(); } - -extern "C" VecSimQueryReply_Code VecSimQueryReply_GetCode(VecSimQueryReply *qr) { return qr->code; } - -extern "C" void VecSimQueryReply_Free(VecSimQueryReply *qr) { delete qr; } - -extern "C" VecSimQueryReply_Iterator *VecSimQueryReply_GetIterator(VecSimQueryReply *results) { - return new VecSimQueryReply_Iterator(results); -} - -extern "C" bool VecSimQueryReply_IteratorHasNext(VecSimQueryReply_Iterator *iterator) { - return iterator->current != iterator->end; -} - -extern "C" VecSimQueryResult *VecSimQueryReply_IteratorNext(VecSimQueryReply_Iterator *iterator) { - if (iterator->current == iterator->end) { - return nullptr; - } - - return std::to_address(iterator->current++); -} - -extern "C" int64_t VecSimQueryResult_GetId(const VecSimQueryResult *res) { - if (res == nullptr) { - return INVALID_ID; - } - return (int64_t)res->id; -} - -extern "C" double VecSimQueryResult_GetScore(const VecSimQueryResult *res) { - if (res == nullptr) { - return INVALID_SCORE; // "NaN" - } - return res->score; -} - -extern "C" void VecSimQueryReply_IteratorFree(VecSimQueryReply_Iterator *iterator) { - delete iterator; -} - -extern "C" void VecSimQueryReply_IteratorReset(VecSimQueryReply_Iterator *iterator) { - iterator->current = iterator->begin; -} - -/********************** batch iterator API ***************************/ -VecSimQueryReply *VecSimBatchIterator_Next(VecSimBatchIterator *iterator, size_t n_results, - VecSimQueryReply_Order order) { - assert((order == BY_ID || order == BY_SCORE) && - "Possible order values are only 'BY_ID' or 'BY_SCORE'"); - return iterator->getNextResults(n_results, order); -} - -bool VecSimBatchIterator_HasNext(VecSimBatchIterator *iterator) { return !iterator->isDepleted(); } - -void VecSimBatchIterator_Free(VecSimBatchIterator *iterator) { - // Batch iterator might be deleted after the index, so it should keep the allocator before - // deleting. - auto allocator = iterator->getAllocator(); - delete iterator; -} - -void VecSimBatchIterator_Reset(VecSimBatchIterator *iterator) { iterator->reset(); } diff --git a/src/VecSim/query_results.h b/src/VecSim/query_results.h deleted file mode 100644 index be5f630a3..000000000 --- a/src/VecSim/query_results.h +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include -#include - -#include "vec_sim_common.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// The possible ordering for results that return from a query -typedef enum { BY_SCORE, BY_ID, BY_SCORE_THEN_ID } VecSimQueryReply_Order; - -typedef enum { - VecSim_QueryReply_OK = VecSim_OK, - VecSim_QueryReply_TimedOut, -} VecSimQueryReply_Code; - -////////////////////////////////////// VecSimQueryResult API ////////////////////////////////////// - -/** - * @brief A single query result. This is an opaque object from which a user can get the result - * vector id and score (comparing to the query vector). - */ -typedef struct VecSimQueryResult VecSimQueryResult; - -/** - * @brief Get the id of the result vector. If item is nullptr, return INVALID_ID (defined as the - * -1). - */ -int64_t VecSimQueryResult_GetId(const VecSimQueryResult *item); - -/** - * @brief Get the score of the result vector. If item is nullptr, return INVALID_SCORE (defined as - * the special value of NaN). - */ -double VecSimQueryResult_GetScore(const VecSimQueryResult *item); - -////////////////////////////////////// VecSimQueryReply API /////////////////////////////////////// - -/** - * @brief An opaque object from which results can be obtained via iterator. - */ -typedef struct VecSimQueryReply VecSimQueryReply; - -/** - * @brief Get the length of the result list that returned from a query. - */ -size_t VecSimQueryReply_Len(VecSimQueryReply *results); - -/** - * @brief Get the return code of a query. - */ -VecSimQueryReply_Code VecSimQueryReply_GetCode(VecSimQueryReply *results); - -/** - * @brief Release the entire query results list. - */ -void VecSimQueryReply_Free(VecSimQueryReply *results); - -////////////////////////////////// VecSimQueryReply_Iterator API ////////////////////////////////// - -/** - * @brief Iterator for going over the list of results that had returned form a query - */ -typedef struct VecSimQueryReply_Iterator VecSimQueryReply_Iterator; - -/** - * @brief Create an iterator for going over the list of results. The iterator needs to be free - * with VecSimQueryReply_IteratorFree. - */ -VecSimQueryReply_Iterator *VecSimQueryReply_GetIterator(VecSimQueryReply *results); - -/** - * @brief Advance the iterator, so it will point to the next item, and return the value. - * The first call will return the first result. This will return NULL once the iterator is depleted. - */ -VecSimQueryResult *VecSimQueryReply_IteratorNext(VecSimQueryReply_Iterator *iterator); - -/** - * @brief Return true while the iterator points to some result, false if it is depleted. - */ -bool VecSimQueryReply_IteratorHasNext(VecSimQueryReply_Iterator *iterator); - -/** - * @brief Rewind the iterator to the beginning of the result list - */ -void VecSimQueryReply_IteratorReset(VecSimQueryReply_Iterator *iterator); - -/** - * @brief Release the iterator - */ -void VecSimQueryReply_IteratorFree(VecSimQueryReply_Iterator *iterator); - -///////////////////////////////////// VecSimBatchIterator API ///////////////////////////////////// - -/** - * @brief Iterator for running the same query over an index, getting the in each iteration - * the best results that hasn't returned in the previous iterations. - */ -typedef struct VecSimBatchIterator VecSimBatchIterator; - -/** - * @brief Run TopKQuery over the underling index of the given iterator using BatchIterator_Next - * method, and return n_results new results. - * @param iterator the iterator that olds the current state of this "batched search". - * @param n_results number of new results to return. - * @param order enum - determine the returned results order (by id or by score). - * @return List of (at most) new n_results vectors which are the "nearest neighbours" to the - * underline query vector in the iterator. - */ -VecSimQueryReply *VecSimBatchIterator_Next(VecSimBatchIterator *iterator, size_t n_results, - VecSimQueryReply_Order order); - -/** - * @brief Return true while the iterator has new results to return, false if it is depleted - * (using BatchIterator_HasNext method) . - */ -bool VecSimBatchIterator_HasNext(VecSimBatchIterator *iterator); - -/** - * @brief Release the iterator using BatchIterator_Free method - */ -void VecSimBatchIterator_Free(VecSimBatchIterator *iterator); - -/** - * @brief Reset the iterator - back to the initial state using BatchIterator_Reset method. - */ -void VecSimBatchIterator_Reset(VecSimBatchIterator *iterator); - -#ifdef __cplusplus -} -#endif diff --git a/src/VecSim/spaces/AVX_utils.h b/src/VecSim/spaces/AVX_utils.h deleted file mode 100644 index 2fe0b904e..000000000 --- a/src/VecSim/spaces/AVX_utils.h +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "space_includes.h" - -template <__mmask8 mask> // (2^n)-1, where n is in 1..7 (1, 3, ..., 127) -static inline __m256 my_mm256_maskz_loadu_ps(const float *p) { - // Load 8 floats (assuming this is safe to do) - __m256 data = _mm256_loadu_ps(p); - // Set the mask for the loaded data (set 0 if a bit is 0) - __m256 masked_data = _mm256_blend_ps(_mm256_setzero_ps(), data, mask); - - return masked_data; -} - -template <__mmask8 mask> // (2^n)-1, where n is in 1..3 (1, 3, 7) -static inline __m256d my_mm256_maskz_loadu_pd(const double *p) { - // Load 4 doubles (assuming this is safe to do) - __m256d data = _mm256_loadu_pd(p); - // Set the mask for the loaded data (set 0 if a bit is 0) - __m256d masked_data = _mm256_blend_pd(_mm256_setzero_pd(), data, mask); - - return masked_data; -} - -static inline float my_mm256_reduce_add_ps(__m256 x) { - float PORTABLE_ALIGN32 TmpRes[8]; - _mm256_store_ps(TmpRes, x); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + - TmpRes[7]; -} diff --git a/src/VecSim/spaces/CMakeLists.txt b/src/VecSim/spaces/CMakeLists.txt deleted file mode 100644 index d88750e91..000000000 --- a/src/VecSim/spaces/CMakeLists.txt +++ /dev/null @@ -1,156 +0,0 @@ -# Build non optimized code in a single project without architecture optimization flag. -project(VectorSimilaritySpaces_no_optimization) -add_library(VectorSimilaritySpaces_no_optimization - L2/L2.cpp - IP/IP.cpp -) - -include(${root}/cmake/cpu_features.cmake) - -project(VectorSimilarity_Spaces) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") - -set(OPTIMIZATIONS "") - -if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(x86_64)|(AMD64|amd64)|(^i.86$)") - # Check that the compiler supports instructions flag. - # from gcc14+ -mavx512bw is implicitly enabled when -mavx512vbmi2 is requested - include(${root}/cmake/x86_64InstructionFlags.cmake) - - # build SSE/AVX* code only on x64 processors. - # This will add the relevant flag both to the space selector and the optimization. - if(CXX_AVX512BF16 AND CXX_AVX512VL) - message("Building with AVX512BF16 and AVX512VL") - set_source_files_properties(functions/AVX512BF16_VL.cpp PROPERTIES COMPILE_FLAGS "-mavx512bf16 -mavx512vl") - list(APPEND OPTIMIZATIONS functions/AVX512BF16_VL.cpp) - endif() - - if(CXX_AVX512VL AND CXX_AVX512FP16) - message("Building with AVX512FP16 and AVX512VL") - set_source_files_properties(functions/AVX512FP16_VL.cpp PROPERTIES COMPILE_FLAGS "-mavx512fp16 -mavx512vl") - list(APPEND OPTIMIZATIONS functions/AVX512FP16_VL.cpp) - endif() - - if(CXX_AVX512BW AND CXX_AVX512VBMI2) - message("Building with AVX512BW and AVX512VBMI2") - set_source_files_properties(functions/AVX512BW_VBMI2.cpp PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512vbmi2") - list(APPEND OPTIMIZATIONS functions/AVX512BW_VBMI2.cpp) - endif() - - if(CXX_AVX512F) - message("Building with AVX512F") - set_source_files_properties(functions/AVX512F.cpp PROPERTIES COMPILE_FLAGS "-mavx512f") - list(APPEND OPTIMIZATIONS functions/AVX512F.cpp) - endif() - - if(CXX_AVX512F AND CXX_AVX512BW AND CXX_AVX512VL AND CXX_AVX512VNNI) - message("Building with AVX512F, AVX512BW, AVX512VL and AVX512VNNI") - set_source_files_properties(functions/AVX512F_BW_VL_VNNI.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512vnni") - list(APPEND OPTIMIZATIONS functions/AVX512F_BW_VL_VNNI.cpp) - endif() - - if(CXX_AVX2) - message("Building with AVX2") - set_source_files_properties(functions/AVX2.cpp PROPERTIES COMPILE_FLAGS -mavx2) - list(APPEND OPTIMIZATIONS functions/AVX2.cpp) - endif() - - if(CXX_AVX2 AND CXX_FMA) - message("Building with AVX2 and FMA") - set_source_files_properties(functions/AVX2_FMA.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") - list(APPEND OPTIMIZATIONS functions/AVX2_FMA.cpp) - endif() - - if(CXX_F16C AND CXX_FMA AND CXX_AVX) - message("Building with CXX_F16C") - set_source_files_properties(functions/F16C.cpp PROPERTIES COMPILE_FLAGS "-mf16c -mfma -mavx") - list(APPEND OPTIMIZATIONS functions/F16C.cpp) - endif() - - if(CXX_AVX) - message("Building with AVX") - set_source_files_properties(functions/AVX.cpp PROPERTIES COMPILE_FLAGS -mavx) - list(APPEND OPTIMIZATIONS functions/AVX.cpp) - endif() - - if(CXX_SSE3) - message("Building with SSE3") - set_source_files_properties(functions/SSE3.cpp PROPERTIES COMPILE_FLAGS -msse3) - list(APPEND OPTIMIZATIONS functions/SSE3.cpp) - endif() - - if(CXX_SSE4) - message("Building with SSE4") - set_source_files_properties(functions/SSE4.cpp PROPERTIES COMPILE_FLAGS -msse4.1) - list(APPEND OPTIMIZATIONS functions/SSE4.cpp) - endif() - - if(CXX_SSE) - message("Building with SSE") - set_source_files_properties(functions/SSE.cpp PROPERTIES COMPILE_FLAGS -msse) - list(APPEND OPTIMIZATIONS functions/SSE.cpp) - endif() -endif() - -if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)|(ARM64)|(armv.*)") - include(${root}/cmake/aarch64InstructionFlags.cmake) - - # Create different optimization implementations for ARM architecture - if (CXX_NEON_DOTPROD) - message("Building with ARMV8.2 with dotprod") - set_source_files_properties(functions/NEON_DOTPROD.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+dotprod") - list(APPEND OPTIMIZATIONS functions/NEON_DOTPROD.cpp) - endif() - if (CXX_ARMV8A) - message("Building with ARMV8A") - set_source_files_properties(functions/NEON.cpp PROPERTIES COMPILE_FLAGS "-march=armv8-a") - list(APPEND OPTIMIZATIONS functions/NEON.cpp) - endif() - - # NEON half-precision support - if (CXX_NEON_HP AND CXX_ARMV8A) - message("Building with NEON+HP") - set_source_files_properties(functions/NEON_HP.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+fp16fml") - list(APPEND OPTIMIZATIONS functions/NEON_HP.cpp) - endif() - - # NEON bfloat16 support - if (CXX_NEON_BF16) - message("Building with NEON + BF16") - set_source_files_properties(functions/NEON_BF16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+bf16") - list(APPEND OPTIMIZATIONS functions/NEON_BF16.cpp) - endif() - - # SVE support - if (CXX_SVE) - message("Building with SVE") - set_source_files_properties(functions/SVE.cpp PROPERTIES COMPILE_FLAGS "-march=armv8-a+sve") - list(APPEND OPTIMIZATIONS functions/SVE.cpp) - endif() - - # SVE with BF16 support - if (CXX_SVE_BF16) - message("Building with SVE + BF16") - set_source_files_properties(functions/SVE_BF16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+sve+bf16") - list(APPEND OPTIMIZATIONS functions/SVE_BF16.cpp) - endif() - - # SVE2 support - if (CXX_SVE2) - message("Building with ARMV9A and SVE2") - set_source_files_properties(functions/SVE2.cpp PROPERTIES COMPILE_FLAGS "-march=armv9-a+sve2") - list(APPEND OPTIMIZATIONS functions/SVE2.cpp) - endif() -endif() - -# Here we are compiling the space selectors with the relevant optimization flag. -add_library(VectorSimilaritySpaces - L2_space.cpp - IP_space.cpp - spaces.cpp - ${OPTIMIZATIONS} - computer/preprocessor_container.cpp -) - -target_link_libraries(VectorSimilaritySpaces VectorSimilaritySpaces_no_optimization cpu_features) diff --git a/src/VecSim/spaces/IP/IP.cpp b/src/VecSim/spaces/IP/IP.cpp deleted file mode 100644 index e96cf0bc3..000000000 --- a/src/VecSim/spaces/IP/IP.cpp +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "IP.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include "VecSim/types/sq8.h" -#include - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8-FP32 inner product using algebraic identity: - * IP(x, y) = Σ(x_i * y_i) - * ≈ Σ((min + delta * q_i) * y_i) - * = min * Σy_i + delta * Σ(q_i * y_i) - * = min * y_sum + delta * quantized_dot_product - * - * Uses 4x loop unrolling with multiple accumulators for ILP. - * pVect1 is storage (SQ8): [uint8_t values (dim)] [min_val] [delta] [x_sum] [x_sum_squares (L2 - * only)] - * pVect2 is query (FP32): [float values (dim)] [y_sum] [y_sum_squares (L2 only)] - * - * Returns raw inner product value (not distance). Used by SQ8_FP32_InnerProduct, SQ8_FP32_Cosine, - * SQ8_FP32_L2Sqr. - */ -float SQ8_FP32_InnerProduct_Impl(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - - // Use 4 accumulators for instruction-level parallelism - float sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0; - - // Main loop: process 4 elements per iteration - size_t i = 0; - size_t dim4 = dimension & ~size_t(3); // dim4 is a multiple of 4 - for (; i < dim4; i += 4) { - sum0 += static_cast(pVect1[i + 0]) * pVect2[i + 0]; - sum1 += static_cast(pVect1[i + 1]) * pVect2[i + 1]; - sum2 += static_cast(pVect1[i + 2]) * pVect2[i + 2]; - sum3 += static_cast(pVect1[i + 3]) * pVect2[i + 3]; - } - - // Handle remainder (0-3 elements) - for (; i < dimension; i++) { - sum0 += static_cast(pVect1[i]) * pVect2[i]; - } - - // Combine accumulators - float quantized_dot = (sum0 + sum1) + (sum2 + sum3); - - // Get quantization parameters from stored vector (pVect1 is SQ8) - const float *params = reinterpret_cast(pVect1 + dimension); - const float min_val = params[sq8::MIN_VAL]; - const float delta = params[sq8::DELTA]; - - // Get precomputed y_sum from query blob (pVect2 is FP32, stored after the dim floats) - const float y_sum = pVect2[dimension + sq8::SUM_QUERY]; - - // Apply formula: IP = min * y_sum + delta * Σ(q_i * y_i) - return min_val * y_sum + delta * quantized_dot; -} - -float SQ8_FP32_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - SQ8_FP32_InnerProduct_Impl(pVect1v, pVect2v, dimension); -} - -float SQ8_FP32_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension) { - return SQ8_FP32_InnerProduct(pVect1v, pVect2v, dimension); -} - -// SQ8-to-SQ8: Common inner product implementation that returns the raw inner product value -// (not distance). Used by both SQ8_SQ8_InnerProduct, SQ8_SQ8_Cosine, and SQ8_SQ8_L2Sqr. -// Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] -float SQ8_SQ8_InnerProduct_Impl(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - - // Compute inner product of quantized values: Σ(q1[i]*q2[i]) - float product = 0; - for (size_t i = 0; i < dimension; i++) { - product += pVect1[i] * pVect2[i]; - } - - // Get quantization parameters from pVect1 - const float *params1 = reinterpret_cast(pVect1 + dimension); - const float min_val1 = params1[sq8::MIN_VAL]; - const float delta1 = params1[sq8::DELTA]; - const float sum1 = params1[sq8::SUM]; - - // Get quantization parameters from pVect2 - const float *params2 = reinterpret_cast(pVect2 + dimension); - const float min_val2 = params2[sq8::MIN_VAL]; - const float delta2 = params2[sq8::DELTA]; - const float sum2 = params2[sq8::SUM]; - - // Apply the algebraic formula using precomputed sums: - // IP = min1*sum2 + min2*sum1 + delta1*delta2*Σ(q1[i]*q2[i]) - dim*min1*min2 - return min_val1 * sum2 + min_val2 * sum1 - static_cast(dimension) * min_val1 * min_val2 + - delta1 * delta2 * product; -} - -// SQ8-to-SQ8: Both vectors are uint8 quantized with precomputed sum -// Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] -float SQ8_SQ8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - SQ8_SQ8_InnerProduct_Impl(pVect1v, pVect2v, dimension); -} - -// SQ8-to-SQ8: Both vectors are uint8 quantized and normalized with precomputed sum -// Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] -float SQ8_SQ8_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension) { - return SQ8_SQ8_InnerProduct(pVect1v, pVect2v, dimension); -} - -float FP32_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension) { - auto *vec1 = (float *)pVect1; - auto *vec2 = (float *)pVect2; - - float res = 0; - for (size_t i = 0; i < dimension; i++) { - res += vec1[i] * vec2[i]; - } - return 1.0f - res; -} - -double FP64_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension) { - auto *vec1 = (double *)pVect1; - auto *vec2 = (double *)pVect2; - - double res = 0; - for (size_t i = 0; i < dimension; i++) { - res += vec1[i] * vec2[i]; - } - return 1.0 - res; -} - -template -float BF16_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) { - auto *pVect1 = (bfloat16 *)pVect1v; - auto *pVect2 = (bfloat16 *)pVect2v; - - float res = 0; - for (size_t i = 0; i < dimension; i++) { - float a = vecsim_types::bfloat16_to_float32(pVect1[i]); - float b = vecsim_types::bfloat16_to_float32(pVect2[i]); - res += a * b; - } - return 1.0f - res; -} - -float BF16_InnerProduct_LittleEndian(const void *pVect1v, const void *pVect2v, size_t dimension) { - return BF16_InnerProduct(pVect1v, pVect2v, dimension); -} - -float BF16_InnerProduct_BigEndian(const void *pVect1v, const void *pVect2v, size_t dimension) { - return BF16_InnerProduct(pVect1v, pVect2v, dimension); -} - -float FP16_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension) { - auto *vec1 = (float16 *)pVect1; - auto *vec2 = (float16 *)pVect2; - - float res = 0; - for (size_t i = 0; i < dimension; i++) { - res += vecsim_types::FP16_to_FP32(vec1[i]) * vecsim_types::FP16_to_FP32(vec2[i]); - } - return 1.0f - res; -} - -// Return type for the inner product functions. -// The type should be able to hold `dimension * MAX_VAL(int_elem_t) * MAX_VAL(int_elem_t)`. -// To support dimension up to 2^16, we need the difference between the type and int_elem_t to be at -// least 2 bytes. We assert that in the implementation. -template -using ret_t = std::conditional_t; - -template -static inline ret_t -INTEGER_InnerProductImp(const int_elem_t *pVect1, const int_elem_t *pVect2, size_t dimension) { - static_assert(sizeof(ret_t) - sizeof(int_elem_t) * 2 >= sizeof(uint16_t)); - ret_t res = 0; - for (size_t i = 0; i < dimension; i++) { - res += pVect1[i] * pVect2[i]; - } - return res; -} - -float INT8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - return 1 - INTEGER_InnerProductImp(pVect1, pVect2, dimension); -} - -float INT8_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - // We expect the vectors' norm to be stored at the end of the vector. - float norm_v1 = *reinterpret_cast(pVect1 + dimension); - float norm_v2 = *reinterpret_cast(pVect2 + dimension); - return 1.0f - float(INTEGER_InnerProductImp(pVect1, pVect2, dimension)) / (norm_v1 * norm_v2); -} - -float UINT8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - return 1 - INTEGER_InnerProductImp(pVect1, pVect2, dimension); -} - -float UINT8_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - // We expect the vectors' norm to be stored at the end of the vector. - float norm_v1 = *reinterpret_cast(pVect1 + dimension); - float norm_v2 = *reinterpret_cast(pVect2 + dimension); - return 1.0f - float(INTEGER_InnerProductImp(pVect1, pVect2, dimension)) / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP/IP.h b/src/VecSim/spaces/IP/IP.h deleted file mode 100644 index 64f2003ec..000000000 --- a/src/VecSim/spaces/IP/IP.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include - -// SQ8-FP32: Common inner product implementation that returns the raw inner product value -// (not distance). Used by SQ8_FP32_InnerProduct, SQ8_FP32_Cosine, and SQ8_FP32_L2Sqr. -// pVect1 is storage (SQ8): [uint8_t values (dim)] [min_val] [delta] [x_sum] [x_sum_squares (L2 -// only)] -// pVect2 is query (FP32): [float values (dim)] [y_sum] [y_sum_squares (L2 only)] -float SQ8_FP32_InnerProduct_Impl(const void *pVect1v, const void *pVect2v, size_t dimension); - -// pVect1v vector of type uint8 (SQ8) and pVect2v vector of type fp32 -float SQ8_FP32_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension); - -// pVect1v vector of type uint8 (SQ8) and pVect2v vector of type fp32 -float SQ8_FP32_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension); - -// SQ8-to-SQ8: Common inner product implementation that returns the raw inner product value -// (not distance). Used by both SQ8_SQ8_InnerProduct, SQ8_SQ8_Cosine, and SQ8_SQ8_L2Sqr. -// Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] -float SQ8_SQ8_InnerProduct_Impl(const void *pVect1v, const void *pVect2v, size_t dimension); - -// SQ8-to-SQ8: Both vectors are uint8 quantized with precomputed sum -// Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] -float SQ8_SQ8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension); - -// SQ8-to-SQ8: Both vectors are uint8 quantized and normalized with precomputed sum -// Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] -float SQ8_SQ8_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension); - -float FP32_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension); - -double FP64_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension); - -float FP16_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension); - -float BF16_InnerProduct_LittleEndian(const void *pVect1v, const void *pVect2v, size_t dimension); -float BF16_InnerProduct_BigEndian(const void *pVect1v, const void *pVect2v, size_t dimension); - -float INT8_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension); -float INT8_Cosine(const void *pVect1, const void *pVect2, size_t dimension); - -float UINT8_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension); -float UINT8_Cosine(const void *pVect1, const void *pVect2, size_t dimension); diff --git a/src/VecSim/spaces/IP/IP_AVX2_BF16.h b/src/VecSim/spaces/IP/IP_AVX2_BF16.h deleted file mode 100644 index c7cd08bdc..000000000 --- a/src/VecSim/spaces/IP/IP_AVX2_BF16.h +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/spaces/AVX_utils.h" - -using bfloat16 = vecsim_types::bfloat16; - -static inline void InnerProductLowHalfStep(__m256i v1, __m256i v2, __m256i zeros, - __m256 &sum_prod) { - // Convert next 0:3, 8:11 bf16 to 8 floats - __m256i bf16_low1 = _mm256_unpacklo_epi16(zeros, v1); // AVX2 - __m256i bf16_low2 = _mm256_unpacklo_epi16(zeros, v2); - - sum_prod = _mm256_add_ps( - sum_prod, _mm256_mul_ps(_mm256_castsi256_ps(bf16_low1), _mm256_castsi256_ps(bf16_low2))); -} - -static inline void InnerProductHighHalfStep(__m256i v1, __m256i v2, __m256i zeros, - __m256 &sum_prod) { - // Convert next 4:7, 12:15 bf16 to 8 floats - __m256i bf16_high1 = _mm256_unpackhi_epi16(zeros, v1); - __m256i bf16_high2 = _mm256_unpackhi_epi16(zeros, v2); - - sum_prod = _mm256_add_ps( - sum_prod, _mm256_mul_ps(_mm256_castsi256_ps(bf16_high1), _mm256_castsi256_ps(bf16_high2))); -} - -static inline void InnerProductStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m256 &sum_prod) { - // Load 16 bf16 elements - __m256i v1 = _mm256_lddqu_si256((__m256i *)pVect1); // avx - pVect1 += 16; - __m256i v2 = _mm256_lddqu_si256((__m256i *)pVect2); - pVect2 += 16; - - __m256i zeros = _mm256_setzero_si256(); // avx - - // Compute dist for 0:3, 8:11 bf16 - InnerProductLowHalfStep(v1, v2, zeros, sum_prod); - - // Compute dist for 4:7, 12:15 bf16 - InnerProductHighHalfStep(v1, v2, zeros, sum_prod); -} - -template // 0..31 -float BF16_InnerProductSIMD32_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) { - bfloat16 *pVect1 = (bfloat16 *)pVect1v; - bfloat16 *pVect2 = (bfloat16 *)pVect2v; - - const bfloat16 *pEnd1 = pVect1 + dimension; - - __m256 sum_prod = _mm256_setzero_ps(); - - // Handle first (residual % 16) elements - if constexpr (residual % 16) { - // Load all 16 elements to a 256 bit register - __m256i v1 = _mm256_lddqu_si256((__m256i *)pVect1); // avx - pVect1 += residual % 16; - __m256i v2 = _mm256_lddqu_si256((__m256i *)pVect2); - pVect2 += residual % 16; - - // Unpack 0:3, 8:11 bf16 to 8 floats - __m256i zeros = _mm256_setzero_si256(); - __m256i v1_low = _mm256_unpacklo_epi16(zeros, v1); - __m256i v2_low = _mm256_unpacklo_epi16(zeros, v2); - - __m256 low_mul = _mm256_mul_ps(_mm256_castsi256_ps(v1_low), _mm256_castsi256_ps(v2_low)); - if constexpr (residual % 16 <= 4) { - constexpr unsigned char elem_to_calc = residual % 16; - constexpr __mmask8 mask = (1 << elem_to_calc) - 1; - low_mul = _mm256_blend_ps(_mm256_setzero_ps(), low_mul, mask); - } else { - __m256i v1_high = _mm256_unpackhi_epi16(zeros, v1); - __m256i v2_high = _mm256_unpackhi_epi16(zeros, v2); - __m256 high_mul = - _mm256_mul_ps(_mm256_castsi256_ps(v1_high), _mm256_castsi256_ps(v2_high)); - if constexpr (4 < residual % 16 && residual % 16 <= 8) { - // Keep only 4 first elements of low pack - constexpr __mmask8 mask = (1 << 4) - 1; - low_mul = _mm256_blend_ps(_mm256_setzero_ps(), low_mul, mask); - - // Keep (residual % 16 - 4) first elements of high_mul - constexpr unsigned char elem_to_calc = residual % 16 - 4; - constexpr __mmask8 mask2 = (1 << elem_to_calc) - 1; - high_mul = _mm256_blend_ps(_mm256_setzero_ps(), high_mul, mask2); - } else if constexpr (8 < residual % 16 && residual % 16 < 12) { - // Keep (residual % 16 - 4) first elements of low_mul - constexpr unsigned char elem_to_calc = residual % 16 - 4; - constexpr __mmask8 mask = (1 << elem_to_calc) - 1; - low_mul = _mm256_blend_ps(_mm256_setzero_ps(), low_mul, mask); - - // Keep ony 4 first elements of high_mul - constexpr __mmask8 mask2 = (1 << 4) - 1; - high_mul = _mm256_blend_ps(_mm256_setzero_ps(), high_mul, mask2); - } else if constexpr (residual % 16 >= 12) { - // Keep (residual % 16 - 8) first elements of high - constexpr unsigned char elem_to_calc = (residual % 16) - 8; - constexpr __mmask8 mask = (1 << elem_to_calc) - 1; - high_mul = _mm256_blend_ps(_mm256_setzero_ps(), high_mul, mask); - } - sum_prod = _mm256_add_ps(sum_prod, high_mul); - } - sum_prod = _mm256_add_ps(sum_prod, low_mul); - } - - // Do a single step if residual >=16 - if constexpr (residual >= 16) { - InnerProductStep(pVect1, pVect2, sum_prod); - } - - // Handle 512 bits (32 bfloat16) in chunks of max SIMD = 256 bits = 16 bfloat16 - do { - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - } while (pVect1 < pEnd1); - - return 1.0f - my_mm256_reduce_add_ps(sum_prod); -} diff --git a/src/VecSim/spaces/IP/IP_AVX2_FMA_SQ8_FP32.h b/src/VecSim/spaces/IP/IP_AVX2_FMA_SQ8_FP32.h deleted file mode 100644 index 5767a4828..000000000 --- a/src/VecSim/spaces/IP/IP_AVX2_FMA_SQ8_FP32.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" -#include "VecSim/types/sq8.h" -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 inner product using algebraic identity: - * - * IP(x, y) = Σ(x_i * y_i) - * ≈ Σ((min + delta * q_i) * y_i) - * = min * Σy_i + delta * Σ(q_i * y_i) - * = min * y_sum + delta * quantized_dot_product - * - * where y_sum = Σy_i is precomputed and stored in the query blob. - * This avoids dequantization in the hot loop - we only compute Σ(q_i * y_i). - * - * This version uses FMA instructions for better performance. - */ - -// Helper: compute Σ(q_i * y_i) for 8 elements using FMA (no dequantization) -// pVect1 = SQ8 storage (quantized values), pVect2 = FP32 query -static inline void InnerProductStepSQ8_FMA(const uint8_t *&pVect1, const float *&pVect2, - __m256 &sum256) { - // Load 8 uint8 elements and convert to float - __m128i v1_128 = _mm_loadl_epi64(reinterpret_cast(pVect1)); - pVect1 += 8; - - __m256i v1_256 = _mm256_cvtepu8_epi32(v1_128); - __m256 v1_f = _mm256_cvtepi32_ps(v1_256); - - // Load 8 float elements from query - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - - // Accumulate q_i * y_i using FMA (no dequantization!) - sum256 = _mm256_fmadd_ps(v1_f, v2, sum256); -} - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_InnerProductImp_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) { - const uint8_t *pVect1 = static_cast(pVect1v); // SQ8 storage - const float *pVect2 = static_cast(pVect2v); // FP32 query - const uint8_t *pEnd1 = pVect1 + dimension; - - // Initialize sum accumulator for Σ(q_i * y_i) - __m256 sum256 = _mm256_setzero_ps(); - - // Handle residual elements first (0-7 elements) - if constexpr (residual % 8) { - __mmask8 constexpr mask = (1 << (residual % 8)) - 1; - - // Load uint8 elements and convert to float - __m128i v1_128 = _mm_loadl_epi64(reinterpret_cast(pVect1)); - pVect1 += residual % 8; - - __m256i v1_256 = _mm256_cvtepu8_epi32(v1_128); - __m256 v1_f = _mm256_cvtepi32_ps(v1_256); - - // Load masked float elements from query - __m256 v2 = my_mm256_maskz_loadu_ps(pVect2); - pVect2 += residual % 8; - - // Compute q_i * y_i (no dequantization) - sum256 = _mm256_mul_ps(v1_f, v2); - } - - // If the residual is >=8, have another step of 8 floats - if constexpr (residual >= 8) { - InnerProductStepSQ8_FMA(pVect1, pVect2, sum256); - } - - // Process remaining full chunks of 16 elements (2x8) - // Using do-while since dim > 16 guarantees at least one iteration - do { - InnerProductStepSQ8_FMA(pVect1, pVect2, sum256); - InnerProductStepSQ8_FMA(pVect1, pVect2, sum256); - } while (pVect1 < pEnd1); - - // Reduce to get Σ(q_i * y_i) - float quantized_dot = my_mm256_reduce_add_ps(sum256); - - // Get quantization parameters from stored vector (after quantized data) - const uint8_t *pVect1Base = static_cast(pVect1v); - const float *params1 = reinterpret_cast(pVect1Base + dimension); - const float min_val = params1[sq8::MIN_VAL]; - const float delta = params1[sq8::DELTA]; - - // Get precomputed y_sum from query blob (stored after the dim floats) - const float y_sum = static_cast(pVect2v)[dimension + sq8::SUM_QUERY]; - - // Apply the algebraic formula: IP = min * y_sum + delta * Σ(q_i * y_i) - return min_val * y_sum + delta * quantized_dot; -} - -template // 0..15 -float SQ8_FP32_InnerProductSIMD16_AVX2_FMA(const void *pVect1v, const void *pVect2v, - size_t dimension) { - return 1.0f - SQ8_FP32_InnerProductImp_FMA(pVect1v, pVect2v, dimension); -} - -template // 0..15 -float SQ8_FP32_CosineSIMD16_AVX2_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Cosine distance = 1 - IP (vectors are pre-normalized) - return SQ8_FP32_InnerProductSIMD16_AVX2_FMA(pVect1v, pVect2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_AVX2_SQ8_FP32.h b/src/VecSim/spaces/IP/IP_AVX2_SQ8_FP32.h deleted file mode 100644 index dea167eb3..000000000 --- a/src/VecSim/spaces/IP/IP_AVX2_SQ8_FP32.h +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" -#include "VecSim/types/sq8.h" - -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 inner product using algebraic identity: - * - * IP(x, y) = Σ(x_i * y_i) - * ≈ Σ((min + delta * q_i) * y_i) - * = min * Σy_i + delta * Σ(q_i * y_i) - * = min * y_sum + delta * quantized_dot_product - * - * where y_sum = Σy_i is precomputed and stored in the query blob. - * This avoids dequantization in the hot loop - we only compute Σ(q_i * y_i). - */ - -// Helper: compute Σ(q_i * y_i) for 8 elements (no dequantization) -// pVect1 = SQ8 storage (quantized values), pVect2 = FP32 query -static inline void InnerProductStepSQ8_FP32(const uint8_t *&pVect1, const float *&pVect2, - __m256 &sum256) { - // Load 8 uint8 elements and convert to float - __m128i v1_128 = _mm_loadl_epi64(reinterpret_cast(pVect1)); - pVect1 += 8; - - __m256i v1_256 = _mm256_cvtepu8_epi32(v1_128); - __m256 v1_f = _mm256_cvtepi32_ps(v1_256); - - // Load 8 float elements from query - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - - // Accumulate q_i * y_i (no dequantization!) - // Using mul + add since this is the non-FMA version - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1_f, v2)); -} - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_InnerProductImp_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) { - const uint8_t *pVect1 = static_cast(pVect1v); // SQ8 storage - const float *pVect2 = static_cast(pVect2v); // FP32 query - const uint8_t *pEnd1 = pVect1 + dimension; - - // Initialize sum accumulator for Σ(q_i * y_i) - __m256 sum256 = _mm256_setzero_ps(); - - // Handle residual elements first (0-7 elements) - if constexpr (residual % 8) { - __mmask8 constexpr mask = (1 << (residual % 8)) - 1; - - // Load uint8 elements and convert to float - __m128i v1_128 = _mm_loadl_epi64(reinterpret_cast(pVect1)); - pVect1 += residual % 8; - - __m256i v1_256 = _mm256_cvtepu8_epi32(v1_128); - __m256 v1_f = _mm256_cvtepi32_ps(v1_256); - - // Load masked float elements from query - __m256 v2 = my_mm256_maskz_loadu_ps(pVect2); - pVect2 += residual % 8; - - // Compute q_i * y_i (no dequantization) - sum256 = _mm256_mul_ps(v1_f, v2); - } - - // If the residual is >=8, have another step of 8 floats - if constexpr (residual >= 8) { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum256); - } - - // Process remaining full chunks of 16 elements (2x8) - // Using do-while since dim > 16 guarantees at least one iteration - do { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum256); - InnerProductStepSQ8_FP32(pVect1, pVect2, sum256); - } while (pVect1 < pEnd1); - - // Reduce to get Σ(q_i * y_i) - float quantized_dot = my_mm256_reduce_add_ps(sum256); - - // Get quantization parameters from stored vector (after quantized data) - const uint8_t *pVect1Base = static_cast(pVect1v); - const float *params1 = reinterpret_cast(pVect1Base + dimension); - const float min_val = params1[sq8::MIN_VAL]; - const float delta = params1[sq8::DELTA]; - - // Get precomputed y_sum from query blob (stored after the dim floats) - const float y_sum = static_cast(pVect2v)[dimension + sq8::SUM_QUERY]; - - // Apply the algebraic formula: IP = min * y_sum + delta * Σ(q_i * y_i) - return min_val * y_sum + delta * quantized_dot; -} - -template // 0..15 -float SQ8_FP32_InnerProductSIMD16_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - SQ8_FP32_InnerProductImp_AVX2(pVect1v, pVect2v, dimension); -} - -template // 0..15 -float SQ8_FP32_CosineSIMD16_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Calculate inner product using common implementation with normalization - return SQ8_FP32_InnerProductSIMD16_AVX2(pVect1v, pVect2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512BW_VBMI2_BF16.h b/src/VecSim/spaces/IP/IP_AVX512BW_VBMI2_BF16.h deleted file mode 100644 index cebeba041..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512BW_VBMI2_BF16.h +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/bfloat16.h" - -using bfloat16 = vecsim_types::bfloat16; - -static inline void InnerProductHalfStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m512 &sum, - __mmask32 mask) { - __m512i v1 = _mm512_maskz_expandloadu_epi16(mask, pVect1); // AVX512_VBMI2 - __m512i v2 = _mm512_maskz_expandloadu_epi16(mask, pVect2); // AVX512_VBMI2 - sum = _mm512_fmadd_ps(_mm512_castsi512_ps(v1), _mm512_castsi512_ps(v2), sum); -} - -static inline void InnerProductStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m512 &sum) { - __m512i v1 = _mm512_loadu_si512((__m512i *)pVect1); - __m512i v2 = _mm512_loadu_si512((__m512i *)pVect2); - pVect1 += 32; - pVect2 += 32; - __m512i zeros = _mm512_setzero_si512(); - - // Convert 0:3, 8:11, .. 28:31 to float32 - __m512i v1_low = _mm512_unpacklo_epi16(zeros, v1); // AVX512BW - __m512i v2_low = _mm512_unpacklo_epi16(zeros, v2); - sum = _mm512_fmadd_ps(_mm512_castsi512_ps(v1_low), _mm512_castsi512_ps(v2_low), sum); - - // Convert 4:7, 12:15, .. 24:27 to float32 - __m512i v1_high = _mm512_unpackhi_epi16(zeros, v1); - __m512i v2_high = _mm512_unpackhi_epi16(zeros, v2); - sum = _mm512_fmadd_ps(_mm512_castsi512_ps(v1_high), _mm512_castsi512_ps(v2_high), sum); -} - -template // 0..31 -float BF16_InnerProductSIMD32_AVX512BW_VBMI2(const void *pVect1v, const void *pVect2v, - size_t dimension) { - bfloat16 *pVect1 = (bfloat16 *)pVect1v; - bfloat16 *pVect2 = (bfloat16 *)pVect2v; - - const bfloat16 *pEnd1 = pVect1 + dimension; - - __m512 sum = _mm512_setzero_ps(); - - // Handle first residual % 32 elements - if constexpr (residual) { - constexpr __mmask32 mask = 0xAAAAAAAA; // 01010101... - - // Calculate first 16 - if constexpr (residual >= 16) { - InnerProductHalfStep(pVect1, pVect2, sum, mask); - pVect1 += 16; - pVect2 += 16; - } - if constexpr (residual != 16) { - // Each element is represented by a pair of 01 bits - // Create a mask for the elements we want to process: - // mask2 = {01 * (residual % 16)}0000... - constexpr __mmask32 mask2 = mask & ((1 << ((residual % 16) * 2)) - 1); - InnerProductHalfStep(pVect1, pVect2, sum, mask2); - pVect1 += residual % 16; - pVect2 += residual % 16; - } - } - - // Handle 512 bits (32 bfloat16) in chunks of max SIMD = 512 bits = 32 bfloat16 - do { - InnerProductStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return 1.0f - _mm512_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512FP16_VL_FP16.h b/src/VecSim/spaces/IP/IP_AVX512FP16_VL_FP16.h deleted file mode 100644 index 130ff2c7a..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512FP16_VL_FP16.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/float16.h" -#include - -using float16 = vecsim_types::float16; - -static void InnerProductStep(float16 *&pVect1, float16 *&pVect2, __m512h &sum) { - __m512h v1 = _mm512_loadu_ph(pVect1); - __m512h v2 = _mm512_loadu_ph(pVect2); - - sum = _mm512_fmadd_ph(v1, v2, sum); - pVect1 += 32; - pVect2 += 32; -} - -template // 0..31 -float FP16_InnerProductSIMD32_AVX512FP16_VL(const void *pVect1v, const void *pVect2v, - size_t dimension) { - auto *pVect1 = (float16 *)pVect1v; - auto *pVect2 = (float16 *)pVect2v; - - const float16 *pEnd1 = pVect1 + dimension; - - __m512h sum = _mm512_setzero_ph(); - - if constexpr (residual) { - constexpr __mmask32 mask = (1LU << residual) - 1; - __m512h v1 = _mm512_loadu_ph(pVect1); - pVect1 += residual; - __m512h v2 = _mm512_loadu_ph(pVect2); - pVect2 += residual; - sum = _mm512_maskz_mul_ph(mask, v1, v2); - } - - // We dealt with the residual part. We are left with some multiple of 32 16-bit floats. - do { - InnerProductStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - _Float16 res = _mm512_reduce_add_ph(sum); - return _Float16(1) - res; -} diff --git a/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_INT8.h b/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_INT8.h deleted file mode 100644 index add070942..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_INT8.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void InnerProductStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) { - __m256i temp_a = _mm256_loadu_epi8(pVect1); - __m512i va = _mm512_cvtepi8_epi16(temp_a); - pVect1 += 32; - - __m256i temp_b = _mm256_loadu_epi8(pVect2); - __m512i vb = _mm512_cvtepi8_epi16(temp_b); - pVect2 += 32; - - // _mm512_dpwssd_epi32(src, a, b) - // Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding - // 16-bit integers in `b`, producing 2 intermediate signed 32-bit results. Sum these 2 results - // with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst. - sum = _mm512_dpwssd_epi32(sum, va, vb); -} - -template // 0..63 -static inline int INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { - int8_t *pVect1 = (int8_t *)pVect1v; - int8_t *pVect2 = (int8_t *)pVect2v; - - const int8_t *pEnd1 = pVect1 + dimension; - - __m512i sum = _mm512_setzero_epi32(); - - // Deal with remainder first. `dim` is more than 32, so we have at least one 32-int_8 block, - // so mask loading is guaranteed to be safe - if constexpr (residual % 32) { - constexpr __mmask32 mask = (1LU << (residual % 32)) - 1; - __m256i temp_a = _mm256_maskz_loadu_epi8(mask, pVect1); - __m512i va = _mm512_cvtepi8_epi16(temp_a); - pVect1 += residual % 32; - - __m256i temp_b = _mm256_maskz_loadu_epi8(mask, pVect2); - __m512i vb = _mm512_cvtepi8_epi16(temp_b); - pVect2 += residual % 32; - - sum = _mm512_dpwssd_epi32(sum, va, vb); - } - - if constexpr (residual >= 32) { - InnerProductStep(pVect1, pVect2, sum); - } - - // We dealt with the residual part. We are left with some multiple of 64-int_8. - while (pVect1 < pEnd1) { - InnerProductStep(pVect1, pVect2, sum); - InnerProductStep(pVect1, pVect2, sum); - } - - return _mm512_reduce_add_epi32(sum); -} - -template // 0..63 -float INT8_InnerProductSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, - size_t dimension) { - - return 1 - INT8_InnerProductImp(pVect1v, pVect2v, dimension); -} -template // 0..63 -float INT8_CosineSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, - size_t dimension) { - float ip = INT8_InnerProductImp(pVect1v, pVect2v, dimension); - float norm_v1 = - *reinterpret_cast(static_cast(pVect1v) + dimension); - float norm_v2 = - *reinterpret_cast(static_cast(pVect2v) + dimension); - return 1.0f - ip / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_FP32.h b/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_FP32.h deleted file mode 100644 index 76a590519..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_FP32.h +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/sq8.h" -#include - -using sq8 = vecsim_types::sq8; -/* - * Optimized asymmetric SQ8 inner product using algebraic identity: - * - * IP(x, y) = Σ(x_i * y_i) - * ≈ Σ((min + delta * q_i) * y_i) - * = min * Σy_i + delta * Σ(q_i * y_i) - * = min * y_sum + delta * quantized_dot_product - * - * where y_sum = Σy_i is precomputed and stored in the query blob. - * This avoids dequantization in the hot loop - we only compute Σ(q_i * y_i). - */ - -// Helper: compute Σ(q_i * y_i) for 16 elements -// pVec1 = SQ8 storage (quantized values), pVec2 = FP32 query -static inline void SQ8_FP32_InnerProductStep(const uint8_t *&pVec1, const float *&pVec2, - __m512 &sum) { - // Load 16 uint8 elements from quantized vector and convert to float - __m128i v1_128 = _mm_loadu_si128(reinterpret_cast(pVec1)); - __m512i v1_512 = _mm512_cvtepu8_epi32(v1_128); - __m512 v1_f = _mm512_cvtepi32_ps(v1_512); - - // Load 16 float elements from query (pVec2) - __m512 v2 = _mm512_loadu_ps(pVec2); - - // Accumulate q_i * y_i (no dequantization!) - sum = _mm512_fmadd_ps(v1_f, v2, sum); - - pVec1 += 16; - pVec2 += 16; -} - -// Common implementation for both inner product and cosine similarity -// pVec1v = SQ8 storage, pVec2v = FP32 query -template // 0..15 -float SQ8_FP32_InnerProductImp_AVX512(const void *pVec1v, const void *pVec2v, size_t dimension) { - const uint8_t *pVec1 = static_cast(pVec1v); // SQ8 storage - const float *pVec2 = static_cast(pVec2v); // FP32 query - const uint8_t *pEnd1 = pVec1 + dimension; - - // Initialize sum accumulator for Σ(q_i * y_i) - __m512 sum = _mm512_setzero_ps(); - - // Handle residual elements first (0 to 15) - if constexpr (residual > 0) { - __mmask16 mask = (1U << residual) - 1; - - // Load uint8 elements (safe to load 16 bytes due to padding) - __m128i v1_128 = _mm_loadu_si128(reinterpret_cast(pVec1)); - __m512i v1_512 = _mm512_cvtepu8_epi32(v1_128); - __m512 v1_f = _mm512_cvtepi32_ps(v1_512); - - // Load masked float elements from query - __m512 v2 = _mm512_maskz_loadu_ps(mask, pVec2); - - // Compute q_i * y_i with mask (no dequantization) - sum = _mm512_maskz_mul_ps(mask, v1_f, v2); - - pVec1 += residual; - pVec2 += residual; - } - - // Process full chunks of 16 elements - // Using do-while since dim > 16 guarantees at least one iteration - do { - SQ8_FP32_InnerProductStep(pVec1, pVec2, sum); - } while (pVec1 < pEnd1); - - // Reduce to get Σ(q_i * y_i) - float quantized_dot = _mm512_reduce_add_ps(sum); - - // Get quantization parameters from stored vector (after quantized data) - // Use the original base pointer since pVec1 has been advanced - const uint8_t *pVec1Base = static_cast(pVec1v); - const float *params1 = reinterpret_cast(pVec1Base + dimension); - const float min_val = params1[sq8::MIN_VAL]; - const float delta = params1[sq8::DELTA]; - - // Get precomputed y_sum from query blob (stored after the dim floats) - // Use the original base pointer since pVec2 has been advanced - const float y_sum = static_cast(pVec2v)[dimension + sq8::SUM_QUERY]; - - // Apply the algebraic formula: IP = min * y_sum + delta * Σ(q_i * y_i) - return min_val * y_sum + delta * quantized_dot; -} - -template // 0..15 -float SQ8_FP32_InnerProductSIMD16_AVX512F_BW_VL_VNNI(const void *pVec1v, const void *pVec2v, - size_t dimension) { - // The inner product similarity is 1 - ip - return 1.0f - SQ8_FP32_InnerProductImp_AVX512(pVec1v, pVec2v, dimension); -} - -template // 0..15 -float SQ8_FP32_CosineSIMD16_AVX512F_BW_VL_VNNI(const void *pVec1v, const void *pVec2v, - size_t dimension) { - // Cosine distance = 1 - IP (vectors are pre-normalized) - return SQ8_FP32_InnerProductSIMD16_AVX512F_BW_VL_VNNI(pVec1v, pVec2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_SQ8.h b/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_SQ8.h deleted file mode 100644 index 899b466b9..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_SQ8.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_UINT8.h" -#include "VecSim/types/sq8.h" -#include - -using sq8 = vecsim_types::sq8; - -/** - * SQ8-to-SQ8 distance functions using AVX512 VNNI with precomputed sum. - * These functions compute distance between two SQ8 (scalar quantized 8-bit) vectors, - * where BOTH vectors are uint8 quantized. - * - * Uses precomputed sum stored in the vector data, - * eliminating the need to compute them during distance calculation. - * - * Uses algebraic optimization to leverage integer VNNI instructions: - * - * With sum = Σv[i] (sum of original float values), the formula is: - * IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1[i]*q2[i]) - dim*min1*min2 - * - * Since sum is precomputed, we only need to compute the dot product Σ(q1[i]*q2[i]). - * The dot product is computed using the efficient UINT8_InnerProductImp which uses - * VNNI instructions (_mm512_dpwssd_epi32) for native integer dot product computation. - * - * Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] - */ - -// Common implementation for inner product between two SQ8 vectors with precomputed sum -// Uses UINT8_InnerProductImp for efficient dot product computation with VNNI -template // 0..63 -float SQ8_SQ8_InnerProductImp(const void *pVec1v, const void *pVec2v, size_t dimension) { - // Compute raw dot product using efficient UINT8 AVX512 VNNI implementation - // UINT8_InnerProductImp uses _mm512_dpwssd_epi32 for native integer dot product - int dot_product = UINT8_InnerProductImp(pVec1v, pVec2v, dimension); - - // Get dequantization parameters and precomputed values from the end of vectors - // Layout: [data (dim)] [min (float)] [delta (float)] [sum (float)] - const uint8_t *pVec1 = static_cast(pVec1v); - const uint8_t *pVec2 = static_cast(pVec2v); - - const float *params1 = reinterpret_cast(pVec1 + dimension); - const float min1 = params1[sq8::MIN_VAL]; - const float delta1 = params1[sq8::DELTA]; - const float sum1 = params1[sq8::SUM]; // Precomputed sum of original float elements - - const float *params2 = reinterpret_cast(pVec2 + dimension); - const float min2 = params2[sq8::MIN_VAL]; - const float delta2 = params2[sq8::DELTA]; - const float sum2 = params2[sq8::SUM]; // Precomputed sum of original float elements - - // Apply the algebraic formula using precomputed sums: - // IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1[i]*q2[i]) - dim*min1*min2 - return min1 * sum2 + min2 * sum1 + delta1 * delta2 * static_cast(dot_product) - - static_cast(dimension) * min1 * min2; -} - -// SQ8-to-SQ8 Inner Product distance function -// Returns 1 - inner_product (distance form) -template // 0..63 -float SQ8_SQ8_InnerProductSIMD64_AVX512F_BW_VL_VNNI(const void *pVec1v, const void *pVec2v, - size_t dimension) { - return 1.0f - SQ8_SQ8_InnerProductImp(pVec1v, pVec2v, dimension); -} - -// SQ8-to-SQ8 Cosine distance function -// Returns 1 - (inner_product) -template // 0..63 -float SQ8_SQ8_CosineSIMD64_AVX512F_BW_VL_VNNI(const void *pVec1v, const void *pVec2v, - size_t dimension) { - // Assume vectors are normalized. - return SQ8_SQ8_InnerProductSIMD64_AVX512F_BW_VL_VNNI(pVec1v, pVec2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_UINT8.h b/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_UINT8.h deleted file mode 100644 index deed0f706..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_UINT8.h +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" - -static inline void InnerProductStep(uint8_t *&pVect1, uint8_t *&pVect2, __m512i &sum) { - __m512i va = _mm512_loadu_epi8(pVect1); // AVX512BW - pVect1 += 64; - - __m512i vb = _mm512_loadu_epi8(pVect2); // AVX512BW - pVect2 += 64; - - __m512i va_lo = _mm512_unpacklo_epi8(va, _mm512_setzero_si512()); // AVX512BW - __m512i vb_lo = _mm512_unpacklo_epi8(vb, _mm512_setzero_si512()); - sum = _mm512_dpwssd_epi32(sum, va_lo, vb_lo); - - __m512i va_hi = _mm512_unpackhi_epi8(va, _mm512_setzero_si512()); // AVX512BW - __m512i vb_hi = _mm512_unpackhi_epi8(vb, _mm512_setzero_si512()); - sum = _mm512_dpwssd_epi32(sum, va_hi, vb_hi); - - // _mm512_dpwssd_epi32(src, a, b) - // Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding - // 16-bit integers in `b`, producing 2 intermediate signed 32-bit results. Sum these 2 results - // with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst. -} - -template // 0..63 -static inline int UINT8_InnerProductImp(const void *pVect1v, const void *pVect2v, - size_t dimension) { - uint8_t *pVect1 = (uint8_t *)pVect1v; - uint8_t *pVect2 = (uint8_t *)pVect2v; - - const uint8_t *pEnd1 = pVect1 + dimension; - - __m512i sum = _mm512_setzero_epi32(); - - // Deal with remainder first. - if constexpr (residual) { - if constexpr (residual < 32) { - constexpr __mmask32 mask = (1LU << residual) - 1; - __m256i temp_a = _mm256_maskz_loadu_epi8(mask, pVect1); - __m512i va = _mm512_cvtepu8_epi16(temp_a); - - __m256i temp_b = _mm256_maskz_loadu_epi8(mask, pVect2); - __m512i vb = _mm512_cvtepu8_epi16(temp_b); - - sum = _mm512_dpwssd_epi32(sum, va, vb); - } else if constexpr (residual == 32) { - __m256i temp_a = _mm256_loadu_epi8(pVect1); - __m512i va = _mm512_cvtepu8_epi16(temp_a); - - __m256i temp_b = _mm256_loadu_epi8(pVect2); - __m512i vb = _mm512_cvtepu8_epi16(temp_b); - - sum = _mm512_dpwssd_epi32(sum, va, vb); - } else { - constexpr __mmask64 mask = (1LU << residual) - 1; - __m512i va = _mm512_maskz_loadu_epi8(mask, pVect1); - __m512i vb = _mm512_maskz_loadu_epi8(mask, pVect2); - - __m512i va_lo = _mm512_unpacklo_epi8(va, _mm512_setzero_si512()); - __m512i vb_lo = _mm512_unpacklo_epi8(vb, _mm512_setzero_si512()); - sum = _mm512_dpwssd_epi32(sum, va_lo, vb_lo); - - __m512i va_hi = _mm512_unpackhi_epi8(va, _mm512_setzero_si512()); - __m512i vb_hi = _mm512_unpackhi_epi8(vb, _mm512_setzero_si512()); - sum = _mm512_dpwssd_epi32(sum, va_hi, vb_hi); - } - pVect1 += residual; - pVect2 += residual; - - // We dealt with the residual part. - // We are left with some multiple of 64-uint_8 (might be 0). - while (pVect1 < pEnd1) { - InnerProductStep(pVect1, pVect2, sum); - } - } else { - // We have no residual, we have some non-zero multiple of 64-uint_8. - do { - InnerProductStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - } - - return _mm512_reduce_add_epi32(sum); -} - -template // 0..63 -float UINT8_InnerProductSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, - size_t dimension) { - - return 1 - UINT8_InnerProductImp(pVect1v, pVect2v, dimension); -} -template // 0..63 -float UINT8_CosineSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, - size_t dimension) { - float ip = UINT8_InnerProductImp(pVect1v, pVect2v, dimension); - float norm_v1 = - *reinterpret_cast(static_cast(pVect1v) + dimension); - float norm_v2 = - *reinterpret_cast(static_cast(pVect2v) + dimension); - return 1.0f - ip / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512F_FP16.h b/src/VecSim/spaces/IP/IP_AVX512F_FP16.h deleted file mode 100644 index b09352533..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512F_FP16.h +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" -#include "VecSim/types/float16.h" - -using float16 = vecsim_types::float16; - -static void InnerProductStep(float16 *&pVect1, float16 *&pVect2, __m512 &sum) { - // Convert 16 half-floats into floats and store them in 512 bits register. - auto v1 = _mm512_cvtph_ps(_mm256_lddqu_si256((__m256i *)pVect1)); - auto v2 = _mm512_cvtph_ps(_mm256_lddqu_si256((__m256i *)pVect2)); - - // sum = v1 * v2 + sum - sum = _mm512_fmadd_ps(v1, v2, sum); - pVect1 += 16; - pVect2 += 16; -} - -template // 0..31 -float FP16_InnerProductSIMD32_AVX512(const void *pVect1v, const void *pVect2v, size_t dimension) { - auto *pVect1 = (float16 *)pVect1v; - auto *pVect2 = (float16 *)pVect2v; - - const float16 *pEnd1 = pVect1 + dimension; - - auto sum = _mm512_setzero_ps(); - - if constexpr (residual % 16) { - // Deal with remainder first. `dim` is more than 32, so we have at least one block of 32 - // 16-bit float so mask loading is guaranteed to be safe. - __mmask16 constexpr residuals_mask = (1 << (residual % 16)) - 1; - // Convert the first half-floats in the residual positions into floats and store them - // 512 bits register, where the floats in the positions corresponding to the non-residuals - // positions are zeros. - auto v1 = _mm512_maskz_mov_ps(residuals_mask, - _mm512_cvtph_ps(_mm256_lddqu_si256((__m256i *)pVect1))); - auto v2 = _mm512_maskz_mov_ps(residuals_mask, - _mm512_cvtph_ps(_mm256_lddqu_si256((__m256i *)pVect2))); - sum = _mm512_mul_ps(v1, v2); - pVect1 += residual % 16; - pVect2 += residual % 16; - } - if constexpr (residual >= 16) { - InnerProductStep(pVect1, pVect2, sum); - } - - // We dealt with the residual part. We are left with some multiple of 32 16-bit floats. - // In every iteration we process 2 chunks of 256bit (32 FP16) - do { - InnerProductStep(pVect1, pVect2, sum); - InnerProductStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return 1.0f - _mm512_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512F_FP32.h b/src/VecSim/spaces/IP/IP_AVX512F_FP32.h deleted file mode 100644 index 88421fd39..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512F_FP32.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void InnerProductStep(float *&pVect1, float *&pVect2, __m512 &sum512) { - __m512 v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; - __m512 v2 = _mm512_loadu_ps(pVect2); - pVect2 += 16; - sum512 = _mm512_fmadd_ps(v1, v2, sum512); -} - -template // 0..15 -float FP32_InnerProductSIMD16_AVX512(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - - const float *pEnd1 = pVect1 + dimension; - - __m512 sum512 = _mm512_setzero_ps(); - - // Deal with remainder first. `dim` is more than 16, so we have at least one 16-float block, - // so mask loading is guaranteed to be safe - if constexpr (residual) { - __mmask16 constexpr mask = (1 << residual) - 1; - __m512 v1 = _mm512_maskz_loadu_ps(mask, pVect1); - pVect1 += residual; - __m512 v2 = _mm512_maskz_loadu_ps(mask, pVect2); - pVect2 += residual; - sum512 = _mm512_mul_ps(v1, v2); - } - - // We dealt with the residual part. We are left with some multiple of 16 floats. - do { - InnerProductStep(pVect1, pVect2, sum512); - } while (pVect1 < pEnd1); - - return 1.0f - _mm512_reduce_add_ps(sum512); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512F_FP64.h b/src/VecSim/spaces/IP/IP_AVX512F_FP64.h deleted file mode 100644 index e6eebcc44..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512F_FP64.h +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void InnerProductStep(double *&pVect1, double *&pVect2, __m512d &sum512) { - __m512d v1 = _mm512_loadu_pd(pVect1); - pVect1 += 8; - __m512d v2 = _mm512_loadu_pd(pVect2); - pVect2 += 8; - sum512 = _mm512_fmadd_pd(v1, v2, sum512); -} - -template // 0..7 -double FP64_InnerProductSIMD8_AVX512(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - - const double *pEnd1 = pVect1 + dimension; - - __m512d sum512 = _mm512_setzero_pd(); - - // Deal with remainder first. `dim` is more than 8, so we have at least one 8-double block, - // so mask loading is guaranteed to be safe - if constexpr (residual) { - __mmask8 constexpr mask = (1 << residual) - 1; - __m512d v1 = _mm512_maskz_loadu_pd(mask, pVect1); - pVect1 += residual; - __m512d v2 = _mm512_maskz_loadu_pd(mask, pVect2); - pVect2 += residual; - sum512 = _mm512_mul_pd(v1, v2); - } - - // We dealt with the residual part. We are left with some multiple of 8 doubles. - do { - InnerProductStep(pVect1, pVect2, sum512); - } while (pVect1 < pEnd1); - - return 1.0 - _mm512_reduce_add_pd(sum512); -} diff --git a/src/VecSim/spaces/IP/IP_AVX512_BF16_VL_BF16.h b/src/VecSim/spaces/IP/IP_AVX512_BF16_VL_BF16.h deleted file mode 100644 index 3c007ed35..000000000 --- a/src/VecSim/spaces/IP/IP_AVX512_BF16_VL_BF16.h +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/bfloat16.h" - -using bfloat16 = vecsim_types::bfloat16; - -static inline void InnerProductStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m512 &sum) { - __m512i vec1 = _mm512_loadu_si512((__m512i *)pVect1); - __m512i vec2 = _mm512_loadu_si512((__m512i *)pVect2); - - sum = _mm512_dpbf16_ps(sum, (__m512bh)vec1, (__m512bh)vec2); - pVect1 += 32; - pVect2 += 32; -} - -template // 0..31 -float BF16_InnerProductSIMD32_AVX512BF16_VL(const void *pVect1v, const void *pVect2v, - size_t dimension) { - bfloat16 *pVect1 = (bfloat16 *)pVect1v; - bfloat16 *pVect2 = (bfloat16 *)pVect2v; - - const bfloat16 *pEnd1 = pVect1 + dimension; - - __m512 sum = _mm512_setzero_ps(); - - if constexpr (residual) { - constexpr __mmask32 mask = (1LU << residual) - 1; - __m512i v1 = _mm512_maskz_loadu_epi16(mask, pVect1); - pVect1 += residual; - __m512i v2 = _mm512_maskz_loadu_epi16(mask, pVect2); - pVect2 += residual; - sum = _mm512_dpbf16_ps(sum, (__m512bh)v1, (__m512bh)v2); - } - - do { - InnerProductStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return 1.0f - _mm512_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/IP/IP_AVX_FP32.h b/src/VecSim/spaces/IP/IP_AVX_FP32.h deleted file mode 100644 index e495fb9a1..000000000 --- a/src/VecSim/spaces/IP/IP_AVX_FP32.h +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" - -static inline void InnerProductStep(float *&pVect1, float *&pVect2, __m256 &sum256) { - __m256 v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); -} - -template // 0..15 -float FP32_InnerProductSIMD16_AVX(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - - const float *pEnd1 = pVect1 + dimension; - - __m256 sum256 = _mm256_setzero_ps(); - - // Deal with 1-7 floats with mask loading, if needed. `dim` is >16, so we have at least one - // 16-float block, so mask loading is guaranteed to be safe. - if constexpr (residual % 8) { - __mmask8 constexpr mask = (1 << (residual % 8)) - 1; - __m256 v1 = my_mm256_maskz_loadu_ps(pVect1); - pVect1 += residual % 8; - __m256 v2 = my_mm256_maskz_loadu_ps(pVect2); - pVect2 += residual % 8; - sum256 = _mm256_mul_ps(v1, v2); - } - - // If the reminder is >=8, have another step of 8 floats - if constexpr (residual >= 8) { - InnerProductStep(pVect1, pVect2, sum256); - } - - // We dealt with the residual part. We are left with some multiple of 16 floats. - // In each iteration we calculate 16 floats = 512 bits. - do { - InnerProductStep(pVect1, pVect2, sum256); - InnerProductStep(pVect1, pVect2, sum256); - } while (pVect1 < pEnd1); - - return 1.0f - my_mm256_reduce_add_ps(sum256); -} diff --git a/src/VecSim/spaces/IP/IP_AVX_FP64.h b/src/VecSim/spaces/IP/IP_AVX_FP64.h deleted file mode 100644 index b570e6b61..000000000 --- a/src/VecSim/spaces/IP/IP_AVX_FP64.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" - -static inline void InnerProductStep(double *&pVect1, double *&pVect2, __m256d &sum256) { - __m256d v1 = _mm256_loadu_pd(pVect1); - pVect1 += 4; - __m256d v2 = _mm256_loadu_pd(pVect2); - pVect2 += 4; - sum256 = _mm256_add_pd(sum256, _mm256_mul_pd(v1, v2)); -} - -template // 0..7 -double FP64_InnerProductSIMD8_AVX(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - - const double *pEnd1 = pVect1 + dimension; - - __m256d sum256 = _mm256_setzero_pd(); - - // Deal with 1-3 doubles with mask loading, if needed. `dim` is >8, so we have at least one - // 8-double block, so mask loading is guaranteed to be safe. - if constexpr (residual % 4) { - // _mm256_maskz_loadu_pd is not available in AVX - __mmask8 constexpr mask = (1 << (residual % 4)) - 1; - __m256d v1 = my_mm256_maskz_loadu_pd(pVect1); - pVect1 += residual % 4; - __m256d v2 = my_mm256_maskz_loadu_pd(pVect2); - pVect2 += residual % 4; - sum256 = _mm256_mul_pd(v1, v2); - } - - // If the reminder is >=4, have another step of 4 doubles - if constexpr (residual >= 4) { - InnerProductStep(pVect1, pVect2, sum256); - } - - // We dealt with the residual part. We are left with some multiple of 8 doubles. - // In each iteration we calculate 8 doubles = 512 bits. - do { - InnerProductStep(pVect1, pVect2, sum256); - InnerProductStep(pVect1, pVect2, sum256); - } while (pVect1 < pEnd1); - - double PORTABLE_ALIGN32 TmpRes[4]; - _mm256_store_pd(TmpRes, sum256); - double sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - return 1.0 - sum; -} diff --git a/src/VecSim/spaces/IP/IP_F16C_FP16.h b/src/VecSim/spaces/IP/IP_F16C_FP16.h deleted file mode 100644 index a6f2ec0f4..000000000 --- a/src/VecSim/spaces/IP/IP_F16C_FP16.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" -#include "VecSim/types/float16.h" - -using float16 = vecsim_types::float16; - -static void InnerProductStep(float16 *&pVect1, float16 *&pVect2, __m256 &sum) { - // Convert 8 half-floats into floats and store them in 256 bits register. - auto v1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i_u const *)(pVect1))); - auto v2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i_u const *)(pVect2))); - - // sum = v1 * v2 + sum - sum = _mm256_fmadd_ps(v1, v2, sum); - pVect1 += 8; - pVect2 += 8; -} - -template // 0..31 -float FP16_InnerProductSIMD32_F16C(const void *pVect1v, const void *pVect2v, size_t dimension) { - auto *pVect1 = (float16 *)pVect1v; - auto *pVect2 = (float16 *)pVect2v; - - const float16 *pEnd1 = pVect1 + dimension; - - auto sum = _mm256_setzero_ps(); - - if constexpr (residual % 8) { - // Deal with remainder first. `dim` is more than 32, so we have at least one block of 32 - // 16-bit float so mask loading is guaranteed to be safe. - __mmask16 constexpr residuals_mask = (1 << (residual % 8)) - 1; - // Convert the first 8 half-floats into floats and store them 256 bits register, - // where the floats in the positions corresponding to residuals are zeros. - auto v1 = _mm256_blend_ps(_mm256_setzero_ps(), - _mm256_cvtph_ps(_mm_loadu_si128((__m128i_u const *)pVect1)), - residuals_mask); - auto v2 = _mm256_blend_ps(_mm256_setzero_ps(), - _mm256_cvtph_ps(_mm_loadu_si128((__m128i_u const *)pVect2)), - residuals_mask); - sum = _mm256_mul_ps(v1, v2); - pVect1 += residual % 8; - pVect2 += residual % 8; - } - if constexpr (residual >= 8 && residual < 16) { - InnerProductStep(pVect1, pVect2, sum); - } else if constexpr (residual >= 16 && residual < 24) { - InnerProductStep(pVect1, pVect2, sum); - InnerProductStep(pVect1, pVect2, sum); - } else if constexpr (residual >= 24) { - InnerProductStep(pVect1, pVect2, sum); - InnerProductStep(pVect1, pVect2, sum); - InnerProductStep(pVect1, pVect2, sum); - } - - // We dealt with the residual part. We are left with some multiple of 32 16-bit floats. - // In every iteration we process 4 chunk of 128bit (32 FP16) - do { - InnerProductStep(pVect1, pVect2, sum); - InnerProductStep(pVect1, pVect2, sum); - InnerProductStep(pVect1, pVect2, sum); - InnerProductStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return 1.0f - my_mm256_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_BF16.h b/src/VecSim/spaces/IP/IP_NEON_BF16.h deleted file mode 100644 index 4e08865e2..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_BF16.h +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include - -inline void InnerProduct_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) { - // Load brain-half-precision vectors - bfloat16x8_t v1 = vld1q_bf16(vec1); - bfloat16x8_t v2 = vld1q_bf16(vec2); - vec1 += 8; - vec2 += 8; - // Compute multiplications and add to the accumulator - acc = vbfdotq_f32(acc, v1, v2); -} - -template // 0..31 -float BF16_InnerProduct_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *vec1 = static_cast(pVect1v); - const auto *vec2 = static_cast(pVect2v); - const auto *const v1End = vec1 + dimension; - float32x4_t acc1 = vdupq_n_f32(0.0f); - float32x4_t acc2 = vdupq_n_f32(0.0f); - float32x4_t acc3 = vdupq_n_f32(0.0f); - float32x4_t acc4 = vdupq_n_f32(0.0f); - - // First, handle the partial chunk residual - if constexpr (residual % 8) { - auto constexpr chunk_residual = residual % 8; - // TODO: special cases for some residuals and benchmark if its better - constexpr uint16x8_t mask = { - 0xFFFF, - (chunk_residual >= 2) ? 0xFFFF : 0, - (chunk_residual >= 3) ? 0xFFFF : 0, - (chunk_residual >= 4) ? 0xFFFF : 0, - (chunk_residual >= 5) ? 0xFFFF : 0, - (chunk_residual >= 6) ? 0xFFFF : 0, - (chunk_residual >= 7) ? 0xFFFF : 0, - 0, - }; - - // Load partial vectors - bfloat16x8_t v1 = vld1q_bf16(vec1); - bfloat16x8_t v2 = vld1q_bf16(vec2); - - // Apply mask to both vectors - bfloat16x8_t masked_v1 = - vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask)); - bfloat16x8_t masked_v2 = - vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask)); - - acc1 = vbfdotq_f32(acc1, masked_v1, masked_v2); - - // Advance pointers - vec1 += chunk_residual; - vec2 += chunk_residual; - } - - // Handle (residual - (residual % 8)) in chunks of 8 bfloat16 - if constexpr (residual >= 8) - InnerProduct_Step(vec1, vec2, acc2); - if constexpr (residual >= 16) - InnerProduct_Step(vec1, vec2, acc3); - if constexpr (residual >= 24) - InnerProduct_Step(vec1, vec2, acc4); - - // Process the rest of the vectors (the full chunks part) - while (vec1 < v1End) { - // TODO: use `vld1q_f16_x4` for quad-loading? - InnerProduct_Step(vec1, vec2, acc1); - InnerProduct_Step(vec1, vec2, acc2); - InnerProduct_Step(vec1, vec2, acc3); - InnerProduct_Step(vec1, vec2, acc4); - } - - // Accumulate accumulators - acc1 = vpaddq_f32(acc1, acc3); - acc2 = vpaddq_f32(acc2, acc4); - acc1 = vpaddq_f32(acc1, acc2); - - // Pairwise add to get horizontal sum - float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1)); - folded = vpadd_f32(folded, folded); - - // Extract result - return 1.0f - vget_lane_f32(folded, 0); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_DOTPROD_INT8.h b/src/VecSim/spaces/IP/IP_NEON_DOTPROD_INT8.h deleted file mode 100644 index 9fbc2b28d..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_DOTPROD_INT8.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -__attribute__((always_inline)) static inline void InnerProductOp(int8x16_t &v1, int8x16_t &v2, - int32x4_t &sum) { - sum = vdotq_s32(sum, v1, v2); -} - -__attribute__((always_inline)) static inline void InnerProductStep(int8_t *&pVect1, int8_t *&pVect2, - int32x4_t &sum) { - // Load 16 int8 elements (16 bytes) into NEON registers - int8x16_t v1 = vld1q_s8(pVect1); - int8x16_t v2 = vld1q_s8(pVect2); - - InnerProductOp(v1, v2, sum); - - pVect1 += 16; - pVect2 += 16; -} - -template // 0..63 -float INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { - int8_t *pVect1 = (int8_t *)pVect1v; - int8_t *pVect2 = (int8_t *)pVect2v; - - // Initialize multiple sum accumulators for better parallelism - int32x4_t sum0 = vdupq_n_s32(0); - int32x4_t sum1 = vdupq_n_s32(0); - - constexpr size_t final_residual = residual % 16; - if constexpr (final_residual > 0) { - // Define a compile-time constant mask based on final_residual - constexpr uint8x16_t mask = { - 0xFF, - (final_residual >= 2) ? 0xFF : 0, - (final_residual >= 3) ? 0xFF : 0, - (final_residual >= 4) ? 0xFF : 0, - (final_residual >= 5) ? 0xFF : 0, - (final_residual >= 6) ? 0xFF : 0, - (final_residual >= 7) ? 0xFF : 0, - (final_residual >= 8) ? 0xFF : 0, - (final_residual >= 9) ? 0xFF : 0, - (final_residual >= 10) ? 0xFF : 0, - (final_residual >= 11) ? 0xFF : 0, - (final_residual >= 12) ? 0xFF : 0, - (final_residual >= 13) ? 0xFF : 0, - (final_residual >= 14) ? 0xFF : 0, - (final_residual >= 15) ? 0xFF : 0, - 0, - }; - - // Load data directly from input vectors - int8x16_t v1 = vld1q_s8(pVect1); - int8x16_t v2 = vld1q_s8(pVect2); - - // Zero vector for replacement - int8x16_t zeros = vdupq_n_s8(0); - - // Apply bit select to zero out irrelevant elements - v1 = vbslq_s8(mask, v1, zeros); - v2 = vbslq_s8(mask, v2, zeros); - InnerProductOp(v1, v2, sum0); - pVect1 += final_residual; - pVect2 += final_residual; - } - - // Process 64 elements at a time in the main loop - const size_t num_of_chunks = dimension / 64; - - for (size_t i = 0; i < num_of_chunks; i++) { - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - } - - constexpr size_t residual_chunks = residual / 16; - - if constexpr (residual_chunks > 0) { - if constexpr (residual_chunks >= 1) { - InnerProductStep(pVect1, pVect2, sum0); - } - if constexpr (residual_chunks >= 2) { - InnerProductStep(pVect1, pVect2, sum1); - } - if constexpr (residual_chunks >= 3) { - InnerProductStep(pVect1, pVect2, sum0); - } - } - - // Combine all four sum registers - int32x4_t total_sum = vaddq_s32(sum0, sum1); - // Horizontal sum of the 4 elements in the combined sum register - int32_t result = vaddvq_s32(total_sum); - - return static_cast(result); -} - -template // 0..63 -float INT8_InnerProductSIMD16_NEON_DOTPROD(const void *pVect1v, const void *pVect2v, - size_t dimension) { - return 1.0f - INT8_InnerProductImp(pVect1v, pVect2v, dimension); -} - -template // 0..63 -float INT8_CosineSIMD_NEON_DOTPROD(const void *pVect1v, const void *pVect2v, size_t dimension) { - float ip = INT8_InnerProductImp(pVect1v, pVect2v, dimension); - float norm_v1 = - *reinterpret_cast(static_cast(pVect1v) + dimension); - float norm_v2 = - *reinterpret_cast(static_cast(pVect2v) + dimension); - return 1.0f - ip / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_DOTPROD_SQ8_SQ8.h b/src/VecSim/spaces/IP/IP_NEON_DOTPROD_SQ8_SQ8.h deleted file mode 100644 index 7a122974f..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_DOTPROD_SQ8_SQ8.h +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_NEON_DOTPROD_UINT8.h" -#include "VecSim/types/sq8.h" -#include - -using sq8 = vecsim_types::sq8; - -/** - * SQ8-to-SQ8 distance functions using ARM NEON DOTPROD with precomputed sum. - * These functions compute distance between two SQ8 (scalar quantized 8-bit) vectors, - * where BOTH vectors are uint8 quantized. - * - * Uses precomputed sum stored in the vector data, - * eliminating the need to compute them during distance calculation. - * - * Uses algebraic optimization with DOTPROD instruction: - * - * With sum = Σv[i] (sum of original float values), the formula is: - * IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1[i]*q2[i]) - dim*min1*min2 - * - * Since sum is precomputed, we only need to compute the dot product Σ(q1[i]*q2[i]). - * The dot product is computed using the efficient UINT8_InnerProductImp which uses - * the DOTPROD instruction (vdotq_u32) for native uint8 dot product computation. - * - * Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] - */ - -// Common implementation for inner product between two SQ8 vectors with precomputed sum -// Uses UINT8_InnerProductImp for efficient dot product computation with DOTPROD -template // 0..63 -float SQ8_SQ8_InnerProductSIMD64_NEON_DOTPROD_IMP(const void *pVec1v, const void *pVec2v, - size_t dimension) { - // Compute raw dot product using efficient UINT8 DOTPROD implementation - // UINT8_InnerProductImp uses vdotq_u32 for native uint8 dot product - float dot_product = UINT8_InnerProductImp(pVec1v, pVec2v, dimension); - - // Get dequantization parameters and precomputed values from the end of vectors - // Layout: [data (dim)] [min (float)] [delta (float)] [sum (float)] - const uint8_t *pVec1 = static_cast(pVec1v); - const uint8_t *pVec2 = static_cast(pVec2v); - - const float *params1 = reinterpret_cast(pVec1 + dimension); - const float min1 = params1[sq8::MIN_VAL]; - const float delta1 = params1[sq8::DELTA]; - const float sum1 = params1[sq8::SUM]; // Precomputed sum of original float elements - - const float *params2 = reinterpret_cast(pVec2 + dimension); - const float min2 = params2[sq8::MIN_VAL]; - const float delta2 = params2[sq8::DELTA]; - const float sum2 = params2[sq8::SUM]; // Precomputed sum of original float elements - - // Apply algebraic formula using precomputed sums: - // IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1*q2) - dim*min1*min2 - return min1 * sum2 + min2 * sum1 + delta1 * delta2 * dot_product - - static_cast(dimension) * min1 * min2; -} - -// SQ8-to-SQ8 Inner Product distance function -// Returns 1 - inner_product (distance form) -template // 0..63 -float SQ8_SQ8_InnerProductSIMD64_NEON_DOTPROD(const void *pVec1v, const void *pVec2v, - size_t dimension) { - return 1.0f - SQ8_SQ8_InnerProductSIMD64_NEON_DOTPROD_IMP(pVec1v, pVec2v, dimension); -} - -// SQ8-to-SQ8 Cosine distance function -// Returns 1 - inner_product (assumes vectors are pre-normalized) -template // 0..63 -float SQ8_SQ8_CosineSIMD64_NEON_DOTPROD(const void *pVec1v, const void *pVec2v, size_t dimension) { - return SQ8_SQ8_InnerProductSIMD64_NEON_DOTPROD(pVec1v, pVec2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_DOTPROD_UINT8.h b/src/VecSim/spaces/IP/IP_NEON_DOTPROD_UINT8.h deleted file mode 100644 index 73682a21a..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_DOTPROD_UINT8.h +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include - -__attribute__((always_inline)) static inline void InnerProductOp(uint8x16_t &v1, uint8x16_t &v2, - uint32x4_t &sum) { - sum = vdotq_u32(sum, v1, v2); -} - -__attribute__((always_inline)) static inline void -InnerProductStep(uint8_t *&pVect1, uint8_t *&pVect2, uint32x4_t &sum) { - // Load 16 uint8 elements (16 bytes) into NEON registers - uint8x16_t v1 = vld1q_u8(pVect1); - uint8x16_t v2 = vld1q_u8(pVect2); - InnerProductOp(v1, v2, sum); - - pVect1 += 16; - pVect2 += 16; -} - -template // 0..63 -float UINT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { - uint8_t *pVect1 = (uint8_t *)pVect1v; - uint8_t *pVect2 = (uint8_t *)pVect2v; - - // Initialize multiple sum accumulators for better parallelism - uint32x4_t sum0 = vdupq_n_u32(0); - uint32x4_t sum1 = vdupq_n_u32(0); - - constexpr size_t final_residual = residual % 16; - if constexpr (final_residual > 0) { - constexpr uint8x16_t mask = { - 0xFF, - (final_residual >= 2) ? 0xFF : 0, - (final_residual >= 3) ? 0xFF : 0, - (final_residual >= 4) ? 0xFF : 0, - (final_residual >= 5) ? 0xFF : 0, - (final_residual >= 6) ? 0xFF : 0, - (final_residual >= 7) ? 0xFF : 0, - (final_residual >= 8) ? 0xFF : 0, - (final_residual >= 9) ? 0xFF : 0, - (final_residual >= 10) ? 0xFF : 0, - (final_residual >= 11) ? 0xFF : 0, - (final_residual >= 12) ? 0xFF : 0, - (final_residual >= 13) ? 0xFF : 0, - (final_residual >= 14) ? 0xFF : 0, - (final_residual >= 15) ? 0xFF : 0, - 0, - }; - - // Load data directly from input vectors - uint8x16_t v1 = vld1q_u8(pVect1); - uint8x16_t v2 = vld1q_u8(pVect2); - - // Zero vector for replacement - uint8x16_t zeros = vdupq_n_u8(0); - - // Apply bit select to zero out irrelevant elements - v1 = vbslq_u8(mask, v1, zeros); - v2 = vbslq_u8(mask, v2, zeros); - InnerProductOp(v1, v2, sum1); - pVect1 += final_residual; - pVect2 += final_residual; - } - - // Process 64 elements at a time in the main loop - const size_t num_of_chunks = dimension / 64; - - for (size_t i = 0; i < num_of_chunks; i++) { - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - } - - constexpr size_t residual_chunks = residual / 16; - - if constexpr (residual_chunks > 0) { - if constexpr (residual_chunks >= 1) { - InnerProductStep(pVect1, pVect2, sum0); - } - if constexpr (residual_chunks >= 2) { - InnerProductStep(pVect1, pVect2, sum1); - } - if constexpr (residual_chunks >= 3) { - InnerProductStep(pVect1, pVect2, sum0); - } - } - - uint32x4_t total_sum = vaddq_u32(sum0, sum1); - - int32_t result = vaddvq_u32(total_sum); - - return static_cast(result); -} - -template // 0..63 -float UINT8_InnerProductSIMD16_NEON_DOTPROD(const void *pVect1v, const void *pVect2v, - size_t dimension) { - return 1.0f - UINT8_InnerProductImp(pVect1v, pVect2v, dimension); -} - -template // 0..63 -float UINT8_CosineSIMD_NEON_DOTPROD(const void *pVect1v, const void *pVect2v, size_t dimension) { - float ip = UINT8_InnerProductImp(pVect1v, pVect2v, dimension); - float norm_v1 = - *reinterpret_cast(static_cast(pVect1v) + dimension); - float norm_v2 = - *reinterpret_cast(static_cast(pVect2v) + dimension); - return 1.0f - ip / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_FP16.h b/src/VecSim/spaces/IP/IP_NEON_FP16.h deleted file mode 100644 index fd547c457..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_FP16.h +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include - -inline void InnerProduct_Step(const float16_t *&vec1, const float16_t *&vec2, float16x8_t &acc) { - // Load half-precision vectors - float16x8_t v1 = vld1q_f16(vec1); - float16x8_t v2 = vld1q_f16(vec2); - vec1 += 8; - vec2 += 8; - - // Multiply and accumulate - acc = vfmaq_f16(acc, v1, v2); -} - -template // 0..31 -float FP16_InnerProduct_NEON_HP(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *vec1 = static_cast(pVect1v); - const auto *vec2 = static_cast(pVect2v); - const auto *const v1End = vec1 + dimension; - float16x8_t acc1 = vdupq_n_f16(0.0f); - float16x8_t acc2 = vdupq_n_f16(0.0f); - float16x8_t acc3 = vdupq_n_f16(0.0f); - float16x8_t acc4 = vdupq_n_f16(0.0f); - - // First, handle the partial chunk residual - if constexpr (residual % 8) { - auto constexpr chunk_residual = residual % 8; - // TODO: spacial cases for some residuals and benchmark if its better - constexpr uint16x8_t mask = { - 0xFFFF, - (chunk_residual >= 2) ? 0xFFFF : 0, - (chunk_residual >= 3) ? 0xFFFF : 0, - (chunk_residual >= 4) ? 0xFFFF : 0, - (chunk_residual >= 5) ? 0xFFFF : 0, - (chunk_residual >= 6) ? 0xFFFF : 0, - (chunk_residual >= 7) ? 0xFFFF : 0, - 0, - }; - - // Load partial vectors - float16x8_t v1 = vld1q_f16(vec1); - float16x8_t v2 = vld1q_f16(vec2); - - // Apply mask to both vectors - float16x8_t masked_v1 = vbslq_f16(mask, v1, acc1); // `acc1` should be all zeros here - float16x8_t masked_v2 = vbslq_f16(mask, v2, acc2); // `acc2` should be all zeros here - - // Multiply and accumulate - acc1 = vfmaq_f16(acc1, masked_v1, masked_v2); - - // Advance pointers - vec1 += chunk_residual; - vec2 += chunk_residual; - } - - // Handle (residual - (residual % 8)) in chunks of 8 float16 - if constexpr (residual >= 8) - InnerProduct_Step(vec1, vec2, acc2); - if constexpr (residual >= 16) - InnerProduct_Step(vec1, vec2, acc3); - if constexpr (residual >= 24) - InnerProduct_Step(vec1, vec2, acc4); - - // Process the rest of the vectors (the full chunks part) - while (vec1 < v1End) { - // TODO: use `vld1q_f16_x4` for quad-loading? - InnerProduct_Step(vec1, vec2, acc1); - InnerProduct_Step(vec1, vec2, acc2); - InnerProduct_Step(vec1, vec2, acc3); - InnerProduct_Step(vec1, vec2, acc4); - } - - // Accumulate accumulators - acc1 = vpaddq_f16(acc1, acc3); - acc2 = vpaddq_f16(acc2, acc4); - acc1 = vpaddq_f16(acc1, acc2); - - // Horizontal sum of the accumulated values - float32x4_t sum_f32 = vcvt_f32_f16(vget_low_f16(acc1)); - sum_f32 = vaddq_f32(sum_f32, vcvt_f32_f16(vget_high_f16(acc1))); - - // Pairwise add to get horizontal sum - float32x2_t sum_2 = vadd_f32(vget_low_f32(sum_f32), vget_high_f32(sum_f32)); - sum_2 = vpadd_f32(sum_2, sum_2); - - // Extract result - return 1.0f - vget_lane_f32(sum_2, 0); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_FP32.h b/src/VecSim/spaces/IP/IP_NEON_FP32.h deleted file mode 100644 index 664a4ef6f..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_FP32.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -static inline void InnerProductStep(float *&pVect1, float *&pVect2, float32x4_t &sum) { - float32x4_t v1 = vld1q_f32(pVect1); - float32x4_t v2 = vld1q_f32(pVect2); - sum = vmlaq_f32(sum, v1, v2); - pVect1 += 4; - pVect2 += 4; -} - -template // 0..15 -float FP32_InnerProductSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - - float32x4_t sum0 = vdupq_n_f32(0.0f); - float32x4_t sum1 = vdupq_n_f32(0.0f); - float32x4_t sum2 = vdupq_n_f32(0.0f); - float32x4_t sum3 = vdupq_n_f32(0.0f); - - const size_t num_of_chunks = dimension / 16; - - for (size_t i = 0; i < num_of_chunks; i++) { - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - InnerProductStep(pVect1, pVect2, sum2); - InnerProductStep(pVect1, pVect2, sum3); - } - - // Handle remaining complete 4-float blocks within residual - constexpr size_t remaining_chunks = residual / 4; - - // Unrolled loop for the 4-float blocks - if constexpr (remaining_chunks >= 1) { - InnerProductStep(pVect1, pVect2, sum0); - } - if constexpr (remaining_chunks >= 2) { - InnerProductStep(pVect1, pVect2, sum1); - } - if constexpr (remaining_chunks >= 3) { - InnerProductStep(pVect1, pVect2, sum2); - } - - // Handle final residual elements (0-3 elements) - constexpr size_t final_residual = residual % 4; - if constexpr (final_residual > 0) { - float32x4_t v1 = vdupq_n_f32(0.0f); - float32x4_t v2 = vdupq_n_f32(0.0f); - - if constexpr (final_residual >= 1) { - v1 = vld1q_lane_f32(pVect1, v1, 0); - v2 = vld1q_lane_f32(pVect2, v2, 0); - } - if constexpr (final_residual >= 2) { - v1 = vld1q_lane_f32(pVect1 + 1, v1, 1); - v2 = vld1q_lane_f32(pVect2 + 1, v2, 1); - } - if constexpr (final_residual >= 3) { - v1 = vld1q_lane_f32(pVect1 + 2, v1, 2); - v2 = vld1q_lane_f32(pVect2 + 2, v2, 2); - } - - sum3 = vmlaq_f32(sum3, v1, v2); - } - - // Combine all four sum accumulators - float32x4_t sum_combined = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)); - - // Horizontal sum of the 4 elements in the combined NEON register - float32x2_t sum_halves = vadd_f32(vget_low_f32(sum_combined), vget_high_f32(sum_combined)); - float32x2_t summed = vpadd_f32(sum_halves, sum_halves); - float sum = vget_lane_f32(summed, 0); - - return 1.0f - sum; -} diff --git a/src/VecSim/spaces/IP/IP_NEON_FP64.h b/src/VecSim/spaces/IP/IP_NEON_FP64.h deleted file mode 100644 index 113963b24..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_FP64.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -inline void InnerProductStep(double *&pVect1, double *&pVect2, float64x2_t &sum) { - float64x2_t v1 = vld1q_f64(pVect1); - float64x2_t v2 = vld1q_f64(pVect2); - sum = vmlaq_f64(sum, v1, v2); - pVect1 += 2; - pVect2 += 2; -} - -template // 0..7 -double FP64_InnerProductSIMD8_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - - float64x2_t sum0 = vdupq_n_f64(0.0); - float64x2_t sum1 = vdupq_n_f64(0.0); - float64x2_t sum2 = vdupq_n_f64(0.0); - float64x2_t sum3 = vdupq_n_f64(0.0); - - const size_t num_of_chunks = dimension / 8; - - for (size_t i = 0; i < num_of_chunks; i++) { - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - InnerProductStep(pVect1, pVect2, sum2); - InnerProductStep(pVect1, pVect2, sum3); - } - - // Handle remaining complete 2-float blocks within residual - constexpr size_t remaining_chunks = residual / 2; - // Unrolled loop for the 2-float blocks - if constexpr (remaining_chunks >= 1) { - InnerProductStep(pVect1, pVect2, sum0); - } - if constexpr (remaining_chunks >= 2) { - InnerProductStep(pVect1, pVect2, sum1); - } - if constexpr (remaining_chunks >= 3) { - InnerProductStep(pVect1, pVect2, sum2); - } - - // Handle final residual elements (0-1 elements) - // This entire block is eliminated at compile time if final_residual is 0 - constexpr size_t final_residual = residual % 2; // Final 0-1 elements - if constexpr (final_residual == 1) { - float64x2_t v1 = vdupq_n_f64(0.0); - float64x2_t v2 = vdupq_n_f64(0.0); - v1 = vld1q_lane_f64(pVect1, v1, 0); - v2 = vld1q_lane_f64(pVect2, v2, 0); - - sum3 = vmlaq_f64(sum3, v1, v2); - } - - float64x2_t sum_combined = vaddq_f64(vaddq_f64(sum0, sum1), vaddq_f64(sum2, sum3)); - - // Horizontal sum of the 4 elements in the NEON register - float64x1_t summed = vadd_f64(vget_low_f64(sum_combined), vget_high_f64(sum_combined)); - double sum = vget_lane_f64(summed, 0); - - return 1.0 - sum; -} diff --git a/src/VecSim/spaces/IP/IP_NEON_INT8.h b/src/VecSim/spaces/IP/IP_NEON_INT8.h deleted file mode 100644 index 5118908d6..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_INT8.h +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -__attribute__((always_inline)) static inline void InnerProductOp(int8x16_t &v1, int8x16_t &v2, - int32x4_t &sum) { - // Multiply low 8 elements (first half) - int16x8_t prod_low = vmull_s8(vget_low_s8(v1), vget_low_s8(v2)); - - // Multiply high 8 elements (second half) using vmull_high_s8 - int16x8_t prod_high = vmull_high_s8(v1, v2); - - // Pairwise add adjacent elements to 32-bit accumulators - sum = vpadalq_s16(sum, prod_low); - sum = vpadalq_s16(sum, prod_high); -} - -__attribute__((always_inline)) static inline void InnerProductStep(int8_t *&pVect1, int8_t *&pVect2, - int32x4_t &sum) { - // Load 16 int8 elements (16 bytes) into NEON registers - int8x16_t v1 = vld1q_s8(pVect1); - int8x16_t v2 = vld1q_s8(pVect2); - InnerProductOp(v1, v2, sum); - - pVect1 += 16; - pVect2 += 16; -} - -template // 0..63 -float INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { - int8_t *pVect1 = (int8_t *)pVect1v; - int8_t *pVect2 = (int8_t *)pVect2v; - - // Initialize multiple sum accumulators for better parallelism - int32x4_t sum0 = vdupq_n_s32(0); - int32x4_t sum1 = vdupq_n_s32(0); - - constexpr size_t final_residual = residual % 16; - if constexpr (final_residual > 0) { - // Define a compile-time constant mask based on final_residual - constexpr uint8x16_t mask = { - 0xFF, - (final_residual >= 2) ? 0xFF : 0, - (final_residual >= 3) ? 0xFF : 0, - (final_residual >= 4) ? 0xFF : 0, - (final_residual >= 5) ? 0xFF : 0, - (final_residual >= 6) ? 0xFF : 0, - (final_residual >= 7) ? 0xFF : 0, - (final_residual >= 8) ? 0xFF : 0, - (final_residual >= 9) ? 0xFF : 0, - (final_residual >= 10) ? 0xFF : 0, - (final_residual >= 11) ? 0xFF : 0, - (final_residual >= 12) ? 0xFF : 0, - (final_residual >= 13) ? 0xFF : 0, - (final_residual >= 14) ? 0xFF : 0, - (final_residual >= 15) ? 0xFF : 0, - 0, - }; - - // Load data directly from input vectors - int8x16_t v1 = vld1q_s8(pVect1); - int8x16_t v2 = vld1q_s8(pVect2); - - // Zero vector for replacement - int8x16_t zeros = vdupq_n_s8(0); - - // Apply bit select to zero out irrelevant elements - v1 = vbslq_s8(mask, v1, zeros); - v2 = vbslq_s8(mask, v2, zeros); - InnerProductOp(v1, v2, sum0); - pVect1 += final_residual; - pVect2 += final_residual; - } - - // Process 64 elements at a time in the main loop - const size_t num_of_chunks = dimension / 64; - - for (size_t i = 0; i < num_of_chunks; i++) { - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - } - - constexpr size_t residual_chunks = residual / 16; - - if constexpr (residual_chunks > 0) { - if constexpr (residual_chunks >= 1) { - InnerProductStep(pVect1, pVect2, sum0); - } - if constexpr (residual_chunks >= 2) { - InnerProductStep(pVect1, pVect2, sum1); - } - if constexpr (residual_chunks >= 3) { - InnerProductStep(pVect1, pVect2, sum0); - } - } - - // Combine all four sum registers - int32x4_t total_sum = vaddq_s32(sum0, sum1); - // Horizontal sum of the 4 elements in the combined sum register - int32_t result = vaddvq_s32(total_sum); - - return static_cast(result); -} - -template // 0..15 -float INT8_InnerProductSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - INT8_InnerProductImp(pVect1v, pVect2v, dimension); -} - -template // 0..63 -float INT8_CosineSIMD_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - float ip = INT8_InnerProductImp(pVect1v, pVect2v, dimension); - float norm_v1 = - *reinterpret_cast(static_cast(pVect1v) + dimension); - float norm_v2 = - *reinterpret_cast(static_cast(pVect2v) + dimension); - return 1.0f - ip / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_SQ8_FP32.h b/src/VecSim/spaces/IP/IP_NEON_SQ8_FP32.h deleted file mode 100644 index 53a89bc7d..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_SQ8_FP32.h +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/sq8.h" -#include -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 inner product using algebraic identity: - * - * IP(x, y) = Σ(x_i * y_i) - * ≈ Σ((min + delta * q_i) * y_i) - * = min * Σy_i + delta * Σ(q_i * y_i) - * = min * y_sum + delta * quantized_dot_product - * - * where y_sum = Σy_i is precomputed and stored in the query blob. - * This avoids dequantization in the hot loop - we only compute Σ(q_i * y_i). - */ - -// Helper: compute Σ(q_i * y_i) for 4 elements (no dequantization) -// pVect1 = SQ8 storage (quantized values), pVect2 = FP32 query -static inline void InnerProductStepSQ8_FP32(const uint8_t *&pVect1, const float *&pVect2, - float32x4_t &sum) { - // Load 4 uint8 elements and convert to float - uint8x8_t v1_u8 = vld1_u8(pVect1); - pVect1 += 4; - - uint32x4_t v1_u32 = vmovl_u16(vget_low_u16(vmovl_u8(v1_u8))); - float32x4_t v1_f = vcvtq_f32_u32(v1_u32); - - // Load 4 float elements from query - float32x4_t v2 = vld1q_f32(pVect2); - pVect2 += 4; - - // Accumulate q_i * y_i (no dequantization!) - sum = vmlaq_f32(sum, v1_f, v2); -} - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_InnerProductSIMD16_NEON_IMP(const void *pVect1v, const void *pVect2v, - size_t dimension) { - const uint8_t *pVect1 = static_cast(pVect1v); // SQ8 storage - const float *pVect2 = static_cast(pVect2v); // FP32 query - - // Multiple accumulators for ILP - float32x4_t sum0 = vdupq_n_f32(0.0f); - float32x4_t sum1 = vdupq_n_f32(0.0f); - float32x4_t sum2 = vdupq_n_f32(0.0f); - float32x4_t sum3 = vdupq_n_f32(0.0f); - - const size_t num_of_chunks = dimension / 16; - - // Process 16 elements at a time in the main loop - for (size_t i = 0; i < num_of_chunks; i++) { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum0); - InnerProductStepSQ8_FP32(pVect1, pVect2, sum1); - InnerProductStepSQ8_FP32(pVect1, pVect2, sum2); - InnerProductStepSQ8_FP32(pVect1, pVect2, sum3); - } - - // Handle remaining complete 4-element blocks within residual - if constexpr (residual >= 4) { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum0); - } - if constexpr (residual >= 8) { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum1); - } - if constexpr (residual >= 12) { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum2); - } - - // Handle final residual elements (0-3 elements) - constexpr size_t final_residual = residual % 4; - if constexpr (final_residual > 0) { - float32x4_t v1_f = vdupq_n_f32(0.0f); - float32x4_t v2 = vdupq_n_f32(0.0f); - - if constexpr (final_residual >= 1) { - float q0 = static_cast(pVect1[0]); - v1_f = vld1q_lane_f32(&q0, v1_f, 0); - v2 = vld1q_lane_f32(pVect2, v2, 0); - } - if constexpr (final_residual >= 2) { - float q1 = static_cast(pVect1[1]); - v1_f = vld1q_lane_f32(&q1, v1_f, 1); - v2 = vld1q_lane_f32(pVect2 + 1, v2, 1); - } - if constexpr (final_residual >= 3) { - float q2 = static_cast(pVect1[2]); - v1_f = vld1q_lane_f32(&q2, v1_f, 2); - v2 = vld1q_lane_f32(pVect2 + 2, v2, 2); - } - - // Compute q_i * y_i (no dequantization) - sum3 = vmlaq_f32(sum3, v1_f, v2); - } - - // Combine all four sum accumulators - float32x4_t sum_combined = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)); - - // Horizontal sum to get Σ(q_i * y_i) - float32x2_t sum_halves = vadd_f32(vget_low_f32(sum_combined), vget_high_f32(sum_combined)); - float32x2_t summed = vpadd_f32(sum_halves, sum_halves); - float quantized_dot = vget_lane_f32(summed, 0); - - // Get quantization parameters from stored vector (after quantized data) - const uint8_t *pVect1Base = static_cast(pVect1v); - const float *params1 = reinterpret_cast(pVect1Base + dimension); - const float min_val = params1[sq8::MIN_VAL]; - const float delta = params1[sq8::DELTA]; - - // Get precomputed y_sum from query blob (stored after the dim floats) - const float y_sum = static_cast(pVect2v)[dimension + sq8::SUM_QUERY]; - - // Apply the algebraic formula: IP = min * y_sum + delta * Σ(q_i * y_i) - return min_val * y_sum + delta * quantized_dot; -} - -template // 0..15 -float SQ8_FP32_InnerProductSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - SQ8_FP32_InnerProductSIMD16_NEON_IMP(pVect1v, pVect2v, dimension); -} - -template // 0..15 -float SQ8_FP32_CosineSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Cosine distance = 1 - IP (vectors are pre-normalized) - return SQ8_FP32_InnerProductSIMD16_NEON(pVect1v, pVect2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_SQ8_SQ8.h b/src/VecSim/spaces/IP/IP_NEON_SQ8_SQ8.h deleted file mode 100644 index b89586322..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_SQ8_SQ8.h +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_NEON_UINT8.h" -#include "VecSim/types/sq8.h" -#include - -using sq8 = vecsim_types::sq8; - -/** - * SQ8-to-SQ8 distance functions using ARM NEON with precomputed sum. - * These functions compute distance between two SQ8 (scalar quantized 8-bit) vectors, - * where BOTH vectors are uint8 quantized. - * - * Uses precomputed sum stored in the vector data, - * eliminating the need to compute them during distance calculation. - * - * Uses algebraic optimization: - * - * With sum = Σv[i] (sum of original float values), the formula is: - * IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1[i]*q2[i]) - dim*min1*min2 - * - * Since sum is precomputed, we only need to compute the dot product Σ(q1[i]*q2[i]). - * The dot product is computed using the efficient UINT8_InnerProductImp which uses - * native NEON uint8 multiply-accumulate instructions (vmull_u8, vpadalq_u16). - * - * Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] - */ - -// Common implementation for inner product between two SQ8 vectors with precomputed sum -// Uses UINT8_InnerProductImp for efficient dot product computation -template // 0..63 -float SQ8_SQ8_InnerProductSIMD64_NEON_IMP(const void *pVec1v, const void *pVec2v, - size_t dimension) { - // Compute raw dot product using efficient UINT8 implementation - // UINT8_InnerProductImp processes 16 elements at a time using native uint8 instructions - float dot_product = UINT8_InnerProductImp(pVec1v, pVec2v, dimension); - - // Get dequantization parameters and precomputed values from the end of pVec1 - // Layout: [data (dim)] [min (float)] [delta (float)] [sum (float)] - const uint8_t *pVec1 = static_cast(pVec1v); - const uint8_t *pVec2 = static_cast(pVec2v); - - const float *params1 = reinterpret_cast(pVec1 + dimension); - const float min1 = params1[sq8::MIN_VAL]; - const float delta1 = params1[sq8::DELTA]; - const float sum1 = params1[sq8::SUM]; // Precomputed sum of original float elements - - // Get dequantization parameters and precomputed values from the end of pVec2 - const float *params2 = reinterpret_cast(pVec2 + dimension); - const float min2 = params2[sq8::MIN_VAL]; - const float delta2 = params2[sq8::DELTA]; - const float sum2 = params2[sq8::SUM]; // Precomputed sum of original float elements - - // Apply algebraic formula using precomputed sums: - // IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1*q2) - dim*min1*min2 - return min1 * sum2 + min2 * sum1 + delta1 * delta2 * dot_product - - static_cast(dimension) * min1 * min2; -} - -// SQ8-to-SQ8 Inner Product distance function -// Returns 1 - inner_product (distance form) -template // 0..63 -float SQ8_SQ8_InnerProductSIMD64_NEON(const void *pVec1v, const void *pVec2v, size_t dimension) { - return 1.0f - SQ8_SQ8_InnerProductSIMD64_NEON_IMP(pVec1v, pVec2v, dimension); -} - -// SQ8-to-SQ8 Cosine distance function -// Returns 1 - inner_product (assumes vectors are pre-normalized) -template // 0..63 -float SQ8_SQ8_CosineSIMD64_NEON(const void *pVec1v, const void *pVec2v, size_t dimension) { - return SQ8_SQ8_InnerProductSIMD64_NEON(pVec1v, pVec2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_NEON_UINT8.h b/src/VecSim/spaces/IP/IP_NEON_UINT8.h deleted file mode 100644 index 6263eeea4..000000000 --- a/src/VecSim/spaces/IP/IP_NEON_UINT8.h +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include - -__attribute__((always_inline)) static inline void InnerProductOp(uint8x16_t &v1, uint8x16_t &v2, - uint32x4_t &sum) { - // Multiply and accumulate low 8 elements (first half) - uint16x8_t prod_low = vmull_u8(vget_low_u8(v1), vget_low_u8(v2)); - - // Multiply and accumulate high 8 elements (second half) - uint16x8_t prod_high = vmull_u8(vget_high_u8(v1), vget_high_u8(v2)); - - // Pairwise add adjacent elements to 32-bit accumulators - sum = vpadalq_u16(sum, prod_low); - sum = vpadalq_u16(sum, prod_high); -} - -__attribute__((always_inline)) static inline void -InnerProductStep(uint8_t *&pVect1, uint8_t *&pVect2, uint32x4_t &sum) { - // Load 16 uint8 elements (16 bytes) into NEON registers - uint8x16_t v1 = vld1q_u8(pVect1); - uint8x16_t v2 = vld1q_u8(pVect2); - InnerProductOp(v1, v2, sum); - - pVect1 += 16; - pVect2 += 16; -} - -template // 0..63 -float UINT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { - uint8_t *pVect1 = (uint8_t *)pVect1v; - uint8_t *pVect2 = (uint8_t *)pVect2v; - - // Initialize multiple sum accumulators for better parallelism - uint32x4_t sum0 = vdupq_n_u32(0); - uint32x4_t sum1 = vdupq_n_u32(0); - - constexpr size_t final_residual = residual % 16; - if constexpr (final_residual > 0) { - constexpr uint8x16_t mask = { - 0xFF, - (final_residual >= 2) ? 0xFF : 0, - (final_residual >= 3) ? 0xFF : 0, - (final_residual >= 4) ? 0xFF : 0, - (final_residual >= 5) ? 0xFF : 0, - (final_residual >= 6) ? 0xFF : 0, - (final_residual >= 7) ? 0xFF : 0, - (final_residual >= 8) ? 0xFF : 0, - (final_residual >= 9) ? 0xFF : 0, - (final_residual >= 10) ? 0xFF : 0, - (final_residual >= 11) ? 0xFF : 0, - (final_residual >= 12) ? 0xFF : 0, - (final_residual >= 13) ? 0xFF : 0, - (final_residual >= 14) ? 0xFF : 0, - (final_residual >= 15) ? 0xFF : 0, - 0, - }; - - // Load data directly from input vectors - uint8x16_t v1 = vld1q_u8(pVect1); - uint8x16_t v2 = vld1q_u8(pVect2); - - // Zero vector for replacement - uint8x16_t zeros = vdupq_n_u8(0); - - // Apply bit select to zero out irrelevant elements - v1 = vbslq_u8(mask, v1, zeros); - v2 = vbslq_u8(mask, v2, zeros); - InnerProductOp(v1, v2, sum1); - pVect1 += final_residual; - pVect2 += final_residual; - } - - // Process 64 elements at a time in the main loop - const size_t num_of_chunks = dimension / 64; - - for (size_t i = 0; i < num_of_chunks; i++) { - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - InnerProductStep(pVect1, pVect2, sum0); - InnerProductStep(pVect1, pVect2, sum1); - } - - constexpr size_t residual_chunks = residual / 16; - - if constexpr (residual_chunks > 0) { - if constexpr (residual_chunks >= 1) { - InnerProductStep(pVect1, pVect2, sum0); - } - if constexpr (residual_chunks >= 2) { - InnerProductStep(pVect1, pVect2, sum1); - } - if constexpr (residual_chunks >= 3) { - InnerProductStep(pVect1, pVect2, sum0); - } - } - - uint32x4_t total_sum = vaddq_u32(sum0, sum1); - - // Horizontal sum of the 4 elements in the combined sum register - int32_t result = vaddvq_u32(total_sum); - - return static_cast(result); -} - -template // 0..15 -float UINT8_InnerProductSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - UINT8_InnerProductImp(pVect1v, pVect2v, dimension); -} - -template // 0..63 -float UINT8_CosineSIMD_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - float ip = UINT8_InnerProductImp(pVect1v, pVect2v, dimension); - float norm_v1 = - *reinterpret_cast(static_cast(pVect1v) + dimension); - float norm_v2 = - *reinterpret_cast(static_cast(pVect2v) + dimension); - return 1.0f - ip / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP/IP_SSE3_BF16.h b/src/VecSim/spaces/IP/IP_SSE3_BF16.h deleted file mode 100644 index 3ad511bdd..000000000 --- a/src/VecSim/spaces/IP/IP_SSE3_BF16.h +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/bfloat16.h" - -using bfloat16 = vecsim_types::bfloat16; - -static inline void InnerProductLowHalfStep(__m128i v1, __m128i v2, __m128i zeros, - __m128 &sum_prod) { - // Convert next 0..3 bf16 to 4 floats - __m128i bf16_low1 = _mm_unpacklo_epi16(zeros, v1); // SSE2 - __m128i bf16_low2 = _mm_unpacklo_epi16(zeros, v2); - - sum_prod = - _mm_add_ps(sum_prod, _mm_mul_ps(_mm_castsi128_ps(bf16_low1), _mm_castsi128_ps(bf16_low2))); -} - -static inline void InnerProductHighHalfStep(__m128i v1, __m128i v2, __m128i zeros, - __m128 &sum_prod) { - // Convert next 4..7 bf16 to 4 floats - __m128i bf16_high1 = _mm_unpackhi_epi16(zeros, v1); - __m128i bf16_high2 = _mm_unpackhi_epi16(zeros, v2); - - sum_prod = _mm_add_ps(sum_prod, - _mm_mul_ps(_mm_castsi128_ps(bf16_high1), _mm_castsi128_ps(bf16_high2))); -} - -static inline void InnerProductStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m128 &sum_prod) { - // Load 8 bf16 elements - __m128i v1 = _mm_lddqu_si128((__m128i *)pVect1); // SSE3 - pVect1 += 8; - __m128i v2 = _mm_lddqu_si128((__m128i *)pVect2); - pVect2 += 8; - - __m128i zeros = _mm_setzero_si128(); // SSE2 - - // Compute dist for 0..3 bf16 - InnerProductLowHalfStep(v1, v2, zeros, sum_prod); - - // Compute dist for 4..7 bf16 - InnerProductHighHalfStep(v1, v2, zeros, sum_prod); -} - -template // 0..31 -float BF16_InnerProductSIMD32_SSE3(const void *pVect1v, const void *pVect2v, size_t dimension) { - bfloat16 *pVect1 = (bfloat16 *)pVect1v; - bfloat16 *pVect2 = (bfloat16 *)pVect2v; - - const bfloat16 *pEnd1 = pVect1 + dimension; - - __m128 sum_prod = _mm_setzero_ps(); - - // Handle first residual % 8 elements (smaller than step chunk size) - - // Handle residual % 4 - if constexpr (residual % 4) { - __m128i v1, v2; - constexpr bfloat16 zero = bfloat16(0); - if constexpr (residual % 4 == 3) { - v1 = _mm_setr_epi16(zero, pVect1[0], zero, pVect1[1], zero, pVect1[2], zero, - zero); // SSE2 - v2 = _mm_setr_epi16(zero, pVect2[0], zero, pVect2[1], zero, pVect2[2], zero, zero); - } else if constexpr (residual % 4 == 2) { - // load 2 bf16 element set the rest to 0 - v1 = _mm_setr_epi16(zero, pVect1[0], zero, pVect1[1], zero, zero, zero, zero); // SSE2 - v2 = _mm_setr_epi16(zero, pVect2[0], zero, pVect2[1], zero, zero, zero, zero); - } else if constexpr (residual % 4 == 1) { - // load only first element - v1 = _mm_setr_epi16(zero, pVect1[0], zero, zero, zero, zero, zero, zero); // SSE2 - v2 = _mm_setr_epi16(zero, pVect2[0], zero, zero, zero, zero, zero, zero); - } - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(_mm_castsi128_ps(v1), _mm_castsi128_ps(v2))); - pVect1 += residual % 4; - pVect2 += residual % 4; - } - - // If residual % 8 >= 4 we need to handle 4 more elements - if constexpr (residual % 8 >= 4) { - __m128i v1 = _mm_lddqu_si128((__m128i *)pVect1); - __m128i v2 = _mm_lddqu_si128((__m128i *)pVect2); - InnerProductLowHalfStep(v1, v2, _mm_setzero_si128(), sum_prod); - pVect1 += 4; - pVect2 += 4; - } - - // Handle (residual - (residual % 8)) in chunks of 8 bfloat16 - if constexpr (residual >= 24) - InnerProductStep(pVect1, pVect2, sum_prod); - if constexpr (residual >= 16) - InnerProductStep(pVect1, pVect2, sum_prod); - if constexpr (residual >= 8) - InnerProductStep(pVect1, pVect2, sum_prod); - - // Handle 512 bits (32 bfloat16) in chunks of max SIMD = 128 bits = 8 bfloat16 - do { - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - } while (pVect1 < pEnd1); - - // TmpRes must be 16 bytes aligned - float PORTABLE_ALIGN16 TmpRes[4]; - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - - return 1.0f - sum; -} diff --git a/src/VecSim/spaces/IP/IP_SSE4_SQ8_FP32.h b/src/VecSim/spaces/IP/IP_SSE4_SQ8_FP32.h deleted file mode 100644 index 45f9f31f4..000000000 --- a/src/VecSim/spaces/IP/IP_SSE4_SQ8_FP32.h +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/sq8.h" -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 inner product using algebraic identity: - * - * IP(x, y) = Σ(x_i * y_i) - * ≈ Σ((min + delta * q_i) * y_i) - * = min * Σy_i + delta * Σ(q_i * y_i) - * = min * y_sum + delta * quantized_dot_product - * - * where y_sum = Σy_i is precomputed and stored in the query blob. - * This avoids dequantization in the hot loop - we only compute Σ(q_i * y_i). - */ - -// Helper: compute Σ(q_i * y_i) for 4 elements (no dequantization) -// pVect1 = SQ8 storage (quantized values), pVect2 = FP32 query -static inline void InnerProductStepSQ8_FP32(const uint8_t *&pVect1, const float *&pVect2, - __m128 &sum) { - // Load 4 uint8 elements and convert to float - __m128i v1_i = _mm_cvtepu8_epi32(_mm_cvtsi32_si128(*reinterpret_cast(pVect1))); - pVect1 += 4; - - __m128 v1_f = _mm_cvtepi32_ps(v1_i); - - // Load 4 float elements from query - __m128 v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - - // Accumulate q_i * y_i (no dequantization!) - // SSE doesn't have FMA, so use mul + add - sum = _mm_add_ps(sum, _mm_mul_ps(v1_f, v2)); -} - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_InnerProductSIMD16_SSE4_IMP(const void *pVect1v, const void *pVect2v, - size_t dimension) { - const uint8_t *pVect1 = static_cast(pVect1v); // SQ8 storage - const float *pVect2 = static_cast(pVect2v); // FP32 query - const uint8_t *pEnd1 = pVect1 + dimension; - - // Initialize sum accumulator for Σ(q_i * y_i) - __m128 sum = _mm_setzero_ps(); - - // Process residual elements first (1-3 elements) - if constexpr (residual % 4) { - __m128 v1_f; - __m128 v2; - - if constexpr (residual % 4 == 3) { - v1_f = _mm_set_ps(0.0f, static_cast(pVect1[2]), static_cast(pVect1[1]), - static_cast(pVect1[0])); - v2 = _mm_set_ps(0.0f, pVect2[2], pVect2[1], pVect2[0]); - } else if constexpr (residual % 4 == 2) { - v1_f = _mm_set_ps(0.0f, 0.0f, static_cast(pVect1[1]), - static_cast(pVect1[0])); - v2 = _mm_set_ps(0.0f, 0.0f, pVect2[1], pVect2[0]); - } else if constexpr (residual % 4 == 1) { - v1_f = _mm_set_ps(0.0f, 0.0f, 0.0f, static_cast(pVect1[0])); - v2 = _mm_set_ps(0.0f, 0.0f, 0.0f, pVect2[0]); - } - - pVect1 += residual % 4; - pVect2 += residual % 4; - - // Compute q_i * y_i (no dequantization) - sum = _mm_mul_ps(v1_f, v2); - } - - // Handle remaining residual in chunks of 4 (for residual 4-15) - if constexpr (residual >= 4) { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum); - } - if constexpr (residual >= 8) { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum); - } - if constexpr (residual >= 12) { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum); - } - - // Process remaining full chunks of 4 elements - // Using do-while since dim > 16 guarantees at least one iteration - do { - InnerProductStepSQ8_FP32(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - // Horizontal sum to get Σ(q_i * y_i) - float PORTABLE_ALIGN16 TmpRes[4]; - _mm_store_ps(TmpRes, sum); - float quantized_dot = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - - // Get quantization parameters from stored vector (after quantized data) - const uint8_t *pVect1Base = static_cast(pVect1v); - const float *params1 = reinterpret_cast(pVect1Base + dimension); - const float min_val = params1[sq8::MIN_VAL]; - const float delta = params1[sq8::DELTA]; - - // Get precomputed y_sum from query blob (stored after the dim floats) - const float *pVect2Base = static_cast(pVect2v); - const float y_sum = pVect2Base[dimension + sq8::SUM_QUERY]; - - // Apply the algebraic formula: IP = min * y_sum + delta * Σ(q_i * y_i) - return min_val * y_sum + delta * quantized_dot; -} - -template // 0..15 -float SQ8_FP32_InnerProductSIMD16_SSE4(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - SQ8_FP32_InnerProductSIMD16_SSE4_IMP(pVect1v, pVect2v, dimension); -} - -template // 0..15 -float SQ8_FP32_CosineSIMD16_SSE4(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Cosine distance = 1 - IP (vectors are pre-normalized) - return SQ8_FP32_InnerProductSIMD16_SSE4(pVect1v, pVect2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_SSE_FP32.h b/src/VecSim/spaces/IP/IP_SSE_FP32.h deleted file mode 100644 index a82cc85ff..000000000 --- a/src/VecSim/spaces/IP/IP_SSE_FP32.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void InnerProductStep(float *&pVect1, float *&pVect2, __m128 &sum_prod) { - __m128 v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - __m128 v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); -} - -template // 0..15 -float FP32_InnerProductSIMD16_SSE(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - - const float *pEnd1 = pVect1 + dimension; - - __m128 sum_prod = _mm_setzero_ps(); - - // Deal with %4 remainder first. `dim` is >16, so we have at least one 16-float block, - // so loading 4 floats and then masking them is safe. - if constexpr (residual % 4) { - __m128 v1, v2; - if constexpr (residual % 4 == 3) { - // Load 3 floats and set the last one to 0 - v1 = _mm_load_ss(pVect1); // load 1 float, set the rest to 0 - v2 = _mm_load_ss(pVect2); - v1 = _mm_loadh_pi(v1, (__m64 *)(pVect1 + 1)); - v2 = _mm_loadh_pi(v2, (__m64 *)(pVect2 + 1)); - } else if constexpr (residual % 4 == 2) { - // Load 2 floats and set the last two to 0 - v1 = _mm_loadh_pi(_mm_setzero_ps(), (__m64 *)pVect1); - v2 = _mm_loadh_pi(_mm_setzero_ps(), (__m64 *)pVect2); - } else if constexpr (residual % 4 == 1) { - // Load 1 float and set the last three to 0 - v1 = _mm_load_ss(pVect1); - v2 = _mm_load_ss(pVect2); - } - pVect1 += residual % 4; - pVect2 += residual % 4; - sum_prod = _mm_mul_ps(v1, v2); - } - - // have another 1, 2 or 3 4-float steps according to residual - if constexpr (residual >= 12) - InnerProductStep(pVect1, pVect2, sum_prod); - if constexpr (residual >= 8) - InnerProductStep(pVect1, pVect2, sum_prod); - if constexpr (residual >= 4) - InnerProductStep(pVect1, pVect2, sum_prod); - - // We dealt with the residual part. We are left with some multiple of 16 floats. - // In each iteration we calculate 16 floats = 512 bits. - do { - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - } while (pVect1 < pEnd1); - - // TmpRes must be 16 bytes aligned. - float PORTABLE_ALIGN16 TmpRes[4]; - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - - return 1.0f - sum; -} diff --git a/src/VecSim/spaces/IP/IP_SSE_FP64.h b/src/VecSim/spaces/IP/IP_SSE_FP64.h deleted file mode 100644 index eb0ebab7f..000000000 --- a/src/VecSim/spaces/IP/IP_SSE_FP64.h +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void InnerProductStep(double *&pVect1, double *&pVect2, __m128d &sum_prod) { - __m128d v1 = _mm_loadu_pd(pVect1); - pVect1 += 2; - __m128d v2 = _mm_loadu_pd(pVect2); - pVect2 += 2; - sum_prod = _mm_add_pd(sum_prod, _mm_mul_pd(v1, v2)); -} - -template // 0..7 -double FP64_InnerProductSIMD8_SSE(const void *pVect1v, const void *pVect2v, size_t dimension) { - - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - - const double *pEnd1 = pVect1 + dimension; - - __m128d sum_prod = _mm_setzero_pd(); - - // If residual is odd, we load 1 double and set the last one to 0 - if constexpr (residual % 2 == 1) { - __m128d v1 = _mm_load_sd(pVect1); - pVect1++; - __m128d v2 = _mm_load_sd(pVect2); - pVect2++; - sum_prod = _mm_mul_pd(v1, v2); - } - - // have another 1, 2 or 3 2-double steps according to residual - if constexpr (residual >= 6) - InnerProductStep(pVect1, pVect2, sum_prod); - if constexpr (residual >= 4) - InnerProductStep(pVect1, pVect2, sum_prod); - if constexpr (residual >= 2) - InnerProductStep(pVect1, pVect2, sum_prod); - - // We dealt with the residual part. We are left with some multiple of 8 doubles. - // In each iteration we calculate 8 doubles = 512 bits in total. - do { - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - InnerProductStep(pVect1, pVect2, sum_prod); - } while (pVect1 < pEnd1); - - double PORTABLE_ALIGN16 TmpRes[2]; - _mm_store_pd(TmpRes, sum_prod); - double sum = TmpRes[0] + TmpRes[1]; - - return 1.0 - sum; -} diff --git a/src/VecSim/spaces/IP/IP_SVE_BF16.h b/src/VecSim/spaces/IP/IP_SVE_BF16.h deleted file mode 100644 index 9c39c1360..000000000 --- a/src/VecSim/spaces/IP/IP_SVE_BF16.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include - -inline void InnerProduct_Step(const bfloat16_t *vec1, const bfloat16_t *vec2, svfloat32_t &acc, - size_t &offset, const size_t chunk) { - svbool_t all = svptrue_b16(); - - // Load brain-half-precision vectors. - svbfloat16_t v1 = svld1_bf16(all, vec1 + offset); - svbfloat16_t v2 = svld1_bf16(all, vec2 + offset); - // Compute multiplications and add to the accumulator - acc = svbfdot(acc, v1, v2); - - // Move to next chunk - offset += chunk; -} - -template // [t/f, 0..3] -float BF16_InnerProduct_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *vec1 = static_cast(pVect1v); - const auto *vec2 = static_cast(pVect2v); - const size_t chunk = svcnth(); // number of 16-bit elements in a register - svfloat32_t acc1 = svdup_f32(0.0f); - svfloat32_t acc2 = svdup_f32(0.0f); - svfloat32_t acc3 = svdup_f32(0.0f); - svfloat32_t acc4 = svdup_f32(0.0f); - size_t offset = 0; - - // Process all full vectors - const size_t full_iterations = dimension / chunk / 4; - for (size_t iter = 0; iter < full_iterations; iter++) { - InnerProduct_Step(vec1, vec2, acc1, offset, chunk); - InnerProduct_Step(vec1, vec2, acc2, offset, chunk); - InnerProduct_Step(vec1, vec2, acc3, offset, chunk); - InnerProduct_Step(vec1, vec2, acc4, offset, chunk); - } - - // Perform between 0 and 3 additional steps, according to `additional_steps` value - if constexpr (additional_steps >= 1) - InnerProduct_Step(vec1, vec2, acc1, offset, chunk); - if constexpr (additional_steps >= 2) - InnerProduct_Step(vec1, vec2, acc2, offset, chunk); - if constexpr (additional_steps >= 3) - InnerProduct_Step(vec1, vec2, acc3, offset, chunk); - - // Handle the tail with the residual predicate - if constexpr (partial_chunk) { - svbool_t pg = svwhilelt_b16_u64(offset, dimension); - - // Load brain-half-precision vectors. - // Inactive elements are zeros, according to the docs - svbfloat16_t v1 = svld1_bf16(pg, vec1 + offset); - svbfloat16_t v2 = svld1_bf16(pg, vec2 + offset); - // Compute multiplications and add to the accumulator. - acc4 = svbfdot(acc4, v1, v2); - } - - // Accumulate accumulators - acc1 = svadd_f32_x(svptrue_b32(), acc1, acc3); - acc2 = svadd_f32_x(svptrue_b32(), acc2, acc4); - acc1 = svadd_f32_x(svptrue_b32(), acc1, acc2); - - // Reduce the accumulated sum. - float result = svaddv_f32(svptrue_b32(), acc1); - return 1.0f - result; -} diff --git a/src/VecSim/spaces/IP/IP_SVE_FP16.h b/src/VecSim/spaces/IP/IP_SVE_FP16.h deleted file mode 100644 index ac464977e..000000000 --- a/src/VecSim/spaces/IP/IP_SVE_FP16.h +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include - -inline void InnerProduct_Step(const float16_t *vec1, const float16_t *vec2, svfloat16_t &acc, - size_t &offset, const size_t chunk) { - svbool_t all = svptrue_b16(); - - // Load half-precision vectors. - svfloat16_t v1 = svld1_f16(all, vec1 + offset); - svfloat16_t v2 = svld1_f16(all, vec2 + offset); - // Compute multiplications and add to the accumulator - acc = svmla_f16_x(all, acc, v1, v2); - - // Move to next chunk - offset += chunk; -} - -template // [t/f, 0..3] -float FP16_InnerProduct_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *vec1 = static_cast(pVect1v); - const auto *vec2 = static_cast(pVect2v); - const size_t chunk = svcnth(); // number of 16-bit elements in a register - svbool_t all = svptrue_b16(); - svfloat16_t acc1 = svdup_f16(0.0f); - svfloat16_t acc2 = svdup_f16(0.0f); - svfloat16_t acc3 = svdup_f16(0.0f); - svfloat16_t acc4 = svdup_f16(0.0f); - size_t offset = 0; - - // Process all full vectors - const size_t full_iterations = dimension / chunk / 4; - for (size_t iter = 0; iter < full_iterations; iter++) { - InnerProduct_Step(vec1, vec2, acc1, offset, chunk); - InnerProduct_Step(vec1, vec2, acc2, offset, chunk); - InnerProduct_Step(vec1, vec2, acc3, offset, chunk); - InnerProduct_Step(vec1, vec2, acc4, offset, chunk); - } - - // Perform between 0 and 3 additional steps, according to `additional_steps` value - if constexpr (additional_steps >= 1) - InnerProduct_Step(vec1, vec2, acc1, offset, chunk); - if constexpr (additional_steps >= 2) - InnerProduct_Step(vec1, vec2, acc2, offset, chunk); - if constexpr (additional_steps >= 3) - InnerProduct_Step(vec1, vec2, acc3, offset, chunk); - - // Handle the tail with the residual predicate - if constexpr (partial_chunk) { - svbool_t pg = svwhilelt_b16_u64(offset, dimension); - - // Load half-precision vectors. - svfloat16_t v1 = svld1_f16(pg, vec1 + offset); - svfloat16_t v2 = svld1_f16(pg, vec2 + offset); - // Compute multiplications and add to the accumulator. - // use the existing value of `acc` for the inactive elements (by the `m` suffix) - acc4 = svmla_f16_m(pg, acc4, v1, v2); - } - - // Accumulate accumulators - acc1 = svadd_f16_x(all, acc1, acc3); - acc2 = svadd_f16_x(all, acc2, acc4); - acc1 = svadd_f16_x(all, acc1, acc2); - - // Reduce the accumulated sum. - float result = svaddv_f16(all, acc1); - return 1.0f - result; -} diff --git a/src/VecSim/spaces/IP/IP_SVE_FP32.h b/src/VecSim/spaces/IP/IP_SVE_FP32.h deleted file mode 100644 index c1cc79ccd..000000000 --- a/src/VecSim/spaces/IP/IP_SVE_FP32.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -#include - -static inline void InnerProductStep(float *&pVect1, float *&pVect2, size_t &offset, - svfloat32_t &sum, const size_t chunk) { - svfloat32_t v1 = svld1_f32(svptrue_b32(), pVect1 + offset); - svfloat32_t v2 = svld1_f32(svptrue_b32(), pVect2 + offset); - - sum = svmla_f32_x(svptrue_b32(), sum, v1, v2); - - offset += chunk; -} - -template -float FP32_InnerProductSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - size_t offset = 0; - - uint64_t chunk = svcntw(); - - svfloat32_t sum0 = svdup_f32(0.0f); - svfloat32_t sum1 = svdup_f32(0.0f); - svfloat32_t sum2 = svdup_f32(0.0f); - svfloat32_t sum3 = svdup_f32(0.0f); - - auto chunk_size = 4 * chunk; - const size_t number_of_chunks = dimension / chunk_size; - for (size_t i = 0; i < number_of_chunks; i++) { - InnerProductStep(pVect1, pVect2, offset, sum0, chunk); - InnerProductStep(pVect1, pVect2, offset, sum1, chunk); - InnerProductStep(pVect1, pVect2, offset, sum2, chunk); - InnerProductStep(pVect1, pVect2, offset, sum3, chunk); - } - - // Process remaining complete SVE vectors that didn't fit into the main loop - // These are full vector operations (0-3 elements) - if constexpr (additional_steps > 0) { - if constexpr (additional_steps >= 1) { - InnerProductStep(pVect1, pVect2, offset, sum0, chunk); - } - if constexpr (additional_steps >= 2) { - InnerProductStep(pVect1, pVect2, offset, sum1, chunk); - } - if constexpr (additional_steps >= 3) { - InnerProductStep(pVect1, pVect2, offset, sum3, chunk); - } - } - - // Process final tail elements that don't form a complete vector - // This section handles the case when dimension is not evenly divisible by SVE vector length - if constexpr (partial_chunk) { - // Create a predicate mask where each lane is active only for the remaining elements - svbool_t pg = - svwhilelt_b32(static_cast(offset), static_cast(dimension)); - - // Load vectors with predication - svfloat32_t v1 = svld1_f32(pg, pVect1 + offset); - svfloat32_t v2 = svld1_f32(pg, pVect2 + offset); - sum3 = svmla_f32_m(pg, sum3, v1, v2); - } - - sum0 = svadd_f32_x(svptrue_b32(), sum0, sum1); - sum2 = svadd_f32_x(svptrue_b32(), sum2, sum3); - // Perform vector addition in parallel - svfloat32_t sum_all = svadd_f32_x(svptrue_b32(), sum0, sum2); - // Single horizontal reduction at the end - float result = svaddv_f32(svptrue_b32(), sum_all); - return 1.0f - result; -} diff --git a/src/VecSim/spaces/IP/IP_SVE_FP64.h b/src/VecSim/spaces/IP/IP_SVE_FP64.h deleted file mode 100644 index 1e091e85c..000000000 --- a/src/VecSim/spaces/IP/IP_SVE_FP64.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -#include - -inline void InnerProductStep(double *&pVect1, double *&pVect2, size_t &offset, svfloat64_t &sum, - const size_t chunk) { - // Load vectors - svfloat64_t v1 = svld1_f64(svptrue_b64(), pVect1 + offset); - svfloat64_t v2 = svld1_f64(svptrue_b64(), pVect2 + offset); - - // Multiply-accumulate - sum = svmla_f64_x(svptrue_b64(), sum, v1, v2); - - // Advance pointers - offset += chunk; -} - -template -double FP64_InnerProductSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - const size_t chunk = svcntd(); - size_t offset = 0; - - // Multiple accumulators to increase instruction-level parallelism - svfloat64_t sum0 = svdup_f64(0.0); - svfloat64_t sum1 = svdup_f64(0.0); - svfloat64_t sum2 = svdup_f64(0.0); - svfloat64_t sum3 = svdup_f64(0.0); - - auto chunk_size = 4 * chunk; - size_t number_of_chunks = dimension / chunk_size; - for (size_t i = 0; i < number_of_chunks; i++) { - InnerProductStep(pVect1, pVect2, offset, sum0, chunk); - InnerProductStep(pVect1, pVect2, offset, sum1, chunk); - InnerProductStep(pVect1, pVect2, offset, sum2, chunk); - InnerProductStep(pVect1, pVect2, offset, sum3, chunk); - } - - if constexpr (additional_steps >= 1) { - InnerProductStep(pVect1, pVect2, offset, sum0, chunk); - } - if constexpr (additional_steps >= 2) { - InnerProductStep(pVect1, pVect2, offset, sum1, chunk); - } - if constexpr (additional_steps >= 3) { - InnerProductStep(pVect1, pVect2, offset, sum2, chunk); - } - - if constexpr (partial_chunk) { - svbool_t pg = - svwhilelt_b64(static_cast(offset), static_cast(dimension)); - svfloat64_t v1 = svld1_f64(pg, pVect1 + offset); - svfloat64_t v2 = svld1_f64(pg, pVect2 + offset); - sum3 = svmla_f64_m(pg, sum3, v1, v2); - } - - // Combine the partial sums - sum0 = svadd_f64_x(svptrue_b64(), sum0, sum1); - sum2 = svadd_f64_x(svptrue_b64(), sum2, sum3); - - // Perform vector addition in parallel - svfloat64_t sum_all = svadd_f64_x(svptrue_b64(), sum0, sum2); - // Single horizontal reduction at the end - double result = svaddv_f64(svptrue_b64(), sum_all); - return 1.0 - result; -} diff --git a/src/VecSim/spaces/IP/IP_SVE_INT8.h b/src/VecSim/spaces/IP/IP_SVE_INT8.h deleted file mode 100644 index e8110bcff..000000000 --- a/src/VecSim/spaces/IP/IP_SVE_INT8.h +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once -#include "VecSim/spaces/space_includes.h" -#include - -inline void InnerProductStep(const int8_t *&pVect1, const int8_t *&pVect2, size_t &offset, - svint32_t &sum, const size_t chunk) { - svbool_t pg = svptrue_b8(); - - // Load int8 vectors - svint8_t v1_i8 = svld1_s8(pg, pVect1 + offset); - svint8_t v2_i8 = svld1_s8(pg, pVect2 + offset); - - sum = svdot_s32(sum, v1_i8, v2_i8); - - offset += chunk; // Move to the next set of int8 elements -} - -template -float INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { - const int8_t *pVect1 = reinterpret_cast(pVect1v); - const int8_t *pVect2 = reinterpret_cast(pVect2v); - - size_t offset = 0; - const size_t vl = svcntb(); - const size_t chunk_size = 4 * vl; - - // Each innerProductStep adds maximum 2^8 & 2^8 = 2^16 - // Therefore, on a single accumulator, we can perform 2^15 steps before overflowing - // That scenario will happen only is the dimension of the vector is larger than 16*4*2^15 = 2^21 - // (16 int8 in 1 SVE register) * (4 accumulators) * (2^15 steps) - // We can safely assume that the dimension is smaller than that - // So using int32_t is safe - - svint32_t sum0 = svdup_s32(0); - svint32_t sum1 = svdup_s32(0); - svint32_t sum2 = svdup_s32(0); - svint32_t sum3 = svdup_s32(0); - - size_t num_chunks = dimension / chunk_size; - - for (size_t i = 0; i < num_chunks; ++i) { - InnerProductStep(pVect1, pVect2, offset, sum0, vl); - InnerProductStep(pVect1, pVect2, offset, sum1, vl); - InnerProductStep(pVect1, pVect2, offset, sum2, vl); - InnerProductStep(pVect1, pVect2, offset, sum3, vl); - } - - // Process remaining complete SVE vectors that didn't fit into the main loop - // These are full vector operations (0-3 elements) - if constexpr (additional_steps > 0) { - if constexpr (additional_steps >= 1) { - InnerProductStep(pVect1, pVect2, offset, sum0, vl); - } - if constexpr (additional_steps >= 2) { - InnerProductStep(pVect1, pVect2, offset, sum1, vl); - } - if constexpr (additional_steps >= 3) { - InnerProductStep(pVect1, pVect2, offset, sum2, vl); - } - } - - if constexpr (partial_chunk) { - svbool_t pg = svwhilelt_b8_u64(offset, dimension); - - svint8_t v1_i8 = svld1_s8(pg, pVect1 + offset); // Load int8 vectors - svint8_t v2_i8 = svld1_s8(pg, pVect2 + offset); // Load int8 vectors - - sum3 = svdot_s32(sum3, v1_i8, v2_i8); - - pVect1 += vl; - pVect2 += vl; - } - - sum0 = svadd_s32_x(svptrue_b32(), sum0, sum1); - sum2 = svadd_s32_x(svptrue_b32(), sum2, sum3); - - // Perform vector addition in parallel and Horizontal sum - int32_t sum_all = svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sum0, sum2)); - - return sum_all; -} - -template -float INT8_InnerProductSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - - INT8_InnerProductImp(pVect1v, pVect2v, dimension); -} - -template -float INT8_CosineSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - float ip = INT8_InnerProductImp(pVect1v, pVect2v, dimension); - float norm_v1 = - *reinterpret_cast(static_cast(pVect1v) + dimension); - float norm_v2 = - *reinterpret_cast(static_cast(pVect2v) + dimension); - return 1.0f - ip / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP/IP_SVE_SQ8_FP32.h b/src/VecSim/spaces/IP/IP_SVE_SQ8_FP32.h deleted file mode 100644 index c4d5dbd7f..000000000 --- a/src/VecSim/spaces/IP/IP_SVE_SQ8_FP32.h +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/sq8.h" -#include - -using sq8 = vecsim_types::sq8; -/* - * Optimized asymmetric SQ8 inner product using algebraic identity: - * - * IP(x, y) = Σ(x_i * y_i) - * ≈ Σ((min + delta * q_i) * y_i) - * = min * Σy_i + delta * Σ(q_i * y_i) - * = min * y_sum + delta * quantized_dot_product - * - * where y_sum = Σy_i is precomputed and stored in the query blob. - * This avoids dequantization in the hot loop - we only compute Σ(q_i * y_i). - */ - -// Helper: compute Σ(q_i * y_i) for one SVE vector width (no dequantization) -// pVect1 = SQ8 storage (quantized values), pVect2 = FP32 query -static inline void InnerProductStepSQ8_FP32(const uint8_t *pVect1, const float *pVect2, - size_t &offset, svfloat32_t &sum, const size_t chunk) { - svbool_t pg = svptrue_b32(); - - // Load uint8 elements and zero-extend to uint32 - svuint32_t v1_u32 = svld1ub_u32(pg, pVect1 + offset); - - // Convert uint32 to float32 - svfloat32_t v1_f = svcvt_f32_u32_x(pg, v1_u32); - - // Load float elements from query - svfloat32_t v2 = svld1_f32(pg, pVect2 + offset); - - // Accumulate q_i * y_i (no dequantization!) - sum = svmla_f32_x(pg, sum, v1_f, v2); - - offset += chunk; -} - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template -float SQ8_FP32_InnerProductSIMD_SVE_IMP(const void *pVect1v, const void *pVect2v, - size_t dimension) { - const uint8_t *pVect1 = static_cast(pVect1v); // SQ8 storage - const float *pVect2 = static_cast(pVect2v); // FP32 query - size_t offset = 0; - - svbool_t pg = svptrue_b32(); - - // Get the number of 32-bit elements per vector at runtime - uint64_t chunk = svcntw(); - - // Multiple accumulators for ILP - svfloat32_t sum0 = svdup_f32(0.0f); - svfloat32_t sum1 = svdup_f32(0.0f); - svfloat32_t sum2 = svdup_f32(0.0f); - svfloat32_t sum3 = svdup_f32(0.0f); - - // Handle partial chunk if needed - if constexpr (partial_chunk) { - size_t remaining = dimension % chunk; - if (remaining > 0) { - // Create predicate for the remaining elements - svbool_t pg_partial = - svwhilelt_b32(static_cast(0), static_cast(remaining)); - - // Load uint8 elements and zero-extend to uint32 - svuint32_t v1_u32 = svld1ub_u32(pg_partial, pVect1 + offset); - - // Convert uint32 to float32 - svfloat32_t v1_f = svcvt_f32_u32_z(pg_partial, v1_u32); - - // Load float elements from query with predicate - svfloat32_t v2 = svld1_f32(pg_partial, pVect2); - - // Compute q_i * y_i (no dequantization) - sum0 = svmla_f32_z(pg_partial, sum0, v1_f, v2); - - offset += remaining; - } - } - - // Process 4 chunks at a time in the main loop - auto chunk_size = 4 * chunk; - const size_t number_of_chunks = - (dimension - (partial_chunk ? dimension % chunk : 0)) / chunk_size; - - for (size_t i = 0; i < number_of_chunks; i++) { - InnerProductStepSQ8_FP32(pVect1, pVect2, offset, sum0, chunk); - InnerProductStepSQ8_FP32(pVect1, pVect2, offset, sum1, chunk); - InnerProductStepSQ8_FP32(pVect1, pVect2, offset, sum2, chunk); - InnerProductStepSQ8_FP32(pVect1, pVect2, offset, sum3, chunk); - } - - // Handle remaining steps (0-3) - if constexpr (additional_steps > 0) { - InnerProductStepSQ8_FP32(pVect1, pVect2, offset, sum0, chunk); - } - if constexpr (additional_steps > 1) { - InnerProductStepSQ8_FP32(pVect1, pVect2, offset, sum1, chunk); - } - if constexpr (additional_steps > 2) { - InnerProductStepSQ8_FP32(pVect1, pVect2, offset, sum2, chunk); - } - - // Combine the accumulators - svfloat32_t sum = svadd_f32_z(pg, sum0, sum1); - sum = svadd_f32_z(pg, sum, sum2); - sum = svadd_f32_z(pg, sum, sum3); - - // Horizontal sum to get Σ(q_i * y_i) - float quantized_dot = svaddv_f32(pg, sum); - - // Get quantization parameters from stored vector (after quantized data) - const float *params1 = reinterpret_cast(pVect1 + dimension); - const float min_val = params1[sq8::MIN_VAL]; - const float delta = params1[sq8::DELTA]; - - // Get precomputed y_sum from query blob (stored after the dim floats) - const float y_sum = pVect2[dimension + sq8::SUM_QUERY]; - - // Apply the algebraic formula: IP = min * y_sum + delta * Σ(q_i * y_i) - return min_val * y_sum + delta * quantized_dot; -} - -template -float SQ8_FP32_InnerProductSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - SQ8_FP32_InnerProductSIMD_SVE_IMP( - pVect1v, pVect2v, dimension); -} - -template -float SQ8_FP32_CosineSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Cosine distance = 1 - IP (vectors are pre-normalized) - return SQ8_FP32_InnerProductSIMD_SVE(pVect1v, pVect2v, - dimension); -} diff --git a/src/VecSim/spaces/IP/IP_SVE_SQ8_SQ8.h b/src/VecSim/spaces/IP/IP_SVE_SQ8_SQ8.h deleted file mode 100644 index a752817dd..000000000 --- a/src/VecSim/spaces/IP/IP_SVE_SQ8_SQ8.h +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_SVE_UINT8.h" -#include "VecSim/types/sq8.h" -#include - -using sq8 = vecsim_types::sq8; - -/** - * SQ8-to-SQ8 distance functions using ARM SVE with precomputed sum. - * These functions compute distance between two SQ8 (scalar quantized 8-bit) vectors, - * where BOTH vectors are uint8 quantized. - * - * Uses precomputed sum stored in the vector data, - * eliminating the need to compute them during distance calculation. - * - * Uses algebraic optimization with SVE dot product instruction: - * - * With sum = Σv[i] (sum of original float values), the formula is: - * IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1[i]*q2[i]) - dim*min1*min2 - * - * Since sum is precomputed, we only need to compute the dot product Σ(q1[i]*q2[i]). - * The dot product is computed using the efficient UINT8_InnerProductImp which uses - * SVE dot product instruction (svdot_u32) for native uint8 dot product computation. - * - * Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] - */ - -// Common implementation for inner product between two SQ8 vectors with precomputed sum -// Uses UINT8_InnerProductImp for efficient dot product computation with SVE -template -float SQ8_SQ8_InnerProductSIMD_SVE_IMP(const void *pVec1v, const void *pVec2v, size_t dimension) { - // Compute raw dot product using efficient UINT8 SVE implementation - // UINT8_InnerProductImp uses svdot_u32 for native uint8 dot product - float dot_product = - UINT8_InnerProductImp(pVec1v, pVec2v, dimension); - - // Get dequantization parameters and precomputed values from the end of vectors - // Layout: [data (dim)] [min (float)] [delta (float)] [sum (float)] - const uint8_t *pVec1 = static_cast(pVec1v); - const uint8_t *pVec2 = static_cast(pVec2v); - - const float *params1 = reinterpret_cast(pVec1 + dimension); - const float min1 = params1[sq8::MIN_VAL]; - const float delta1 = params1[sq8::DELTA]; - const float sum1 = params1[sq8::SUM]; // Precomputed sum of original float elements - - const float *params2 = reinterpret_cast(pVec2 + dimension); - const float min2 = params2[sq8::MIN_VAL]; - const float delta2 = params2[sq8::DELTA]; - const float sum2 = params2[sq8::SUM]; // Precomputed sum of original float elements - - // Apply algebraic formula with float conversion only at the end: - // IP = min1*sum2 + min2*sum1 + δ1*δ2 * Σ(q1*q2) - dim*min1*min2 - return min1 * sum2 + min2 * sum1 + delta1 * delta2 * dot_product - - static_cast(dimension) * min1 * min2; -} - -// SQ8-to-SQ8 Inner Product distance function -// Returns 1 - inner_product (distance form) -template -float SQ8_SQ8_InnerProductSIMD_SVE(const void *pVec1v, const void *pVec2v, size_t dimension) { - return 1.0f - SQ8_SQ8_InnerProductSIMD_SVE_IMP(pVec1v, pVec2v, - dimension); -} - -// SQ8-to-SQ8 Cosine distance function -// Returns 1 - inner_product (assumes vectors are pre-normalized) -template -float SQ8_SQ8_CosineSIMD_SVE(const void *pVec1v, const void *pVec2v, size_t dimension) { - // Assume vectors are normalized. - return SQ8_SQ8_InnerProductSIMD_SVE(pVec1v, pVec2v, dimension); -} diff --git a/src/VecSim/spaces/IP/IP_SVE_UINT8.h b/src/VecSim/spaces/IP/IP_SVE_UINT8.h deleted file mode 100644 index c1cc45b66..000000000 --- a/src/VecSim/spaces/IP/IP_SVE_UINT8.h +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include - -inline void InnerProductStep(const uint8_t *&pVect1, const uint8_t *&pVect2, size_t &offset, - svuint32_t &sum, const size_t chunk) { - svbool_t pg = svptrue_b8(); - - // Load uint8 vectors - svuint8_t v1_ui8 = svld1_u8(pg, pVect1 + offset); - svuint8_t v2_ui8 = svld1_u8(pg, pVect2 + offset); - - sum = svdot_u32(sum, v1_ui8, v2_ui8); - - offset += chunk; // Move to the next set of uint8 elements -} - -template -float UINT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { - const uint8_t *pVect1 = reinterpret_cast(pVect1v); - const uint8_t *pVect2 = reinterpret_cast(pVect2v); - - size_t offset = 0; - const size_t vl = svcntb(); - const size_t chunk_size = 4 * vl; - - // Each innerProductStep adds maximum 2^8 & 2^8 = 2^16 - // Therefore, on a single accumulator, we can perform 2^16 steps before overflowing - // That scenario will happen only is the dimension of the vector is larger than 16*4*2^16 = 2^22 - // (16 uint8 in 1 SVE register) * (4 accumulators) * (2^16 steps) - // We can safely assume that the dimension is smaller than that - // So using int32_t is safe - - svuint32_t sum0 = svdup_u32(0); - svuint32_t sum1 = svdup_u32(0); - svuint32_t sum2 = svdup_u32(0); - svuint32_t sum3 = svdup_u32(0); - - size_t num_chunks = dimension / chunk_size; - - for (size_t i = 0; i < num_chunks; ++i) { - InnerProductStep(pVect1, pVect2, offset, sum0, vl); - InnerProductStep(pVect1, pVect2, offset, sum1, vl); - InnerProductStep(pVect1, pVect2, offset, sum2, vl); - InnerProductStep(pVect1, pVect2, offset, sum3, vl); - } - - // Process remaining complete SVE vectors that didn't fit into the main loop - // These are full vector operations (0-3 elements) - if constexpr (additional_steps > 0) { - if constexpr (additional_steps >= 1) { - InnerProductStep(pVect1, pVect2, offset, sum0, vl); - } - if constexpr (additional_steps >= 2) { - InnerProductStep(pVect1, pVect2, offset, sum1, vl); - } - if constexpr (additional_steps >= 3) { - InnerProductStep(pVect1, pVect2, offset, sum2, vl); - } - } - - if constexpr (partial_chunk) { - svbool_t pg = svwhilelt_b8_u64(offset, dimension); - - svuint8_t v1_ui8 = svld1_u8(pg, pVect1 + offset); // Load uint8 vectors - svuint8_t v2_ui8 = svld1_u8(pg, pVect2 + offset); // Load uint8 vectors - - sum3 = svdot_u32(sum3, v1_ui8, v2_ui8); - - pVect1 += vl; - pVect2 += vl; - } - - sum0 = svadd_u32_x(svptrue_b32(), sum0, sum1); - sum2 = svadd_u32_x(svptrue_b32(), sum2, sum3); - - // Perform vector addition in parallel and Horizontal sum - int32_t sum_all = svaddv_u32(svptrue_b32(), svadd_u32_x(svptrue_b32(), sum0, sum2)); - - return sum_all; -} - -template -float UINT8_InnerProductSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - return 1.0f - - UINT8_InnerProductImp(pVect1v, pVect2v, dimension); -} - -template -float UINT8_CosineSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - float ip = UINT8_InnerProductImp(pVect1v, pVect2v, dimension); - float norm_v1 = - *reinterpret_cast(static_cast(pVect1v) + dimension); - float norm_v2 = - *reinterpret_cast(static_cast(pVect2v) + dimension); - return 1.0f - ip / (norm_v1 * norm_v2); -} diff --git a/src/VecSim/spaces/IP_space.cpp b/src/VecSim/spaces/IP_space.cpp deleted file mode 100644 index 859b90271..000000000 --- a/src/VecSim/spaces/IP_space.cpp +++ /dev/null @@ -1,684 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP_space.h" -#include "VecSim/spaces/IP/IP.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include "VecSim/spaces/functions/AVX512F.h" -#include "VecSim/spaces/functions/F16C.h" -#include "VecSim/spaces/functions/AVX.h" -#include "VecSim/spaces/functions/SSE.h" -#include "VecSim/spaces/functions/AVX512BW_VBMI2.h" -#include "VecSim/spaces/functions/AVX512FP16_VL.h" -#include "VecSim/spaces/functions/AVX512BF16_VL.h" -#include "VecSim/spaces/functions/AVX512F_BW_VL_VNNI.h" -#include "VecSim/spaces/functions/AVX2.h" -#include "VecSim/spaces/functions/AVX2_FMA.h" -#include "VecSim/spaces/functions/SSE3.h" -#include "VecSim/spaces/functions/SSE4.h" -#include "VecSim/spaces/functions/NEON.h" -#include "VecSim/spaces/functions/NEON_DOTPROD.h" -#include "VecSim/spaces/functions/NEON_HP.h" -#include "VecSim/spaces/functions/NEON_BF16.h" -#include "VecSim/spaces/functions/SVE.h" -#include "VecSim/spaces/functions/SVE_BF16.h" -#include "VecSim/spaces/functions/SVE2.h" - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; - -namespace spaces { -// SQ8-FP32: asymmetric distance between SQ8 storage and FP32 query -dist_func_t IP_SQ8_FP32_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = SQ8_FP32_InnerProduct; - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); -#ifdef CPU_FEATURES_ARCH_AARCH64 - -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_SQ8_FP32_IP_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_SQ8_FP32_IP_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd) { - return Choose_SQ8_FP32_IP_implementation_NEON(dim); - } -#endif - -#endif - -#ifdef CPU_FEATURES_ARCH_X86_64 - // Optimizations assume at least 16 floats. If we have less, we use the naive implementation. - if (dim < 16) { - return ret_dist_func; - } -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vnni) { - return Choose_SQ8_FP32_IP_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#ifdef OPT_AVX2_FMA - if (features.avx2 && features.fma3) { - return Choose_SQ8_FP32_IP_implementation_AVX2_FMA(dim); - } -#endif -#ifdef OPT_AVX2 - if (features.avx2) { - return Choose_SQ8_FP32_IP_implementation_AVX2(dim); - } -#endif -#ifdef OPT_SSE4 - if (features.sse4_1) { - return Choose_SQ8_FP32_IP_implementation_SSE4(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -// SQ8-FP32: asymmetric cosine distance between SQ8 storage and FP32 query -dist_func_t Cosine_SQ8_FP32_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = SQ8_FP32_Cosine; - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); -#ifdef CPU_FEATURES_ARCH_AARCH64 - -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_SQ8_FP32_Cosine_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_SQ8_FP32_Cosine_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd) { - return Choose_SQ8_FP32_Cosine_implementation_NEON(dim); - } -#endif - -#endif - -#ifdef CPU_FEATURES_ARCH_X86_64 - // Optimizations assume at least 16 floats. If we have less, we use the naive implementation. - if (dim < 16) { - return ret_dist_func; - } -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vnni) { - return Choose_SQ8_FP32_Cosine_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#ifdef OPT_AVX2_FMA - if (features.avx2 && features.fma3) { - return Choose_SQ8_FP32_Cosine_implementation_AVX2_FMA(dim); - } -#endif -#ifdef OPT_AVX2 - if (features.avx2) { - return Choose_SQ8_FP32_Cosine_implementation_AVX2(dim); - } -#endif -#ifdef OPT_SSE4 - if (features.sse4_1) { - return Choose_SQ8_FP32_Cosine_implementation_SSE4(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -// SQ8-to-SQ8 Inner Product distance function (both vectors are uint8 quantized with precomputed -// sum) -dist_func_t IP_SQ8_SQ8_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = SQ8_SQ8_InnerProduct; - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_SQ8_SQ8_IP_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_SQ8_SQ8_IP_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD - if (features.asimddp && dim >= 16) { - return Choose_SQ8_SQ8_IP_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd && dim >= 16) { - return Choose_SQ8_SQ8_IP_implementation_NEON(dim); - } -#endif -#endif // AARCH64 - -#ifdef CPU_FEATURES_ARCH_X86_64 -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (dim >= 64 && features.avx512f && features.avx512bw && features.avx512vnni) { - return Choose_SQ8_SQ8_IP_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -// SQ8-to-SQ8 Cosine distance function (both vectors are uint8 quantized with precomputed sum) -dist_func_t Cosine_SQ8_SQ8_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = SQ8_SQ8_Cosine; - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_SQ8_SQ8_Cosine_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_SQ8_SQ8_Cosine_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD - if (features.asimddp && dim >= 16) { - return Choose_SQ8_SQ8_Cosine_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd && dim >= 16) { - return Choose_SQ8_SQ8_Cosine_implementation_NEON(dim); - } -#endif -#endif // AARCH64 - -#ifdef CPU_FEATURES_ARCH_X86_64 -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (dim >= 64 && features.avx512f && features.avx512bw && features.avx512vnni) { - return Choose_SQ8_SQ8_Cosine_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t IP_FP32_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = FP32_InnerProduct; - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); -#ifdef CPU_FEATURES_ARCH_AARCH64 - -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_FP32_IP_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_FP32_IP_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd) { - return Choose_FP32_IP_implementation_NEON(dim); - } -#endif - -#endif - -#ifdef CPU_FEATURES_ARCH_X86_64 - // Optimizations assume at least 16 floats. If we have less, we use the naive implementation. - if (dim < 16) { - return ret_dist_func; - } -#ifdef OPT_AVX512F - if (features.avx512f) { - if (dim % 16 == 0) // no point in aligning if we have an offsetting residual - *alignment = 16 * sizeof(float); // handles 16 floats - return Choose_FP32_IP_implementation_AVX512F(dim); - } -#endif -#ifdef OPT_AVX - if (features.avx) { - if (dim % 8 == 0) // no point in aligning if we have an offsetting residual - *alignment = 8 * sizeof(float); // handles 8 floats - return Choose_FP32_IP_implementation_AVX(dim); - } -#endif -#ifdef OPT_SSE - if (features.sse) { - if (dim % 4 == 0) // no point in aligning if we have an offsetting residual - *alignment = 4 * sizeof(float); // handles 4 floats - return Choose_FP32_IP_implementation_SSE(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t IP_FP64_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = FP64_InnerProduct; - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_FP64_IP_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_FP64_IP_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd) { - return Choose_FP64_IP_implementation_NEON(dim); - } -#endif - -#endif - -#ifdef CPU_FEATURES_ARCH_X86_64 - // Optimizations assume at least 8 doubles. If we have less, we use the naive implementation. - if (dim < 8) { - return ret_dist_func; - } -#ifdef OPT_AVX512F - if (features.avx512f) { - if (dim % 8 == 0) // no point in aligning if we have an offsetting residual - *alignment = 8 * sizeof(double); // handles 8 doubles - return Choose_FP64_IP_implementation_AVX512F(dim); - } -#endif -#ifdef OPT_AVX - if (features.avx) { - if (dim % 4 == 0) // no point in aligning if we have an offsetting residual - *alignment = 4 * sizeof(double); // handles 4 doubles - return Choose_FP64_IP_implementation_AVX(dim); - } -#endif -#ifdef OPT_SSE - if (features.sse) { - if (dim % 2 == 0) // no point in aligning if we have an offsetting residual - *alignment = 2 * sizeof(double); // handles 2 doubles - return Choose_FP64_IP_implementation_SSE(dim); - } -#endif -#endif // __x86_64__ */ - return ret_dist_func; -} - -dist_func_t IP_BF16_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { - unsigned char dummy_alignment; - if (!alignment) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = BF16_InnerProduct_LittleEndian; - if (!is_little_endian()) { - return BF16_InnerProduct_BigEndian; - } - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#if defined(CPU_FEATURES_ARCH_AARCH64) -#ifdef OPT_SVE_BF16 - if (features.svebf16) { - return Choose_BF16_IP_implementation_SVE_BF16(dim); - } -#endif -#ifdef OPT_NEON_BF16 - if (features.bf16 && dim >= 8) { // Optimization assumes at least 8 BF16s (full chunk) - return Choose_BF16_IP_implementation_NEON_BF16(dim); - } -#endif -#endif // AARCH64 - -#if defined(CPU_FEATURES_ARCH_X86_64) - // Optimizations assume at least 32 bfloats. If we have less, we use the naive implementation. - if (dim < 32) { - return ret_dist_func; - } - -#ifdef OPT_AVX512_BF16_VL - if (features.avx512_bf16 && features.avx512vl) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(bfloat16); // align to 512 bits. - return Choose_BF16_IP_implementation_AVX512BF16_VL(dim); - } -#endif -#ifdef OPT_AVX512_BW_VBMI2 - if (features.avx512bw && features.avx512vbmi2) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(bfloat16); // align to 512 bits. - return Choose_BF16_IP_implementation_AVX512BW_VBMI2(dim); - } -#endif -#ifdef OPT_AVX2 - if (features.avx2) { - if (dim % 16 == 0) // no point in aligning if we have an offsetting residual - *alignment = 16 * sizeof(bfloat16); // align to 256 bits. - return Choose_BF16_IP_implementation_AVX2(dim); - } -#endif -#ifdef OPT_SSE3 - if (features.sse3) { - if (dim % 8 == 0) // no point in aligning if we have an offsetting residual - *alignment = 8 * sizeof(bfloat16); // align to 128 bits. - return Choose_BF16_IP_implementation_SSE3(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t IP_FP16_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - - dist_func_t ret_dist_func = FP16_InnerProduct; - -#if defined(CPU_FEATURES_ARCH_AARCH64) -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_FP16_IP_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_FP16_IP_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_HP - if (features.asimdhp && dim >= 8) { // Optimization assumes at least 8 16FPs (full chunk) - return Choose_FP16_IP_implementation_NEON_HP(dim); - } -#endif -#endif - -#if defined(CPU_FEATURES_ARCH_X86_64) - // Optimizations assume at least 32 16FPs. If we have less, we use the naive implementation. - if (dim < 32) { - return ret_dist_func; - } -#ifdef OPT_AVX512_FP16_VL - // More details about the dimension limitation can be found in this PR's description: - // https://github.com/RedisAI/VectorSimilarity/pull/477 - if (features.avx512_fp16 && features.avx512vl) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(float16); // handles 32 floats - return Choose_FP16_IP_implementation_AVX512FP16_VL(dim); - } -#endif -#ifdef OPT_AVX512F - if (features.avx512f) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(float16); // handles 32 floats - return Choose_FP16_IP_implementation_AVX512F(dim); - } -#endif -#ifdef OPT_F16C - if (features.f16c && features.fma3 && features.avx) { - if (dim % 16 == 0) // no point in aligning if we have an offsetting residual - *alignment = 16 * sizeof(float16); // handles 16 floats - return Choose_FP16_IP_implementation_F16C(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t IP_INT8_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = INT8_InnerProduct; - - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_INT8_IP_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_INT8_IP_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD // Should be the first check, as it is the most optimized - if (features.asimddp && dim >= 16) { - return Choose_INT8_IP_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd && dim >= 16) { - return Choose_INT8_IP_implementation_NEON(dim); - } -#endif -#endif -#ifdef CPU_FEATURES_ARCH_X86_64 - // Optimizations assume at least 32 int8. If we have less, we use the naive implementation. - if (dim < 32) { - return ret_dist_func; - } - -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(int8_t); // align to 256 bits. - return Choose_INT8_IP_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t Cosine_INT8_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = INT8_Cosine; - - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_INT8_Cosine_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_INT8_Cosine_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD - if (features.asimddp && dim >= 16) { - return Choose_INT8_Cosine_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd && dim >= 16) { - return Choose_INT8_Cosine_implementation_NEON(dim); - } -#endif -#endif -#ifdef CPU_FEATURES_ARCH_X86_64 - if (dim < 32) { - return ret_dist_func; - } -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) { - // For int8 vectors with cosine distance, the extra float for the norm shifts alignment to - // `(dim + sizeof(float)) % 32`. - // Vectors satisfying this have a residual, causing offset loads during calculation. - // To avoid complexity, we skip alignment here, assuming the performance impact is - // negligible. - return Choose_INT8_Cosine_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif - -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t IP_UINT8_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = UINT8_InnerProduct; - - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_UINT8_IP_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_UINT8_IP_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD - if (features.asimddp && dim >= 16) { - return Choose_UINT8_IP_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd && dim >= 16) { - return Choose_UINT8_IP_implementation_NEON(dim); - } -#endif -#endif -#ifdef CPU_FEATURES_ARCH_X86_64 - if (dim < 32) { - return ret_dist_func; - } -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(uint8_t); // align to 256 bits. - return Choose_UINT8_IP_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t Cosine_UINT8_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = UINT8_Cosine; - - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_UINT8_Cosine_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_UINT8_Cosine_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD - if (features.asimddp && dim >= 16) { - return Choose_UINT8_Cosine_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd && dim >= 16) { - return Choose_UINT8_Cosine_implementation_NEON(dim); - } -#endif -#endif -#ifdef CPU_FEATURES_ARCH_X86_64 - if (dim < 32) { - return ret_dist_func; - } -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) { - // For uint8 vectors with cosine distance, the extra float for the norm shifts alignment to - // `(dim + sizeof(float)) % 32`. - // Vectors satisfying this have a residual, causing offset loads during calculation. - // To avoid complexity, we skip alignment here, assuming the performance impact is - // negligible. - return Choose_UINT8_Cosine_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -} // namespace spaces diff --git a/src/VecSim/spaces/IP_space.h b/src/VecSim/spaces/IP_space.h deleted file mode 100644 index b258ff481..000000000 --- a/src/VecSim/spaces/IP_space.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/spaces.h" - -namespace spaces { -// SQ8-FP32: asymmetric distance between FP32 query and SQ8 storage -dist_func_t IP_SQ8_FP32_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); - -dist_func_t IP_FP32_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t IP_FP64_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t IP_BF16_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t IP_FP16_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t IP_INT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t Cosine_INT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t IP_UINT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t Cosine_UINT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -// SQ8-FP32: asymmetric cosine distance between FP32 query and SQ8 storage -dist_func_t Cosine_SQ8_FP32_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized with precomputed sum) -dist_func_t IP_SQ8_SQ8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t Cosine_SQ8_SQ8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -} // namespace spaces diff --git a/src/VecSim/spaces/L2/L2.cpp b/src/VecSim/spaces/L2/L2.cpp deleted file mode 100644 index 7761df920..000000000 --- a/src/VecSim/spaces/L2/L2.cpp +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "L2.h" -#include "VecSim/spaces/IP/IP.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include "VecSim/types/sq8.h" -#include -#include - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8-FP32 L2 squared distance using algebraic identity: - * ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i² - * = x_sum_squares - 2 * IP(x, y) + y_sum_squares - * where IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) - * - * pVect1 is storage (SQ8): [uint8_t values (dim)] [min_val] [delta] [x_sum] [x_sum_squares] - * pVect2 is query (FP32): [float values (dim)] [y_sum] [y_sum_squares] - */ -float SQ8_FP32_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Get the raw inner product using the common implementation - const float ip = SQ8_FP32_InnerProduct_Impl(pVect1v, pVect2v, dimension); - - // Get precomputed sum of squares from storage blob (pVect1 is SQ8) - const auto *pVect1 = static_cast(pVect1v); - const float *params = reinterpret_cast(pVect1 + dimension); - const float x_sum_sq = params[sq8::SUM_SQUARES]; - - // Get precomputed sum of squares from query blob (pVect2 is FP32) - const auto *pVect2 = static_cast(pVect2v); - const float y_sum_sq = pVect2[dimension + sq8::SUM_SQUARES_QUERY]; - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return x_sum_sq + y_sum_sq - 2.0f * ip; -} - -float FP32_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *vec1 = (float *)pVect1v; - float *vec2 = (float *)pVect2v; - - float res = 0; - for (size_t i = 0; i < dimension; i++) { - float t = vec1[i] - vec2[i]; - res += t * t; - } - return res; -} - -double FP64_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *vec1 = (double *)pVect1v; - double *vec2 = (double *)pVect2v; - - double res = 0; - for (size_t i = 0; i < dimension; i++) { - double t = vec1[i] - vec2[i]; - res += t * t; - } - return res; -} - -template -float BF16_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) { - bfloat16 *pVect1 = (bfloat16 *)pVect1v; - bfloat16 *pVect2 = (bfloat16 *)pVect2v; - - float res = 0; - for (size_t i = 0; i < dimension; i++) { - float a = vecsim_types::bfloat16_to_float32(pVect1[i]); - float b = vecsim_types::bfloat16_to_float32(pVect2[i]); - float diff = a - b; - res += diff * diff; - } - return res; -} - -float BF16_L2Sqr_LittleEndian(const void *pVect1v, const void *pVect2v, size_t dimension) { - return BF16_L2Sqr(pVect1v, pVect2v, dimension); -} - -float BF16_L2Sqr_BigEndian(const void *pVect1v, const void *pVect2v, size_t dimension) { - return BF16_L2Sqr(pVect1v, pVect2v, dimension); -} - -float FP16_L2Sqr(const void *pVect1, const void *pVect2, size_t dimension) { - auto *vec1 = (float16 *)pVect1; - auto *vec2 = (float16 *)pVect2; - - float res = 0; - for (size_t i = 0; i < dimension; i++) { - float t = vecsim_types::FP16_to_FP32(vec1[i]) - vecsim_types::FP16_to_FP32(vec2[i]); - res += t * t; - } - return res; -} - -// Return type for the L2 functions. -// The type should be able to hold `dimension * MAX_VAL(int_elem_t) * MAX_VAL(int_elem_t)`. -// To support dimension up to 2^16, we need the difference between the type and int_elem_t to be at -// least 2 bytes. We assert that in the implementation. -template -using ret_t = std::conditional_t; - -// Difference type for the L2 functions. -// The type should be able to hold `MIN_VAL(int_elem_t)-MAX_VAL(int_elem_t)`, and should be signed -// to avoid unsigned arithmetic. This means that the difference type should be bigger than the -// size of the element type. We assert that in the implementation. -template -using diff_t = std::conditional_t; - -template -static inline ret_t INTEGER_L2Sqr(const int_elem_t *pVect1, const int_elem_t *pVect2, - size_t dimension) { - static_assert(sizeof(ret_t) - sizeof(int_elem_t) * 2 >= sizeof(uint16_t)); - static_assert(std::is_signed_v>); - static_assert(sizeof(diff_t) >= 2 * sizeof(int_elem_t)); - - ret_t res = 0; - for (size_t i = 0; i < dimension; i++) { - diff_t diff = pVect1[i] - pVect2[i]; - res += diff * diff; - } - return res; -} - -float INT8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - return float(INTEGER_L2Sqr(pVect1, pVect2, dimension)); -} - -float UINT8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - return float(INTEGER_L2Sqr(pVect1, pVect2, dimension)); -} - -// SQ8-to-SQ8 L2 squared distance (both vectors are uint8 quantized) -// Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] -// [sum_of_squares (float)] -// ||x - y||² = ||x||² + ||y||² - 2*IP(x, y) -// where: -// - ||x||² = sum_squares_x is precomputed and stored -// - ||y||² = sum_squares_y is precomputed and stored -// - IP(x, y) is computed using SQ8_SQ8_InnerProduct_Impl - -float SQ8_SQ8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *pVect1 = static_cast(pVect1v); - const auto *pVect2 = static_cast(pVect2v); - - // Get precomputed sum of squares from both vectors - // Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares] - const float sum_sq_1 = - *reinterpret_cast(pVect1 + dimension + sq8::SUM_SQUARES * sizeof(float)); - const float sum_sq_2 = - *reinterpret_cast(pVect2 + dimension + sq8::SUM_SQUARES * sizeof(float)); - - // Use the common inner product implementation - const float ip = SQ8_SQ8_InnerProduct_Impl(pVect1v, pVect2v, dimension); - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return sum_sq_1 + sum_sq_2 - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2.h b/src/VecSim/spaces/L2/L2.h deleted file mode 100644 index d055760f9..000000000 --- a/src/VecSim/spaces/L2/L2.h +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include - -// SQ8-FP32: pVect1v vector of type uint8 (SQ8) and pVect2v vector of type fp32 -float SQ8_FP32_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension); - -float FP32_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension); - -double FP64_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension); - -float BF16_L2Sqr_LittleEndian(const void *pVect1v, const void *pVect2v, size_t dimension); -float BF16_L2Sqr_BigEndian(const void *pVect1v, const void *pVect2v, size_t dimension); - -float FP16_L2Sqr(const void *pVect1, const void *pVect2, size_t dimension); - -float INT8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension); - -float UINT8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension); - -// SQ8-to-SQ8 L2 squared distance (both vectors are uint8 quantized) -float SQ8_SQ8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension); diff --git a/src/VecSim/spaces/L2/L2_AVX2_BF16.h b/src/VecSim/spaces/L2/L2_AVX2_BF16.h deleted file mode 100644 index fa84eb389..000000000 --- a/src/VecSim/spaces/L2/L2_AVX2_BF16.h +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/bfloat16.h" - -using bfloat16 = vecsim_types::bfloat16; - -static inline void L2SqrLowHalfStep(__m256i v1, __m256i v2, __m256i zeros, __m256 &sum) { - // Convert next 0:3, 8:11 bf16 to 8 floats - __m256i bf16_low1 = _mm256_unpacklo_epi16(zeros, v1); // AVX2 - __m256i bf16_low2 = _mm256_unpacklo_epi16(zeros, v2); - - __m256 diff = _mm256_sub_ps(_mm256_castsi256_ps(bf16_low1), _mm256_castsi256_ps(bf16_low2)); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); -} - -static inline void L2SqrHighHalfStep(__m256i v1, __m256i v2, __m256i zeros, __m256 &sum) { - // Convert next 4:7, 12:15 bf16 to 8 floats - __m256i bf16_high1 = _mm256_unpackhi_epi16(zeros, v1); - __m256i bf16_high2 = _mm256_unpackhi_epi16(zeros, v2); - - __m256 diff = _mm256_sub_ps(_mm256_castsi256_ps(bf16_high1), _mm256_castsi256_ps(bf16_high2)); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); -} - -static inline void L2SqrStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m256 &sum) { - // Load 16 bf16 elements - __m256i v1 = _mm256_lddqu_si256((__m256i *)pVect1); // avx - pVect1 += 16; - __m256i v2 = _mm256_lddqu_si256((__m256i *)pVect2); - pVect2 += 16; - - __m256i zeros = _mm256_setzero_si256(); // avx - - // Compute dist for 0:3, 8:11 bf16 - L2SqrLowHalfStep(v1, v2, zeros, sum); - - // Compute dist for 4:7, 12:15 bf16 - L2SqrHighHalfStep(v1, v2, zeros, sum); -} - -template // 0..31 -float BF16_L2SqrSIMD32_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) { - bfloat16 *pVect1 = (bfloat16 *)pVect1v; - bfloat16 *pVect2 = (bfloat16 *)pVect2v; - - const bfloat16 *pEnd1 = pVect1 + dimension; - - __m256 sum = _mm256_setzero_ps(); - - // Handle first (residual % 16) elements - if constexpr (residual % 16) { - // Load all 16 elements to a 256 bit register - __m256i v1 = _mm256_lddqu_si256((__m256i *)pVect1); // avx - pVect1 += residual % 16; - __m256i v2 = _mm256_lddqu_si256((__m256i *)pVect2); - pVect2 += residual % 16; - - // Unpack 0:3, 8:11 bf16 to 8 floats - __m256i zeros = _mm256_setzero_si256(); - __m256i v1_low = _mm256_unpacklo_epi16(zeros, v1); - __m256i v2_low = _mm256_unpacklo_epi16(zeros, v2); - - __m256 low_diff = _mm256_sub_ps(_mm256_castsi256_ps(v1_low), _mm256_castsi256_ps(v2_low)); - if constexpr (residual % 16 <= 4) { - constexpr unsigned char elem_to_calc = residual % 16; - constexpr __mmask8 mask = (1 << elem_to_calc) - 1; - low_diff = _mm256_blend_ps(_mm256_setzero_ps(), low_diff, mask); - } else { - __m256i v1_high = _mm256_unpackhi_epi16(zeros, v1); - __m256i v2_high = _mm256_unpackhi_epi16(zeros, v2); - __m256 high_diff = - _mm256_sub_ps(_mm256_castsi256_ps(v1_high), _mm256_castsi256_ps(v2_high)); - - if constexpr (4 < residual % 16 && residual % 16 <= 8) { - // Keep only 4 first elements of low pack - constexpr __mmask8 mask2 = (1 << 4) - 1; - low_diff = _mm256_blend_ps(_mm256_setzero_ps(), low_diff, mask2); - - // Keep (residual % 16 - 4) first elements of high_diff - constexpr unsigned char elem_to_calc = residual % 16 - 4; - constexpr __mmask8 mask3 = (1 << elem_to_calc) - 1; - high_diff = _mm256_blend_ps(_mm256_setzero_ps(), high_diff, mask3); - } else if constexpr (8 < residual % 16 && residual % 16 < 12) { - // Keep (residual % 16 - 4) first elements of low_diff - constexpr unsigned char elem_to_calc = residual % 16 - 4; - constexpr __mmask8 mask2 = (1 << elem_to_calc) - 1; - low_diff = _mm256_blend_ps(_mm256_setzero_ps(), low_diff, mask2); - - // Keep ony 4 first elements of high_diff - constexpr __mmask8 mask3 = (1 << 4) - 1; - high_diff = _mm256_blend_ps(_mm256_setzero_ps(), high_diff, mask3); - } else if constexpr (residual % 16 >= 12) { - // Keep (residual % 16 - 8) first elements of high - constexpr unsigned char elem_to_calc = residual % 16 - 8; - constexpr __mmask8 mask2 = (1 << elem_to_calc) - 1; - high_diff = _mm256_blend_ps(_mm256_setzero_ps(), high_diff, mask2); - } - sum = _mm256_add_ps(sum, _mm256_mul_ps(high_diff, high_diff)); - } - sum = _mm256_add_ps(sum, _mm256_mul_ps(low_diff, low_diff)); - } - - // Do a single step if residual >=16 - if constexpr (residual >= 16) { - L2SqrStep(pVect1, pVect2, sum); - } - - // Handle 512 bits (32 bfloat16) in chunks of max SIMD = 256 bits = 16 bfloat16 - do { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return my_mm256_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/L2/L2_AVX2_FMA_SQ8_FP32.h b/src/VecSim/spaces/L2/L2_AVX2_FMA_SQ8_FP32.h deleted file mode 100644 index 46eb4cc6e..000000000 --- a/src/VecSim/spaces/L2/L2_AVX2_FMA_SQ8_FP32.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" -#include "VecSim/spaces/IP/IP_AVX2_FMA_SQ8_FP32.h" -#include "VecSim/types/sq8.h" - -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 L2 squared distance using algebraic identity: - * - * ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i² - * = x_sum_squares - 2 * IP(x, y) + y_sum_squares - * - * where: - * - IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) (computed via SQ8_FP32_InnerProductImp_FMA) - * - x_sum_squares and y_sum_squares are precomputed - * - * This avoids dequantization in the hot loop. - */ - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_L2SqrSIMD16_AVX2_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Get the raw inner product using the common SIMD implementation - const float ip = SQ8_FP32_InnerProductImp_FMA(pVect1v, pVect2v, dimension); - - // Get precomputed sum of squares from storage blob (pVect1v is SQ8 storage) - const uint8_t *pVect1 = static_cast(pVect1v); - const float *params = reinterpret_cast(pVect1 + dimension); - const float x_sum_sq = params[sq8::SUM_SQUARES]; - - // Get precomputed sum of squares from query blob (pVect2v is FP32 query) - const float y_sum_sq = static_cast(pVect2v)[dimension + sq8::SUM_SQUARES_QUERY]; - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return x_sum_sq + y_sum_sq - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_AVX2_SQ8_FP32.h b/src/VecSim/spaces/L2/L2_AVX2_SQ8_FP32.h deleted file mode 100644 index cc1fa4272..000000000 --- a/src/VecSim/spaces/L2/L2_AVX2_SQ8_FP32.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" -#include "VecSim/spaces/IP/IP_AVX2_SQ8_FP32.h" -#include "VecSim/types/sq8.h" - -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 L2 squared distance using algebraic identity: - * - * ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i² - * = x_sum_squares - 2 * IP(x, y) + y_sum_squares - * - * where: - * - IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) (computed via SQ8_FP32_InnerProductImp_AVX2) - * - x_sum_squares and y_sum_squares are precomputed - * - * This avoids dequantization in the hot loop. - */ - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_L2SqrSIMD16_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Get the raw inner product using the common SIMD implementation - const float ip = SQ8_FP32_InnerProductImp_AVX2(pVect1v, pVect2v, dimension); - - // Get precomputed sum of squares from storage blob (pVect1v is SQ8 storage) - const uint8_t *pVect1 = static_cast(pVect1v); - const float *params = reinterpret_cast(pVect1 + dimension); - const float x_sum_sq = params[sq8::SUM_SQUARES]; - - // Get precomputed sum of squares from query blob (pVect2v is FP32 query) - const float y_sum_sq = static_cast(pVect2v)[dimension + sq8::SUM_SQUARES_QUERY]; - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return x_sum_sq + y_sum_sq - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_AVX512BW_VBMI2_BF16.h b/src/VecSim/spaces/L2/L2_AVX512BW_VBMI2_BF16.h deleted file mode 100644 index 6d7bf01e7..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512BW_VBMI2_BF16.h +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/bfloat16.h" - -using bfloat16 = vecsim_types::bfloat16; - -static inline void L2SqrHalfStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m512 &sum, - __mmask32 mask) { - __m512i v1 = _mm512_maskz_expandloadu_epi16(mask, pVect1); // AVX512_VBMI2 - __m512i v2 = _mm512_maskz_expandloadu_epi16(mask, pVect2); // AVX512_VBMI2 - __m512 diff = _mm512_sub_ps(_mm512_castsi512_ps(v1), _mm512_castsi512_ps(v2)); - sum = _mm512_fmadd_ps(diff, diff, sum); -} - -static inline void L2SqrStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m512 &sum) { - __m512i v1 = _mm512_loadu_si512((__m512i *)pVect1); - __m512i v2 = _mm512_loadu_si512((__m512i *)pVect2); - pVect1 += 32; - pVect2 += 32; - __m512i zeros = _mm512_setzero_si512(); - - // Convert 0:3, 8:11, .. 28:31 to float32 - __m512i v1_low = _mm512_unpacklo_epi16(zeros, v1); // AVX512BW - __m512i v2_low = _mm512_unpacklo_epi16(zeros, v2); - __m512 diff = _mm512_sub_ps(_mm512_castsi512_ps(v1_low), _mm512_castsi512_ps(v2_low)); - sum = _mm512_fmadd_ps(diff, diff, sum); - - // Convert 4:7, 12:15, .. 24:27 to float32 - __m512i v1_high = _mm512_unpackhi_epi16(zeros, v1); - __m512i v2_high = _mm512_unpackhi_epi16(zeros, v2); - diff = _mm512_sub_ps(_mm512_castsi512_ps(v1_high), _mm512_castsi512_ps(v2_high)); - sum = _mm512_fmadd_ps(diff, diff, sum); -} - -template // 0..31 -float BF16_L2SqrSIMD32_AVX512BW_VBMI2(const void *pVect1v, const void *pVect2v, size_t dimension) { - bfloat16 *pVect1 = (bfloat16 *)pVect1v; - bfloat16 *pVect2 = (bfloat16 *)pVect2v; - - const bfloat16 *pEnd1 = pVect1 + dimension; - - __m512 sum = _mm512_setzero_ps(); - - // Handle first residual % 32 elements - if constexpr (residual) { - constexpr __mmask32 mask = 0xAAAAAAAA; - - // Calculate first 16 - if constexpr (residual >= 16) { - L2SqrHalfStep(pVect1, pVect2, sum, mask); - pVect1 += 16; - pVect2 += 16; - } - if constexpr (residual != 16) { - // Each element is represented by a pair of 01 bits - // Create a mask for the elements we want to process: - // mask2 = {01 * (residual % 16)}0000... - constexpr __mmask32 mask2 = mask & ((1 << ((residual % 16) * 2)) - 1); - L2SqrHalfStep(pVect1, pVect2, sum, mask2); - pVect1 += residual % 16; - pVect2 += residual % 16; - } - } - - // Handle 512 bits (32 bfloat16) in chunks of max SIMD = 512 bits = 32 bfloat16 - do { - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return _mm512_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/L2/L2_AVX512FP16_VL_FP16.h b/src/VecSim/spaces/L2/L2_AVX512FP16_VL_FP16.h deleted file mode 100644 index 27e909a30..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512FP16_VL_FP16.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/float16.h" -#include - -using float16 = vecsim_types::float16; - -static inline void L2SqrStep(float16 *&pVect1, float16 *&pVect2, __m512h &sum) { - __m512h v1 = _mm512_loadu_ph(pVect1); - __m512h v2 = _mm512_loadu_ph(pVect2); - - __m512h diff = _mm512_sub_ph(v1, v2); - - sum = _mm512_fmadd_ph(diff, diff, sum); - pVect1 += 32; - pVect2 += 32; -} - -template // 0..31 -float FP16_L2SqrSIMD32_AVX512FP16_VL(const void *pVect1v, const void *pVect2v, size_t dimension) { - auto *pVect1 = (float16 *)pVect1v; - auto *pVect2 = (float16 *)pVect2v; - - const float16 *pEnd1 = pVect1 + dimension; - - __m512h sum = _mm512_setzero_ph(); - - if constexpr (residual) { - constexpr __mmask32 mask = (1LU << residual) - 1; - __m512h v1 = _mm512_loadu_ph(pVect1); - pVect1 += residual; - __m512h v2 = _mm512_loadu_ph(pVect2); - pVect2 += residual; - __m512h diff = _mm512_maskz_sub_ph(mask, v1, v2); - - sum = _mm512_mul_ph(diff, diff); - } - - // We dealt with the residual part. We are left with some multiple of 32 16-bit floats. - do { - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - _Float16 res = _mm512_reduce_add_ph(sum); - return res; -} diff --git a/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_INT8.h b/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_INT8.h deleted file mode 100644 index 5dd765b2a..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_INT8.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void L2SqrStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) { - __m256i temp_a = _mm256_loadu_epi8(pVect1); - __m512i va = _mm512_cvtepi8_epi16(temp_a); - pVect1 += 32; - - __m256i temp_b = _mm256_loadu_epi8(pVect2); - __m512i vb = _mm512_cvtepi8_epi16(temp_b); - pVect2 += 32; - - __m512i diff = _mm512_sub_epi16(va, vb); - // _mm512_dpwssd_epi32(src, a, b) - // Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding - // 16-bit integers in `b`, producing 2 intermediate signed 32-bit results. Sum these 2 results - // with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst. - sum = _mm512_dpwssd_epi32(sum, diff, diff); -} - -template // 0..63 -float INT8_L2SqrSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, - size_t dimension) { - int8_t *pVect1 = (int8_t *)pVect1v; - int8_t *pVect2 = (int8_t *)pVect2v; - - const int8_t *pEnd1 = pVect1 + dimension; - - __m512i sum = _mm512_setzero_epi32(); - - // Deal with remainder first. `dim` is more than 32, so we have at least one 32-int_8 block, - // so mask loading is guaranteed to be safe - if constexpr (residual % 32) { - constexpr __mmask32 mask = (1LU << (residual % 32)) - 1; - __m256i temp_a = _mm256_loadu_epi8(pVect1); - __m512i va = _mm512_cvtepi8_epi16(temp_a); - pVect1 += residual % 32; - - __m256i temp_b = _mm256_loadu_epi8(pVect2); - __m512i vb = _mm512_cvtepi8_epi16(temp_b); - pVect2 += residual % 32; - - __m512i diff = _mm512_maskz_sub_epi16(mask, va, vb); - sum = _mm512_dpwssd_epi32(sum, diff, diff); - } - - if constexpr (residual >= 32) { - L2SqrStep(pVect1, pVect2, sum); - } - - // We dealt with the residual part. We are left with some multiple of 64-int_8. - while (pVect1 < pEnd1) { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } - - return _mm512_reduce_add_epi32(sum); -} diff --git a/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8_FP32.h b/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8_FP32.h deleted file mode 100644 index 57db23fb9..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8_FP32.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_FP32.h" -#include "VecSim/types/sq8.h" - -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 L2 squared distance using algebraic identity: - * - * ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i² - * = x_sum_squares - 2 * IP(x, y) + y_sum_squares - * - * where: - * - IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) (computed via SQ8_FP32_InnerProductImp_AVX512) - * - x_sum_squares and y_sum_squares are precomputed - * - * This avoids dequantization in the hot loop. - */ - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_L2SqrSIMD16_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, - size_t dimension) { - // Get the raw inner product using the common SIMD implementation - const float ip = SQ8_FP32_InnerProductImp_AVX512(pVect1v, pVect2v, dimension); - - // Get precomputed sum of squares from storage blob (pVect1v is SQ8 storage) - const uint8_t *pVect1 = static_cast(pVect1v); - const float *params = reinterpret_cast(pVect1 + dimension); - const float x_sum_sq = params[sq8::SUM_SQUARES]; - - // Get precomputed sum of squares from query blob (pVect2v is FP32 query) - const float y_sum_sq = static_cast(pVect2v)[dimension + sq8::SUM_SQUARES_QUERY]; - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return x_sum_sq + y_sum_sq - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8_SQ8.h b/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8_SQ8.h deleted file mode 100644 index df3043bf5..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8_SQ8.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_SQ8.h" - -/** - * SQ8-to-SQ8 L2 squared distance using AVX512 VNNI. - * Computes L2 squared distance between two SQ8 (scalar quantized 8-bit) vectors, - * where BOTH vectors are uint8 quantized. - * - * Uses the identity: ||x - y||² = ||x||² + ||y||² - 2*IP(x, y) - * where ||x||² and ||y||² are precomputed sum of squares stored in the vector data. - * - * Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] - * [sum_of_squares (float)] - */ - -// L2 squared distance using the common inner product implementation -template // 0..63 -float SQ8_SQ8_L2SqrSIMD64_AVX512F_BW_VL_VNNI(const void *pVec1v, const void *pVec2v, - size_t dimension) { - - // Use the common inner product implementation (returns raw IP, not distance) - const float ip = SQ8_SQ8_InnerProductImp(pVec1v, pVec2v, dimension); - - const uint8_t *pVec1 = static_cast(pVec1v); - const uint8_t *pVec2 = static_cast(pVec2v); - // Get precomputed sum of squares from both vectors - // Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares] - const float sum_sq_1 = *reinterpret_cast(pVec1 + dimension + 3 * sizeof(float)); - const float sum_sq_2 = *reinterpret_cast(pVec2 + dimension + 3 * sizeof(float)); - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return sum_sq_1 + sum_sq_2 - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_UINT8.h b/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_UINT8.h deleted file mode 100644 index 350b759ea..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_UINT8.h +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void L2SqrStep(uint8_t *&pVect1, uint8_t *&pVect2, __m512i &sum) { - __m512i va = _mm512_loadu_epi8(pVect1); // AVX512BW - pVect1 += 64; - - __m512i vb = _mm512_loadu_epi8(pVect2); // AVX512BW - pVect2 += 64; - - __m512i va_lo = _mm512_unpacklo_epi8(va, _mm512_setzero_si512()); // AVX512BW - __m512i vb_lo = _mm512_unpacklo_epi8(vb, _mm512_setzero_si512()); - __m512i diff_lo = _mm512_sub_epi16(va_lo, vb_lo); - sum = _mm512_dpwssd_epi32(sum, diff_lo, diff_lo); - - __m512i va_hi = _mm512_unpackhi_epi8(va, _mm512_setzero_si512()); // AVX512BW - __m512i vb_hi = _mm512_unpackhi_epi8(vb, _mm512_setzero_si512()); - __m512i diff_hi = _mm512_sub_epi16(va_hi, vb_hi); - sum = _mm512_dpwssd_epi32(sum, diff_hi, diff_hi); - - // _mm512_dpwssd_epi32(src, a, b) - // Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding - // 16-bit integers in `b`, producing 2 intermediate signed 32-bit results. Sum these 2 results - // with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst. -} - -template // 0..63 -float UINT8_L2SqrSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, - size_t dimension) { - uint8_t *pVect1 = (uint8_t *)pVect1v; - uint8_t *pVect2 = (uint8_t *)pVect2v; - - const uint8_t *pEnd1 = pVect1 + dimension; - - __m512i sum = _mm512_setzero_epi32(); - - // Deal with remainder first. - if constexpr (residual) { - if constexpr (residual < 32) { - constexpr __mmask32 mask = (1LU << residual) - 1; - __m256i temp_a = _mm256_maskz_loadu_epi8(mask, pVect1); - __m512i va = _mm512_cvtepu8_epi16(temp_a); - - __m256i temp_b = _mm256_maskz_loadu_epi8(mask, pVect2); - __m512i vb = _mm512_cvtepu8_epi16(temp_b); - - __m512i diff = _mm512_sub_epi16(va, vb); - sum = _mm512_dpwssd_epi32(sum, diff, diff); - } else if constexpr (residual == 32) { - __m256i temp_a = _mm256_loadu_epi8(pVect1); - __m512i va = _mm512_cvtepu8_epi16(temp_a); - - __m256i temp_b = _mm256_loadu_epi8(pVect2); - __m512i vb = _mm512_cvtepu8_epi16(temp_b); - - __m512i diff = _mm512_sub_epi16(va, vb); - sum = _mm512_dpwssd_epi32(sum, diff, diff); - } else { - constexpr __mmask64 mask = (1LU << residual) - 1; - __m512i va = _mm512_maskz_loadu_epi8(mask, pVect1); // AVX512BW - __m512i vb = _mm512_maskz_loadu_epi8(mask, pVect2); // AVX512BW - - __m512i va_lo = _mm512_unpacklo_epi8(va, _mm512_setzero_si512()); // AVX512BW - __m512i vb_lo = _mm512_unpacklo_epi8(vb, _mm512_setzero_si512()); - __m512i diff_lo = _mm512_sub_epi16(va_lo, vb_lo); - sum = _mm512_dpwssd_epi32(sum, diff_lo, diff_lo); - - __m512i va_hi = _mm512_unpackhi_epi8(va, _mm512_setzero_si512()); // AVX512BW - __m512i vb_hi = _mm512_unpackhi_epi8(vb, _mm512_setzero_si512()); - __m512i diff_hi = _mm512_sub_epi16(va_hi, vb_hi); - sum = _mm512_dpwssd_epi32(sum, diff_hi, diff_hi); - } - pVect1 += residual; - pVect2 += residual; - - // We dealt with the residual part. - // We are left with some multiple of 64-uint_8 (might be 0). - while (pVect1 < pEnd1) { - L2SqrStep(pVect1, pVect2, sum); - } - } else { - // We have no residual, we have some non-zero multiple of 64-uint_8. - do { - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - } - - return _mm512_reduce_add_epi32(sum); -} diff --git a/src/VecSim/spaces/L2/L2_AVX512F_FP16.h b/src/VecSim/spaces/L2/L2_AVX512F_FP16.h deleted file mode 100644 index e2a21414f..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512F_FP16.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" -#include "VecSim/types/float16.h" - -using float16 = vecsim_types::float16; - -static void L2SqrStep(float16 *&pVect1, float16 *&pVect2, __m512 &sum) { - // Convert 16 half-floats into floats and store them in 512 bits register. - auto v1 = _mm512_cvtph_ps(_mm256_lddqu_si256((__m256i *)pVect1)); - auto v2 = _mm512_cvtph_ps(_mm256_lddqu_si256((__m256i *)pVect2)); - - // sum = (v1 - v2)^2 + sum - auto c = _mm512_sub_ps(v1, v2); - sum = _mm512_fmadd_ps(c, c, sum); - pVect1 += 16; - pVect2 += 16; -} - -template // 0..31 -float FP16_L2SqrSIMD32_AVX512(const void *pVect1v, const void *pVect2v, size_t dimension) { - auto *pVect1 = (float16 *)pVect1v; - auto *pVect2 = (float16 *)pVect2v; - - const float16 *pEnd1 = pVect1 + dimension; - - auto sum = _mm512_setzero_ps(); - - if constexpr (residual % 16) { - // Deal with remainder first. `dim` is more than 32, so we have at least one block of 32 - // 16-bit float so mask loading is guaranteed to be safe. - __mmask16 constexpr residuals_mask = (1 << (residual % 16)) - 1; - // Convert the first half-floats in the residual positions into floats and store them - // 512 bits register, where the floats in the positions corresponding to the non-residuals - // positions are zeros. - auto v1 = _mm512_maskz_mov_ps(residuals_mask, - _mm512_cvtph_ps(_mm256_lddqu_si256((__m256i *)pVect1))); - auto v2 = _mm512_maskz_mov_ps(residuals_mask, - _mm512_cvtph_ps(_mm256_lddqu_si256((__m256i *)pVect2))); - auto c = _mm512_sub_ps(v1, v2); - sum = _mm512_mul_ps(c, c); - pVect1 += residual % 16; - pVect2 += residual % 16; - } - if constexpr (residual >= 16) { - L2SqrStep(pVect1, pVect2, sum); - } - - // We dealt with the residual part. We are left with some multiple of 32 16-bit floats. - // In every iteration we process 2 chunk of 256bit (32 FP16) - do { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return _mm512_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/L2/L2_AVX512F_FP32.h b/src/VecSim/spaces/L2/L2_AVX512F_FP32.h deleted file mode 100644 index 0100e4264..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512F_FP32.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void L2SqrStep(float *&pVect1, float *&pVect2, __m512 &sum) { - __m512 v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; - __m512 v2 = _mm512_loadu_ps(pVect2); - pVect2 += 16; - __m512 diff = _mm512_sub_ps(v1, v2); - - sum = _mm512_fmadd_ps(diff, diff, sum); -} - -template // 0..15 -float FP32_L2SqrSIMD16_AVX512(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - - const float *pEnd1 = pVect1 + dimension; - - __m512 sum = _mm512_setzero_ps(); - - // Deal with remainder first. `dim` is more than 16, so we have at least one 16-float block, - // so mask loading is guaranteed to be safe - if constexpr (residual) { - __mmask16 constexpr mask = (1 << residual) - 1; - __m512 v1 = _mm512_maskz_loadu_ps(mask, pVect1); - pVect1 += residual; - __m512 v2 = _mm512_maskz_loadu_ps(mask, pVect2); - pVect2 += residual; - __m512 diff = _mm512_sub_ps(v1, v2); - sum = _mm512_mul_ps(diff, diff); - } - - // We dealt with the residual part. We are left with some multiple of 16 floats. - do { - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return _mm512_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/L2/L2_AVX512F_FP64.h b/src/VecSim/spaces/L2/L2_AVX512F_FP64.h deleted file mode 100644 index 1a54c7048..000000000 --- a/src/VecSim/spaces/L2/L2_AVX512F_FP64.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void L2SqrStep(double *&pVect1, double *&pVect2, __m512d &sum) { - __m512d v1 = _mm512_loadu_pd(pVect1); - pVect1 += 8; - __m512d v2 = _mm512_loadu_pd(pVect2); - pVect2 += 8; - __m512d diff = _mm512_sub_pd(v1, v2); - - sum = _mm512_fmadd_pd(diff, diff, sum); -} - -template // 0..7 -double FP64_L2SqrSIMD8_AVX512(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - - const double *pEnd1 = pVect1 + dimension; - - __m512d sum = _mm512_setzero_pd(); - - // Deal with remainder first. `dim` is more than 8, so we have at least one 8-double block, - // so mask loading is guaranteed to be safe - if constexpr (residual) { - __mmask8 constexpr mask = (1 << residual) - 1; - __m512d v1 = _mm512_maskz_loadu_pd(mask, pVect1); - pVect1 += residual; - __m512d v2 = _mm512_maskz_loadu_pd(mask, pVect2); - pVect2 += residual; - __m512d diff = _mm512_sub_pd(v1, v2); - sum = _mm512_mul_pd(diff, diff); - } - - // We dealt with the residual part. We are left with some multiple of 8 doubles. - do { - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return _mm512_reduce_add_pd(sum); -} diff --git a/src/VecSim/spaces/L2/L2_AVX_FP32.h b/src/VecSim/spaces/L2/L2_AVX_FP32.h deleted file mode 100644 index 4751d4726..000000000 --- a/src/VecSim/spaces/L2/L2_AVX_FP32.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" - -static inline void L2SqrStep(float *&pVect1, float *&pVect2, __m256 &sum) { - __m256 v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - __m256 diff = _mm256_sub_ps(v1, v2); - // sum = _mm256_fmadd_ps(diff, diff, sum); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); -} - -template // 0..15 -float FP32_L2SqrSIMD16_AVX(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - - const float *pEnd1 = pVect1 + dimension; - - __m256 sum = _mm256_setzero_ps(); - - // Deal with 1-7 floats with mask loading, if needed - if constexpr (residual % 8) { - __mmask8 constexpr mask8 = (1 << (residual % 8)) - 1; - __m256 v1 = my_mm256_maskz_loadu_ps(pVect1); - pVect1 += residual % 8; - __m256 v2 = my_mm256_maskz_loadu_ps(pVect2); - pVect2 += residual % 8; - __m256 diff = _mm256_sub_ps(v1, v2); - sum = _mm256_mul_ps(diff, diff); - } - - // If the reminder is >=8, have another step of 8 floats - if constexpr (residual >= 8) { - L2SqrStep(pVect1, pVect2, sum); - } - - // We dealt with the residual part. We are left with some multiple of 16 floats. - // In each iteration we calculate 16 floats = 512 bits. - do { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return my_mm256_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/L2/L2_AVX_FP64.h b/src/VecSim/spaces/L2/L2_AVX_FP64.h deleted file mode 100644 index 09257dca5..000000000 --- a/src/VecSim/spaces/L2/L2_AVX_FP64.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" - -static inline void L2SqrStep(double *&pVect1, double *&pVect2, __m256d &sum) { - __m256d v1 = _mm256_loadu_pd(pVect1); - pVect1 += 4; - __m256d v2 = _mm256_loadu_pd(pVect2); - pVect2 += 4; - __m256d diff = _mm256_sub_pd(v1, v2); - // sum = _mm256_fmadd_pd(diff, diff, sum); - sum = _mm256_add_pd(sum, _mm256_mul_pd(diff, diff)); -} - -template // 0..7 -double FP64_L2SqrSIMD8_AVX(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - - const double *pEnd1 = pVect1 + dimension; - - __m256d sum = _mm256_setzero_pd(); - - // Deal with 1-3 doubles with mask loading, if needed - if constexpr (residual % 4) { - // _mm256_maskz_loadu_pd is not available in AVX - __mmask8 constexpr mask4 = (1 << (residual % 4)) - 1; - __m256d v1 = my_mm256_maskz_loadu_pd(pVect1); - pVect1 += residual % 4; - __m256d v2 = my_mm256_maskz_loadu_pd(pVect2); - pVect2 += residual % 4; - __m256d diff = _mm256_sub_pd(v1, v2); - sum = _mm256_mul_pd(diff, diff); - } - - // If the reminder is >=4, have another step of 4 doubles - if constexpr (residual >= 4) { - L2SqrStep(pVect1, pVect2, sum); - } - - // We dealt with the residual part. We are left with some multiple of 8 doubles. - // In each iteration we calculate 8 doubles = 512 bits. - do { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - double PORTABLE_ALIGN32 TmpRes[4]; - _mm256_store_pd(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; -} diff --git a/src/VecSim/spaces/L2/L2_F16C_FP16.h b/src/VecSim/spaces/L2/L2_F16C_FP16.h deleted file mode 100644 index c193b9cec..000000000 --- a/src/VecSim/spaces/L2/L2_F16C_FP16.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/AVX_utils.h" -#include "VecSim/types/float16.h" - -using float16 = vecsim_types::float16; - -static void L2SqrStep(float16 *&pVect1, float16 *&pVect2, __m256 &sum) { - // Convert 8 half-floats into floats and store them in 256 bits register. - auto v1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i_u const *)(pVect1))); - auto v2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i_u const *)(pVect2))); - - // sum = (v1 * v2)^2 + sum - auto c = _mm256_sub_ps(v1, v2); - sum = _mm256_fmadd_ps(c, c, sum); - pVect1 += 8; - pVect2 += 8; -} - -template // 0..31 -float FP16_L2SqrSIMD32_F16C(const void *pVect1v, const void *pVect2v, size_t dimension) { - auto *pVect1 = (float16 *)pVect1v; - auto *pVect2 = (float16 *)pVect2v; - - const float16 *pEnd1 = pVect1 + dimension; - - auto sum = _mm256_setzero_ps(); - - if constexpr (residual % 8) { - // Deal with remainder first. `dim` is more than 32, so we have at least one block of 32 - // 16-bit float so mask loading is guaranteed to be safe. - __mmask16 constexpr residuals_mask = (1 << (residual % 8)) - 1; - // Convert the first 8 half-floats into floats and store them 256 bits register, - // where the floats in the positions corresponding to residuals are zeros. - auto v1 = _mm256_blend_ps(_mm256_setzero_ps(), - _mm256_cvtph_ps(_mm_loadu_si128((__m128i_u const *)pVect1)), - residuals_mask); - auto v2 = _mm256_blend_ps(_mm256_setzero_ps(), - _mm256_cvtph_ps(_mm_loadu_si128((__m128i_u const *)pVect2)), - residuals_mask); - // sum = (v1 * v2)^2 + sum - auto c = _mm256_sub_ps(v1, v2); - sum = _mm256_fmadd_ps(c, c, sum); - pVect1 += residual % 8; - pVect2 += residual % 8; - } - if constexpr (residual >= 8 && residual < 16) { - L2SqrStep(pVect1, pVect2, sum); - } else if constexpr (residual >= 16 && residual < 24) { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } else if constexpr (residual >= 24) { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } - // We dealt with the residual part. We are left with some multiple of 32 16-bit floats. - // In every iteration we process 4 chunk of 128bit (32 FP16) - do { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - return my_mm256_reduce_add_ps(sum); -} diff --git a/src/VecSim/spaces/L2/L2_NEON_BF16.h b/src/VecSim/spaces/L2/L2_NEON_BF16.h deleted file mode 100644 index 2447f1b6f..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_BF16.h +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include - -// Assumes little-endianess -inline void L2Sqr_Op(float32x4_t &acc, bfloat16x8_t &v1, bfloat16x8_t &v2) { - float32x4_t v1_lo = vcvtq_low_f32_bf16(v1); - float32x4_t v2_lo = vcvtq_low_f32_bf16(v2); - float32x4_t diff_lo = vsubq_f32(v1_lo, v2_lo); - - acc = vfmaq_f32(acc, diff_lo, diff_lo); - - float32x4_t v1_hi = vcvtq_high_f32_bf16(v1); - float32x4_t v2_hi = vcvtq_high_f32_bf16(v2); - float32x4_t diff_hi = vsubq_f32(v1_hi, v2_hi); - - acc = vfmaq_f32(acc, diff_hi, diff_hi); -} - -inline void L2Sqr_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) { - // Load brain-half-precision vectors - bfloat16x8_t v1 = vld1q_bf16(vec1); - bfloat16x8_t v2 = vld1q_bf16(vec2); - vec1 += 8; - vec2 += 8; - L2Sqr_Op(acc, v1, v2); -} - -template // 0..31 -float BF16_L2Sqr_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *vec1 = static_cast(pVect1v); - const auto *vec2 = static_cast(pVect2v); - const auto *const v1End = vec1 + dimension; - float32x4_t acc1 = vdupq_n_f32(0.0f); - float32x4_t acc2 = vdupq_n_f32(0.0f); - float32x4_t acc3 = vdupq_n_f32(0.0f); - float32x4_t acc4 = vdupq_n_f32(0.0f); - - // First, handle the partial chunk residual - if constexpr (residual % 8) { - auto constexpr chunk_residual = residual % 8; - // TODO: special cases for some residuals and benchmark if its better - constexpr uint16x8_t mask = { - 0xFFFF, - (chunk_residual >= 2) ? 0xFFFF : 0, - (chunk_residual >= 3) ? 0xFFFF : 0, - (chunk_residual >= 4) ? 0xFFFF : 0, - (chunk_residual >= 5) ? 0xFFFF : 0, - (chunk_residual >= 6) ? 0xFFFF : 0, - (chunk_residual >= 7) ? 0xFFFF : 0, - 0, - }; - - // Load partial vectors - bfloat16x8_t v1 = vld1q_bf16(vec1); - bfloat16x8_t v2 = vld1q_bf16(vec2); - - // Apply mask to both vectors - bfloat16x8_t masked_v1 = - vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask)); - bfloat16x8_t masked_v2 = - vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask)); - - L2Sqr_Op(acc1, masked_v1, masked_v2); - - // Advance pointers - vec1 += chunk_residual; - vec2 += chunk_residual; - } - - // Handle (residual - (residual % 8)) in chunks of 8 bfloat16 - if constexpr (residual >= 8) - L2Sqr_Step(vec1, vec2, acc2); - if constexpr (residual >= 16) - L2Sqr_Step(vec1, vec2, acc3); - if constexpr (residual >= 24) - L2Sqr_Step(vec1, vec2, acc4); - - // Process the rest of the vectors (the full chunks part) - while (vec1 < v1End) { - // TODO: use `vld1q_f16_x4` for quad-loading? - L2Sqr_Step(vec1, vec2, acc1); - L2Sqr_Step(vec1, vec2, acc2); - L2Sqr_Step(vec1, vec2, acc3); - L2Sqr_Step(vec1, vec2, acc4); - } - - // Accumulate accumulators - acc1 = vpaddq_f32(acc1, acc3); - acc2 = vpaddq_f32(acc2, acc4); - acc1 = vpaddq_f32(acc1, acc2); - - // Pairwise add to get horizontal sum - float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1)); - folded = vpadd_f32(folded, folded); - - // Extract result - return vget_lane_f32(folded, 0); -} diff --git a/src/VecSim/spaces/L2/L2_NEON_DOTPROD_INT8.h b/src/VecSim/spaces/L2/L2_NEON_DOTPROD_INT8.h deleted file mode 100644 index 5eb49c57b..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_DOTPROD_INT8.h +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -__attribute__((always_inline)) static inline void L2SquareOp(const int8x16_t &v1, - const int8x16_t &v2, uint32x4_t &sum) { - // Explicitly reinterpret the int8 vectors as uint8 for vabdq_u8 - - // Compute absolute differences (results in uint8x16_t) - int8x16_t diff = vabdq_s8(v1, v2); - - // Reinterpret back to int8x16_t for vdotq_s32 - uint8x16_t diff_s8 = vreinterpretq_u8_s8(diff); - - // Use dot product to square and accumulate (diff·diff) - sum = vdotq_u32(sum, diff_s8, diff_s8); -} - -__attribute__((always_inline)) static inline void L2SquareStep16(int8_t *&pVect1, int8_t *&pVect2, - uint32x4_t &sum) { - // Load 16 int8 elements (16 bytes) into NEON registers - int8x16_t v1 = vld1q_s8(pVect1); - int8x16_t v2 = vld1q_s8(pVect2); - - L2SquareOp(v1, v2, sum); - - pVect1 += 16; - pVect2 += 16; -} - -static inline void L2SquareStep32(int8_t *&pVect1, int8_t *&pVect2, uint32x4_t &sum0, - uint32x4_t &sum1) { - // Load 32 int8 elements (32 bytes) at once - int8x16x2_t v1 = vld1q_s8_x2(pVect1); - int8x16x2_t v2 = vld1q_s8_x2(pVect2); - - auto v1_0 = v1.val[0]; - auto v2_0 = v2.val[0]; - L2SquareOp(v1_0, v2_0, sum0); - - auto v1_1 = v1.val[1]; - auto v2_1 = v2.val[1]; - L2SquareOp(v1_1, v2_1, sum1); - - pVect1 += 32; - pVect2 += 32; -} - -template // 0..63 -float INT8_L2SqrSIMD16_NEON_DOTPROD(const void *pVect1v, const void *pVect2v, size_t dimension) { - int8_t *pVect1 = (int8_t *)pVect1v; - int8_t *pVect2 = (int8_t *)pVect2v; - - uint32x4_t sum0 = vdupq_n_u32(0); - uint32x4_t sum1 = vdupq_n_u32(0); - uint32x4_t sum2 = vdupq_n_u32(0); - uint32x4_t sum3 = vdupq_n_u32(0); - - constexpr size_t final_residual = residual % 16; - if constexpr (final_residual > 0) { - // Define a compile-time constant mask based on final_residual - constexpr uint8x16_t mask = { - 0xFF, - (final_residual >= 2) ? 0xFF : 0, - (final_residual >= 3) ? 0xFF : 0, - (final_residual >= 4) ? 0xFF : 0, - (final_residual >= 5) ? 0xFF : 0, - (final_residual >= 6) ? 0xFF : 0, - (final_residual >= 7) ? 0xFF : 0, - (final_residual >= 8) ? 0xFF : 0, - (final_residual >= 9) ? 0xFF : 0, - (final_residual >= 10) ? 0xFF : 0, - (final_residual >= 11) ? 0xFF : 0, - (final_residual >= 12) ? 0xFF : 0, - (final_residual >= 13) ? 0xFF : 0, - (final_residual >= 14) ? 0xFF : 0, - (final_residual >= 15) ? 0xFF : 0, - 0, - }; - - // Load data directly from input vectors - int8x16_t v1 = vld1q_s8(pVect1); - int8x16_t v2 = vld1q_s8(pVect2); - - // Zero vector for replacement - int8x16_t zeros = vdupq_n_s8(0); - - // Apply bit select to zero out irrelevant elements - v1 = vbslq_s8(mask, v1, zeros); - v2 = vbslq_s8(mask, v2, zeros); - L2SquareOp(v1, v2, sum0); - pVect1 += final_residual; - pVect2 += final_residual; - } - - // Process 64 elements at a time in the main loop - size_t num_of_chunks = dimension / 64; - - for (size_t i = 0; i < num_of_chunks; i++) { - L2SquareStep32(pVect1, pVect2, sum0, sum1); - L2SquareStep32(pVect1, pVect2, sum2, sum3); - } - - constexpr size_t num_of_32_chunks = residual / 32; - if constexpr (num_of_32_chunks) { - L2SquareStep32(pVect1, pVect2, sum0, sum1); - } - - constexpr size_t residual_chunks = (residual % 32) / 16; - if constexpr (residual_chunks > 0) { - L2SquareStep16(pVect1, pVect2, sum2); - } - - // Horizontal sum of the 4 elements in the sum register to get final result - uint32x4_t total_sum = vaddq_u32(sum0, sum1); - - total_sum = vaddq_u32(total_sum, sum2); - total_sum = vaddq_u32(total_sum, sum3); - - // Horizontal sum of the 4 elements in the combined sum register - uint32_t result = vaddvq_u32(total_sum); - - // Return the L2 squared distance as a float - return static_cast(result); -} diff --git a/src/VecSim/spaces/L2/L2_NEON_DOTPROD_SQ8_SQ8.h b/src/VecSim/spaces/L2/L2_NEON_DOTPROD_SQ8_SQ8.h deleted file mode 100644 index 7de9f336a..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_DOTPROD_SQ8_SQ8.h +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_NEON_DOTPROD_SQ8_SQ8.h" -#include "VecSim/types/sq8.h" - -using sq8 = vecsim_types::sq8; - -/** - * SQ8-to-SQ8 L2 squared distance functions for NEON with DOTPROD extension. - * Computes L2 squared distance between two SQ8 (scalar quantized 8-bit) vectors, - * where BOTH vectors are uint8 quantized. - * - * Uses the identity: ||x - y||² = ||x||² + ||y||² - 2*IP(x, y) - * where ||x||² and ||y||² are precomputed sum of squares stored in the vector data. - * - * Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] - * [sum_of_squares (float)] - */ - -// L2 squared distance using the common inner product implementation -template // 0..63 -float SQ8_SQ8_L2SqrSIMD64_NEON_DOTPROD(const void *pVec1v, const void *pVec2v, size_t dimension) { - // Use the common inner product implementation (returns raw IP, not distance) - const float ip = - SQ8_SQ8_InnerProductSIMD64_NEON_DOTPROD_IMP(pVec1v, pVec2v, dimension); - - const uint8_t *pVec1 = static_cast(pVec1v); - const uint8_t *pVec2 = static_cast(pVec2v); - - // Get precomputed sum of squares from both vectors - // Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares] - const float sum_sq_1 = - *reinterpret_cast(pVec1 + dimension + sq8::SUM_SQUARES * sizeof(float)); - const float sum_sq_2 = - *reinterpret_cast(pVec2 + dimension + sq8::SUM_SQUARES * sizeof(float)); - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return sum_sq_1 + sum_sq_2 - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_NEON_DOTPROD_UINT8.h b/src/VecSim/spaces/L2/L2_NEON_DOTPROD_UINT8.h deleted file mode 100644 index 654c0b3b1..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_DOTPROD_UINT8.h +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -__attribute__((always_inline)) static inline void -L2SquareOp(const uint8x16_t &v1, const uint8x16_t &v2, uint32x4_t &sum) { - // Explicitly reinterpret the int8 vectors as uint8 for vabdq_u8 - - // Compute absolute differences (results in uint8x16_t) - uint8x16_t diff = vabdq_u8(v1, v2); - - // Use dot product to square and accumulate (diff·diff) - sum = vdotq_u32(sum, diff, diff); -} - -__attribute__((always_inline)) static inline void L2SquareStep16(uint8_t *&pVect1, uint8_t *&pVect2, - uint32x4_t &sum) { - // Load 16 uint8 elements into NEON registers - uint8x16_t v1 = vld1q_u8(pVect1); - uint8x16_t v2 = vld1q_u8(pVect2); - - L2SquareOp(v1, v2, sum); - - pVect1 += 16; - pVect2 += 16; -} - -__attribute__((always_inline)) static inline void -L2SquareStep32(uint8_t *&pVect1, uint8_t *&pVect2, uint32x4_t &sum1, uint32x4_t &sum2) { - uint8x16x2_t v1_pair = vld1q_u8_x2(pVect1); - uint8x16x2_t v2_pair = vld1q_u8_x2(pVect2); - - // Reference the individual vectors - uint8x16_t v1_first = v1_pair.val[0]; - uint8x16_t v1_second = v1_pair.val[1]; - uint8x16_t v2_first = v2_pair.val[0]; - uint8x16_t v2_second = v2_pair.val[1]; - - L2SquareOp(v1_first, v2_first, sum1); - L2SquareOp(v1_second, v2_second, sum2); - - pVect1 += 32; - pVect2 += 32; -} - -template // 0..63 -float UINT8_L2SqrSIMD16_NEON_DOTPROD(const void *pVect1v, const void *pVect2v, size_t dimension) { - uint8_t *pVect1 = (uint8_t *)pVect1v; - uint8_t *pVect2 = (uint8_t *)pVect2v; - - uint32x4_t sum0 = vdupq_n_u32(0); - uint32x4_t sum1 = vdupq_n_u32(0); - uint32x4_t sum2 = vdupq_n_u32(0); - uint32x4_t sum3 = vdupq_n_u32(0); - - constexpr size_t final_residual = residual % 16; - if constexpr (final_residual > 0) { - constexpr uint8x16_t mask = { - 0xFF, - (final_residual >= 2) ? 0xFF : 0, - (final_residual >= 3) ? 0xFF : 0, - (final_residual >= 4) ? 0xFF : 0, - (final_residual >= 5) ? 0xFF : 0, - (final_residual >= 6) ? 0xFF : 0, - (final_residual >= 7) ? 0xFF : 0, - (final_residual >= 8) ? 0xFF : 0, - (final_residual >= 9) ? 0xFF : 0, - (final_residual >= 10) ? 0xFF : 0, - (final_residual >= 11) ? 0xFF : 0, - (final_residual >= 12) ? 0xFF : 0, - (final_residual >= 13) ? 0xFF : 0, - (final_residual >= 14) ? 0xFF : 0, - (final_residual >= 15) ? 0xFF : 0, - 0, - }; - - // Load data directly from input vectors - uint8x16_t v1 = vld1q_u8(pVect1); - uint8x16_t v2 = vld1q_u8(pVect2); - - // Zero vector for replacement - uint8x16_t zeros = vdupq_n_u8(0); - - // Apply bit select to zero out irrelevant elements - v1 = vbslq_u8(mask, v1, zeros); - v2 = vbslq_u8(mask, v2, zeros); - L2SquareOp(v1, v2, sum1); - pVect1 += final_residual; - pVect2 += final_residual; - } - - // Process 64 elements at a time in the main loop - size_t num_of_chunks = dimension / 64; - - for (size_t i = 0; i < num_of_chunks; i++) { - L2SquareStep32(pVect1, pVect2, sum0, sum2); - L2SquareStep32(pVect1, pVect2, sum1, sum3); - } - - constexpr size_t num_of_32_chunks = residual / 32; - - if constexpr (num_of_32_chunks) { - L2SquareStep32(pVect1, pVect2, sum0, sum1); - } - - constexpr size_t residual_chunks = (residual % 32) / 16; - if constexpr (residual_chunks > 0) { - L2SquareStep16(pVect1, pVect2, sum0); - } - - // Horizontal sum of the 4 elements in the sum register to get final result - uint32x4_t total_sum = vaddq_u32(sum0, sum1); - - total_sum = vaddq_u32(total_sum, sum2); - total_sum = vaddq_u32(total_sum, sum3); - - // Horizontal sum of the 4 elements in the combined sum register - uint32_t result = vaddvq_u32(total_sum); - - // Return the L2 squared distance as a float - return static_cast(result); -} diff --git a/src/VecSim/spaces/L2/L2_NEON_FP16.h b/src/VecSim/spaces/L2/L2_NEON_FP16.h deleted file mode 100644 index e2786aa7a..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_FP16.h +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include - -inline void L2Sqr_Step(const float16_t *&vec1, const float16_t *&vec2, float16x8_t &acc) { - // Load half-precision vectors - float16x8_t v1 = vld1q_f16(vec1); - float16x8_t v2 = vld1q_f16(vec2); - vec1 += 8; - vec2 += 8; - - // Calculate differences - float16x8_t diff = vsubq_f16(v1, v2); - // Square and accumulate - acc = vfmaq_f16(acc, diff, diff); -} - -template // 0..31 -float FP16_L2Sqr_NEON_HP(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *vec1 = static_cast(pVect1v); - const auto *vec2 = static_cast(pVect2v); - const auto *const v1End = vec1 + dimension; - float16x8_t acc1 = vdupq_n_f16(0.0f); - float16x8_t acc2 = vdupq_n_f16(0.0f); - float16x8_t acc3 = vdupq_n_f16(0.0f); - float16x8_t acc4 = vdupq_n_f16(0.0f); - - // First, handle the partial chunk residual - if constexpr (residual % 8) { - auto constexpr chunk_residual = residual % 8; - // TODO: spacial cases for some residuals and benchmark if its better - constexpr uint16x8_t mask = { - 0xFFFF, - (chunk_residual >= 2) ? 0xFFFF : 0, - (chunk_residual >= 3) ? 0xFFFF : 0, - (chunk_residual >= 4) ? 0xFFFF : 0, - (chunk_residual >= 5) ? 0xFFFF : 0, - (chunk_residual >= 6) ? 0xFFFF : 0, - (chunk_residual >= 7) ? 0xFFFF : 0, - 0, - }; - - // Load partial vectors - float16x8_t v1 = vld1q_f16(vec1); - float16x8_t v2 = vld1q_f16(vec2); - - // Apply mask to both vectors - float16x8_t masked_v1 = vbslq_f16(mask, v1, acc1); // `acc1` should be all zeros here - float16x8_t masked_v2 = vbslq_f16(mask, v2, acc2); // `acc2` should be all zeros here - - // Calculate differences - float16x8_t diff = vsubq_f16(masked_v1, masked_v2); - // Square and accumulate - acc1 = vfmaq_f16(acc1, diff, diff); - - // Advance pointers - vec1 += chunk_residual; - vec2 += chunk_residual; - } - - // Handle (residual - (residual % 8)) in chunks of 8 float16 - if constexpr (residual >= 8) - L2Sqr_Step(vec1, vec2, acc2); - if constexpr (residual >= 16) - L2Sqr_Step(vec1, vec2, acc3); - if constexpr (residual >= 24) - L2Sqr_Step(vec1, vec2, acc4); - - // Process the rest of the vectors (the full chunks part) - while (vec1 < v1End) { - // TODO: use `vld1q_f16_x4` for quad-loading? - L2Sqr_Step(vec1, vec2, acc1); - L2Sqr_Step(vec1, vec2, acc2); - L2Sqr_Step(vec1, vec2, acc3); - L2Sqr_Step(vec1, vec2, acc4); - } - - // Accumulate accumulators - acc1 = vpaddq_f16(acc1, acc3); - acc2 = vpaddq_f16(acc2, acc4); - acc1 = vpaddq_f16(acc1, acc2); - - // Horizontal sum of the accumulated values - float32x4_t sum_f32 = vcvt_f32_f16(vget_low_f16(acc1)); - sum_f32 = vaddq_f32(sum_f32, vcvt_f32_f16(vget_high_f16(acc1))); - - // Pairwise add to get horizontal sum - float32x2_t sum_2 = vadd_f32(vget_low_f32(sum_f32), vget_high_f32(sum_f32)); - sum_2 = vpadd_f32(sum_2, sum_2); - - // Extract result - return vget_lane_f32(sum_2, 0); -} diff --git a/src/VecSim/spaces/L2/L2_NEON_FP32.h b/src/VecSim/spaces/L2/L2_NEON_FP32.h deleted file mode 100644 index f6c8b618f..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_FP32.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -static inline void L2SquareStep(float *&pVect1, float *&pVect2, float32x4_t &sum) { - float32x4_t v1 = vld1q_f32(pVect1); - float32x4_t v2 = vld1q_f32(pVect2); - - float32x4_t diff = vsubq_f32(v1, v2); - - sum = vmlaq_f32(sum, diff, diff); - - pVect1 += 4; - pVect2 += 4; -} - -template // 0..15 -float FP32_L2SqrSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - - float32x4_t sum0 = vdupq_n_f32(0.0f); - float32x4_t sum1 = vdupq_n_f32(0.0f); - float32x4_t sum2 = vdupq_n_f32(0.0f); - float32x4_t sum3 = vdupq_n_f32(0.0f); - - const size_t num_of_chunks = dimension / 16; - - for (size_t i = 0; i < num_of_chunks; i++) { - L2SquareStep(pVect1, pVect2, sum0); - L2SquareStep(pVect1, pVect2, sum1); - L2SquareStep(pVect1, pVect2, sum2); - L2SquareStep(pVect1, pVect2, sum3); - } - - // Handle remaining complete 4-float blocks within residual - constexpr size_t remaining_chunks = residual / 4; - // Unrolled loop for the 4-float blocks - if constexpr (remaining_chunks >= 1) { - L2SquareStep(pVect1, pVect2, sum0); - } - if constexpr (remaining_chunks >= 2) { - L2SquareStep(pVect1, pVect2, sum1); - } - if constexpr (remaining_chunks >= 3) { - L2SquareStep(pVect1, pVect2, sum2); - } - - // Handle final residual elements (0-3 elements) - constexpr size_t final_residual = residual % 4; - if constexpr (final_residual > 0) { - float32x4_t v1 = vdupq_n_f32(0.0f); - float32x4_t v2 = vdupq_n_f32(0.0f); - - if constexpr (final_residual >= 1) { - v1 = vld1q_lane_f32(pVect1, v1, 0); - v2 = vld1q_lane_f32(pVect2, v2, 0); - } - if constexpr (final_residual >= 2) { - v1 = vld1q_lane_f32(pVect1 + 1, v1, 1); - v2 = vld1q_lane_f32(pVect2 + 1, v2, 1); - } - if constexpr (final_residual >= 3) { - v1 = vld1q_lane_f32(pVect1 + 2, v1, 2); - v2 = vld1q_lane_f32(pVect2 + 2, v2, 2); - } - - float32x4_t diff = vsubq_f32(v1, v2); - sum3 = vmlaq_f32(sum3, diff, diff); - } - - // Combine all four sum accumulators - float32x4_t sum_combined = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)); - - // Horizontal sum of the 4 elements in the combined NEON register - float32x2_t sum_halves = vadd_f32(vget_low_f32(sum_combined), vget_high_f32(sum_combined)); - float32x2_t summed = vpadd_f32(sum_halves, sum_halves); - float sum = vget_lane_f32(summed, 0); - - return sum; -} diff --git a/src/VecSim/spaces/L2/L2_NEON_FP64.h b/src/VecSim/spaces/L2/L2_NEON_FP64.h deleted file mode 100644 index 92d3d0a56..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_FP64.h +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -inline void L2SquareStep(double *&pVect1, double *&pVect2, float64x2_t &sum) { - float64x2_t v1 = vld1q_f64(pVect1); - float64x2_t v2 = vld1q_f64(pVect2); - - // Calculate difference between vectors - float64x2_t diff = vsubq_f64(v1, v2); - - // Square and accumulate - sum = vmlaq_f64(sum, diff, diff); - - pVect1 += 2; - pVect2 += 2; -} - -template // 0..7 -double FP64_L2SqrSIMD8_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - - float64x2_t sum0 = vdupq_n_f64(0.0); - float64x2_t sum1 = vdupq_n_f64(0.0); - float64x2_t sum2 = vdupq_n_f64(0.0); - float64x2_t sum3 = vdupq_n_f64(0.0); - // These are compile-time constants derived from the template parameter - - // Calculate how many full 8-element blocks to process - const size_t num_of_chunks = dimension / 8; - - for (size_t i = 0; i < num_of_chunks; i++) { - L2SquareStep(pVect1, pVect2, sum0); - L2SquareStep(pVect1, pVect2, sum1); - L2SquareStep(pVect1, pVect2, sum2); - L2SquareStep(pVect1, pVect2, sum3); - } - - // Handle remaining complete 2-float blocks within residual - constexpr size_t remaining_chunks = residual / 2; - // Unrolled loop for the 2-float blocks - if constexpr (remaining_chunks >= 1) { - L2SquareStep(pVect1, pVect2, sum0); - } - if constexpr (remaining_chunks >= 2) { - L2SquareStep(pVect1, pVect2, sum1); - } - if constexpr (remaining_chunks >= 3) { - L2SquareStep(pVect1, pVect2, sum2); - } - - // Handle final residual element - constexpr size_t final_residual = residual % 2; // Final element - if constexpr (final_residual > 0) { - float64x2_t v1 = vdupq_n_f64(0.0); - float64x2_t v2 = vdupq_n_f64(0.0); - v1 = vld1q_lane_f64(pVect1, v1, 0); - v2 = vld1q_lane_f64(pVect2, v2, 0); - - // Calculate difference and square - float64x2_t diff = vsubq_f64(v1, v2); - sum3 = vmlaq_f64(sum3, diff, diff); - } - - float64x2_t sum_combined = vaddq_f64(vaddq_f64(sum0, sum1), vaddq_f64(sum2, sum3)); - - // Horizontal sum of the 4 elements in the NEON register - float64x1_t sum = vadd_f64(vget_low_f64(sum_combined), vget_high_f64(sum_combined)); - return vget_lane_f64(sum, 0); -} diff --git a/src/VecSim/spaces/L2/L2_NEON_INT8.h b/src/VecSim/spaces/L2/L2_NEON_INT8.h deleted file mode 100644 index 17a2875c0..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_INT8.h +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -__attribute__((always_inline)) static inline void L2SquareOp(const int8x16_t &v1, - const int8x16_t &v2, int32x4_t &sum) { - // Compute absolute differences and widen to 16-bit in one step - // Use vabdl_s8 for the low half - int16x8_t diff_low = vabdl_s8(vget_low_s8(v1), vget_low_s8(v2)); - int16x8_t diff_high = vabdl_high_s8(v1, v2); - - // Square and accumulate the differences using vmlal_s16 - sum = vmlal_s16(sum, vget_low_s16(diff_low), vget_low_s16(diff_low)); - sum = vmlal_s16(sum, vget_high_s16(diff_low), vget_high_s16(diff_low)); - sum = vmlal_s16(sum, vget_low_s16(diff_high), vget_low_s16(diff_high)); - sum = vmlal_s16(sum, vget_high_s16(diff_high), vget_high_s16(diff_high)); -} - -__attribute__((always_inline)) static inline void L2SquareStep16(int8_t *&pVect1, int8_t *&pVect2, - int32x4_t &sum) { - // Load 16 int8 elements (16 bytes) into NEON registers - int8x16_t v1 = vld1q_s8(pVect1); - int8x16_t v2 = vld1q_s8(pVect2); - - L2SquareOp(v1, v2, sum); - - pVect1 += 16; - pVect2 += 16; -} - -static inline void L2SquareStep32(int8_t *&pVect1, int8_t *&pVect2, int32x4_t &sum0, - int32x4_t &sum1) { - // Load 32 int8 elements (32 bytes) at once - int8x16x2_t v1 = vld1q_s8_x2(pVect1); - int8x16x2_t v2 = vld1q_s8_x2(pVect2); - - auto v1_0 = v1.val[0]; - auto v2_0 = v2.val[0]; - L2SquareOp(v1_0, v2_0, sum0); - - auto v1_1 = v1.val[1]; - auto v2_1 = v2.val[1]; - L2SquareOp(v1_1, v2_1, sum1); - - pVect1 += 32; - pVect2 += 32; -} - -template // 0..63 -float INT8_L2SqrSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - int8_t *pVect1 = (int8_t *)pVect1v; - int8_t *pVect2 = (int8_t *)pVect2v; - - int32x4_t sum0 = vdupq_n_s32(0); - int32x4_t sum1 = vdupq_n_s32(0); - int32x4_t sum2 = vdupq_n_s32(0); - int32x4_t sum3 = vdupq_n_s32(0); - - constexpr size_t final_residual = residual % 16; - if constexpr (final_residual > 0) { - // Define a compile-time constant mask based on final_residual - constexpr uint8x16_t mask = { - 0xFF, - (final_residual >= 2) ? 0xFF : 0, - (final_residual >= 3) ? 0xFF : 0, - (final_residual >= 4) ? 0xFF : 0, - (final_residual >= 5) ? 0xFF : 0, - (final_residual >= 6) ? 0xFF : 0, - (final_residual >= 7) ? 0xFF : 0, - (final_residual >= 8) ? 0xFF : 0, - (final_residual >= 9) ? 0xFF : 0, - (final_residual >= 10) ? 0xFF : 0, - (final_residual >= 11) ? 0xFF : 0, - (final_residual >= 12) ? 0xFF : 0, - (final_residual >= 13) ? 0xFF : 0, - (final_residual >= 14) ? 0xFF : 0, - (final_residual >= 15) ? 0xFF : 0, - 0, - }; - - // Load data directly from input vectors - int8x16_t v1 = vld1q_s8(pVect1); - int8x16_t v2 = vld1q_s8(pVect2); - - // Zero vector for replacement - int8x16_t zeros = vdupq_n_s8(0); - - // Apply bit select to zero out irrelevant elements - v1 = vbslq_s8(mask, v1, zeros); - v2 = vbslq_s8(mask, v2, zeros); - L2SquareOp(v1, v2, sum0); - pVect1 += final_residual; - pVect2 += final_residual; - } - - // Process 64 elements at a time in the main loop - size_t num_of_chunks = dimension / 64; - - for (size_t i = 0; i < num_of_chunks; i++) { - L2SquareStep32(pVect1, pVect2, sum0, sum1); - L2SquareStep32(pVect1, pVect2, sum2, sum3); - } - - constexpr size_t num_of_32_chunks = residual / 32; - - if constexpr (num_of_32_chunks) { - L2SquareStep32(pVect1, pVect2, sum0, sum1); - } - - constexpr size_t residual_chunks = (residual % 32) / 16; - if constexpr (residual_chunks >= 1) { - L2SquareStep16(pVect1, pVect2, sum2); - } - if constexpr (residual_chunks >= 2) { - L2SquareStep16(pVect1, pVect2, sum3); - } - - // Horizontal sum of the 4 elements in the sum register to get final result - int32x4_t total_sum = vaddq_s32(sum0, sum1); - - total_sum = vaddq_s32(total_sum, sum2); - total_sum = vaddq_s32(total_sum, sum3); - - // Horizontal sum of the 4 elements in the combined sum register - int32_t result = vaddvq_s32(total_sum); - - // Return the L2 squared distance as a float - return static_cast(result); -} diff --git a/src/VecSim/spaces/L2/L2_NEON_SQ8_FP32.h b/src/VecSim/spaces/L2/L2_NEON_SQ8_FP32.h deleted file mode 100644 index e98beb13e..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_SQ8_FP32.h +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_NEON_SQ8_FP32.h" -#include "VecSim/types/sq8.h" -#include - -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 L2 squared distance using algebraic identity: - * - * ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i² - * = x_sum_squares - 2 * IP(x, y) + y_sum_squares - * - * where: - * - IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) (computed via - * SQ8_FP32_InnerProductSIMD16_NEON_IMP) - * - x_sum_squares and y_sum_squares are precomputed - * - * This avoids dequantization in the hot loop. - */ - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_L2SqrSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Get the raw inner product using the common SIMD implementation - const float ip = SQ8_FP32_InnerProductSIMD16_NEON_IMP(pVect1v, pVect2v, dimension); - - // Get precomputed sum of squares from storage blob (pVect1v is SQ8 storage) - const uint8_t *pVect1 = static_cast(pVect1v); - const float *params = reinterpret_cast(pVect1 + dimension); - const float x_sum_sq = params[sq8::SUM_SQUARES]; - - // Get precomputed sum of squares from query blob (pVect2v is FP32 query) - const float y_sum_sq = static_cast(pVect2v)[dimension + sq8::SUM_SQUARES_QUERY]; - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return x_sum_sq + y_sum_sq - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_NEON_SQ8_SQ8.h b/src/VecSim/spaces/L2/L2_NEON_SQ8_SQ8.h deleted file mode 100644 index e86838404..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_SQ8_SQ8.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_NEON_SQ8_SQ8.h" -#include "VecSim/types/sq8.h" - -using sq8 = vecsim_types::sq8; - -/** - * SQ8-to-SQ8 L2 squared distance functions for NEON. - * Computes L2 squared distance between two SQ8 (scalar quantized 8-bit) vectors, - * where BOTH vectors are uint8 quantized. - * - * Uses the identity: ||x - y||² = ||x||² + ||y||² - 2*IP(x, y) - * where ||x||² and ||y||² are precomputed sum of squares stored in the vector data. - * - * Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] - * [sum_of_squares (float)] - */ - -// L2 squared distance using the common inner product implementation -template // 0..63 -float SQ8_SQ8_L2SqrSIMD64_NEON(const void *pVec1v, const void *pVec2v, size_t dimension) { - // Use the common inner product implementation (returns raw IP, not distance) - const float ip = SQ8_SQ8_InnerProductSIMD64_NEON_IMP(pVec1v, pVec2v, dimension); - - const uint8_t *pVec1 = static_cast(pVec1v); - const uint8_t *pVec2 = static_cast(pVec2v); - - // Get precomputed sum of squares from both vectors - // Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares] - const float sum_sq_1 = - *reinterpret_cast(pVec1 + dimension + sq8::SUM_SQUARES * sizeof(float)); - const float sum_sq_2 = - *reinterpret_cast(pVec2 + dimension + sq8::SUM_SQUARES * sizeof(float)); - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return sum_sq_1 + sum_sq_2 - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_NEON_UINT8.h b/src/VecSim/spaces/L2/L2_NEON_UINT8.h deleted file mode 100644 index aa3769867..000000000 --- a/src/VecSim/spaces/L2/L2_NEON_UINT8.h +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -__attribute__((always_inline)) static inline void -L2SquareOp(const uint8x16_t &v1, const uint8x16_t &v2, uint32x4_t &sum) { - // Compute absolute differences and widen to 16-bit in one step - uint16x8_t diff_low = vabdl_u8(vget_low_u8(v1), vget_low_u8(v2)); - uint16x8_t diff_high = vabdl_u8(vget_high_u8(v1), vget_high_u8(v2)); - - // Square and accumulate the differences using vmlal_u16 - sum = vmlal_u16(sum, vget_low_u16(diff_low), vget_low_u16(diff_low)); - sum = vmlal_u16(sum, vget_high_u16(diff_low), vget_high_u16(diff_low)); - sum = vmlal_u16(sum, vget_low_u16(diff_high), vget_low_u16(diff_high)); - sum = vmlal_u16(sum, vget_high_u16(diff_high), vget_high_u16(diff_high)); -} - -__attribute__((always_inline)) static inline void L2SquareStep16(uint8_t *&pVect1, uint8_t *&pVect2, - uint32x4_t &sum) { - // Load 16 uint8 elements into NEON registers - uint8x16_t v1 = vld1q_u8(pVect1); - uint8x16_t v2 = vld1q_u8(pVect2); - - L2SquareOp(v1, v2, sum); - - pVect1 += 16; - pVect2 += 16; -} - -__attribute__((always_inline)) static inline void -L2SquareStep32(uint8_t *&pVect1, uint8_t *&pVect2, uint32x4_t &sum1, uint32x4_t &sum2) { - uint8x16x2_t v1_pair = vld1q_u8_x2(pVect1); - uint8x16x2_t v2_pair = vld1q_u8_x2(pVect2); - - // Reference the individual vectors - uint8x16_t v1_first = v1_pair.val[0]; - uint8x16_t v1_second = v1_pair.val[1]; - uint8x16_t v2_first = v2_pair.val[0]; - uint8x16_t v2_second = v2_pair.val[1]; - - L2SquareOp(v1_first, v2_first, sum1); - L2SquareOp(v1_second, v2_second, sum2); - - pVect1 += 32; - pVect2 += 32; -} - -template // 0..63 -float UINT8_L2SqrSIMD16_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { - uint8_t *pVect1 = (uint8_t *)pVect1v; - uint8_t *pVect2 = (uint8_t *)pVect2v; - - uint32x4_t sum0 = vdupq_n_u32(0); - uint32x4_t sum1 = vdupq_n_u32(0); - uint32x4_t sum2 = vdupq_n_u32(0); - uint32x4_t sum3 = vdupq_n_u32(0); - - constexpr size_t final_residual = residual % 16; - if constexpr (final_residual > 0) { - constexpr uint8x16_t mask = { - 0xFF, - (final_residual >= 2) ? 0xFF : 0, - (final_residual >= 3) ? 0xFF : 0, - (final_residual >= 4) ? 0xFF : 0, - (final_residual >= 5) ? 0xFF : 0, - (final_residual >= 6) ? 0xFF : 0, - (final_residual >= 7) ? 0xFF : 0, - (final_residual >= 8) ? 0xFF : 0, - (final_residual >= 9) ? 0xFF : 0, - (final_residual >= 10) ? 0xFF : 0, - (final_residual >= 11) ? 0xFF : 0, - (final_residual >= 12) ? 0xFF : 0, - (final_residual >= 13) ? 0xFF : 0, - (final_residual >= 14) ? 0xFF : 0, - (final_residual >= 15) ? 0xFF : 0, - 0, - }; - - // Load data directly from input vectors - uint8x16_t v1 = vld1q_u8(pVect1); - uint8x16_t v2 = vld1q_u8(pVect2); - - // Zero vector for replacement - uint8x16_t zeros = vdupq_n_u8(0); - - // Apply bit select to zero out irrelevant elements - v1 = vbslq_u8(mask, v1, zeros); - v2 = vbslq_u8(mask, v2, zeros); - L2SquareOp(v1, v2, sum1); - pVect1 += final_residual; - pVect2 += final_residual; - } - - // Process 64 elements at a time in the main loop - size_t num_of_chunks = dimension / 64; - - for (size_t i = 0; i < num_of_chunks; i++) { - L2SquareStep32(pVect1, pVect2, sum0, sum2); - L2SquareStep32(pVect1, pVect2, sum1, sum3); - } - - constexpr size_t num_of_32_chunks = residual / 32; - - if constexpr (num_of_32_chunks) { - L2SquareStep32(pVect1, pVect2, sum0, sum1); - } - - // Handle residual elements (0-63) - // First, process full chunks of 16 elements - constexpr size_t residual_chunks = (residual % 32) / 16; - if constexpr (residual_chunks > 0) { - L2SquareStep16(pVect1, pVect2, sum0); - } - - // Horizontal sum of the 4 elements in the sum register to get final result - uint32x4_t total_sum = vaddq_u32(sum0, sum1); - - total_sum = vaddq_u32(total_sum, sum2); - total_sum = vaddq_u32(total_sum, sum3); - - // Horizontal sum of the 4 elements in the combined sum register - int32_t result = vaddvq_u32(total_sum); - - // Return the L2 squared distance as a float - return static_cast(result); -} diff --git a/src/VecSim/spaces/L2/L2_SSE3_BF16.h b/src/VecSim/spaces/L2/L2_SSE3_BF16.h deleted file mode 100644 index fac2e5ca1..000000000 --- a/src/VecSim/spaces/L2/L2_SSE3_BF16.h +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/types/bfloat16.h" - -using bfloat16 = vecsim_types::bfloat16; - -static inline void L2SqrLowHalfStep(__m128i v1, __m128i v2, __m128i zeros, __m128 &sum) { - // Convert next 0..3 bf16 to 4 floats - __m128i bf16_low1 = _mm_unpacklo_epi16(zeros, v1); // SSE2 - __m128i bf16_low2 = _mm_unpacklo_epi16(zeros, v2); - - __m128 diff = _mm_sub_ps(_mm_castsi128_ps(bf16_low1), _mm_castsi128_ps(bf16_low2)); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); -} - -static inline void L2SqrHighHalfStep(__m128i v1, __m128i v2, __m128i zeros, __m128 &sum) { - // Convert next 4..7 bf16 to 4 floats - __m128i bf16_high1 = _mm_unpackhi_epi16(zeros, v1); - __m128i bf16_high2 = _mm_unpackhi_epi16(zeros, v2); - - __m128 diff = _mm_sub_ps(_mm_castsi128_ps(bf16_high1), _mm_castsi128_ps(bf16_high2)); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); -} - -static inline void L2SqrStep(bfloat16 *&pVect1, bfloat16 *&pVect2, __m128 &sum) { - // Load 8 bf16 elements - __m128i v1 = _mm_lddqu_si128((__m128i *)pVect1); // SSE3 - pVect1 += 8; - __m128i v2 = _mm_lddqu_si128((__m128i *)pVect2); - pVect2 += 8; - - __m128i zeros = _mm_setzero_si128(); // SSE2 - - // Compute dist for 0..3 bf16 - L2SqrLowHalfStep(v1, v2, zeros, sum); - - // Compute dist for 4..7 bf16 - L2SqrHighHalfStep(v1, v2, zeros, sum); -} - -template // 0..31 -float BF16_L2SqrSIMD32_SSE3(const void *pVect1v, const void *pVect2v, size_t dimension) { - bfloat16 *pVect1 = (bfloat16 *)pVect1v; - bfloat16 *pVect2 = (bfloat16 *)pVect2v; - - const bfloat16 *pEnd1 = pVect1 + dimension; - - __m128 sum = _mm_setzero_ps(); - - // Handle first residual % 8 elements (smaller than step chunk size) - - // Handle residual % 4 - if constexpr (residual % 4) { - __m128i v1, v2; - constexpr bfloat16 zero = bfloat16(0); - if constexpr (residual % 4 == 3) { - v1 = _mm_setr_epi16(zero, pVect1[0], zero, pVect1[1], zero, pVect1[2], zero, - zero); // SSE2 - v2 = _mm_setr_epi16(zero, pVect2[0], zero, pVect2[1], zero, pVect2[2], zero, zero); - } else if constexpr (residual % 4 == 2) { - // Load 2 bf16 element set the rest to 0 - v1 = _mm_setr_epi16(zero, pVect1[0], zero, pVect1[1], zero, zero, zero, zero); // SSE2 - v2 = _mm_setr_epi16(zero, pVect2[0], zero, pVect2[1], zero, zero, zero, zero); - } else if constexpr (residual % 4 == 1) { - // Load only first element - v1 = _mm_setr_epi16(zero, pVect1[0], zero, zero, zero, zero, zero, zero); // SSE2 - v2 = _mm_setr_epi16(zero, pVect2[0], zero, zero, zero, zero, zero, zero); - } - __m128 diff = _mm_sub_ps(_mm_castsi128_ps(v1), _mm_castsi128_ps(v2)); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - pVect1 += residual % 4; - pVect2 += residual % 4; - } - - // If (residual % 8 >= 4) we need to handle 4 more elements - if constexpr (residual % 8 >= 4) { - __m128i v1 = _mm_lddqu_si128((__m128i *)pVect1); - __m128i v2 = _mm_lddqu_si128((__m128i *)pVect2); - L2SqrLowHalfStep(v1, v2, _mm_setzero_si128(), sum); - pVect1 += 4; - pVect2 += 4; - } - - // Handle (residual - (residual % 8)) in chunks of 8 bfloat16 - if constexpr (residual >= 24) - L2SqrStep(pVect1, pVect2, sum); - if constexpr (residual >= 16) - L2SqrStep(pVect1, pVect2, sum); - if constexpr (residual >= 8) - L2SqrStep(pVect1, pVect2, sum); - - // Handle 512 bits (32 bfloat16) in chunks of max SIMD = 128 bits = 8 bfloat16 - do { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - // TmpRes must be 16 bytes aligned - float PORTABLE_ALIGN16 TmpRes[4]; - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; -} diff --git a/src/VecSim/spaces/L2/L2_SSE4_SQ8_FP32.h b/src/VecSim/spaces/L2/L2_SSE4_SQ8_FP32.h deleted file mode 100644 index 29c662786..000000000 --- a/src/VecSim/spaces/L2/L2_SSE4_SQ8_FP32.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_SSE4_SQ8_FP32.h" -#include "VecSim/types/sq8.h" - -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8 L2 squared distance using algebraic identity: - * - * ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i² - * = x_sum_squares - 2 * IP(x, y) + y_sum_squares - * - * where: - * - IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) (computed via - * SQ8_FP32_InnerProductSIMD16_SSE4_IMP) - * - x_sum_squares and y_sum_squares are precomputed - * - * This avoids dequantization in the hot loop. - */ - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template // 0..15 -float SQ8_FP32_L2SqrSIMD16_SSE4(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Get the raw inner product using the common SIMD implementation - const float ip = SQ8_FP32_InnerProductSIMD16_SSE4_IMP(pVect1v, pVect2v, dimension); - - // Get precomputed sum of squares from storage blob (pVect1v is SQ8 storage) - const uint8_t *pVect1 = static_cast(pVect1v); - const float *params = reinterpret_cast(pVect1 + dimension); - const float x_sum_sq = params[sq8::SUM_SQUARES]; - - // Get precomputed sum of squares from query blob (pVect2v is FP32 query) - const float y_sum_sq = static_cast(pVect2v)[dimension + sq8::SUM_SQUARES_QUERY]; - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return x_sum_sq + y_sum_sq - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_SSE_FP32.h b/src/VecSim/spaces/L2/L2_SSE_FP32.h deleted file mode 100644 index e04cc4fe5..000000000 --- a/src/VecSim/spaces/L2/L2_SSE_FP32.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void L2SqrStep(float *&pVect1, float *&pVect2, __m128 &sum) { - __m128 v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - __m128 v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - __m128 diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); -} - -template // 0..15 -float FP32_L2SqrSIMD16_SSE(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - - const float *pEnd1 = pVect1 + dimension; - - __m128 sum = _mm_setzero_ps(); - - // Deal with %4 remainder first. `dim` is >16, so we have at least one 16-float block, - // so loading 4 floats and then masking them is safe. - if constexpr (residual % 4) { - __m128 v1, v2, diff; - if constexpr (residual % 4 == 3) { - // Load 3 floats and set the last one to 0 - v1 = _mm_loadr_ps(pVect1); // load 4 floats - v2 = _mm_loadr_ps(pVect2); - // sets the last float of v1 to the last of v2, so the diff is 0. - v1 = _mm_move_ss(v1, v2); - } else if constexpr (residual % 4 == 2) { - // Load 2 floats and set the last two to 0 - v1 = _mm_loadh_pi(_mm_setzero_ps(), (__m64 *)pVect1); - v2 = _mm_loadh_pi(_mm_setzero_ps(), (__m64 *)pVect2); - } else if constexpr (residual % 4 == 1) { - // Load 1 float and set the last three to 0 - v1 = _mm_load_ss(pVect1); - v2 = _mm_load_ss(pVect2); - } - pVect1 += residual % 4; - pVect2 += residual % 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_mul_ps(diff, diff); - } - - // have another 1, 2 or 3 4-floats steps according to residual - if constexpr (residual >= 12) - L2SqrStep(pVect1, pVect2, sum); - if constexpr (residual >= 8) - L2SqrStep(pVect1, pVect2, sum); - if constexpr (residual >= 4) - L2SqrStep(pVect1, pVect2, sum); - - // We dealt with the residual part. We are left with some multiple of 16 floats. - // In each iteration we calculate 16 floats = 512 bits. - do { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - // TmpRes must be 16 bytes aligned - float PORTABLE_ALIGN16 TmpRes[4]; - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; -} diff --git a/src/VecSim/spaces/L2/L2_SSE_FP64.h b/src/VecSim/spaces/L2/L2_SSE_FP64.h deleted file mode 100644 index 4640c1cbd..000000000 --- a/src/VecSim/spaces/L2/L2_SSE_FP64.h +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" - -static inline void L2SqrStep(double *&pVect1, double *&pVect2, __m128d &sum) { - __m128d v1 = _mm_loadu_pd(pVect1); - pVect1 += 2; - __m128d v2 = _mm_loadu_pd(pVect2); - pVect2 += 2; - __m128d diff = _mm_sub_pd(v1, v2); - sum = _mm_add_pd(sum, _mm_mul_pd(diff, diff)); -} - -template // 0..7 -double FP64_L2SqrSIMD8_SSE(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - - const double *pEnd1 = pVect1 + dimension; - - __m128d sum = _mm_setzero_pd(); - - // If residual is odd, we load 1 double and set the last one to 0 - if constexpr (residual % 2 == 1) { - __m128d v1 = _mm_load_sd(pVect1); - pVect1++; - __m128d v2 = _mm_load_sd(pVect2); - pVect2++; - __m128d diff = _mm_sub_pd(v1, v2); - sum = _mm_mul_pd(diff, diff); - } - - // have another 1, 2 or 3 2-double steps according to residual - if constexpr (residual >= 6) - L2SqrStep(pVect1, pVect2, sum); - if constexpr (residual >= 4) - L2SqrStep(pVect1, pVect2, sum); - if constexpr (residual >= 2) - L2SqrStep(pVect1, pVect2, sum); - - // We dealt with the residual part. We are left with some multiple of 8 doubles. - // In each iteration we calculate 8 doubles = 512 bits in total. - do { - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - L2SqrStep(pVect1, pVect2, sum); - } while (pVect1 < pEnd1); - - // TmpRes must be 16 bytes aligned - double PORTABLE_ALIGN16 TmpRes[2]; - _mm_store_pd(TmpRes, sum); - return TmpRes[0] + TmpRes[1]; -} diff --git a/src/VecSim/spaces/L2/L2_SVE_BF16.h b/src/VecSim/spaces/L2/L2_SVE_BF16.h deleted file mode 100644 index 0eec1a2c5..000000000 --- a/src/VecSim/spaces/L2/L2_SVE_BF16.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include - -// Assumes little-endianess -inline void L2Sqr_Op(svfloat32_t &acc, svbfloat16_t &v1, svbfloat16_t &v2) { - svfloat32_t v1_lo = svreinterpret_f32(svzip1(svdup_u16(0), svreinterpret_u16(v1))); - svfloat32_t v2_lo = svreinterpret_f32(svzip1(svdup_u16(0), svreinterpret_u16(v2))); - svfloat32_t diff_lo = svsub_f32_x(svptrue_b32(), v1_lo, v2_lo); - - acc = svmla_f32_x(svptrue_b32(), acc, diff_lo, diff_lo); - - svfloat32_t v1_hi = svreinterpret_f32(svzip2(svdup_u16(0), svreinterpret_u16(v1))); - svfloat32_t v2_hi = svreinterpret_f32(svzip2(svdup_u16(0), svreinterpret_u16(v2))); - svfloat32_t diff_hi = svsub_f32_x(svptrue_b32(), v1_hi, v2_hi); - - acc = svmla_f32_x(svptrue_b32(), acc, diff_hi, diff_hi); -} - -inline void L2Sqr_Step(const bfloat16_t *vec1, const bfloat16_t *vec2, svfloat32_t &acc, - size_t &offset, const size_t chunk) { - svbool_t all = svptrue_b16(); - - // Load brain-half-precision vectors. - svbfloat16_t v1 = svld1_bf16(all, vec1 + offset); - svbfloat16_t v2 = svld1_bf16(all, vec2 + offset); - // Compute multiplications and add to the accumulator - L2Sqr_Op(acc, v1, v2); - - // Move to next chunk - offset += chunk; -} - -template // [t/f, 0..3] -float BF16_L2Sqr_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *vec1 = static_cast(pVect1v); - const auto *vec2 = static_cast(pVect2v); - const size_t chunk = svcnth(); // number of 16-bit elements in a register - svfloat32_t acc1 = svdup_f32(0.0f); - svfloat32_t acc2 = svdup_f32(0.0f); - svfloat32_t acc3 = svdup_f32(0.0f); - svfloat32_t acc4 = svdup_f32(0.0f); - size_t offset = 0; - - // Process all full vectors - const size_t full_iterations = dimension / chunk / 4; - for (size_t iter = 0; iter < full_iterations; iter++) { - L2Sqr_Step(vec1, vec2, acc1, offset, chunk); - L2Sqr_Step(vec1, vec2, acc2, offset, chunk); - L2Sqr_Step(vec1, vec2, acc3, offset, chunk); - L2Sqr_Step(vec1, vec2, acc4, offset, chunk); - } - - // Perform between 0 and 3 additional steps, according to `additional_steps` value - if constexpr (additional_steps >= 1) - L2Sqr_Step(vec1, vec2, acc1, offset, chunk); - if constexpr (additional_steps >= 2) - L2Sqr_Step(vec1, vec2, acc2, offset, chunk); - if constexpr (additional_steps >= 3) - L2Sqr_Step(vec1, vec2, acc3, offset, chunk); - - // Handle the tail with the residual predicate - if constexpr (partial_chunk) { - svbool_t pg = svwhilelt_b16_u64(offset, dimension); - - // Load brain-half-precision vectors. - // Inactive elements are zeros, according to the docs - svbfloat16_t v1 = svld1_bf16(pg, vec1 + offset); - svbfloat16_t v2 = svld1_bf16(pg, vec2 + offset); - // Compute multiplications and add to the accumulator. - L2Sqr_Op(acc4, v1, v2); - } - - // Accumulate accumulators - acc1 = svadd_f32_x(svptrue_b32(), acc1, acc3); - acc2 = svadd_f32_x(svptrue_b32(), acc2, acc4); - acc1 = svadd_f32_x(svptrue_b32(), acc1, acc2); - - // Reduce the accumulated sum. - float result = svaddv_f32(svptrue_b32(), acc1); - return result; -} diff --git a/src/VecSim/spaces/L2/L2_SVE_FP16.h b/src/VecSim/spaces/L2/L2_SVE_FP16.h deleted file mode 100644 index 24b5ee2df..000000000 --- a/src/VecSim/spaces/L2/L2_SVE_FP16.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include - -inline void L2Sqr_Step(const float16_t *vec1, const float16_t *vec2, svfloat16_t &acc, - size_t &offset, const size_t chunk) { - svbool_t all = svptrue_b16(); - - svfloat16_t v1 = svld1_f16(all, vec1 + offset); - svfloat16_t v2 = svld1_f16(all, vec2 + offset); - // Compute difference in half precision. - svfloat16_t diff = svsub_f16_x(all, v1, v2); - // Square the differences and accumulate - acc = svmla_f16_x(all, acc, diff, diff); - offset += chunk; -} - -template // [t/f, 0..3] -float FP16_L2Sqr_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - const auto *vec1 = static_cast(pVect1v); - const auto *vec2 = static_cast(pVect2v); - const size_t chunk = svcnth(); // number of 16-bit elements in a register - svbool_t all = svptrue_b16(); - svfloat16_t acc1 = svdup_f16(0.0f); - svfloat16_t acc2 = svdup_f16(0.0f); - svfloat16_t acc3 = svdup_f16(0.0f); - svfloat16_t acc4 = svdup_f16(0.0f); - size_t offset = 0; - - // Process all full vectors - const size_t full_iterations = dimension / chunk / 4; - for (size_t iter = 0; iter < full_iterations; iter++) { - L2Sqr_Step(vec1, vec2, acc1, offset, chunk); - L2Sqr_Step(vec1, vec2, acc2, offset, chunk); - L2Sqr_Step(vec1, vec2, acc3, offset, chunk); - L2Sqr_Step(vec1, vec2, acc4, offset, chunk); - } - - // Perform between 0 and 3 additional steps, according to `additional_steps` value - if constexpr (additional_steps >= 1) - L2Sqr_Step(vec1, vec2, acc1, offset, chunk); - if constexpr (additional_steps >= 2) - L2Sqr_Step(vec1, vec2, acc2, offset, chunk); - if constexpr (additional_steps >= 3) - L2Sqr_Step(vec1, vec2, acc3, offset, chunk); - - // Handle partial chunk, if needed - if constexpr (partial_chunk) { - svbool_t pg = svwhilelt_b16_u64(offset, dimension); - - // Load half-precision vectors. - svfloat16_t v1 = svld1_f16(pg, vec1 + offset); - svfloat16_t v2 = svld1_f16(pg, vec2 + offset); - // Compute difference in half precision. - svfloat16_t diff = svsub_f16_x(pg, v1, v2); - // Square the differences. - // Use `m` suffix to keep the inactive elements as they are in `acc` - acc4 = svmla_f16_m(pg, acc4, diff, diff); - } - - // Accumulate accumulators - acc1 = svadd_f16_x(all, acc1, acc3); - acc2 = svadd_f16_x(all, acc2, acc4); - acc1 = svadd_f16_x(all, acc1, acc2); - - // Reduce the accumulated sum. - float result = svaddv_f16(all, acc1); - return result; -} diff --git a/src/VecSim/spaces/L2/L2_SVE_FP32.h b/src/VecSim/spaces/L2/L2_SVE_FP32.h deleted file mode 100644 index 8367baa97..000000000 --- a/src/VecSim/spaces/L2/L2_SVE_FP32.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -static inline void L2SquareStep(float *&pVect1, float *&pVect2, size_t &offset, svfloat32_t &sum, - const size_t chunk) { - // Load vectors - svfloat32_t v1 = svld1_f32(svptrue_b32(), pVect1 + offset); - svfloat32_t v2 = svld1_f32(svptrue_b32(), pVect2 + offset); - - // Calculate difference between vectors - svfloat32_t diff = svsub_f32_x(svptrue_b32(), v1, v2); - - // Square the difference and accumulate: sum += diff * diff - sum = svmla_f32_z(svptrue_b32(), sum, diff, diff); - - // Advance pointers by the vector length - offset += chunk; -} - -template -float FP32_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - float *pVect1 = (float *)pVect1v; - float *pVect2 = (float *)pVect2v; - size_t offset = 0; - - // Get the number of 32-bit elements per vector at runtime - uint64_t chunk = svcntw(); - - // Multiple accumulators to increase instruction-level parallelism - svfloat32_t sum0 = svdup_f32(0.0f); - svfloat32_t sum1 = svdup_f32(0.0f); - svfloat32_t sum2 = svdup_f32(0.0f); - svfloat32_t sum3 = svdup_f32(0.0f); - - // Process vectors in chunks, with unrolling for better pipelining - auto chunk_size = 4 * chunk; - size_t number_of_chunks = dimension / chunk_size; - for (size_t i = 0; i < number_of_chunks; ++i) { - // Process 4 chunks with separate accumulators - L2SquareStep(pVect1, pVect2, offset, sum0, chunk); - L2SquareStep(pVect1, pVect2, offset, sum1, chunk); - L2SquareStep(pVect1, pVect2, offset, sum2, chunk); - L2SquareStep(pVect1, pVect2, offset, sum3, chunk); - } - - // Process remaining complete SVE vectors that didn't fit into the main loop - // These are full vector operations (0-3 elements) - if constexpr (additional_steps > 0) { - if constexpr (additional_steps >= 1) { - L2SquareStep(pVect1, pVect2, offset, sum0, chunk); - } - if constexpr (additional_steps >= 2) { - L2SquareStep(pVect1, pVect2, offset, sum1, chunk); - } - if constexpr (additional_steps >= 3) { - L2SquareStep(pVect1, pVect2, offset, sum2, chunk); - } - } - - // Process final tail elements that don't form a complete vector - // This section handles the case when dimension is not evenly divisible by SVE vector length - if constexpr (partial_chunk) { - // Create a predicate mask where each lane is active only for the remaining elements - svbool_t pg = - svwhilelt_b32(static_cast(offset), static_cast(dimension)); - - // Load vectors with predication - svfloat32_t v1 = svld1_f32(pg, pVect1 + offset); - svfloat32_t v2 = svld1_f32(pg, pVect2 + offset); - - svfloat32_t diff = svsub_f32_m(pg, v1, v2); - - sum3 = svmla_f32_m(pg, sum3, diff, diff); - } - - sum0 = svadd_f32_x(svptrue_b32(), sum0, sum1); - sum2 = svadd_f32_x(svptrue_b32(), sum2, sum3); - svfloat32_t sum_all = svadd_f32_x(svptrue_b32(), sum0, sum2); - float result = svaddv_f32(svptrue_b32(), sum_all); - return result; -} diff --git a/src/VecSim/spaces/L2/L2_SVE_FP64.h b/src/VecSim/spaces/L2/L2_SVE_FP64.h deleted file mode 100644 index 8fb822544..000000000 --- a/src/VecSim/spaces/L2/L2_SVE_FP64.h +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -inline void L2SquareStep(double *&pVect1, double *&pVect2, size_t &offset, svfloat64_t &sum, - const size_t chunk) { - // Load vectors - svfloat64_t v1 = svld1_f64(svptrue_b64(), pVect1 + offset); - svfloat64_t v2 = svld1_f64(svptrue_b64(), pVect2 + offset); - - // Calculate difference between vectors - svfloat64_t diff = svsub_f64_x(svptrue_b64(), v1, v2); - - // Square the difference and accumulate: sum += diff * diff - sum = svmla_f64_x(svptrue_b64(), sum, diff, diff); - - // Advance pointers by the vector length - offset += chunk; -} - -template -double FP64_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - double *pVect1 = (double *)pVect1v; - double *pVect2 = (double *)pVect2v; - const size_t chunk = svcntd(); - size_t offset = 0; - - // Multiple accumulators to increase instruction-level parallelism - svfloat64_t sum0 = svdup_f64(0.0); - svfloat64_t sum1 = svdup_f64(0.0); - svfloat64_t sum2 = svdup_f64(0.0); - svfloat64_t sum3 = svdup_f64(0.0); - - // Process vectors in chunks, with unrolling for better pipelining - auto chunk_size = 4 * chunk; - size_t number_of_chunks = dimension / chunk_size; - for (size_t i = 0; i < number_of_chunks; ++i) { - // Process 4 chunks with separate accumulators - L2SquareStep(pVect1, pVect2, offset, sum0, chunk); - L2SquareStep(pVect1, pVect2, offset, sum1, chunk); - L2SquareStep(pVect1, pVect2, offset, sum2, chunk); - L2SquareStep(pVect1, pVect2, offset, sum3, chunk); - } - - if constexpr (additional_steps >= 1) { - L2SquareStep(pVect1, pVect2, offset, sum0, chunk); - } - if constexpr (additional_steps >= 2) { - L2SquareStep(pVect1, pVect2, offset, sum1, chunk); - } - if constexpr (additional_steps >= 3) { - L2SquareStep(pVect1, pVect2, offset, sum2, chunk); - } - - if constexpr (partial_chunk) { - svbool_t pg = - svwhilelt_b64(static_cast(offset), static_cast(dimension)); - - // Load vectors with predication - svfloat64_t v1 = svld1_f64(pg, pVect1 + offset); - svfloat64_t v2 = svld1_f64(pg, pVect2 + offset); - - // Calculate difference with predication (corrected) - svfloat64_t diff = svsub_f64_x(pg, v1, v2); - - // Square the difference and accumulate with predication - sum3 = svmla_f64_m(pg, sum3, diff, diff); - } - - // Combine the partial sums - sum0 = svadd_f64_x(svptrue_b64(), sum0, sum1); - sum2 = svadd_f64_x(svptrue_b64(), sum2, sum3); - svfloat64_t sum_all = svadd_f64_x(svptrue_b64(), sum0, sum2); - double result = svaddv_f64(svptrue_b64(), sum_all); - return result; -} diff --git a/src/VecSim/spaces/L2/L2_SVE_INT8.h b/src/VecSim/spaces/L2/L2_SVE_INT8.h deleted file mode 100644 index af959d19a..000000000 --- a/src/VecSim/spaces/L2/L2_SVE_INT8.h +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -// Aligned step using svptrue_b8() -inline void L2SquareStep(const int8_t *&pVect1, const int8_t *&pVect2, size_t &offset, - svuint32_t &sum, const size_t chunk) { - svbool_t pg = svptrue_b8(); - // Note: Because all the bits are 1, the extention to 16 and 32 bits does not make a difference - // Otherwise, pg should be recalculated for 16 and 32 operations - - svint8_t v1_i8 = svld1_s8(pg, pVect1 + offset); // Load int8 vectors from pVect1 - svint8_t v2_i8 = svld1_s8(pg, pVect2 + offset); // Load int8 vectors from pVect2 - - // The result of svabd can be reinterpreted as uint8 - svuint8_t abs_diff = svreinterpret_u8_s8(svabd_s8_x(pg, v1_i8, v2_i8)); - - sum = svdot_u32(sum, abs_diff, abs_diff); - offset += chunk; // Move to the next set of int8 elements -} - -template -float INT8_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - const int8_t *pVect1 = reinterpret_cast(pVect1v); - const int8_t *pVect2 = reinterpret_cast(pVect2v); - - // number of uint8 per SVE register (we use uint accumulators) - const size_t vl = svcntb(); - const size_t chunk_size = 4 * vl; - svbool_t all = svptrue_b8(); - - // Each L2SquareStep adds maximum (2^8)^2 = 2^16 - // Therefor, on a single accumulator, we can perform 2^16 steps before overflowing - // That scenario will happen only is the dimension of the vector is larger than 16*4*2^16 = 2^22 - // (16 uint8 in 1 SVE register) * (4 accumulators) * (2^16 steps) - // We can safely assume that the dimension is smaller than that - // So using uint32_t is safe - - svuint32_t sum0 = svdup_u32(0); - svuint32_t sum1 = svdup_u32(0); - svuint32_t sum2 = svdup_u32(0); - svuint32_t sum3 = svdup_u32(0); - - size_t offset = 0; - size_t num_main_blocks = dimension / chunk_size; - - for (size_t i = 0; i < num_main_blocks; ++i) { - L2SquareStep(pVect1, pVect2, offset, sum0, vl); - L2SquareStep(pVect1, pVect2, offset, sum1, vl); - L2SquareStep(pVect1, pVect2, offset, sum2, vl); - L2SquareStep(pVect1, pVect2, offset, sum3, vl); - } - - if constexpr (additional_steps > 0) { - if constexpr (additional_steps >= 1) { - L2SquareStep(pVect1, pVect2, offset, sum0, vl); - } - if constexpr (additional_steps >= 2) { - L2SquareStep(pVect1, pVect2, offset, sum1, vl); - } - if constexpr (additional_steps >= 3) { - L2SquareStep(pVect1, pVect2, offset, sum2, vl); - } - } - - if constexpr (partial_chunk) { - - svbool_t pg = svwhilelt_b8_u64(offset, dimension); - - svint8_t v1_i8 = svld1_s8(pg, pVect1 + offset); // Load int8 vectors from pVect1 - svint8_t v2_i8 = svld1_s8(pg, pVect2 + offset); // Load int8 vectors from pVect2 - - // The result of svabd can be reinterpreted as uint8 - svuint8_t abs_diff = svreinterpret_u8_s8(svabd_s8_x(all, v1_i8, v2_i8)); - - // Can sum with taking into account pg because svld1 will set inactive lanes to 0 - sum3 = svdot_u32(sum3, abs_diff, abs_diff); - } - - sum0 = svadd_u32_x(all, sum0, sum1); - sum2 = svadd_u32_x(all, sum2, sum3); - svuint32_t sum_all = svadd_u32_x(all, sum0, sum2); - return svaddv_u32(svptrue_b32(), sum_all); -} diff --git a/src/VecSim/spaces/L2/L2_SVE_SQ8_FP32.h b/src/VecSim/spaces/L2/L2_SVE_SQ8_FP32.h deleted file mode 100644 index 0ae9fec74..000000000 --- a/src/VecSim/spaces/L2/L2_SVE_SQ8_FP32.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_SVE_SQ8_FP32.h" -#include "VecSim/types/sq8.h" -#include - -using sq8 = vecsim_types::sq8; - -/* - * Optimized asymmetric SQ8-FP32 L2 squared distance using algebraic identity: - * - * ||x - y||² = Σx_i² - 2*IP(x, y) + Σy_i² - * = x_sum_squares - 2 * IP(x, y) + y_sum_squares - * - * where: - * - IP(x, y) = min * y_sum + delta * Σ(q_i * y_i) (computed via - * SQ8_FP32_InnerProductSIMD_SVE_IMP) - * - x_sum_squares and y_sum_squares are precomputed - * - * This avoids dequantization in the hot loop. - */ - -// pVect1v = SQ8 storage, pVect2v = FP32 query -template -float SQ8_FP32_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - // Get the raw inner product using the common SIMD implementation - const float ip = SQ8_FP32_InnerProductSIMD_SVE_IMP( - pVect1v, pVect2v, dimension); - - // Get precomputed sum of squares from storage blob (pVect1v is SQ8 storage) - const uint8_t *pVect1 = static_cast(pVect1v); - const float *params = reinterpret_cast(pVect1 + dimension); - const float x_sum_sq = params[sq8::SUM_SQUARES]; - - // Get precomputed sum of squares from query blob (pVect2v is FP32 query) - const float y_sum_sq = static_cast(pVect2v)[dimension + sq8::SUM_SQUARES_QUERY]; - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return x_sum_sq + y_sum_sq - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_SVE_SQ8_SQ8.h b/src/VecSim/spaces/L2/L2_SVE_SQ8_SQ8.h deleted file mode 100644 index 90801f82a..000000000 --- a/src/VecSim/spaces/L2/L2_SVE_SQ8_SQ8.h +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/IP/IP_SVE_SQ8_SQ8.h" -#include "VecSim/types/sq8.h" - -using sq8 = vecsim_types::sq8; - -/** - * SQ8-to-SQ8 L2 squared distance functions for SVE. - * Computes L2 squared distance between two SQ8 (scalar quantized 8-bit) vectors, - * where BOTH vectors are uint8 quantized. - * - * Uses the identity: ||x - y||² = ||x||² + ||y||² - 2*IP(x, y) - * where ||x||² and ||y||² are precomputed sum of squares stored in the vector data. - * - * Vector layout: [uint8_t values (dim)] [min_val (float)] [delta (float)] [sum (float)] - * [sum_of_squares (float)] - */ - -// L2 squared distance using the common inner product implementation -template -float SQ8_SQ8_L2SqrSIMD_SVE(const void *pVec1v, const void *pVec2v, size_t dimension) { - // Use the common inner product implementation (returns raw IP, not distance) - const float ip = SQ8_SQ8_InnerProductSIMD_SVE_IMP( - pVec1v, pVec2v, dimension); - - const uint8_t *pVec1 = static_cast(pVec1v); - const uint8_t *pVec2 = static_cast(pVec2v); - - // Get precomputed sum of squares from both vectors - // Layout: [uint8_t values (dim)] [min_val] [delta] [sum] [sum_of_squares] - const float sum_sq_1 = - *reinterpret_cast(pVec1 + dimension + sq8::SUM_SQUARES * sizeof(float)); - const float sum_sq_2 = - *reinterpret_cast(pVec2 + dimension + sq8::SUM_SQUARES * sizeof(float)); - - // L2² = ||x||² + ||y||² - 2*IP(x, y) - return sum_sq_1 + sum_sq_2 - 2.0f * ip; -} diff --git a/src/VecSim/spaces/L2/L2_SVE_UINT8.h b/src/VecSim/spaces/L2/L2_SVE_UINT8.h deleted file mode 100644 index 553db2169..000000000 --- a/src/VecSim/spaces/L2/L2_SVE_UINT8.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include - -// Aligned step using svptrue_b8() -inline void L2SquareStep(const uint8_t *&pVect1, const uint8_t *&pVect2, size_t &offset, - svuint32_t &sum, const size_t chunk) { - svbool_t pg = svptrue_b8(); - // Note: Because all the bits are 1, the extention to 16 and 32 bits does not make a difference - // Otherwise, pg should be recalculated for 16 and 32 operations - - svuint8_t v1_ui8 = svld1_u8(pg, pVect1 + offset); // Load uint8 vectors from pVect1 - svuint8_t v2_ui8 = svld1_u8(pg, pVect2 + offset); // Load uint8 vectors from pVect2 - - svuint8_t abs_diff = svabd_u8_x(pg, v1_ui8, v2_ui8); - - sum = svdot_u32(sum, abs_diff, abs_diff); - - offset += chunk; // Move to the next set of uint8 elements -} - -template -float UINT8_L2SqrSIMD_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { - const uint8_t *pVect1 = reinterpret_cast(pVect1v); - const uint8_t *pVect2 = reinterpret_cast(pVect2v); - - // number of uint8 per SVE register - const size_t vl = svcntb(); - const size_t chunk_size = 4 * vl; - svbool_t all = svptrue_b8(); - - // Each L2SquareStep adds maximum (2^8)^2 = 2^16 - // Therefor, on a single accumulator, we can perform 2^16 steps before overflowing - // That scenario will happen only is the dimension of the vector is larger than 16*4*2^16 = 2^22 - // (16 uint8 in 1 SVE register) * (4 accumulators) * (2^16 steps) - // We can safely assume that the dimension is smaller than that - // So using uint32_t is safe - - svuint32_t sum0 = svdup_u32(0); - svuint32_t sum1 = svdup_u32(0); - svuint32_t sum2 = svdup_u32(0); - svuint32_t sum3 = svdup_u32(0); - - size_t offset = 0; - size_t num_main_blocks = dimension / chunk_size; - - for (size_t i = 0; i < num_main_blocks; ++i) { - L2SquareStep(pVect1, pVect2, offset, sum0, vl); - L2SquareStep(pVect1, pVect2, offset, sum1, vl); - L2SquareStep(pVect1, pVect2, offset, sum2, vl); - L2SquareStep(pVect1, pVect2, offset, sum3, vl); - } - - if constexpr (additional_steps > 0) { - if constexpr (additional_steps >= 1) { - L2SquareStep(pVect1, pVect2, offset, sum0, vl); - } - if constexpr (additional_steps >= 2) { - L2SquareStep(pVect1, pVect2, offset, sum1, vl); - } - if constexpr (additional_steps >= 3) { - L2SquareStep(pVect1, pVect2, offset, sum2, vl); - } - } - - if constexpr (partial_chunk) { - - svbool_t pg = svwhilelt_b8_u64(offset, dimension); - svuint8_t v1_ui8 = svld1_u8(pg, pVect1 + offset); // Load uint8 vectors from pVect1 - svuint8_t v2_ui8 = svld1_u8(pg, pVect2 + offset); // Load uint8 vectors from pVect2 - - svuint8_t abs_diff = svabd_u8_x(all, v1_ui8, v2_ui8); - - // Can sum with taking into account pg because svld1 will set inactive lanes to 0 - sum3 = svdot_u32(sum3, abs_diff, abs_diff); - } - - sum0 = svadd_u32_x(all, sum0, sum1); - sum2 = svadd_u32_x(all, sum2, sum3); - svuint32_t sum_all = svadd_u32_x(all, sum0, sum2); - return svaddv_u32(svptrue_b32(), sum_all); -} diff --git a/src/VecSim/spaces/L2_space.cpp b/src/VecSim/spaces/L2_space.cpp deleted file mode 100644 index dcccd513f..000000000 --- a/src/VecSim/spaces/L2_space.cpp +++ /dev/null @@ -1,468 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/L2_space.h" -#include "VecSim/spaces/L2/L2.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include "VecSim/spaces/functions/F16C.h" -#include "VecSim/spaces/functions/AVX512F.h" -#include "VecSim/spaces/functions/AVX.h" -#include "VecSim/spaces/functions/SSE.h" -#include "VecSim/spaces/functions/AVX512BW_VBMI2.h" -#include "VecSim/spaces/functions/AVX512FP16_VL.h" -#include "VecSim/spaces/functions/AVX512F_BW_VL_VNNI.h" -#include "VecSim/spaces/functions/AVX2.h" -#include "VecSim/spaces/functions/AVX2_FMA.h" -#include "VecSim/spaces/functions/SSE3.h" -#include "VecSim/spaces/functions/SSE4.h" -#include "VecSim/spaces/functions/NEON.h" -#include "VecSim/spaces/functions/NEON_DOTPROD.h" -#include "VecSim/spaces/functions/NEON_HP.h" -#include "VecSim/spaces/functions/NEON_BF16.h" -#include "VecSim/spaces/functions/SVE.h" -#include "VecSim/spaces/functions/SVE_BF16.h" -#include "VecSim/spaces/functions/SVE2.h" - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; - -namespace spaces { - -// SQ8-FP32: asymmetric L2 distance between SQ8 storage and FP32 query -dist_func_t L2_SQ8_FP32_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (!alignment) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = SQ8_FP32_L2Sqr; - - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_SQ8_FP32_L2_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_SQ8_FP32_L2_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd) { - return Choose_SQ8_FP32_L2_implementation_NEON(dim); - } -#endif -#endif - -#ifdef CPU_FEATURES_ARCH_X86_64 - // Optimizations assume at least 16 floats. If we have less, we use the naive implementation. - - if (dim < 16) { - return ret_dist_func; - } -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vnni) { - return Choose_SQ8_FP32_L2_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#ifdef OPT_AVX2_FMA - if (features.avx2 && features.fma3) { - return Choose_SQ8_FP32_L2_implementation_AVX2_FMA(dim); - } -#endif -#ifdef OPT_AVX2 - if (features.avx2) { - return Choose_SQ8_FP32_L2_implementation_AVX2(dim); - } -#endif -#ifdef OPT_SSE4 - if (features.sse4_1) { - return Choose_SQ8_FP32_L2_implementation_SSE4(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t L2_FP32_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { - unsigned char dummy_alignment; - if (!alignment) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = FP32_L2Sqr; - - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_FP32_L2_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_FP32_L2_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd) { - return Choose_FP32_L2_implementation_NEON(dim); - } -#endif -#endif - -#ifdef CPU_FEATURES_ARCH_X86_64 - // Optimizations assume at least 16 floats. If we have less, we use the naive implementation. - - if (dim < 16) { - return ret_dist_func; - } -#ifdef OPT_AVX512F - if (features.avx512f) { - if (dim % 16 == 0) // no point in aligning if we have an offsetting residual - *alignment = 16 * sizeof(float); // handles 16 floats - return Choose_FP32_L2_implementation_AVX512F(dim); - } -#endif -#ifdef OPT_AVX - if (features.avx) { - if (dim % 8 == 0) // no point in aligning if we have an offsetting residual - *alignment = 8 * sizeof(float); // handles 8 floats - return Choose_FP32_L2_implementation_AVX(dim); - } -#endif -#ifdef OPT_SSE - if (features.sse) { - if (dim % 4 == 0) // no point in aligning if we have an offsetting residual - *alignment = 4 * sizeof(float); // handles 4 floats - return Choose_FP32_L2_implementation_SSE(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t L2_FP64_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (!alignment) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = FP64_L2Sqr; - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_FP64_L2_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_FP64_L2_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd) { - return Choose_FP64_L2_implementation_NEON(dim); - } -#endif -#endif - -#ifdef CPU_FEATURES_ARCH_X86_64 - // Optimizations assume at least 8 doubles. If we have less, we use the naive implementation. - if (dim < 8) { - return ret_dist_func; - } -#ifdef OPT_AVX512F - if (features.avx512f) { - if (dim % 8 == 0) // no point in aligning if we have an offsetting residual - *alignment = 8 * sizeof(double); // handles 8 doubles - return Choose_FP64_L2_implementation_AVX512F(dim); - } -#endif -#ifdef OPT_AVX - if (features.avx) { - if (dim % 4 == 0) // no point in aligning if we have an offsetting residual - *alignment = 4 * sizeof(double); // handles 4 doubles - return Choose_FP64_L2_implementation_AVX(dim); - } -#endif -#ifdef OPT_SSE - if (features.sse) { - if (dim % 2 == 0) // no point in aligning if we have an offsetting residual - *alignment = 2 * sizeof(double); // handles 2 doubles - return Choose_FP64_L2_implementation_SSE(dim); - } -#endif -#endif // __x86_64__ */ - return ret_dist_func; -} - -dist_func_t L2_BF16_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { - unsigned char dummy_alignment; - if (!alignment) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = BF16_L2Sqr_LittleEndian; - if (!is_little_endian()) { - return BF16_L2Sqr_BigEndian; - } - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#if defined(CPU_FEATURES_ARCH_AARCH64) -#ifdef OPT_SVE_BF16 - if (features.svebf16) { - return Choose_BF16_L2_implementation_SVE_BF16(dim); - } -#endif -#ifdef OPT_NEON_BF16 - if (features.bf16 && dim >= 8) { // Optimization assumes at least 8 BF16s (full chunk) - return Choose_BF16_L2_implementation_NEON_BF16(dim); - } -#endif -#endif // AARCH64 - -#if defined(CPU_FEATURES_ARCH_X86_64) - // Optimizations assume at least 32 bfloats. If we have less, we use the naive implementation. - if (dim < 32) { - return ret_dist_func; - } -#ifdef OPT_AVX512_BW_VBMI2 - if (features.avx512bw && features.avx512vbmi2) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(bfloat16); // align to 512 bits. - return Choose_BF16_L2_implementation_AVX512BW_VBMI2(dim); - } -#endif -#ifdef OPT_AVX2 - if (features.avx2) { - if (dim % 16 == 0) // no point in aligning if we have an offsetting residual - *alignment = 16 * sizeof(bfloat16); // align to 256 bits. - return Choose_BF16_L2_implementation_AVX2(dim); - } -#endif -#ifdef OPT_SSE3 - if (features.sse3) { - if (dim % 8 == 0) // no point in aligning if we have an offsetting residual - *alignment = 8 * sizeof(bfloat16); // align to 128 bits. - return Choose_BF16_L2_implementation_SSE3(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t L2_FP16_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - - dist_func_t ret_dist_func = FP16_L2Sqr; - -#if defined(CPU_FEATURES_ARCH_AARCH64) -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_FP16_L2_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_FP16_L2_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_HP - if (features.asimdhp && dim >= 8) { // Optimization assumes at least 8 16FPs (full chunk) - return Choose_FP16_L2_implementation_NEON_HP(dim); - } -#endif -#endif // CPU_FEATURES_ARCH_AARCH64 - -#if defined(CPU_FEATURES_ARCH_X86_64) - // Optimizations assume at least 32 16FPs. If we have less, we use the naive implementation. - if (dim < 32) { - return ret_dist_func; - } -#ifdef OPT_AVX512_FP16_VL - // More details about the dimension limitation can be found in this PR's description: - // https://github.com/RedisAI/VectorSimilarity/pull/477 - if (features.avx512_fp16 && features.avx512vl) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(float16); // handles 32 floats - return Choose_FP16_L2_implementation_AVX512FP16_VL(dim); - } -#endif -#ifdef OPT_AVX512F - if (features.avx512f) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(float16); // handles 32 floats - return Choose_FP16_L2_implementation_AVX512F(dim); - } -#endif -#ifdef OPT_F16C - if (features.f16c && features.fma3 && features.avx) { - if (dim % 16 == 0) // no point in aligning if we have an offsetting residual - *alignment = 16 * sizeof(float16); // handles 16 floats - return Choose_FP16_L2_implementation_F16C(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t L2_INT8_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = INT8_L2Sqr; - - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_INT8_L2_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_INT8_L2_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD - if (features.asimddp && dim >= 16) { - return Choose_INT8_L2_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd && dim >= 16) { - return Choose_INT8_L2_implementation_NEON(dim); - } -#endif -#endif -#ifdef CPU_FEATURES_ARCH_X86_64 - if (dim < 32) { - return ret_dist_func; - } -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(int8_t); // align to 256 bits. - return Choose_INT8_L2_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -dist_func_t L2_UINT8_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = UINT8_L2Sqr; - // Optimizations assume at least 32 uint8. If we have less, we use the naive implementation. - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_UINT8_L2_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_UINT8_L2_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD - if (features.asimddp && dim >= 16) { - return Choose_UINT8_L2_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (features.asimd && dim >= 16) { - return Choose_UINT8_L2_implementation_NEON(dim); - } -#endif -#endif // __aarch64__ -#ifdef CPU_FEATURES_ARCH_X86_64 - if (dim < 32) { - return ret_dist_func; - } -#ifdef OPT_AVX512_F_BW_VL_VNNI - if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) { - if (dim % 32 == 0) // no point in aligning if we have an offsetting residual - *alignment = 32 * sizeof(int8_t); // align to 256 bits. - return Choose_UINT8_L2_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -// SQ8-to-SQ8 L2 squared distance function (both vectors are uint8 quantized) -dist_func_t L2_SQ8_SQ8_GetDistFunc(size_t dim, unsigned char *alignment, - const void *arch_opt) { - unsigned char dummy_alignment; - if (alignment == nullptr) { - alignment = &dummy_alignment; - } - - dist_func_t ret_dist_func = SQ8_SQ8_L2Sqr; - [[maybe_unused]] auto features = getCpuOptimizationFeatures(arch_opt); - -#ifdef CPU_FEATURES_ARCH_AARCH64 -#ifdef OPT_SVE2 - if (features.sve2) { - return Choose_SQ8_SQ8_L2_implementation_SVE2(dim); - } -#endif -#ifdef OPT_SVE - if (features.sve) { - return Choose_SQ8_SQ8_L2_implementation_SVE(dim); - } -#endif -#ifdef OPT_NEON_DOTPROD - // DOTPROD uses integer arithmetic - much faster than float-based NEON - if (dim >= 16 && features.asimddp) { - return Choose_SQ8_SQ8_L2_implementation_NEON_DOTPROD(dim); - } -#endif -#ifdef OPT_NEON - if (dim >= 16 && features.asimd) { - return Choose_SQ8_SQ8_L2_implementation_NEON(dim); - } -#endif -#endif // AARCH64 - -#ifdef CPU_FEATURES_ARCH_X86_64 -#ifdef OPT_AVX512_F_BW_VL_VNNI - // AVX512 VNNI SQ8_SQ8 uses 64-element chunks - if (dim >= 64 && features.avx512f && features.avx512bw && features.avx512vnni) { - return Choose_SQ8_SQ8_L2_implementation_AVX512F_BW_VL_VNNI(dim); - } -#endif -#endif // __x86_64__ - return ret_dist_func; -} - -} // namespace spaces diff --git a/src/VecSim/spaces/L2_space.h b/src/VecSim/spaces/L2_space.h deleted file mode 100644 index dd2dfec0c..000000000 --- a/src/VecSim/spaces/L2_space.h +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "VecSim/spaces/spaces.h" - -namespace spaces { -dist_func_t L2_FP32_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t L2_FP64_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t L2_BF16_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t L2_FP16_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t L2_INT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t L2_UINT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -// SQ8-FP32: asymmetric L2 distance between FP32 query and SQ8 storage -dist_func_t L2_SQ8_FP32_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -dist_func_t L2_SQ8_SQ8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr, - const void *arch_opt = nullptr); -} // namespace spaces diff --git a/src/VecSim/spaces/computer/calculator.h b/src/VecSim/spaces/computer/calculator.h deleted file mode 100644 index a82293700..000000000 --- a/src/VecSim/spaces/computer/calculator.h +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include - -#include "VecSim/memory/vecsim_base.h" -#include "VecSim/spaces/spaces.h" - -// We need this "wrapper" class to hold the DistanceCalculatorInterface in the index, that is not -// templated according to the distance function signature. -template -class IndexCalculatorInterface : public VecsimBaseObject { -public: - explicit IndexCalculatorInterface(std::shared_ptr allocator) - : VecsimBaseObject(allocator) {} - - virtual ~IndexCalculatorInterface() = default; - - virtual DistType calcDistance(const void *v1, const void *v2, size_t dim) const = 0; -}; - -/** - * This object purpose is to calculate the distance between two vectors. - * It extends the IndexCalculatorInterface class' type to hold the distance function. - * Every specific implementation of the distance calculator should hold by reference or by value the - * parameters required for the calculation. The distance calculation API of all DistanceCalculator - * classes is: calc_dist(v1,v2,dim). Internally it calls the distance function according the - * template signature, allowing flexibility in the distance function arguments. - */ -template -class DistanceCalculatorInterface : public IndexCalculatorInterface { -public: - DistanceCalculatorInterface(std::shared_ptr allocator, DistFuncType dist_func) - : IndexCalculatorInterface(allocator), dist_func(dist_func) {} - virtual DistType calcDistance(const void *v1, const void *v2, size_t dim) const = 0; - -protected: - DistFuncType dist_func; -}; - -template -class DistanceCalculatorCommon - : public DistanceCalculatorInterface> { -public: - DistanceCalculatorCommon(std::shared_ptr allocator, - spaces::dist_func_t dist_func) - : DistanceCalculatorInterface>(allocator, - dist_func) {} - - DistType calcDistance(const void *v1, const void *v2, size_t dim) const override { - return this->dist_func(v1, v2, dim); - } -}; diff --git a/src/VecSim/spaces/computer/preprocessor_container.cpp b/src/VecSim/spaces/computer/preprocessor_container.cpp deleted file mode 100644 index 37d678206..000000000 --- a/src/VecSim/spaces/computer/preprocessor_container.cpp +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/spaces/computer/preprocessor_container.h" - -ProcessedBlobs PreprocessorsContainerAbstract::preprocess(const void *original_blob, - size_t input_blob_size) const { - return ProcessedBlobs(preprocessForStorage(original_blob, input_blob_size), - preprocessQuery(original_blob, input_blob_size)); -} - -MemoryUtils::unique_blob -PreprocessorsContainerAbstract::preprocessForStorage(const void *original_blob, - size_t input_blob_size) const { - return wrapWithDummyDeleter(const_cast(original_blob)); -} - -MemoryUtils::unique_blob PreprocessorsContainerAbstract::preprocessQuery(const void *original_blob, - size_t input_blob_size, - bool force_copy) const { - return maybeCopyToAlignedMem(original_blob, input_blob_size, force_copy); -} - -void PreprocessorsContainerAbstract::preprocessStorageInPlace(void *blob, - size_t input_blob_size) const {} - -MemoryUtils::unique_blob PreprocessorsContainerAbstract::maybeCopyToAlignedMem( - const void *original_blob, size_t input_blob_size, bool force_copy) const { - bool needs_copy = - force_copy || (this->alignment && ((uintptr_t)original_blob % this->alignment != 0)); - - if (needs_copy) { - auto aligned_mem = this->allocator->allocate_aligned(input_blob_size, this->alignment); - memcpy(aligned_mem, original_blob, input_blob_size); - return this->wrapAllocated(aligned_mem); - } - - // Returning a unique_ptr with a no-op deleter - return wrapWithDummyDeleter(const_cast(original_blob)); -} diff --git a/src/VecSim/spaces/computer/preprocessor_container.h b/src/VecSim/spaces/computer/preprocessor_container.h deleted file mode 100644 index 454504bb3..000000000 --- a/src/VecSim/spaces/computer/preprocessor_container.h +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include -#include - -#include "VecSim/memory/vecsim_base.h" -#include "VecSim/memory/memory_utils.h" -#include "VecSim/spaces/computer/preprocessors.h" - -struct ProcessedBlobs; - -class PreprocessorsContainerAbstract : public VecsimBaseObject { -public: - PreprocessorsContainerAbstract(std::shared_ptr allocator, - unsigned char alignment) - : VecsimBaseObject(allocator), alignment(alignment) {} - // It is assumed that the resulted query blob is aligned. - virtual ProcessedBlobs preprocess(const void *original_blob, size_t input_blob_size) const; - - virtual MemoryUtils::unique_blob preprocessForStorage(const void *original_blob, - size_t input_blob_size) const; - - // It is assumed that the resulted query blob is aligned. - virtual MemoryUtils::unique_blob preprocessQuery(const void *original_blob, - size_t input_blob_size, - bool force_copy = false) const; - - virtual void preprocessStorageInPlace(void *blob, size_t input_blob_size) const; - - unsigned char getAlignment() const { return alignment; } - -protected: - const unsigned char alignment; - - // Allocate and copy the blob only if the original blob is not aligned. - MemoryUtils::unique_blob maybeCopyToAlignedMem(const void *original_blob, - size_t input_blob_size, - bool force_copy = false) const; - - MemoryUtils::unique_blob wrapAllocated(void *blob) const { - return MemoryUtils::unique_blob( - blob, [this](void *ptr) { this->allocator->free_allocation(ptr); }); - } - - static MemoryUtils::unique_blob wrapWithDummyDeleter(void *ptr) { - return MemoryUtils::unique_blob(ptr, [](void *) {}); - } -}; - -template -class MultiPreprocessorsContainer : public PreprocessorsContainerAbstract { -protected: - std::array preprocessors; - -public: - MultiPreprocessorsContainer(std::shared_ptr allocator, unsigned char alignment) - : PreprocessorsContainerAbstract(allocator, alignment) { - assert(n_preprocessors); - std::fill_n(preprocessors.begin(), n_preprocessors, nullptr); - } - - ~MultiPreprocessorsContainer() override { - for (auto pp : preprocessors) { - if (!pp) - break; - - delete pp; - } - } - - /** @returns On success, next uninitialized index, or 0 in case capacity is reached (after - * inserting the preprocessor). -1 if capacity is full and we failed to add the preprocessor. - */ - int addPreprocessor(PreprocessorInterface *preprocessor); - - ProcessedBlobs preprocess(const void *original_blob, size_t input_blob_size) const override; - - MemoryUtils::unique_blob preprocessForStorage(const void *original_blob, - size_t input_blob_size) const override; - - MemoryUtils::unique_blob preprocessQuery(const void *original_blob, size_t input_blob_size, - bool force_copy = false) const override; - - void preprocessStorageInPlace(void *blob, size_t input_blob_size) const override; - -#ifdef BUILD_TESTS - std::array getPreprocessors() const { - return preprocessors; - } -#endif - -private: - using Base = PreprocessorsContainerAbstract; -}; - -/* ======================= ProcessedBlobs Definition ======================= */ - -struct ProcessedBlobs { - explicit ProcessedBlobs() = default; - - explicit ProcessedBlobs(MemoryUtils::unique_blob &&storage_blob, - MemoryUtils::unique_blob &&query_blob) - : storage_blob(std::move(storage_blob)), query_blob(std::move(query_blob)) {} - - ProcessedBlobs(ProcessedBlobs &&other) noexcept - : storage_blob(std::move(other.storage_blob)), query_blob(std::move(other.query_blob)) {} - - // Move assignment operator - ProcessedBlobs &operator=(ProcessedBlobs &&other) noexcept { - if (this != &other) { - storage_blob = std::move(other.storage_blob); - query_blob = std::move(other.query_blob); - } - return *this; - } - - // Delete copy constructor and assignment operator to avoid copying unique_ptr - ProcessedBlobs(const ProcessedBlobs &) = delete; - ProcessedBlobs &operator=(const ProcessedBlobs &) = delete; - - const void *getStorageBlob() const { return storage_blob.get(); } - const void *getQueryBlob() const { return query_blob.get(); } - -private: - MemoryUtils::unique_blob storage_blob; - MemoryUtils::unique_blob query_blob; -}; - -/* ====================================== Implementation ======================================*/ - -/* ======================= MultiPreprocessorsContainer ======================= */ - -// On success, returns the array size after adding the preprocessor, or 0 when we add the last -// preprocessor. Returns -1 if the array is full and we failed to add the preprocessor. -template -int MultiPreprocessorsContainer::addPreprocessor( - PreprocessorInterface *preprocessor) { - for (size_t curr_pp_idx = 0; curr_pp_idx < n_preprocessors; curr_pp_idx++) { - if (preprocessors[curr_pp_idx] == nullptr) { - preprocessors[curr_pp_idx] = preprocessor; - const size_t pp_arr_size = curr_pp_idx + 1; - return pp_arr_size >= n_preprocessors ? 0 : pp_arr_size; - } - } - return -1; -} - -template -ProcessedBlobs -MultiPreprocessorsContainer::preprocess(const void *original_blob, - size_t input_blob_size) const { - // No preprocessors were added yet. - if (preprocessors[0] == nullptr) { - // query might need to be aligned - auto query_ptr = this->maybeCopyToAlignedMem(original_blob, input_blob_size); - return ProcessedBlobs( - std::move(Base::wrapWithDummyDeleter(const_cast(original_blob))), - std::move(query_ptr)); - } - - void *storage_blob = nullptr; - void *query_blob = nullptr; - - // Use of separate variables for the storage_blob_size and query_blob_size, in case we need to - // change their sizes to different values. - size_t storage_blob_size = input_blob_size; - size_t query_blob_size = input_blob_size; - - for (auto pp : preprocessors) { - if (!pp) - break; - pp->preprocess(original_blob, storage_blob, query_blob, storage_blob_size, query_blob_size, - this->alignment); - } - // At least one blob was allocated. - - // If they point to the same memory, we need to free only one of them. - if (storage_blob == query_blob) { - return ProcessedBlobs(std::move(this->wrapAllocated(storage_blob)), - std::move(Base::wrapWithDummyDeleter(storage_blob))); - } - - if (storage_blob == nullptr) { // we processed only the query - return ProcessedBlobs( - std::move(Base::wrapWithDummyDeleter(const_cast(original_blob))), - std::move(this->wrapAllocated(query_blob))); - } - - if (query_blob == nullptr) { // we processed only the storage - // query might need to be aligned - auto query_ptr = this->maybeCopyToAlignedMem(original_blob, input_blob_size); - return ProcessedBlobs(std::move(this->wrapAllocated(storage_blob)), std::move(query_ptr)); - } - - // Else, both were allocated separately, we need to release both. - return ProcessedBlobs(std::move(this->wrapAllocated(storage_blob)), - std::move(this->wrapAllocated(query_blob))); -} - -template -MemoryUtils::unique_blob -MultiPreprocessorsContainer::preprocessForStorage( - const void *original_blob, size_t input_blob_size) const { - - void *storage_blob = nullptr; - for (auto pp : preprocessors) { - if (!pp) - break; - pp->preprocessForStorage(original_blob, storage_blob, input_blob_size); - } - - return storage_blob ? std::move(this->wrapAllocated(storage_blob)) - : std::move(Base::wrapWithDummyDeleter(const_cast(original_blob))); -} - -template -MemoryUtils::unique_blob MultiPreprocessorsContainer::preprocessQuery( - const void *original_blob, size_t input_blob_size, bool force_copy) const { - - void *query_blob = nullptr; - for (auto pp : preprocessors) { - if (!pp) - break; - // modifies the memory in place - pp->preprocessQuery(original_blob, query_blob, input_blob_size, this->alignment); - } - return query_blob - ? std::move(this->wrapAllocated(query_blob)) - : std::move(this->maybeCopyToAlignedMem(original_blob, input_blob_size, force_copy)); -} - -template -void MultiPreprocessorsContainer::preprocessStorageInPlace( - void *blob, size_t input_blob_size) const { - - for (auto pp : preprocessors) { - if (!pp) - break; - // modifies the memory in place - pp->preprocessStorageInPlace(blob, input_blob_size); - } -} diff --git a/src/VecSim/spaces/computer/preprocessors.h b/src/VecSim/spaces/computer/preprocessors.h deleted file mode 100644 index 5954b3fc1..000000000 --- a/src/VecSim/spaces/computer/preprocessors.h +++ /dev/null @@ -1,484 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include "VecSim/memory/vecsim_base.h" -#include "VecSim/spaces/spaces.h" -#include "VecSim/memory/memory_utils.h" -#include "VecSim/types/sq8.h" - -class PreprocessorInterface : public VecsimBaseObject { -public: - PreprocessorInterface(std::shared_ptr allocator) - : VecsimBaseObject(allocator) {} - // Note: input_blob_size is relevant for both storage blob and query blob, as we assume results - // are the same size. - // Use the overload below for different sizes. - virtual void preprocess(const void *original_blob, void *&storage_blob, void *&query_blob, - size_t &input_blob_size, unsigned char alignment) const = 0; - virtual void preprocess(const void *original_blob, void *&storage_blob, void *&query_blob, - size_t &storage_blob_size, size_t &query_blob_size, - unsigned char alignment) const = 0; - virtual void preprocessForStorage(const void *original_blob, void *&storage_blob, - size_t &input_blob_size) const = 0; - virtual void preprocessQuery(const void *original_blob, void *&query_blob, - size_t &input_blob_size, unsigned char alignment) const = 0; - virtual void preprocessStorageInPlace(void *original_blob, size_t input_blob_size) const = 0; -}; - -template -class CosinePreprocessor : public PreprocessorInterface { -public: - // This preprocessor requires that storage_blob and query_blob have identical memory sizes - // both before processing (as input) and after preprocessing completes. - CosinePreprocessor(std::shared_ptr allocator, size_t dim, - size_t processed_bytes_count) - : PreprocessorInterface(allocator), normalize_func(spaces::GetNormalizeFunc()), - dim(dim), processed_bytes_count(processed_bytes_count) {} - - void preprocess(const void *original_blob, void *&storage_blob, void *&query_blob, - size_t &storage_blob_size, size_t &query_blob_size, - unsigned char alignment) const override { - // This assert verifies that the current use of this function is for blobs of the same - // size, which is the case for the Cosine preprocessor. If we ever need to support different - // sizes for storage and query blobs, we can remove the assert and implement the logic to - // handle different sizes. - assert(storage_blob_size == query_blob_size); - - preprocess(original_blob, storage_blob, query_blob, storage_blob_size, alignment); - // Ensure both blobs have the same size after processing. - query_blob_size = storage_blob_size; - } - - void preprocess(const void *original_blob, void *&storage_blob, void *&query_blob, - size_t &input_blob_size, unsigned char alignment) const override { - // This assert verifies that if a blob was allocated by a previous preprocessor, its - // size matches our expected processed size. Therefore, it is safe to skip re-allocation and - // process it inplace. Supporting dynamic resizing would require additional size checks (if - // statements) and memory management logic, which could impact performance. Currently, no - // code path requires this capability. If resizing becomes necessary in the future, remove - // the assertions and implement appropriate allocation handling with performance - // considerations. - assert(storage_blob == nullptr || input_blob_size == processed_bytes_count); - assert(query_blob == nullptr || input_blob_size == processed_bytes_count); - - // Case 1: Blobs are different (one might be null, or both are allocated and processed - // separately). - if (storage_blob != query_blob) { - // If one of them is null, allocate memory for it and copy the original_blob to it. - if (storage_blob == nullptr) { - storage_blob = this->allocator->allocate(processed_bytes_count); - memcpy(storage_blob, original_blob, input_blob_size); - } else if (query_blob == nullptr) { - query_blob = this->allocator->allocate_aligned(processed_bytes_count, alignment); - memcpy(query_blob, original_blob, input_blob_size); - } - - // Normalize both blobs. - normalize_func(storage_blob, this->dim); - normalize_func(query_blob, this->dim); - } else { // Case 2: Blobs are the same (either both are null or processed in the same way). - if (query_blob == nullptr) { // If both blobs are null, allocate query_blob and set - // storage_blob to point to it. - query_blob = this->allocator->allocate_aligned(processed_bytes_count, alignment); - memcpy(query_blob, original_blob, input_blob_size); - storage_blob = query_blob; - } - // normalize one of them (since they point to the same memory). - normalize_func(query_blob, this->dim); - } - - input_blob_size = processed_bytes_count; - } - - void preprocessForStorage(const void *original_blob, void *&blob, - size_t &input_blob_size) const override { - // see assert docs in preprocess - assert(blob == nullptr || input_blob_size == processed_bytes_count); - - if (blob == nullptr) { - blob = this->allocator->allocate(processed_bytes_count); - memcpy(blob, original_blob, input_blob_size); - } - normalize_func(blob, this->dim); - input_blob_size = processed_bytes_count; - } - - void preprocessQuery(const void *original_blob, void *&blob, size_t &input_blob_size, - unsigned char alignment) const override { - // see assert docs in preprocess - assert(blob == nullptr || input_blob_size == processed_bytes_count); - if (blob == nullptr) { - blob = this->allocator->allocate_aligned(processed_bytes_count, alignment); - memcpy(blob, original_blob, input_blob_size); - } - normalize_func(blob, this->dim); - input_blob_size = processed_bytes_count; - } - - void preprocessStorageInPlace(void *blob, size_t input_blob_size) const override { - assert(blob); - assert(input_blob_size == this->processed_bytes_count); - normalize_func(blob, this->dim); - } - -private: - spaces::normalizeVector_f normalize_func; - const size_t dim; - const size_t processed_bytes_count; -}; - -/* - * QuantPreprocessor is a preprocessor that quantizes storage vectors from DataType to a - * lower precision representation using OUTPUT_TYPE (uint8_t). - * Query vectors remain as DataType for asymmetric distance computation. - * - * The quantized storage blob contains the quantized values along with metadata (min value, - * scaling factor, and precomputed sums for reconstruction) in a single contiguous blob. - * The quantization is done by finding the minimum and maximum values of the input vector, - * and then scaling the values to fit in the range of [0, 255]. - * - * Storage layout: - * | quantized_values[dim] | min_val | delta | x_sum | (x_sum_squares for L2 only) | - * where: - * x_sum = Σx_i: sum of the original values, - * x_sum_squares = Σx_i²: sum of squares of the original values. - * - * The quantized blob size is: - * - For L2: dim * sizeof(OUTPUT_TYPE) + 4 * sizeof(DataType) - * - For IP/Cosine: dim * sizeof(OUTPUT_TYPE) + 3 * sizeof(DataType) - * - * Reconstruction formulas: - * Given quantized value q_i, the original value is reconstructed as: - * x_i ≈ min + delta * q_i - * - * Query processing: - * The query vector is not quantized. It remains as DataType, but we precompute - * and store metric-specific values to accelerate asymmetric distance computation: - * - For IP/Cosine: y_sum = Σy_i (sum of query values) - * - For L2: y_sum = Σy_i (sum of query values), y_sum_squares = Σy_i² (sum of squared query values) - * - * Query blob layout: - * - For IP/Cosine: | query_values[dim] | y_sum | - * - For L2: | query_values[dim] | y_sum | y_sum_squares | - * - * Query blob size: - * - For IP/Cosine: (dim + 1) * sizeof(DataType) - * - For L2: (dim + 2) * sizeof(DataType) - * - * === Asymmetric distance (storage x quantized, query y remains float) === - * - * For IP/Cosine: - * IP(x, y) = Σ(x_i * y_i) - * ≈ Σ((min + delta * q_i) * y_i) - * = min * Σy_i + delta * Σ(q_i * y_i) - * = min * y_sum + delta * quantized_dot_product - * where y_sum = Σy_i is precomputed and stored in the query blob. - * - * For L2: - * ||x - y||² = Σx_i² - 2*Σ(x_i * y_i) + Σy_i² - * = x_sum_squares - 2 * IP(x, y) + y_sum_squares - * where: - * - x_sum_squares = Σx_i² is precomputed and stored in the storage blob - * - IP(x, y) is computed using the formula above - * - y_sum_squares = Σy_i² is precomputed and stored in the query blob - * - * === Symmetric distance (both x and y are quantized) === - * - * For IP/Cosine: - * IP(x, y) = Σ((min_x + delta_x * qx_i) * (min_y + delta_y * qy_i)) - * = dim * min_x * min_y - * + min_x * delta_y * Σqy_i + min_y * delta_x * Σqx_i - * + delta_x * delta_y * Σ(qx_i * qy_i) - * = dim * min_x * min_y - * + min_x * (sum_y - dim * min_y) + min_y * (sum_x - dim * min_x) - * + delta_x * delta_y * Σ(qx_i * qy_i) - * = min_x * sum_y + min_y * sum_x - dim * min_x * min_y - * + delta_x * delta_y * Σ(qx_i * qy_i) - * where: - * - sum_x, sum_y are precomputed sums of original values - * - Σqx_i = (sum_x - dim * min_x) / delta_x (sum of quantized values, derived from stored sum) - * - Σqy_i = (sum_y - dim * min_y) / delta_y - * - * For L2: - * ||x - y||² = sum_sq_x + sum_sq_y - 2 * IP(x, y) - * where sum_sq_x, sum_sq_y are precomputed sums of squared original values. - */ -template -class QuantPreprocessor : public PreprocessorInterface { - using OUTPUT_TYPE = uint8_t; - using sq8 = vecsim_types::sq8; - - static_assert(Metric == VecSimMetric_L2 || Metric == VecSimMetric_IP || - Metric == VecSimMetric_Cosine, - "QuantPreprocessor only supports L2, IP and Cosine metrics"); - - // Helper function to perform quantization. This function is used by the storage preprocessing - // methods. - void quantize(const DataType *input, OUTPUT_TYPE *quantized) const { - assert(input && quantized); - // Find min and max values - auto [min_val, max_val] = find_min_max(input); - - // Calculate scaling factor - const DataType diff = (max_val - min_val); - // Delta = diff / 255.0f - const DataType delta = (diff == DataType{0}) ? DataType{1} : diff / DataType{255}; - const DataType inv_delta = DataType{1} / delta; - - // Compute sum (and sum of squares for L2) while quantizing - // 4 independent accumulators (sum) - DataType s0{}, s1{}, s2{}, s3{}; - - // 4 independent accumulators (sum of squares), only used for L2 - DataType q0{}, q1{}, q2{}, q3{}; - - size_t i = 0; - // round dim down to the nearest multiple of 4 - size_t dim_round_down = this->dim & ~size_t(3); - - // Quantize the values - for (; i < dim_round_down; i += 4) { - // Load once - const DataType x0 = input[i + 0]; - const DataType x1 = input[i + 1]; - const DataType x2 = input[i + 2]; - const DataType x3 = input[i + 3]; - // We know (input - min) => 0 - // If min == max, all values are the same and should be quantized to 0. - // reconstruction will yield the same original value for all vectors. - quantized[i + 0] = static_cast(std::round((x0 - min_val) * inv_delta)); - quantized[i + 1] = static_cast(std::round((x1 - min_val) * inv_delta)); - quantized[i + 2] = static_cast(std::round((x2 - min_val) * inv_delta)); - quantized[i + 3] = static_cast(std::round((x3 - min_val) * inv_delta)); - - // Accumulate sum for all metrics - s0 += x0; - s1 += x1; - s2 += x2; - s3 += x3; - - // Accumulate sum of squares only for L2 metric - if constexpr (Metric == VecSimMetric_L2) { - q0 += x0 * x0; - q1 += x1 * x1; - q2 += x2 * x2; - q3 += x3 * x3; - } - } - - // Tail: 0..3 remaining elements (still the same pass, just finishing work) - DataType sum = (s0 + s1) + (s2 + s3); - DataType sum_squares = (q0 + q1) + (q2 + q3); - - for (; i < this->dim; ++i) { - const DataType x = input[i]; - quantized[i] = static_cast(std::round((x - min_val) * inv_delta)); - sum += x; - if constexpr (Metric == VecSimMetric_L2) { - sum_squares += x * x; - } - } - - DataType *metadata = reinterpret_cast(quantized + this->dim); - - // Store min_val, delta, in the metadata - metadata[sq8::MIN_VAL] = min_val; - metadata[sq8::DELTA] = delta; - - // Store sum (for all metrics) and sum_squares (for L2 only) - metadata[sq8::SUM] = sum; - if constexpr (Metric == VecSimMetric_L2) { - metadata[sq8::SUM_SQUARES] = sum_squares; - } - } - - // Computes and assigns query metadata in a single pass over the input vector. - // For IP/Cosine: assigns y_sum = Σy_i - // For L2: assigns y_sum = Σy_i and y_sum_squares = Σy_i² - void assign_query_metadata(const DataType *input, DataType *output_metadata) const { - // 4 independent accumulators for sum - DataType s0{}, s1{}, s2{}, s3{}; - // 4 independent accumulators for sum of squares (only used for L2) - DataType q0{}, q1{}, q2{}, q3{}; - - size_t i = 0; - // round dim down to the nearest multiple of 4 - size_t dim_round_down = this->dim & ~size_t(3); - - for (; i < dim_round_down; i += 4) { - const DataType y0 = input[i + 0]; - const DataType y1 = input[i + 1]; - const DataType y2 = input[i + 2]; - const DataType y3 = input[i + 3]; - - s0 += y0; - s1 += y1; - s2 += y2; - s3 += y3; - - if constexpr (Metric == VecSimMetric_L2) { - q0 += y0 * y0; - q1 += y1 * y1; - q2 += y2 * y2; - q3 += y3 * y3; - } - } - - DataType sum = (s0 + s1) + (s2 + s3); - DataType sum_squares = (q0 + q1) + (q2 + q3); - - // Tail: handle remaining elements - for (; i < this->dim; ++i) { - const DataType y = input[i]; - sum += y; - if constexpr (Metric == VecSimMetric_L2) { - sum_squares += y * y; - } - } - - // Assign the computed metadata - output_metadata[sq8::SUM_QUERY] = sum; // y_sum for all metrics - if constexpr (Metric == VecSimMetric_L2) { - output_metadata[sq8::SUM_SQUARES_QUERY] = sum_squares; // y_sum_squares for L2 only - } - } - -public: - QuantPreprocessor(std::shared_ptr allocator, size_t dim) - : PreprocessorInterface(allocator), dim(dim), - storage_bytes_count(dim * sizeof(OUTPUT_TYPE) + - (vecsim_types::sq8::storage_metadata_count()) * - sizeof(DataType)), - query_bytes_count((dim + vecsim_types::sq8::query_metadata_count()) * - sizeof(DataType)) { - static_assert(std::is_floating_point_v, - "QuantPreprocessor only supports floating-point types"); - } - - void preprocess(const void *original_blob, void *&storage_blob, void *&query_blob, - size_t &input_blob_size, unsigned char alignment) const override { - assert(false && - "QuantPreprocessor does not support identical size for storage and query blobs"); - } - - /** - * Preprocesses the original blob into separate storage and query blobs. - * - * Storage vectors are quantized to uint8_t values, with metadata (min, delta, sum, and - * sum_squares for L2) appended for distance reconstruction. - * - * Query vectors remain as DataType for asymmetric distance computation, with a precomputed - * sum (for IP/Cosine) or sum of squares (for L2) appended for efficient distance calculation. - * - * Possible scenarios (currently only CASE 1 is implemented): - * - CASE 1: STORAGE BLOB AND QUERY BLOB NEED ALLOCATION (storage_blob == query_blob == nullptr) - * - CASE 2: STORAGE BLOB EXISTS (storage_blob != nullptr) - * - CASE 2A: STORAGE BLOB EXISTS and its size is insufficient - * (storage_blob_size < required_size) - reallocate storage - * - CASE 2B: STORAGE AND QUERY SHARE MEMORY (storage_blob == query_blob != nullptr) - - * reallocate storage - * - CASE 2C: SEPARATE STORAGE AND QUERY BLOBS (storage_blob != query_blob) - quantize storage - * in-place - */ - void preprocess(const void *original_blob, void *&storage_blob, void *&query_blob, - size_t &storage_blob_size, size_t &query_blob_size, - unsigned char alignment) const override { - // CASE 1: STORAGE BLOB NEEDS ALLOCATION - the only implemented case - assert(!storage_blob && "CASE 1: storage_blob must be nullptr"); - assert(!query_blob && "CASE 1: query_blob must be nullptr"); - - // storage_blob_size and query_blob_size must point to different memory slots. - assert(&storage_blob_size != &query_blob_size); - - // CASE 2A: STORAGE BLOB EXISTS and its size is insufficient - not implemented - // storage_blob && storage_blob_size < required_size - // CASE 2B: STORAGE EXISTS AND EQUALS QUERY BLOB - not implemented - // storage_blob && storage_blob == query_blob - // (if we want to handle this, we need to separate the blobs) - // CASE 2C: SEPARATE STORAGE AND QUERY BLOBS - not implemented - // storage_blob && storage_blob != query_blob - // We can quantize the storage blob in-place (if we already checked storage_blob_size is - // sufficient) - - preprocessForStorage(original_blob, storage_blob, storage_blob_size); - preprocessQuery(original_blob, query_blob, query_blob_size, alignment); - } - - void preprocessForStorage(const void *original_blob, void *&blob, - size_t &input_blob_size) const override { - assert(!blob && "storage_blob must be nullptr"); - - blob = this->allocator->allocate(storage_bytes_count); - // Cast to appropriate types - const DataType *input = static_cast(original_blob); - OUTPUT_TYPE *quantized = static_cast(blob); - quantize(input, quantized); - - input_blob_size = storage_bytes_count; - } - - /** - * Preprocesses the query vector for asymmetric distance computation. - * - * The query blob contains the original float values followed by precomputed values: - * - For IP/Cosine: y_sum = Σy_i (sum of query values) - * - For L2: y_sum = Σy_i (sum of query values), y_sum_squares = Σy_i² (sum of squared query - * values) - * - * Query blob layout: - * - For IP/Cosine: | query_values[dim] | y_sum | - * - For L2: | query_values[dim] | y_sum | y_sum_squares | - * - * Query blob size: - * - For IP/Cosine: (dim + 1) * sizeof(DataType) - * - For L2: (dim + 2) * sizeof(DataType) - */ - void preprocessQuery(const void *original_blob, void *&blob, size_t &query_blob_size, - unsigned char alignment) const override { - assert(!blob && "query_blob must be nullptr"); - - // Allocate aligned memory for the query blob - blob = this->allocator->allocate_aligned(this->query_bytes_count, alignment); - memcpy(blob, original_blob, this->dim * sizeof(DataType)); - const DataType *input = static_cast(original_blob); - DataType *output = static_cast(blob); - - // Compute and assign query metadata (sum for IP/Cosine, sum and sum_squares for L2) - assign_query_metadata(input, output + this->dim); - - query_blob_size = this->query_bytes_count; - } - - void preprocessStorageInPlace(void *original_blob, size_t input_blob_size) const override { - assert(original_blob); - assert(input_blob_size >= storage_bytes_count && - "Input buffer too small for in-place quantization"); - - quantize(static_cast(original_blob), - static_cast(original_blob)); - } - -private: - std::pair find_min_max(const DataType *input) const { - auto [min_it, max_it] = std::minmax_element(input, input + dim); - return {*min_it, *max_it}; - } - - const size_t dim; - const size_t storage_bytes_count; - const size_t query_bytes_count; -}; diff --git a/src/VecSim/spaces/functions/AVX.cpp b/src/VecSim/spaces/functions/AVX.cpp deleted file mode 100644 index 4b707a5b5..000000000 --- a/src/VecSim/spaces/functions/AVX.cpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "AVX.h" - -#include "VecSim/spaces/L2/L2_AVX_FP32.h" -#include "VecSim/spaces/L2/L2_AVX_FP64.h" - -#include "VecSim/spaces/IP/IP_AVX_FP32.h" -#include "VecSim/spaces/IP/IP_AVX_FP64.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_FP32_IP_implementation_AVX(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, FP32_InnerProductSIMD16_AVX); - return ret_dist_func; -} - -dist_func_t Choose_FP64_IP_implementation_AVX(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 8, FP64_InnerProductSIMD8_AVX); - return ret_dist_func; -} - -dist_func_t Choose_FP32_L2_implementation_AVX(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, FP32_L2SqrSIMD16_AVX); - return ret_dist_func; -} - -dist_func_t Choose_FP64_L2_implementation_AVX(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 8, FP64_L2SqrSIMD8_AVX); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX.h b/src/VecSim/spaces/functions/AVX.h deleted file mode 100644 index 6e2e5843f..000000000 --- a/src/VecSim/spaces/functions/AVX.h +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_FP32_IP_implementation_AVX(size_t dim); -dist_func_t Choose_FP64_IP_implementation_AVX(size_t dim); - -dist_func_t Choose_FP32_L2_implementation_AVX(size_t dim); -dist_func_t Choose_FP64_L2_implementation_AVX(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX2.cpp b/src/VecSim/spaces/functions/AVX2.cpp deleted file mode 100644 index 322ed0aec..000000000 --- a/src/VecSim/spaces/functions/AVX2.cpp +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "AVX2.h" - -#include "VecSim/spaces/IP/IP_AVX2_BF16.h" -#include "VecSim/spaces/L2/L2_AVX2_BF16.h" -#include "VecSim/spaces/IP/IP_AVX2_SQ8_FP32.h" -#include "VecSim/spaces/L2/L2_AVX2_SQ8_FP32.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_BF16_IP_implementation_AVX2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_InnerProductSIMD32_AVX2); - return ret_dist_func; -} - -dist_func_t Choose_BF16_L2_implementation_AVX2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_L2SqrSIMD32_AVX2); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_IP_implementation_AVX2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_InnerProductSIMD16_AVX2); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_Cosine_implementation_AVX2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_CosineSIMD16_AVX2); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_L2_implementation_AVX2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_L2SqrSIMD16_AVX2); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX2.h b/src/VecSim/spaces/functions/AVX2.h deleted file mode 100644 index 081c42a4e..000000000 --- a/src/VecSim/spaces/functions/AVX2.h +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_SQ8_FP32_IP_implementation_AVX2(size_t dim); -dist_func_t Choose_SQ8_FP32_Cosine_implementation_AVX2(size_t dim); -dist_func_t Choose_SQ8_FP32_L2_implementation_AVX2(size_t dim); - -dist_func_t Choose_BF16_IP_implementation_AVX2(size_t dim); -dist_func_t Choose_BF16_L2_implementation_AVX2(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX2_FMA.cpp b/src/VecSim/spaces/functions/AVX2_FMA.cpp deleted file mode 100644 index c859128b2..000000000 --- a/src/VecSim/spaces/functions/AVX2_FMA.cpp +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "AVX2_FMA.h" -#include "VecSim/spaces/L2/L2_AVX2_FMA_SQ8_FP32.h" -#include "VecSim/spaces/IP/IP_AVX2_FMA_SQ8_FP32.h" - -namespace spaces { - -#include "implementation_chooser.h" -// FMA optimized implementations -dist_func_t Choose_SQ8_FP32_IP_implementation_AVX2_FMA(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_InnerProductSIMD16_AVX2_FMA); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_Cosine_implementation_AVX2_FMA(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_CosineSIMD16_AVX2_FMA); - return ret_dist_func; -} -dist_func_t Choose_SQ8_FP32_L2_implementation_AVX2_FMA(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_L2SqrSIMD16_AVX2_FMA); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX2_FMA.h b/src/VecSim/spaces/functions/AVX2_FMA.h deleted file mode 100644 index b20b1a588..000000000 --- a/src/VecSim/spaces/functions/AVX2_FMA.h +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_SQ8_FP32_IP_implementation_AVX2_FMA(size_t dim); -dist_func_t Choose_SQ8_FP32_Cosine_implementation_AVX2_FMA(size_t dim); -dist_func_t Choose_SQ8_FP32_L2_implementation_AVX2_FMA(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512BF16_VL.cpp b/src/VecSim/spaces/functions/AVX512BF16_VL.cpp deleted file mode 100644 index 8ed623205..000000000 --- a/src/VecSim/spaces/functions/AVX512BF16_VL.cpp +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "AVX512BF16_VL.h" - -#include "VecSim/spaces/IP/IP_AVX512_BF16_VL_BF16.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_BF16_IP_implementation_AVX512BF16_VL(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_InnerProductSIMD32_AVX512BF16_VL); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512BF16_VL.h b/src/VecSim/spaces/functions/AVX512BF16_VL.h deleted file mode 100644 index 9c7e50ed5..000000000 --- a/src/VecSim/spaces/functions/AVX512BF16_VL.h +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_BF16_IP_implementation_AVX512BF16_VL(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512BW_VBMI2.cpp b/src/VecSim/spaces/functions/AVX512BW_VBMI2.cpp deleted file mode 100644 index 532cb1882..000000000 --- a/src/VecSim/spaces/functions/AVX512BW_VBMI2.cpp +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "AVX512BW_VBMI2.h" - -#include "VecSim/spaces/IP/IP_AVX512BW_VBMI2_BF16.h" -#include "VecSim/spaces/L2/L2_AVX512BW_VBMI2_BF16.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_BF16_IP_implementation_AVX512BW_VBMI2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_InnerProductSIMD32_AVX512BW_VBMI2); - return ret_dist_func; -} - -dist_func_t Choose_BF16_L2_implementation_AVX512BW_VBMI2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_L2SqrSIMD32_AVX512BW_VBMI2); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512BW_VBMI2.h b/src/VecSim/spaces/functions/AVX512BW_VBMI2.h deleted file mode 100644 index 801236e11..000000000 --- a/src/VecSim/spaces/functions/AVX512BW_VBMI2.h +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_BF16_IP_implementation_AVX512BW_VBMI2(size_t dim); -dist_func_t Choose_BF16_L2_implementation_AVX512BW_VBMI2(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512F.cpp b/src/VecSim/spaces/functions/AVX512F.cpp deleted file mode 100644 index e765f4c8b..000000000 --- a/src/VecSim/spaces/functions/AVX512F.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "AVX512F.h" - -#include "VecSim/spaces/L2/L2_AVX512F_FP16.h" -#include "VecSim/spaces/L2/L2_AVX512F_FP32.h" -#include "VecSim/spaces/L2/L2_AVX512F_FP64.h" - -#include "VecSim/spaces/IP/IP_AVX512F_FP16.h" -#include "VecSim/spaces/IP/IP_AVX512F_FP32.h" -#include "VecSim/spaces/IP/IP_AVX512F_FP64.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_FP32_IP_implementation_AVX512F(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, FP32_InnerProductSIMD16_AVX512); - return ret_dist_func; -} - -dist_func_t Choose_FP64_IP_implementation_AVX512F(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 8, FP64_InnerProductSIMD8_AVX512); - return ret_dist_func; -} - -dist_func_t Choose_FP32_L2_implementation_AVX512F(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, FP32_L2SqrSIMD16_AVX512); - return ret_dist_func; -} - -dist_func_t Choose_FP64_L2_implementation_AVX512F(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 8, FP64_L2SqrSIMD8_AVX512); - return ret_dist_func; -} - -dist_func_t Choose_FP16_IP_implementation_AVX512F(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, FP16_InnerProductSIMD32_AVX512); - return ret_dist_func; -} - -dist_func_t Choose_FP16_L2_implementation_AVX512F(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, FP16_L2SqrSIMD32_AVX512); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512F.h b/src/VecSim/spaces/functions/AVX512F.h deleted file mode 100644 index fd36f312f..000000000 --- a/src/VecSim/spaces/functions/AVX512F.h +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_FP16_IP_implementation_AVX512F(size_t dim); -dist_func_t Choose_FP32_IP_implementation_AVX512F(size_t dim); -dist_func_t Choose_FP64_IP_implementation_AVX512F(size_t dim); - -dist_func_t Choose_FP16_L2_implementation_AVX512F(size_t dim); -dist_func_t Choose_FP32_L2_implementation_AVX512F(size_t dim); -dist_func_t Choose_FP64_L2_implementation_AVX512F(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512FP16_VL.cpp b/src/VecSim/spaces/functions/AVX512FP16_VL.cpp deleted file mode 100644 index 93549b589..000000000 --- a/src/VecSim/spaces/functions/AVX512FP16_VL.cpp +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "AVX512FP16_VL.h" - -#include "VecSim/spaces/IP/IP_AVX512FP16_VL_FP16.h" -#include "VecSim/spaces/L2/L2_AVX512FP16_VL_FP16.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_FP16_IP_implementation_AVX512FP16_VL(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, FP16_InnerProductSIMD32_AVX512FP16_VL); - return ret_dist_func; -} - -dist_func_t Choose_FP16_L2_implementation_AVX512FP16_VL(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, FP16_L2SqrSIMD32_AVX512FP16_VL); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512FP16_VL.h b/src/VecSim/spaces/functions/AVX512FP16_VL.h deleted file mode 100644 index 8ffe348a3..000000000 --- a/src/VecSim/spaces/functions/AVX512FP16_VL.h +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_FP16_IP_implementation_AVX512FP16_VL(size_t dim); -dist_func_t Choose_FP16_L2_implementation_AVX512FP16_VL(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512F_BW_VL_VNNI.cpp b/src/VecSim/spaces/functions/AVX512F_BW_VL_VNNI.cpp deleted file mode 100644 index 3b8813b89..000000000 --- a/src/VecSim/spaces/functions/AVX512F_BW_VL_VNNI.cpp +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "AVX512F_BW_VL_VNNI.h" - -#include "VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_INT8.h" -#include "VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_INT8.h" - -#include "VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_UINT8.h" -#include "VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_UINT8.h" - -#include "VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_FP32.h" -#include "VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8_FP32.h" - -#include "VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_SQ8_SQ8.h" -#include "VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_SQ8_SQ8.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_INT8_L2_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_L2SqrSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -dist_func_t Choose_INT8_IP_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_InnerProductSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -dist_func_t Choose_INT8_Cosine_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_CosineSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_L2_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_L2SqrSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_IP_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_InnerProductSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_Cosine_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_CosineSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_IP_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_InnerProductSIMD16_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} -dist_func_t Choose_SQ8_FP32_Cosine_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_CosineSIMD16_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} -dist_func_t Choose_SQ8_FP32_L2_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_L2SqrSIMD16_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized with precomputed sum) -dist_func_t Choose_SQ8_SQ8_IP_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_InnerProductSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_CosineSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_L2_implementation_AVX512F_BW_VL_VNNI(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_L2SqrSIMD64_AVX512F_BW_VL_VNNI); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/AVX512F_BW_VL_VNNI.h b/src/VecSim/spaces/functions/AVX512F_BW_VL_VNNI.h deleted file mode 100644 index fe1583491..000000000 --- a/src/VecSim/spaces/functions/AVX512F_BW_VL_VNNI.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_INT8_L2_implementation_AVX512F_BW_VL_VNNI(size_t dim); -dist_func_t Choose_INT8_IP_implementation_AVX512F_BW_VL_VNNI(size_t dim); -dist_func_t Choose_INT8_Cosine_implementation_AVX512F_BW_VL_VNNI(size_t dim); - -dist_func_t Choose_UINT8_L2_implementation_AVX512F_BW_VL_VNNI(size_t dim); -dist_func_t Choose_UINT8_IP_implementation_AVX512F_BW_VL_VNNI(size_t dim); -dist_func_t Choose_UINT8_Cosine_implementation_AVX512F_BW_VL_VNNI(size_t dim); - -dist_func_t Choose_SQ8_FP32_IP_implementation_AVX512F_BW_VL_VNNI(size_t dim); -dist_func_t Choose_SQ8_FP32_Cosine_implementation_AVX512F_BW_VL_VNNI(size_t dim); -dist_func_t Choose_SQ8_FP32_L2_implementation_AVX512F_BW_VL_VNNI(size_t dim); - -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized with precomputed sum) -dist_func_t Choose_SQ8_SQ8_IP_implementation_AVX512F_BW_VL_VNNI(size_t dim); -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_AVX512F_BW_VL_VNNI(size_t dim); -dist_func_t Choose_SQ8_SQ8_L2_implementation_AVX512F_BW_VL_VNNI(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/F16C.cpp b/src/VecSim/spaces/functions/F16C.cpp deleted file mode 100644 index 7a4127adc..000000000 --- a/src/VecSim/spaces/functions/F16C.cpp +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "F16C.h" - -#include "VecSim/spaces/IP/IP_F16C_FP16.h" -#include "VecSim/spaces/L2/L2_F16C_FP16.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_FP16_IP_implementation_F16C(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, FP16_InnerProductSIMD32_F16C); - return ret_dist_func; -} - -dist_func_t Choose_FP16_L2_implementation_F16C(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, FP16_L2SqrSIMD32_F16C); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/F16C.h b/src/VecSim/spaces/functions/F16C.h deleted file mode 100644 index 1e977d0eb..000000000 --- a/src/VecSim/spaces/functions/F16C.h +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_FP16_IP_implementation_F16C(size_t dim); -dist_func_t Choose_FP16_L2_implementation_F16C(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/NEON.cpp b/src/VecSim/spaces/functions/NEON.cpp deleted file mode 100644 index 0c9a286e3..000000000 --- a/src/VecSim/spaces/functions/NEON.cpp +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "NEON.h" -#include "VecSim/spaces/L2/L2_NEON_FP32.h" -#include "VecSim/spaces/IP/IP_NEON_FP32.h" -#include "VecSim/spaces/L2/L2_NEON_INT8.h" -#include "VecSim/spaces/L2/L2_NEON_UINT8.h" -#include "VecSim/spaces/IP/IP_NEON_INT8.h" -#include "VecSim/spaces/IP/IP_NEON_UINT8.h" -#include "VecSim/spaces/L2/L2_NEON_FP64.h" -#include "VecSim/spaces/IP/IP_NEON_FP64.h" -#include "VecSim/spaces/L2/L2_NEON_SQ8_FP32.h" -#include "VecSim/spaces/IP/IP_NEON_SQ8_FP32.h" -#include "VecSim/spaces/IP/IP_NEON_SQ8_SQ8.h" -#include "VecSim/spaces/L2/L2_NEON_SQ8_SQ8.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_INT8_IP_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_InnerProductSIMD16_NEON); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_IP_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_InnerProductSIMD16_NEON); - return ret_dist_func; -} - -dist_func_t Choose_FP32_IP_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, FP32_InnerProductSIMD16_NEON); - return ret_dist_func; -} - -dist_func_t Choose_FP64_IP_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 8, FP64_InnerProductSIMD8_NEON); - return ret_dist_func; -} - -dist_func_t Choose_INT8_Cosine_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_CosineSIMD_NEON); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_Cosine_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_CosineSIMD_NEON); - return ret_dist_func; -} - -dist_func_t Choose_FP32_L2_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, FP32_L2SqrSIMD16_NEON); - return ret_dist_func; -} -dist_func_t Choose_INT8_L2_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_L2SqrSIMD16_NEON); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_L2_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_L2SqrSIMD16_NEON); - return ret_dist_func; -} - -dist_func_t Choose_FP64_L2_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 8, FP64_L2SqrSIMD8_NEON); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_L2_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_L2SqrSIMD16_NEON); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_IP_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_InnerProductSIMD16_NEON); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_Cosine_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_CosineSIMD16_NEON); - return ret_dist_func; -} - -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized with precomputed sum) -// Uses 64-element chunking to leverage efficient UINT8_InnerProductImp -dist_func_t Choose_SQ8_SQ8_IP_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_InnerProductSIMD64_NEON); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_CosineSIMD64_NEON); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_L2_implementation_NEON(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_L2SqrSIMD64_NEON); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/NEON.h b/src/VecSim/spaces/functions/NEON.h deleted file mode 100644 index 08060b402..000000000 --- a/src/VecSim/spaces/functions/NEON.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_INT8_IP_implementation_NEON(size_t dim); -dist_func_t Choose_INT8_L2_implementation_NEON(size_t dim); -dist_func_t Choose_INT8_Cosine_implementation_NEON(size_t dim); - -dist_func_t Choose_UINT8_IP_implementation_NEON(size_t dim); -dist_func_t Choose_UINT8_L2_implementation_NEON(size_t dim); -dist_func_t Choose_UINT8_Cosine_implementation_NEON(size_t dim); - -dist_func_t Choose_FP32_IP_implementation_NEON(size_t dim); -dist_func_t Choose_FP32_L2_implementation_NEON(size_t dim); - -dist_func_t Choose_FP64_IP_implementation_NEON(size_t dim); -dist_func_t Choose_FP64_L2_implementation_NEON(size_t dim); - -dist_func_t Choose_SQ8_FP32_L2_implementation_NEON(size_t dim); -dist_func_t Choose_SQ8_FP32_IP_implementation_NEON(size_t dim); -dist_func_t Choose_SQ8_FP32_Cosine_implementation_NEON(size_t dim); - -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized with precomputed sum) -dist_func_t Choose_SQ8_SQ8_IP_implementation_NEON(size_t dim); -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_NEON(size_t dim); -dist_func_t Choose_SQ8_SQ8_L2_implementation_NEON(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/NEON_BF16.cpp b/src/VecSim/spaces/functions/NEON_BF16.cpp deleted file mode 100644 index 4de205bb8..000000000 --- a/src/VecSim/spaces/functions/NEON_BF16.cpp +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "NEON_BF16.h" - -#include "VecSim/spaces/L2/L2_NEON_BF16.h" -#include "VecSim/spaces/IP/IP_NEON_BF16.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_BF16_L2_implementation_NEON_BF16(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_L2Sqr_NEON); - return ret_dist_func; -} - -dist_func_t Choose_BF16_IP_implementation_NEON_BF16(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_InnerProduct_NEON); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/NEON_BF16.h b/src/VecSim/spaces/functions/NEON_BF16.h deleted file mode 100644 index 85cded665..000000000 --- a/src/VecSim/spaces/functions/NEON_BF16.h +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_BF16_IP_implementation_NEON_BF16(size_t dim); - -dist_func_t Choose_BF16_L2_implementation_NEON_BF16(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/NEON_DOTPROD.cpp b/src/VecSim/spaces/functions/NEON_DOTPROD.cpp deleted file mode 100644 index 12f762093..000000000 --- a/src/VecSim/spaces/functions/NEON_DOTPROD.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "NEON.h" -#include "VecSim/spaces/IP/IP_NEON_DOTPROD_INT8.h" -#include "VecSim/spaces/IP/IP_NEON_DOTPROD_UINT8.h" -#include "VecSim/spaces/IP/IP_NEON_DOTPROD_SQ8_SQ8.h" -#include "VecSim/spaces/L2/L2_NEON_DOTPROD_INT8.h" -#include "VecSim/spaces/L2/L2_NEON_DOTPROD_UINT8.h" -#include "VecSim/spaces/L2/L2_NEON_DOTPROD_SQ8_SQ8.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_INT8_IP_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_InnerProductSIMD16_NEON_DOTPROD); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_IP_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_InnerProductSIMD16_NEON_DOTPROD); - return ret_dist_func; -} - -dist_func_t Choose_INT8_Cosine_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_CosineSIMD_NEON_DOTPROD); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_Cosine_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_CosineSIMD_NEON_DOTPROD); - return ret_dist_func; -} - -dist_func_t Choose_INT8_L2_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, INT8_L2SqrSIMD16_NEON_DOTPROD); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_L2_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, UINT8_L2SqrSIMD16_NEON_DOTPROD); - return ret_dist_func; -} - -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized with precomputed sum) -dist_func_t Choose_SQ8_SQ8_IP_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_InnerProductSIMD64_NEON_DOTPROD); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_CosineSIMD64_NEON_DOTPROD); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_L2_implementation_NEON_DOTPROD(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 64, SQ8_SQ8_L2SqrSIMD64_NEON_DOTPROD); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/NEON_DOTPROD.h b/src/VecSim/spaces/functions/NEON_DOTPROD.h deleted file mode 100644 index 0fda479bc..000000000 --- a/src/VecSim/spaces/functions/NEON_DOTPROD.h +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_INT8_IP_implementation_NEON_DOTPROD(size_t dim); -dist_func_t Choose_INT8_Cosine_implementation_NEON_DOTPROD(size_t dim); - -dist_func_t Choose_UINT8_IP_implementation_NEON_DOTPROD(size_t dim); -dist_func_t Choose_UINT8_Cosine_implementation_NEON_DOTPROD(size_t dim); - -dist_func_t Choose_INT8_L2_implementation_NEON_DOTPROD(size_t dim); -dist_func_t Choose_UINT8_L2_implementation_NEON_DOTPROD(size_t dim); - -// SQ8-to-SQ8 DOTPROD-optimized distance functions (with precomputed sum) -dist_func_t Choose_SQ8_SQ8_IP_implementation_NEON_DOTPROD(size_t dim); -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_NEON_DOTPROD(size_t dim); -dist_func_t Choose_SQ8_SQ8_L2_implementation_NEON_DOTPROD(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/NEON_HP.cpp b/src/VecSim/spaces/functions/NEON_HP.cpp deleted file mode 100644 index 2dea94934..000000000 --- a/src/VecSim/spaces/functions/NEON_HP.cpp +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "NEON_HP.h" - -#include "VecSim/spaces/L2/L2_NEON_FP16.h" -#include "VecSim/spaces/IP/IP_NEON_FP16.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_FP16_L2_implementation_NEON_HP(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, FP16_L2Sqr_NEON_HP); - return ret_dist_func; -} - -dist_func_t Choose_FP16_IP_implementation_NEON_HP(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, FP16_InnerProduct_NEON_HP); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/NEON_HP.h b/src/VecSim/spaces/functions/NEON_HP.h deleted file mode 100644 index c65bd6948..000000000 --- a/src/VecSim/spaces/functions/NEON_HP.h +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_FP16_IP_implementation_NEON_HP(size_t dim); - -dist_func_t Choose_FP16_L2_implementation_NEON_HP(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SSE.cpp b/src/VecSim/spaces/functions/SSE.cpp deleted file mode 100644 index 9963fa86f..000000000 --- a/src/VecSim/spaces/functions/SSE.cpp +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "SSE.h" - -#include "VecSim/spaces/L2/L2_SSE_FP32.h" -#include "VecSim/spaces/L2/L2_SSE_FP64.h" -#include "VecSim/spaces/L2/L2_SSE4_SQ8_FP32.h" - -#include "VecSim/spaces/IP/IP_SSE_FP32.h" -#include "VecSim/spaces/IP/IP_SSE_FP64.h" -#include "VecSim/spaces/IP/IP_SSE4_SQ8_FP32.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_FP32_IP_implementation_SSE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, FP32_InnerProductSIMD16_SSE); - return ret_dist_func; -} - -dist_func_t Choose_FP64_IP_implementation_SSE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 8, FP64_InnerProductSIMD8_SSE); - return ret_dist_func; -} - -dist_func_t Choose_FP32_L2_implementation_SSE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, FP32_L2SqrSIMD16_SSE); - return ret_dist_func; -} - -dist_func_t Choose_FP64_L2_implementation_SSE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 8, FP64_L2SqrSIMD8_SSE); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SSE.h b/src/VecSim/spaces/functions/SSE.h deleted file mode 100644 index 2eba22f31..000000000 --- a/src/VecSim/spaces/functions/SSE.h +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_FP32_IP_implementation_SSE(size_t dim); -dist_func_t Choose_FP64_IP_implementation_SSE(size_t dim); - -dist_func_t Choose_FP32_L2_implementation_SSE(size_t dim); -dist_func_t Choose_FP64_L2_implementation_SSE(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SSE3.cpp b/src/VecSim/spaces/functions/SSE3.cpp deleted file mode 100644 index 4c60c2e02..000000000 --- a/src/VecSim/spaces/functions/SSE3.cpp +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "SSE3.h" - -#include "VecSim/spaces/IP/IP_SSE3_BF16.h" -#include "VecSim/spaces/L2/L2_SSE3_BF16.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_BF16_IP_implementation_SSE3(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_InnerProductSIMD32_SSE3); - return ret_dist_func; -} - -dist_func_t Choose_BF16_L2_implementation_SSE3(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 32, BF16_L2SqrSIMD32_SSE3); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SSE3.h b/src/VecSim/spaces/functions/SSE3.h deleted file mode 100644 index d181063a7..000000000 --- a/src/VecSim/spaces/functions/SSE3.h +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_BF16_IP_implementation_SSE3(size_t dim); -dist_func_t Choose_BF16_L2_implementation_SSE3(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SSE4.cpp b/src/VecSim/spaces/functions/SSE4.cpp deleted file mode 100644 index 5f5bbc1ba..000000000 --- a/src/VecSim/spaces/functions/SSE4.cpp +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "SSE4.h" -#include "VecSim/spaces/IP/IP_SSE4_SQ8_FP32.h" -#include "VecSim/spaces/L2/L2_SSE4_SQ8_FP32.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_SQ8_FP32_IP_implementation_SSE4(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_InnerProductSIMD16_SSE4); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_Cosine_implementation_SSE4(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_CosineSIMD16_SSE4); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_L2_implementation_SSE4(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_IMPLEMENTATION(ret_dist_func, dim, 16, SQ8_FP32_L2SqrSIMD16_SSE4); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SSE4.h b/src/VecSim/spaces/functions/SSE4.h deleted file mode 100644 index e47948137..000000000 --- a/src/VecSim/spaces/functions/SSE4.h +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_SQ8_FP32_IP_implementation_SSE4(size_t dim); -dist_func_t Choose_SQ8_FP32_Cosine_implementation_SSE4(size_t dim); -dist_func_t Choose_SQ8_FP32_L2_implementation_SSE4(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SVE.cpp b/src/VecSim/spaces/functions/SVE.cpp deleted file mode 100644 index fde853db2..000000000 --- a/src/VecSim/spaces/functions/SVE.cpp +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "SVE.h" - -#include "VecSim/spaces/L2/L2_SVE_FP32.h" -#include "VecSim/spaces/IP/IP_SVE_FP32.h" - -#include "VecSim/spaces/IP/IP_SVE_FP16.h" -#include "VecSim/spaces/L2/L2_SVE_FP16.h" - -#include "VecSim/spaces/IP/IP_SVE_FP64.h" -#include "VecSim/spaces/L2/L2_SVE_FP64.h" - -#include "VecSim/spaces/L2/L2_SVE_INT8.h" -#include "VecSim/spaces/IP/IP_SVE_INT8.h" - -#include "VecSim/spaces/L2/L2_SVE_UINT8.h" -#include "VecSim/spaces/IP/IP_SVE_UINT8.h" -#include "VecSim/spaces/IP/IP_SVE_SQ8_FP32.h" -#include "VecSim/spaces/L2/L2_SVE_SQ8_FP32.h" - -#include "VecSim/spaces/IP/IP_SVE_SQ8_SQ8.h" -#include "VecSim/spaces/L2/L2_SVE_SQ8_SQ8.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_FP32_IP_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP32_InnerProductSIMD_SVE, dim, svcntw); - return ret_dist_func; -} -dist_func_t Choose_FP32_L2_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP32_L2SqrSIMD_SVE, dim, svcntw); - return ret_dist_func; -} - -dist_func_t Choose_FP16_IP_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP16_InnerProduct_SVE, dim, svcnth); - return ret_dist_func; -} -dist_func_t Choose_FP16_L2_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP16_L2Sqr_SVE, dim, svcnth); - return ret_dist_func; -} - -dist_func_t Choose_FP64_IP_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP64_InnerProductSIMD_SVE, dim, svcntd); - return ret_dist_func; -} -dist_func_t Choose_FP64_L2_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP64_L2SqrSIMD_SVE, dim, svcntd); - return ret_dist_func; -} - -dist_func_t Choose_INT8_L2_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, INT8_L2SqrSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_INT8_IP_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, INT8_InnerProductSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_INT8_Cosine_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, INT8_CosineSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_L2_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, UINT8_L2SqrSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_IP_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, UINT8_InnerProductSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_Cosine_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, UINT8_CosineSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_IP_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_FP32_InnerProductSIMD_SVE, dim, svcntw); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_Cosine_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_FP32_CosineSIMD_SVE, dim, svcntw); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_L2_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_FP32_L2SqrSIMD_SVE, dim, svcntw); - return ret_dist_func; -} - -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized with precomputed sum) -// Note: Use svcntb for uint8 elements (not svcntw which is for 32-bit elements) -dist_func_t Choose_SQ8_SQ8_IP_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_SQ8_InnerProductSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_SQ8_CosineSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_L2_implementation_SVE(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_SQ8_L2SqrSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SVE.h b/src/VecSim/spaces/functions/SVE.h deleted file mode 100644 index bd3bc97c3..000000000 --- a/src/VecSim/spaces/functions/SVE.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_FP32_IP_implementation_SVE(size_t dim); -dist_func_t Choose_FP32_L2_implementation_SVE(size_t dim); - -dist_func_t Choose_FP16_IP_implementation_SVE(size_t dim); -dist_func_t Choose_FP16_L2_implementation_SVE(size_t dim); - -dist_func_t Choose_FP64_IP_implementation_SVE(size_t dim); -dist_func_t Choose_FP64_L2_implementation_SVE(size_t dim); - -dist_func_t Choose_INT8_IP_implementation_SVE(size_t dim); -dist_func_t Choose_INT8_Cosine_implementation_SVE(size_t dim); -dist_func_t Choose_INT8_L2_implementation_SVE(size_t dim); - -dist_func_t Choose_UINT8_L2_implementation_SVE(size_t dim); -dist_func_t Choose_UINT8_Cosine_implementation_SVE(size_t dim); -dist_func_t Choose_UINT8_IP_implementation_SVE(size_t dim); - -dist_func_t Choose_SQ8_FP32_IP_implementation_SVE(size_t dim); -dist_func_t Choose_SQ8_FP32_Cosine_implementation_SVE(size_t dim); -dist_func_t Choose_SQ8_FP32_L2_implementation_SVE(size_t dim); - -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized with precomputed sum) -dist_func_t Choose_SQ8_SQ8_IP_implementation_SVE(size_t dim); -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_SVE(size_t dim); -dist_func_t Choose_SQ8_SQ8_L2_implementation_SVE(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SVE2.cpp b/src/VecSim/spaces/functions/SVE2.cpp deleted file mode 100644 index 4215d79cf..000000000 --- a/src/VecSim/spaces/functions/SVE2.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "SVE2.h" - -#include "VecSim/spaces/L2/L2_SVE_FP32.h" -#include "VecSim/spaces/IP/IP_SVE_FP32.h" - -#include "VecSim/spaces/IP/IP_SVE_FP16.h" // using SVE implementation, but different compilation flags -#include "VecSim/spaces/L2/L2_SVE_FP16.h" // using SVE implementation, but different compilation flags - -#include "VecSim/spaces/IP/IP_SVE_FP64.h" -#include "VecSim/spaces/L2/L2_SVE_FP64.h" -#include "VecSim/spaces/L2/L2_SVE_INT8.h" // SVE2 implementation is identical to SVE -#include "VecSim/spaces/IP/IP_SVE_INT8.h" // SVE2 implementation is identical to SVE -#include "VecSim/spaces/L2/L2_SVE_UINT8.h" // SVE2 implementation is identical to SVE -#include "VecSim/spaces/IP/IP_SVE_UINT8.h" // SVE2 implementation is identical to SVE -#include "VecSim/spaces/IP/IP_SVE_SQ8_FP32.h" // SVE2 implementation is identical to SVE -#include "VecSim/spaces/L2/L2_SVE_SQ8_FP32.h" // SVE2 implementation is identical to SVE -#include "VecSim/spaces/IP/IP_SVE_SQ8_SQ8.h" // SVE2 implementation is identical to SVE -#include "VecSim/spaces/L2/L2_SVE_SQ8_SQ8.h" // SVE2 implementation is identical to SVE - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_FP32_IP_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP32_InnerProductSIMD_SVE, dim, svcntw); - return ret_dist_func; -} -dist_func_t Choose_FP32_L2_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP32_L2SqrSIMD_SVE, dim, svcntw); - return ret_dist_func; -} - -dist_func_t Choose_FP16_IP_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP16_InnerProduct_SVE, dim, svcnth); - return ret_dist_func; -} -dist_func_t Choose_FP16_L2_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP16_L2Sqr_SVE, dim, svcnth); - return ret_dist_func; -} - -dist_func_t Choose_FP64_IP_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP64_InnerProductSIMD_SVE, dim, svcntd); - return ret_dist_func; -} -dist_func_t Choose_FP64_L2_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, FP64_L2SqrSIMD_SVE, dim, svcntd); - return ret_dist_func; -} - -dist_func_t Choose_INT8_L2_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, INT8_L2SqrSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_INT8_IP_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, INT8_InnerProductSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_INT8_Cosine_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, INT8_CosineSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_L2_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, UINT8_L2SqrSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_IP_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, UINT8_InnerProductSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_UINT8_Cosine_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, UINT8_CosineSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_IP_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_FP32_InnerProductSIMD_SVE, dim, svcntw); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_Cosine_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_FP32_CosineSIMD_SVE, dim, svcntw); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_FP32_L2_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_FP32_L2SqrSIMD_SVE, dim, svcntw); - return ret_dist_func; -} - -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized) -// Note: Use svcntb for uint8 elements (not svcntw which is for 32-bit elements) -dist_func_t Choose_SQ8_SQ8_IP_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_SQ8_InnerProductSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_SQ8_CosineSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -dist_func_t Choose_SQ8_SQ8_L2_implementation_SVE2(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, SQ8_SQ8_L2SqrSIMD_SVE, dim, svcntb); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SVE2.h b/src/VecSim/spaces/functions/SVE2.h deleted file mode 100644 index 04078a91e..000000000 --- a/src/VecSim/spaces/functions/SVE2.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_FP32_IP_implementation_SVE2(size_t dim); -dist_func_t Choose_FP32_L2_implementation_SVE2(size_t dim); - -dist_func_t Choose_FP16_IP_implementation_SVE2(size_t dim); -dist_func_t Choose_FP16_L2_implementation_SVE2(size_t dim); - -dist_func_t Choose_FP64_IP_implementation_SVE2(size_t dim); -dist_func_t Choose_FP64_L2_implementation_SVE2(size_t dim); - -dist_func_t Choose_INT8_L2_implementation_SVE2(size_t dim); -dist_func_t Choose_INT8_Cosine_implementation_SVE2(size_t dim); -dist_func_t Choose_INT8_IP_implementation_SVE2(size_t dim); - -dist_func_t Choose_UINT8_L2_implementation_SVE2(size_t dim); -dist_func_t Choose_UINT8_Cosine_implementation_SVE2(size_t dim); -dist_func_t Choose_UINT8_IP_implementation_SVE2(size_t dim); - -dist_func_t Choose_SQ8_FP32_IP_implementation_SVE2(size_t dim); -dist_func_t Choose_SQ8_FP32_Cosine_implementation_SVE2(size_t dim); -dist_func_t Choose_SQ8_FP32_L2_implementation_SVE2(size_t dim); - -// SQ8-to-SQ8 distance functions (both vectors are uint8 quantized) -dist_func_t Choose_SQ8_SQ8_IP_implementation_SVE2(size_t dim); -dist_func_t Choose_SQ8_SQ8_Cosine_implementation_SVE2(size_t dim); -dist_func_t Choose_SQ8_SQ8_L2_implementation_SVE2(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SVE_BF16.cpp b/src/VecSim/spaces/functions/SVE_BF16.cpp deleted file mode 100644 index b457cdb7f..000000000 --- a/src/VecSim/spaces/functions/SVE_BF16.cpp +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "SVE_BF16.h" - -#include "VecSim/spaces/IP/IP_SVE_BF16.h" -#include "VecSim/spaces/L2/L2_SVE_BF16.h" - -namespace spaces { - -#include "implementation_chooser.h" - -dist_func_t Choose_BF16_IP_implementation_SVE_BF16(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, BF16_InnerProduct_SVE, dim, svcnth); - return ret_dist_func; -} -dist_func_t Choose_BF16_L2_implementation_SVE_BF16(size_t dim) { - dist_func_t ret_dist_func; - CHOOSE_SVE_IMPLEMENTATION(ret_dist_func, BF16_L2Sqr_SVE, dim, svcnth); - return ret_dist_func; -} - -#include "implementation_chooser_cleanup.h" - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/SVE_BF16.h b/src/VecSim/spaces/functions/SVE_BF16.h deleted file mode 100644 index e8317ae57..000000000 --- a/src/VecSim/spaces/functions/SVE_BF16.h +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/spaces/spaces.h" - -namespace spaces { - -dist_func_t Choose_BF16_IP_implementation_SVE_BF16(size_t dim); -dist_func_t Choose_BF16_L2_implementation_SVE_BF16(size_t dim); - -} // namespace spaces diff --git a/src/VecSim/spaces/functions/implementation_chooser.h b/src/VecSim/spaces/functions/implementation_chooser.h deleted file mode 100644 index 9c153413f..000000000 --- a/src/VecSim/spaces/functions/implementation_chooser.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -/* - * This file contains macros magic to choose the implementation of a function based on the - * dimension's remainder. It is used to collapse large and repetitive switch statements that are - * used to choose and define the templated values of the implementation of the distance functions. - * We assume that we are dealing with 512-bit blocks, so we define a chunk size of 32 for 16-bit - * elements, 16 for 32-bit elements, and a chunk size of 8 for 64-bit elements. The main macro is - * CHOOSE_IMPLEMENTATION, and it's the one that should be used. - */ - -// Macro for a single case. Sets __ret_dist_func to the function with the given remainder. -#define C1(func, N) \ - case (N): \ - __ret_dist_func = func<(N)>; \ - break; - -// Macros for folding cases of a switch statement, for easier readability. -// Each macro expands into a sequence of cases, from 0 to N-1, doubling the previous macro. -#define C2(func, N) C1(func, 2 * (N)) C1(func, 2 * (N) + 1) -#define C4(func, N) C2(func, 2 * (N)) C2(func, 2 * (N) + 1) -#define C8(func, N) C4(func, 2 * (N)) C4(func, 2 * (N) + 1) -#define C16(func, N) C8(func, 2 * (N)) C8(func, 2 * (N) + 1) -#define C32(func, N) C16(func, 2 * (N)) C16(func, 2 * (N) + 1) -#define C64(func, N) C32(func, 2 * (N)) C32(func, 2 * (N) + 1) - -// Macros for 8, 16, 32 and 64 cases. Used to collapse the switch statement. -// Expands into 0-7, 0-15, 0-31 or 0-63 cases respectively. -#define CASES8(func) C8(func, 0) -#define CASES16(func) C16(func, 0) -#define CASES32(func) C32(func, 0) -#define CASES64(func) C64(func, 0) - -// Main macro. Expands into a switch statement that chooses the implementation based on the -// dimension's remainder. -// @params: -// out: The output variable that will be set to the chosen implementation. -// dim: The dimension. -// func: The templated function that we want to choose the implementation for. -// chunk: The chunk size. Can be 64, 32, 16 or 8. Should be the number of elements of the expected -// type fitting in the expected register size. -#define CHOOSE_IMPLEMENTATION(out, dim, chunk, func) \ - do { \ - decltype(out) __ret_dist_func; \ - switch ((dim) % (chunk)) { CASES##chunk(func) } \ - out = __ret_dist_func; \ - } while (0) - -#define SVE_CASE(base_func, N) \ - case (N): \ - if (partial_chunk) \ - __ret_dist_func = base_func; \ - else \ - __ret_dist_func = base_func; \ - break - -#define CHOOSE_SVE_IMPLEMENTATION(out, base_func, dim, chunk_getter) \ - do { \ - decltype(out) __ret_dist_func; \ - size_t chunk = chunk_getter(); \ - bool partial_chunk = dim % chunk; \ - /* Assuming `base_func` has its main loop for 4 steps */ \ - unsigned char additional_steps = (dim / chunk) % 4; \ - switch (additional_steps) { \ - SVE_CASE(base_func, 0); \ - SVE_CASE(base_func, 1); \ - SVE_CASE(base_func, 2); \ - SVE_CASE(base_func, 3); \ - } \ - out = __ret_dist_func; \ - } while (0) diff --git a/src/VecSim/spaces/functions/implementation_chooser_cleanup.h b/src/VecSim/spaces/functions/implementation_chooser_cleanup.h deleted file mode 100644 index f64a5f464..000000000 --- a/src/VecSim/spaces/functions/implementation_chooser_cleanup.h +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -/* - * Include this file after done using the implementation chooser. - */ - -#undef C1 -#undef C2 -#undef C4 -#undef C8 -#undef C16 -#undef C32 -#undef C64 - -#undef CASES8 -#undef CASES16 -#undef CASES32 -#undef CASES64 - -#undef SVE_CASE - -#undef CHOOSE_IMPLEMENTATION -#undef CHOOSE_SVE_IMPLEMENTATION diff --git a/src/VecSim/spaces/normalize/compute_norm.h b/src/VecSim/spaces/normalize/compute_norm.h deleted file mode 100644 index 2fc2550ac..000000000 --- a/src/VecSim/spaces/normalize/compute_norm.h +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include -#include - -namespace spaces { - -template -static inline float IntegralType_ComputeNorm(const DataType *vec, const size_t dim) { - static_assert(std::is_integral_v, "DataType must be an integral type"); - - int sum = 0; - - for (size_t i = 0; i < dim; i++) { - // No need to cast to int because c++ integer promotion ensures vec[i] is promoted to int - // before multiplication. - sum += vec[i] * vec[i]; - } - return sqrt(sum); -} - -} // namespace spaces diff --git a/src/VecSim/spaces/normalize/normalize_naive.h b/src/VecSim/spaces/normalize/normalize_naive.h deleted file mode 100644 index 85bdc88c1..000000000 --- a/src/VecSim/spaces/normalize/normalize_naive.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include "compute_norm.h" -#include -#include - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; - -namespace spaces { - -template -static inline void normalizeVector_imp(void *vec, const size_t dim) { - DataType *input_vector = (DataType *)vec; - // Cast to double to avoid float overflow. - double sum = 0; - - for (size_t i = 0; i < dim; i++) { - sum += (double)input_vector[i] * (double)input_vector[i]; - } - DataType norm = sqrt(sum); - - for (size_t i = 0; i < dim; i++) { - input_vector[i] = input_vector[i] / norm; - } -} - -template -static inline void bfloat16_normalizeVector(void *vec, const size_t dim) { - bfloat16 *input_vector = (bfloat16 *)vec; - - std::vector f32_tmp(dim); - - float sum = 0; - - for (size_t i = 0; i < dim; i++) { - float val = vecsim_types::bfloat16_to_float32(input_vector[i]); - f32_tmp[i] = val; - sum += val * val; - } - - float norm = sqrt(sum); - - for (size_t i = 0; i < dim; i++) { - input_vector[i] = vecsim_types::float_to_bf16(f32_tmp[i] / norm); - } -} - -static inline void float16_normalizeVector(void *vec, const size_t dim) { - float16 *input_vector = (float16 *)vec; - - std::vector f32_tmp(dim); - - float sum = 0; - - for (size_t i = 0; i < dim; i++) { - float val = vecsim_types::FP16_to_FP32(input_vector[i]); - f32_tmp[i] = val; - sum += val * val; - } - - float norm = sqrt(sum); - - for (size_t i = 0; i < dim; i++) { - input_vector[i] = vecsim_types::FP32_to_FP16(f32_tmp[i] / norm); - } -} - -template -static inline void integer_normalizeVector(void *vec, const size_t dim) { - DataType *input_vector = static_cast(vec); - - float norm = IntegralType_ComputeNorm(input_vector, dim); - - // Store norm at the end of the vector. - *reinterpret_cast(input_vector + dim) = norm; -} - -} // namespace spaces diff --git a/src/VecSim/spaces/space_includes.h b/src/VecSim/spaces/space_includes.h deleted file mode 100644 index 8860fb966..000000000 --- a/src/VecSim/spaces/space_includes.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/utils/alignment.h" - -#include "cpu_features_macros.h" -#ifdef CPU_FEATURES_ARCH_X86_64 -#include "cpuinfo_x86.h" -#endif // CPU_FEATURES_ARCH_X86_64 -#ifdef CPU_FEATURES_ARCH_AARCH64 -#include "cpuinfo_aarch64.h" -#endif // CPU_FEATURES_ARCH_AARCH64 - -#if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__) -#if defined(__GNUC__) -#include -// Override missing implementations in GCC < 11 -// Full list and suggested alternatives for each missing function can be found here: -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=95483 -#if (__GNUC__ < 11) -#define _mm256_loadu_epi8(ptr) _mm256_maskz_loadu_epi8(~0, ptr) -#define _mm512_loadu_epi8(ptr) _mm512_maskz_loadu_epi8(~0, ptr) -#endif -#elif defined(__clang__) -#include -#elif defined(_MSC_VER) -#include -#include -#endif - -#endif // __AVX512F__ || __AVX__ || __SSE__ diff --git a/src/VecSim/spaces/spaces.cpp b/src/VecSim/spaces/spaces.cpp deleted file mode 100644 index baf5c886f..000000000 --- a/src/VecSim/spaces/spaces.cpp +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include "VecSim/types/sq8.h" -#include "VecSim/spaces/space_includes.h" -#include "VecSim/spaces/spaces.h" -#include "VecSim/spaces/IP_space.h" -#include "VecSim/spaces/L2_space.h" -#include "VecSim/spaces/normalize/normalize_naive.h" - -#include -namespace spaces { - -// Set the distance function for a given data type, metric and dimension. The alignment hint is -// determined according to the chosen implementation and available optimizations. - -template <> -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, - unsigned char *alignment) { - switch (metric) { - case VecSimMetric_Cosine: - case VecSimMetric_IP: - return IP_BF16_GetDistFunc(dim, alignment); - case VecSimMetric_L2: - return L2_BF16_GetDistFunc(dim, alignment); - } - throw std::invalid_argument("Invalid metric"); -} - -template <> -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, - unsigned char *alignment) { - switch (metric) { - case VecSimMetric_Cosine: - case VecSimMetric_IP: - return IP_FP16_GetDistFunc(dim, alignment); - case VecSimMetric_L2: - return L2_FP16_GetDistFunc(dim, alignment); - } - throw std::invalid_argument("Invalid metric"); -} - -template <> -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, - unsigned char *alignment) { - switch (metric) { - case VecSimMetric_Cosine: - case VecSimMetric_IP: - return IP_FP32_GetDistFunc(dim, alignment); - case VecSimMetric_L2: - return L2_FP32_GetDistFunc(dim, alignment); - } - throw std::invalid_argument("Invalid metric"); -} - -template <> -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, - unsigned char *alignment) { - switch (metric) { - case VecSimMetric_Cosine: - case VecSimMetric_IP: - return IP_FP64_GetDistFunc(dim, alignment); - case VecSimMetric_L2: - return L2_FP64_GetDistFunc(dim, alignment); - } - throw std::invalid_argument("Invalid metric"); -} - -template <> -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, - unsigned char *alignment) { - switch (metric) { - case VecSimMetric_Cosine: - return Cosine_INT8_GetDistFunc(dim, alignment); - case VecSimMetric_IP: - return IP_INT8_GetDistFunc(dim, alignment); - case VecSimMetric_L2: - return L2_INT8_GetDistFunc(dim, alignment); - } - throw std::invalid_argument("Invalid metric"); -} - -template <> -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, - unsigned char *alignment) { - switch (metric) { - case VecSimMetric_Cosine: - return Cosine_UINT8_GetDistFunc(dim, alignment); - case VecSimMetric_IP: - return IP_UINT8_GetDistFunc(dim, alignment); - case VecSimMetric_L2: - return L2_UINT8_GetDistFunc(dim, alignment); - } - throw std::invalid_argument("Invalid metric"); -} - -template <> -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, - unsigned char *alignment) { - switch (metric) { - case VecSimMetric_Cosine: - return Cosine_SQ8_SQ8_GetDistFunc(dim, alignment); - case VecSimMetric_IP: - return IP_SQ8_SQ8_GetDistFunc(dim, alignment); - case VecSimMetric_L2: - return L2_SQ8_SQ8_GetDistFunc(dim, alignment); - } - throw std::invalid_argument("Invalid metric"); -} - -template <> -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, - unsigned char *alignment) { - switch (metric) { - case VecSimMetric_Cosine: - return Cosine_SQ8_FP32_GetDistFunc(dim, alignment); - case VecSimMetric_IP: - return IP_SQ8_FP32_GetDistFunc(dim, alignment); - case VecSimMetric_L2: - return L2_SQ8_FP32_GetDistFunc(dim, alignment); - } - throw std::invalid_argument("Invalid metric"); -} - -template <> -normalizeVector_f GetNormalizeFunc(void) { - return normalizeVector_imp; -} - -template <> -normalizeVector_f GetNormalizeFunc(void) { - return normalizeVector_imp; -} - -template <> -normalizeVector_f GetNormalizeFunc(void) { - if (is_little_endian()) { - return bfloat16_normalizeVector; - } else { - return bfloat16_normalizeVector; - } -} - -template <> -normalizeVector_f GetNormalizeFunc(void) { - return float16_normalizeVector; -} - -/** The returned function computes the norm and stores it at the end of the given vector */ -template <> -normalizeVector_f GetNormalizeFunc(void) { - return integer_normalizeVector; -} -template <> -normalizeVector_f GetNormalizeFunc(void) { - return integer_normalizeVector; -} - -} // namespace spaces diff --git a/src/VecSim/spaces/spaces.h b/src/VecSim/spaces/spaces.h deleted file mode 100644 index 982d3f749..000000000 --- a/src/VecSim/spaces/spaces.h +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/vec_sim_common.h" // enum VecSimMetric -#include "space_includes.h" - -namespace spaces { - -template -using dist_func_t = RET_TYPE (*)(const void *, const void *, size_t); - -// Get the distance function for comparing vectors of type VecType1 and VecType2, for a given metric -// and dimension. The returned function has the signature: dist(VecType1*, VecType2*, size_t) -> -// DistType. VecType2 defaults to VecType1 when both vectors are of the same type. The alignment -// hint is set based on the chosen implementation and available optimizations. -template -dist_func_t GetDistFunc(VecSimMetric metric, size_t dim, unsigned char *alignment); - -template -using normalizeVector_f = void (*)(void *input_vector, const size_t dim); - -template -normalizeVector_f GetNormalizeFunc(); - -static int inline is_little_endian() { - unsigned int x = 1; - return *(char *)&x; -} - -static inline auto getCpuOptimizationFeatures(const void *arch_opt = nullptr) { - -#if defined(CPU_FEATURES_ARCH_AARCH64) - using FeaturesType = cpu_features::Aarch64Features; - constexpr auto getFeatures = cpu_features::GetAarch64Info; -#else - using FeaturesType = cpu_features::X86Features; // Fallback - constexpr auto getFeatures = cpu_features::GetX86Info; -#endif - return arch_opt ? *static_cast(arch_opt) : getFeatures().features; -} - -} // namespace spaces diff --git a/src/VecSim/tombstone_interface.h b/src/VecSim/tombstone_interface.h deleted file mode 100644 index 0d72cf121..000000000 --- a/src/VecSim/tombstone_interface.h +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once -#include -#include "vec_sim_common.h" - -/* - * Defines a simple tombstone API for indexes. - * Every index that has to implement "marking as deleted" mechanism should inherit this API and - * implement the required functions. The implementation should also update the `numMarkedDeleted` - * property to hold the number of vectors marked as deleted. - */ -struct VecSimIndexTombstone { -protected: - size_t numMarkedDeleted; - -public: - VecSimIndexTombstone() : numMarkedDeleted(0) {} - ~VecSimIndexTombstone() = default; - - inline size_t getNumMarkedDeleted() const { return numMarkedDeleted; } - - /** - * @param label vector to mark as deleted - * @return a vector of internal ids that has been marked as deleted (to be disposed later on). - */ - virtual inline vecsim_stl::vector markDelete(labelType label) = 0; -}; diff --git a/src/VecSim/types/bfloat16.h b/src/VecSim/types/bfloat16.h deleted file mode 100644 index 9e1430f96..000000000 --- a/src/VecSim/types/bfloat16.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include -#include -#include - -namespace vecsim_types { -struct bfloat16 { - uint16_t val; - bfloat16() = default; - explicit constexpr bfloat16(uint16_t val) : val(val) {} - operator uint16_t() const { return val; } -}; - -static inline bfloat16 float_to_bf16(const float ff) { - uint32_t *p_f32 = (uint32_t *)&ff; - uint32_t f32 = *p_f32; - uint32_t lsb = (f32 >> 16) & 1; - uint32_t round = lsb + 0x7FFF; - f32 += round; - return bfloat16(f32 >> 16); -} - -template -inline float bfloat16_to_float32(bfloat16 val) { - size_t constexpr bytes_offset = is_little ? 1 : 0; - float result = 0; - bfloat16 *p_result = (bfloat16 *)&result + bytes_offset; - *p_result = val; - return result; -} - -} // namespace vecsim_types diff --git a/src/VecSim/types/float16.h b/src/VecSim/types/float16.h deleted file mode 100644 index fef2fa0b3..000000000 --- a/src/VecSim/types/float16.h +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include -#include -#include -namespace vecsim_types { -struct float16 { - uint16_t val; - float16() = default; - explicit constexpr float16(uint16_t val) : val(val) {} - operator uint16_t() const { return val; } -}; - -inline float _interpret_as_float(uint32_t num) { - void *num_ptr = # - return *(float *)num_ptr; -} - -inline int32_t _interpret_as_int(float num) { - void *num_ptr = # - return *(int32_t *)num_ptr; -} - -static inline float FP16_to_FP32(float16 input) { - // https://gist.github.com/2144712 - // Fabian "ryg" Giesen. - - const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift - - int32_t o = ((int32_t)(input & 0x7fffu)) << 13; // exponent/mantissa bits - int32_t exp = shifted_exp & o; // just the exponent - o += (int32_t)(127 - 15) << 23; // exponent adjust - - int32_t infnan_val = o + ((int32_t)(128 - 16) << 23); - int32_t zerodenorm_val = - _interpret_as_int(_interpret_as_float(o + (1u << 23)) - _interpret_as_float(113u << 23)); - int32_t reg_val = (exp == 0) ? zerodenorm_val : o; - - int32_t sign_bit = ((int32_t)(input & 0x8000u)) << 16; - return _interpret_as_float(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit); -} - -static inline float16 FP32_to_FP16(float input) { - // via Fabian "ryg" Giesen. - // https://gist.github.com/2156668 - uint32_t sign_mask = 0x80000000u; - int32_t o; - - uint32_t fint = _interpret_as_int(input); - uint32_t sign = fint & sign_mask; - fint ^= sign; - - // NOTE all the integer compares in this function can be safely - // compiled into signed compares since all operands are below - // 0x80000000. Important if you want fast straight SSE2 code (since - // there's no unsigned PCMPGTD). - - // Inf or NaN (all exponent bits set) - // NaN->qNaN and Inf->Inf - // unconditional assignment here, will override with right value for - // the regular case below. - uint32_t f32infty = 255u << 23; - o = (fint > f32infty) ? 0x7e00u : 0x7c00u; - - // (De)normalized number or zero - // update fint unconditionally to save the blending; we don't need it - // anymore for the Inf/NaN case anyway. - - const uint32_t round_mask = ~0xfffu; - const uint32_t magic = 15u << 23; - - // Shift exponent down, denormalize if necessary. - // NOTE This represents half-float denormals using single - // precision denormals. The main reason to do this is that - // there's no shift with per-lane variable shifts in SSE*, which - // we'd otherwise need. It has some funky side effects though: - // - This conversion will actually respect the FTZ (Flush To Zero) - // flag in MXCSR - if it's set, no half-float denormals will be - // generated. I'm honestly not sure whether this is good or - // bad. It's definitely interesting. - // - If the underlying HW doesn't support denormals (not an issue - // with Intel CPUs, but might be a problem on GPUs or PS3 SPUs), - // you will always get flush-to-zero behavior. This is bad, - // unless you're on a CPU where you don't care. - // - Denormals tend to be slow. FP32 denormals are rare in - // practice outside of things like recursive filters in DSP - - // not a typical half-float application. Whether FP16 denormals - // are rare in practice, I don't know. Whatever slow path your - // HW may or may not have for denormals, this may well hit it. - float fscale = _interpret_as_float(fint & round_mask) * _interpret_as_float(magic); - fscale = std::min(fscale, _interpret_as_float((31u << 23) - 0x1000u)); - int32_t fint2 = _interpret_as_int(fscale) - round_mask; - - if (fint < f32infty) - o = fint2 >> 13; // Take the bits! - - return float16(o | (sign >> 16)); -} - -} // namespace vecsim_types diff --git a/src/VecSim/types/sq8.h b/src/VecSim/types/sq8.h deleted file mode 100644 index dfe296321..000000000 --- a/src/VecSim/types/sq8.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include -#include -#include "VecSim/vec_sim_common.h" - -namespace vecsim_types { - -// Represents a scalar-quantized 8-bit blob with reconstruction metadata -struct sq8 { - using value_type = uint8_t; - - // Metadata layout indices (stored after quantized values) - enum MetadataIndex : size_t { - MIN_VAL = 0, - DELTA = 1, - SUM = 2, - SUM_SQUARES = 3 // Only for L2 - }; - - enum QueryMetadataIndex : size_t { - SUM_QUERY = 0, - SUM_SQUARES_QUERY = 1 // Only for L2 - }; - - // Template on metric — compile-time constant when metric is known - template - static constexpr size_t storage_metadata_count() { - return (Metric == VecSimMetric_L2) ? 4 : 3; - } - - template - static constexpr size_t query_metadata_count() { - return (Metric == VecSimMetric_L2) ? 2 : 1; - } -}; - -} // namespace vecsim_types diff --git a/src/VecSim/utils/alignment.h b/src/VecSim/utils/alignment.h deleted file mode 100644 index e26f1de22..000000000 --- a/src/VecSim/utils/alignment.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#if defined(__GNUC__) || defined(__clang__) -#define PORTABLE_ALIGN16 __attribute__((aligned(16))) -#define PORTABLE_ALIGN32 __attribute__((aligned(32))) -#define PORTABLE_ALIGN64 __attribute__((aligned(64))) -#elif defined(_MSC_VER) -#define PORTABLE_ALIGN16 __declspec(align(16)) -#define PORTABLE_ALIGN32 __declspec(align(32)) -#define PORTABLE_ALIGN64 __declspec(align(64)) -#endif - -#define PORTABLE_ALIGN PORTABLE_ALIGN64 - -// TODO: relax the above alignment requirements according to the CPU architecture -#ifndef PORTABLE_ALIGN -#if defined(__AVX512F__) -#define PORTABLE_ALIGN PORTABLE_ALIGN64 -#elif defined(__AVX__) -#define PORTABLE_ALIGN PORTABLE_ALIGN32 -#elif defined(__SSE__) -#define PORTABLE_ALIGN PORTABLE_ALIGN16 -#else -#define PORTABLE_ALIGN -#endif -#endif // ifndef PORTABLE_ALIGN diff --git a/src/VecSim/utils/query_result_utils.h b/src/VecSim/utils/query_result_utils.h deleted file mode 100644 index 5bb8f35c9..000000000 --- a/src/VecSim/utils/query_result_utils.h +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/query_result_definitions.h" -#include - -#define VECSIM_EPSILON (1e-6) - -inline bool double_eq(double a, double b) { return fabs(a - b) < VECSIM_EPSILON; } - -// Compare two results by score, and if the scores are equal, by id. -inline int cmpVecSimQueryResultByScoreThenId(const VecSimQueryResultContainer::iterator res1, - const VecSimQueryResultContainer::iterator res2) { - return !double_eq(res1->score, res2->score) ? (res1->score > res2->score ? 1 : -1) - : (int)(res1->id - res2->id); -} - -// Append the current result to the merged results, after verifying that it did not added yet (if -// verification is needed). Also update the set, limit and the current result. -template -inline constexpr void maybe_append(VecSimQueryResultContainer &results, - VecSimQueryResultContainer::iterator &cur_res, - std::unordered_set &ids, size_t &limit) { - // In a single line, checks (only if a check is needed) if we already inserted the current id to - // the merged results, add it to the set if not, and returns its conclusion. - if (!withSet || ids.insert(cur_res->id).second) { - results.push_back(*cur_res); - limit--; - } - cur_res++; -} - -// Assumes that the arrays are sorted by score firstly and by id secondarily. -// By the end of the function, the first and second referenced pointers will point to the first -// element that was not merged (in each array), or to the end of the array if it was merged -// completely. -template -std::pair merge_results(VecSimQueryResultContainer &results, - VecSimQueryResultContainer &first, - VecSimQueryResultContainer &second, size_t limit) { - // Allocate the merged results array with the minimum size needed. - // Min of the limit and the sum of the lengths of the two arrays. - results.reserve(std::min(limit, first.size() + second.size())); - // Will hold the ids of the results we've already added to the merged results. - // Will be used only if withSet is true. - std::unordered_set ids; - auto cur_first = first.begin(); - auto cur_second = second.begin(); - - while (limit && cur_first != first.end() && cur_second != second.end()) { - int cmp = cmpVecSimQueryResultByScoreThenId(cur_first, cur_second); - if (cmp > 0) { - maybe_append(results, cur_second, ids, limit); - } else if (cmp < 0) { - maybe_append(results, cur_first, ids, limit); - } else { - // Even if `withSet` is true, we encountered an exact duplicate, so we know that this id - // didn't appear before in both arrays, and it won't appear again in both arrays, so we - // can add it to the merged results, and skip adding it to the set. - results.push_back(*cur_first); - cur_first++; - cur_second++; - limit--; - } - } - - // If we didn't exit the loop because of the limit, at least one of the arrays is empty. - // We can try appending the rest of the other array. - if (limit != 0) { - if (cur_first == first.end()) { - while (limit && cur_second != second.end()) { - maybe_append(results, cur_second, ids, limit); - } - } else { - while (limit && cur_first != first.end()) { - maybe_append(results, cur_first, ids, limit); - } - } - } - - // Return the number of elements that were merged from each array. - return {cur_first - first.begin(), cur_second - second.begin()}; -} - -// Assumes that the arrays are sorted by score firstly and by id secondarily. -// Use withSet=false if you can guarantee that shared ids between the two lists -// will also have identical scores. In this case, any duplicates will naturally align -// at the front of both lists during the merge, so they can be removed without explicitly -// tracking seen ids — enabling a more efficient merge. -template -VecSimQueryReply *merge_result_lists(VecSimQueryReply *first, VecSimQueryReply *second, - size_t limit) { - - auto mergedResults = new VecSimQueryReply(first->results.getAllocator()); - merge_results(mergedResults->results, first->results, second->results, limit); - - VecSimQueryReply_Free(first); - VecSimQueryReply_Free(second); - return mergedResults; -} - -// Concatenate the results of two queries into the results of the first query, consuming the second. -static inline void concat_results(VecSimQueryReply *first, VecSimQueryReply *second) { - first->results.insert(first->results.end(), second->results.begin(), second->results.end()); - VecSimQueryReply_Free(second); -} - -// Sorts the results by id and removes duplicates. -// Assumes that a result can appear at most twice in the results list. -// @returns the number of unique results. This should be set to be the new length of the results -template -void filter_results_by_id(VecSimQueryReply *results) { - if (VecSimQueryReply_Len(results) < 2) { - return; - } - sort_results_by_id(results); - - size_t i, cur_end; - for (i = 0, cur_end = 0; i < VecSimQueryReply_Len(results) - 1; i++, cur_end++) { - const VecSimQueryResult *cur_res = results->results.data() + i; - const VecSimQueryResult *next_res = cur_res + 1; - if (VecSimQueryResult_GetId(cur_res) == VecSimQueryResult_GetId(next_res)) { - if (IsMulti) { - // On multi value index, scores might be different and we want to keep the lower - // score. - if (VecSimQueryResult_GetScore(cur_res) < VecSimQueryResult_GetScore(next_res)) { - results->results[cur_end] = *cur_res; - } else { - results->results[cur_end] = *next_res; - } - } else { - // On single value index, scores are the same so we can keep any of the results. - results->results[cur_end] = *cur_res; - } - // Assuming every id can appear at most twice, we can skip the next comparison between - // the current and the next result. - i++; - } else { - results->results[cur_end] = *cur_res; - } - } - // If the last result is unique, we need to add it to the results. - if (i == VecSimQueryReply_Len(results) - 1) { - results->results[cur_end++] = results->results[i]; - // Logically, we should increment cur_end and i here, but we don't need to because it won't - // affect the rest of the function. - } - results->results.resize(cur_end); -} diff --git a/src/VecSim/utils/serializer.h b/src/VecSim/utils/serializer.h deleted file mode 100644 index 211a887b6..000000000 --- a/src/VecSim/utils/serializer.h +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include -#include - -/* - * Serializer Abstraction Layer for Vector Indexes - * ----------------------------------------------- - * This header defines the base `Serializer` class, which provides a generic interface for - * serializing vector indexes to disk. It is designed to be inherited - * by algorithm-specific serializers (e.g., HNSWSerializer, SVSSerializer), and provides a - * versioned, extensible mechanism for managing persistent representations of index state. - * Each serializer subclass must define its own EncodingVersion enum. - * How to Extend: - * 1. Derive a new class from `Serializer`, e.g., `MyIndexSerializer`. - * 2. Implement `saveIndex()` and `saveIndexIMP()`. - * 3. Implement `saveIndexFields()` to write out relevant fields in a deterministic order. - * 4. Optionally, add version-aware deserialization methods. - * - * Example Inheritance Tree: - * Serializer (abstract) - * ├── HNSWSerializer - * │ └── HNSWIndex - * └── SVSSerializer - * └── SVSIndex - */ - -class Serializer { -public: - enum class EncodingVersion { INVALID }; - - Serializer(EncodingVersion version = EncodingVersion::INVALID) : m_version(version) {} - - virtual void saveIndex(const std::string &location) = 0; - - EncodingVersion getVersion() const; - - static EncodingVersion ReadVersion(std::ifstream &input); - - // Helper functions for serializing the index. - template - static inline void writeBinaryPOD(std::ostream &out, const T &podRef) { - out.write((char *)&podRef, sizeof(T)); - } - - template - static inline void readBinaryPOD(std::istream &in, T &podRef) { - in.read((char *)&podRef, sizeof(T)); - } - -protected: - EncodingVersion m_version; - - // Index memory size might be changed during index saving. - virtual void saveIndexIMP(std::ofstream &output) = 0; - -private: - virtual void saveIndexFields(std::ofstream &output) const = 0; -}; diff --git a/src/VecSim/utils/updatable_heap.h b/src/VecSim/utils/updatable_heap.h deleted file mode 100644 index f79e78eb5..000000000 --- a/src/VecSim/utils/updatable_heap.h +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/memory/vecsim_base.h" -#include "vecsim_stl.h" -#include -#include - -namespace vecsim_stl { - -// This class implements updatable max heap. insertion, updating and deletion (of the max priority) -// are done in O(log(n)), finding the max priority takes O(1), as well as getting the size and if -// the heap is empty. -// The priority can only be updated DOWN, because we only care about the lowest distance score for a -// vector, and that is the use of this heap. We use it to hold the top candidates while performing -// VSS on multi-valued indexes, and we need to find and delete the worst score easily. -template -class updatable_max_heap : public abstract_priority_queue { -private: - // Maps a priority that exists in the heap to its value. - using PVmultimap = std::multimap, - VecsimSTLAllocator>>; - PVmultimap priorityToValue; - - // Maps a value in the heap to its node in the `priorityToValue` multimap. - std::unordered_map, std::equal_to, - VecsimSTLAllocator>> - valueToNode; - -public: - updatable_max_heap(const std::shared_ptr &alloc); - ~updatable_max_heap() = default; - - inline void emplace(Priority p, Value v) override; - inline bool empty() const override; - inline void pop() override; - inline const std::pair top() const override; - inline size_t size() const override; - -private: - inline auto top_ptr() const; -}; - -template -updatable_max_heap::updatable_max_heap( - const std::shared_ptr &alloc) - : abstract_priority_queue(alloc), priorityToValue(alloc), valueToNode(alloc) {} - -template -size_t updatable_max_heap::size() const { - return valueToNode.size(); -} - -template -bool updatable_max_heap::empty() const { - return valueToNode.empty(); -} - -template -auto updatable_max_heap::top_ptr() const { - // The `.begin()` of "priorityToValue" is the max priority element. - auto x = priorityToValue.begin(); - // x has the max priority, but there might be multiple values with the same priority. We need to - // find the value with the highest value as well. - auto [begin, end] = priorityToValue.equal_range(x->first); - auto y = std::max_element(begin, end, - [](const auto &a, const auto &b) { return a.second < b.second; }); - return y; -} - -template -const std::pair updatable_max_heap::top() const { - auto x = top_ptr(); - return *x; -} - -template -void updatable_max_heap::pop() { - auto to_remove = top_ptr(); - valueToNode.erase(to_remove->second); // Erase by the value of the top pair. - priorityToValue.erase(to_remove); // Erase by iterator deletes only the specific pair. -} - -template -void updatable_max_heap::emplace(Priority p, Value v) { - // This function either inserting a new value or updating the priority of the value, if the new - // priority is higher. - auto existing_v = valueToNode.find(v); - if (existing_v == valueToNode.end()) { - // Case 1: value is not in the heap. Insert it. - auto node = priorityToValue.emplace(p, v); - valueToNode.emplace(v, node); - } else if (existing_v->second->first > p) { - // Case 2: value is in the heap, and its new priority is higher. Update its priority. - - // Erase the old priority from the `priorityToValue` map. - // Erase by iterator deletes only the specific pair. - priorityToValue.erase(existing_v->second); - // Re-insert the updated value to the `priorityToValue` map. - auto new_node = priorityToValue.emplace(p, v); - // Update the node of the value. - existing_v->second = new_node; - } -} - -} // namespace vecsim_stl diff --git a/src/VecSim/utils/vec_utils.cpp b/src/VecSim/utils/vec_utils.cpp deleted file mode 100644 index edb3fc989..000000000 --- a/src/VecSim/utils/vec_utils.cpp +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "vec_utils.h" -#include "VecSim/query_result_definitions.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include -#include -#include -#include -#include - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; - -const char *VecSimCommonStrings::ALGORITHM_STRING = "ALGORITHM"; -const char *VecSimCommonStrings::FLAT_STRING = "FLAT"; -const char *VecSimCommonStrings::HNSW_STRING = "HNSW"; -const char *VecSimCommonStrings::TIERED_STRING = "TIERED"; -const char *VecSimCommonStrings::SVS_STRING = "SVS"; - -const char *VecSimCommonStrings::TYPE_STRING = "TYPE"; -const char *VecSimCommonStrings::FLOAT32_STRING = "FLOAT32"; -const char *VecSimCommonStrings::FLOAT64_STRING = "FLOAT64"; -const char *VecSimCommonStrings::BFLOAT16_STRING = "BFLOAT16"; -const char *VecSimCommonStrings::FLOAT16_STRING = "FLOAT16"; -const char *VecSimCommonStrings::INT8_STRING = "INT8"; -const char *VecSimCommonStrings::UINT8_STRING = "UINT8"; -const char *VecSimCommonStrings::INT32_STRING = "INT32"; -const char *VecSimCommonStrings::INT64_STRING = "INT64"; - -const char *VecSimCommonStrings::METRIC_STRING = "METRIC"; -const char *VecSimCommonStrings::COSINE_STRING = "COSINE"; -const char *VecSimCommonStrings::IP_STRING = "IP"; -const char *VecSimCommonStrings::L2_STRING = "L2"; - -const char *VecSimCommonStrings::DIMENSION_STRING = "DIMENSION"; -const char *VecSimCommonStrings::INDEX_SIZE_STRING = "INDEX_SIZE"; -const char *VecSimCommonStrings::INDEX_LABEL_COUNT_STRING = "INDEX_LABEL_COUNT"; -const char *VecSimCommonStrings::IS_MULTI_STRING = "IS_MULTI_VALUE"; -const char *VecSimCommonStrings::IS_DISK_STRING = "IS_DISK"; -const char *VecSimCommonStrings::MEMORY_STRING = "MEMORY"; - -const char *VecSimCommonStrings::HNSW_EF_RUNTIME_STRING = "EF_RUNTIME"; -const char *VecSimCommonStrings::HNSW_M_STRING = "M"; -const char *VecSimCommonStrings::HNSW_EF_CONSTRUCTION_STRING = "EF_CONSTRUCTION"; -const char *VecSimCommonStrings::EPSILON_STRING = "EPSILON"; -const char *VecSimCommonStrings::HNSW_MAX_LEVEL = "MAX_LEVEL"; -const char *VecSimCommonStrings::HNSW_ENTRYPOINT = "ENTRYPOINT"; -const char *VecSimCommonStrings::NUM_MARKED_DELETED = "NUMBER_OF_MARKED_DELETED"; - -const char *VecSimCommonStrings::SVS_SEARCH_WS_STRING = "SEARCH_WINDOW_SIZE"; -const char *VecSimCommonStrings::SVS_CONSTRUCTION_WS_STRING = "CONSTRUCTION_WINDOW_SIZE"; -const char *VecSimCommonStrings::SVS_SEARCH_BC_STRING = "SEARCH_BUFFER_CAPACITY"; -const char *VecSimCommonStrings::SVS_USE_SEARCH_HISTORY_STRING = "USE_SEARCH_HISTORY"; -const char *VecSimCommonStrings::SVS_ALPHA_STRING = "ALPHA"; -const char *VecSimCommonStrings::SVS_QUANT_BITS_STRING = "QUANT_BITS"; -const char *VecSimCommonStrings::SVS_GRAPH_MAX_DEGREE_STRING = "GRAPH_MAX_DEGREE"; -const char *VecSimCommonStrings::SVS_MAX_CANDIDATE_POOL_SIZE_STRING = "MAX_CANDIDATE_POOL_SIZE"; -const char *VecSimCommonStrings::SVS_PRUNE_TO_STRING = "PRUNE_TO"; -const char *VecSimCommonStrings::SVS_NUM_THREADS_STRING = "NUM_THREADS"; -const char *VecSimCommonStrings::SVS_LAST_RESERVED_THREADS_STRING = "LAST_RESERVED_NUM_THREADS"; -const char *VecSimCommonStrings::SVS_LEANVEC_DIM_STRING = "LEANVEC_DIMENSION"; - -const char *VecSimCommonStrings::BLOCK_SIZE_STRING = "BLOCK_SIZE"; -const char *VecSimCommonStrings::SEARCH_MODE_STRING = "LAST_SEARCH_MODE"; -const char *VecSimCommonStrings::HYBRID_POLICY_STRING = "HYBRID_POLICY"; -const char *VecSimCommonStrings::BATCH_SIZE_STRING = "BATCH_SIZE"; - -const char *VecSimCommonStrings::TIERED_MANAGEMENT_MEMORY_STRING = "MANAGEMENT_LAYER_MEMORY"; -const char *VecSimCommonStrings::TIERED_BACKGROUND_INDEXING_STRING = "BACKGROUND_INDEXING"; -const char *VecSimCommonStrings::TIERED_BUFFER_LIMIT_STRING = "TIERED_BUFFER_LIMIT"; -const char *VecSimCommonStrings::FRONTEND_INDEX_STRING = "FRONTEND_INDEX"; -const char *VecSimCommonStrings::BACKEND_INDEX_STRING = "BACKEND_INDEX"; -// Tiered HNSW specific -const char *VecSimCommonStrings::TIERED_HNSW_SWAP_JOBS_THRESHOLD_STRING = - "TIERED_HNSW_SWAP_JOBS_THRESHOLD"; -// Tiered SVS specific -const char *VecSimCommonStrings::TIERED_SVS_TRAINING_THRESHOLD_STRING = - "TIERED_SVS_TRAINING_THRESHOLD"; -const char *VecSimCommonStrings::TIERED_SVS_UPDATE_THRESHOLD_STRING = "TIERED_SVS_UPDATE_THRESHOLD"; -const char *VecSimCommonStrings::TIERED_SVS_THREADS_RESERVE_TIMEOUT_STRING = - "TIERED_SVS_THREADS_RESERVE_TIMEOUT"; - -// Log levels -const char *VecSimCommonStrings::LOG_DEBUG_STRING = "debug"; -const char *VecSimCommonStrings::LOG_VERBOSE_STRING = "verbose"; -const char *VecSimCommonStrings::LOG_NOTICE_STRING = "notice"; -const char *VecSimCommonStrings::LOG_WARNING_STRING = "warning"; - -void sort_results_by_id(VecSimQueryReply *rep) { - std::sort(rep->results.begin(), rep->results.end(), - [](const VecSimQueryResult &a, const VecSimQueryResult &b) { return a.id < b.id; }); -} - -void sort_results_by_score(VecSimQueryReply *rep) { - std::sort( - rep->results.begin(), rep->results.end(), - [](const VecSimQueryResult &a, const VecSimQueryResult &b) { return a.score < b.score; }); -} - -void sort_results_by_score_then_id(VecSimQueryReply *rep) { - std::sort(rep->results.begin(), rep->results.end(), - [](const VecSimQueryResult &a, const VecSimQueryResult &b) { - if (a.score == b.score) { - return a.id < b.id; - } - return a.score < b.score; - }); -} - -void sort_results(VecSimQueryReply *rep, VecSimQueryReply_Order order) { - switch (order) { - case BY_ID: - return sort_results_by_id(rep); - case BY_SCORE: - return sort_results_by_score(rep); - case BY_SCORE_THEN_ID: - return sort_results_by_score_then_id(rep); - } -} - -VecSimResolveCode validate_positive_integer_param(VecSimRawParam rawParam, long long *val) { - char *ep; // For checking that strtoll used all rawParam.valLen chars. - errno = 0; - *val = strtoll(rawParam.value, &ep, 0); - // Here we verify that val is positive and strtoll was successful. - // The last test checks that the entire rawParam.value was used. - // We catch here inputs like "3.14", "123text" and so on. - if (*val <= 0 || *val == LLONG_MAX || errno != 0 || (rawParam.value + rawParam.valLen) != ep) { - return VecSimParamResolverErr_BadValue; - } - return VecSimParamResolver_OK; -} - -VecSimResolveCode validate_positive_double_param(VecSimRawParam rawParam, double *val) { - char *ep; // For checking that strtold used all rawParam.valLen chars. - errno = 0; - *val = strtod(rawParam.value, &ep); - // Here we verify that val is positive and strtod was successful. - // The last test checks that the entire rawParam.value was used. - // We catch here inputs like "-3.14", "123text" and so on. - if (*val <= 0 || *val == DBL_MAX || errno != 0 || (rawParam.value + rawParam.valLen) != ep) { - return VecSimParamResolverErr_BadValue; - } - return VecSimParamResolver_OK; -} - -VecSimResolveCode validate_vecsim_bool_param(VecSimRawParam rawParam, VecSimOptionMode *val) { - // Here we verify that given value is strictly ON or OFF - std::string value(rawParam.value, rawParam.valLen); - std::transform(value.begin(), value.end(), value.begin(), ::toupper); - if (value == "ON") { - *val = VecSimOption_ENABLE; - } else if (value == "OFF") { - *val = VecSimOption_DISABLE; - } else if (value == "AUTO") { - *val = VecSimOption_AUTO; - } else { - return VecSimParamResolverErr_BadValue; - } - return VecSimParamResolver_OK; -} - -const char *VecSimAlgo_ToString(VecSimAlgo vecsimAlgo) { - switch (vecsimAlgo) { - case VecSimAlgo_BF: - return VecSimCommonStrings::FLAT_STRING; - case VecSimAlgo_HNSWLIB: - return VecSimCommonStrings::HNSW_STRING; - case VecSimAlgo_TIERED: - return VecSimCommonStrings::TIERED_STRING; - case VecSimAlgo_SVS: - return VecSimCommonStrings::SVS_STRING; - } - return NULL; -} -const char *VecSimType_ToString(VecSimType vecsimType) { - switch (vecsimType) { - case VecSimType_FLOAT32: - return VecSimCommonStrings::FLOAT32_STRING; - case VecSimType_FLOAT64: - return VecSimCommonStrings::FLOAT64_STRING; - case VecSimType_BFLOAT16: - return VecSimCommonStrings::BFLOAT16_STRING; - case VecSimType_FLOAT16: - return VecSimCommonStrings::FLOAT16_STRING; - case VecSimType_INT8: - return VecSimCommonStrings::INT8_STRING; - case VecSimType_UINT8: - return VecSimCommonStrings::UINT8_STRING; - case VecSimType_INT32: - return VecSimCommonStrings::INT32_STRING; - case VecSimType_INT64: - return VecSimCommonStrings::INT64_STRING; - } - return NULL; -} - -const char *VecSimMetric_ToString(VecSimMetric vecsimMetric) { - switch (vecsimMetric) { - case VecSimMetric_Cosine: - return "COSINE"; - case VecSimMetric_IP: - return "IP"; - case VecSimMetric_L2: - return "L2"; - } - return NULL; -} - -const char *VecSimSearchMode_ToString(VecSearchMode vecsimSearchMode) { - switch (vecsimSearchMode) { - case EMPTY_MODE: - return "EMPTY_MODE"; - case STANDARD_KNN: - return "STANDARD_KNN"; - case HYBRID_ADHOC_BF: - return "HYBRID_ADHOC_BF"; - case HYBRID_BATCHES: - return "HYBRID_BATCHES"; - case HYBRID_BATCHES_TO_ADHOC_BF: - return "HYBRID_BATCHES_TO_ADHOC_BF"; - case RANGE_QUERY: - return "RANGE_QUERY"; - } - return NULL; -} - -const char *VecSimQuantBits_ToString(VecSimSvsQuantBits quantBits) { - switch (quantBits) { - case VecSimSvsQuant_NONE: - return "NONE"; - case VecSimSvsQuant_Scalar: - return "Scalar"; - case VecSimSvsQuant_4: - return "4"; - case VecSimSvsQuant_8: - return "8"; - case VecSimSvsQuant_4x4: - return "4x4"; - case VecSimSvsQuant_4x8: - return "4x8"; - case VecSimSvsQuant_4x8_LeanVec: - return "4x8_LeanVec"; - case VecSimSvsQuant_8x8_LeanVec: - return "8x8_LeanVec"; - } - return NULL; -} - -size_t VecSimType_sizeof(VecSimType type) { - switch (type) { - case VecSimType_FLOAT32: - return sizeof(float); - case VecSimType_FLOAT64: - return sizeof(double); - case VecSimType_BFLOAT16: - return sizeof(bfloat16); - case VecSimType_FLOAT16: - return sizeof(float16); - case VecSimType_INT8: - return sizeof(int8_t); - case VecSimType_UINT8: - return sizeof(uint8_t); - case VecSimType_INT32: - return sizeof(int32_t); - case VecSimType_INT64: - return sizeof(int64_t); - } - return 0; -} - -size_t VecSimParams_GetStoredDataSize(VecSimType type, size_t dim, VecSimMetric metric) { - size_t storedDataSize = VecSimType_sizeof(type) * dim; - if (metric == VecSimMetric_Cosine && (type == VecSimType_INT8 || type == VecSimType_UINT8)) { - storedDataSize += sizeof(float); // For the norm - } - return storedDataSize; -} diff --git a/src/VecSim/utils/vec_utils.h b/src/VecSim/utils/vec_utils.h deleted file mode 100644 index 53adb62bb..000000000 --- a/src/VecSim/utils/vec_utils.h +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include -#include "VecSim/vec_sim_common.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" -#include "VecSim/query_results.h" -#include "VecSim/utils/vecsim_stl.h" -#include -#include - -struct VecSimCommonStrings { -public: - static const char *ALGORITHM_STRING; - static const char *FLAT_STRING; - static const char *HNSW_STRING; - static const char *TIERED_STRING; - static const char *SVS_STRING; - - static const char *TYPE_STRING; - static const char *FLOAT32_STRING; - static const char *FLOAT64_STRING; - static const char *BFLOAT16_STRING; - static const char *FLOAT16_STRING; - static const char *INT8_STRING; - static const char *UINT8_STRING; - static const char *INT32_STRING; - static const char *INT64_STRING; - - static const char *METRIC_STRING; - static const char *COSINE_STRING; - static const char *IP_STRING; - static const char *L2_STRING; - - static const char *DIMENSION_STRING; - static const char *INDEX_SIZE_STRING; - static const char *INDEX_LABEL_COUNT_STRING; - static const char *IS_MULTI_STRING; - static const char *IS_DISK_STRING; - static const char *MEMORY_STRING; - - static const char *HNSW_EF_RUNTIME_STRING; - static const char *HNSW_EF_CONSTRUCTION_STRING; - static const char *HNSW_M_STRING; - static const char *EPSILON_STRING; - static const char *HNSW_MAX_LEVEL; - static const char *HNSW_ENTRYPOINT; - static const char *NUM_MARKED_DELETED; - // static const char *HNSW_VISITED_NODES_POOL_SIZE_STRING; - - static const char *SVS_SEARCH_WS_STRING; - static const char *SVS_CONSTRUCTION_WS_STRING; - static const char *SVS_SEARCH_BC_STRING; - static const char *SVS_USE_SEARCH_HISTORY_STRING; - static const char *SVS_ALPHA_STRING; - static const char *SVS_QUANT_BITS_STRING; - static const char *SVS_GRAPH_MAX_DEGREE_STRING; - static const char *SVS_MAX_CANDIDATE_POOL_SIZE_STRING; - static const char *SVS_PRUNE_TO_STRING; - static const char *SVS_NUM_THREADS_STRING; - static const char *SVS_LAST_RESERVED_THREADS_STRING; - static const char *SVS_LEANVEC_DIM_STRING; - - static const char *BLOCK_SIZE_STRING; - static const char *SEARCH_MODE_STRING; - static const char *HYBRID_POLICY_STRING; - static const char *BATCH_SIZE_STRING; - - static const char *TIERED_MANAGEMENT_MEMORY_STRING; - static const char *TIERED_BACKGROUND_INDEXING_STRING; - static const char *TIERED_BUFFER_LIMIT_STRING; - static const char *FRONTEND_INDEX_STRING; - static const char *BACKEND_INDEX_STRING; - // tiered HNSW specific - static const char *TIERED_HNSW_SWAP_JOBS_THRESHOLD_STRING; - // tiered SVS specific - static const char *TIERED_SVS_TRAINING_THRESHOLD_STRING; - static const char *TIERED_SVS_UPDATE_THRESHOLD_STRING; - static const char *TIERED_SVS_THREADS_RESERVE_TIMEOUT_STRING; - - // Log levels - static const char *LOG_DEBUG_STRING; - static const char *LOG_VERBOSE_STRING; - static const char *LOG_NOTICE_STRING; - static const char *LOG_WARNING_STRING; -}; - -void sort_results_by_id(VecSimQueryReply *results); - -void sort_results_by_score(VecSimQueryReply *results); - -void sort_results_by_score_then_id(VecSimQueryReply *results); - -void sort_results(VecSimQueryReply *results, VecSimQueryReply_Order order); - -VecSimResolveCode validate_positive_integer_param(VecSimRawParam rawParam, long long *val); - -VecSimResolveCode validate_positive_double_param(VecSimRawParam rawParam, double *val); - -VecSimResolveCode validate_vecsim_bool_param(VecSimRawParam rawParam, VecSimOptionMode *val); - -const char *VecSimAlgo_ToString(VecSimAlgo vecsimAlgo); - -const char *VecSimType_ToString(VecSimType vecsimType); - -const char *VecSimMetric_ToString(VecSimMetric vecsimMetric); - -const char *VecSimSearchMode_ToString(VecSearchMode vecsimSearchMode); - -const char *VecSimQuantBits_ToString(VecSimSvsQuantBits quantBits); - -size_t VecSimType_sizeof(VecSimType vecsimType); - -/** Returns the size in bytes of a stored or query vector */ -size_t VecSimParams_GetStoredDataSize(VecSimType type, size_t dim, VecSimMetric metric); diff --git a/src/VecSim/utils/vecsim_stl.h b/src/VecSim/utils/vecsim_stl.h deleted file mode 100644 index a55eb968c..000000000 --- a/src/VecSim/utils/vecsim_stl.h +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "VecSim/memory/vecsim_base.h" -#include -#include -#include -#include -#include -#include - -namespace vecsim_stl { - -template -using unordered_map = std::unordered_map, std::equal_to, - VecsimSTLAllocator>>; - -template -class vector : public VecsimBaseObject, public std::vector> { -public: - explicit vector(const std::shared_ptr &alloc) - : VecsimBaseObject(alloc), std::vector>(alloc) {} - explicit vector(size_t cap, const std::shared_ptr &alloc) - : VecsimBaseObject(alloc), std::vector>(cap, alloc) {} - explicit vector(size_t cap, T val, const std::shared_ptr &alloc) - : VecsimBaseObject(alloc), std::vector>(cap, val, alloc) {} - - bool remove(T element) { - auto it = std::find(this->begin(), this->end(), element); - if (it != this->end()) { - // Swap the last element with the current one (equivalent to removing the element from - // the list). - *it = this->back(); - this->pop_back(); - return true; - } - return false; - } -}; - -template -struct abstract_priority_queue : public VecsimBaseObject { -public: - abstract_priority_queue(const std::shared_ptr &alloc) - : VecsimBaseObject(alloc) {} - ~abstract_priority_queue() = default; - - virtual void emplace(Priority p, Value v) = 0; - virtual bool empty() const = 0; - virtual void pop() = 0; - virtual const std::pair top() const = 0; - virtual size_t size() const = 0; -}; - -// max-heap -template , - vecsim_stl::vector>, - std::less>>> -struct max_priority_queue : public abstract_priority_queue, public std_queue { -public: - max_priority_queue(const std::shared_ptr &alloc) - : abstract_priority_queue(alloc), std_queue(alloc) {} - ~max_priority_queue() = default; - - void emplace(Priority p, Value v) override { std_queue::emplace(p, v); } - bool empty() const override { return std_queue::empty(); } - void pop() override { std_queue::pop(); } - const std::pair top() const override { return std_queue::top(); } - size_t size() const override { return std_queue::size(); } - - // Random order iteration - const auto begin() const { return this->c.begin(); } - const auto end() const { return this->c.end(); } -}; - -// min-heap -template -using min_priority_queue = - std::priority_queue, vecsim_stl::vector>, - std::greater>>; - -template -class set : public VecsimBaseObject, public std::set, VecsimSTLAllocator> { -public: - explicit set(const std::shared_ptr &alloc) - : VecsimBaseObject(alloc), std::set, VecsimSTLAllocator>(alloc) {} -}; - -template -class unordered_set - : public VecsimBaseObject, - public std::unordered_set, std::equal_to, VecsimSTLAllocator> { -public: - explicit unordered_set(const std::shared_ptr &alloc) - : VecsimBaseObject(alloc), - std::unordered_set, std::equal_to, VecsimSTLAllocator>(alloc) {} - explicit unordered_set(size_t n_bucket, const std::shared_ptr &alloc) - : VecsimBaseObject(alloc), - std::unordered_set, std::equal_to, VecsimSTLAllocator>(n_bucket, - alloc) {} -}; - -} // namespace vecsim_stl diff --git a/src/VecSim/vec_sim.cpp b/src/VecSim/vec_sim.cpp deleted file mode 100644 index 1cc8ea8b0..000000000 --- a/src/VecSim/vec_sim.cpp +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/vec_sim.h" -#include "VecSim/query_results.h" -#include "VecSim/query_result_definitions.h" -#include "VecSim/utils/vec_utils.h" -#include "VecSim/spaces/spaces.h" -#include "VecSim/index_factories/index_factory.h" -#include "VecSim/vec_sim_index.h" -#include "VecSim/types/bfloat16.h" -#include -#include "memory.h" - -extern "C" void VecSim_SetTimeoutCallbackFunction(timeoutCallbackFunction callback) { - VecSimIndex::setTimeoutCallbackFunction(callback); -} - -extern "C" void VecSim_SetLogCallbackFunction(logCallbackFunction callback) { - VecSimIndex::setLogCallbackFunction(callback); -} - -extern "C" void VecSim_SetWriteMode(VecSimWriteMode mode) { VecSimIndex::setWriteMode(mode); } - -static VecSimResolveCode _ResolveParams_EFRuntime(VecSimAlgo index_type, VecSimRawParam rparam, - VecSimQueryParams *qparams, - VecsimQueryType query_type) { - long long num_val; - // EF_RUNTIME is a valid parameter only in HNSW algorithm. - if (index_type != VecSimAlgo_HNSWLIB) { - return VecSimParamResolverErr_UnknownParam; - } - // EF_RUNTIME is invalid for range query - if (query_type == QUERY_TYPE_RANGE) { - return VecSimParamResolverErr_UnknownParam; - } - if (qparams->hnswRuntimeParams.efRuntime != 0) { - return VecSimParamResolverErr_AlreadySet; - } - if (validate_positive_integer_param(rparam, &num_val) != VecSimParamResolver_OK) { - return VecSimParamResolverErr_BadValue; - } - - qparams->hnswRuntimeParams.efRuntime = (size_t)num_val; - return VecSimParamResolver_OK; -} - -static VecSimResolveCode _ResolveParams_SearchWS(VecSimAlgo index_type, VecSimRawParam rparam, - VecSimQueryParams *qparams) { - long long num_val; - // SEARCH_WS is a valid parameter only in SVS algorithm. - if (index_type != VecSimAlgo_SVS) { - return VecSimParamResolverErr_UnknownParam; - } - if (qparams->svsRuntimeParams.windowSize != 0) { - return VecSimParamResolverErr_AlreadySet; - } - if (validate_positive_integer_param(rparam, &num_val) != VecSimParamResolver_OK) { - return VecSimParamResolverErr_BadValue; - } - - qparams->svsRuntimeParams.windowSize = (size_t)num_val; - return VecSimParamResolver_OK; -} - -static VecSimResolveCode _ResolveParams_SearchBC(VecSimAlgo index_type, VecSimRawParam rparam, - VecSimQueryParams *qparams) { - long long num_val; - // SEARCH_BC is a valid parameter only in SVS algorithm. - if (index_type != VecSimAlgo_SVS) { - return VecSimParamResolverErr_UnknownParam; - } - if (qparams->svsRuntimeParams.bufferCapacity != 0) { - return VecSimParamResolverErr_AlreadySet; - } - if (validate_positive_integer_param(rparam, &num_val) != VecSimParamResolver_OK) { - return VecSimParamResolverErr_BadValue; - } - - qparams->svsRuntimeParams.bufferCapacity = (size_t)num_val; - return VecSimParamResolver_OK; -} -static VecSimResolveCode _ResolveParams_UseSearchHistory(VecSimAlgo index_type, - VecSimRawParam rparam, - VecSimQueryParams *qparams) { - VecSimOptionMode bool_val; - // USE_SEARCH_HISTORY is a valid parameter only in SVS algorithm. - if (index_type != VecSimAlgo_SVS) { - return VecSimParamResolverErr_UnknownParam; - } - if (qparams->svsRuntimeParams.searchHistory != 0) { - return VecSimParamResolverErr_AlreadySet; - } - if (validate_vecsim_bool_param(rparam, &bool_val) != VecSimParamResolver_OK) { - return VecSimParamResolverErr_BadValue; - } - - qparams->svsRuntimeParams.searchHistory = bool_val; - return VecSimParamResolver_OK; -} - -static VecSimResolveCode _ResolveParams_BatchSize(VecSimRawParam rparam, VecSimQueryParams *qparams, - VecsimQueryType query_type) { - long long num_val; - if (query_type != QUERY_TYPE_HYBRID) { - return VecSimParamResolverErr_InvalidPolicy_NHybrid; - } - if (qparams->batchSize != 0) { - return VecSimParamResolverErr_AlreadySet; - } - if (validate_positive_integer_param(rparam, &num_val) != VecSimParamResolver_OK) { - return VecSimParamResolverErr_BadValue; - } - qparams->batchSize = (size_t)num_val; - return VecSimParamResolver_OK; -} - -static VecSimResolveCode _ResolveParams_Epsilon(VecSimAlgo index_type, VecSimRawParam rparam, - VecSimQueryParams *qparams, - VecsimQueryType query_type) { - double num_val; - // EPSILON is a valid parameter only in HNSW or SVS algorithms. - if (index_type != VecSimAlgo_HNSWLIB && index_type != VecSimAlgo_SVS) { - return VecSimParamResolverErr_UnknownParam; - } - if (query_type != QUERY_TYPE_RANGE) { - return VecSimParamResolverErr_InvalidPolicy_NRange; - } - auto &epsilon_ref = index_type == VecSimAlgo_HNSWLIB ? qparams->hnswRuntimeParams.epsilon - : qparams->svsRuntimeParams.epsilon; - if (epsilon_ref != 0) { - return VecSimParamResolverErr_AlreadySet; - } - if (validate_positive_double_param(rparam, &num_val) != VecSimParamResolver_OK) { - return VecSimParamResolverErr_BadValue; - } - epsilon_ref = num_val; - return VecSimParamResolver_OK; -} - -static VecSimResolveCode _ResolveParams_HybridPolicy(VecSimRawParam rparam, - VecSimQueryParams *qparams, - VecsimQueryType query_type) { - if (query_type != QUERY_TYPE_HYBRID) { - return VecSimParamResolverErr_InvalidPolicy_NHybrid; - } - if (qparams->searchMode != 0) { - return VecSimParamResolverErr_AlreadySet; - } - if (!strcasecmp(rparam.value, VECSIM_POLICY_BATCHES)) { - qparams->searchMode = HYBRID_BATCHES; - } else if (!strcasecmp(rparam.value, VECSIM_POLICY_ADHOC_BF)) { - qparams->searchMode = HYBRID_ADHOC_BF; - } else if (!strcasecmp(rparam.value, VECSIM_POLICY_INVALID)) { - return VecSimParamResolverErr_InvalidPolicy_NExits; - } else { - return VecSimParamResolverErr_InvalidPolicy_NExits; - } - return VecSimParamResolver_OK; -} - -extern "C" VecSimIndex *VecSimIndex_New(const VecSimParams *params) { - return VecSimFactory::NewIndex(params); -} - -extern "C" size_t VecSimIndex_EstimateInitialSize(const VecSimParams *params) { - return VecSimFactory::EstimateInitialSize(params); -} - -extern "C" int VecSimIndex_AddVector(VecSimIndex *index, const void *blob, size_t label) { - return index->addVector(blob, label); -} - -extern "C" int VecSimIndex_DeleteVector(VecSimIndex *index, size_t label) { - return index->deleteVector(label); -} - -extern "C" double VecSimIndex_GetDistanceFrom_Unsafe(VecSimIndex *index, size_t label, - const void *blob) { - return index->getDistanceFrom_Unsafe(label, blob); -} - -extern "C" size_t VecSimIndex_EstimateElementSize(const VecSimParams *params) { - return VecSimFactory::EstimateElementSize(params); -} - -extern "C" void VecSim_Normalize(void *blob, size_t dim, VecSimType type) { - if (type == VecSimType_FLOAT32) { - spaces::GetNormalizeFunc()(blob, dim); - } else if (type == VecSimType_FLOAT64) { - spaces::GetNormalizeFunc()(blob, dim); - } else if (type == VecSimType_BFLOAT16) { - spaces::GetNormalizeFunc()(blob, dim); - } else if (type == VecSimType_FLOAT16) { - spaces::GetNormalizeFunc()(blob, dim); - } else if (type == VecSimType_INT8) { - // assuming blob is large enough to store the norm at the end of the vector - spaces::GetNormalizeFunc()(blob, dim); - } else if (type == VecSimType_UINT8) { - // assuming blob is large enough to store the norm at the end of the vector - spaces::GetNormalizeFunc()(blob, dim); - } -} - -extern "C" size_t VecSimIndex_IndexSize(VecSimIndex *index) { return index->indexSize(); } - -extern "C" VecSimResolveCode VecSimIndex_ResolveParams(VecSimIndex *index, VecSimRawParam *rparams, - int paramNum, VecSimQueryParams *qparams, - VecsimQueryType query_type) { - - if (!qparams || (!rparams && (paramNum != 0))) { - return VecSimParamResolverErr_NullParam; - } - VecSimAlgo index_type = index->basicInfo().algo; - - bzero(qparams, sizeof(VecSimQueryParams)); - auto res = VecSimParamResolver_OK; - for (int i = 0; i < paramNum; i++) { - if (!strcasecmp(rparams[i].name, VecSimCommonStrings::HNSW_EF_RUNTIME_STRING)) { - if ((res = _ResolveParams_EFRuntime(index_type, rparams[i], qparams, query_type)) != - VecSimParamResolver_OK) { - return res; - } - } else if (!strcasecmp(rparams[i].name, VecSimCommonStrings::EPSILON_STRING)) { - if ((res = _ResolveParams_Epsilon(index_type, rparams[i], qparams, query_type)) != - VecSimParamResolver_OK) { - return res; - } - } else if (!strcasecmp(rparams[i].name, VecSimCommonStrings::BATCH_SIZE_STRING)) { - if ((res = _ResolveParams_BatchSize(rparams[i], qparams, query_type)) != - VecSimParamResolver_OK) { - return res; - } - } else if (!strcasecmp(rparams[i].name, VecSimCommonStrings::HYBRID_POLICY_STRING)) { - if ((res = _ResolveParams_HybridPolicy(rparams[i], qparams, query_type)) != - VecSimParamResolver_OK) { - return res; - } - } else if (!strcasecmp(rparams[i].name, VecSimCommonStrings::SVS_SEARCH_WS_STRING)) { - if ((res = _ResolveParams_SearchWS(index_type, rparams[i], qparams)) != - VecSimParamResolver_OK) { - return res; - } - } else if (!strcasecmp(rparams[i].name, VecSimCommonStrings::SVS_SEARCH_BC_STRING)) { - if ((res = _ResolveParams_SearchBC(index_type, rparams[i], qparams)) != - VecSimParamResolver_OK) { - return res; - } - } else if (!strcasecmp(rparams[i].name, - VecSimCommonStrings::SVS_USE_SEARCH_HISTORY_STRING)) { - if ((res = _ResolveParams_UseSearchHistory(index_type, rparams[i], qparams)) != - VecSimParamResolver_OK) { - return res; - } - } else { - return VecSimParamResolverErr_UnknownParam; - } - } - // The combination of AD-HOC with batch_size is invalid, as there are no batches in this policy. - if (qparams->searchMode == HYBRID_ADHOC_BF && qparams->batchSize > 0) { - return VecSimParamResolverErr_InvalidPolicy_AdHoc_With_BatchSize; - } - // Also, 'ef_runtime' is meaning less in AD-HOC policy, since it doesn't involve search in HNSW - // graph. - if (qparams->searchMode == HYBRID_ADHOC_BF && index_type == VecSimAlgo_HNSWLIB && - qparams->hnswRuntimeParams.efRuntime > 0) { - return VecSimParamResolverErr_InvalidPolicy_AdHoc_With_EfRuntime; - } - if (qparams->searchMode != 0) { - index->setLastSearchMode(qparams->searchMode); - } - return res; -} - -extern "C" VecSimQueryReply *VecSimIndex_TopKQuery(VecSimIndex *index, const void *queryBlob, - size_t k, VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) { - assert((order == BY_ID || order == BY_SCORE) && - "Possible order values are only 'BY_ID' or 'BY_SCORE'"); - VecSimQueryReply *results; - results = index->topKQuery(queryBlob, k, queryParams); - - if (order == BY_ID) { - sort_results_by_id(results); - } - return results; -} - -extern "C" VecSimQueryReply *VecSimIndex_RangeQuery(VecSimIndex *index, const void *queryBlob, - double radius, VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) { - if (order != BY_ID && order != BY_SCORE) { - throw std::runtime_error("Possible order values are only 'BY_ID' or 'BY_SCORE'"); - } - if (radius < 0) { - throw std::runtime_error("radius must be non-negative"); - } - return index->rangeQuery(queryBlob, radius, queryParams, order); -} - -extern "C" void VecSimIndex_Free(VecSimIndex *index) { - std::shared_ptr allocator = - index->getAllocator(); // Save allocator so it will not deallocate itself - delete index; -} - -extern "C" VecSimIndexDebugInfo VecSimIndex_DebugInfo(VecSimIndex *index) { - return index->debugInfo(); -} - -extern "C" VecSimDebugInfoIterator *VecSimIndex_DebugInfoIterator(VecSimIndex *index) { - return index->debugInfoIterator(); -} - -extern "C" VecSimIndexBasicInfo VecSimIndex_BasicInfo(VecSimIndex *index) { - return index->basicInfo(); -} - -extern "C" VecSimIndexStatsInfo VecSimIndex_StatsInfo(VecSimIndex *index) { - return index->statisticInfo(); -} - -extern "C" VecSimBatchIterator *VecSimBatchIterator_New(VecSimIndex *index, const void *queryBlob, - VecSimQueryParams *queryParams) { - return index->newBatchIterator(queryBlob, queryParams); -} - -extern "C" void VecSimTieredIndex_GC(VecSimIndex *index) { - if (index->basicInfo().isTiered) { - index->runGC(); - } -} - -extern "C" void VecSimTieredIndex_AcquireSharedLocks(VecSimIndex *index) { - index->acquireSharedLocks(); -} - -extern "C" void VecSimTieredIndex_ReleaseSharedLocks(VecSimIndex *index) { - index->releaseSharedLocks(); -} - -extern "C" void VecSim_SetMemoryFunctions(VecSimMemoryFunctions memoryfunctions) { - VecSimAllocator::setMemoryFunctions(memoryfunctions); -} - -extern "C" bool VecSimIndex_PreferAdHocSearch(VecSimIndex *index, size_t subsetSize, size_t k, - bool initial_check) { - return index->preferAdHocSearch(subsetSize, k, initial_check); -} diff --git a/src/VecSim/vec_sim.h b/src/VecSim/vec_sim.h deleted file mode 100644 index 56110a900..000000000 --- a/src/VecSim/vec_sim.h +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include "query_results.h" -#include "vec_sim_common.h" -#include "info_iterator.h" - -typedef struct VecSimIndexInterface VecSimIndex; - -/** - * @brief Create a new VecSim index based on the given params. - * @param params index configurations (initial size, data type, dimension, metric, algorithm and the - * algorithm-related params). - * @return A pointer to the created index. - */ -VecSimIndex *VecSimIndex_New(const VecSimParams *params); - -/** - * @brief Estimates the size of an empty index according to the parameters. - * @param params index configurations (initial size, data type, dimension, metric, algorithm and the - * algorithm-related params). - * @return Estimated index size. - */ -size_t VecSimIndex_EstimateInitialSize(const VecSimParams *params); - -/** - * @brief Estimates the size of a single vector and its metadata according to the parameters, WHEN - * THE INDEX IS RESIZING BY A BLOCK. That is, this function estimates the allocation size of a new - * block upon resizing all the internal data structures, and returns the size of a single vector in - * that block. This value can be used later to decide what is the best block size for the block - * size, when the memory limit is known. - * ("memory limit for a block" / "size of a single vector in a block" = "block size") - * @param params index configurations (initial size, data type, dimension, metric, algorithm and the - * algorithm-related params). - * @return The estimated single vector memory consumption, considering the parameters. - */ -size_t VecSimIndex_EstimateElementSize(const VecSimParams *params); - -/** - * @brief Release an index and its internal data. - * @param index the index to release. - */ -void VecSimIndex_Free(VecSimIndex *index); - -/** - * @brief Add a vector to an index. - * @param index the index to which the vector is added. - * @param blob binary representation of the vector. Blob size should match the index data type and - * dimension. - * @param label the label of the added vector - * @return the number of new vectors inserted (1 for new insertion, 0 for override). - */ -int VecSimIndex_AddVector(VecSimIndex *index, const void *blob, size_t label); - -/** - * @brief Remove a vector from an index. - * @param index the index from which the vector is removed. - * @param label the label of the removed vector - * @return the number of vectors removed (0 if the label was not found) - */ -int VecSimIndex_DeleteVector(VecSimIndex *index, size_t label); - -/** - * @brief Calculate the distance of a vector from an index to a vector. This function assumes that - * the vector fits the index - its type and dimension are the same as the index's, and if the - * index's distance metric is cosine, the vector is already normalized. - * IMPORTANT: for tiered index, this should be called while *locks are locked for shared ownership*, - * as we avoid acquiring the locks internally. That is since this is usually called for every vector - * individually, and the overhead of acquiring and releasing the locks is significant in that case. - * @param index the index from which the first vector is located, and that defines the distance - * metric. - * @param label the label of the vector in the index. - * @param blob binary representation of the second vector. Blob size should match the index data - * type and dimension, and pre-normalized if needed. - * @return The distance (according to the index's distance metric) between `blob` and the vector - * with label label`. - */ -double VecSimIndex_GetDistanceFrom_Unsafe(VecSimIndex *index, size_t label, const void *blob); - -/** - * @brief normalize the vector blob in place. - * @param blob binary representation of a vector. Blob size should match the specified type and - * dimension. - * @param dim vector dimension. - * @param type vector type. - */ -void VecSim_Normalize(void *blob, size_t dim, VecSimType type); - -/** - * @brief Return the number of vectors in the index. - * @param index the index whose size is requested. - * @return index size. - */ -size_t VecSimIndex_IndexSize(VecSimIndex *index); - -/** - * @brief Resolves VecSimRawParam array and generate VecSimQueryParams struct. - * @param index the index whose size is requested. - * @param rparams array of raw params to resolve. - * @param paramNum number of params in rparams (or number of parames in rparams to resolve). - * @param qparams pointer to VecSimQueryParams struct to set. - * @param query_type indicates if query is hybrid, range or "standard" VSS query. - * @return VecSim_OK if the resolve was successful, VecSimResolveCode error code if not. - */ -VecSimResolveCode VecSimIndex_ResolveParams(VecSimIndex *index, VecSimRawParam *rparams, - int paramNum, VecSimQueryParams *qparams, - VecsimQueryType query_type); - -/** - * @brief Search for the k closest vectors to a given vector in the index. The results can be - * ordered by their score or id. - * @param index the index to query in. - * @param queryBlob binary representation of the query vector. Blob size should match the index data - * type and dimension. - * @param k the number of "nearest neighbours" to return (upper bound). - * @param queryParams run time params for the search, which are algorithm-specific. - * @param order the criterion to sort the results list by it. Options are by score, or by id. - * @return An opaque object the represents a list of results. User can access the id and score - * (which is the distance according to the index metric) of every result through - * VecSimQueryReply_Iterator. - */ -VecSimQueryReply *VecSimIndex_TopKQuery(VecSimIndex *index, const void *queryBlob, size_t k, - VecSimQueryParams *queryParams, VecSimQueryReply_Order); - -/** - * @brief Search for the vectors that are in a given range in the index with respect to a given - * vector. The results can be ordered by their score or id. - * @param index the index to query in. - * @param queryBlob binary representation of the query vector. Blob size should match the index data - * type and dimension. - * @param radius the radius around the query vector to search vectors within it. - * @param queryParams run time params for the search, which are algorithm-specific. - * @param order the criterion to sort the results list by it. Options are by score, or by id. - * @return An opaque object the represents a list of results. User can access the id and score - * (which is the distance according to the index metric) of every result through - * VecSimQueryReply_Iterator. - */ -VecSimQueryReply *VecSimIndex_RangeQuery(VecSimIndex *index, const void *queryBlob, double radius, - VecSimQueryParams *queryParams, VecSimQueryReply_Order); -/** - * @brief Return index information. - * @param index the index to return its info. - * @return Index general and specific meta-data. - */ -VecSimIndexDebugInfo VecSimIndex_DebugInfo(VecSimIndex *index); - -/** - * @brief Return basic immutable index information. - * @param index the index to return its info. - * @return Index basic meta-data. - */ -VecSimIndexBasicInfo VecSimIndex_BasicInfo(VecSimIndex *index); - -/** - * @brief Return statistics information. - * @param index the index to return its info. - * @return Index statistic data. - */ -VecSimIndexStatsInfo VecSimIndex_StatsInfo(VecSimIndex *index); - -/** - * @brief Returns an info iterator for generic reply purposes. - * - * @param index this index to return its info. - * @return VecSimDebugInfoIterator* An iterable containing the index general and specific meta-data. - */ -VecSimDebugInfoIterator *VecSimIndex_DebugInfoIterator(VecSimIndex *index); - -/** - * @brief Create a new batch iterator for a specific index, for a specific query vector, - * using the Index_BatchIteratorNew method of the index. Should be released with - * VecSimBatchIterator_Free call. - * @param index the index in which the search will be done (in batches) - * @param queryBlob binary representation of the vector. Blob size should match the index data type - * and dimension. - * @param queryParams run time params for the search, which are algorithm-specific. - * @return Fresh batch iterator - */ -VecSimBatchIterator *VecSimBatchIterator_New(VecSimIndex *index, const void *queryBlob, - VecSimQueryParams *queryParams); - -/** - * @brief Run async garbage collection for tiered async index. - */ -void VecSimTieredIndex_GC(VecSimIndex *index); - -/** - * @brief Return True if heuristics says that it is better to use ad-hoc brute-force - * search over the index instead of using batch iterator. - * - * @param subsetSize the estimated number of vectors in the index that pass the filter - * (that is, query results can be only from a subset of vector of this size). - * - * @param k the number of required results to return from the query. - * - * @param initial_check flag to indicate if this check is performed for the first time (upon - * creating the hybrid iterator), or after running batches. - */ -bool VecSimIndex_PreferAdHocSearch(VecSimIndex *index, size_t subsetSize, size_t k, - bool initial_check); - -/** - * @brief Acquire/Release the required locks of the tiered index externally before executing an - * an unsafe *READ* operation (as the locks are acquired for shared ownership). - * @param index the tiered index to protect (no nothing for non-tiered indexes). - */ -void VecSimTieredIndex_AcquireSharedLocks(VecSimIndex *index); - -void VecSimTieredIndex_ReleaseSharedLocks(VecSimIndex *index); - -/** - * @brief Allow 3rd party memory functions to be used for memory management. - * - * @param memoryfunctions VecSimMemoryFunctions struct. - */ -void VecSim_SetMemoryFunctions(VecSimMemoryFunctions memoryfunctions); - -/** - * @brief Allow 3rd party timeout callback to be used for limiting runtime of a query. - * - * @param callback timeoutCallbackFunction function. should get void* and return int. - */ -void VecSim_SetTimeoutCallbackFunction(timeoutCallbackFunction callback); - -/** - * @brief Allow 3rd party log callback to be used for logging. - * - * @param callback logCallbackFunction function. should get void* and return void. - */ -void VecSim_SetLogCallbackFunction(logCallbackFunction callback); - -/** - * @brief Set the context for logging (e.g., test name or file name). - * - * @param test_name the name of the test. - * @param test_type the type of the test (e.g., "unit" or "flow"). - */ -void VecSim_SetTestLogContext(const char *test_name, const char *test_type); - -/** - * @brief Allow 3rd party to set the write mode for tiered index - async insert/delete using - * background jobs, or insert/delete inplace. - * @note In tiered index scenario, should be called from main thread only !! (that is, the thread - * that is calling add/delete vector functions). - * - * @param mode VecSimWriteMode the mode in which we add/remove vectors (async or in-place). - */ -void VecSim_SetWriteMode(VecSimWriteMode mode); - -#ifdef __cplusplus -} -#endif diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h deleted file mode 100644 index fa136b7fe..000000000 --- a/src/VecSim/vec_sim_common.h +++ /dev/null @@ -1,479 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif -#include -#include -#include -#include - -// Common definitions -#define DEFAULT_BLOCK_SIZE 1024 -#define INVALID_ID UINT_MAX -#define INVALID_LABEL SIZE_MAX -#define UNUSED(x) (void)(x) - -// Hybrid policy values -#define VECSIM_POLICY_ADHOC_BF "adhoc_bf" -#define VECSIM_POLICY_BATCHES "batches" -#define VECSIM_POLICY_INVALID "invalid_policy" - -// HNSW default parameters -#define HNSW_DEFAULT_M 16 -#define HNSW_DEFAULT_EF_C 200 -#define HNSW_DEFAULT_EF_RT 10 -#define HNSW_DEFAULT_EPSILON 0.01 - -#define HNSW_INVALID_LEVEL SIZE_MAX -#define INVALID_JOB_ID UINT_MAX -#define INVALID_INFO UINT_MAX - -// SVS-Vamana default parameters -#define SVS_VAMANA_DEFAULT_ALPHA_L2 1.2f -#define SVS_VAMANA_DEFAULT_ALPHA_IP 0.95f -#define SVS_VAMANA_DEFAULT_GRAPH_MAX_DEGREE 32 -#define SVS_VAMANA_DEFAULT_CONSTRUCTION_WINDOW_SIZE 200 -#define SVS_VAMANA_DEFAULT_USE_SEARCH_HISTORY true -#define SVS_VAMANA_DEFAULT_NUM_THREADS 1 -// NOTE: optimal training threshold may depend on the SVSIndex compression mode. -// it might be good to implement a utility to compute default threshold based on index parameters -// DEFAULT_BLOCK_SIZE is used to round the training threshold to FLAT index blocks -#define SVS_VAMANA_DEFAULT_TRAINING_THRESHOLD (10 * DEFAULT_BLOCK_SIZE) // 10 * 1024 vectors -// Default batch update threshold for SVS index. -#define SVS_VAMANA_DEFAULT_UPDATE_THRESHOLD (1 * DEFAULT_BLOCK_SIZE) // 1 * 1024 vectors -#define SVS_VAMANA_DEFAULT_SEARCH_WINDOW_SIZE 10 -// NOTE: No need to have SVS_VAMANA_DEFAULT_SEARCH_BUFFER_CAPACITY -// as the default is determined by the search_window_size -#define SVS_VAMANA_DEFAULT_LEANVEC_DIM 0 -#define SVS_VAMANA_DEFAULT_EPSILON 0.01f - -// Datatypes for indexing. -typedef enum { - VecSimType_FLOAT32, - VecSimType_FLOAT64, - VecSimType_BFLOAT16, - VecSimType_FLOAT16, - VecSimType_INT8, - VecSimType_UINT8, - VecSimType_INT32, - VecSimType_INT64 -} VecSimType; - -// Algorithm type/library. -typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_TIERED, VecSimAlgo_SVS } VecSimAlgo; - -typedef enum { - VecSimOption_AUTO = 0, - VecSimOption_ENABLE = 1, - VecSimOption_DISABLE = 2, -} VecSimOptionMode; - -typedef enum { - VecSimBool_TRUE = 1, - VecSimBool_FALSE = 0, - VecSimBool_UNSET = -1, -} VecSimBool; - -// Distance metric -typedef enum { VecSimMetric_L2, VecSimMetric_IP, VecSimMetric_Cosine } VecSimMetric; - -typedef size_t labelType; -typedef unsigned int idType; - -/** - * @brief Query Runtime raw parameters. - * Use VecSimIndex_ResolveParams to generate VecSimQueryParams from array of VecSimRawParams. - * - */ -typedef struct { - const char *name; - size_t nameLen; - const char *value; - size_t valLen; -} VecSimRawParam; - -#define VecSim_OK 0 - -typedef enum { - VecSimParamResolver_OK = VecSim_OK, // for returning VecSim_OK as an enum value - VecSimParamResolverErr_NullParam, - VecSimParamResolverErr_AlreadySet, - VecSimParamResolverErr_UnknownParam, - VecSimParamResolverErr_BadValue, - VecSimParamResolverErr_InvalidPolicy_NExits, - VecSimParamResolverErr_InvalidPolicy_NHybrid, - VecSimParamResolverErr_InvalidPolicy_NRange, - VecSimParamResolverErr_InvalidPolicy_AdHoc_With_BatchSize, - VecSimParamResolverErr_InvalidPolicy_AdHoc_With_EfRuntime -} VecSimResolveCode; - -typedef enum { - VecSimDebugCommandCode_OK = VecSim_OK, // for returning VecSim_OK as an enum value - VecSimDebugCommandCode_BadIndex, - VecSimDebugCommandCode_LabelNotExists, - VecSimDebugCommandCode_MultiNotSupported -} VecSimDebugCommandCode; - -typedef struct AsyncJob AsyncJob; // forward declaration. - -// Write async/sync mode -typedef enum { VecSim_WriteAsync, VecSim_WriteInPlace } VecSimWriteMode; - -/** - * Callback signatures for asynchronous tiered index. - */ -typedef void (*JobCallback)(AsyncJob *); -typedef int (*SubmitCB)(void *job_queue, void *index_ctx, AsyncJob **jobs, JobCallback *CBs, - size_t jobs_len); - -/** - * @brief Index initialization parameters. - * - */ -typedef struct VecSimParams VecSimParams; -typedef struct { - VecSimType type; // Datatype to index. - size_t dim; // Vector's dimension. - VecSimMetric metric; // Distance metric to use in the index. - bool multi; // Determines if the index should multi-index or not. - size_t initialCapacity; // Deprecated - size_t blockSize; - size_t M; - size_t efConstruction; - size_t efRuntime; - double epsilon; -} HNSWParams; - -typedef struct { - VecSimType type; // Datatype to index. - size_t dim; // Vector's dimension. - VecSimMetric metric; // Distance metric to use in the index. - bool multi; // Determines if the index should multi-index or not. - size_t initialCapacity; // Deprecated. - size_t blockSize; -} BFParams; - -typedef enum { - VecSimSvsQuant_NONE = 0, // No quantization. - VecSimSvsQuant_Scalar = 1, // 8-bit scalar quantization - VecSimSvsQuant_4 = 4, // 4-bit quantization - VecSimSvsQuant_8 = 8, // 8-bit quantization - VecSimSvsQuant_4x4 = 4 | (4 << 8), // 4-bit quantization with 4-bit residuals - VecSimSvsQuant_4x8 = 4 | (8 << 8), // 4-bit quantization with 8-bit residuals - VecSimSvsQuant_4x8_LeanVec = 4 | (8 << 8) | (1 << 16), // LeanVec 4x8 quantization - VecSimSvsQuant_8x8_LeanVec = 8 | (8 << 8) | (1 << 16), // LeanVec 8x8 quantization -} VecSimSvsQuantBits; - -typedef struct { - VecSimType type; // Datatype to index. - size_t dim; // Vector's dimension. - VecSimMetric metric; // Distance metric to use in the index. - bool multi; // Determines if the index should multi-index or not. - size_t blockSize; - - /* SVS-Vamana specifics. See Intel ScalableVectorSearch documentation */ - VecSimSvsQuantBits quantBits; // Quantization level. - float alpha; // The pruning parameter. - size_t graph_max_degree; // Maximum degree in the graph. - size_t construction_window_size; // Search window size to use during graph construction. - size_t max_candidate_pool_size; // Limit on the number of neighbors considered during pruning. - size_t prune_to; // Amount that candidates will be pruned. - VecSimOptionMode use_search_history; // Either the contents of the search buffer can be used or - // the entire search history. - size_t num_threads; // Maximum number of threads in threadpool. - size_t search_window_size; // Search window size to use during search. - size_t search_buffer_capacity; // Search buffer capacity to use during search. - size_t leanvec_dim; // Leanvec dimension to use when LeanVec is enabled. - double epsilon; // Epsilon parameter for SVS graph accuracy/latency for range search. -} SVSParams; - -// A struct that contains HNSW tiered index specific params. -typedef struct { - size_t swapJobThreshold; // The minimum number of swap jobs to accumulate before applying - // all the ready swap jobs in a batch. -} TieredHNSWParams; - -// A struct that contains HNSW Disk tiered index specific params. -// Consider removing and use TieredHNSWParams instead if they both share swapJobThreshold -typedef struct { - char _placeholder; // Reserved for future fields and avoid compiler errors -} TieredHNSWDiskParams; - -// A struct that contains SVS tiered index specific params. -typedef struct { - size_t trainingTriggerThreshold; // The flat index size threshold to trigger the initialization - // of backend index. - size_t updateTriggerThreshold; // The flat index size threshold to trigger the vectors migration - // to backend index. - size_t updateJobWaitTime; // The time (microseconds) to wait for Redis threads reservation - // before executing the scheduled SVS Index update job. -} TieredSVSParams; - -// A struct that contains the common tiered index params. -typedef struct { - void *jobQueue; // External queue that holds the jobs. - void *jobQueueCtx; // External context to be sent to the submit callback. - SubmitCB submitCb; // A callback that submits an array of jobs into a given jobQueue. - size_t flatBufferLimit; // Maximum size allowed for the flat buffer. If flat buffer is full, use - // in-place insertion. - VecSimParams *primaryIndexParams; // Parameters to initialize the index. - union { - TieredHNSWParams tieredHnswParams; - TieredSVSParams tieredSVSParams; - TieredHNSWDiskParams tieredHnswDiskParams; - } specificParams; -} TieredIndexParams; - -typedef union { - HNSWParams hnswParams; - BFParams bfParams; - TieredIndexParams tieredParams; - SVSParams svsParams; -} AlgoParams; - -struct VecSimParams { - VecSimAlgo algo; // Algorithm to use. - AlgoParams algoParams; - void *logCtx; // External context that stores the index log. -}; - -typedef struct { - void *storage; // Opaque pointer to disk storage - const char *indexName; - size_t indexNameLen; -} VecSimDiskContext; - -typedef struct { - VecSimParams *indexParams; - VecSimDiskContext *diskContext; -} VecSimParamsDisk; - -/** - * The specific job types in use (to be extended in the future by demand) - */ -typedef enum { - HNSW_INSERT_VECTOR_JOB, - HNSW_REPAIR_NODE_CONNECTIONS_JOB, - HNSW_SEARCH_JOB, - HNSW_SWAP_JOB, - SVS_BATCH_UPDATE_JOB, - SVS_GC_JOB, - INVALID_JOB // to indicate that finding a JobType >= INVALID_JOB is an error -} JobType; - -typedef struct { - size_t efRuntime; // EF parameter for HNSW graph accuracy/latency for search. - double epsilon; // Epsilon parameter for HNSW graph accuracy/latency for range search. -} HNSWRuntimeParams; - -typedef struct { - size_t windowSize; // Search window size for Vamana graph accuracy/latency tune. - size_t bufferCapacity; // Search buffer capacity for Vamana graph accuracy/latency tune. - VecSimOptionMode searchHistory; // Enabling of the visited set for search. - double epsilon; // Epsilon parameter for SVS graph accuracy/latency for range search. -} SVSRuntimeParams; - -/** - * @brief Query runtime information - the search mode in RediSearch (used for debug/testing). - * - */ -typedef enum { - EMPTY_MODE, // Default value to initialize the "lastMode" field with. - STANDARD_KNN, // Run k-nn query over the entire vector index. - HYBRID_ADHOC_BF, // Measure ad-hoc the distance for every result that passes the filters, - // and take the top k results. - HYBRID_BATCHES, // Get the top vector results in batches upon demand, and keep the results that - // passes the filters until we reach k results. - HYBRID_BATCHES_TO_ADHOC_BF, // Start with batches and dynamically switched to ad-hoc BF. - RANGE_QUERY, // Run range query, to return all vectors that are within a given range from the - // query vector. -} VecSearchMode; - -typedef enum { - QUERY_TYPE_NONE, // Use when no params are given. - QUERY_TYPE_KNN, - QUERY_TYPE_HYBRID, - QUERY_TYPE_RANGE, -} VecsimQueryType; - -/** - * @brief Query Runtime parameters. - * - */ -typedef struct { - union { - HNSWRuntimeParams hnswRuntimeParams; - SVSRuntimeParams svsRuntimeParams; - }; - size_t batchSize; - VecSearchMode searchMode; - void *timeoutCtx; // This parameter is not exposed directly to the user, and we shouldn't expect - // to get it from the parameters resolve function. -} VecSimQueryParams; - -/** - * Index info that is static and immutable (cannot be changed over time) - */ -typedef struct { - VecSimAlgo algo; // Algorithm being used (if index is tiered, this is the backend index). - VecSimMetric metric; // Index distance metric - VecSimType type; // Datatype the index holds. - bool isMulti; // Determines if the index should multi-index or not. - bool isTiered; // Is the index is tiered or not. - bool isDisk; // Is the index stored on disk. - size_t blockSize; // Brute force algorithm vector block (mini matrix) size - size_t dim; // Vector size (dimension). -} VecSimIndexBasicInfo; - -/** - * Index info for statistics - a thin and efficient (no locks, no calculations) info. Can be used in - * production without worrying about performance - */ -typedef struct { - size_t memory; - size_t numberOfMarkedDeleted; // The number of vectors that are marked as deleted (HNSW/tiered - // only). -} VecSimIndexStatsInfo; - -typedef struct { - VecSimIndexBasicInfo basicInfo; // Index immutable meta-data. - size_t indexSize; // Current count of vectors. - size_t indexLabelCount; // Current unique count of labels. - uint64_t memory; // Index memory consumption. - VecSearchMode lastMode; // The mode in which the last query ran. -} CommonInfo; - -typedef struct { - size_t M; // Number of allowed edges per node in graph. - size_t efConstruction; // EF parameter for HNSW graph accuracy/latency for indexing. - size_t efRuntime; // EF parameter for HNSW graph accuracy/latency for search. - double epsilon; // Epsilon parameter for HNSW graph accuracy/latency for range search. - size_t max_level; // Number of graph levels. - size_t entrypoint; // Entrypoint vector label. - size_t visitedNodesPoolSize; // The max number of parallel graph scans so far. - size_t numberOfMarkedDeletedNodes; // The number of nodes that are marked as deleted. -} hnswInfoStruct; - -typedef struct { - char dummy; // For not having this as an empty struct, can be removed after we extend this. -} bfInfoStruct; - -typedef struct { - VecSimSvsQuantBits quantBits; // Quantization flavor. - float alpha; // The pruning parameter. - size_t graphMaxDegree; // Maximum degree in the graph. - size_t constructionWindowSize; // Search window size to use during graph construction. - size_t maxCandidatePoolSize; // Limit on the number of neighbors considered during pruning. - size_t pruneTo; // Amount that candidates will be pruned. - bool useSearchHistory; // Either the contents of the search buffer can be used or - // the entire search history. - size_t numThreads; // Maximum number of threads to be used by svs for ingestion. - size_t lastReservedThreads; // Number of threads that were successfully reserved by the last - // ingestion operation. - size_t numberOfMarkedDeletedNodes; // The number of nodes that are marked as deleted. - size_t searchWindowSize; // Search window size for Vamana graph accuracy/latency tune. - size_t searchBufferCapacity; // Search buffer capacity for Vamana graph accuracy/latency tune. - size_t leanvecDim; // Leanvec dimension to use when LeanVec is enabled. - double epsilon; // Epsilon parameter for SVS graph accuracy/latency for range search. -} svsInfoStruct; - -typedef struct HnswTieredInfo { - size_t pendingSwapJobsThreshold; -} HnswTieredInfo; - -typedef struct SvsTieredInfo { - size_t trainingTriggerThreshold; - size_t updateTriggerThreshold; - size_t updateJobWaitTime; // The time (microseconds) to wait for Redis threads reservation - // before executing the scheduled SVS Index update job. - bool indexUpdateScheduled; -} SvsTieredInfo; - -typedef struct { - - // Since we cannot recursively have a struct that contains itself, we need this workaround. - union { - hnswInfoStruct hnswInfo; - svsInfoStruct svsInfo; - } backendInfo; // The backend index info. - union { - HnswTieredInfo hnswTieredInfo; - SvsTieredInfo svsTieredInfo; - } specificTieredBackendInfo; // Info relevant for tiered index with a specific backend. - CommonInfo backendCommonInfo; // Common index info. - CommonInfo frontendCommonInfo; // Common index info. - bfInfoStruct bfInfo; // The brute force index info. - - uint64_t management_layer_memory; // Memory consumption of the management layer. - VecSimBool backgroundIndexing; // Determines if the index is currently being indexed in the - // background. - size_t bufferLimit; // Maximum number of vectors allowed in the flat buffer. -} tieredInfoStruct; - -/** - * @brief Index information. Should only be used for debug/testing. - * - */ -typedef struct { - CommonInfo commonInfo; - union { - bfInfoStruct bfInfo; - hnswInfoStruct hnswInfo; - svsInfoStruct svsInfo; - tieredInfoStruct tieredInfo; - }; -} VecSimIndexDebugInfo; - -// Memory function declarations. -typedef void *(*allocFn)(size_t n); -typedef void *(*callocFn)(size_t nelem, size_t elemsz); -typedef void *(*reallocFn)(void *p, size_t n); -typedef void (*freeFn)(void *p); -typedef char *(*strdupFn)(const char *s); - -/** - * @brief A struct to pass 3rd party memory functions to Vecsimlib. - * - */ -typedef struct { - allocFn allocFunction; // Malloc like function. - callocFn callocFunction; // Calloc like function. - reallocFn reallocFunction; // Realloc like function. - freeFn freeFunction; // Free function. -} VecSimMemoryFunctions; - -/** - * @brief A struct to pass 3rd party timeout function to Vecsimlib. - * @param ctx some generic context to pass to the function - * @return the function should return a non-zero value on timeout - */ -typedef int (*timeoutCallbackFunction)(void *ctx); - -/** - * @brief A struct to pass 3rd party logging function to Vecsimlib. - * @param ctx some generic context to pass to the function - * @param level loglevel (in redis we should choose from: "warning", "notice", "verbose", "debug") - * @param message the message to log - */ -typedef void (*logCallbackFunction)(void *ctx, const char *level, const char *message); - -// Round up to the nearest multiplication of blockSize. -static inline size_t RoundUpInitialCapacity(size_t initialCapacity, size_t blockSize) { - return initialCapacity % blockSize ? initialCapacity + blockSize - initialCapacity % blockSize - : initialCapacity; -} - -#define VECSIM_TIMEOUT(ctx) (__builtin_expect(VecSimIndexInterface::timeoutCallback(ctx), false)) - -#ifdef __cplusplus -} -#endif diff --git a/src/VecSim/vec_sim_debug.cpp b/src/VecSim/vec_sim_debug.cpp deleted file mode 100644 index c83790111..000000000 --- a/src/VecSim/vec_sim_debug.cpp +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "vec_sim_debug.h" -#include "VecSim/vec_sim_index.h" -#include "VecSim/algorithms/hnsw/hnsw.h" -#include "VecSim/algorithms/hnsw/hnsw_tiered.h" -#include "VecSim/types/bfloat16.h" - -extern "C" int VecSimDebug_GetElementNeighborsInHNSWGraph(VecSimIndex *index, size_t label, - int ***neighborsData) { - - // Set as if we return an error, and upon success we will set the pointers appropriately. - *neighborsData = nullptr; - VecSimIndexBasicInfo info = index->basicInfo(); - if (info.algo != VecSimAlgo_HNSWLIB) { - return VecSimDebugCommandCode_BadIndex; - } - if (!info.isTiered) { - if (info.type == VecSimType_FLOAT32) { - return dynamic_cast *>(index)->getHNSWElementNeighbors( - label, neighborsData); - } else if (info.type == VecSimType_FLOAT64) { - return dynamic_cast *>(index)->getHNSWElementNeighbors( - label, neighborsData); - } else if (info.type == VecSimType_BFLOAT16) { - return dynamic_cast *>(index) - ->getHNSWElementNeighbors(label, neighborsData); - } else if (info.type == VecSimType_FLOAT16) { - return dynamic_cast *>(index) - ->getHNSWElementNeighbors(label, neighborsData); - } else if (info.type == VecSimType_INT8) { - return dynamic_cast *>(index)->getHNSWElementNeighbors( - label, neighborsData); - } else if (info.type == VecSimType_UINT8) { - return dynamic_cast *>(index)->getHNSWElementNeighbors( - label, neighborsData); - } else { - assert(false && "Invalid data type"); - } - } else { - if (info.type == VecSimType_FLOAT32) { - return dynamic_cast *>(index)->getHNSWElementNeighbors( - label, neighborsData); - } else if (info.type == VecSimType_FLOAT64) { - return dynamic_cast *>(index)->getHNSWElementNeighbors( - label, neighborsData); - } else if (info.type == VecSimType_BFLOAT16) { - return dynamic_cast *>(index) - ->getHNSWElementNeighbors(label, neighborsData); - } else if (info.type == VecSimType_FLOAT16) { - return dynamic_cast *>(index) - ->getHNSWElementNeighbors(label, neighborsData); - } else if (info.type == VecSimType_INT8) { - return dynamic_cast *>(index)->getHNSWElementNeighbors( - label, neighborsData); - } else if (info.type == VecSimType_UINT8) { - return dynamic_cast *>(index)->getHNSWElementNeighbors( - label, neighborsData); - } else { - assert(false && "Invalid data type"); - } - } - return VecSimDebugCommandCode_BadIndex; -} - -extern "C" void VecSimDebug_ReleaseElementNeighborsInHNSWGraph(int **neighborsData) { - if (neighborsData == nullptr) { - return; - } - size_t i = 0; - while (neighborsData[i] != nullptr) { - delete[] neighborsData[i]; - i++; - } - delete[] neighborsData; -} diff --git a/src/VecSim/vec_sim_debug.h b/src/VecSim/vec_sim_debug.h deleted file mode 100644 index 833009ee1..000000000 --- a/src/VecSim/vec_sim_debug.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#include "vec_sim.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * @brief: Dump the neighbors of an element in HNSW index in the following format: - * an array with entries, where each entry is an array itself. - * Every internal array in a position where <0<=l<=topLevel> corresponds to the neighbors of the - * element in the graph in level . It contains entries, where is the number of - * neighbors in level l. The last entry in the external array is NULL (indicates its length). - * The first entry in each internal array contains the number , while the next - * entries are the the *labels* of the elements neighbors in this level. - * Note: currently only HNSW indexes of type single are supported (multi not yet) - tiered included. - * For cleanup, VecSimDebug_ReleaseElementNeighborsInHNSWGraph need to be called with the value - * pointed by neighborsData as returned from this call. - * @param index - the index in which the element resides. - * @param label - the label to dump its neighbors in every level in which it exits. - * @param neighborsData - a pointer to a 2-dim array of integer which is a placeholder for the - * output of the neighbors' labels that will be allocated and stored in the format described above. - * - */ -// TODO: Implement the full version that supports MULTI as well. This will require adding an -// additional dim to the array and perhaps differentiating between internal ids of labels in the -// output format. Also, we may want in the future to dump the incoming edges as well. -int VecSimDebug_GetElementNeighborsInHNSWGraph(VecSimIndex *index, size_t label, - int ***neighborsData); - -/** - * @brief: Release the neighbors data allocated by VecSimDebug_GetElementNeighborsInHNSWGraph. - * @param neighborsData - the 2-dim array returned in the placeholder to be de-allocated. - */ -void VecSimDebug_ReleaseElementNeighborsInHNSWGraph(int **neighborsData); - -#ifdef __cplusplus -} -#endif diff --git a/src/VecSim/vec_sim_index.h b/src/VecSim/vec_sim_index.h deleted file mode 100644 index 017935ce9..000000000 --- a/src/VecSim/vec_sim_index.h +++ /dev/null @@ -1,369 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "vec_sim_common.h" -#include "vec_sim_interface.h" -#include "query_results.h" -#include -#include "VecSim/memory/vecsim_base.h" -#include "VecSim/utils/vec_utils.h" -#include "VecSim/spaces/spaces.h" -#include "VecSim/spaces/computer/calculator.h" -#include "VecSim/spaces/computer/preprocessor_container.h" -#include "info_iterator_struct.h" -#include "containers/data_blocks_container.h" -#include "containers/raw_data_container_interface.h" - -#include -#include - -/** - * @brief Struct for initializing an abstract index class. - * - * @param allocator The allocator to use for the index. - * @param dim The dimension of the vectors in the index. - * @param vecType The type of the vectors in the index. - * @param storedDataSize The size of stored vectors (possibly after pre-processing) in bytes. - * @param metric The metric to use in the index. - * @param blockSize The block size to use in the index. - * @param multi Determines if the index should multi-index or not. - * @param logCtx The context to use for logging. - * @param inputBlobSize The size of input vectors/queries blob in bytes. May differ from dim * - * sizeof(vecType) when vectors have been externally preprocessed (e.g., cosine normalization adds - * extra bytes). For example, in tiered indexes, the backend receives preprocessed blobs, not raw - * input vectors. - */ -struct AbstractIndexInitParams { - std::shared_ptr allocator; - size_t dim; - VecSimType vecType; - size_t storedDataSize; - VecSimMetric metric; - size_t blockSize; - bool multi; - bool isDisk; // Whether the index stores vectors on disk - void *logCtx; - size_t inputBlobSize; -}; - -/** - * @brief Struct for initializing the components of the abstract index. - * The index takes ownership of the components allocations' and is responsible for freeing - * them when the index is destroyed. - * - * @param indexCalculator The distance calculator for the index. - * @param preprocessors The preprocessing pipeline for ingesting user data before storage and - * querying. - */ -template -struct IndexComponents { - IndexCalculatorInterface *indexCalculator; - PreprocessorsContainerAbstract *preprocessors; -}; - -/** - * @brief Abstract C++ class for vector index, delete and lookup - * - */ -template -struct VecSimIndexAbstract : public VecSimIndexInterface { -protected: - size_t dim; // Vector's dimension. - VecSimType vecType; // Datatype to index. - VecSimMetric metric; // Distance metric to use in the index. - size_t blockSize; // Index's vector block size (determines by how many vectors to resize when - // resizing) - mutable VecSearchMode lastMode; // The last search mode in RediSearch (used for debug/testing). - bool isMulti; // Determines if the index should multi-index or not. - bool isDisk; // Whether the index stores vectors on disk. - void *logCallbackCtx; // Context for the log callback. - RawDataContainer *vectors; // The raw vectors data container. -private: - IndexCalculatorInterface *indexCalculator; // Distance calculator. - PreprocessorsContainerAbstract *preprocessors; // Storage and query preprocessors. - - size_t inputBlobSize; // The size of input vectors/queries blob in bytes. May differ from dim * - // sizeof(vecType) when vectors have been externally preprocessed (e.g., - // cosine normalization adds extra bytes). For example, in tiered indexes, - // the backend receives preprocessed blobs, not raw input vectors. - size_t storedDataSize; // Vector element data size in bytes to be stored - // (possibly after pre-processing and may differ from inputBlobSize if - // NOT externally preprocessed). -protected: - /** - * @brief Get the common info object - * - * @return CommonInfo - */ - CommonInfo getCommonInfo() const { - CommonInfo info; - info.basicInfo = this->getBasicInfo(); - info.lastMode = this->lastMode; - info.memory = this->getAllocationSize(); - info.indexSize = this->indexSize(); - info.indexLabelCount = this->indexLabelCount(); - return info; - } - - IndexCalculatorInterface *getIndexCalculator() const { return indexCalculator; } - -public: - /** - * @brief Construct a new Vec Sim Index object - * - */ - VecSimIndexAbstract(const AbstractIndexInitParams ¶ms, - const IndexComponents &components) - : VecSimIndexInterface(params.allocator), dim(params.dim), vecType(params.vecType), - metric(params.metric), - blockSize(params.blockSize ? params.blockSize : DEFAULT_BLOCK_SIZE), lastMode(EMPTY_MODE), - isMulti(params.multi), isDisk(params.isDisk), logCallbackCtx(params.logCtx), - indexCalculator(components.indexCalculator), preprocessors(components.preprocessors), - inputBlobSize(params.inputBlobSize), storedDataSize(params.storedDataSize) { - assert(VecSimType_sizeof(vecType)); - assert(storedDataSize); - assert(inputBlobSize); - this->vectors = new (this->allocator) DataBlocksContainer( - this->blockSize, this->storedDataSize, this->allocator, this->getAlignment()); - } - - /** - * @brief Destroy the Vec Sim Index object - * - */ - virtual ~VecSimIndexAbstract() noexcept { - delete this->vectors; - delete indexCalculator; - delete preprocessors; - } - - /** - * @brief Calculate the distance between two vectors based on index parameters. - * - * @return the distance between the vectors. - */ - DistType calcDistance(const void *vector_data1, const void *vector_data2) const { - return indexCalculator->calcDistance(vector_data1, vector_data2, this->dim); - } - - /** - * @brief Preprocess a blob for both storage and query. - * - * @param original_blob will be copied. - * @return two unique_ptr of the processed blobs. - */ - ProcessedBlobs preprocess(const void *original_blob) const; - - /** - * @brief Preprocess a blob for query. - * - * @param queryBlob will be copied if preprocessing is required, or if force_copy is set to - * true. - * @return unique_ptr of the processed blob. - */ - MemoryUtils::unique_blob preprocessQuery(const void *queryBlob, bool force_copy = false) const; - - /** - * @brief Preprocess a blob for storage. - * - * @param original_blob will be copied. - * @return unique_ptr of the processed blob. - */ - MemoryUtils::unique_blob preprocessForStorage(const void *original_blob) const; - - /** - * @brief Preprocess a blob for storage in place. - * - * @param blob will be directly modified, not copied. - */ - void preprocessStorageInPlace(void *blob) const; - - inline size_t getDim() const { return dim; } - inline void setLastSearchMode(VecSearchMode mode) override { this->lastMode = mode; } - inline bool isMultiValue() const { return isMulti; } - inline VecSimType getType() const { return vecType; } - inline VecSimMetric getMetric() const { return metric; } - inline size_t getStoredDataSize() const { return storedDataSize; } - inline size_t getInputBlobSize() const { return inputBlobSize; } - inline size_t getBlockSize() const { return blockSize; } - inline auto getAlignment() const { return this->preprocessors->getAlignment(); } - - virtual inline VecSimIndexStatsInfo statisticInfo() const override { - return VecSimIndexStatsInfo{ - .memory = this->getAllocationSize(), - .numberOfMarkedDeleted = 0, - }; - } - - virtual VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams) const = 0; - VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) const override { - auto results = rangeQuery(queryBlob, radius, queryParams); - sort_results(results, order); - return results; - } - - void log(const char *level, const char *fmt, ...) const { - if (VecSimIndexInterface::logCallback) { - // Format the message and call the callback - va_list args; - va_start(args, fmt); - int len = vsnprintf(NULL, 0, fmt, args); - va_end(args); - char *buf = new char[len + 1]; - va_start(args, fmt); - vsnprintf(buf, len + 1, fmt, args); - va_end(args); - logCallback(this->logCallbackCtx, level, buf); - delete[] buf; - } - } - - // Adds all common info to the info iterator, besides the block size (currently 8 fields). - void addCommonInfoToIterator(VecSimDebugInfoIterator *infoIterator, - const CommonInfo &info) const { - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::TYPE_STRING, - .fieldType = INFOFIELD_STRING, - .fieldValue = {FieldValue{.stringValue = VecSimType_ToString(info.basicInfo.type)}}}); - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::DIMENSION_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.basicInfo.dim}}}); - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::METRIC_STRING, - .fieldType = INFOFIELD_STRING, - .fieldValue = {FieldValue{ - .stringValue = VecSimMetric_ToString(info.basicInfo.metric)}}}); - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::IS_MULTI_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.basicInfo.isMulti}}}); - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::IS_DISK_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.basicInfo.isDisk}}}); - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::INDEX_SIZE_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.indexSize}}}); - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::INDEX_LABEL_COUNT_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.indexLabelCount}}}); - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::MEMORY_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.memory}}}); - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::SEARCH_MODE_STRING, - .fieldType = INFOFIELD_STRING, - .fieldValue = {FieldValue{.stringValue = VecSimSearchMode_ToString(info.lastMode)}}}); - } - - /** - * @brief Get the basic static info object - * - * @return basicInfo - */ - VecSimIndexBasicInfo getBasicInfo() const { - VecSimIndexBasicInfo info{ - .metric = this->metric, - .type = this->vecType, - .isMulti = this->isMulti, - .isDisk = this->isDisk, - .blockSize = this->blockSize, - .dim = this->dim, - }; - return info; - } -#ifdef BUILD_TESTS - void replacePPContainer(PreprocessorsContainerAbstract *newPPContainer) { - delete this->preprocessors; - this->preprocessors = newPPContainer; - } - - IndexComponents get_components() const { - return {.indexCalculator = this->indexCalculator, .preprocessors = this->preprocessors}; - } - - /** - * @brief Used for testing - get only the vector elements associated with a given label. - * This function copies only the vector(s) elements into the output vector, - * without any additional metadata that might be stored with the vector. - * - * Important: This method returns ONLY the vector elements, even if the stored vector contains - * additional metadata. For example, with int8_t/uint8_t vectors using cosine similarity, - * this method will NOT return the norm that is stored with the vector(s). - * - * If you need the complete data including any metadata, use getStoredVectorDataByLabel() - * instead. - * - * @param label The label to retrieve vector(s) elements for - * @param vectors_output Empty vector to be filled with vector(s) - */ - virtual void getDataByLabel(labelType label, - std::vector> &vectors_output) const = 0; - - /** - * @brief Used for testing - get the complete raw data associated with a given label. - * This function returns the ENTIRE vector(s) data as stored in the index, including any - * additional metadata that might be stored alongside the vector elements. - * - * For example: - * - For int8_t/uint8_t vectors with cosine similarity, this includes the norm stored at the end - * - For other vector types or future implementations, this will include any additional data - * that might be stored with the vector - * - * Use this method when you need access to the complete vector data as it is stored internally. - * - * @param label The label to retrieve data for - * @return A vector containing the complete vector data (elements + metadata) for the given - * label - */ - virtual std::vector> getStoredVectorDataByLabel(labelType label) const = 0; -#endif - - /** - * Virtual functions that access the label lookup which is implemented in the derived classes - * Return all the labels in the index - this should be used for computing the number of distinct - * labels in a tiered index, and caller should hold the appropriate guards. - */ - virtual vecsim_stl::set getLabelsSet() const = 0; - -protected: - void runGC() override {} // Do nothing, relevant for tiered index only. - void acquireSharedLocks() override {} // Do nothing, relevant for tiered index only. - void releaseSharedLocks() override {} // Do nothing, relevant for tiered index only. -}; - -template -ProcessedBlobs VecSimIndexAbstract::preprocess(const void *blob) const { - return this->preprocessors->preprocess(blob, inputBlobSize); -} - -template -MemoryUtils::unique_blob -VecSimIndexAbstract::preprocessQuery(const void *queryBlob, - bool force_copy) const { - return this->preprocessors->preprocessQuery(queryBlob, inputBlobSize, force_copy); -} - -template -MemoryUtils::unique_blob -VecSimIndexAbstract::preprocessForStorage(const void *original_blob) const { - return this->preprocessors->preprocessForStorage(original_blob, inputBlobSize); -} - -template -void VecSimIndexAbstract::preprocessStorageInPlace(void *blob) const { - this->preprocessors->preprocessStorageInPlace(blob, inputBlobSize); -} diff --git a/src/VecSim/vec_sim_interface.cpp b/src/VecSim/vec_sim_interface.cpp deleted file mode 100644 index 2359f10fd..000000000 --- a/src/VecSim/vec_sim_interface.cpp +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/vec_sim_interface.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// Global variable to store the current log context -struct TestNameLogContext { - std::string test_name = ""; - std::string test_type = ""; -}; - -static TestNameLogContext test_name_log_context; - -extern "C" void VecSim_SetTestLogContext(const char *test_name, const char *test_type) { - test_name_log_context.test_name = std::string(test_name); - test_name_log_context.test_type = std::string(test_type); -} - -/* Example: - * createLogString("ERROR", "Failed to open file"); - * → "[2025-06-11 09:13:47.237] [ERROR] Failed to open file" - */ -static std::string createLogString(const char *level, const char *message) { - // Get current timestamp - auto now = std::chrono::system_clock::now(); - auto time_t = std::chrono::system_clock::to_time_t(now); - auto ms = std::chrono::duration_cast(now.time_since_epoch()) % 1000; - - // Format timestamp - std::tm *tm_info = std::localtime(&time_t); - char timestamp[100]; - std::strftime(timestamp, sizeof(timestamp), "%Y-%m-%d %H:%M:%S", tm_info); - - // Format log entry - std::ostringstream oss; - oss << "[" << timestamp << "." << std::setw(3) << std::setfill('0') << ms.count() << "] [" - << level << "] " << message; - return oss.str(); -} - -// writes the logs to a file -void Vecsim_Log(void *ctx, const char *level, const char *message) { - std::string log_entry = createLogString(level, message); - // If test name context is not provided, write it to stdout - if (test_name_log_context.test_name.empty() || test_name_log_context.test_type.empty()) { - std::cout << log_entry << std::endl; - return; - } - - std::ostringstream path_stream; - path_stream << "logs/tests/" << test_name_log_context.test_type << "/" - << test_name_log_context.test_name << ".log"; - - // Write to file - std::ofstream log_file(path_stream.str(), std::ios::app); - if (log_file.is_open()) { - log_file << log_entry << std::endl; - log_file.close(); - } -} - -timeoutCallbackFunction VecSimIndexInterface::timeoutCallback = [](void *ctx) { return 0; }; -logCallbackFunction VecSimIndexInterface::logCallback = Vecsim_Log; -VecSimWriteMode VecSimIndexInterface::asyncWriteMode = VecSim_WriteAsync; diff --git a/src/VecSim/vec_sim_interface.h b/src/VecSim/vec_sim_interface.h deleted file mode 100644 index 0b627bc57..000000000 --- a/src/VecSim/vec_sim_interface.h +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once -#include "vec_sim_common.h" -#include "query_results.h" -#include "VecSim/memory/vecsim_base.h" -#include "info_iterator_struct.h" - -#include -#include -#include -/** - * @brief Abstract C++ class for vector index, delete and lookup - * - */ -struct VecSimIndexInterface : public VecsimBaseObject { - -public: - /** - * @brief Construct a new Vec Sim Index object - * - */ - VecSimIndexInterface(std::shared_ptr allocator) - : VecsimBaseObject(allocator) {} - - /** - * @brief Destroy the Vec Sim Index object - * - */ - virtual ~VecSimIndexInterface() = default; - - /** - * @brief Add a vector blob and its id to the index. - * - * @param blob binary representation of the vector. Blob size should match the index data type - * and dimension. The blob will be copied and processed by the index. - * @param label the label of the added vector. - * @return the number of new vectors inserted (1 for new insertion, 0 for override). - */ - virtual int addVector(const void *blob, labelType label) = 0; - - /** - * @brief Remove a vector from an index. - * - * @param label the label of the vector to remove - * @return the number of vectors deleted - */ - virtual int deleteVector(labelType label) = 0; - - /** - * @brief Calculate the distance of a vector from an index to a vector. - * @param index the index from which the first vector is located, and that defines the distance - * metric. - * @param id the id of the vector in the index. - * @param blob binary representation of the second vector. Blob size should match the index data - * type and dimension, and pre-normalized if needed. - * @return The distance (according to the index's distance metric) between `blob` and the vector - * with id `id`. - */ - virtual double getDistanceFrom_Unsafe(labelType id, const void *blob) const = 0; - - /** - * @brief Return the number of vectors in the index (including ones that are marked as deleted). - * - * @return index size. - */ - virtual size_t indexSize() const = 0; - - /** - * @brief Return the index capacity, so we know if resize is required for adding new vectors. - * - * @return index capacity. - */ - virtual size_t indexCapacity() const = 0; - - /** - * @brief Return the number of unique labels in the index (which are not deleted). - * !!! Note: for tiered index, this should only be called in debug mode, as it may require - * locking the indexes and going over the labels sets, which is time-consuming. !!! - * - * @return index label count. - */ - virtual size_t indexLabelCount() const = 0; - - /** - * @brief Search for the k closest vectors to a given vector in the index. - * @param queryBlob binary representation of the query vector. Blob size should match the index - * data type and dimension. The index is responsible to process the query vector. - * @param k the number of "nearest neighbors" to return (upper bound). - * @param queryParams run time params for the search, which are algorithm-specific. - * @return An opaque object the represents a list of results. User can access the id and score - * (which is the distance according to the index metric) of every result through - * VecSimQueryReply_Iterator. - */ - virtual VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const = 0; - - /** - * @brief Search for the vectors that are in a given range in the index with respect to a given - * vector. The results can be ordered by their score or id. - * @param queryBlob binary representation of the query vector. Blob size should match the index - * data type and dimension. The index is responsible to process the query vector. - * @param radius the radius around the query vector to search vectors within it. - * @param queryParams run time params for the search, which are algorithm-specific. - * @param order the criterion to sort the results list by it. Options are by score, or by id. - * @return An opaque object the represents a list of results. User can access the id and score - * (which is the distance according to the index metric) of every result through - * VecSimQueryReply_Iterator. - */ - virtual VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) const = 0; - - /** - * @brief Return index information. - * - * @return Index general and specific meta-data. Note that this operation might - * be time consuming (specially for tiered index where computing label count required - * locking and going over the labels sets). So this should be used carefully. - */ - virtual VecSimIndexDebugInfo debugInfo() const = 0; - - /** - * @brief Return index static information. - * - * @return Index general and specific meta-data (for quick and lock-less data retrieval) - */ - virtual VecSimIndexBasicInfo basicInfo() const = 0; - - /** - * @brief Return index statistic information. - * - * @return Index general and specific statistic data (for quick and lock-less retrieval) - */ - virtual VecSimIndexStatsInfo statisticInfo() const = 0; - - /** - * @brief Returns an index information in an iterable structure. - * - * @return VecSimDebugInfoIterator Index general and specific meta-data. - */ - virtual VecSimDebugInfoIterator *debugInfoIterator() const = 0; - - /** - * @brief A function to be implemented by the inheriting index and called by rangeQuery. - * @param queryBlob binary representation of the query vector. Blob size should match the index - * data type and dimension. The index is responsible to process the query vector. - */ - virtual VecSimBatchIterator *newBatchIterator(const void *queryBlob, - VecSimQueryParams *queryParams) const = 0; - /** - * @brief Return True if heuristics says that it is better to use ad-hoc brute-force - * search over the index instead of using batch iterator. - * - * @param subsetSize the estimated number of vectors in the index that pass the filter - * (that is, query results can be only from a subset of vector of this size). - * - * @param k the number of required results to return from the query. - * - * @param initial_check flag to indicate if this check is performed for the first time (upon - * creating the hybrid iterator), or after running batches. - */ - - virtual bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const = 0; - - /** - * @brief Set the latest search mode in the index data (for info/debugging). - * @param mode The search mode. - */ - virtual inline void setLastSearchMode(VecSearchMode mode) = 0; - - /** - * @brief Run async garbage collection for tiered async index. - */ - virtual void runGC() = 0; - - /** - * @brief Acquire the locks for shared ownership in tiered async index. - */ - virtual void acquireSharedLocks() = 0; - - /** - * @brief Release the locks for shared ownership in tiered async index. - */ - virtual void releaseSharedLocks() = 0; - - /** - * @brief Allow 3rd party timeout callback to be used for limiting runtime of a query. - * - * @param callback timeoutCallbackFunction function. should get void* and return int. - */ - static timeoutCallbackFunction timeoutCallback; - inline static void setTimeoutCallbackFunction(timeoutCallbackFunction callback) { - VecSimIndexInterface::timeoutCallback = callback; - } - - static logCallbackFunction logCallback; - inline static void setLogCallbackFunction(logCallbackFunction callback) { - VecSimIndexInterface::logCallback = callback; - } - - /** - * @brief Allow 3rd party to set the write mode for tiered index - async insert/delete using - * background jobs, or insert/delete inplace. - * - * @param mode VecSimWriteMode the mode in which we add/remove vectors (async or in-place). - */ - static VecSimWriteMode asyncWriteMode; - inline static void setWriteMode(VecSimWriteMode mode) { - VecSimIndexInterface::asyncWriteMode = mode; - } -#ifdef BUILD_TESTS - virtual void fitMemory() = 0; - /** - * @brief get the capacity of the meta data containers. - * - * @return The capacity of the meta data containers in number of elements. - * The value returned from this function may differ from the indexCapacity() function. For - * example, in HNSW, the capacity of the meta data containers is the capacity of the labels - * lookup table, while the capacity of the data containers is the capacity of the vectors - * container. - */ - virtual size_t indexMetaDataCapacity() const = 0; -#endif -}; diff --git a/src/VecSim/vec_sim_tiered_index.h b/src/VecSim/vec_sim_tiered_index.h deleted file mode 100644 index 0a36b1f86..000000000 --- a/src/VecSim/vec_sim_tiered_index.h +++ /dev/null @@ -1,434 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ - -#pragma once - -#include "vec_sim_index.h" -#include "algorithms/brute_force/brute_force.h" -#include "VecSim/batch_iterator.h" -#include "VecSim/tombstone_interface.h" -#include "VecSim/utils/query_result_utils.h" -#include "VecSim/utils/alignment.h" - -#include - -#define TIERED_LOG this->backendIndex->log - -/** - * Definition of generic job structure for asynchronous tiered index. - */ -struct AsyncJob : public VecsimBaseObject { - JobType jobType; - JobCallback Execute; // A callback that receives a job as its input and executes the job. - VecSimIndex *index; - bool isValid; - - AsyncJob(std::shared_ptr allocator, JobType type, JobCallback callback, - VecSimIndex *index_ref) - : VecsimBaseObject(allocator), jobType(type), Execute(callback), index(index_ref), - isValid(true) {} -}; - -// All read operations (including KNN, range, batch iterators and get-distance-from) are guaranteed -// to consider all vectors that were added to the index before the query was submitted. The results -// may include vectors that were added after the query was submitted, with no guarantees. -template -class VecSimTieredIndex : public VecSimIndexInterface { -protected: - VecSimIndexAbstract *backendIndex; - BruteForceIndex *frontendIndex; - - void *jobQueue; - void *jobQueueCtx; // External context to be sent to the submit callback. - SubmitCB SubmitJobsToQueue; - - mutable std::shared_mutex flatIndexGuard; - mutable std::shared_mutex mainIndexGuard; - void lockMainIndexGuard() const { - mainIndexGuard.lock(); -#ifdef BUILD_TESTS - mainIndexGuard_write_lock_count++; -#endif - } - - void unlockMainIndexGuard() const { mainIndexGuard.unlock(); } -#ifdef BUILD_TESTS - mutable std::atomic_int mainIndexGuard_write_lock_count = 0; -#endif - size_t flatBufferLimit; - - void submitSingleJob(AsyncJob *job) { - this->SubmitJobsToQueue(this->jobQueue, this->jobQueueCtx, &job, &job->Execute, 1); - } - - void submitJobs(vecsim_stl::vector &jobs) { - vecsim_stl::vector callbacks(jobs.size(), this->allocator); - for (size_t i = 0; i < jobs.size(); i++) { - callbacks[i] = jobs[i]->Execute; - } - this->SubmitJobsToQueue(this->jobQueue, this->jobQueueCtx, jobs.data(), callbacks.data(), - jobs.size()); - } - - /** - * @brief Return the union of unique labels in both index tiers (which are not deleted). - * This is a debug-only method for tiered indexes that computes the union of labels - * from both frontend and backend indexes. It assumes that caller holds the appropriate - * locks and it is time-consuming. - * !!! Note: this should only be called in debug mode for tiered indexes !!! - * - * @return index label count for debug purposes. - */ - vecsim_stl::vector computeUnifiedIndexLabelsSetUnsafe() const { - auto [flat_labels, backend_labels] = - std::make_pair(this->frontendIndex->getLabelsSet(), this->backendIndex->getLabelsSet()); - - // Compute the union of the two sets. - vecsim_stl::vector labels_union(this->allocator); - labels_union.reserve(flat_labels.size() + backend_labels.size()); - std::set_union(flat_labels.begin(), flat_labels.end(), backend_labels.begin(), - backend_labels.end(), std::back_inserter(labels_union)); - return labels_union; - } - -#ifdef BUILD_TESTS -public: - int getMainIndexGuardWriteLockCount() const { return mainIndexGuard_write_lock_count; } -#endif - // For both topK and range, Use withSet=false if you can guarantee that shared ids between the - // two lists will also have identical scores. In this case, any duplicates will naturally align - // at the front of both lists during the merge, so they can be removed without explicitly - // tracking seen ids — enabling a more efficient merge. - template - VecSimQueryReply *topKQueryImp(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const; - - template - VecSimQueryReply *rangeQueryImp(const void *queryBlob, double radius, - VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) const; - -public: - VecSimTieredIndex(VecSimIndexAbstract *backendIndex_, - BruteForceIndex *frontendIndex_, - TieredIndexParams tieredParams, std::shared_ptr allocator) - : VecSimIndexInterface(allocator), backendIndex(backendIndex_), - frontendIndex(frontendIndex_), jobQueue(tieredParams.jobQueue), - jobQueueCtx(tieredParams.jobQueueCtx), SubmitJobsToQueue(tieredParams.submitCb), - flatBufferLimit(tieredParams.flatBufferLimit) {} - - virtual ~VecSimTieredIndex() { - VecSimIndex_Free(backendIndex); - VecSimIndex_Free(frontendIndex); - } - - VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const override; - - VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) const override; - - virtual inline uint64_t getAllocationSize() const override { - return this->allocator->getAllocationSize() + this->backendIndex->getAllocationSize() + - this->frontendIndex->getAllocationSize(); - } - virtual size_t getNumMarkedDeleted() const = 0; - size_t indexLabelCount() const override; - VecSimIndexStatsInfo statisticInfo() const override; - virtual VecSimIndexDebugInfo debugInfo() const override; - virtual VecSimDebugInfoIterator *debugInfoIterator() const override; - - bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override { - // For now, decide according to the bigger index. - return this->backendIndex->indexSize() > this->frontendIndex->indexSize() - ? this->backendIndex->preferAdHocSearch(subsetSize, k, initial_check) - : this->frontendIndex->preferAdHocSearch(subsetSize, k, initial_check); - } - - // Return the current state of the global write mode (async/in-place). - static VecSimWriteMode getWriteMode() { return VecSimIndexInterface::asyncWriteMode; } - -#ifdef BUILD_TESTS - inline BruteForceIndex *getFlatBufferIndex() { return this->frontendIndex; } - inline size_t getFlatBufferLimit() { return this->flatBufferLimit; } - - virtual void fitMemory() override { - this->backendIndex->fitMemory(); - this->frontendIndex->fitMemory(); - } -#endif -}; - -template -template -VecSimQueryReply * -VecSimTieredIndex::topKQueryImp(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const { - this->flatIndexGuard.lock_shared(); - - // If the flat buffer is empty, we can simply query the main index. - if (this->frontendIndex->indexSize() == 0) { - // Release the flat lock and acquire the main lock. - this->flatIndexGuard.unlock_shared(); - - // Simply query the main index and return the results while holding the lock. - auto processed_query_ptr = this->frontendIndex->preprocessQuery(queryBlob); - const void *processed_query = processed_query_ptr.get(); - this->mainIndexGuard.lock_shared(); - auto res = this->backendIndex->topKQuery(processed_query, k, queryParams); - this->mainIndexGuard.unlock_shared(); - - return res; - } else { - // No luck... first query the flat buffer and release the lock. - // The query blob is already processed according to the frontend index. - auto flat_results = this->frontendIndex->topKQuery(queryBlob, k, queryParams); - this->flatIndexGuard.unlock_shared(); - - // If the query failed (currently only on timeout), return the error code. - if (flat_results->code != VecSim_QueryReply_OK) { - assert(flat_results->results.empty()); - return flat_results; - } - - auto processed_query_ptr = this->frontendIndex->preprocessQuery(queryBlob); - const void *processed_query = processed_query_ptr.get(); - // Lock the main index and query it. - this->mainIndexGuard.lock_shared(); - auto main_results = this->backendIndex->topKQuery(processed_query, k, queryParams); - this->mainIndexGuard.unlock_shared(); - - // If the query failed (currently only on timeout), return the error code. - if (main_results->code != VecSim_QueryReply_OK) { - // Free the flat results. - VecSimQueryReply_Free(flat_results); - - assert(main_results->results.empty()); - return main_results; - } - - return merge_result_lists(main_results, flat_results, k); - } -} -template -VecSimQueryReply * -VecSimTieredIndex::topKQuery(const void *queryBlob, size_t k, - VecSimQueryParams *queryParams) const { - if (this->backendIndex->isMultiValue()) { - return this->topKQueryImp(queryBlob, k, queryParams); // Multi-value index - } else { - // Calling with withSet=false for optimized performance, assuming that shared IDs across - // lists also have identical scores — in which case duplicates are implicitly avoided by the - // merge logic. - return this->topKQueryImp(queryBlob, k, queryParams); - } -} - -template -VecSimQueryReply * -VecSimTieredIndex::rangeQuery(const void *queryBlob, double radius, - VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) const { - if (this->backendIndex->isMultiValue()) { - return this->rangeQueryImp(queryBlob, radius, queryParams, - order); // Multi-value index - } else { - // Calling with withSet=false for optimized performance, assuming that shared IDs across - // lists also have identical scores — in which case duplicates are implicitly avoided by the - // merge logic. - return this->rangeQueryImp(queryBlob, radius, queryParams, order); - } -} - -template -template -VecSimQueryReply * -VecSimTieredIndex::rangeQueryImp(const void *queryBlob, double radius, - VecSimQueryParams *queryParams, - VecSimQueryReply_Order order) const { - this->flatIndexGuard.lock_shared(); - - // If the flat buffer is empty, we can simply query the main index. - if (this->frontendIndex->indexSize() == 0) { - // Release the flat lock and acquire the main lock. - this->flatIndexGuard.unlock_shared(); - - auto processed_query_ptr = this->frontendIndex->preprocessQuery(queryBlob); - const void *processed_query = processed_query_ptr.get(); - // Simply query the main index and return the results while holding the lock. - this->mainIndexGuard.lock_shared(); - auto res = this->backendIndex->rangeQuery(processed_query, radius, queryParams); - this->mainIndexGuard.unlock_shared(); - - // We could have passed the order to the main index, but we can sort them here after - // unlocking it instead. - sort_results(res, order); - return res; - } else { - // No luck... first query the flat buffer and release the lock. - // The query blob is already processed according to the frontend index. - auto flat_results = this->frontendIndex->rangeQuery(queryBlob, radius, queryParams); - this->flatIndexGuard.unlock_shared(); - - // If the query failed (currently only on timeout), return the error code and the partial - // results. - if (flat_results->code != VecSim_QueryReply_OK) { - return flat_results; - } - - auto processed_query_ptr = this->frontendIndex->preprocessQuery(queryBlob); - const void *processed_query = processed_query_ptr.get(); - // Lock the main index and query it. - this->mainIndexGuard.lock_shared(); - auto main_results = this->backendIndex->rangeQuery(processed_query, radius, queryParams); - this->mainIndexGuard.unlock_shared(); - - // Merge the results and return, avoiding duplicates. - // At this point, the return code of the FLAT index is OK, and the return code of the MAIN - // index is either OK or TIMEOUT. Make sure to return the return code of the MAIN index. - if (BY_SCORE == order) { - sort_results_by_score_then_id(main_results); - sort_results_by_score_then_id(flat_results); - - // Keep the return code of the main index. - auto code = main_results->code; - - // Merge the sorted results with no limit (all the results are valid). - VecSimQueryReply *ret = merge_result_lists(main_results, flat_results, -1); - // Restore the return code and return. - ret->code = code; - return ret; - - } else { // BY_ID - // Notice that we don't modify the return code of the main index in any step. - concat_results(main_results, flat_results); - filter_results_by_id(main_results); - return main_results; - } - } -} - -template -VecSimIndexStatsInfo VecSimTieredIndex::statisticInfo() const { - auto stats = VecSimIndexStatsInfo{ - .memory = this->getAllocationSize(), - .numberOfMarkedDeleted = this->getNumMarkedDeleted(), - }; - - return stats; -} - -template -size_t VecSimTieredIndex::indexLabelCount() const { - // This is a debug-only method for tiered indexes that computes the union of labels - // from both frontend and backend indexes. It requires locking and is time-consuming. - // !!! Note: this should only be called in debug mode for tiered indexes !!! - std::shared_lock flat_lock(this->flatIndexGuard); - std::shared_lock main_lock(this->mainIndexGuard); - return computeUnifiedIndexLabelsSetUnsafe().size(); -} - -template -VecSimIndexDebugInfo VecSimTieredIndex::debugInfo() const { - VecSimIndexDebugInfo info; - this->flatIndexGuard.lock_shared(); - this->mainIndexGuard.lock_shared(); - - VecSimIndexDebugInfo frontendInfo = this->frontendIndex->debugInfo(); - VecSimIndexDebugInfo backendInfo = this->backendIndex->debugInfo(); - - info.commonInfo.indexLabelCount = this->computeUnifiedIndexLabelsSetUnsafe().size(); - - this->flatIndexGuard.unlock_shared(); - this->mainIndexGuard.unlock_shared(); - - info.commonInfo.indexSize = - frontendInfo.commonInfo.indexSize + backendInfo.commonInfo.indexSize; - info.commonInfo.memory = this->getAllocationSize(); - info.commonInfo.lastMode = backendInfo.commonInfo.lastMode; - - VecSimIndexBasicInfo basic_info{ - .algo = backendInfo.commonInfo.basicInfo.algo, - .metric = backendInfo.commonInfo.basicInfo.metric, - .type = backendInfo.commonInfo.basicInfo.type, - .isMulti = this->backendIndex->isMultiValue(), - .isTiered = true, - .isDisk = backendInfo.commonInfo.basicInfo.isDisk, - .blockSize = backendInfo.commonInfo.basicInfo.blockSize, - .dim = backendInfo.commonInfo.basicInfo.dim, - }; - info.commonInfo.basicInfo = basic_info; - - // NOTE: backgroundIndexing needs to be set by the backend index. - info.tieredInfo.backgroundIndexing = VecSimBool_UNSET; - - switch (backendInfo.commonInfo.basicInfo.algo) { - case VecSimAlgo_HNSWLIB: - info.tieredInfo.backendInfo.hnswInfo = backendInfo.hnswInfo; - break; - case VecSimAlgo_SVS: - info.tieredInfo.backendInfo.svsInfo = backendInfo.svsInfo; - break; - case VecSimAlgo_BF: - case VecSimAlgo_TIERED: - assert(false && "Invalid backend algorithm"); - } - - info.tieredInfo.backendCommonInfo = backendInfo.commonInfo; - // For now, this is hard coded to FLAT - info.tieredInfo.frontendCommonInfo = frontendInfo.commonInfo; - info.tieredInfo.bfInfo = frontendInfo.bfInfo; - - info.tieredInfo.management_layer_memory = this->allocator->getAllocationSize(); - info.tieredInfo.bufferLimit = this->flatBufferLimit; - return info; -} - -template -VecSimDebugInfoIterator *VecSimTieredIndex::debugInfoIterator() const { - VecSimIndexDebugInfo info = this->debugInfo(); - // For readability. Update this number when needed. - size_t numberOfInfoFields = 14; - auto *infoIterator = new VecSimDebugInfoIterator(numberOfInfoFields, this->allocator); - - // Set tiered explicitly as algo name for root iterator. - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::ALGORITHM_STRING, - .fieldType = INFOFIELD_STRING, - .fieldValue = {FieldValue{.stringValue = VecSimCommonStrings::TIERED_STRING}}}); - - this->backendIndex->addCommonInfoToIterator(infoIterator, info.commonInfo); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::TIERED_MANAGEMENT_MEMORY_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.tieredInfo.management_layer_memory}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::TIERED_BACKGROUND_INDEXING_STRING, - .fieldType = INFOFIELD_INT64, - .fieldValue = {FieldValue{.integerValue = info.tieredInfo.backgroundIndexing}}}); - - infoIterator->addInfoField( - VecSim_InfoField{.fieldName = VecSimCommonStrings::TIERED_BUFFER_LIMIT_STRING, - .fieldType = INFOFIELD_UINT64, - .fieldValue = {FieldValue{.uintegerValue = info.tieredInfo.bufferLimit}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::FRONTEND_INDEX_STRING, - .fieldType = INFOFIELD_ITERATOR, - .fieldValue = {FieldValue{.iteratorValue = this->frontendIndex->debugInfoIterator()}}}); - - infoIterator->addInfoField(VecSim_InfoField{ - .fieldName = VecSimCommonStrings::BACKEND_INDEX_STRING, - .fieldType = INFOFIELD_ITERATOR, - .fieldValue = {FieldValue{.iteratorValue = this->backendIndex->debugInfoIterator()}}}); - return infoIterator; -}; diff --git a/src/VecSim/version.h b/src/VecSim/version.h deleted file mode 100644 index c87fbb25e..000000000 --- a/src/VecSim/version.h +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#pragma once - -#define VSS_VERSION_MAJOR 99 -#define VSS_VERSION_MINOR 99 -#define VSS_VERSION_PATCH 99 diff --git a/src/python_bindings/BF_iterator_demo.ipynb b/src/python_bindings/BF_iterator_demo.ipynb deleted file mode 100644 index 596393adb..000000000 --- a/src/python_bindings/BF_iterator_demo.ipynb +++ /dev/null @@ -1,115 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "86973d44", - "metadata": {}, - "outputs": [], - "source": [ - "from VecSim import *\n", - "import numpy as np\n", - "\n", - "dim = 128\n", - "num_elements = 1000\n", - "\n", - "# Create a brute force index for vectors of 128 floats. Use 'L2' as the distance metric\n", - "bf_params = BFParams()\n", - "bf_params.blockSize = num_elements\n", - "bf_index = BFIndex(bf_params, VecSimType_FLOAT32, dim, VecSimMetric_L2)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "03ff62d4", - "metadata": {}, - "outputs": [], - "source": [ - "# Add 1M random vectors to the index\n", - "data = np.float32(np.random.random((num_elements, dim)))\n", - "vectors = []\n", - "\n", - "for i, vector in enumerate(data):\n", - " bf_index.add_vector(vector, i)\n", - " vectors.append((i, vector))\n", - "\n", - "print(f'Index size: {bf_index.index_size()}')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc831b57", - "metadata": {}, - "outputs": [], - "source": [ - "# Create a random query vector\n", - "query_data = np.float32(np.random.random((1, dim)))\n", - "\n", - "# Create batch iterator for this query vector\n", - "batch_iterator = bf_index.create_batch_iterator(query_data)\n", - "returned_results_num = 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a3c3fbee", - "metadata": {}, - "outputs": [], - "source": [ - "# Get the next best results\n", - "batch_size = 100\n", - "labels, distances = batch_iterator.get_next_results(batch_size, BY_SCORE)\n", - "\n", - "print (f'Results in rank {returned_results_num}-{returned_results_num+len(labels[0])} are: \\n')\n", - "print (f'labels: {labels}')\n", - "print (f'scores: {distances}')\n", - "\n", - "returned_results_num += len(labels[0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7925ecf6", - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "\n", - "# Run batches until depleted\n", - "batch_size = 15\n", - "start = time.time()\n", - "while(batch_iterator.has_next()):\n", - " labels, distances = batch_iterator.get_next_results(batch_size, BY_ID)\n", - " returned_results_num += len(labels[0])\n", - "\n", - "print(f'Total results returned: {returned_results_num}\\n')\n", - "print(f'Total search time: {time.time() - start}')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python [conda env:vsim] *", - "language": "python", - "name": "conda-env-vsim-py" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/python_bindings/CMakeLists.txt b/src/python_bindings/CMakeLists.txt deleted file mode 100644 index 48d4a71fa..000000000 --- a/src/python_bindings/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -cmake_minimum_required(VERSION 3.25) -project(VecSim LANGUAGES CXX) -set(CMAKE_CXX_STANDARD 20) - -# build bindings with -DBUILD_TESTS flag -option(VECSIM_BUILD_TESTS "Build tests" ON) -ADD_DEFINITIONS(-DBUILD_TESTS) -get_filename_component(root ${CMAKE_CURRENT_LIST_DIR}/../.. ABSOLUTE) - -include(FetchContent) -FetchContent_Declare( - pybind11 - GIT_REPOSITORY https://github.com/pybind/pybind11 - GIT_TAG v2.10.1 -) -FetchContent_GetProperties(pybind11) - -if(NOT pybind11_POPULATED) - FetchContent_Populate(pybind11) - add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR}) -endif() - -include(${root}/cmake/svs.cmake) -add_subdirectory(${root}/src/VecSim VectorSimilarity) - -include_directories(${root}/src ${root}/tests/utils) - -pybind11_add_module(VecSim ../../tests/utils/mock_thread_pool.cpp bindings.cpp) - -target_link_libraries(VecSim PRIVATE VectorSimilarity) - -add_dependencies(VecSim VectorSimilarity) diff --git a/src/python_bindings/HNSW_iterator_demo.ipynb b/src/python_bindings/HNSW_iterator_demo.ipynb deleted file mode 100644 index 5856f91dd..000000000 --- a/src/python_bindings/HNSW_iterator_demo.ipynb +++ /dev/null @@ -1,176 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "86973d44", - "metadata": {}, - "outputs": [], - "source": [ - "from VecSim import *\n", - "import numpy as np\n", - "\n", - "dim = 100\n", - "num_elements = 100000\n", - "M = 32\n", - "efConstruction = 200\n", - "efRuntime = 200\n", - "\n", - "# Create a hnsw index for vectors of 100 floats. Use 'L2' as the distance metric\n", - "hnswparams = HNSWParams()\n", - "hnswparams.M = M\n", - "hnswparams.efConstruction = efConstruction\n", - "hnswparams.efRuntime = efRuntime\n", - "hnswparams.dim = dim\n", - "hnswparams.type = VecSimType_FLOAT32\n", - "hnswparams.metric = VecSimMetric_L2\n", - "\n", - "hnsw_index = HNSWIndex(hnswparams)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "03ff62d4", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# Add 100k random vectors and insert then to the index\n", - "data = np.float32(np.random.random((num_elements, dim)))\n", - "vectors = []\n", - "\n", - "for i, vector in enumerate(data):\n", - " hnsw_index.add_vector(vector, i)\n", - " vectors.append((i, vector))\n", - "\n", - "print(f'Index size: {hnsw_index.index_size()}')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc831b57", - "metadata": {}, - "outputs": [], - "source": [ - "# Create a random query vector\n", - "hnsw_index.set_ef(300)\n", - "query_data = np.float32(np.random.random((1, dim)))\n", - "\n", - "# Create batch iterator for this query vector\n", - "batch_iterator = hnsw_index.create_batch_iterator(query_data)\n", - "returned_results_num = 0\n", - "accumulated_labels = []\n", - "total_time = 0\n", - "\n", - "from scipy import spatial\n", - "\n", - "# Sort distances of every vector from the target vector and get the actual order\n", - "dists = [(spatial.distance.euclidean(query_data, vec), key) for key, vec in vectors]\n", - "dists = sorted(dists)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a3c3fbee", - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "# Get the next best results\n", - "import time\n", - "\n", - "start = time.time()\n", - "batch_size = 100\n", - "labels, distances = batch_iterator.get_next_results(batch_size, BY_SCORE)\n", - "total_time += time.time()-start\n", - "\n", - "print (f'Results in rank {returned_results_num}-{returned_results_num+len(labels[0])} are: \\n')\n", - "print (f'scores: {distances}\\n')\n", - "print (f'labels: {labels}')\n", - "\n", - "returned_results_num += len(labels[0])\n", - "accumulated_labels.extend(labels[0])\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d79d0bb9", - "metadata": {}, - "outputs": [], - "source": [ - "# Measure recall and time\n", - "\n", - "keys = [key for _, key in dists[:returned_results_num]]\n", - "correct = len(set(accumulated_labels).intersection(set(keys)))\n", - "\n", - "print(f'Total search time: {total_time}')\n", - "print(f'Recall for {returned_results_num} results in index of size {num_elements} with dim={dim} is: ', correct/returned_results_num)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8328da4b", - "metadata": {}, - "outputs": [], - "source": [ - "# Comapre to \"stadnrd\" KNN search\n", - "\n", - "start = time.time()\n", - "labels_knn, distances_knn = hnsw_index.knn_query(query_data, returned_results_num)\n", - "print(f'Total search time: {time.time() - start}')\n", - "\n", - "keys = [key for _, key in dists[:returned_results_num]]\n", - "correct = len(set(labels_knn[0]).intersection(set(keys)))\n", - "print(f'Recall for {returned_results_num} results in index of size {num_elements} with dim={dim} is: ', correct/returned_results_num)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7925ecf6", - "metadata": {}, - "outputs": [], - "source": [ - "# Run batches until depleted\n", - "batch_iterator.reset()\n", - "returned_results_num = 0\n", - "batch_size = 100\n", - "start = time.time()\n", - "while(batch_iterator.has_next()):\n", - " labels, distances = batch_iterator.get_next_results(batch_size, BY_ID)\n", - " returned_results_num += len(labels[0])\n", - "\n", - "print(f'Total results returned: {returned_results_num}\\n')\n", - "print(f'Total search time: {time.time() - start}')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python [conda env:vsim] *", - "language": "python", - "name": "conda-env-vsim-py" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/python_bindings/bindings.cpp b/src/python_bindings/bindings.cpp deleted file mode 100644 index b68c14653..000000000 --- a/src/python_bindings/bindings.cpp +++ /dev/null @@ -1,875 +0,0 @@ -/* - * Copyright (c) 2006-Present, Redis Ltd. - * All rights reserved. - * - * Licensed under your choice of the Redis Source Available License 2.0 - * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the - * GNU Affero General Public License v3 (AGPLv3). - */ -#include "VecSim/vec_sim.h" -#include "VecSim/algorithms/hnsw/hnsw.h" -#include "VecSim/index_factories/hnsw_factory.h" -#if HAVE_SVS -#include "VecSim/algorithms/svs/svs.h" -#include "VecSim/index_factories/svs_factory.h" -#endif -#include "VecSim/batch_iterator.h" -#include "VecSim/types/bfloat16.h" -#include "VecSim/types/float16.h" - -#include "pybind11/pybind11.h" -#include "pybind11/numpy.h" -#include "pybind11/stl.h" -#include -#include -#include -#include -#include "mock_thread_pool.h" - -namespace py = pybind11; - -using bfloat16 = vecsim_types::bfloat16; -using float16 = vecsim_types::float16; - -// Helper function that iterates query results and wrap them in python numpy object - -// a tuple of two 2D arrays: (labels, distances) -py::object wrap_results(VecSimQueryReply **res, size_t num_res, size_t num_queries = 1) { - auto *data_numpy_l = new long[num_res * num_queries]; - auto *data_numpy_d = new double[num_res * num_queries]; - // Default "padding" for the entries that will stay empty (in case of less than k results return - // in KNN, or results of range queries with number of results lower than the maximum in the - // batch (which determines the arrays' shape) - std::fill_n(data_numpy_l, num_res * num_queries, -1); - std::fill_n(data_numpy_d, num_res * num_queries, -1.0); - - for (size_t i = 0; i < num_queries; i++) { - VecSimQueryReply_Iterator *iterator = VecSimQueryReply_GetIterator(res[i]); - size_t res_ind = i * num_res; - while (VecSimQueryReply_IteratorHasNext(iterator)) { - VecSimQueryResult *item = VecSimQueryReply_IteratorNext(iterator); - data_numpy_d[res_ind] = VecSimQueryResult_GetScore(item); - data_numpy_l[res_ind++] = (long)VecSimQueryResult_GetId(item); - } - VecSimQueryReply_IteratorFree(iterator); - VecSimQueryReply_Free(res[i]); - } - - py::capsule free_when_done_l(data_numpy_l, [](void *labels) { delete[] (long *)labels; }); - py::capsule free_when_done_d(data_numpy_d, [](void *dists) { delete[] (double *)dists; }); - return py::make_tuple( - py::array_t( - {(size_t)num_queries, num_res}, // shape - {num_res * sizeof(long), sizeof(long)}, // C-style contiguous strides for size_t - data_numpy_l, // the data pointer (labels array) - free_when_done_l), - py::array_t( - {(size_t)num_queries, num_res}, // shape - {num_res * sizeof(double), sizeof(double)}, // C-style contiguous strides for double - data_numpy_d, // the data pointer (distances array) - free_when_done_d)); -} - -class PyBatchIterator { -private: - // Hold the index pointer, so that it will be destroyed **after** the batch iterator. Hence, - // the index field should come before the iterator field. - std::shared_ptr vectorIndex; - std::shared_ptr batchIterator; - -public: - PyBatchIterator(const std::shared_ptr &vecIndex, - const std::shared_ptr &batchIterator) - : vectorIndex(vecIndex), batchIterator(batchIterator) {} - - bool hasNext() { return VecSimBatchIterator_HasNext(batchIterator.get()); } - - py::object getNextResults(size_t n_res, VecSimQueryReply_Order order) { - VecSimQueryReply *results; - { - // We create this object inside the scope to enable parallel execution of the batch - // iterator from different Python threads. - py::gil_scoped_release py_gil; - results = VecSimBatchIterator_Next(batchIterator.get(), n_res, order); - } - // The number of results may be lower than n_res, if there are less than n_res remaining - // vectors in the index that hadn't been returned yet. - size_t actual_n_res = VecSimQueryReply_Len(results); - return wrap_results(&results, actual_n_res); - } - void reset() { VecSimBatchIterator_Reset(batchIterator.get()); } - virtual ~PyBatchIterator() = default; -}; - -// @input or @query arguments are a py::object object. (numpy arrays are acceptable) -class PyVecSimIndex { -private: - template - inline py::object rawVectorsAsNumpy(labelType label, size_t dim) { - std::vector> vectors; - if (index->basicInfo().algo == VecSimAlgo_BF) { - dynamic_cast *>(this->index.get()) - ->getDataByLabel(label, vectors); - } else { - // index is HNSW - dynamic_cast *>(this->index.get()) - ->getDataByLabel(label, vectors); - } - size_t n_vectors = vectors.size(); - auto *data_numpy = new NPArrayType[n_vectors * dim]; - - // Copy the vector blobs into one contiguous array of data, and free the original buffer - // afterwards. - if constexpr (std::is_same_v) { - for (size_t i = 0; i < n_vectors; i++) { - for (size_t j = 0; j < dim; j++) { - data_numpy[i * dim + j] = vecsim_types::bfloat16_to_float32(vectors[i][j]); - } - } - } else if constexpr (std::is_same_v) { - for (size_t i = 0; i < n_vectors; i++) { - for (size_t j = 0; j < dim; j++) { - data_numpy[i * dim + j] = vecsim_types::FP16_to_FP32(vectors[i][j]); - } - } - } else { - for (size_t i = 0; i < n_vectors; i++) { - memcpy(data_numpy + i * dim, vectors[i].data(), dim * sizeof(NPArrayType)); - } - } - - py::capsule free_when_done(data_numpy, - [](void *vector_data) { delete[] (NPArrayType *)vector_data; }); - return py::array_t( - {n_vectors, dim}, // shape - {dim * sizeof(NPArrayType), - sizeof(NPArrayType)}, // C-style contiguous strides for the data type - data_numpy, // the data pointer - free_when_done); - } - -protected: - std::shared_ptr index; - - inline VecSimQueryReply *searchKnnInternal(const char *query, size_t k, - VecSimQueryParams *query_params) { - return VecSimIndex_TopKQuery(index.get(), query, k, query_params, BY_SCORE); - } - - inline void addVectorInternal(const char *vector_data, size_t id) { - VecSimIndex_AddVector(index.get(), vector_data, id); - } - - inline VecSimQueryReply *searchRangeInternal(const char *query, double radius, - VecSimQueryParams *query_params) { - return VecSimIndex_RangeQuery(index.get(), query, radius, query_params, BY_SCORE); - } - -public: - PyVecSimIndex() = default; - - explicit PyVecSimIndex(const VecSimParams ¶ms) { - index = std::shared_ptr(VecSimIndex_New(¶ms), VecSimIndex_Free); - } - - void addVector(const py::object &input, size_t id) { - py::array vector_data(input); - py::gil_scoped_release py_gil; - addVectorInternal((const char *)vector_data.data(0), id); - } - - void deleteVector(size_t id) { VecSimIndex_DeleteVector(index.get(), id); } - - py::object knn(const py::object &input, size_t k, VecSimQueryParams *query_params) { - py::array query(input); - VecSimQueryReply *res; - { - py::gil_scoped_release py_gil; - res = searchKnnInternal((const char *)query.data(0), k, query_params); - } - return wrap_results(&res, k); - } - - py::object range(const py::object &input, double radius, VecSimQueryParams *query_params) { - py::array query(input); - VecSimQueryReply *res; - { - py::gil_scoped_release py_gil; - res = searchRangeInternal((const char *)query.data(0), radius, query_params); - } - return wrap_results(&res, VecSimQueryReply_Len(res)); - } - - size_t indexSize() { return VecSimIndex_IndexSize(index.get()); } - - VecSimType indexType() { return index->basicInfo().type; } - - size_t indexMemory() { return this->index->getAllocationSize(); } - - virtual PyBatchIterator createBatchIterator(const py::object &input, - VecSimQueryParams *query_params) { - py::array query(input); - auto py_batch_ptr = std::shared_ptr( - VecSimBatchIterator_New(index.get(), (const char *)query.data(0), query_params), - VecSimBatchIterator_Free); - return PyBatchIterator(index, py_batch_ptr); - } - - void runGC() { VecSimTieredIndex_GC(index.get()); } - - py::object getVector(labelType label) { - VecSimIndexBasicInfo info = index->basicInfo(); - size_t dim = info.dim; - if (info.type == VecSimType_FLOAT32) { - return rawVectorsAsNumpy(label, dim); - } else if (info.type == VecSimType_FLOAT64) { - return rawVectorsAsNumpy(label, dim); - } else if (info.type == VecSimType_BFLOAT16) { - return rawVectorsAsNumpy(label, dim); - } else if (info.type == VecSimType_FLOAT16) { - return rawVectorsAsNumpy(label, dim); - } else if (info.type == VecSimType_INT8) { - return rawVectorsAsNumpy(label, dim); - } else { - throw std::runtime_error("Invalid vector data type"); - } - } - - virtual ~PyVecSimIndex() = default; // Delete function was given to the shared pointer object -}; - -class PyHNSWLibIndex : public PyVecSimIndex { -private: - std::shared_ptr - indexGuard; // to protect parallel operations on the index. Make sure to release the GIL - // while locking the mutex. - template // size_t/double for KNN/range queries. - using QueryFunc = - std::function; - - template // size_t/double for KNN / range queries. - void runParallelQueries(const py::array &queries, size_t n_queries, search_param_t param, - VecSimQueryParams *query_params, int n_threads, - QueryFunc queryFunc, VecSimQueryReply **results) { - - // Use number of hardware cores as default number of threads, unless specified otherwise. - if (n_threads <= 0) { - n_threads = (int)std::thread::hardware_concurrency(); - } - std::atomic_int global_counter(0); - - auto parallel_search = [&](const py::array &items) { - while (true) { - int ind = global_counter.fetch_add(1); - if (ind >= n_queries) { - break; - } - { - std::shared_lock lock(*indexGuard); - results[ind] = queryFunc((const char *)items.data(ind), param, query_params); - } - } - }; - std::thread thread_objs[n_threads]; - { - // Release python GIL while threads are running. - py::gil_scoped_release py_gil; - for (size_t i = 0; i < n_threads; i++) { - thread_objs[i] = std::thread(parallel_search, queries); - } - for (size_t i = 0; i < n_threads; i++) { - thread_objs[i].join(); - } - } - } - -public: - explicit PyHNSWLibIndex(const HNSWParams &hnsw_params) { - VecSimParams params = {.algo = VecSimAlgo_HNSWLIB, - .algoParams = {.hnswParams = HNSWParams{hnsw_params}}}; - this->index = std::shared_ptr(VecSimIndex_New(¶ms), VecSimIndex_Free); - this->indexGuard = std::make_shared(); - } - - // @params is required only in V1. - explicit PyHNSWLibIndex(const std::string &location) { - this->index = - std::shared_ptr(HNSWFactory::NewIndex(location), VecSimIndex_Free); - this->indexGuard = std::make_shared(); - } - - void setDefaultEf(size_t ef) { - auto *hnsw = reinterpret_cast *>(index.get()); - hnsw->setEf(ef); - } - void saveIndex(const std::string &location) { - auto type = VecSimIndex_BasicInfo(this->index.get()).type; - if (type == VecSimType_FLOAT32) { - auto *hnsw = dynamic_cast *>(index.get()); - hnsw->saveIndex(location); - } else if (type == VecSimType_FLOAT64) { - auto *hnsw = dynamic_cast *>(index.get()); - hnsw->saveIndex(location); - } else if (type == VecSimType_BFLOAT16) { - auto *hnsw = dynamic_cast *>(index.get()); - hnsw->saveIndex(location); - } else if (type == VecSimType_FLOAT16) { - auto *hnsw = dynamic_cast *>(index.get()); - hnsw->saveIndex(location); - } else if (type == VecSimType_INT8) { - auto *hnsw = dynamic_cast *>(index.get()); - hnsw->saveIndex(location); - } else if (type == VecSimType_UINT8) { - auto *hnsw = dynamic_cast *>(index.get()); - hnsw->saveIndex(location); - } else { - throw std::runtime_error("Invalid index data type"); - } - } - py::object searchKnnParallel(const py::object &input, size_t k, VecSimQueryParams *query_params, - int n_threads) { - - py::array queries(input); - if (queries.ndim() != 2) { - throw std::runtime_error("Input queries array must be 2D array"); - } - size_t n_queries = queries.shape(0); - QueryFunc searchKnnWrapper( - [this](const char *query_, size_t k_, - VecSimQueryParams *query_params_) -> VecSimQueryReply * { - return this->searchKnnInternal(query_, k_, query_params_); - }); - VecSimQueryReply *results[n_queries]; - runParallelQueries(queries, n_queries, k, query_params, n_threads, searchKnnWrapper, - results); - return wrap_results(results, k, n_queries); - } - py::object searchRangeParallel(const py::object &input, double radius, - VecSimQueryParams *query_params, int n_threads) { - py::array queries(input); - if (queries.ndim() != 2) { - throw std::runtime_error("Input queries array must be 2D array"); - } - size_t n_queries = queries.shape(0); - QueryFunc searchRangeWrapper( - [this](const char *query_, double radius_, - VecSimQueryParams *query_params_) -> VecSimQueryReply * { - return this->searchRangeInternal(query_, radius_, query_params_); - }); - VecSimQueryReply *results[n_queries]; - runParallelQueries(queries, n_queries, radius, query_params, n_threads, - searchRangeWrapper, results); - size_t max_results_num = 1; - for (size_t i = 0; i < n_queries; i++) { - if (VecSimQueryReply_Len(results[i]) > max_results_num) { - max_results_num = VecSimQueryReply_Len(results[i]); - } - } - // We return 2D numpy array of results (labels and distances), use padding of "-1" in the - // empty entries of the matrices. - return wrap_results(results, max_results_num, n_queries); - } - - void addVectorsParallel(const py::object &input, const py::object &vectors_labels, - int n_threads) { - py::array vectors_data(input); - py::array_t labels(vectors_labels); - - if (vectors_data.ndim() != 2) { - throw std::runtime_error("Input vectors data array must be 2D array"); - } - if (labels.ndim() != 1) { - throw std::runtime_error("Input vectors labels array must be 1D array"); - } - if (vectors_data.shape(0) != labels.shape(0)) { - throw std::runtime_error( - "The first dim of vectors data and labels arrays must be equal"); - } - size_t n_vectors = vectors_data.shape(0); - // Use number of hardware cores as default number of threads, unless specified otherwise. - if (n_threads <= 0) { - n_threads = (int)std::thread::hardware_concurrency(); - } - // The decision as to when to allocate a new block is made by the index internally in the - // "addVector" function, where there is an internal counter that is incremented for each - // vector. To ensure that the thread which is taking the write lock is the one that performs - // the resizing, we make sure that no other thread is allowed to bypass the thread for which - // the global counter is a multiple of the block size. Hence, we use the barrier lock and - // lock in every iteration to ensure we acquire the right lock (read/write) based on the - // global counter, so threads won't call "addVector" with the inappropriate lock. - std::mutex barrier; - std::atomic global_counter{}; - size_t block_size = VecSimIndex_BasicInfo(this->index.get()).blockSize; - auto parallel_insert = - [&](const py::array &data, - const py::array_t &labels) { - while (true) { - bool exclusive = true; - barrier.lock(); - int ind = global_counter++; - if (ind >= n_vectors) { - barrier.unlock(); - break; - } - if (ind % block_size != 0) { - // Read lock for normal operations - indexGuard->lock_shared(); - exclusive = false; - } else { - // Exclusive lock for block resizing operations - indexGuard->lock(); - } - barrier.unlock(); - this->addVectorInternal((const char *)data.data(ind), labels.at(ind)); - exclusive ? indexGuard->unlock() : indexGuard->unlock_shared(); - } - }; - std::thread thread_objs[n_threads]; - { - // Release python GIL while threads are running. - py::gil_scoped_release py_gil; - for (size_t i = 0; i < n_threads; i++) { - thread_objs[i] = std::thread(parallel_insert, vectors_data, labels); - } - for (size_t i = 0; i < n_threads; i++) { - thread_objs[i].join(); - } - } - } - - bool checkIntegrity() { - auto type = VecSimIndex_BasicInfo(this->index.get()).type; - if (type == VecSimType_FLOAT32) { - return dynamic_cast *>(this->index.get()) - ->checkIntegrity() - .valid_state; - } else if (type == VecSimType_FLOAT64) { - return dynamic_cast *>(this->index.get()) - ->checkIntegrity() - .valid_state; - } else if (type == VecSimType_BFLOAT16) { - return dynamic_cast *>(this->index.get()) - ->checkIntegrity() - .valid_state; - } else if (type == VecSimType_FLOAT16) { - return dynamic_cast *>(this->index.get()) - ->checkIntegrity() - .valid_state; - } else if (type == VecSimType_INT8) { - return dynamic_cast *>(this->index.get()) - ->checkIntegrity() - .valid_state; - } else if (type == VecSimType_UINT8) { - return dynamic_cast *>(this->index.get()) - ->checkIntegrity() - .valid_state; - } else { - throw std::runtime_error("Invalid index data type"); - } - } - PyBatchIterator createBatchIterator(const py::object &input, - VecSimQueryParams *query_params) override { - - py::array query(input); - py::gil_scoped_release py_gil; - // Passing indexGuardPtr by value, so that the refCount of the mutex - auto del = [indexGuardPtr = this->indexGuard](VecSimBatchIterator *pyBatchIter) { - VecSimBatchIterator_Free(pyBatchIter); - indexGuardPtr->unlock_shared(); - }; - indexGuard->lock_shared(); - auto py_batch_ptr = std::shared_ptr( - VecSimBatchIterator_New(index.get(), (const char *)query.data(0), query_params), del); - return PyBatchIterator(index, py_batch_ptr); - } -}; - -class PyTieredIndex : public PyVecSimIndex { -protected: - tieredIndexMock mock_thread_pool; - - VecSimIndexAbstract *getFlatBuffer() { - return reinterpret_cast *>(this->index.get()) - ->getFlatBufferIndex(); - } - - TieredIndexParams getTieredIndexParams(size_t buffer_limit) { - // Create TieredIndexParams using the mock thread pool. - return TieredIndexParams{ - .jobQueue = &(this->mock_thread_pool.jobQ), - .jobQueueCtx = this->mock_thread_pool.ctx, - .submitCb = tieredIndexMock::submit_callback, - .flatBufferLimit = buffer_limit, - }; - } - -public: - explicit PyTieredIndex() { mock_thread_pool.init_threads(); } - - void WaitForIndex(size_t waiting_duration = 10) { - mock_thread_pool.thread_pool_wait(waiting_duration); - } - - size_t getFlatIndexSize() { return getFlatBuffer()->indexLabelCount(); } - - size_t getThreadsNum() { return mock_thread_pool.thread_pool_size; } - - size_t getBufferLimit() { - return reinterpret_cast *>(this->index.get()) - ->getFlatBufferLimit(); - } -}; - -class PyTiered_HNSWIndex : public PyTieredIndex { -public: - explicit PyTiered_HNSWIndex(const HNSWParams &hnsw_params, - const TieredHNSWParams &tiered_hnsw_params, size_t buffer_limit) { - - // Create primaryIndexParams and specific params for hnsw tiered index. - VecSimParams primary_index_params = {.algo = VecSimAlgo_HNSWLIB, - .algoParams = {.hnswParams = HNSWParams{hnsw_params}}}; - - auto tiered_params = this->getTieredIndexParams(buffer_limit); - tiered_params.primaryIndexParams = &primary_index_params; - tiered_params.specificParams.tieredHnswParams = tiered_hnsw_params; - - // Create VecSimParams for TieredIndexParams - VecSimParams params = {.algo = VecSimAlgo_TIERED, - .algoParams = {.tieredParams = TieredIndexParams{tiered_params}}}; - - this->index = std::shared_ptr(VecSimIndex_New(¶ms), VecSimIndex_Free); - - // Set the created tiered index in the index external context. - this->mock_thread_pool.ctx->index_strong_ref = this->index; - } - - size_t HNSWLabelCount() { - return this->index->debugInfo().tieredInfo.backendCommonInfo.indexLabelCount; - } -}; - -class PyBFIndex : public PyVecSimIndex { -public: - explicit PyBFIndex(const BFParams &bf_params) { - VecSimParams params = {.algo = VecSimAlgo_BF, - .algoParams = {.bfParams = BFParams{bf_params}}}; - this->index = std::shared_ptr(VecSimIndex_New(¶ms), VecSimIndex_Free); - } -}; - -#if HAVE_SVS -class PySVSIndex : public PyVecSimIndex { -public: - explicit PySVSIndex(const SVSParams &svs_params) { - VecSimParams params = {.algo = VecSimAlgo_SVS, .algoParams = {.svsParams = svs_params}}; - this->index = std::shared_ptr(VecSimIndex_New(¶ms), VecSimIndex_Free); - if (!this->index) { - throw std::runtime_error("Index creation failed"); - } - } - - explicit PySVSIndex(const std::string &location, const SVSParams &svs_params) { - VecSimParams params = {.algo = VecSimAlgo_SVS, .algoParams = {.svsParams = svs_params}}; - this->index = - std::shared_ptr(SVSFactory::NewIndex(location, ¶ms), VecSimIndex_Free); - if (!this->index) { - throw std::runtime_error("Index creation failed"); - } - } - - void addVectorsParallel(const py::object &input, const py::object &vectors_labels) { - py::array vectors_data(input); - // py::array labels(vectors_labels); - py::array_t labels(vectors_labels); - - if (vectors_data.ndim() != 2) { - throw std::runtime_error("Input vectors data array must be 2D array"); - } - if (labels.ndim() != 1) { - throw std::runtime_error("Input vectors labels array must be 1D array"); - } - if (vectors_data.shape(0) != labels.shape(0)) { - throw std::runtime_error( - "The first dim of vectors data and labels arrays must be equal"); - } - size_t n_vectors = vectors_data.shape(0); - - auto svs_index = dynamic_cast(this->index.get()); - assert(svs_index); - svs_index->addVectors(vectors_data.data(), labels.data(), n_vectors); - } - - void checkIntegrity() { - auto svs_index = dynamic_cast(this->index.get()); - assert(svs_index); - try { - svs_index->checkIntegrity(); - } catch (const std::exception &e) { - throw std::runtime_error(std::string("SVSIndex integrity check failed: ") + e.what()); - } - } - - void saveIndex(const std::string &location) { - auto svs_index = dynamic_cast(this->index.get()); - assert(svs_index); - svs_index->saveIndex(location); - } - - void loadIndex(const std::string &location) { - auto svs_index = dynamic_cast(this->index.get()); - assert(svs_index); - svs_index->loadIndex(location); - } - - size_t getLabelsCount() const { return this->index->debugInfo().commonInfo.indexLabelCount; } -}; - -class PyTiered_SVSIndex : public PyTieredIndex { -public: - explicit PyTiered_SVSIndex(const SVSParams &svs_params, - const TieredSVSParams &tiered_svs_params, size_t buffer_limit) { - - // Create primaryIndexParams and specific params for svs tiered index. - VecSimParams primary_index_params = {.algo = VecSimAlgo_SVS, - .algoParams = {.svsParams = svs_params}}; - - if (primary_index_params.algoParams.svsParams.num_threads == 0) { - primary_index_params.algoParams.svsParams.num_threads = - this->mock_thread_pool.thread_pool_size; // Use the mock thread pool size as default - } - - auto tiered_params = this->getTieredIndexParams(buffer_limit); - tiered_params.primaryIndexParams = &primary_index_params; - tiered_params.specificParams.tieredSVSParams = tiered_svs_params; - - // Create VecSimParams for TieredIndexParams - VecSimParams params = {.algo = VecSimAlgo_TIERED, - .algoParams = {.tieredParams = tiered_params}}; - - this->index = std::shared_ptr(VecSimIndex_New(¶ms), VecSimIndex_Free); - - // Set the created tiered index in the index external context. - this->mock_thread_pool.ctx->index_strong_ref = this->index; - } - - size_t SVSLabelCount() { - return this->index->debugInfo().tieredInfo.backendCommonInfo.indexLabelCount; - } -}; -#endif - -PYBIND11_MODULE(VecSim, m) { - py::enum_(m, "VecSimAlgo") - .value("VecSimAlgo_HNSWLIB", VecSimAlgo_HNSWLIB) - .value("VecSimAlgo_BF", VecSimAlgo_BF) - .value("VecSimAlgo_SVS", VecSimAlgo_SVS) - .export_values(); - - py::enum_(m, "VecSimType") - .value("VecSimType_FLOAT32", VecSimType_FLOAT32) - .value("VecSimType_FLOAT64", VecSimType_FLOAT64) - .value("VecSimType_BFLOAT16", VecSimType_BFLOAT16) - .value("VecSimType_FLOAT16", VecSimType_FLOAT16) - .value("VecSimType_INT8", VecSimType_INT8) - .value("VecSimType_UINT8", VecSimType_UINT8) - .value("VecSimType_INT32", VecSimType_INT32) - .value("VecSimType_INT64", VecSimType_INT64) - .export_values(); - - py::enum_(m, "VecSimMetric") - .value("VecSimMetric_L2", VecSimMetric_L2) - .value("VecSimMetric_IP", VecSimMetric_IP) - .value("VecSimMetric_Cosine", VecSimMetric_Cosine) - .export_values(); - - py::enum_(m, "VecSimOptionMode") - .value("VecSimOption_AUTO", VecSimOption_AUTO) - .value("VecSimOption_ENABLE", VecSimOption_ENABLE) - .value("VecSimOption_DISABLE", VecSimOption_DISABLE) - .export_values(); - - py::enum_(m, "VecSimQueryReply_Order") - .value("BY_SCORE", BY_SCORE) - .value("BY_ID", BY_ID) - .export_values(); - - py::class_(m, "HNSWParams") - .def(py::init()) - .def_readwrite("type", &HNSWParams::type) - .def_readwrite("dim", &HNSWParams::dim) - .def_readwrite("metric", &HNSWParams::metric) - .def_readwrite("multi", &HNSWParams::multi) - .def_readwrite("initialCapacity", &HNSWParams::initialCapacity) - .def_readwrite("M", &HNSWParams::M) - .def_readwrite("efConstruction", &HNSWParams::efConstruction) - .def_readwrite("efRuntime", &HNSWParams::efRuntime) - .def_readwrite("epsilon", &HNSWParams::epsilon); - - py::class_(m, "BFParams") - .def(py::init()) - .def_readwrite("type", &BFParams::type) - .def_readwrite("dim", &BFParams::dim) - .def_readwrite("metric", &BFParams::metric) - .def_readwrite("multi", &BFParams::multi) - .def_readwrite("initialCapacity", &BFParams::initialCapacity) - .def_readwrite("blockSize", &BFParams::blockSize); - - py::enum_(m, "VecSimSvsQuantBits") - .value("VecSimSvsQuant_NONE", VecSimSvsQuant_NONE) - .value("VecSimSvsQuant_Scalar", VecSimSvsQuant_Scalar) - .value("VecSimSvsQuant_4", VecSimSvsQuant_4) - .value("VecSimSvsQuant_8", VecSimSvsQuant_8) - .value("VecSimSvsQuant_4x4", VecSimSvsQuant_4x4) - .value("VecSimSvsQuant_4x8", VecSimSvsQuant_4x8) - .value("VecSimSvsQuant_4x8_LeanVec", VecSimSvsQuant_4x8_LeanVec) - .value("VecSimSvsQuant_8x8_LeanVec", VecSimSvsQuant_8x8_LeanVec) - .export_values(); - - py::class_(m, "SVSParams") - .def(py::init()) - .def_readwrite("type", &SVSParams::type) - .def_readwrite("dim", &SVSParams::dim) - .def_readwrite("metric", &SVSParams::metric) - .def_readwrite("multi", &SVSParams::multi) - .def_readwrite("blockSize", &SVSParams::blockSize) - .def_readwrite("quantBits", &SVSParams::quantBits) - .def_readwrite("alpha", &SVSParams::alpha) - .def_readwrite("graph_max_degree", &SVSParams::graph_max_degree) - .def_readwrite("construction_window_size", &SVSParams::construction_window_size) - .def_readwrite("max_candidate_pool_size", &SVSParams::max_candidate_pool_size) - .def_readwrite("prune_to", &SVSParams::prune_to) - .def_readwrite("use_search_history", &SVSParams::use_search_history) - .def_readwrite("search_window_size", &SVSParams::search_window_size) - .def_readwrite("search_buffer_capacity", &SVSParams::search_buffer_capacity) - .def_readwrite("leanvec_dim", &SVSParams::leanvec_dim) - .def_readwrite("epsilon", &SVSParams::epsilon) - .def_readwrite("num_threads", &SVSParams::num_threads); - - py::class_(m, "TieredHNSWParams") - .def(py::init()) - .def_readwrite("swapJobThreshold", &TieredHNSWParams::swapJobThreshold); - - py::class_(m, "TieredSVSParams") - .def(py::init()) - .def_readwrite("trainingTriggerThreshold", &TieredSVSParams::trainingTriggerThreshold) - .def_readwrite("updateTriggerThreshold", &TieredSVSParams::updateTriggerThreshold) - .def_readwrite("updateJobWaitTime", &TieredSVSParams::updateJobWaitTime); - - py::class_(m, "AlgoParams") - .def(py::init()) - .def_readwrite("hnswParams", &AlgoParams::hnswParams) - .def_readwrite("bfParams", &AlgoParams::bfParams) - .def_readwrite("svsParams", &AlgoParams::svsParams); - - py::class_(m, "VecSimParams") - .def(py::init()) - .def_readwrite("algo", &VecSimParams::algo) - .def_readwrite("algoParams", &VecSimParams::algoParams); - - py::class_ queryParams(m, "VecSimQueryParams"); - - queryParams.def(py::init<>()) - .def_readwrite("hnswRuntimeParams", &VecSimQueryParams::hnswRuntimeParams) - .def_readwrite("svsRuntimeParams", &VecSimQueryParams::svsRuntimeParams) - .def_readwrite("batchSize", &VecSimQueryParams::batchSize); - - py::class_(queryParams, "HNSWRuntimeParams") - .def(py::init<>()) - .def_readwrite("efRuntime", &HNSWRuntimeParams::efRuntime) - .def_readwrite("epsilon", &HNSWRuntimeParams::epsilon); - - py::class_(queryParams, "SVSRuntimeParams") - .def(py::init<>()) - .def_readwrite("windowSize", &SVSRuntimeParams::windowSize) - .def_readwrite("bufferCapacity", &SVSRuntimeParams::bufferCapacity) - .def_readwrite("searchHistory", &SVSRuntimeParams::searchHistory) - .def_readwrite("epsilon", &SVSRuntimeParams::epsilon); - - py::class_(m, "VecSimIndex") - .def(py::init([](const VecSimParams ¶ms) { return new PyVecSimIndex(params); }), - py::arg("params")) - .def("add_vector", &PyVecSimIndex::addVector) - .def("delete_vector", &PyVecSimIndex::deleteVector) - .def("knn_query", &PyVecSimIndex::knn, py::arg("vector"), py::arg("k"), - py::arg("query_param") = nullptr) - .def("range_query", &PyVecSimIndex::range, py::arg("vector"), py::arg("radius"), - py::arg("query_param") = nullptr) - .def("index_size", &PyVecSimIndex::indexSize) - .def("index_type", &PyVecSimIndex::indexType) - .def("index_memory", &PyVecSimIndex::indexMemory) - .def("create_batch_iterator", &PyVecSimIndex::createBatchIterator, py::arg("query_blob"), - py::arg("query_param") = nullptr) - .def("get_vector", &PyVecSimIndex::getVector) - .def("run_gc", &PyVecSimIndex::runGC); - - py::class_(m, "HNSWIndex") - .def(py::init([](const HNSWParams ¶ms) { return new PyHNSWLibIndex(params); }), - py::arg("params")) - .def(py::init([](const std::string &location) { return new PyHNSWLibIndex(location); }), - py::arg("location")) - .def("set_ef", &PyHNSWLibIndex::setDefaultEf) - .def("save_index", &PyHNSWLibIndex::saveIndex) - .def("knn_parallel", &PyHNSWLibIndex::searchKnnParallel, py::arg("queries"), py::arg("k"), - py::arg("query_param") = nullptr, py::arg("num_threads") = -1) - .def("add_vector_parallel", &PyHNSWLibIndex::addVectorsParallel, py::arg("vectors"), - py::arg("labels"), py::arg("num_threads") = -1) - .def("check_integrity", &PyHNSWLibIndex::checkIntegrity) - .def("range_parallel", &PyHNSWLibIndex::searchRangeParallel, py::arg("queries"), - py::arg("radius"), py::arg("query_param") = nullptr, py::arg("num_threads") = -1) - .def("create_batch_iterator", &PyHNSWLibIndex::createBatchIterator, py::arg("query_blob"), - py::arg("query_param") = nullptr); - - py::class_(m, "TieredIndex") - .def("wait_for_index", &PyTieredIndex::WaitForIndex, py::arg("waiting_duration") = 10) - .def("get_curr_bf_size", &PyTieredIndex::getFlatIndexSize) - .def("get_buffer_limit", &PyTieredIndex::getBufferLimit) - .def("get_threads_num", &PyTieredIndex::getThreadsNum); - - py::class_(m, "Tiered_HNSWIndex") - .def(py::init([](const HNSWParams &hnsw_params, const TieredHNSWParams &tiered_hnsw_params, - size_t flat_buffer_size = DEFAULT_BLOCK_SIZE) { - return new PyTiered_HNSWIndex(hnsw_params, tiered_hnsw_params, flat_buffer_size); - }), - py::arg("hnsw_params"), py::arg("tiered_hnsw_params"), py::arg("flat_buffer_size")) - .def("hnsw_label_count", &PyTiered_HNSWIndex::HNSWLabelCount); - - py::class_(m, "BFIndex") - .def(py::init([](const BFParams ¶ms) { return new PyBFIndex(params); }), - py::arg("params")); -#if HAVE_SVS - py::class_(m, "SVSIndex") - .def(py::init([](const SVSParams ¶ms) { return new PySVSIndex(params); }), - py::arg("params")) - .def(py::init([](const std::string &location, const SVSParams ¶ms) { - return new PySVSIndex(location, params); - }), - py::arg("location"), py::arg("params")) - .def("add_vector_parallel", &PySVSIndex::addVectorsParallel, py::arg("vectors"), - py::arg("labels")) - .def("check_integrity", &PySVSIndex::checkIntegrity) - .def("save_index", &PySVSIndex::saveIndex, py::arg("location")) - .def("load_index", &PySVSIndex::loadIndex, py::arg("location")) - .def("get_labels_count", &PySVSIndex::getLabelsCount); - - py::class_(m, "Tiered_SVSIndex") - .def(py::init([](const SVSParams &svs_params, const TieredSVSParams &tiered_svs_params, - size_t flat_buffer_size = DEFAULT_BLOCK_SIZE) { - return new PyTiered_SVSIndex(svs_params, tiered_svs_params, flat_buffer_size); - }), - py::arg("svs_params"), py::arg("tiered_svs_params"), - py::arg("flat_buffer_size") = DEFAULT_BLOCK_SIZE) - .def("svs_label_count", &PyTiered_SVSIndex::SVSLabelCount); -#endif - - py::class_(m, "BatchIterator") - .def("has_next", &PyBatchIterator::hasNext) - .def("get_next_results", &PyBatchIterator::getNextResults) - .def("reset", &PyBatchIterator::reset); - - m.def( - "set_log_context", - [](const std::string &test_name, const std::string &test_type) { - // Call the C++ function to set the global context - VecSim_SetTestLogContext(test_name.c_str(), test_type.c_str()); - }, - "Set the context (test name) for logging"); -} diff --git a/tests/benchmark/CMakeLists.txt b/tests/benchmark/CMakeLists.txt index 859f2c0af..c98aa7c30 100644 --- a/tests/benchmark/CMakeLists.txt +++ b/tests/benchmark/CMakeLists.txt @@ -39,6 +39,10 @@ endif() # Spaces benchmarks # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +# Simple HNSW benchmark (self-contained, no external data needed) +add_executable(simple_hnsw_bench simple_hnsw_bench.cpp) +target_link_libraries(simple_hnsw_bench VectorSimilarity) + set(DATA_TYPE fp32 fp64 bf16 fp16 int8 uint8 sq8_fp32 sq8_sq8) foreach(data_type IN LISTS DATA_TYPE) add_executable(bm_spaces_${data_type} spaces_benchmarks/bm_spaces_${data_type}.cpp) diff --git a/tests/benchmark/simple_hnsw_bench.cpp b/tests/benchmark/simple_hnsw_bench.cpp new file mode 100644 index 000000000..5f6af44a9 --- /dev/null +++ b/tests/benchmark/simple_hnsw_bench.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2006-Present, Redis Ltd. + * All rights reserved. + * + * Licensed under your choice of the Redis Source Available License 2.0 + * (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the + * GNU Affero General Public License v3 (AGPLv3). + * + * Simple HNSW benchmark that creates an in-memory index for comparison with Rust. + * This benchmark is self-contained and doesn't require external data files. + * + * Usage: ./simple_hnsw_bench + */ +#include +#include +#include +#include +#include "VecSim/vec_sim.h" +#include "VecSim/algorithms/hnsw/hnsw.h" +#include "VecSim/index_factories/hnsw_factory.h" + +constexpr size_t DIM = 128; +constexpr size_t N_VECTORS = 10000; +constexpr size_t M = 16; +constexpr size_t EF_CONSTRUCTION = 100; +constexpr size_t EF_RUNTIME = 100; +constexpr size_t N_QUERIES = 1000; + +std::vector generate_random_vector(size_t dim, std::mt19937 &gen) { + std::uniform_real_distribution dist(0.0f, 1.0f); + std::vector vec(dim); + for (size_t i = 0; i < dim; ++i) { + vec[i] = dist(gen); + } + return vec; +} + +int main() { + std::mt19937 gen(42); // Fixed seed for reproducibility + + std::cout << "=== C++ HNSW Benchmark ===" << std::endl; + std::cout << "Config: " << N_VECTORS << " vectors, " << DIM << " dimensions, M=" << M + << ", ef_construction=" << EF_CONSTRUCTION << ", ef_runtime=" << EF_RUNTIME + << std::endl; + std::cout << std::endl; + + // Create HNSW parameters + HNSWParams params = {.dim = DIM, + .metric = VecSimMetric_L2, + .type = VecSimType_FLOAT32, + .M = M, + .efConstruction = EF_CONSTRUCTION, + .efRuntime = EF_RUNTIME}; + + // Create index + VecSimIndex *index = HNSWFactory::NewIndex(¶ms); + + // Generate and insert vectors + std::cout << "Inserting " << N_VECTORS << " vectors..." << std::endl; + auto insert_start = std::chrono::high_resolution_clock::now(); + + for (size_t i = 0; i < N_VECTORS; ++i) { + auto vec = generate_random_vector(DIM, gen); + VecSimIndex_AddVector(index, vec.data(), i); + } + + auto insert_end = std::chrono::high_resolution_clock::now(); + auto insert_duration = + std::chrono::duration_cast(insert_end - insert_start); + double insert_throughput = N_VECTORS * 1000.0 / insert_duration.count(); + + std::cout << "Insertion time: " << insert_duration.count() << " ms (" << insert_throughput + << " vec/s)" << std::endl; + std::cout << std::endl; + + // Generate query vectors + std::vector> queries; + for (size_t i = 0; i < N_QUERIES; ++i) { + queries.push_back(generate_random_vector(DIM, gen)); + } + + // KNN Search benchmark + std::cout << "Running " << N_QUERIES << " KNN queries (k=10)..." << std::endl; + + auto knn_start = std::chrono::high_resolution_clock::now(); + + for (size_t i = 0; i < N_QUERIES; ++i) { + VecSimQueryReply *results = + VecSimIndex_TopKQuery(index, queries[i].data(), 10, nullptr, BY_SCORE); + VecSimQueryReply_Free(results); + } + + auto knn_end = std::chrono::high_resolution_clock::now(); + auto knn_duration = std::chrono::duration_cast(knn_end - knn_start); + double avg_knn_time = static_cast(knn_duration.count()) / N_QUERIES; + double knn_throughput = N_QUERIES * 1000000.0 / knn_duration.count(); + + std::cout << "KNN k=10 (ef=" << EF_RUNTIME << "): avg " << avg_knn_time << " µs (" + << knn_throughput << " queries/s)" << std::endl; + std::cout << std::endl; + + // Range Search benchmark + std::cout << "Running " << N_QUERIES << " Range queries (radius=10.0)..." << std::endl; + + auto range_start = std::chrono::high_resolution_clock::now(); + + for (size_t i = 0; i < N_QUERIES; ++i) { + VecSimQueryReply *results = + VecSimIndex_RangeQuery(index, queries[i].data(), 10.0, nullptr, BY_SCORE); + VecSimQueryReply_Free(results); + } + + auto range_end = std::chrono::high_resolution_clock::now(); + auto range_duration = + std::chrono::duration_cast(range_end - range_start); + double avg_range_time = static_cast(range_duration.count()) / N_QUERIES; + double range_throughput = N_QUERIES * 1000000.0 / range_duration.count(); + + std::cout << "Range (r=10): avg " << avg_range_time << " µs (" << range_throughput + << " queries/s)" << std::endl; + std::cout << std::endl; + + // Cleanup + VecSimIndex_Free(index); + + std::cout << "=== Summary ===" << std::endl; + std::cout << "Insertion: " << insert_throughput << " vec/s" << std::endl; + std::cout << "KNN (k=10): " << avg_knn_time << " µs (" << knn_throughput << " q/s)" + << std::endl; + std::cout << "Range (r=10): " << avg_range_time << " µs (" << range_throughput << " q/s)" + << std::endl; + + return 0; +} diff --git a/tests/flow/test_hnsw.py b/tests/flow/test_hnsw.py index 245e82e05..8750d31a6 100644 --- a/tests/flow/test_hnsw.py +++ b/tests/flow/test_hnsw.py @@ -14,57 +14,88 @@ import hnswlib -# compare results with the original version of hnswlib - do not use elements deletion. +# Validate HNSW L2 results against brute force ground truth. def test_sanity_hnswlib_index_L2(): dim = 16 num_elements = 10000 - space = 'l2' M = 16 efConstruction = 100 - efRuntime = 10 + efRuntime = 50 # Higher ef for more reliable recall + k = 10 + num_queries = 10 # Multiple queries for robust recall measurement index = create_hnsw_index(dim, num_elements, VecSimMetric_L2, VecSimType_FLOAT32, efConstruction, M, efRuntime) - p = hnswlib.Index(space=space, dim=dim) - p.init_index(max_elements=num_elements, ef_construction=efConstruction, M=M) - p.set_ef(efRuntime) - + np.random.seed(42) data = np.float32(np.random.random((num_elements, dim))) for i, vector in enumerate(data): index.add_vector(vector, i) - p.add_items(vector, i) - query_data = np.float32(np.random.random((1, dim))) - hnswlib_labels, hnswlib_distances = p.knn_query(query_data, k=10) - redis_labels, redis_distances = index.knn_query(query_data, 10) - assert_allclose(hnswlib_labels, redis_labels, rtol=1e-5, atol=0) - assert_allclose(hnswlib_distances, redis_distances, rtol=1e-5, atol=0) + # Run multiple queries and compute average recall + from scipy.spatial import distance + total_recall = 0 + for _ in range(num_queries): + query_data = np.float32(np.random.random((1, dim))) + hnsw_labels, hnsw_distances = index.knn_query(query_data, k) + + # Compute brute force ground truth + all_dists = np.array([distance.sqeuclidean(query_data.flatten(), vec) for vec in data]) + bf_indices = np.argsort(all_dists)[:k] + + hnsw_set = set(hnsw_labels[0]) + bf_set = set(bf_indices) + recall = len(hnsw_set & bf_set) / k + total_recall += recall + + # Verify distances are correct for returned labels + for label, dist in zip(hnsw_labels[0], hnsw_distances[0]): + true_dist = all_dists[label] + assert_allclose(dist, true_dist, rtol=1e-5, atol=1e-6) + + avg_recall = total_recall / num_queries + assert avg_recall >= 0.9, f"Average recall {avg_recall:.2f} is too low, expected >= 0.9" +# Validate HNSW cosine results against brute force ground truth. def test_sanity_hnswlib_index_cosine(): dim = 16 num_elements = 10000 - space = 'cosine' M = 16 efConstruction = 100 - efRuntime = 10 + efRuntime = 50 # Higher ef for more reliable recall + k = 10 + num_queries = 10 # Multiple queries for robust recall measurement index = create_hnsw_index(dim, num_elements, VecSimMetric_Cosine, VecSimType_FLOAT32, efConstruction, M, efRuntime) - p = hnswlib.Index(space=space, dim=dim) - p.init_index(max_elements=num_elements, ef_construction=efConstruction, M=M) - p.set_ef(efRuntime) - + np.random.seed(42) data = np.float32(np.random.random((num_elements, dim))) for i, vector in enumerate(data): index.add_vector(vector, i) - p.add_items(vector, i) - query_data = np.float32(np.random.random((1, dim))) - hnswlib_labels, hnswlib_distances = p.knn_query(query_data, k=10) - redis_labels, redis_distances = index.knn_query(query_data, 10) - assert_allclose(hnswlib_labels, redis_labels, rtol=1e-5, atol=0) - assert_allclose(hnswlib_distances, redis_distances, rtol=1e-5, atol=0) + # Run multiple queries and compute average recall + from scipy.spatial import distance + total_recall = 0 + for _ in range(num_queries): + query_data = np.float32(np.random.random((1, dim))) + hnsw_labels, hnsw_distances = index.knn_query(query_data, k) + + # Compute brute force ground truth using cosine distance + all_dists = np.array([distance.cosine(query_data.flatten(), vec) for vec in data]) + bf_indices = np.argsort(all_dists)[:k] + + hnsw_set = set(hnsw_labels[0]) + bf_set = set(bf_indices) + recall = len(hnsw_set & bf_set) / k + total_recall += recall + + # Verify distances are correct for returned labels + for label, dist in zip(hnsw_labels[0], hnsw_distances[0]): + true_dist = all_dists[label] + assert_allclose(dist, true_dist, rtol=1e-5, atol=1e-6) + + avg_recall = total_recall / num_queries + assert avg_recall >= 0.9, f"Average recall {avg_recall:.2f} is too low, expected >= 0.9" # Validate correctness of delete implementation comparing the brute force search. We test the search recall which is not