From ebc11c0bd526e43e7e231e55621912718801d506 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Feb 2026 22:18:37 +0530 Subject: [PATCH 01/11] Refactor search interface with unified SearchDispatch trait - Add SearchDispatch trait for unified search entry point - Split search types into separate files (graph_search, range_search, multihop_search, diverse_search) - Rename execute_search to search, parameters to search_params - Rename search_recorded to debug_search to signal debug-only intent - Move internal implementations to respective search type files - Add From impls for SearchParams/RangeSearchParams to new types - Add test helper functions in test modules for backwards compat - Bump version to 0.46.0 --- Cargo.lock | 30 +- Cargo.toml | 28 +- .../src/search/graph/knn.rs | 3 +- .../src/search/graph/multihop.rs | 9 +- .../src/search/graph/range.rs | 10 +- .../src/search/provider/disk_provider.rs | 11 +- diskann-providers/src/index/diskann_async.rs | 291 ++++--- diskann-providers/src/index/wrapped_async.rs | 3 +- diskann/src/graph/index.rs | 777 ++---------------- diskann/src/graph/mod.rs | 8 + diskann/src/graph/search/dispatch.rs | 33 + diskann/src/graph/search/diverse_search.rs | 165 ++++ diskann/src/graph/search/graph_search.rs | 227 +++++ diskann/src/graph/search/mod.rs | 41 + diskann/src/graph/search/multihop_search.rs | 312 +++++++ diskann/src/graph/search/range_search.rs | 371 +++++++++ diskann/src/graph/test/cases/grid.rs | 8 +- 17 files changed, 1456 insertions(+), 871 deletions(-) create mode 100644 diskann/src/graph/search/dispatch.rs create mode 100644 diskann/src/graph/search/diverse_search.rs create mode 100644 diskann/src/graph/search/graph_search.rs create mode 100644 diskann/src/graph/search/multihop_search.rs create mode 100644 diskann/src/graph/search/range_search.rs diff --git a/Cargo.lock b/Cargo.lock index 9e2ba2409..b65475315 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -403,7 +403,7 @@ dependencies = [ [[package]] name = "diskann" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "bytemuck", @@ -427,7 +427,7 @@ dependencies = [ [[package]] name = "diskann-benchmark" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "bf-tree", @@ -464,7 +464,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-core" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "diskann", @@ -481,7 +481,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-runner" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "clap", @@ -495,7 +495,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-simd" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "diskann-benchmark-runner", @@ -512,7 +512,7 @@ dependencies = [ [[package]] name = "diskann-disk" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "bincode", @@ -547,7 +547,7 @@ dependencies = [ [[package]] name = "diskann-label-filter" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "bf-tree", @@ -570,7 +570,7 @@ dependencies = [ [[package]] name = "diskann-linalg" -version = "0.45.0" +version = "0.46.0" dependencies = [ "approx", "cfg-if", @@ -584,7 +584,7 @@ dependencies = [ [[package]] name = "diskann-platform" -version = "0.45.0" +version = "0.46.0" dependencies = [ "io-uring", "libc", @@ -594,7 +594,7 @@ dependencies = [ [[package]] name = "diskann-providers" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "approx", @@ -638,7 +638,7 @@ dependencies = [ [[package]] name = "diskann-quantization" -version = "0.45.0" +version = "0.46.0" dependencies = [ "bytemuck", "cfg-if", @@ -657,7 +657,7 @@ dependencies = [ [[package]] name = "diskann-tools" -version = "0.45.0" +version = "0.46.0" dependencies = [ "anyhow", "bincode", @@ -689,7 +689,7 @@ dependencies = [ [[package]] name = "diskann-utils" -version = "0.45.0" +version = "0.46.0" dependencies = [ "cfg-if", "diskann-vector", @@ -703,7 +703,7 @@ dependencies = [ [[package]] name = "diskann-vector" -version = "0.45.0" +version = "0.46.0" dependencies = [ "approx", "cfg-if", @@ -717,7 +717,7 @@ dependencies = [ [[package]] name = "diskann-wide" -version = "0.45.0" +version = "0.46.0" dependencies = [ "cfg-if", "half", diff --git a/Cargo.toml b/Cargo.toml index bd9e75037..ce3e47293 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ default-members = [ resolver = "3" [workspace.package] -version = "0.45.0" # Obeying semver +version = "0.46.0" # Obeying semver description = "DiskANN is a fast approximate nearest neighbor search library for high dimensional data" authors = ["Microsoft"] documentation = "https://github.com/microsoft/DiskANN" @@ -46,22 +46,22 @@ undocumented_unsafe_blocks = "warn" [workspace.dependencies] # Base And Numerics -diskann-wide = { path = "diskann-wide", version = "0.45.0" } -diskann-vector = { path = "diskann-vector", version = "0.45.0" } -diskann-linalg = { path = "diskann-linalg", version = "0.45.0" } -diskann-utils = { path = "diskann-utils", default-features = false, version = "0.45.0" } -diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.45.0" } -diskann-platform = { path = "diskann-platform", version = "0.45.0" } +diskann-wide = { path = "diskann-wide", version = "0.46.0" } +diskann-vector = { path = "diskann-vector", version = "0.46.0" } +diskann-linalg = { path = "diskann-linalg", version = "0.46.0" } +diskann-utils = { path = "diskann-utils", default-features = false, version = "0.46.0" } +diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.46.0" } +diskann-platform = { path = "diskann-platform", version = "0.46.0" } # Algorithm -diskann = { path = "diskann", version = "0.45.0" } +diskann = { path = "diskann", version = "0.46.0" } # Providers -diskann-providers = { path = "diskann-providers", default-features = false, version = "0.45.0" } -diskann-disk = { path = "diskann-disk", version = "0.45.0" } -diskann-label-filter = { path = "diskann-label-filter", version = "0.45.0" } +diskann-providers = { path = "diskann-providers", default-features = false, version = "0.46.0" } +diskann-disk = { path = "diskann-disk", version = "0.46.0" } +diskann-label-filter = { path = "diskann-label-filter", version = "0.46.0" } # Infra -diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.45.0" } -diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.45.0" } -diskann-tools = { path = "diskann-tools", version = "0.45.0" } +diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.46.0" } +diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.46.0" } +diskann-tools = { path = "diskann-tools", version = "0.46.0" } # External dependencies (shared versions) anyhow = "1.0.98" diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index ad07f2c82..2d1a3c064 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -113,13 +113,14 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); + let graph_search = graph::GraphSearch::from(*parameters); let stats = self .index .search( self.strategy.get(index)?, &context, self.queries.row(index), - parameters, + &graph_search, buffer, ) .await?; diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 6dfb646bb..e191b7944 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -111,15 +111,18 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); + let multihop_search = graph::MultihopSearch::new( + graph::GraphSearch::from(*parameters), + &*self.labels[index], + ); let stats = self .index - .multihop_search( + .search( self.strategy.get(index)?, &context, self.queries.row(index), - parameters, + &multihop_search, buffer, - &*self.labels[index], ) .await?; diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index a5669ae25..8cfaf9a9a 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -104,16 +104,18 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let (_, ids, distances) = self + let range_search = graph::RangeSearch::from(*parameters); + let result = self .index - .range_search( + .search( self.strategy.get(index)?, &context, self.queries.row(index), - parameters, + &range_search, + &mut (), ) .await?; - buffer.extend(std::iter::zip(ids.into_iter(), distances.into_iter())); + buffer.extend(std::iter::zip(result.ids.into_iter(), result.distances.into_iter())); Ok(Metrics {}) } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index ab0a4f4e7..be12f22f3 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -19,7 +19,7 @@ use diskann::{ graph::{ self, glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy}, - search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, SearchParams, + search_output_buffer, AdjacencyList, DiskANNIndex, GraphSearch, SearchOutputBuffer, SearchParams, }, neighbor::Neighbor, provider::{ @@ -993,11 +993,12 @@ where &mut result_output_buffer, ))? } else { + let graph_search = GraphSearch::new(k_value, search_list_size as usize, beam_width)?; self.runtime.block_on(self.index.search( &strategy, &DefaultContext, strategy.query, - &SearchParams::new(k_value, search_list_size as usize, beam_width)?, + &graph_search, &mut result_output_buffer, ))? }; @@ -1040,7 +1041,7 @@ fn ensure_vertex_loaded>( #[cfg(test)] mod disk_provider_tests { use diskann::{ - graph::{search::record::VisitedSearchRecord, SearchParamsError}, + graph::{search::record::VisitedSearchRecord, SearchParams, SearchParamsError}, utils::IntoUsize, ANNErrorKind, }; @@ -1626,7 +1627,7 @@ mod disk_provider_tests { let mut search_record = VisitedSearchRecord::new(0); search_engine .runtime - .block_on(search_engine.index.search_recorded( + .block_on(search_engine.index.debug_search( &strategy, &DefaultContext, &query_vector, @@ -2088,7 +2089,7 @@ mod disk_provider_tests { let mut search_record = VisitedSearchRecord::new(0); search_engine .runtime - .block_on(search_engine.index.search_recorded( + .block_on(search_engine.index.debug_search( &strategy, &DefaultContext, &query_vector, diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 5451af3ad..8a7d982a9 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -215,6 +215,55 @@ pub(crate) mod tests { // Callbacks for use with `simplified_builder`. fn no_modify(_: &mut diskann::graph::config::Builder) {} + ////////////////////////// + // Test helper functions // + ////////////////////////// + + use diskann::graph::index::SearchStats; + + /// Test helper: performs multihop search using the dispatch API. + async fn multihop_search( + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + search_params: &SearchParams, + output: &mut OB, + filter: &dyn graph::index::QueryLabelProvider, + ) -> diskann::ANNResult + where + DP: DataProvider, + T: Sync + ?Sized, + S: graph::glue::SearchStrategy, + O: Send, + OB: graph::search_output_buffer::SearchOutputBuffer + Send, + { + let multihop = graph::MultihopSearch::new( + graph::GraphSearch::from(*search_params), + filter, + ); + index.search(strategy, context, query, &multihop, output).await + } + + /// Test helper: performs range search using the dispatch API. + async fn range_search( + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + search_params: &RangeSearchParams, + ) -> diskann::ANNResult<(SearchStats, Vec, Vec)> + where + DP: DataProvider, + T: Sync + ?Sized, + S: graph::glue::SearchStrategy, + O: Send + Default + Clone, + { + let range_search = graph::RangeSearch::from(*search_params); + let result = index.search(strategy, context, query, &range_search, &mut ()).await?; + Ok((result.stats, result.ids, result.distances)) + } + ///////////////////////////////////////// // Tests from the original async index // ///////////////////////////////////////// @@ -400,17 +449,17 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - index - .multihop_search( - strategy, - ¶meters.context, - query, - &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), - &mut result_output_buffer, - filter, - ) - .await - .unwrap(); + multihop_search( + index, + strategy, + ¶meters.context, + query, + &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), + &mut result_output_buffer, + filter, + ) + .await + .unwrap(); // Loop over the requested number of results to check, invoking the checker closure. // @@ -1443,17 +1492,17 @@ pub(crate) mod tests { let filter = CallbackFilter::new(blocked, adjusted, 0.5); - let stats = index - .multihop_search( - &FullPrecision, - ¶meters.context, - query.as_slice(), - &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), - &mut result_output_buffer, - &filter, - ) - .await - .unwrap(); + let stats = multihop_search( + &index, + &FullPrecision, + ¶meters.context, + query.as_slice(), + &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), + &mut result_output_buffer, + &filter, + ) + .await + .unwrap(); // Retrieve callback metrics for detailed validation let callback_metrics = filter.metrics(); @@ -2272,30 +2321,30 @@ pub(crate) mod tests { let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); { // Full Precision Search. - let (_, ids, _) = index - .range_search( - &FullPrecision, - ctx, - query, - &RangeSearchParams::new_default(starting_l_value, radius).unwrap(), - ) - .await - .unwrap(); + let (_, ids, _) = range_search( + &*index, + &FullPrecision, + ctx, + query, + &RangeSearchParams::new_default(starting_l_value, radius).unwrap(), + ) + .await + .unwrap(); assert_range_results_exactly_match(q, >, &ids, radius, None); } { // Quantized Search - let (_, ids, _) = index - .range_search( - &Hybrid::new(None), - ctx, - query, - &RangeSearchParams::new_default(starting_l_value, radius).unwrap(), - ) - .await - .unwrap(); + let (_, ids, _) = range_search( + &*index, + &Hybrid::new(None), + ctx, + query, + &RangeSearchParams::new_default(starting_l_value, radius).unwrap(), + ) + .await + .unwrap(); assert_range_results_exactly_match(q, >, &ids, radius, None); } @@ -2304,24 +2353,24 @@ pub(crate) mod tests { // Test with an inner radius assert!(inner_radius <= radius); - let (_, ids, _) = index - .range_search( - &FullPrecision, - ctx, - query, - &RangeSearchParams::new( - None, - starting_l_value, - None, - radius, - Some(inner_radius), - 1.0, - 1.0, - ) - .unwrap(), + let (_, ids, _) = range_search( + &*index, + &FullPrecision, + ctx, + query, + &RangeSearchParams::new( + None, + starting_l_value, + None, + radius, + Some(inner_radius), + 1.0, + 1.0, ) - .await - .unwrap(); + .unwrap(), + ) + .await + .unwrap(); assert_range_results_exactly_match(q, >, &ids, radius, Some(inner_radius)); } @@ -2329,15 +2378,15 @@ pub(crate) mod tests { { // Test with a lower initial beam to trigger more two-round searches // We don't expect results to exactly match here - let (_, ids, _) = index - .range_search( - &FullPrecision, - ctx, - query, - &RangeSearchParams::new_default(lower_l_value, radius).unwrap(), - ) - .await - .unwrap(); + let (_, ids, _) = range_search( + &*index, + &FullPrecision, + ctx, + query, + &RangeSearchParams::new_default(lower_l_value, radius).unwrap(), + ) + .await + .unwrap(); // check that ids don't have duplicates let mut ids_set = std::collections::HashSet::new(); @@ -4104,17 +4153,17 @@ pub(crate) mod tests { // but reject everything via on_visit let filter = RejectAllFilter::only([0_u32]); - let stats = index - .multihop_search( - &FullPrecision, - &DefaultContext, - query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), - &mut result_output_buffer, - &filter, - ) - .await - .unwrap(); + let stats = multihop_search( + &index, + &FullPrecision, + &DefaultContext, + query.as_slice(), + &SearchParams::new_default(10, 20).unwrap(), + &mut result_output_buffer, + &filter, + ) + .await + .unwrap(); // When all candidates are rejected via on_visit, result_count should be 0 // because rejected candidates are not added to the search frontier @@ -4166,17 +4215,17 @@ pub(crate) mod tests { let target = (num_points / 2) as u32; let filter = TerminatingFilter::new(target); - let stats = index - .multihop_search( - &FullPrecision, - &DefaultContext, - query.as_slice(), - &SearchParams::new_default(10, 40).unwrap(), - &mut result_output_buffer, - &filter, - ) - .await - .unwrap(); + let stats = multihop_search( + &index, + &FullPrecision, + &DefaultContext, + query.as_slice(), + &SearchParams::new_default(10, 40).unwrap(), + &mut result_output_buffer, + &filter, + ) + .await + .unwrap(); let hits = filter.hits(); @@ -4230,17 +4279,17 @@ pub(crate) mod tests { let mut baseline_buffer = search_output_buffer::IdDistance::new(&mut baseline_ids, &mut baseline_distances); - let baseline_stats = index - .multihop_search( - &FullPrecision, - &DefaultContext, - query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), - &mut baseline_buffer, - &EvenFilter, // Just filter to even IDs - ) - .await - .unwrap(); + let baseline_stats = multihop_search( + &index, + &FullPrecision, + &DefaultContext, + query.as_slice(), + &SearchParams::new_default(10, 20).unwrap(), + &mut baseline_buffer, + &EvenFilter, // Just filter to even IDs + ) + .await + .unwrap(); // Now run with a filter that boosts a specific far-away point let boosted_point = (num_points - 2) as u32; // A point far from origin @@ -4251,17 +4300,17 @@ pub(crate) mod tests { let mut adjusted_buffer = search_output_buffer::IdDistance::new(&mut adjusted_ids, &mut adjusted_distances); - let adjusted_stats = index - .multihop_search( - &FullPrecision, - &DefaultContext, - query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), - &mut adjusted_buffer, - &filter, - ) - .await - .unwrap(); + let adjusted_stats = multihop_search( + &index, + &FullPrecision, + &DefaultContext, + query.as_slice(), + &SearchParams::new_default(10, 20).unwrap(), + &mut adjusted_buffer, + &filter, + ) + .await + .unwrap(); // Both searches should return results assert!( @@ -4377,17 +4426,17 @@ pub(crate) mod tests { let max_visits = 5; let filter = TerminateAfterN::new(max_visits); - let _stats = index - .multihop_search( - &FullPrecision, - &DefaultContext, - query.as_slice(), - &SearchParams::new_default(10, 100).unwrap(), // Large L to ensure we'd visit more without termination - &mut result_output_buffer, - &filter, - ) - .await - .unwrap(); + let _stats = multihop_search( + &index, + &FullPrecision, + &DefaultContext, + query.as_slice(), + &SearchParams::new_default(10, 100).unwrap(), // Large L to ensure we'd visit more without termination + &mut result_output_buffer, + &filter, + ) + .await + .unwrap(); // The search should have stopped after max_visits assert!( diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index d3d37416d..050b76afb 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -235,9 +235,10 @@ where O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, { + let graph_search = diskann::graph::GraphSearch::from(*search_params); self.handle.block_on( self.inner - .search(strategy, context, query, search_params, output), + .search(strategy, context, query, &graph_search, output), ) } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index ea48adc0b..f74bb36e4 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -24,10 +24,10 @@ use thiserror::Error; use tokio::task::JoinSet; use super::{ - AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, RangeSearchParams, SearchParams, + AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, SearchParams, glue::{ - self, AsElement, ExpandBeam, FillSet, HybridPredicate, IdIterator, InplaceDeleteStrategy, - InsertStrategy, Predicate, PredicateMut, PruneStrategy, SearchExt, SearchPostProcess, + self, AsElement, ExpandBeam, FillSet, IdIterator, InplaceDeleteStrategy, + InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, @@ -38,9 +38,6 @@ use super::{ search_output_buffer, }; -#[cfg(feature = "experimental_diversity_search")] -use super::DiverseSearchParams; - use crate::{ ANNError, ANNErrorKind, ANNResult, error::{ErrorExt, IntoANNResult}, @@ -110,6 +107,7 @@ pub struct DegreeStats { /// This struct provides detailed metrics about the search process, including /// the number of nodes visited, the number of distance computations performed, /// the number of hops taken during the search, and the total number of results returned. +#[derive(Debug, Clone, Copy)] pub struct SearchStats { /// The total number of distance computations performed during the search. pub cmps: u32, @@ -220,53 +218,6 @@ struct SetBatchElements { batch: Arc<[VectorIdBoxSlice]>, } -pub struct NotInMutWithLabelCheck<'a, K> -where - K: VectorId, -{ - visited_set: &'a mut hashbrown::HashSet, - query_label_evaluator: &'a dyn QueryLabelProvider, -} - -impl<'a, K> NotInMutWithLabelCheck<'a, K> -where - K: VectorId, -{ - /// Construct a new `NotInMutWithLabelCheck` around `visited_set`. - pub fn new( - visited_set: &'a mut hashbrown::HashSet, - query_label_evaluator: &'a dyn QueryLabelProvider, - ) -> Self { - Self { - visited_set, - query_label_evaluator, - } - } -} - -impl Predicate for NotInMutWithLabelCheck<'_, K> -where - K: VectorId, -{ - fn eval(&self, item: &K) -> bool { - !self.visited_set.contains(item) && self.query_label_evaluator.is_match(*item) - } -} - -impl PredicateMut for NotInMutWithLabelCheck<'_, K> -where - K: VectorId, -{ - fn eval_mut(&mut self, item: &K) -> bool { - if self.query_label_evaluator.is_match(*item) { - return self.visited_set.insert(*item); - } - false - } -} - -impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId {} - impl DiskANNIndex where DP: DataProvider, @@ -296,7 +247,7 @@ where /// * `l`: The default window size to use. /// * `additional`: Extra capacity, usually to allow start points to be filtered from /// the result. - fn search_scratch( + pub(crate) fn search_scratch( &self, l: usize, additional: usize, @@ -2061,7 +2012,7 @@ where } // A is the accessor type, T is the query type used for BuildQueryComputer - fn search_internal( + pub(crate) fn search_internal( &self, beam_width: Option, start_ids: &[DP::InternalId], @@ -2136,204 +2087,6 @@ where } } - // A is the accessor type, T is the query type used for BuildQueryComputer - // scratch.in_range is guaranteed to include the starting points - fn range_search_internal( - &self, - search_params: &RangeSearchParams, - accessor: &mut A, - computer: &A::QueryComputer, - scratch: &mut SearchScratch, - ) -> impl SendFuture> - where - A: ExpandBeam + SearchExt, - T: ?Sized, - { - async move { - let beam_width = search_params.beam_width.unwrap_or(1); - - for neighbor in &scratch.in_range { - scratch.range_frontier.push_back(neighbor.id); - } - - let mut neighbors = Vec::with_capacity(self.max_degree_with_slack()); - - let max_returned = search_params.max_returned.unwrap_or(usize::MAX); - - while !scratch.range_frontier.is_empty() { - scratch.beam_nodes.clear(); - - // In this loop we are going to find the beam_width number of remaining nodes within the radius - // Each of these nodes will be a frontier node. - while !scratch.range_frontier.is_empty() && scratch.beam_nodes.len() < beam_width { - let next = scratch.range_frontier.pop_front(); - if let Some(next_node) = next { - scratch.beam_nodes.push(next_node); - } - } - - neighbors.clear(); - accessor - .expand_beam( - scratch.beam_nodes.iter().copied(), - computer, - glue::NotInMut::new(&mut scratch.visited), - |distance, id| neighbors.push(Neighbor::new(id, distance)), - ) - .await?; - - // The predicate ensure that the contents of `neighbors` are unique. - for neighbor in neighbors.iter() { - if neighbor.distance <= search_params.radius * search_params.range_search_slack - && scratch.in_range.len() < max_returned - { - scratch.in_range.push(*neighbor); - scratch.range_frontier.push_back(neighbor.id); - } - } - scratch.cmps += neighbors.len() as u32; - scratch.hops += scratch.beam_nodes.len() as u32; - } - - Ok(InternalSearchStats { - cmps: scratch.cmps, - hops: scratch.hops, - range_search_second_round: true, - }) - } - } - - // A is the accessor type, T is the query type used for BuildQueryComputer - fn multihop_search_internal( - &self, - search_params: &SearchParams, - accessor: &mut A, - computer: &A::QueryComputer, - scratch: &mut SearchScratch, - search_record: &mut SR, - query_label_evaluator: &dyn QueryLabelProvider, - ) -> impl SendFuture> - where - A: ExpandBeam + SearchExt, - T: ?Sized, - SR: SearchRecord + ?Sized, - { - async move { - let beam_width = search_params.beam_width.unwrap_or(1); - - // Helper to build the final stats from scratch state. - let make_stats = |scratch: &SearchScratch| InternalSearchStats { - cmps: scratch.cmps, - hops: scratch.hops, - range_search_second_round: false, - }; - - // Initialize search state if not already initialized. - // This allows paged search to call multihop_search_internal multiple times - if scratch.visited.is_empty() { - let start_ids = accessor.starting_points().await?; - - for id in start_ids { - scratch.visited.insert(id); - let element = accessor - .get_element(id) - .await - .escalate("start point retrieval must succeed")?; - let dist = computer.evaluate_similarity(element.reborrow()); - scratch.best.insert(Neighbor::new(id, dist)); - } - } - - // Pre-allocate with good capacity to avoid repeated allocations - let mut one_hop_neighbors = Vec::with_capacity(self.max_degree_with_slack()); - let mut two_hop_neighbors = Vec::with_capacity(self.max_degree_with_slack()); - let mut candidates_two_hop_expansion = Vec::with_capacity(self.max_degree_with_slack()); - - while scratch.best.has_notvisited_node() && !accessor.terminate_early() { - scratch.beam_nodes.clear(); - one_hop_neighbors.clear(); - candidates_two_hop_expansion.clear(); - two_hop_neighbors.clear(); - - // In this loop we are going to find the beam_width number of nodes that are closest to the query. - // Each of these nodes will be a frontier node. - while scratch.best.has_notvisited_node() && scratch.beam_nodes.len() < beam_width { - let closest_node = scratch.best.closest_notvisited(); - search_record.record(closest_node, scratch.hops, scratch.cmps); - scratch.beam_nodes.push(closest_node.id); - } - - // compute distances from query to one-hop neighbors, and mark them visited - accessor - .expand_beam( - scratch.beam_nodes.iter().copied(), - computer, - glue::NotInMut::new(&mut scratch.visited), - |distance, id| one_hop_neighbors.push(Neighbor::new(id, distance)), - ) - .await?; - - // Process one-hop neighbors based on on_visit() decision - for neighbor in one_hop_neighbors.iter().copied() { - match query_label_evaluator.on_visit(neighbor) { - QueryVisitDecision::Accept(accepted) => { - scratch.best.insert(accepted); - } - QueryVisitDecision::Reject => { - // Rejected nodes: still add to two-hop expansion so we can traverse through them - candidates_two_hop_expansion.push(neighbor); - } - QueryVisitDecision::Terminate => { - scratch.cmps += one_hop_neighbors.len() as u32; - scratch.hops += scratch.beam_nodes.len() as u32; - return Ok(make_stats(scratch)); - } - } - } - - scratch.cmps += one_hop_neighbors.len() as u32; - scratch.hops += scratch.beam_nodes.len() as u32; - - // sort the candidates for two-hop expansion by distance to query point - candidates_two_hop_expansion.sort_unstable_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - // limit the number of two-hop candidates to avoid too many expansions - candidates_two_hop_expansion.truncate(self.max_degree_with_slack() / 2); - - // Expand each two-hop candidate: if its neighbor is a match, compute its distance - // to the query and insert into `scratch.visited` - // If it is not a match, do nothing - let two_hop_expansion_candidate_ids: Vec = - candidates_two_hop_expansion.iter().map(|n| n.id).collect(); - - accessor - .expand_beam( - two_hop_expansion_candidate_ids.iter().copied(), - computer, - NotInMutWithLabelCheck::new(&mut scratch.visited, query_label_evaluator), - |distance, id| { - two_hop_neighbors.push(Neighbor::new(id, distance)); - }, - ) - .await?; - - // Next, insert the new matches into `scratch.best` and increment stats counters - two_hop_neighbors - .iter() - .for_each(|neighbor| scratch.best.insert(*neighbor)); - - scratch.cmps += two_hop_neighbors.len() as u32; - scratch.hops += two_hop_expansion_candidate_ids.len() as u32; - } - - Ok(make_stats(scratch)) - } - } - /// Filter out start nodes from the best candidates in the scratch. fn filter_search_candidates( &self, @@ -2363,48 +2116,75 @@ where } } - /// Performs a graph-based search towards a target query vector recording the path taken. + /// Execute a search using the unified search dispatch interface. /// - /// This method executes a search using the provided `strategy` to access and process elements. - /// It computes the similarity between the query vector and the elements in the index, moving towards the - /// nearest neighbors according to the search parameters. - /// The path taken is recorded according to the search_record object passed in. + /// This method provides a single entry point for all search types. The `parameters` argument + /// implements [`search::SearchDispatch`], which defines the complete search behavior including + /// algorithm selection and post-processing. /// - /// # Arguments + /// # Supported Search Types /// - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`) and beam width. - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. - /// * `search_record` - A mutable reference to a search record object that will record the path taken during the search. + /// - [`search::GraphSearch`]: Standard graph-based ANN search + /// - [`search::MultihopSearch`]: Label-filtered search with multi-hop expansion + /// - [`search::RangeSearch`]: Range-based search within a distance radius + /// - [`search::DiverseSearch`]: Diversity-aware search (feature-gated) /// - /// # Returns + /// For flat (brute-force) search, use [`Self::flat_search`] directly due to its + /// unique iterator type constraints. /// - /// Returns a tuple containing: - /// - An optional vector of visited nodes (if requested in `search_params`). - /// - The number of distance computations performed. - /// - The number of hops (always zero for flat search, as no graph traversal occurs). + /// # Example /// - /// # Errors + /// ```ignore + /// use diskann::graph::{GraphSearch, RangeSearch, SearchDispatch}; /// - /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. + /// // Standard graph search + /// let params = GraphSearch::new(10, 100, None)?; + /// let stats = index.search(&strategy, &context, &query, ¶ms, &mut output).await?; + /// + /// // Range search (note: uses () as output buffer, results in Output type) + /// let params = RangeSearch::new(100, 0.5)?; + /// let result = index.search(&strategy, &context, &query, ¶ms, &mut ()).await?; + /// // result.ids and result.distances contain the matches + /// ``` + pub fn search<'a, S, T, O: 'a, OB, P>( + &'a self, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + search_params: &'a P, + output: &'a mut OB, + ) -> impl SendFuture> + 'a + where + P: super::search::SearchDispatch, + T: ?Sized, + OB: ?Sized, + { + search_params.dispatch(self, strategy, context, query, output) + } + + /// Perform a graph search while recording the traversal path. + /// + /// **Note:** This method is intended for debugging and analysis only. + /// For production searches, use [`Self::search`] with [`super::search::GraphSearch`]. + /// + /// Records which nodes were visited during the search traversal, useful for + /// understanding search behavior or diagnosing issues. #[allow(clippy::too_many_arguments)] - pub fn search_recorded( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &SearchParams, - output: &mut OB, - search_record: &mut SR, - ) -> impl SendFuture> + pub fn debug_search<'a, S, T, O, OB, SR>( + &'a self, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + search_params: &'a SearchParams, + output: &'a mut OB, + search_record: &'a mut SR, + ) -> impl SendFuture> + 'a where T: Sync + ?Sized, S: SearchStrategy, - O: Send, + O: Send + 'a, OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, - SR: SearchRecord + ?Sized, + SR: SearchRecord + Send, { async move { let mut accessor = strategy @@ -2414,11 +2194,12 @@ where let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; - let mut scratch = self.search_scratch(search_params.l_value, start_ids.len()); + let graph_search = super::search::GraphSearch::from(*search_params); + let mut scratch = self.search_scratch(graph_search.l, start_ids.len()); let stats = self .search_internal( - search_params.beam_width, + graph_search.beam_width, &start_ids, &mut accessor, &computer, @@ -2433,7 +2214,7 @@ where &mut accessor, query, &computer, - scratch.best.iter().take(search_params.l_value.into_usize()), + scratch.best.iter().take(graph_search.l.into_usize()), output, ) .send() @@ -2444,57 +2225,6 @@ where } } - /// Performs a graph-based search towards a target query vector. - /// - /// This method executes a search using the provided `strategy` to access and process elements. - /// It computes the similarity between the query vector and the elements in the index, moving towards the - /// nearest neighbors according to the search parameters. - /// - /// # Arguments - /// - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`) and beam width. - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. - /// - /// # Returns - /// - /// Returns a tuple containing: - /// - An optional vector of visited nodes (if requested in `search_params`). - /// - The number of distance computations performed. - /// - The number of hops (always zero for flat search, as no graph traversal occurs). - /// - /// # Errors - /// - /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. - pub fn search( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &SearchParams, - output: &mut OB, - ) -> impl SendFuture> - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, - { - async move { - self.search_recorded( - strategy, - context, - query, - search_params, - output, - &mut NoopSearchRecord::new(), - ) - .await - } - } - /// Performs a brute-force flat search over the points matching a provided filter function. /// /// This method executes a linear scan through all points in the index, applying the provided @@ -2512,10 +2242,7 @@ where /// /// # Returns /// - /// Returns a tuple containing: - /// - An optional vector of visited nodes (if requested in `search_params`). - /// - The number of distance computations performed. - /// - The number of hops (always zero for flat search, as no graph traversal occurs). + /// Returns search statistics including the number of distance computations performed. /// /// # Errors /// @@ -2591,229 +2318,6 @@ where }) } - /// A helper function for range search that allows an external application - /// to perform their own post-processing on the raw in-range results - #[allow(clippy::type_complexity)] - pub fn range_search_raw( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &RangeSearchParams, - ) -> impl SendFuture>)>> - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send + Default + Clone, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - let mut scratch = self.search_scratch(search_params.starting_l_value, start_ids.len()); - - let initial_stats = self - .search_internal( - search_params.beam_width, - &start_ids, - &mut accessor, - &computer, - &mut scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - let mut in_range = Vec::with_capacity(search_params.starting_l_value.into_usize()); - - for neighbor in scratch - .best - .iter() - .take(search_params.starting_l_value.into_usize()) - { - if neighbor.distance <= search_params.radius { - in_range.push(neighbor); - } - } - - // clear the visited set and repopulate it with just the in-range points - scratch.visited.clear(); - for neighbor in in_range.iter() { - scratch.visited.insert(neighbor.id); - } - scratch.in_range = in_range; - - let stats = if scratch.in_range.len() - >= ((search_params.starting_l_value as f32) * search_params.initial_search_slack) - as usize - { - // Move to range search - let range_stats = self - .range_search_internal(search_params, &mut accessor, &computer, &mut scratch) - .await?; - - InternalSearchStats { - cmps: initial_stats.cmps, - hops: initial_stats.hops + range_stats.hops, - range_search_second_round: true, - } - } else { - initial_stats - }; - - Ok(( - stats.finish(scratch.in_range.len() as u32), - scratch.in_range.to_vec(), - )) - } - } - - /// Given a `query` vector, search for all results within a specified radius - /// `l_value` is the search depth of the initial search phase - /// - /// Note that the radii in `search_params` are raw distances, not similarity scores; - /// the user is expected to execute any necessary transformations to their desired - /// radius before calling this function. - /// - /// We allow complicated types here to avoid needing an entirely new type definition - /// for just one function - #[allow(clippy::type_complexity)] - pub fn range_search( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &RangeSearchParams, - ) -> impl SendFuture, Vec)>> - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send + Default + Clone, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - - let (mut stats, in_range) = self - .range_search_raw(strategy, context, query, search_params) - .await?; - // create a new output buffer for the range search - // need to initialize distance buffer to max value because of later filtering step - let mut result_ids: Vec = vec![O::default(); in_range.len()]; - let mut result_dists: Vec = vec![f32::MAX; in_range.len()]; - - let mut output_buffer = search_output_buffer::IdDistance::new( - result_ids.as_mut_slice(), - result_dists.as_mut_slice(), - ); - - let _ = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - in_range.into_iter(), - &mut output_buffer, - ) - .send() - .await - .into_ann_result()?; - - // Filter the output buffer for points with distance between inner and outer radius - // Note this takes a dependency on the output of `post_process` being sorted by distance - - let inner_cutoff = if let Some(inner_radius) = search_params.inner_radius { - result_dists - .iter() - .position(|dist| *dist > inner_radius) - .unwrap_or(result_dists.len()) - } else { - 0 - }; - - let outer_cutoff = result_dists - .iter() - .position(|dist| *dist > search_params.radius) - .unwrap_or(result_dists.len()); - - result_ids.truncate(outer_cutoff); - result_ids.drain(0..inner_cutoff); - - result_dists.truncate(outer_cutoff); - result_dists.drain(0..inner_cutoff); - - let result_count = result_ids.len(); - - stats.result_count = result_count as u32; - - Ok((stats, result_ids, result_dists)) - } - } - - /// Graph search that takes into account label filter matching by expanding - /// each non-matching neighborhood to search for matching nodes - /// Label provider must be included as a function argument - /// Note that if the Strategy is of type BetaFilter, this function assumes - /// but does not enforce that the label provider used in the strategy - /// is the same as the one in the function argument - pub fn multihop_search( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &SearchParams, - output: &mut OB, - query_label_evaluator: &dyn QueryLabelProvider, - ) -> impl SendFuture> - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - - let start_ids = accessor.starting_points().await?; - - let mut scratch = self.search_scratch(search_params.l_value, start_ids.len()); - - let stats = self - .multihop_search_internal( - search_params, - &mut accessor, - &computer, - &mut scratch, - &mut NoopSearchRecord::new(), - query_label_evaluator, - ) - .await?; - - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(search_params.l_value.into_usize()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) - } - } - ////////////////// // Paged Search // ////////////////// @@ -3595,15 +3099,15 @@ struct InplaceDeleteWorkList { in_neighbors: Vec, } -/// Private internal struct for recording search statistics. -struct InternalSearchStats { - cmps: u32, - hops: u32, - range_search_second_round: bool, +/// Internal struct for recording search statistics. +pub(crate) struct InternalSearchStats { + pub(crate) cmps: u32, + pub(crate) hops: u32, + pub(crate) range_search_second_round: bool, } impl InternalSearchStats { - fn finish(self, result_count: u32) -> SearchStats { + pub(crate) fn finish(self, result_count: u32) -> SearchStats { SearchStats { cmps: self.cmps, hops: self.hops, @@ -3613,136 +3117,3 @@ impl InternalSearchStats { } } -#[cfg(feature = "experimental_diversity_search")] -impl DiskANNIndex -where - DP: DataProvider, -{ - /// Create a diverse search scratch with DiverseNeighborQueue - fn create_diverse_scratch

( - &self, - l_value: usize, - beam_width: Option, - diverse_params: &DiverseSearchParams

, - k_value: usize, - ) -> SearchScratch> - where - P: crate::neighbor::AttributeValueProvider, - { - use crate::neighbor::DiverseNeighborQueue; - - let attribute_provider = diverse_params.attribute_provider.clone(); - let diverse_queue = DiverseNeighborQueue::new( - l_value, - // SAFETY: k_value is guaranteed to be non-zero by SearchParams validation by caller - #[allow(clippy::expect_used)] - NonZeroUsize::new(k_value).expect("k_value must be non-zero"), - diverse_params.diverse_results_k, - attribute_provider, - ); - - SearchScratch { - best: diverse_queue, - visited: HashSet::with_capacity(self.estimate_visited_set_capacity(Some(l_value))), - id_scratch: Vec::with_capacity(self.max_degree_with_slack()), - beam_nodes: Vec::with_capacity(beam_width.unwrap_or(1)), - range_frontier: std::collections::VecDeque::new(), - in_range: Vec::new(), - hops: 0, - cmps: 0, - } - } - - /// Experimental diverse search implementation using DiverseNeighborQueue. - /// - /// This method performs a graph-based search with diversity constraints, using the provided - /// diverse search parameters to filter results based on attribute values. - /// - /// # Arguments - /// - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `search_params` - Parameters controlling the search behavior, including l_value, beam width, and k_value. - /// * `diverse_params` - Diversity parameters including attribute provider and alpha value. - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. - /// * `search_record` - A mutable reference to a search record object that will record the path taken during the search. - /// - /// # Returns - /// - /// Returns search statistics including comparisons and hops performed. - /// - /// # Errors - /// - /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. - #[allow(clippy::too_many_arguments)] - pub fn diverse_search_experimental( - &self, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &SearchParams, - diverse_params: &DiverseSearchParams

, - output: &mut OB, - search_record: &mut SR, - ) -> impl SendFuture> - where - T: Sync + ?Sized, - S: glue::SearchStrategy, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send, - SR: super::search::record::SearchRecord + ?Sized, - P: crate::neighbor::AttributeValueProvider, - { - async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - // Use diverse search with DiverseNeighborQueue - // TODO: Use scratch pool in future PRs to avoid allocation. - let mut diverse_scratch = self.create_diverse_scratch( - search_params.l_value, - search_params.beam_width, - diverse_params, - search_params.k_value, - ); - - let stats = self - .search_internal( - search_params.beam_width, - &start_ids, - &mut accessor, - &computer, - &mut diverse_scratch, - search_record, - ) - .await?; - - // Post-process diverse results to keep only diverse_results_k items - diverse_scratch.best.post_process(); - - // TODO: Post processing will change for diverse search in future PRs - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - diverse_scratch - .best - .iter() - .take(search_params.l_value.into_usize()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) - } - } -} diff --git a/diskann/src/graph/mod.rs b/diskann/src/graph/mod.rs index d203a74b0..2622e7ecb 100644 --- a/diskann/src/graph/mod.rs +++ b/diskann/src/graph/mod.rs @@ -32,6 +32,14 @@ pub use misc::DiverseSearchParams; pub mod glue; pub mod search; +// Re-export unified search interface as the primary API. +pub use search::{ + GraphSearch, MultihopSearch, RangeSearch, RangeSearchOutput, SearchDispatch, +}; + +#[cfg(feature = "experimental_diversity_search")] +pub use search::DiverseSearch; + mod internal; // Integration tests and test providers. diff --git a/diskann/src/graph/search/dispatch.rs b/diskann/src/graph/search/dispatch.rs new file mode 100644 index 000000000..b6e7a358f --- /dev/null +++ b/diskann/src/graph/search/dispatch.rs @@ -0,0 +1,33 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Core search dispatch trait. + +use diskann_utils::future::SendFuture; + +use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; + +/// Trait for search parameter types that execute their own search logic. +/// +/// Each search type (graph search, flat search, range search, etc.) implements +/// this trait to define its complete search behavior. The [`DiskANNIndex::search`] +/// method delegates to the `dispatch` method. +pub trait SearchDispatch +where + DP: DataProvider, +{ + /// The result type returned by this search. + type Output; + + /// Execute the search operation with full search logic. + fn dispatch<'a>( + &'a self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture>; +} diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs new file mode 100644 index 000000000..ecf77d6cb --- /dev/null +++ b/diskann/src/graph/search/diverse_search.rs @@ -0,0 +1,165 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Diversity-aware search (feature-gated). + +#![cfg(feature = "experimental_diversity_search")] + +use std::num::NonZeroUsize; + +use diskann_utils::future::{AssertSend, SendFuture}; +use hashbrown::HashSet; + +use super::{dispatch::SearchDispatch, graph_search::GraphSearch, record::NoopSearchRecord, scratch::SearchScratch}; +use crate::{ + ANNResult, + error::IntoANNResult, + graph::{ + DiverseSearchParams, + glue::{SearchExt, SearchPostProcess, SearchStrategy}, + index::{DiskANNIndex, SearchStats}, + search_output_buffer::SearchOutputBuffer, + }, + neighbor::{AttributeValueProvider, DiverseNeighborQueue, NeighborQueue}, + provider::{BuildQueryComputer, DataProvider}, + utils::IntoUsize, +}; + +/// Parameters for diversity-aware search. +/// +/// Returns results that are diverse across a specified attribute. +#[derive(Debug)] +pub struct DiverseSearch

+where + P: AttributeValueProvider, +{ + /// Base graph search parameters. + pub inner: GraphSearch, + /// Diversity-specific parameters. + pub diverse_params: DiverseSearchParams

, +} + +impl

DiverseSearch

+where + P: AttributeValueProvider, +{ + /// Create new diverse search parameters. + pub fn new(inner: GraphSearch, diverse_params: DiverseSearchParams

) -> Self { + Self { inner, diverse_params } + } +} + +impl SearchDispatch for DiverseSearch

+where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send, + P: AttributeValueProvider, +{ + type Output = SearchStats; + + fn dispatch<'a>( + &'a self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture> { + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut diverse_scratch = create_diverse_scratch( + index, + self.inner.l, + self.inner.beam_width, + &self.diverse_params, + self.inner.k, + ); + + let stats = index + .search_internal( + self.inner.beam_width, + &start_ids, + &mut accessor, + &computer, + &mut diverse_scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + // Post-process diverse results + diverse_scratch.best.post_process(); + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + diverse_scratch.best.iter().take(self.inner.l.into_usize()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} + +//============================================================================= +// Internal Implementation +//============================================================================= + +/// Create a diverse search scratch with DiverseNeighborQueue. +/// +/// # Arguments +/// +/// * `index` - The DiskANN index for capacity estimation +/// * `l_value` - Search list size +/// * `beam_width` - Optional beam width for parallel exploration +/// * `diverse_params` - Diversity-specific parameters +/// * `k_value` - Number of results to return +pub(crate) fn create_diverse_scratch( + index: &DiskANNIndex, + l_value: usize, + beam_width: Option, + diverse_params: &DiverseSearchParams

, + k_value: usize, +) -> SearchScratch> +where + DP: DataProvider, + P: AttributeValueProvider, +{ + let attribute_provider = diverse_params.attribute_provider.clone(); + let diverse_queue = DiverseNeighborQueue::new( + l_value, + // SAFETY: k_value is guaranteed to be non-zero by SearchParams validation by caller + #[allow(clippy::expect_used)] + NonZeroUsize::new(k_value).expect("k_value must be non-zero"), + diverse_params.diverse_results_k, + attribute_provider, + ); + + SearchScratch { + best: diverse_queue, + visited: HashSet::with_capacity(index.estimate_visited_set_capacity(Some(l_value))), + id_scratch: Vec::with_capacity(index.max_degree_with_slack()), + beam_nodes: Vec::with_capacity(beam_width.unwrap_or(1)), + range_frontier: std::collections::VecDeque::new(), + in_range: Vec::new(), + hops: 0, + cmps: 0, + } +} diff --git a/diskann/src/graph/search/graph_search.rs b/diskann/src/graph/search/graph_search.rs new file mode 100644 index 000000000..de2d09fe1 --- /dev/null +++ b/diskann/src/graph/search/graph_search.rs @@ -0,0 +1,227 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Standard graph-based ANN search. + +use std::fmt::Debug; + +use diskann_utils::future::{AssertSend, SendFuture}; + +use super::dispatch::SearchDispatch; +use crate::{ + ANNResult, + error::IntoANNResult, + graph::{ + glue::{SearchExt, SearchPostProcess, SearchStrategy}, + index::{DiskANNIndex, SearchStats}, + search::record::NoopSearchRecord, + search_output_buffer::SearchOutputBuffer, + }, + provider::{BuildQueryComputer, DataProvider}, + utils::IntoUsize, +}; + +/// Parameters for standard graph-based ANN search. +/// +/// This is the primary search mode, using the Vamana graph structure for efficient +/// approximate nearest neighbor traversal. +#[derive(Debug, Clone, Copy)] +pub struct GraphSearch { + /// Number of results to return (k in k-NN). + pub k: usize, + /// Search list size - controls accuracy vs speed tradeoff. + pub l: usize, + /// Optional beam width for parallel graph exploration. + pub beam_width: Option, +} + +impl GraphSearch { + /// Create new graph search parameters. + /// + /// # Errors + /// + /// Returns an error if `l < k` or if any value is zero. + pub fn new( + k: usize, + l: usize, + beam_width: Option, + ) -> Result { + use super::super::SearchParamsError; + + if k > l { + return Err(SearchParamsError::LLessThanK { l_value: l, k_value: k }); + } + if let Some(bw) = beam_width { + if bw == 0 { + return Err(SearchParamsError::BeamWidthZero); + } + } + if k == 0 { + return Err(SearchParamsError::KZero); + } + if l == 0 { + return Err(SearchParamsError::LZero); + } + + Ok(Self { k, l, beam_width }) + } + + /// Create parameters with default beam width. + pub fn new_default(k: usize, l: usize) -> Result { + Self::new(k, l, None) + } +} + +impl From for GraphSearch { + fn from(params: super::super::SearchParams) -> Self { + Self { + k: params.k_value, + l: params.l_value, + beam_width: params.beam_width, + } + } +} + +/// Implement SearchDispatch for SearchParams to provide backwards compatibility. +/// This treats SearchParams as an alias for GraphSearch. +impl SearchDispatch for super::super::SearchParams +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, +{ + type Output = SearchStats; + + fn dispatch<'a>( + &'a self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture> { + async move { + let graph_search = GraphSearch::from(*self); + graph_search.dispatch(index, strategy, context, query, output).await + } + } +} + +impl SearchDispatch for GraphSearch +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, +{ + type Output = SearchStats; + + fn dispatch<'a>( + &'a self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture> { + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(self.l, start_ids.len()); + + let stats = index + .search_internal( + self.beam_width, + &start_ids, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch.best.iter().take(self.l.into_usize()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} + +//============================================================================= +// Recorded Graph Search +//============================================================================= + +/// Graph search with traversal path recording. +/// +/// Records the path taken during search for debugging or analysis. +pub struct RecordedGraphSearch<'r, SR: ?Sized> { + /// Base graph search parameters. + pub inner: GraphSearch, + /// The recorder to capture search path. + pub recorder: &'r mut SR, +} + +impl<'r, SR: ?Sized> RecordedGraphSearch<'r, SR> { + /// Create new recorded search parameters. + pub fn new(inner: GraphSearch, recorder: &'r mut SR) -> Self { + Self { inner, recorder } + } +} + +impl<'r, SR: Debug + ?Sized> Debug for RecordedGraphSearch<'r, SR> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RecordedGraphSearch") + .field("inner", &self.inner) + .finish_non_exhaustive() + } +} + +// Note: RecordedGraphSearch cannot implement SearchDispatch because it holds &mut recorder +// which conflicts with the shared reference semantics of dispatch. Users should call +// the search logic directly or use a Cell/RefCell pattern if needed. + +//============================================================================= +// Tests +//============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_graph_search_validation() { + // Valid + assert!(GraphSearch::new(10, 100, None).is_ok()); + assert!(GraphSearch::new(10, 100, Some(4)).is_ok()); + assert!(GraphSearch::new(10, 10, None).is_ok()); // k == l is valid + + // Invalid: l < k + assert!(GraphSearch::new(100, 10, None).is_err()); + + // Invalid: zero values + assert!(GraphSearch::new(0, 100, None).is_err()); + assert!(GraphSearch::new(10, 0, None).is_err()); + assert!(GraphSearch::new(10, 100, Some(0)).is_err()); + } +} diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 2b02ac39f..82d549950 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -3,5 +3,46 @@ * Licensed under the MIT license. */ +//! Unified search execution framework. +//! +//! This module provides the primary search interface for DiskANN. All search types +//! are represented as parameter structs that implement [`SearchDispatch`], which +//! contains the complete search logic. +//! +//! # Usage +//! +//! ```ignore +//! use diskann::graph::{GraphSearch, RangeSearch, MultihopSearch, SearchDispatch}; +//! +//! // Standard graph search +//! let params = GraphSearch::new(10, 100, None)?; +//! let stats = index.search(&strategy, &context, &query, ¶ms, &mut output).await?; +//! +//! // Range search +//! let params = RangeSearch::new(100, 0.5)?; +//! let result = index.search(&strategy, &context, &query, ¶ms, &mut ()).await?; +//! println!("Found {} points within radius", result.ids.len()); +//! ``` + +mod dispatch; +mod graph_search; +mod multihop_search; +mod range_search; + pub mod record; pub(crate) mod scratch; + +// Re-export the core dispatch trait. +pub use dispatch::SearchDispatch; + +// Re-export search parameter types. +pub use graph_search::{GraphSearch, RecordedGraphSearch}; +pub use multihop_search::MultihopSearch; +pub use range_search::{RangeSearch, RangeSearchOutput}; + +// Feature-gated diverse search. +#[cfg(feature = "experimental_diversity_search")] +mod diverse_search; + +#[cfg(feature = "experimental_diversity_search")] +pub use diverse_search::DiverseSearch; diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs new file mode 100644 index 000000000..0e25d9585 --- /dev/null +++ b/diskann/src/graph/search/multihop_search.rs @@ -0,0 +1,312 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Label-filtered search using multi-hop expansion. + +use std::fmt::Debug; + +use diskann_utils::future::{AssertSend, SendFuture}; +use diskann_utils::Reborrow; +use diskann_vector::PreprocessedDistanceFunction; +use hashbrown::HashSet; + +use super::{dispatch::SearchDispatch, record::SearchRecord, scratch::SearchScratch}; +use crate::{ + ANNResult, + error::{ErrorExt, IntoANNResult}, + graph::{ + SearchParams, + glue::{self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, SearchPostProcess, SearchStrategy}, + index::{DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats}, + search::record::NoopSearchRecord, + search_output_buffer::SearchOutputBuffer, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, DataProvider}, + utils::{IntoUsize, VectorId}, +}; + +use super::graph_search::GraphSearch; + +/// Parameters for label-filtered search using multi-hop expansion. +/// +/// This search extends standard graph search by expanding through non-matching +/// nodes to find matching neighbors. More efficient than flat search when the +/// matching subset is reasonably large. +pub struct MultihopSearch<'q, InternalId> { + /// Base graph search parameters. + pub inner: GraphSearch, + /// Label evaluator for determining node matches. + pub label_evaluator: &'q dyn QueryLabelProvider, +} + +impl Debug for MultihopSearch<'_, InternalId> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MultihopSearch") + .field("inner", &self.inner) + .field("label_evaluator", self.label_evaluator) + .finish() + } +} + +impl<'q, InternalId> MultihopSearch<'q, InternalId> { + /// Create new multihop search parameters. + pub fn new( + inner: GraphSearch, + label_evaluator: &'q dyn QueryLabelProvider, + ) -> Self { + Self { inner, label_evaluator } + } +} + +impl<'q, DP, S, T, O, OB> SearchDispatch for MultihopSearch<'q, DP::InternalId> +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send, +{ + type Output = SearchStats; + + fn dispatch<'a>( + &'a self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture> { + let params = SearchParams { + k_value: self.inner.k, + l_value: self.inner.l, + beam_width: self.inner.beam_width, + }; + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + let computer = accessor.build_query_computer(query).into_ann_result()?; + + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(params.l_value, start_ids.len()); + + let stats = multihop_search_internal( + index.max_degree_with_slack(), + ¶ms, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + self.label_evaluator, + ) + .await?; + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch.best.iter().take(params.l_value.into_usize()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} + +//============================================================================= +// Internal Implementation +//============================================================================= + +/// A predicate that checks if an item is not in the visited set AND matches the label filter. +/// +/// Used during two-hop expansion to filter neighbors based on both visitation +/// status and label matching criteria. +pub struct NotInMutWithLabelCheck<'a, K> +where + K: VectorId, +{ + visited_set: &'a mut HashSet, + query_label_evaluator: &'a dyn QueryLabelProvider, +} + +impl<'a, K> NotInMutWithLabelCheck<'a, K> +where + K: VectorId, +{ + /// Construct a new `NotInMutWithLabelCheck` around `visited_set`. + pub fn new( + visited_set: &'a mut HashSet, + query_label_evaluator: &'a dyn QueryLabelProvider, + ) -> Self { + Self { + visited_set, + query_label_evaluator, + } + } +} + +impl Predicate for NotInMutWithLabelCheck<'_, K> +where + K: VectorId, +{ + fn eval(&self, item: &K) -> bool { + !self.visited_set.contains(item) && self.query_label_evaluator.is_match(*item) + } +} + +impl PredicateMut for NotInMutWithLabelCheck<'_, K> +where + K: VectorId, +{ + fn eval_mut(&mut self, item: &K) -> bool { + if self.query_label_evaluator.is_match(*item) { + return self.visited_set.insert(*item); + } + false + } +} + +impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId {} + +/// Internal multihop search implementation. +/// +/// Performs label-filtered search by expanding through non-matching nodes +/// to find matching neighbors within two hops. +pub(crate) async fn multihop_search_internal( + max_degree_with_slack: usize, + search_params: &SearchParams, + accessor: &mut A, + computer: &A::QueryComputer, + scratch: &mut SearchScratch, + search_record: &mut SR, + query_label_evaluator: &dyn QueryLabelProvider, +) -> ANNResult +where + I: VectorId, + A: ExpandBeam + SearchExt, + T: ?Sized, + SR: SearchRecord + ?Sized, +{ + let beam_width = search_params.beam_width.unwrap_or(1); + + // Helper to build the final stats from scratch state. + let make_stats = |scratch: &SearchScratch| InternalSearchStats { + cmps: scratch.cmps, + hops: scratch.hops, + range_search_second_round: false, + }; + + // Initialize search state if not already initialized. + // This allows paged search to call multihop_search_internal multiple times + if scratch.visited.is_empty() { + let start_ids = accessor.starting_points().await?; + + for id in start_ids { + scratch.visited.insert(id); + let element = accessor + .get_element(id) + .await + .escalate("start point retrieval must succeed")?; + let dist = computer.evaluate_similarity(element.reborrow()); + scratch.best.insert(Neighbor::new(id, dist)); + } + } + + // Pre-allocate with good capacity to avoid repeated allocations + let mut one_hop_neighbors = Vec::with_capacity(max_degree_with_slack); + let mut two_hop_neighbors = Vec::with_capacity(max_degree_with_slack); + let mut candidates_two_hop_expansion = Vec::with_capacity(max_degree_with_slack); + + while scratch.best.has_notvisited_node() && !accessor.terminate_early() { + scratch.beam_nodes.clear(); + one_hop_neighbors.clear(); + candidates_two_hop_expansion.clear(); + two_hop_neighbors.clear(); + + // In this loop we are going to find the beam_width number of nodes that are closest to the query. + // Each of these nodes will be a frontier node. + while scratch.best.has_notvisited_node() && scratch.beam_nodes.len() < beam_width { + let closest_node = scratch.best.closest_notvisited(); + search_record.record(closest_node, scratch.hops, scratch.cmps); + scratch.beam_nodes.push(closest_node.id); + } + + // compute distances from query to one-hop neighbors, and mark them visited + accessor + .expand_beam( + scratch.beam_nodes.iter().copied(), + computer, + glue::NotInMut::new(&mut scratch.visited), + |distance, id| one_hop_neighbors.push(Neighbor::new(id, distance)), + ) + .await?; + + // Process one-hop neighbors based on on_visit() decision + for neighbor in one_hop_neighbors.iter().copied() { + match query_label_evaluator.on_visit(neighbor) { + QueryVisitDecision::Accept(accepted) => { + scratch.best.insert(accepted); + } + QueryVisitDecision::Reject => { + // Rejected nodes: still add to two-hop expansion so we can traverse through them + candidates_two_hop_expansion.push(neighbor); + } + QueryVisitDecision::Terminate => { + scratch.cmps += one_hop_neighbors.len() as u32; + scratch.hops += scratch.beam_nodes.len() as u32; + return Ok(make_stats(scratch)); + } + } + } + + scratch.cmps += one_hop_neighbors.len() as u32; + scratch.hops += scratch.beam_nodes.len() as u32; + + // sort the candidates for two-hop expansion by distance to query point + candidates_two_hop_expansion.sort_unstable_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // limit the number of two-hop candidates to avoid too many expansions + candidates_two_hop_expansion.truncate(max_degree_with_slack / 2); + + // Expand each two-hop candidate: if its neighbor is a match, compute its distance + // to the query and insert into `scratch.visited` + // If it is not a match, do nothing + let two_hop_expansion_candidate_ids: Vec = + candidates_two_hop_expansion.iter().map(|n| n.id).collect(); + + accessor + .expand_beam( + two_hop_expansion_candidate_ids.iter().copied(), + computer, + NotInMutWithLabelCheck::new(&mut scratch.visited, query_label_evaluator), + |distance, id| { + two_hop_neighbors.push(Neighbor::new(id, distance)); + }, + ) + .await?; + + // Next, insert the new matches into `scratch.best` and increment stats counters + two_hop_neighbors + .iter() + .for_each(|neighbor| scratch.best.insert(*neighbor)); + + scratch.cmps += two_hop_neighbors.len() as u32; + scratch.hops += two_hop_expansion_candidate_ids.len() as u32; + } + + Ok(make_stats(scratch)) +} diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs new file mode 100644 index 000000000..64b064016 --- /dev/null +++ b/diskann/src/graph/search/range_search.rs @@ -0,0 +1,371 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Range-based search within a distance radius. + +use diskann_utils::future::{AssertSend, SendFuture}; + +use super::{dispatch::SearchDispatch, scratch::SearchScratch}; +use crate::{ + ANNResult, + error::IntoANNResult, + graph::{ + RangeSearchParams, + glue::{self, ExpandBeam, SearchExt, SearchPostProcess, SearchStrategy}, + index::{DiskANNIndex, InternalSearchStats, SearchStats}, + search::record::NoopSearchRecord, + search_output_buffer, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, DataProvider}, + utils::IntoUsize, +}; + +/// Result from a range search operation. +pub struct RangeSearchOutput { + /// Search statistics. + pub stats: SearchStats, + /// IDs of points within the radius. + pub ids: Vec, + /// Distances corresponding to each ID. + pub distances: Vec, +} + +/// Parameters for range-based search. +/// +/// Finds all points within a specified distance radius from the query. +#[derive(Debug, Clone, Copy)] +pub struct RangeSearch { + /// Maximum results to return (None = unlimited). + pub max_returned: Option, + /// Initial search list size. + pub starting_l: usize, + /// Optional beam width. + pub beam_width: Option, + /// Outer radius - points within this distance are candidates. + pub radius: f32, + /// Inner radius - points closer than this are excluded. + pub inner_radius: Option, + /// Slack factor for initial search phase (0.0 to 1.0). + pub initial_slack: f32, + /// Slack factor for range expansion (>= 1.0). + pub range_slack: f32, +} + +impl RangeSearch { + /// Create range search with default slack values. + pub fn new( + starting_l: usize, + radius: f32, + ) -> Result { + Self::with_options(None, starting_l, None, radius, None, 1.0, 1.0) + } + + /// Create range search with full options. + #[allow(clippy::too_many_arguments)] + pub fn with_options( + max_returned: Option, + starting_l: usize, + beam_width: Option, + radius: f32, + inner_radius: Option, + initial_slack: f32, + range_slack: f32, + ) -> Result { + use super::super::RangeSearchParamsError; + + if let Some(bw) = beam_width { + if bw == 0 { + return Err(RangeSearchParamsError::BeamWidthZero); + } + } + if starting_l == 0 { + return Err(RangeSearchParamsError::LZero); + } + if !(0.0..=1.0).contains(&initial_slack) { + return Err(RangeSearchParamsError::StartingListSlackValueError); + } + if range_slack < 1.0 { + return Err(RangeSearchParamsError::RangeSearchSlackValueError); + } + if let Some(inner) = inner_radius { + if inner > radius { + return Err(RangeSearchParamsError::InnerRadiusValueError); + } + } + + Ok(Self { + max_returned, + starting_l, + beam_width, + radius, + inner_radius, + initial_slack, + range_slack, + }) + } + + fn to_legacy_params(&self) -> RangeSearchParams { + RangeSearchParams { + max_returned: self.max_returned, + starting_l_value: self.starting_l, + beam_width: self.beam_width, + radius: self.radius, + inner_radius: self.inner_radius, + initial_search_slack: self.initial_slack, + range_search_slack: self.range_slack, + } + } +} + +impl From for RangeSearch { + fn from(params: RangeSearchParams) -> Self { + Self { + max_returned: params.max_returned, + starting_l: params.starting_l_value, + beam_width: params.beam_width, + radius: params.radius, + inner_radius: params.inner_radius, + initial_slack: params.initial_search_slack, + range_slack: params.range_search_slack, + } + } +} + +impl SearchDispatch for RangeSearch +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send + Default + Clone, +{ + type Output = RangeSearchOutput; + + fn dispatch<'a>( + &'a self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + _output: &'a mut (), + ) -> impl SendFuture> { + let search_params = self.to_legacy_params(); + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(search_params.starting_l_value, start_ids.len()); + + let initial_stats = index + .search_internal( + search_params.beam_width, + &start_ids, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + let mut in_range = Vec::with_capacity(search_params.starting_l_value.into_usize()); + + for neighbor in scratch + .best + .iter() + .take(search_params.starting_l_value.into_usize()) + { + if neighbor.distance <= search_params.radius { + in_range.push(neighbor); + } + } + + // clear the visited set and repopulate it with just the in-range points + scratch.visited.clear(); + for neighbor in in_range.iter() { + scratch.visited.insert(neighbor.id); + } + scratch.in_range = in_range; + + let stats = if scratch.in_range.len() + >= ((search_params.starting_l_value as f32) * search_params.initial_search_slack) + as usize + { + // Move to range search + let range_stats = range_search_internal( + index.max_degree_with_slack(), + &search_params, + &mut accessor, + &computer, + &mut scratch, + ) + .await?; + + InternalSearchStats { + cmps: initial_stats.cmps, + hops: initial_stats.hops + range_stats.hops, + range_search_second_round: true, + } + } else { + initial_stats + }; + + // Post-process results + let mut result_ids: Vec = vec![O::default(); scratch.in_range.len()]; + let mut result_dists: Vec = vec![f32::MAX; scratch.in_range.len()]; + + let mut output_buffer = search_output_buffer::IdDistance::new( + result_ids.as_mut_slice(), + result_dists.as_mut_slice(), + ); + + let _ = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch.in_range.iter().copied(), + &mut output_buffer, + ) + .send() + .await + .into_ann_result()?; + + // Filter by inner/outer radius + let inner_cutoff = if let Some(inner_radius) = search_params.inner_radius { + result_dists + .iter() + .position(|dist| *dist > inner_radius) + .unwrap_or(result_dists.len()) + } else { + 0 + }; + + let outer_cutoff = result_dists + .iter() + .position(|dist| *dist > search_params.radius) + .unwrap_or(result_dists.len()); + + result_ids.truncate(outer_cutoff); + result_ids.drain(0..inner_cutoff); + + result_dists.truncate(outer_cutoff); + result_dists.drain(0..inner_cutoff); + + let result_count = result_ids.len(); + + Ok(RangeSearchOutput { + stats: SearchStats { + cmps: stats.cmps, + hops: stats.hops, + result_count: result_count as u32, + range_search_second_round: stats.range_search_second_round, + }, + ids: result_ids, + distances: result_dists, + }) + } + } +} + +//============================================================================= +// Internal Implementation +//============================================================================= + +/// Internal range search implementation. +/// +/// Expands the search frontier to find all points within the specified radius. +/// Called after the initial graph search has identified starting candidates. +pub(crate) async fn range_search_internal( + max_degree_with_slack: usize, + search_params: &RangeSearchParams, + accessor: &mut A, + computer: &A::QueryComputer, + scratch: &mut SearchScratch, +) -> ANNResult +where + I: crate::utils::VectorId, + A: ExpandBeam + SearchExt, + T: ?Sized, +{ + let beam_width = search_params.beam_width.unwrap_or(1); + + for neighbor in &scratch.in_range { + scratch.range_frontier.push_back(neighbor.id); + } + + let mut neighbors = Vec::with_capacity(max_degree_with_slack); + + let max_returned = search_params.max_returned.unwrap_or(usize::MAX); + + while !scratch.range_frontier.is_empty() { + scratch.beam_nodes.clear(); + + // In this loop we are going to find the beam_width number of remaining nodes within the radius + // Each of these nodes will be a frontier node. + while !scratch.range_frontier.is_empty() && scratch.beam_nodes.len() < beam_width { + let next = scratch.range_frontier.pop_front(); + if let Some(next_node) = next { + scratch.beam_nodes.push(next_node); + } + } + + neighbors.clear(); + accessor + .expand_beam( + scratch.beam_nodes.iter().copied(), + computer, + glue::NotInMut::new(&mut scratch.visited), + |distance, id| neighbors.push(Neighbor::new(id, distance)), + ) + .await?; + + // The predicate ensures that the contents of `neighbors` are unique. + for neighbor in neighbors.iter() { + if neighbor.distance <= search_params.radius * search_params.range_search_slack + && scratch.in_range.len() < max_returned + { + scratch.in_range.push(*neighbor); + scratch.range_frontier.push_back(neighbor.id); + } + } + scratch.cmps += neighbors.len() as u32; + scratch.hops += scratch.beam_nodes.len() as u32; + } + + Ok(InternalSearchStats { + cmps: scratch.cmps, + hops: scratch.hops, + range_search_second_round: true, + }) +} + +//============================================================================= +// Tests +//============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_range_search_validation() { + // Valid + assert!(RangeSearch::new(100, 0.5).is_ok()); + + // Invalid: zero l + assert!(RangeSearch::new(0, 0.5).is_err()); + + // Invalid slack values + assert!(RangeSearch::with_options(None, 100, None, 0.5, None, 1.5, 1.0).is_err()); + assert!(RangeSearch::with_options(None, 100, None, 0.5, None, 1.0, 0.5).is_err()); + + // Invalid inner radius > radius + assert!(RangeSearch::with_options(None, 100, None, 0.5, Some(1.0), 1.0, 1.0).is_err()); + } +} diff --git a/diskann/src/graph/test/cases/grid.rs b/diskann/src/graph/test/cases/grid.rs index 2ea40b677..7c0c03d7f 100644 --- a/diskann/src/graph/test/cases/grid.rs +++ b/diskann/src/graph/test/cases/grid.rs @@ -9,7 +9,7 @@ use diskann_vector::distance::Metric; use crate::{ graph::{ - self, DiskANNIndex, + self, DiskANNIndex, GraphSearch, test::{provider as test_provider, synthetic::Grid}, }, neighbor::Neighbor, @@ -126,10 +126,10 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { // are correct. let index = setup_grid_search(grid, size); - let params = graph::SearchParams::new(10, 10, Some(beam_width)).unwrap(); + let params = GraphSearch::new(10, 10, Some(beam_width)).unwrap(); let context = test_provider::Context::new(); - let mut neighbors = vec![Neighbor::::default(); params.k_value]; + let mut neighbors = vec![Neighbor::::default(); params.k]; let graph::index::SearchStats { cmps, hops, @@ -147,7 +147,7 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { assert_eq!( result_count.into_usize(), - params.k_value, + params.k, "grid search should be configured to always return the requested number of neighbors", ); From 77e66774f2011c6a530c1fa92cb70d0a6cf9306a Mon Sep 17 00:00:00 2001 From: narendatha <164128452+narendatha@users.noreply.github.com> Date: Thu, 12 Feb 2026 22:39:01 +0530 Subject: [PATCH 02/11] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- diskann/src/graph/search/dispatch.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann/src/graph/search/dispatch.rs b/diskann/src/graph/search/dispatch.rs index b6e7a358f..8152d1750 100644 --- a/diskann/src/graph/search/dispatch.rs +++ b/diskann/src/graph/search/dispatch.rs @@ -11,7 +11,7 @@ use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; /// Trait for search parameter types that execute their own search logic. /// -/// Each search type (graph search, flat search, range search, etc.) implements +/// Each search type (graph search, range search, etc.) implements /// this trait to define its complete search behavior. The [`DiskANNIndex::search`] /// method delegates to the `dispatch` method. pub trait SearchDispatch From e64fcb497e8109bd7349ef0d3969ecfe97aebbf4 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Feb 2026 23:23:30 +0530 Subject: [PATCH 03/11] refactor(search): change SearchDispatch::dispatch to take &mut self This enables RecordedGraphSearch to implement SearchDispatch properly, since it holds &mut recorder which requires mutable access. - Change dispatch signature from &self to &mut self - Update index.search to take &mut P for search params - Implement SearchDispatch for RecordedGraphSearch - Remove apologetic comment about trait limitations - Update all callers to use &mut params --- .../src/search/graph/knn.rs | 4 +- .../src/search/graph/multihop.rs | 4 +- .../src/search/graph/range.rs | 4 +- .../src/search/provider/disk_provider.rs | 4 +- diskann-providers/src/index/diskann_async.rs | 71 +++++++++++++------ diskann-providers/src/index/wrapped_async.rs | 4 +- diskann/src/graph/index.rs | 8 +-- diskann/src/graph/search/dispatch.rs | 5 +- diskann/src/graph/search/diverse_search.rs | 2 +- diskann/src/graph/search/graph_search.rs | 66 +++++++++++++++-- diskann/src/graph/search/mod.rs | 8 +-- diskann/src/graph/search/multihop_search.rs | 2 +- diskann/src/graph/search/range_search.rs | 2 +- diskann/src/graph/test/cases/grid.rs | 4 +- 14 files changed, 135 insertions(+), 53 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 2d1a3c064..9e106b133 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -113,14 +113,14 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let graph_search = graph::GraphSearch::from(*parameters); + let mut graph_search = graph::GraphSearch::from(*parameters); let stats = self .index .search( self.strategy.get(index)?, &context, self.queries.row(index), - &graph_search, + &mut graph_search, buffer, ) .await?; diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index e191b7944..08f09d920 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -111,7 +111,7 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let multihop_search = graph::MultihopSearch::new( + let mut multihop_search = graph::MultihopSearch::new( graph::GraphSearch::from(*parameters), &*self.labels[index], ); @@ -121,7 +121,7 @@ where self.strategy.get(index)?, &context, self.queries.row(index), - &multihop_search, + &mut multihop_search, buffer, ) .await?; diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index 8cfaf9a9a..2e164f75d 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -104,14 +104,14 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let range_search = graph::RangeSearch::from(*parameters); + let mut range_search = graph::RangeSearch::from(*parameters); let result = self .index .search( self.strategy.get(index)?, &context, self.queries.row(index), - &range_search, + &mut range_search, &mut (), ) .await?; diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index be12f22f3..3e37e02bd 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -993,12 +993,12 @@ where &mut result_output_buffer, ))? } else { - let graph_search = GraphSearch::new(k_value, search_list_size as usize, beam_width)?; + let mut graph_search = GraphSearch::new(k_value, search_list_size as usize, beam_width)?; self.runtime.block_on(self.index.search( &strategy, &DefaultContext, strategy.query, - &graph_search, + &mut graph_search, &mut result_output_buffer, ))? }; diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 8a7d982a9..147b49492 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -238,11 +238,11 @@ pub(crate) mod tests { O: Send, OB: graph::search_output_buffer::SearchOutputBuffer + Send, { - let multihop = graph::MultihopSearch::new( + let mut multihop = graph::MultihopSearch::new( graph::GraphSearch::from(*search_params), filter, ); - index.search(strategy, context, query, &multihop, output).await + index.search(strategy, context, query, &mut multihop, output).await } /// Test helper: performs range search using the dispatch API. @@ -259,8 +259,8 @@ pub(crate) mod tests { S: graph::glue::SearchStrategy, O: Send + Default + Clone, { - let range_search = graph::RangeSearch::from(*search_params); - let result = index.search(strategy, context, query, &range_search, &mut ()).await?; + let mut range_search = graph::RangeSearch::from(*search_params); + let result = index.search(strategy, context, query, &mut range_search, &mut ()).await?; Ok((result.stats, result.ids, result.distances)) } @@ -403,12 +403,14 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = + SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(); index .search( &strategy, ¶meters.context, query, - &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -449,12 +451,14 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = + SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(); multihop_search( index, strategy, ¶meters.context, query, - &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), + &mut search_params, &mut result_output_buffer, filter, ) @@ -1492,12 +1496,14 @@ pub(crate) mod tests { let filter = CallbackFilter::new(blocked, adjusted, 0.5); + let mut search_params = + SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(); let stats = multihop_search( &index, &FullPrecision, ¶meters.context, query.as_slice(), - &SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(), + &mut search_params, &mut result_output_buffer, &filter, ) @@ -2239,13 +2245,14 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( &FullPrecision, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -2256,13 +2263,14 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); // Quantized Search index .search( &Hybrid::new(None), ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -2505,13 +2513,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = + SearchParams::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( &FullPrecision, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -2522,13 +2532,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = + SearchParams::new_default(top_k, search_l).unwrap(); // Quantized Search index .search( &Quantized, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -2608,13 +2620,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = + SearchParams::new_default(top_k, top_k).unwrap(); // Quantized Search index .search( &Quantized, ctx, query, - &SearchParams::new_default(top_k, top_k).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -2721,12 +2735,13 @@ pub(crate) mod tests { // Full Precision Search. let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); index .search( &FullPrecision, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut output, ) .await @@ -2738,13 +2753,14 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); index .search( &strategy, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut output, ) .await @@ -2846,13 +2862,14 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); index .search( &strategy, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut output, ) .await @@ -2938,13 +2955,14 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( &Quantized, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -3518,13 +3536,14 @@ pub(crate) mod tests { let gt = groundtruth(queries.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( &Hybrid::new(max_fp_vecs_per_prune), ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -3664,13 +3683,14 @@ pub(crate) mod tests { let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( &FullPrecision, ctx, query, - &SearchParams::new_default(top_k, search_l).unwrap(), + &mut search_params, &mut result_output_buffer, ) .await @@ -4153,12 +4173,13 @@ pub(crate) mod tests { // but reject everything via on_visit let filter = RejectAllFilter::only([0_u32]); + let mut search_params = SearchParams::new_default(10, 20).unwrap(); let stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), + &mut search_params, &mut result_output_buffer, &filter, ) @@ -4215,12 +4236,13 @@ pub(crate) mod tests { let target = (num_points / 2) as u32; let filter = TerminatingFilter::new(target); + let mut search_params = SearchParams::new_default(10, 40).unwrap(); let stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 40).unwrap(), + &mut search_params, &mut result_output_buffer, &filter, ) @@ -4279,12 +4301,13 @@ pub(crate) mod tests { let mut baseline_buffer = search_output_buffer::IdDistance::new(&mut baseline_ids, &mut baseline_distances); + let mut search_params = SearchParams::new_default(10, 20).unwrap(); let baseline_stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), + &mut search_params, &mut baseline_buffer, &EvenFilter, // Just filter to even IDs ) @@ -4300,12 +4323,13 @@ pub(crate) mod tests { let mut adjusted_buffer = search_output_buffer::IdDistance::new(&mut adjusted_ids, &mut adjusted_distances); + let mut search_params = SearchParams::new_default(10, 20).unwrap(); let adjusted_stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 20).unwrap(), + &mut search_params, &mut adjusted_buffer, &filter, ) @@ -4426,12 +4450,13 @@ pub(crate) mod tests { let max_visits = 5; let filter = TerminateAfterN::new(max_visits); + let mut search_params = SearchParams::new_default(10, 100).unwrap(); // Large L to ensure we'd visit more without termination let _stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &SearchParams::new_default(10, 100).unwrap(), // Large L to ensure we'd visit more without termination + &mut search_params, &mut result_output_buffer, &filter, ) diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 050b76afb..d8b0f4054 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -235,10 +235,10 @@ where O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, { - let graph_search = diskann::graph::GraphSearch::from(*search_params); + let mut graph_search = diskann::graph::GraphSearch::from(*search_params); self.handle.block_on( self.inner - .search(strategy, context, query, &graph_search, output), + .search(strategy, context, query, &mut graph_search, output), ) } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index f74bb36e4..c151fad44 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2139,11 +2139,11 @@ where /// /// // Standard graph search /// let params = GraphSearch::new(10, 100, None)?; - /// let stats = index.search(&strategy, &context, &query, ¶ms, &mut output).await?; + /// let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?; /// /// // Range search (note: uses () as output buffer, results in Output type) - /// let params = RangeSearch::new(100, 0.5)?; - /// let result = index.search(&strategy, &context, &query, ¶ms, &mut ()).await?; + /// let mut params = RangeSearch::new(100, 0.5)?; + /// let result = index.search(&strategy, &context, &query, &mut params, &mut ()).await?; /// // result.ids and result.distances contain the matches /// ``` pub fn search<'a, S, T, O: 'a, OB, P>( @@ -2151,7 +2151,7 @@ where strategy: &'a S, context: &'a DP::Context, query: &'a T, - search_params: &'a P, + search_params: &'a mut P, output: &'a mut OB, ) -> impl SendFuture> + 'a where diff --git a/diskann/src/graph/search/dispatch.rs b/diskann/src/graph/search/dispatch.rs index 8152d1750..79e10c31e 100644 --- a/diskann/src/graph/search/dispatch.rs +++ b/diskann/src/graph/search/dispatch.rs @@ -14,6 +14,9 @@ use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; /// Each search type (graph search, range search, etc.) implements /// this trait to define its complete search behavior. The [`DiskANNIndex::search`] /// method delegates to the `dispatch` method. +/// +/// The `dispatch` method takes `&mut self` to support search types that need to +/// record state during execution (e.g., [`RecordedGraphSearch`] for path recording). pub trait SearchDispatch where DP: DataProvider, @@ -23,7 +26,7 @@ where /// Execute the search operation with full search logic. fn dispatch<'a>( - &'a self, + &'a mut self, index: &'a DiskANNIndex, strategy: &'a S, context: &'a DP::Context, diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index ecf77d6cb..a5d23c809 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -63,7 +63,7 @@ where type Output = SearchStats; fn dispatch<'a>( - &'a self, + &'a mut self, index: &'a DiskANNIndex, strategy: &'a S, context: &'a DP::Context, diff --git a/diskann/src/graph/search/graph_search.rs b/diskann/src/graph/search/graph_search.rs index de2d09fe1..748d80534 100644 --- a/diskann/src/graph/search/graph_search.rs +++ b/diskann/src/graph/search/graph_search.rs @@ -97,7 +97,7 @@ where type Output = SearchStats; fn dispatch<'a>( - &'a self, + &'a mut self, index: &'a DiskANNIndex, strategy: &'a S, context: &'a DP::Context, @@ -105,7 +105,7 @@ where output: &'a mut OB, ) -> impl SendFuture> { async move { - let graph_search = GraphSearch::from(*self); + let mut graph_search = GraphSearch::from(*self); graph_search.dispatch(index, strategy, context, query, output).await } } @@ -122,7 +122,7 @@ where type Output = SearchStats; fn dispatch<'a>( - &'a self, + &'a mut self, index: &'a DiskANNIndex, strategy: &'a S, context: &'a DP::Context, @@ -197,9 +197,63 @@ impl<'r, SR: Debug + ?Sized> Debug for RecordedGraphSearch<'r, SR> { } } -// Note: RecordedGraphSearch cannot implement SearchDispatch because it holds &mut recorder -// which conflicts with the shared reference semantics of dispatch. Users should call -// the search logic directly or use a Cell/RefCell pattern if needed. +impl<'r, DP, S, T, O, OB, SR> SearchDispatch for RecordedGraphSearch<'r, SR> +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, + SR: super::record::SearchRecord + ?Sized, +{ + type Output = SearchStats; + + fn dispatch<'a>( + &'a mut self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture> { + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(self.inner.l, start_ids.len()); + + let stats = index + .search_internal( + self.inner.beam_width, + &start_ids, + &mut accessor, + &computer, + &mut scratch, + self.recorder, + ) + .await?; + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch.best.iter().take(self.inner.l.into_usize()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} //============================================================================= // Tests diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 82d549950..4d6547aae 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -15,12 +15,12 @@ //! use diskann::graph::{GraphSearch, RangeSearch, MultihopSearch, SearchDispatch}; //! //! // Standard graph search -//! let params = GraphSearch::new(10, 100, None)?; -//! let stats = index.search(&strategy, &context, &query, ¶ms, &mut output).await?; +//! let mut params = GraphSearch::new(10, 100, None)?; +//! let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?; //! //! // Range search -//! let params = RangeSearch::new(100, 0.5)?; -//! let result = index.search(&strategy, &context, &query, ¶ms, &mut ()).await?; +//! let mut params = RangeSearch::new(100, 0.5)?; +//! let result = index.search(&strategy, &context, &query, &mut params, &mut ()).await?; //! println!("Found {} points within radius", result.ids.len()); //! ``` diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 0e25d9585..3afa43b19 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -72,7 +72,7 @@ where type Output = SearchStats; fn dispatch<'a>( - &'a self, + &'a mut self, index: &'a DiskANNIndex, strategy: &'a S, context: &'a DP::Context, diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 64b064016..7ac99a6c0 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -144,7 +144,7 @@ where type Output = RangeSearchOutput; fn dispatch<'a>( - &'a self, + &'a mut self, index: &'a DiskANNIndex, strategy: &'a S, context: &'a DP::Context, diff --git a/diskann/src/graph/test/cases/grid.rs b/diskann/src/graph/test/cases/grid.rs index 7c0c03d7f..790ee3c96 100644 --- a/diskann/src/graph/test/cases/grid.rs +++ b/diskann/src/graph/test/cases/grid.rs @@ -126,7 +126,7 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { // are correct. let index = setup_grid_search(grid, size); - let params = GraphSearch::new(10, 10, Some(beam_width)).unwrap(); + let mut params = GraphSearch::new(10, 10, Some(beam_width)).unwrap(); let context = test_provider::Context::new(); let mut neighbors = vec![Neighbor::::default(); params.k]; @@ -140,7 +140,7 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { &test_provider::Strategy::new(), &context, query.as_slice(), - ¶ms, + &mut params, &mut crate::neighbor::BackInserter::new(neighbors.as_mut_slice()), )) .unwrap(); From dbd34816091302522e2f631eb703c24e561fb6d4 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Feb 2026 23:30:35 +0530 Subject: [PATCH 04/11] refactor: replace manual Debug impls with derive - RecordedGraphSearch in graph_search.rs - MultihopSearch in multihop_search.rs --- diskann/src/graph/search/graph_search.rs | 9 +-------- diskann/src/graph/search/multihop_search.rs | 12 +----------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/diskann/src/graph/search/graph_search.rs b/diskann/src/graph/search/graph_search.rs index 748d80534..8e9bb917b 100644 --- a/diskann/src/graph/search/graph_search.rs +++ b/diskann/src/graph/search/graph_search.rs @@ -175,6 +175,7 @@ where /// Graph search with traversal path recording. /// /// Records the path taken during search for debugging or analysis. +#[derive(Debug)] pub struct RecordedGraphSearch<'r, SR: ?Sized> { /// Base graph search parameters. pub inner: GraphSearch, @@ -189,14 +190,6 @@ impl<'r, SR: ?Sized> RecordedGraphSearch<'r, SR> { } } -impl<'r, SR: Debug + ?Sized> Debug for RecordedGraphSearch<'r, SR> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RecordedGraphSearch") - .field("inner", &self.inner) - .finish_non_exhaustive() - } -} - impl<'r, DP, S, T, O, OB, SR> SearchDispatch for RecordedGraphSearch<'r, SR> where DP: DataProvider, diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 3afa43b19..c4100c894 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -5,8 +5,6 @@ //! Label-filtered search using multi-hop expansion. -use std::fmt::Debug; - use diskann_utils::future::{AssertSend, SendFuture}; use diskann_utils::Reborrow; use diskann_vector::PreprocessedDistanceFunction; @@ -35,6 +33,7 @@ use super::graph_search::GraphSearch; /// This search extends standard graph search by expanding through non-matching /// nodes to find matching neighbors. More efficient than flat search when the /// matching subset is reasonably large. +#[derive(Debug)] pub struct MultihopSearch<'q, InternalId> { /// Base graph search parameters. pub inner: GraphSearch, @@ -42,15 +41,6 @@ pub struct MultihopSearch<'q, InternalId> { pub label_evaluator: &'q dyn QueryLabelProvider, } -impl Debug for MultihopSearch<'_, InternalId> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MultihopSearch") - .field("inner", &self.inner) - .field("label_evaluator", self.label_evaluator) - .finish() - } -} - impl<'q, InternalId> MultihopSearch<'q, InternalId> { /// Create new multihop search parameters. pub fn new( From caac0f0e8df9a8016ec3b3fdbe3d8ad3908a5539 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Feb 2026 23:47:34 +0530 Subject: [PATCH 05/11] style: cargo fmt --- .../src/search/graph/multihop.rs | 6 +-- .../src/search/graph/range.rs | 5 +- .../src/search/provider/disk_provider.rs | 6 ++- diskann-providers/src/index/diskann_async.rs | 47 ++++++------------- diskann/src/graph/index.rs | 6 +-- diskann/src/graph/mod.rs | 4 +- diskann/src/graph/search/diverse_search.rs | 10 +++- diskann/src/graph/search/graph_search.rs | 9 +++- diskann/src/graph/search/multihop_search.rs | 34 +++++++++----- 9 files changed, 63 insertions(+), 64 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 08f09d920..368d1cdf1 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -111,10 +111,8 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let mut multihop_search = graph::MultihopSearch::new( - graph::GraphSearch::from(*parameters), - &*self.labels[index], - ); + let mut multihop_search = + graph::MultihopSearch::new(graph::GraphSearch::from(*parameters), &*self.labels[index]); let stats = self .index .search( diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index 2e164f75d..ebab8a0f2 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -115,7 +115,10 @@ where &mut (), ) .await?; - buffer.extend(std::iter::zip(result.ids.into_iter(), result.distances.into_iter())); + buffer.extend(std::iter::zip( + result.ids.into_iter(), + result.distances.into_iter(), + )); Ok(Metrics {}) } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 3e37e02bd..517932139 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -19,7 +19,8 @@ use diskann::{ graph::{ self, glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy}, - search_output_buffer, AdjacencyList, DiskANNIndex, GraphSearch, SearchOutputBuffer, SearchParams, + search_output_buffer, AdjacencyList, DiskANNIndex, GraphSearch, SearchOutputBuffer, + SearchParams, }, neighbor::Neighbor, provider::{ @@ -993,7 +994,8 @@ where &mut result_output_buffer, ))? } else { - let mut graph_search = GraphSearch::new(k_value, search_list_size as usize, beam_width)?; + let mut graph_search = + GraphSearch::new(k_value, search_list_size as usize, beam_width)?; self.runtime.block_on(self.index.search( &strategy, &DefaultContext, diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 147b49492..a50758612 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -238,11 +238,11 @@ pub(crate) mod tests { O: Send, OB: graph::search_output_buffer::SearchOutputBuffer + Send, { - let mut multihop = graph::MultihopSearch::new( - graph::GraphSearch::from(*search_params), - filter, - ); - index.search(strategy, context, query, &mut multihop, output).await + let mut multihop = + graph::MultihopSearch::new(graph::GraphSearch::from(*search_params), filter); + index + .search(strategy, context, query, &mut multihop, output) + .await } /// Test helper: performs range search using the dispatch API. @@ -260,7 +260,9 @@ pub(crate) mod tests { O: Send + Default + Clone, { let mut range_search = graph::RangeSearch::from(*search_params); - let result = index.search(strategy, context, query, &mut range_search, &mut ()).await?; + let result = index + .search(strategy, context, query, &mut range_search, &mut ()) + .await?; Ok((result.stats, result.ids, result.distances)) } @@ -2513,8 +2515,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = - SearchParams::new_default(top_k, search_l).unwrap(); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( @@ -2532,8 +2533,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = - SearchParams::new_default(top_k, search_l).unwrap(); + let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); // Quantized Search index .search( @@ -2620,8 +2620,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = - SearchParams::new_default(top_k, top_k).unwrap(); + let mut search_params = SearchParams::new_default(top_k, top_k).unwrap(); // Quantized Search index .search( @@ -2737,13 +2736,7 @@ pub(crate) mod tests { let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); index - .search( - &FullPrecision, - ctx, - query, - &mut search_params, - &mut output, - ) + .search(&FullPrecision, ctx, query, &mut search_params, &mut output) .await .unwrap(); assert_top_k_exactly_match(q, >, &ids, &distances, top_k); @@ -2756,13 +2749,7 @@ pub(crate) mod tests { let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); index - .search( - &strategy, - ctx, - query, - &mut search_params, - &mut output, - ) + .search(&strategy, ctx, query, &mut search_params, &mut output) .await .unwrap(); assert_top_k_exactly_match(q, >, &ids, &distances, top_k); @@ -2865,13 +2852,7 @@ pub(crate) mod tests { let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); index - .search( - &strategy, - ctx, - query, - &mut search_params, - &mut output, - ) + .search(&strategy, ctx, query, &mut search_params, &mut output) .await .unwrap(); diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index c151fad44..4af44e1e4 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -26,9 +26,8 @@ use tokio::task::JoinSet; use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, SearchParams, glue::{ - self, AsElement, ExpandBeam, FillSet, IdIterator, InplaceDeleteStrategy, - InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, - SearchStrategy, aliases, + self, AsElement, ExpandBeam, FillSet, IdIterator, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ @@ -3116,4 +3115,3 @@ impl InternalSearchStats { } } } - diff --git a/diskann/src/graph/mod.rs b/diskann/src/graph/mod.rs index 2622e7ecb..9865e9b80 100644 --- a/diskann/src/graph/mod.rs +++ b/diskann/src/graph/mod.rs @@ -33,9 +33,7 @@ pub mod glue; pub mod search; // Re-export unified search interface as the primary API. -pub use search::{ - GraphSearch, MultihopSearch, RangeSearch, RangeSearchOutput, SearchDispatch, -}; +pub use search::{GraphSearch, MultihopSearch, RangeSearch, RangeSearchOutput, SearchDispatch}; #[cfg(feature = "experimental_diversity_search")] pub use search::DiverseSearch; diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index a5d23c809..4e036c0ef 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -12,7 +12,10 @@ use std::num::NonZeroUsize; use diskann_utils::future::{AssertSend, SendFuture}; use hashbrown::HashSet; -use super::{dispatch::SearchDispatch, graph_search::GraphSearch, record::NoopSearchRecord, scratch::SearchScratch}; +use super::{ + dispatch::SearchDispatch, graph_search::GraphSearch, record::NoopSearchRecord, + scratch::SearchScratch, +}; use crate::{ ANNResult, error::IntoANNResult, @@ -47,7 +50,10 @@ where { /// Create new diverse search parameters. pub fn new(inner: GraphSearch, diverse_params: DiverseSearchParams

) -> Self { - Self { inner, diverse_params } + Self { + inner, + diverse_params, + } } } diff --git a/diskann/src/graph/search/graph_search.rs b/diskann/src/graph/search/graph_search.rs index 8e9bb917b..2c3269cec 100644 --- a/diskann/src/graph/search/graph_search.rs +++ b/diskann/src/graph/search/graph_search.rs @@ -51,7 +51,10 @@ impl GraphSearch { use super::super::SearchParamsError; if k > l { - return Err(SearchParamsError::LLessThanK { l_value: l, k_value: k }); + return Err(SearchParamsError::LLessThanK { + l_value: l, + k_value: k, + }); } if let Some(bw) = beam_width { if bw == 0 { @@ -106,7 +109,9 @@ where ) -> impl SendFuture> { async move { let mut graph_search = GraphSearch::from(*self); - graph_search.dispatch(index, strategy, context, query, output).await + graph_search + .dispatch(index, strategy, context, query, output) + .await } } } diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index c4100c894..eca94e06f 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -5,8 +5,8 @@ //! Label-filtered search using multi-hop expansion. -use diskann_utils::future::{AssertSend, SendFuture}; use diskann_utils::Reborrow; +use diskann_utils::future::{AssertSend, SendFuture}; use diskann_vector::PreprocessedDistanceFunction; use hashbrown::HashSet; @@ -16,8 +16,13 @@ use crate::{ error::{ErrorExt, IntoANNResult}, graph::{ SearchParams, - glue::{self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, SearchPostProcess, SearchStrategy}, - index::{DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats}, + glue::{ + self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, + SearchPostProcess, SearchStrategy, + }, + index::{ + DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats, + }, search::record::NoopSearchRecord, search_output_buffer::SearchOutputBuffer, }, @@ -47,7 +52,10 @@ impl<'q, InternalId> MultihopSearch<'q, InternalId> { inner: GraphSearch, label_evaluator: &'q dyn QueryLabelProvider, ) -> Self { - Self { inner, label_evaluator } + Self { + inner, + label_evaluator, + } } } @@ -85,15 +93,15 @@ where let mut scratch = index.search_scratch(params.l_value, start_ids.len()); let stats = multihop_search_internal( - index.max_degree_with_slack(), - ¶ms, - &mut accessor, - &computer, - &mut scratch, - &mut NoopSearchRecord::new(), - self.label_evaluator, - ) - .await?; + index.max_degree_with_slack(), + ¶ms, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + self.label_evaluator, + ) + .await?; let result_count = strategy .post_processor() From 9111e03c7271bb9b8b44de8aab8f98bd2d6ec030 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Fri, 13 Feb 2026 16:17:13 +0530 Subject: [PATCH 06/11] refactor: apply PR code review feedback - Rename SearchDispatch trait to Search and move to mod.rs - Remove unnecessary doc comment about &mut self (can be inferred) - Remove 'feature-gated' mention from diverse_search.rs module doc - Fix banner styles to use ///// format instead of //=== - Make create_diverse_scratch an inherent method on DiverseSearch - Remove flat_search mention from index.search docs - Redirect debug_search to use RecordedGraphSearch internally - Revert version bump (0.46.0 -> 0.45.0) Co-authored-by: hildebrandmw --- Cargo.lock | 30 +++---- Cargo.toml | 28 +++--- diskann-utils/Cargo.toml | 2 +- diskann/src/graph/index.rs | 58 +++---------- diskann/src/graph/mod.rs | 2 +- diskann/src/graph/search/dispatch.rs | 36 -------- diskann/src/graph/search/diverse_search.rs | 96 ++++++++------------- diskann/src/graph/search/graph_search.rs | 22 ++--- diskann/src/graph/search/mod.rs | 33 +++++-- diskann/src/graph/search/multihop_search.rs | 10 +-- diskann/src/graph/search/range_search.rs | 16 ++-- 11 files changed, 134 insertions(+), 199 deletions(-) delete mode 100644 diskann/src/graph/search/dispatch.rs diff --git a/Cargo.lock b/Cargo.lock index b65475315..9e2ba2409 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -403,7 +403,7 @@ dependencies = [ [[package]] name = "diskann" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "bytemuck", @@ -427,7 +427,7 @@ dependencies = [ [[package]] name = "diskann-benchmark" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "bf-tree", @@ -464,7 +464,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-core" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "diskann", @@ -481,7 +481,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-runner" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "clap", @@ -495,7 +495,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-simd" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "diskann-benchmark-runner", @@ -512,7 +512,7 @@ dependencies = [ [[package]] name = "diskann-disk" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "bincode", @@ -547,7 +547,7 @@ dependencies = [ [[package]] name = "diskann-label-filter" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "bf-tree", @@ -570,7 +570,7 @@ dependencies = [ [[package]] name = "diskann-linalg" -version = "0.46.0" +version = "0.45.0" dependencies = [ "approx", "cfg-if", @@ -584,7 +584,7 @@ dependencies = [ [[package]] name = "diskann-platform" -version = "0.46.0" +version = "0.45.0" dependencies = [ "io-uring", "libc", @@ -594,7 +594,7 @@ dependencies = [ [[package]] name = "diskann-providers" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "approx", @@ -638,7 +638,7 @@ dependencies = [ [[package]] name = "diskann-quantization" -version = "0.46.0" +version = "0.45.0" dependencies = [ "bytemuck", "cfg-if", @@ -657,7 +657,7 @@ dependencies = [ [[package]] name = "diskann-tools" -version = "0.46.0" +version = "0.45.0" dependencies = [ "anyhow", "bincode", @@ -689,7 +689,7 @@ dependencies = [ [[package]] name = "diskann-utils" -version = "0.46.0" +version = "0.45.0" dependencies = [ "cfg-if", "diskann-vector", @@ -703,7 +703,7 @@ dependencies = [ [[package]] name = "diskann-vector" -version = "0.46.0" +version = "0.45.0" dependencies = [ "approx", "cfg-if", @@ -717,7 +717,7 @@ dependencies = [ [[package]] name = "diskann-wide" -version = "0.46.0" +version = "0.45.0" dependencies = [ "cfg-if", "half", diff --git a/Cargo.toml b/Cargo.toml index ce3e47293..bd9e75037 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ default-members = [ resolver = "3" [workspace.package] -version = "0.46.0" # Obeying semver +version = "0.45.0" # Obeying semver description = "DiskANN is a fast approximate nearest neighbor search library for high dimensional data" authors = ["Microsoft"] documentation = "https://github.com/microsoft/DiskANN" @@ -46,22 +46,22 @@ undocumented_unsafe_blocks = "warn" [workspace.dependencies] # Base And Numerics -diskann-wide = { path = "diskann-wide", version = "0.46.0" } -diskann-vector = { path = "diskann-vector", version = "0.46.0" } -diskann-linalg = { path = "diskann-linalg", version = "0.46.0" } -diskann-utils = { path = "diskann-utils", default-features = false, version = "0.46.0" } -diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.46.0" } -diskann-platform = { path = "diskann-platform", version = "0.46.0" } +diskann-wide = { path = "diskann-wide", version = "0.45.0" } +diskann-vector = { path = "diskann-vector", version = "0.45.0" } +diskann-linalg = { path = "diskann-linalg", version = "0.45.0" } +diskann-utils = { path = "diskann-utils", default-features = false, version = "0.45.0" } +diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.45.0" } +diskann-platform = { path = "diskann-platform", version = "0.45.0" } # Algorithm -diskann = { path = "diskann", version = "0.46.0" } +diskann = { path = "diskann", version = "0.45.0" } # Providers -diskann-providers = { path = "diskann-providers", default-features = false, version = "0.46.0" } -diskann-disk = { path = "diskann-disk", version = "0.46.0" } -diskann-label-filter = { path = "diskann-label-filter", version = "0.46.0" } +diskann-providers = { path = "diskann-providers", default-features = false, version = "0.45.0" } +diskann-disk = { path = "diskann-disk", version = "0.45.0" } +diskann-label-filter = { path = "diskann-label-filter", version = "0.45.0" } # Infra -diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.46.0" } -diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.46.0" } -diskann-tools = { path = "diskann-tools", version = "0.46.0" } +diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.45.0" } +diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.45.0" } +diskann-tools = { path = "diskann-tools", version = "0.45.0" } # External dependencies (shared versions) anyhow = "1.0.98" diff --git a/diskann-utils/Cargo.toml b/diskann-utils/Cargo.toml index e2dc5ee12..06c161978 100644 --- a/diskann-utils/Cargo.toml +++ b/diskann-utils/Cargo.toml @@ -38,4 +38,4 @@ default = ["rayon"] # Enable Rayon-based Parallelism for tagged kernels. rayon = ["dep:rayon"] # Enable testing utilities like test_data_root() -testing = [] \ No newline at end of file +testing = [] diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 4af44e1e4..2260864b5 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -31,6 +31,7 @@ use super::{ }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ + Search, record::{NoopSearchRecord, SearchRecord, VisitedSearchRecord}, scratch::{self, PriorityQueueConfiguration, SearchScratch, SearchScratchParams}, }, @@ -2115,10 +2116,10 @@ where } } - /// Execute a search using the unified search dispatch interface. + /// Execute a search using the unified search interface. /// - /// This method provides a single entry point for all search types. The `parameters` argument - /// implements [`search::SearchDispatch`], which defines the complete search behavior including + /// This method provides a single entry point for all search types. The `search_params` argument + /// implements [`search::Search`], which defines the complete search behavior including /// algorithm selection and post-processing. /// /// # Supported Search Types @@ -2128,16 +2129,13 @@ where /// - [`search::RangeSearch`]: Range-based search within a distance radius /// - [`search::DiverseSearch`]: Diversity-aware search (feature-gated) /// - /// For flat (brute-force) search, use [`Self::flat_search`] directly due to its - /// unique iterator type constraints. - /// /// # Example /// /// ```ignore - /// use diskann::graph::{GraphSearch, RangeSearch, SearchDispatch}; + /// use diskann::graph::{GraphSearch, RangeSearch, Search}; /// /// // Standard graph search - /// let params = GraphSearch::new(10, 100, None)?; + /// let mut params = GraphSearch::new(10, 100, None)?; /// let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?; /// /// // Range search (note: uses () as output buffer, results in Output type) @@ -2154,7 +2152,7 @@ where output: &'a mut OB, ) -> impl SendFuture> + 'a where - P: super::search::SearchDispatch, + P: super::search::Search, T: ?Sized, OB: ?Sized, { @@ -2183,44 +2181,16 @@ where S: SearchStrategy, O: Send + 'a, OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, - SR: SearchRecord + Send, + SR: SearchRecord, { async move { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - let graph_search = super::search::GraphSearch::from(*search_params); - let mut scratch = self.search_scratch(graph_search.l, start_ids.len()); - - let stats = self - .search_internal( - graph_search.beam_width, - &start_ids, - &mut accessor, - &computer, - &mut scratch, - search_record, - ) - .await?; - - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(graph_search.l.into_usize()), - output, - ) - .send() + let mut recorded_search = super::search::RecordedGraphSearch::new( + super::search::GraphSearch::from(*search_params), + search_record, + ); + recorded_search + .dispatch(self, strategy, context, query, output) .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) } } diff --git a/diskann/src/graph/mod.rs b/diskann/src/graph/mod.rs index 9865e9b80..ae97ba869 100644 --- a/diskann/src/graph/mod.rs +++ b/diskann/src/graph/mod.rs @@ -33,7 +33,7 @@ pub mod glue; pub mod search; // Re-export unified search interface as the primary API. -pub use search::{GraphSearch, MultihopSearch, RangeSearch, RangeSearchOutput, SearchDispatch}; +pub use search::{GraphSearch, MultihopSearch, RangeSearch, RangeSearchOutput, Search}; #[cfg(feature = "experimental_diversity_search")] pub use search::DiverseSearch; diff --git a/diskann/src/graph/search/dispatch.rs b/diskann/src/graph/search/dispatch.rs deleted file mode 100644 index 79e10c31e..000000000 --- a/diskann/src/graph/search/dispatch.rs +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -//! Core search dispatch trait. - -use diskann_utils::future::SendFuture; - -use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; - -/// Trait for search parameter types that execute their own search logic. -/// -/// Each search type (graph search, range search, etc.) implements -/// this trait to define its complete search behavior. The [`DiskANNIndex::search`] -/// method delegates to the `dispatch` method. -/// -/// The `dispatch` method takes `&mut self` to support search types that need to -/// record state during execution (e.g., [`RecordedGraphSearch`] for path recording). -pub trait SearchDispatch -where - DP: DataProvider, -{ - /// The result type returned by this search. - type Output; - - /// Execute the search operation with full search logic. - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, - ) -> impl SendFuture>; -} diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index 4e036c0ef..0b138a65c 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -//! Diversity-aware search (feature-gated). +//! Diversity-aware search. #![cfg(feature = "experimental_diversity_search")] @@ -12,10 +12,7 @@ use std::num::NonZeroUsize; use diskann_utils::future::{AssertSend, SendFuture}; use hashbrown::HashSet; -use super::{ - dispatch::SearchDispatch, graph_search::GraphSearch, record::NoopSearchRecord, - scratch::SearchScratch, -}; +use super::{Search, graph_search::GraphSearch, record::NoopSearchRecord, scratch::SearchScratch}; use crate::{ ANNResult, error::IntoANNResult, @@ -55,9 +52,42 @@ where diverse_params, } } + + /// Create search scratch with DiverseNeighborQueue for this search. + fn create_scratch( + &self, + index: &DiskANNIndex, + ) -> SearchScratch> + where + DP: DataProvider, + P: AttributeValueProvider, + { + let attribute_provider = self.diverse_params.attribute_provider.clone(); + let diverse_queue = DiverseNeighborQueue::new( + self.inner.l, + // SAFETY: k_value is guaranteed to be non-zero by GraphSearch validation + #[allow(clippy::expect_used)] + NonZeroUsize::new(self.inner.k).expect("k_value must be non-zero"), + self.diverse_params.diverse_results_k, + attribute_provider, + ); + + SearchScratch { + best: diverse_queue, + visited: HashSet::with_capacity( + index.estimate_visited_set_capacity(Some(self.inner.l)), + ), + id_scratch: Vec::with_capacity(index.max_degree_with_slack()), + beam_nodes: Vec::with_capacity(self.inner.beam_width.unwrap_or(1)), + range_frontier: std::collections::VecDeque::new(), + in_range: Vec::new(), + hops: 0, + cmps: 0, + } + } } -impl SearchDispatch for DiverseSearch

+impl Search for DiverseSearch

where DP: DataProvider, T: Sync + ?Sized, @@ -84,13 +114,7 @@ where let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; - let mut diverse_scratch = create_diverse_scratch( - index, - self.inner.l, - self.inner.beam_width, - &self.diverse_params, - self.inner.k, - ); + let mut diverse_scratch = self.create_scratch(index); let stats = index .search_internal( @@ -123,49 +147,3 @@ where } } } - -//============================================================================= -// Internal Implementation -//============================================================================= - -/// Create a diverse search scratch with DiverseNeighborQueue. -/// -/// # Arguments -/// -/// * `index` - The DiskANN index for capacity estimation -/// * `l_value` - Search list size -/// * `beam_width` - Optional beam width for parallel exploration -/// * `diverse_params` - Diversity-specific parameters -/// * `k_value` - Number of results to return -pub(crate) fn create_diverse_scratch( - index: &DiskANNIndex, - l_value: usize, - beam_width: Option, - diverse_params: &DiverseSearchParams

, - k_value: usize, -) -> SearchScratch> -where - DP: DataProvider, - P: AttributeValueProvider, -{ - let attribute_provider = diverse_params.attribute_provider.clone(); - let diverse_queue = DiverseNeighborQueue::new( - l_value, - // SAFETY: k_value is guaranteed to be non-zero by SearchParams validation by caller - #[allow(clippy::expect_used)] - NonZeroUsize::new(k_value).expect("k_value must be non-zero"), - diverse_params.diverse_results_k, - attribute_provider, - ); - - SearchScratch { - best: diverse_queue, - visited: HashSet::with_capacity(index.estimate_visited_set_capacity(Some(l_value))), - id_scratch: Vec::with_capacity(index.max_degree_with_slack()), - beam_nodes: Vec::with_capacity(beam_width.unwrap_or(1)), - range_frontier: std::collections::VecDeque::new(), - in_range: Vec::new(), - hops: 0, - cmps: 0, - } -} diff --git a/diskann/src/graph/search/graph_search.rs b/diskann/src/graph/search/graph_search.rs index 2c3269cec..876ff4bb9 100644 --- a/diskann/src/graph/search/graph_search.rs +++ b/diskann/src/graph/search/graph_search.rs @@ -9,7 +9,7 @@ use std::fmt::Debug; use diskann_utils::future::{AssertSend, SendFuture}; -use super::dispatch::SearchDispatch; +use super::Search; use crate::{ ANNResult, error::IntoANNResult, @@ -87,9 +87,9 @@ impl From for GraphSearch { } } -/// Implement SearchDispatch for SearchParams to provide backwards compatibility. +/// Implement Search for SearchParams to provide backwards compatibility. /// This treats SearchParams as an alias for GraphSearch. -impl SearchDispatch for super::super::SearchParams +impl Search for super::super::SearchParams where DP: DataProvider, T: Sync + ?Sized, @@ -116,7 +116,7 @@ where } } -impl SearchDispatch for GraphSearch +impl Search for GraphSearch where DP: DataProvider, T: Sync + ?Sized, @@ -173,9 +173,9 @@ where } } -//============================================================================= -// Recorded Graph Search -//============================================================================= +/////////////////////////// +// Recorded Graph Search // +/////////////////////////// /// Graph search with traversal path recording. /// @@ -195,7 +195,7 @@ impl<'r, SR: ?Sized> RecordedGraphSearch<'r, SR> { } } -impl<'r, DP, S, T, O, OB, SR> SearchDispatch for RecordedGraphSearch<'r, SR> +impl<'r, DP, S, T, O, OB, SR> Search for RecordedGraphSearch<'r, SR> where DP: DataProvider, T: Sync + ?Sized, @@ -253,9 +253,9 @@ where } } -//============================================================================= -// Tests -//============================================================================= +/////////// +// Tests // +/////////// #[cfg(test)] mod tests { diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 4d6547aae..1f21929b0 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -6,13 +6,13 @@ //! Unified search execution framework. //! //! This module provides the primary search interface for DiskANN. All search types -//! are represented as parameter structs that implement [`SearchDispatch`], which +//! are represented as parameter structs that implement [`Search`], which //! contains the complete search logic. //! //! # Usage //! //! ```ignore -//! use diskann::graph::{GraphSearch, RangeSearch, MultihopSearch, SearchDispatch}; +//! use diskann::graph::{GraphSearch, RangeSearch, MultihopSearch, Search}; //! //! // Standard graph search //! let mut params = GraphSearch::new(10, 100, None)?; @@ -24,7 +24,10 @@ //! println!("Found {} points within radius", result.ids.len()); //! ``` -mod dispatch; +use diskann_utils::future::SendFuture; + +use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; + mod graph_search; mod multihop_search; mod range_search; @@ -32,8 +35,28 @@ mod range_search; pub mod record; pub(crate) mod scratch; -// Re-export the core dispatch trait. -pub use dispatch::SearchDispatch; +/// Trait for search parameter types that execute their own search logic. +/// +/// Each search type (graph search, range search, etc.) implements this trait +/// to define its complete search behavior. The [`DiskANNIndex::search`] method +/// delegates to the `dispatch` method. +pub trait Search +where + DP: DataProvider, +{ + /// The result type returned by this search. + type Output; + + /// Execute the search operation with full search logic. + fn dispatch<'a>( + &'a mut self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture>; +} // Re-export search parameter types. pub use graph_search::{GraphSearch, RecordedGraphSearch}; diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index eca94e06f..9c70561f7 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -10,7 +10,7 @@ use diskann_utils::future::{AssertSend, SendFuture}; use diskann_vector::PreprocessedDistanceFunction; use hashbrown::HashSet; -use super::{dispatch::SearchDispatch, record::SearchRecord, scratch::SearchScratch}; +use super::{Search, record::SearchRecord, scratch::SearchScratch}; use crate::{ ANNResult, error::{ErrorExt, IntoANNResult}, @@ -59,7 +59,7 @@ impl<'q, InternalId> MultihopSearch<'q, InternalId> { } } -impl<'q, DP, S, T, O, OB> SearchDispatch for MultihopSearch<'q, DP::InternalId> +impl<'q, DP, S, T, O, OB> Search for MultihopSearch<'q, DP::InternalId> where DP: DataProvider, T: Sync + ?Sized, @@ -121,9 +121,9 @@ where } } -//============================================================================= -// Internal Implementation -//============================================================================= +///////////////////////////// +// Internal Implementation // +///////////////////////////// /// A predicate that checks if an item is not in the visited set AND matches the label filter. /// diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 7ac99a6c0..29f5f2440 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -7,7 +7,7 @@ use diskann_utils::future::{AssertSend, SendFuture}; -use super::{dispatch::SearchDispatch, scratch::SearchScratch}; +use super::{Search, scratch::SearchScratch}; use crate::{ ANNResult, error::IntoANNResult, @@ -134,7 +134,7 @@ impl From for RangeSearch { } } -impl SearchDispatch for RangeSearch +impl Search for RangeSearch where DP: DataProvider, T: Sync + ?Sized, @@ -273,9 +273,9 @@ where } } -//============================================================================= -// Internal Implementation -//============================================================================= +///////////////////////////// +// Internal Implementation // +///////////////////////////// /// Internal range search implementation. /// @@ -345,9 +345,9 @@ where }) } -//============================================================================= -// Tests -//============================================================================= +/////////// +// Tests // +/////////// #[cfg(test)] mod tests { From 015eb976172474499f25445439b1a2517b46ca51 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Fri, 13 Feb 2026 19:46:32 +0530 Subject: [PATCH 07/11] Refactor search interface: rename GraphSearch to KnnSearch with NonZeroUsize - Rename GraphSearch to KnnSearch for clarity (k-NN search) - Change k_value and l_value from usize to NonZeroUsize (compile-time zero check) - Rename RecordedGraphSearch to RecordedKnnSearch - Move RangeSearch validation from misc.rs to range_search.rs - Remove deprecated SearchParams, SearchParamsError, RangeSearchParams, RangeSearchParamsError - Update all callers across diskann, diskann-providers, diskann-disk, diskann-benchmark crates - Remove ensure_positive helper (no longer needed with NonZeroUsize) - Fix duplicated cfg attribute in diverse_search.rs --- .../src/search/graph/knn.rs | 36 +- .../src/search/graph/multihop.rs | 29 +- .../src/search/graph/range.rs | 22 +- diskann-benchmark/src/backend/index/result.rs | 6 +- .../src/backend/index/search/knn.rs | 15 +- .../src/backend/index/search/range.rs | 6 +- diskann-benchmark/src/inputs/async_.rs | 10 +- .../src/search/provider/disk_provider.rs | 48 +-- diskann-providers/src/index/diskann_async.rs | 116 ++++--- diskann-providers/src/index/wrapped_async.rs | 8 +- diskann/src/error/ann_error.rs | 11 - diskann/src/error/mod.rs | 1 - diskann/src/graph/index.rs | 28 +- diskann/src/graph/misc.rs | 210 ------------ diskann/src/graph/mod.rs | 10 +- diskann/src/graph/search/diverse_search.rs | 27 +- diskann/src/graph/search/graph_search.rs | 90 ++--- diskann/src/graph/search/knn_search.rs | 309 ++++++++++++++++++ diskann/src/graph/search/mod.rs | 12 +- diskann/src/graph/search/multihop_search.rs | 28 +- diskann/src/graph/search/range_search.rs | 157 +++++---- diskann/src/graph/test/cases/grid.rs | 15 +- 22 files changed, 653 insertions(+), 541 deletions(-) create mode 100644 diskann/src/graph/search/knn_search.rs diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 9e106b133..71031a944 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -5,7 +5,7 @@ //! A built-in helper for benchmarking K-nearest neighbors. -use std::{num::NonZeroUsize, sync::Arc}; +use std::sync::Arc; use diskann::{ ANNResult, @@ -29,7 +29,7 @@ use crate::{ /// the latter. Result aggregation for [`search::search_all`] is provided /// by the [`Aggregator`] type. /// -/// The provided implementation of [`Search`] accepts [`graph::SearchParams`] +/// The provided implementation of [`Search`] accepts [`graph::KnnSearch`] /// and returns [`Metrics`] as additional output. #[derive(Debug)] pub struct KNN @@ -92,7 +92,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::SearchParams; + type Parameters = graph::KnnSearch; type Output = Metrics; fn num_queries(&self) -> usize { @@ -100,7 +100,7 @@ where } fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Fixed(NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE)) + search::IdCount::Fixed(parameters.k_value()) } async fn search( @@ -113,14 +113,14 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let mut graph_search = graph::GraphSearch::from(*parameters); + let mut knn_search = *parameters; let stats = self .index .search( self.strategy.get(index)?, &context, self.queries.row(index), - &mut graph_search, + &mut knn_search, buffer, ) .await?; @@ -143,7 +143,7 @@ pub struct Summary { pub setup: search::Setup, /// The [`Search::Parameters`] used for the batch of runs. - pub parameters: graph::SearchParams, + pub parameters: graph::KnnSearch, /// The end-to-end latency for each repetition in the batch. pub end_to_end_latencies: Vec, @@ -208,7 +208,7 @@ impl<'a, I> Aggregator<'a, I> { } } -impl search::Aggregate for Aggregator<'_, I> +impl search::Aggregate for Aggregator<'_, I> where I: crate::recall::RecallCompatible, { @@ -216,7 +216,7 @@ where fn aggregate( &mut self, - run: search::Run, + run: search::Run, mut results: Vec>, ) -> anyhow::Result

{ // Compute the recall using just the first result. @@ -281,13 +281,15 @@ where #[cfg(test)] mod tests { + use std::num::NonZeroUsize; + use super::*; use diskann::graph::test::provider; #[test] fn test_knn() { - let nearest_neighbors = 5; + let nearest_neighbors = NonZeroUsize::new(5).unwrap(); let index = search::graph::test_grid_provider(); @@ -311,7 +313,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( knn.clone(), - graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(10).unwrap(), None).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -322,7 +324,7 @@ mod tests { assert_eq!(*rows.row(0).first().unwrap(), 0); for r in 0..rows.nrows() { - assert_eq!(rows.row(r).len(), nearest_neighbors); + assert_eq!(rows.row(r).len(), nearest_neighbors.get()); } const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap(); @@ -335,17 +337,19 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(10).unwrap(), None) + .unwrap(), setup.clone(), ), search::Run::new( - graph::SearchParams::new(nearest_neighbors, 15, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(15).unwrap(), None) + .unwrap(), setup.clone(), ), ]; - let recall_k = nearest_neighbors; - let recall_n = nearest_neighbors; + let recall_k = nearest_neighbors.get(); + let recall_n = nearest_neighbors.get(); let all = search::search_all(knn, parameters, Aggregator::new(rows, recall_k, recall_n)).unwrap(); diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 368d1cdf1..92741311b 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use std::{num::NonZeroUsize, sync::Arc}; +use std::sync::Arc; use diskann::{ ANNResult, @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy}; /// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same /// aggregator as [`search::graph::KNN`]). /// -/// The provided implementation of [`Search`] accepts [`graph::SearchParams`] +/// The provided implementation of [`Search`] accepts [`graph::KnnSearch`] /// and returns [`search::graph::knn::Metrics`] as additional output. #[derive(Debug)] pub struct MultiHop @@ -90,7 +90,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::SearchParams; + type Parameters = graph::KnnSearch; type Output = super::knn::Metrics; fn num_queries(&self) -> usize { @@ -98,7 +98,7 @@ where } fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Fixed(NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE)) + search::IdCount::Fixed(parameters.k_value()) } async fn search( @@ -111,8 +111,7 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let mut multihop_search = - graph::MultihopSearch::new(graph::GraphSearch::from(*parameters), &*self.labels[index]); + let mut multihop_search = graph::MultihopSearch::new(*parameters, &*self.labels[index]); let stats = self .index .search( @@ -137,6 +136,8 @@ where #[cfg(test)] mod tests { + use std::num::NonZeroUsize; + use super::*; use diskann::graph::{index::QueryLabelProvider, test::provider}; @@ -153,7 +154,7 @@ mod tests { #[test] fn test_multihop() { - let nearest_neighbors = 5; + let nearest_neighbors = NonZeroUsize::new(5).unwrap(); let index = search::graph::test_grid_provider(); @@ -180,7 +181,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( multihop.clone(), - graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(10).unwrap(), None).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -192,7 +193,7 @@ mod tests { // Check that only even IDs are returned. for r in 0..rows.nrows() { - assert_eq!(rows.row(r).len(), nearest_neighbors); + assert_eq!(rows.row(r).len(), nearest_neighbors.get()); for &id in rows.row(r) { assert_eq!(id % 2, 0, "Found odd ID {} in row {}", id, r); } @@ -208,17 +209,19 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(10).unwrap(), None) + .unwrap(), setup.clone(), ), search::Run::new( - graph::SearchParams::new(nearest_neighbors, 15, None).unwrap(), + graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(15).unwrap(), None) + .unwrap(), setup.clone(), ), ]; - let recall_k = nearest_neighbors; - let recall_n = nearest_neighbors; + let recall_k = nearest_neighbors.get(); + let recall_n = nearest_neighbors.get(); let all = search::search_all( multihop, diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index ebab8a0f2..f82064f6e 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -27,7 +27,7 @@ use crate::{ /// by the [`Aggregator`] type. /// /// The provided implementation of [`Search`] accepts -/// [`graph::RangeSearchParams`] and returns [`Metrics`] as additional output. +/// [`graph::RangeSearch`] and returns [`Metrics`] as additional output. #[derive(Debug)] pub struct Range where @@ -83,7 +83,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::RangeSearchParams; + type Parameters = graph::RangeSearch; type Output = Metrics; fn num_queries(&self) -> usize { @@ -91,7 +91,7 @@ where } fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Dynamic(NonZeroUsize::new(parameters.starting_l_value)) + search::IdCount::Dynamic(NonZeroUsize::new(parameters.starting_l())) } async fn search( @@ -104,7 +104,7 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let mut range_search = graph::RangeSearch::from(*parameters); + let mut range_search = *parameters; let result = self .index .search( @@ -134,8 +134,8 @@ pub struct Summary { /// The [`search::Setup`] used for the batch of runs. pub setup: search::Setup, - /// The [`graph::RangeSearchParams`] used for the batch of runs. - pub parameters: graph::RangeSearchParams, + /// The [`graph::RangeSearch`] used for the batch of runs. + pub parameters: graph::RangeSearch, /// The end-to-end latency for each repetition in the batch. pub end_to_end_latencies: Vec, @@ -179,7 +179,7 @@ impl<'a, I> Aggregator<'a, I> { } } -impl search::Aggregate for Aggregator<'_, I> +impl search::Aggregate for Aggregator<'_, I> where I: crate::recall::RecallCompatible, { @@ -188,7 +188,7 @@ where #[inline(never)] fn aggregate( &mut self, - run: search::Run, + run: search::Run, mut results: Vec>, ) -> anyhow::Result { // Compute the recall using just the first result. @@ -266,7 +266,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( range.clone(), - graph::RangeSearchParams::new(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(), + graph::RangeSearch::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -285,11 +285,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::RangeSearchParams::new(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(), + graph::RangeSearch::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(), setup.clone(), ), search::Run::new( - graph::RangeSearchParams::new(None, 15, None, 2.0, None, 0.8, 1.2).unwrap(), + graph::RangeSearch::with_options(None, 15, None, 2.0, None, 0.8, 1.2).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index c7e2ab75c..1d6102f9b 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -143,8 +143,8 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), - search_n: parameters.k_value, - search_l: parameters.l_value, + search_n: parameters.k_value().get(), + search_l: parameters.l_value().get(), qps, search_latencies: end_to_end_latencies, mean_latencies, @@ -284,7 +284,7 @@ impl RangeSearchResults { Self { num_tasks: setup.tasks.into(), - initial_l: parameters.starting_l_value, + initial_l: parameters.starting_l(), qps, search_latencies: end_to_end_latencies, mean_latencies, diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 723d32155..0f32b0b2c 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -49,8 +49,9 @@ pub(crate) fn run( .search_l .iter() .map(|search_l| { - let search_params = - diskann::graph::SearchParams::new(run.search_n, *search_l, None).unwrap(); + let k = NonZeroUsize::new(run.search_n).expect("search_n must be non-zero"); + let l = NonZeroUsize::new(*search_l).expect("search_l must be non-zero"); + let search_params = diskann::graph::KnnSearch::new(k, l, None).unwrap(); core_search::Run::new(search_params, setup.clone()) }) @@ -63,7 +64,7 @@ pub(crate) fn run( Ok(all) } -type Run = core_search::Run; +type Run = core_search::Run; pub(crate) trait Knn { fn search_all( &self, @@ -83,13 +84,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::KNN: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::SearchParams, + Parameters = diskann::graph::KnnSearch, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, @@ -109,13 +110,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::MultiHop: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::SearchParams, + Parameters = diskann::graph::KnnSearch, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, diff --git a/diskann-benchmark/src/backend/index/search/range.rs b/diskann-benchmark/src/backend/index/search/range.rs index 6ed6dc25f..d78a66bc8 100644 --- a/diskann-benchmark/src/backend/index/search/range.rs +++ b/diskann-benchmark/src/backend/index/search/range.rs @@ -30,7 +30,7 @@ impl<'a> RangeSearchSteps<'a> { } } -type Run = core_search::Run; +type Run = core_search::Run; pub(crate) trait Range { fn search_all( @@ -79,13 +79,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::Range: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::RangeSearchParams, + Parameters = diskann::graph::RangeSearch, Output = core_search::graph::range::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, ) -> anyhow::Result> { let results = core_search::search_all( diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index e12c26419..d6fe09c45 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -8,7 +8,7 @@ use std::num::{NonZeroU32, NonZeroUsize}; use anyhow::{anyhow, Context}; use diskann::{ - graph::{self, config, RangeSearchParams, RangeSearchParamsError, StartPointStrategy}, + graph::{self, config, RangeSearch, RangeSearchError, StartPointStrategy}, utils::IntoUsize, }; use diskann_benchmark_core::streaming::executors::bigann; @@ -90,13 +90,11 @@ pub(crate) struct GraphRangeSearch { } impl GraphRangeSearch { - pub(crate) fn construct_params( - &self, - ) -> Result, RangeSearchParamsError> { + pub(crate) fn construct_params(&self) -> Result, RangeSearchError> { self.initial_search_l .iter() .map(|&l| { - RangeSearchParams::new( + RangeSearch::with_options( self.max_returned, l, self.beam_width, @@ -111,7 +109,7 @@ impl GraphRangeSearch { } impl CheckDeserialization for GraphRangeSearch { - // all necessary checks are carried out when RangeSearchParams is initialized + // all necessary checks are carried out when RangeSearch is initialized fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { self.construct_params() .context("invalid range search params")?; diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 517932139..f6b18332a 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -19,8 +19,7 @@ use diskann::{ graph::{ self, glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy}, - search_output_buffer, AdjacencyList, DiskANNIndex, GraphSearch, SearchOutputBuffer, - SearchParams, + search_output_buffer, AdjacencyList, DiskANNIndex, KnnSearch, SearchOutputBuffer, }, neighbor::Neighbor, provider::{ @@ -984,23 +983,25 @@ where let strategy = self.search_strategy(query, vector_filter); let timer = Instant::now(); + let k = NonZeroUsize::new(k_value).expect("k_value must be non-zero"); + let l = NonZeroUsize::new(search_list_size as usize) + .expect("search_list_size must be non-zero"); let stats = if is_flat_search { self.runtime.block_on(self.index.flat_search( &strategy, &DefaultContext, strategy.query, vector_filter, - &SearchParams::new(k_value, search_list_size as usize, beam_width)?, + &KnnSearch::new(k, l, beam_width)?, &mut result_output_buffer, ))? } else { - let mut graph_search = - GraphSearch::new(k_value, search_list_size as usize, beam_width)?; + let mut knn_search = KnnSearch::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( &strategy, &DefaultContext, strategy.query, - &mut graph_search, + &mut knn_search, &mut result_output_buffer, ))? }; @@ -1042,8 +1043,10 @@ fn ensure_vertex_loaded>( #[cfg(test)] mod disk_provider_tests { + use std::num::NonZeroUsize; + use diskann::{ - graph::{search::record::VisitedSearchRecord, SearchParams, SearchParamsError}, + graph::{search::record::VisitedSearchRecord, KnnSearch, KnnSearchError}, utils::IntoUsize, ANNErrorKind, }; @@ -1071,6 +1074,11 @@ mod disk_provider_tests { utils::{QueryStatistics, VirtualAlignedReaderFactory}, }; + /// Helper to create NonZeroUsize from usize (for tests only). + fn nz(v: usize) -> NonZeroUsize { + NonZeroUsize::new(v).expect("value must be non-zero") + } + const TEST_INDEX_PREFIX_128DIM: &str = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search"; const TEST_INDEX_128DIM: &str = @@ -1532,17 +1540,15 @@ mod disk_provider_tests { "index_path is not correct" ); - let res = SearchParams::new_default(0, 10); + // Test error case: l < k + let res = KnnSearch::new_default(nz(20), nz(10)); assert!(res.is_err()); assert_eq!( - >::into(res.unwrap_err()).kind(), + >::into(res.unwrap_err()).kind(), ANNErrorKind::IndexError ); - let res = SearchParams::new_default(20, 10); - assert!(res.is_err()); - let res = SearchParams::new_default(10, 0); - assert!(res.is_err()); - let res = SearchParams::new(10, 10, Some(0)); + // Test error case: beam_width = 0 + let res = KnnSearch::new(nz(10), nz(10), Some(0)); assert!(res.is_err()); let search_engine = @@ -1633,7 +1639,7 @@ mod disk_provider_tests { &strategy, &DefaultContext, &query_vector, - &SearchParams::new(10, 10, Some(4)).unwrap(), + &KnnSearch::new(nz(10), nz(10), Some(4)).unwrap(), &mut result_output_buffer, &mut search_record, )) @@ -1755,7 +1761,7 @@ mod disk_provider_tests { attribute_provider.clone(), ); - let search_params = SearchParams::new(10, 20, None).unwrap(); + let search_params = KnnSearch::new(nz(10), nz(20), None).unwrap(); search_engine .runtime @@ -1802,8 +1808,12 @@ mod disk_provider_tests { ); let strategy2 = search_engine.search_strategy(&query_vector, &|_| true); let mut search_record2 = VisitedSearchRecord::new(0); - let search_params2 = - SearchParams::new(return_list_size as usize, search_list_size as usize, None).unwrap(); + let search_params2 = KnnSearch::new( + nz(return_list_size as usize), + nz(search_list_size as usize), + None, + ) + .unwrap(); let stats = search_engine .runtime @@ -2095,7 +2105,7 @@ mod disk_provider_tests { &strategy, &DefaultContext, &query_vector, - &SearchParams::new(10, 10, Some(4)).unwrap(), + &KnnSearch::new(nz(10), nz(10), Some(4)).unwrap(), &mut result_output_buffer, &mut search_record, )) diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index a50758612..66b8c8687 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -173,8 +173,8 @@ pub(crate) mod tests { use crate::storage::VirtualStorageProvider; use diskann::{ graph::{ - self, AdjacencyList, ConsolidateKind, InplaceDeleteMethod, RangeSearchParams, - SearchParams, StartPointStrategy, + self, AdjacencyList, ConsolidateKind, InplaceDeleteMethod, KnnSearch, RangeSearch, + StartPointStrategy, config::IntraBatchCandidates, glue::{AsElement, InplaceDeleteStrategy, InsertStrategy, SearchStrategy, aliases}, index::{PartitionedNeighbors, QueryLabelProvider, QueryVisitDecision}, @@ -215,6 +215,11 @@ pub(crate) mod tests { // Callbacks for use with `simplified_builder`. fn no_modify(_: &mut diskann::graph::config::Builder) {} + /// Helper to create NonZeroUsize from usize (for tests only). + fn nz(v: usize) -> NonZeroUsize { + NonZeroUsize::new(v).expect("value must be non-zero") + } + ////////////////////////// // Test helper functions // ////////////////////////// @@ -227,7 +232,7 @@ pub(crate) mod tests { strategy: &S, context: &DP::Context, query: &T, - search_params: &SearchParams, + search_params: &KnnSearch, output: &mut OB, filter: &dyn graph::index::QueryLabelProvider, ) -> diskann::ANNResult @@ -238,8 +243,7 @@ pub(crate) mod tests { O: Send, OB: graph::search_output_buffer::SearchOutputBuffer + Send, { - let mut multihop = - graph::MultihopSearch::new(graph::GraphSearch::from(*search_params), filter); + let mut multihop = graph::MultihopSearch::new(*search_params, filter); index .search(strategy, context, query, &mut multihop, output) .await @@ -251,7 +255,7 @@ pub(crate) mod tests { strategy: &S, context: &DP::Context, query: &T, - search_params: &RangeSearchParams, + search_params: &RangeSearch, ) -> diskann::ANNResult<(SearchStats, Vec, Vec)> where DP: DataProvider, @@ -259,7 +263,7 @@ pub(crate) mod tests { S: graph::glue::SearchStrategy, O: Send + Default + Clone, { - let mut range_search = graph::RangeSearch::from(*search_params); + let mut range_search = *search_params; let result = index .search(strategy, context, query, &mut range_search, &mut ()) .await?; @@ -405,14 +409,15 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = - SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(); + let mut graph_search = + graph::KnnSearch::new_default(nz(parameters.search_k), nz(parameters.search_l)) + .unwrap(); index .search( &strategy, ¶meters.context, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -453,14 +458,14 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = - SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(); + let search_params = + KnnSearch::new_default(nz(parameters.search_k), nz(parameters.search_l)).unwrap(); multihop_search( index, strategy, ¶meters.context, query, - &mut search_params, + &search_params, &mut result_output_buffer, filter, ) @@ -1498,14 +1503,14 @@ pub(crate) mod tests { let filter = CallbackFilter::new(blocked, adjusted, 0.5); - let mut search_params = - SearchParams::new_default(parameters.search_k, parameters.search_l).unwrap(); + let search_params = + KnnSearch::new_default(nz(parameters.search_k), nz(parameters.search_l)).unwrap(); let stats = multihop_search( &index, &FullPrecision, ¶meters.context, query.as_slice(), - &mut search_params, + &search_params, &mut result_output_buffer, &filter, ) @@ -2247,14 +2252,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = + graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); // Full Precision Search. index .search( &FullPrecision, ctx, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -2265,14 +2271,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = + graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); // Quantized Search index .search( &Hybrid::new(None), ctx, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -2336,7 +2343,7 @@ pub(crate) mod tests { &FullPrecision, ctx, query, - &RangeSearchParams::new_default(starting_l_value, radius).unwrap(), + &RangeSearch::new(starting_l_value, radius).unwrap(), ) .await .unwrap(); @@ -2351,7 +2358,7 @@ pub(crate) mod tests { &Hybrid::new(None), ctx, query, - &RangeSearchParams::new_default(starting_l_value, radius).unwrap(), + &RangeSearch::new(starting_l_value, radius).unwrap(), ) .await .unwrap(); @@ -2368,7 +2375,7 @@ pub(crate) mod tests { &FullPrecision, ctx, query, - &RangeSearchParams::new( + &RangeSearch::with_options( None, starting_l_value, None, @@ -2393,7 +2400,7 @@ pub(crate) mod tests { &FullPrecision, ctx, query, - &RangeSearchParams::new_default(lower_l_value, radius).unwrap(), + &RangeSearch::new(lower_l_value, radius).unwrap(), ) .await .unwrap(); @@ -2515,14 +2522,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = + graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); // Full Precision Search. index .search( &FullPrecision, ctx, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -2533,14 +2541,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = + graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); // Quantized Search index .search( &Quantized, ctx, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -2620,14 +2629,15 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, top_k).unwrap(); + let mut graph_search = + graph::KnnSearch::new_default(nz(top_k), nz(top_k)).unwrap(); // Quantized Search index .search( &Quantized, ctx, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -2734,9 +2744,9 @@ pub(crate) mod tests { // Full Precision Search. let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); index - .search(&FullPrecision, ctx, query, &mut search_params, &mut output) + .search(&FullPrecision, ctx, query, &mut graph_search, &mut output) .await .unwrap(); assert_top_k_exactly_match(q, >, &ids, &distances, top_k); @@ -2746,10 +2756,10 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); index - .search(&strategy, ctx, query, &mut search_params, &mut output) + .search(&strategy, ctx, query, &mut graph_search, &mut output) .await .unwrap(); assert_top_k_exactly_match(q, >, &ids, &distances, top_k); @@ -2849,10 +2859,10 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); index - .search(&strategy, ctx, query, &mut search_params, &mut output) + .search(&strategy, ctx, query, &mut graph_search, &mut output) .await .unwrap(); @@ -2936,14 +2946,14 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); // Full Precision Search. index .search( &Quantized, ctx, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -3517,14 +3527,14 @@ pub(crate) mod tests { let gt = groundtruth(queries.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); // Full Precision Search. index .search( &Hybrid::new(max_fp_vecs_per_prune), ctx, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -3664,14 +3674,14 @@ pub(crate) mod tests { let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut search_params = SearchParams::new_default(top_k, search_l).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); // Full Precision Search. index .search( &FullPrecision, ctx, query, - &mut search_params, + &mut graph_search, &mut result_output_buffer, ) .await @@ -3933,7 +3943,7 @@ pub(crate) mod tests { attribute_provider.clone(), ); - let search_params = diskann::graph::SearchParams::new( + let search_params = diskann::graph::KnnSearch::new( return_list_size, search_list_size, None, // beam_width @@ -4154,13 +4164,13 @@ pub(crate) mod tests { // but reject everything via on_visit let filter = RejectAllFilter::only([0_u32]); - let mut search_params = SearchParams::new_default(10, 20).unwrap(); + let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); let stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &mut search_params, + &search_params, &mut result_output_buffer, &filter, ) @@ -4217,13 +4227,13 @@ pub(crate) mod tests { let target = (num_points / 2) as u32; let filter = TerminatingFilter::new(target); - let mut search_params = SearchParams::new_default(10, 40).unwrap(); + let search_params = KnnSearch::new_default(nz(10), nz(40)).unwrap(); let stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &mut search_params, + &search_params, &mut result_output_buffer, &filter, ) @@ -4282,13 +4292,13 @@ pub(crate) mod tests { let mut baseline_buffer = search_output_buffer::IdDistance::new(&mut baseline_ids, &mut baseline_distances); - let mut search_params = SearchParams::new_default(10, 20).unwrap(); + let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); let baseline_stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &mut search_params, + &search_params, &mut baseline_buffer, &EvenFilter, // Just filter to even IDs ) @@ -4304,13 +4314,13 @@ pub(crate) mod tests { let mut adjusted_buffer = search_output_buffer::IdDistance::new(&mut adjusted_ids, &mut adjusted_distances); - let mut search_params = SearchParams::new_default(10, 20).unwrap(); + let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); let adjusted_stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &mut search_params, + &search_params, &mut adjusted_buffer, &filter, ) @@ -4431,13 +4441,13 @@ pub(crate) mod tests { let max_visits = 5; let filter = TerminateAfterN::new(max_visits); - let mut search_params = SearchParams::new_default(10, 100).unwrap(); // Large L to ensure we'd visit more without termination + let search_params = KnnSearch::new_default(nz(10), nz(100)).unwrap(); // Large L to ensure we'd visit more without termination let _stats = multihop_search( &index, &FullPrecision, &DefaultContext, query.as_slice(), - &mut search_params, + &search_params, &mut result_output_buffer, &filter, ) diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index d8b0f4054..4ca45c5d8 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -8,7 +8,7 @@ use std::{num::NonZeroUsize, sync::Arc}; use diskann::{ ANNResult, graph::{ - self, ConsolidateKind, InplaceDeleteMethod, SearchParams, + self, ConsolidateKind, InplaceDeleteMethod, KnnSearch, glue::{ self, AsElement, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchStrategy, }, @@ -226,7 +226,7 @@ where strategy: &S, context: &DP::Context, query: &T, - search_params: &SearchParams, + search_params: &KnnSearch, output: &mut OB, ) -> ANNResult where @@ -235,10 +235,10 @@ where O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, { - let mut graph_search = diskann::graph::GraphSearch::from(*search_params); + let mut knn_search = *search_params; self.handle.block_on( self.inner - .search(strategy, context, query, &mut graph_search, output), + .search(strategy, context, query, &mut knn_search, output), ) } diff --git a/diskann/src/error/ann_error.rs b/diskann/src/error/ann_error.rs index 6a87a8f6d..040e06037 100644 --- a/diskann/src/error/ann_error.rs +++ b/diskann/src/error/ann_error.rs @@ -718,17 +718,6 @@ where } } -pub(crate) fn ensure_positive(value: T, error: E) -> Result -where - T: PartialOrd + Default + Debug, -{ - if value > T::default() { - Ok(value) - } else { - Err(error) - } -} - // /// An internal macro for creating opaque, adhoc errors to help when debugging. // macro_rules! ann_error { // ($($arg:tt)+) => {{ diff --git a/diskann/src/error/mod.rs b/diskann/src/error/mod.rs index 7c052c4d1..3a9f9ab50 100644 --- a/diskann/src/error/mod.rs +++ b/diskann/src/error/mod.rs @@ -4,7 +4,6 @@ */ pub(crate) mod ann_error; -pub(crate) use ann_error::ensure_positive; pub use ann_error::{ANNError, ANNErrorKind, ANNResult, DiskANNError, ErrorContext, IntoANNResult}; pub(crate) mod ranked; diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 2260864b5..e444cfac1 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -24,7 +24,7 @@ use thiserror::Error; use tokio::task::JoinSet; use super::{ - AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, SearchParams, + AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, KnnSearch, glue::{ self, AsElement, ExpandBeam, FillSet, IdIterator, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, @@ -2124,7 +2124,7 @@ where /// /// # Supported Search Types /// - /// - [`search::GraphSearch`]: Standard graph-based ANN search + /// - [`search::KnnSearch`]: Standard k-NN graph-based search /// - [`search::MultihopSearch`]: Label-filtered search with multi-hop expansion /// - [`search::RangeSearch`]: Range-based search within a distance radius /// - [`search::DiverseSearch`]: Diversity-aware search (feature-gated) @@ -2132,11 +2132,11 @@ where /// # Example /// /// ```ignore - /// use diskann::graph::{GraphSearch, RangeSearch, Search}; + /// use diskann::graph::{KnnSearch, RangeSearch, Search}; /// - /// // Standard graph search - /// let mut params = GraphSearch::new(10, 100, None)?; - /// let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?; + /// // Standard k-NN search + /// let mut params = KnnSearch::new(10, 100, None)?; + /// let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?;; /// /// // Range search (note: uses () as output buffer, results in Output type) /// let mut params = RangeSearch::new(100, 0.5)?; @@ -2162,7 +2162,7 @@ where /// Perform a graph search while recording the traversal path. /// /// **Note:** This method is intended for debugging and analysis only. - /// For production searches, use [`Self::search`] with [`super::search::GraphSearch`]. + /// For production searches, use [`Self::search`] with [`super::search::KnnSearch`]. /// /// Records which nodes were visited during the search traversal, useful for /// understanding search behavior or diagnosing issues. @@ -2172,7 +2172,7 @@ where strategy: &'a S, context: &'a DP::Context, query: &'a T, - search_params: &'a SearchParams, + search_params: &'a KnnSearch, output: &'a mut OB, search_record: &'a mut SR, ) -> impl SendFuture> + 'a @@ -2184,10 +2184,8 @@ where SR: SearchRecord, { async move { - let mut recorded_search = super::search::RecordedGraphSearch::new( - super::search::GraphSearch::from(*search_params), - search_record, - ); + let mut recorded_search = + super::search::RecordedKnnSearch::new(*search_params, search_record); recorded_search .dispatch(self, strategy, context, query, output) .await @@ -2227,7 +2225,7 @@ where context: &'a DP::Context, query: &T, vector_filter: &(dyn Fn(&DP::ExternalId) -> bool + Send + Sync), - search_params: &SearchParams, + search_params: &KnnSearch, output: &mut OB, ) -> ANNResult where @@ -2244,7 +2242,7 @@ where let mut scratch = { let num_start_points = accessor.starting_points().await?.len(); - self.search_scratch(search_params.l_value, num_start_points) + self.search_scratch(search_params.l_value().get(), num_start_points) }; let id_iterator = accessor.id_iterator().await?; @@ -2272,7 +2270,7 @@ where &mut accessor, query, &computer, - scratch.best.iter().take(search_params.l_value.into_usize()), + scratch.best.iter().take(search_params.l_value().get()), output, ) .send() diff --git a/diskann/src/graph/misc.rs b/diskann/src/graph/misc.rs index 8c58f6edb..067bd2d21 100644 --- a/diskann/src/graph/misc.rs +++ b/diskann/src/graph/misc.rs @@ -3,10 +3,6 @@ * Licensed under the MIT license. */ -use thiserror::Error; - -use crate::{ANNError, ANNErrorKind, error::ensure_positive}; - // enum used to return the status of the vector that `consolidate_vector` // was called on: Deleted if the vector was already deleted, and Complete // if the vector was not deleted (and thus is now consolidated) @@ -35,141 +31,6 @@ pub enum InplaceDeleteMethod { OneHop, } -// Parameters for the search algorithm -#[derive(Copy, Clone, Debug)] -pub struct SearchParams { - pub k_value: usize, - pub l_value: usize, - pub beam_width: Option, -} - -#[derive(Debug, Error)] -pub enum SearchParamsError { - #[error("l_value ({l_value}) cannot be less than k_value ({k_value})")] - LLessThanK { l_value: usize, k_value: usize }, - #[error("beam width cannot be zero")] - BeamWidthZero, - #[error("l_value cannot be zero")] - LZero, - #[error("k_value cannot be zero")] - KZero, -} - -impl From for ANNError { - fn from(err: SearchParamsError) -> Self { - Self::new(ANNErrorKind::IndexError, err) - } -} - -impl SearchParams { - pub fn new( - k_value: usize, - l_value: usize, - beam_width: Option, - ) -> Result { - if k_value > l_value { - return Err(SearchParamsError::LLessThanK { l_value, k_value }); - } - if let Some(beam_width) = beam_width { - ensure_positive(beam_width, SearchParamsError::BeamWidthZero)?; - } - ensure_positive(k_value, SearchParamsError::KZero)?; - ensure_positive(l_value, SearchParamsError::LZero)?; - - Ok(Self { - k_value, - l_value, - beam_width, - }) - } - - pub fn new_default(k_value: usize, l_value: usize) -> Result { - SearchParams::new(k_value, l_value, None) - } -} - -// Parameters for the search algorithm -#[derive(Copy, Clone, Debug)] -pub struct RangeSearchParams { - pub max_returned: Option, - pub starting_l_value: usize, - pub beam_width: Option, - pub radius: f32, - pub inner_radius: Option, - pub initial_search_slack: f32, - pub range_search_slack: f32, -} - -#[derive(Debug, Error)] -pub enum RangeSearchParamsError { - #[error("beam width cannot be zero")] - BeamWidthZero, - #[error("l_value cannot be zero")] - LZero, - #[error("initial_search_slack must be between 0 and 1.0")] - StartingListSlackValueError, - #[error("range_search_slack must be greater than or equal to 1.0")] - RangeSearchSlackValueError, - #[error("inner_radius must be less than or equal to radius")] - InnerRadiusValueError, -} - -impl From for ANNError { - fn from(err: RangeSearchParamsError) -> Self { - Self::new(ANNErrorKind::IndexError, err) - } -} - -impl RangeSearchParams { - pub fn new( - max_returned: Option, - starting_l_value: usize, - beam_width: Option, - radius: f32, - inner_radius: Option, - initial_search_slack: f32, - range_search_slack: f32, - ) -> Result { - // note that radius is allowed to be negative due to inner product metrics - if let Some(beam_width) = beam_width { - ensure_positive(beam_width, RangeSearchParamsError::BeamWidthZero)?; - } - ensure_positive(starting_l_value, RangeSearchParamsError::LZero)?; - if !(0.0..=1.0).contains(&initial_search_slack) { - return Err(RangeSearchParamsError::StartingListSlackValueError); - } - if range_search_slack < 1.0 { - return Err(RangeSearchParamsError::RangeSearchSlackValueError); - } - if let Some(inner_radius) = inner_radius - && inner_radius > radius - { - return Err(RangeSearchParamsError::InnerRadiusValueError); - } - - Ok(Self { - max_returned, - starting_l_value, - beam_width, - radius, - inner_radius, - initial_search_slack, - range_search_slack, - }) - } - - pub fn new_default( - starting_l_value: usize, - radius: f32, - ) -> Result { - RangeSearchParams::new(None, starting_l_value, None, radius, None, 1.0, 1.0) - } - - pub fn l_value(&self) -> usize { - self.starting_l_value - } -} - // Parameters for diverse search #[cfg(feature = "experimental_diversity_search")] #[derive(Clone, Debug)] @@ -224,75 +85,4 @@ mod tests { _ => panic!("Expected not deleted variant"), } } - - #[test] - fn test_range_search_params_error_cases() { - { - // test starting list slack factor error - let res = RangeSearchParams::new( - None, // max returned - 10, // starting l value - None, // beam width - 1.0, // radius - None, // inner radius - 1.1, // initial search slack - 1.0, // range search slack - ); - assert!(res.is_err()); - assert_eq!( - res.unwrap_err().to_string(), - "initial_search_slack must be between 0 and 1.0" - ); - } - { - // test range search slack factor error - let res = RangeSearchParams::new( - None, // max returned - 10, // starting l value - None, // beam width - 1.0, // radius - None, // inner radius - 1.0, // initial search slack - 0.9, // range search slack - ); - assert!(res.is_err()); - assert_eq!( - res.unwrap_err().to_string(), - "range_search_slack must be greater than or equal to 1.0" - ); - } - { - // test inner radius error - let res = RangeSearchParams::new( - None, // max returned - 10, // starting l value - None, // beam width - 1.0, // radius - Some(2.0), // inner radius - 1.0, // initial search slack - 1.0, // range search slack - ); - assert!(res.is_err()); - assert_eq!( - res.unwrap_err().to_string(), - "inner_radius must be less than or equal to radius" - ); - } - } - - #[test] - fn test_range_search_params_impl() { - let res = RangeSearchParams::new( - None, // max returned - 10, // starting l value - None, // beam width - 1.0, // radius - None, // inner radius - 1.0, // initial search slack - 1.0, // range search slack - ) - .unwrap(); - - assert_eq!(res.l_value(), 10); - } } diff --git a/diskann/src/graph/mod.rs b/diskann/src/graph/mod.rs index ae97ba869..011734e28 100644 --- a/diskann/src/graph/mod.rs +++ b/diskann/src/graph/mod.rs @@ -21,10 +21,7 @@ mod start_point; pub use start_point::{SampleableForStart, StartPointStrategy}; mod misc; -pub use misc::{ - ConsolidateKind, InplaceDeleteMethod, RangeSearchParams, RangeSearchParamsError, SearchParams, - SearchParamsError, -}; +pub use misc::{ConsolidateKind, InplaceDeleteMethod}; #[cfg(feature = "experimental_diversity_search")] pub use misc::DiverseSearchParams; @@ -33,7 +30,10 @@ pub mod glue; pub mod search; // Re-export unified search interface as the primary API. -pub use search::{GraphSearch, MultihopSearch, RangeSearch, RangeSearchOutput, Search}; +pub use search::{ + KnnSearch, KnnSearchError, MultihopSearch, RangeSearch, RangeSearchError, RangeSearchOutput, + RecordedKnnSearch, Search, +}; #[cfg(feature = "experimental_diversity_search")] pub use search::DiverseSearch; diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index 0b138a65c..023327fdf 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -5,14 +5,10 @@ //! Diversity-aware search. -#![cfg(feature = "experimental_diversity_search")] - -use std::num::NonZeroUsize; - use diskann_utils::future::{AssertSend, SendFuture}; use hashbrown::HashSet; -use super::{Search, graph_search::GraphSearch, record::NoopSearchRecord, scratch::SearchScratch}; +use super::{KnnSearch, Search, record::NoopSearchRecord, scratch::SearchScratch}; use crate::{ ANNResult, error::IntoANNResult, @@ -24,7 +20,6 @@ use crate::{ }, neighbor::{AttributeValueProvider, DiverseNeighborQueue, NeighborQueue}, provider::{BuildQueryComputer, DataProvider}, - utils::IntoUsize, }; /// Parameters for diversity-aware search. @@ -35,8 +30,8 @@ pub struct DiverseSearch

where P: AttributeValueProvider, { - /// Base graph search parameters. - pub inner: GraphSearch, + /// Base k-NN search parameters. + pub inner: KnnSearch, /// Diversity-specific parameters. pub diverse_params: DiverseSearchParams

, } @@ -46,7 +41,7 @@ where P: AttributeValueProvider, { /// Create new diverse search parameters. - pub fn new(inner: GraphSearch, diverse_params: DiverseSearchParams

) -> Self { + pub fn new(inner: KnnSearch, diverse_params: DiverseSearchParams

) -> Self { Self { inner, diverse_params, @@ -64,10 +59,8 @@ where { let attribute_provider = self.diverse_params.attribute_provider.clone(); let diverse_queue = DiverseNeighborQueue::new( - self.inner.l, - // SAFETY: k_value is guaranteed to be non-zero by GraphSearch validation - #[allow(clippy::expect_used)] - NonZeroUsize::new(self.inner.k).expect("k_value must be non-zero"), + self.inner.l_value().get(), + self.inner.k_value(), self.diverse_params.diverse_results_k, attribute_provider, ); @@ -75,10 +68,10 @@ where SearchScratch { best: diverse_queue, visited: HashSet::with_capacity( - index.estimate_visited_set_capacity(Some(self.inner.l)), + index.estimate_visited_set_capacity(Some(self.inner.l_value().get())), ), id_scratch: Vec::with_capacity(index.max_degree_with_slack()), - beam_nodes: Vec::with_capacity(self.inner.beam_width.unwrap_or(1)), + beam_nodes: Vec::with_capacity(self.inner.beam_width().unwrap_or(1)), range_frontier: std::collections::VecDeque::new(), in_range: Vec::new(), hops: 0, @@ -118,7 +111,7 @@ where let stats = index .search_internal( - self.inner.beam_width, + self.inner.beam_width(), &start_ids, &mut accessor, &computer, @@ -136,7 +129,7 @@ where &mut accessor, query, &computer, - diverse_scratch.best.iter().take(self.inner.l.into_usize()), + diverse_scratch.best.iter().take(self.inner.l_value().get()), output, ) .send() diff --git a/diskann/src/graph/search/graph_search.rs b/diskann/src/graph/search/graph_search.rs index 876ff4bb9..354a85d14 100644 --- a/diskann/src/graph/search/graph_search.rs +++ b/diskann/src/graph/search/graph_search.rs @@ -27,14 +27,16 @@ use crate::{ /// /// This is the primary search mode, using the Vamana graph structure for efficient /// approximate nearest neighbor traversal. +/// +/// This type is also exported as `SearchParams` for backwards compatibility. #[derive(Debug, Clone, Copy)] pub struct GraphSearch { /// Number of results to return (k in k-NN). - pub k: usize, + k_value: usize, /// Search list size - controls accuracy vs speed tradeoff. - pub l: usize, + l_value: usize, /// Optional beam width for parallel graph exploration. - pub beam_width: Option, + beam_width: Option, } impl GraphSearch { @@ -42,77 +44,57 @@ impl GraphSearch { /// /// # Errors /// - /// Returns an error if `l < k` or if any value is zero. + /// Returns an error if `l_value < k_value` or if any value is zero. pub fn new( - k: usize, - l: usize, + k_value: usize, + l_value: usize, beam_width: Option, ) -> Result { use super::super::SearchParamsError; - if k > l { - return Err(SearchParamsError::LLessThanK { - l_value: l, - k_value: k, - }); + if k_value > l_value { + return Err(SearchParamsError::LLessThanK { l_value, k_value }); } if let Some(bw) = beam_width { if bw == 0 { return Err(SearchParamsError::BeamWidthZero); } } - if k == 0 { + if k_value == 0 { return Err(SearchParamsError::KZero); } - if l == 0 { + if l_value == 0 { return Err(SearchParamsError::LZero); } - Ok(Self { k, l, beam_width }) + Ok(Self { + k_value, + l_value, + beam_width, + }) } /// Create parameters with default beam width. - pub fn new_default(k: usize, l: usize) -> Result { - Self::new(k, l, None) + pub fn new_default(k_value: usize, l_value: usize) -> Result { + Self::new(k_value, l_value, None) } -} -impl From for GraphSearch { - fn from(params: super::super::SearchParams) -> Self { - Self { - k: params.k_value, - l: params.l_value, - beam_width: params.beam_width, - } + /// Returns the number of results to return (k in k-NN). + #[inline] + pub fn k_value(&self) -> usize { + self.k_value } -} -/// Implement Search for SearchParams to provide backwards compatibility. -/// This treats SearchParams as an alias for GraphSearch. -impl Search for super::super::SearchParams -where - DP: DataProvider, - T: Sync + ?Sized, - S: SearchStrategy, - O: Send, - OB: SearchOutputBuffer + Send + ?Sized, -{ - type Output = SearchStats; + /// Returns the search list size. + #[inline] + pub fn l_value(&self) -> usize { + self.l_value + } - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, - ) -> impl SendFuture> { - async move { - let mut graph_search = GraphSearch::from(*self); - graph_search - .dispatch(index, strategy, context, query, output) - .await - } + /// Returns the optional beam width for parallel graph exploration. + #[inline] + pub fn beam_width(&self) -> Option { + self.beam_width } } @@ -142,7 +124,7 @@ where let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; - let mut scratch = index.search_scratch(self.l, start_ids.len()); + let mut scratch = index.search_scratch(self.l_value, start_ids.len()); let stats = index .search_internal( @@ -161,7 +143,7 @@ where &mut accessor, query, &computer, - scratch.best.iter().take(self.l.into_usize()), + scratch.best.iter().take(self.l_value.into_usize()), output, ) .send() @@ -222,7 +204,7 @@ where let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; - let mut scratch = index.search_scratch(self.inner.l, start_ids.len()); + let mut scratch = index.search_scratch(self.inner.l_value, start_ids.len()); let stats = index .search_internal( @@ -241,7 +223,7 @@ where &mut accessor, query, &computer, - scratch.best.iter().take(self.inner.l.into_usize()), + scratch.best.iter().take(self.inner.l_value.into_usize()), output, ) .send() diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs new file mode 100644 index 000000000..1c790a2eb --- /dev/null +++ b/diskann/src/graph/search/knn_search.rs @@ -0,0 +1,309 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Standard k-NN (k-nearest neighbor) graph-based search. + +use std::{fmt::Debug, num::NonZeroUsize}; + +use diskann_utils::future::{AssertSend, SendFuture}; +use thiserror::Error; + +use super::Search; +use crate::{ + ANNError, ANNErrorKind, ANNResult, + error::IntoANNResult, + graph::{ + glue::{SearchExt, SearchPostProcess, SearchStrategy}, + index::{DiskANNIndex, SearchStats}, + search::record::NoopSearchRecord, + search_output_buffer::SearchOutputBuffer, + }, + provider::{BuildQueryComputer, DataProvider}, + utils::IntoUsize, +}; + +/// Error type for [`KnnSearch`] parameter validation. +#[derive(Debug, Error)] +pub enum KnnSearchError { + #[error("l_value ({l_value}) cannot be less than k_value ({k_value})")] + LLessThanK { l_value: usize, k_value: usize }, + #[error("beam width cannot be zero")] + BeamWidthZero, +} + +impl From for ANNError { + fn from(err: KnnSearchError) -> Self { + Self::new(ANNErrorKind::IndexError, err) + } +} + +/// Parameters for standard k-NN (k-nearest neighbor) graph-based search. +/// +/// This is the primary search mode, using the Vamana graph structure for efficient +/// approximate nearest neighbor traversal. +#[derive(Debug, Clone, Copy)] +pub struct KnnSearch { + /// Number of results to return (k in k-NN). + k_value: NonZeroUsize, + /// Search list size - controls accuracy vs speed tradeoff. + l_value: NonZeroUsize, + /// Optional beam width for parallel graph exploration. + beam_width: Option, +} + +impl KnnSearch { + /// Create new k-NN search parameters. + /// + /// # Errors + /// + /// Returns an error if `l_value < k_value` or if beam_width is zero. + pub fn new( + k_value: NonZeroUsize, + l_value: NonZeroUsize, + beam_width: Option, + ) -> Result { + if k_value > l_value { + return Err(KnnSearchError::LLessThanK { + l_value: l_value.get(), + k_value: k_value.get(), + }); + } + if let Some(bw) = beam_width + && bw == 0 + { + return Err(KnnSearchError::BeamWidthZero); + } + + Ok(Self { + k_value, + l_value, + beam_width, + }) + } + + /// Create parameters with default beam width. + pub fn new_default( + k_value: NonZeroUsize, + l_value: NonZeroUsize, + ) -> Result { + Self::new(k_value, l_value, None) + } + + /// Returns the number of results to return (k in k-NN). + #[inline] + pub fn k_value(&self) -> NonZeroUsize { + self.k_value + } + + /// Returns the search list size. + #[inline] + pub fn l_value(&self) -> NonZeroUsize { + self.l_value + } + + /// Returns the optional beam width for parallel graph exploration. + #[inline] + pub fn beam_width(&self) -> Option { + self.beam_width + } +} + +impl Search for KnnSearch +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, +{ + type Output = SearchStats; + + fn dispatch<'a>( + &'a mut self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture> { + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len()); + + let stats = index + .search_internal( + self.beam_width, + &start_ids, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch.best.iter().take(self.l_value.get().into_usize()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} + +//////////////////////// +// Recorded KnnSearch // +//////////////////////// + +/// K-NN search with traversal path recording. +/// +/// Records the path taken during search for debugging or analysis. +#[derive(Debug)] +pub struct RecordedKnnSearch<'r, SR: ?Sized> { + /// Base k-NN search parameters. + pub inner: KnnSearch, + /// The recorder to capture search path. + pub recorder: &'r mut SR, +} + +impl<'r, SR: ?Sized> RecordedKnnSearch<'r, SR> { + /// Create new recorded search parameters. + pub fn new(inner: KnnSearch, recorder: &'r mut SR) -> Self { + Self { inner, recorder } + } +} + +impl<'r, DP, S, T, O, OB, SR> Search for RecordedKnnSearch<'r, SR> +where + DP: DataProvider, + T: Sync + ?Sized, + S: SearchStrategy, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, + SR: super::record::SearchRecord + ?Sized, +{ + type Output = SearchStats; + + fn dispatch<'a>( + &'a mut self, + index: &'a DiskANNIndex, + strategy: &'a S, + context: &'a DP::Context, + query: &'a T, + output: &'a mut OB, + ) -> impl SendFuture> { + async move { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(self.inner.l_value.get(), start_ids.len()); + + let stats = index + .search_internal( + self.inner.beam_width, + &start_ids, + &mut accessor, + &computer, + &mut scratch, + self.recorder, + ) + .await?; + + let result_count = strategy + .post_processor() + .post_process( + &mut accessor, + query, + &computer, + scratch + .best + .iter() + .take(self.inner.l_value.get().into_usize()), + output, + ) + .send() + .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) + } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_knn_search_validation() { + // Valid + assert!( + KnnSearch::new( + NonZeroUsize::new(10).unwrap(), + NonZeroUsize::new(100).unwrap(), + None + ) + .is_ok() + ); + assert!( + KnnSearch::new( + NonZeroUsize::new(10).unwrap(), + NonZeroUsize::new(100).unwrap(), + Some(4) + ) + .is_ok() + ); + assert!( + KnnSearch::new( + NonZeroUsize::new(10).unwrap(), + NonZeroUsize::new(10).unwrap(), + None + ) + .is_ok() + ); // k == l is valid + + // Invalid: l < k + assert!( + KnnSearch::new( + NonZeroUsize::new(100).unwrap(), + NonZeroUsize::new(10).unwrap(), + None + ) + .is_err() + ); + + // Invalid: zero beam_width + assert!( + KnnSearch::new( + NonZeroUsize::new(10).unwrap(), + NonZeroUsize::new(100).unwrap(), + Some(0) + ) + .is_err() + ); + } +} diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 1f21929b0..65dd182ae 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -12,10 +12,10 @@ //! # Usage //! //! ```ignore -//! use diskann::graph::{GraphSearch, RangeSearch, MultihopSearch, Search}; +//! use diskann::graph::{KnnSearch, RangeSearch, MultihopSearch, Search}; //! -//! // Standard graph search -//! let mut params = GraphSearch::new(10, 100, None)?; +//! // Standard k-NN search +//! let mut params = KnnSearch::new(10, 100, None)?;; //! let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?; //! //! // Range search @@ -28,7 +28,7 @@ use diskann_utils::future::SendFuture; use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; -mod graph_search; +mod knn_search; mod multihop_search; mod range_search; @@ -59,9 +59,9 @@ where } // Re-export search parameter types. -pub use graph_search::{GraphSearch, RecordedGraphSearch}; +pub use knn_search::{KnnSearch, KnnSearchError, RecordedKnnSearch}; pub use multihop_search::MultihopSearch; -pub use range_search::{RangeSearch, RangeSearchOutput}; +pub use range_search::{RangeSearch, RangeSearchError, RangeSearchOutput}; // Feature-gated diverse search. #[cfg(feature = "experimental_diversity_search")] diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 9c70561f7..57706aef2 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -10,12 +10,11 @@ use diskann_utils::future::{AssertSend, SendFuture}; use diskann_vector::PreprocessedDistanceFunction; use hashbrown::HashSet; -use super::{Search, record::SearchRecord, scratch::SearchScratch}; +use super::{KnnSearch, Search, record::SearchRecord, scratch::SearchScratch}; use crate::{ ANNResult, error::{ErrorExt, IntoANNResult}, graph::{ - SearchParams, glue::{ self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, SearchPostProcess, SearchStrategy, @@ -28,11 +27,9 @@ use crate::{ }, neighbor::Neighbor, provider::{BuildQueryComputer, DataProvider}, - utils::{IntoUsize, VectorId}, + utils::VectorId, }; -use super::graph_search::GraphSearch; - /// Parameters for label-filtered search using multi-hop expansion. /// /// This search extends standard graph search by expanding through non-matching @@ -41,17 +38,14 @@ use super::graph_search::GraphSearch; #[derive(Debug)] pub struct MultihopSearch<'q, InternalId> { /// Base graph search parameters. - pub inner: GraphSearch, + pub inner: KnnSearch, /// Label evaluator for determining node matches. pub label_evaluator: &'q dyn QueryLabelProvider, } impl<'q, InternalId> MultihopSearch<'q, InternalId> { /// Create new multihop search parameters. - pub fn new( - inner: GraphSearch, - label_evaluator: &'q dyn QueryLabelProvider, - ) -> Self { + pub fn new(inner: KnnSearch, label_evaluator: &'q dyn QueryLabelProvider) -> Self { Self { inner, label_evaluator, @@ -77,11 +71,7 @@ where query: &'a T, output: &'a mut OB, ) -> impl SendFuture> { - let params = SearchParams { - k_value: self.inner.k, - l_value: self.inner.l, - beam_width: self.inner.beam_width, - }; + let params = self.inner; async move { let mut accessor = strategy .search_accessor(&index.data_provider, context) @@ -90,7 +80,7 @@ where let start_ids = accessor.starting_points().await?; - let mut scratch = index.search_scratch(params.l_value, start_ids.len()); + let mut scratch = index.search_scratch(params.l_value().get(), start_ids.len()); let stats = multihop_search_internal( index.max_degree_with_slack(), @@ -109,7 +99,7 @@ where &mut accessor, query, &computer, - scratch.best.iter().take(params.l_value.into_usize()), + scratch.best.iter().take(params.l_value().get()), output, ) .send() @@ -182,7 +172,7 @@ impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId { /// to find matching neighbors within two hops. pub(crate) async fn multihop_search_internal( max_degree_with_slack: usize, - search_params: &SearchParams, + search_params: &KnnSearch, accessor: &mut A, computer: &A::QueryComputer, scratch: &mut SearchScratch, @@ -195,7 +185,7 @@ where T: ?Sized, SR: SearchRecord + ?Sized, { - let beam_width = search_params.beam_width.unwrap_or(1); + let beam_width = search_params.beam_width().unwrap_or(1); // Helper to build the final stats from scratch state. let make_stats = |scratch: &SearchScratch| InternalSearchStats { diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 29f5f2440..cbae5e38f 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -6,13 +6,13 @@ //! Range-based search within a distance radius. use diskann_utils::future::{AssertSend, SendFuture}; +use thiserror::Error; use super::{Search, scratch::SearchScratch}; use crate::{ - ANNResult, + ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - RangeSearchParams, glue::{self, ExpandBeam, SearchExt, SearchPostProcess, SearchStrategy}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, @@ -33,33 +33,51 @@ pub struct RangeSearchOutput { pub distances: Vec, } +/// Error type for [`RangeSearch`] parameter validation. +#[derive(Debug, Error)] +pub enum RangeSearchError { + #[error("beam width cannot be zero")] + BeamWidthZero, + #[error("l_value cannot be zero")] + LZero, + #[error("initial_search_slack must be between 0 and 1.0")] + StartingListSlackValueError, + #[error("range_search_slack must be greater than or equal to 1.0")] + RangeSearchSlackValueError, + #[error("inner_radius must be less than or equal to radius")] + InnerRadiusValueError, +} + +impl From for ANNError { + fn from(err: RangeSearchError) -> Self { + Self::new(ANNErrorKind::IndexError, err) + } +} + /// Parameters for range-based search. /// /// Finds all points within a specified distance radius from the query. #[derive(Debug, Clone, Copy)] pub struct RangeSearch { /// Maximum results to return (None = unlimited). - pub max_returned: Option, + max_returned: Option, /// Initial search list size. - pub starting_l: usize, + starting_l: usize, /// Optional beam width. - pub beam_width: Option, + beam_width: Option, /// Outer radius - points within this distance are candidates. - pub radius: f32, + radius: f32, /// Inner radius - points closer than this are excluded. - pub inner_radius: Option, + inner_radius: Option, /// Slack factor for initial search phase (0.0 to 1.0). - pub initial_slack: f32, + initial_slack: f32, /// Slack factor for range expansion (>= 1.0). - pub range_slack: f32, + range_slack: f32, } impl RangeSearch { /// Create range search with default slack values. - pub fn new( - starting_l: usize, - radius: f32, - ) -> Result { + pub fn new(starting_l: usize, radius: f32) -> Result { Self::with_options(None, starting_l, None, radius, None, 1.0, 1.0) } @@ -73,27 +91,25 @@ impl RangeSearch { inner_radius: Option, initial_slack: f32, range_slack: f32, - ) -> Result { - use super::super::RangeSearchParamsError; - - if let Some(bw) = beam_width { - if bw == 0 { - return Err(RangeSearchParamsError::BeamWidthZero); - } + ) -> Result { + if let Some(bw) = beam_width + && bw == 0 + { + return Err(RangeSearchError::BeamWidthZero); } if starting_l == 0 { - return Err(RangeSearchParamsError::LZero); + return Err(RangeSearchError::LZero); } if !(0.0..=1.0).contains(&initial_slack) { - return Err(RangeSearchParamsError::StartingListSlackValueError); + return Err(RangeSearchError::StartingListSlackValueError); } if range_slack < 1.0 { - return Err(RangeSearchParamsError::RangeSearchSlackValueError); + return Err(RangeSearchError::RangeSearchSlackValueError); } - if let Some(inner) = inner_radius { - if inner > radius { - return Err(RangeSearchParamsError::InnerRadiusValueError); - } + if let Some(inner) = inner_radius + && inner > radius + { + return Err(RangeSearchError::InnerRadiusValueError); } Ok(Self { @@ -107,30 +123,46 @@ impl RangeSearch { }) } - fn to_legacy_params(&self) -> RangeSearchParams { - RangeSearchParams { - max_returned: self.max_returned, - starting_l_value: self.starting_l, - beam_width: self.beam_width, - radius: self.radius, - inner_radius: self.inner_radius, - initial_search_slack: self.initial_slack, - range_search_slack: self.range_slack, - } + /// Returns the maximum number of results to return. + #[inline] + pub fn max_returned(&self) -> Option { + self.max_returned } -} -impl From for RangeSearch { - fn from(params: RangeSearchParams) -> Self { - Self { - max_returned: params.max_returned, - starting_l: params.starting_l_value, - beam_width: params.beam_width, - radius: params.radius, - inner_radius: params.inner_radius, - initial_slack: params.initial_search_slack, - range_slack: params.range_search_slack, - } + /// Returns the initial search list size. + #[inline] + pub fn starting_l(&self) -> usize { + self.starting_l + } + + /// Returns the optional beam width. + #[inline] + pub fn beam_width(&self) -> Option { + self.beam_width + } + + /// Returns the outer radius. + #[inline] + pub fn radius(&self) -> f32 { + self.radius + } + + /// Returns the inner radius (points closer are excluded). + #[inline] + pub fn inner_radius(&self) -> Option { + self.inner_radius + } + + /// Returns the initial search slack factor. + #[inline] + pub fn initial_slack(&self) -> f32 { + self.initial_slack + } + + /// Returns the range search slack factor. + #[inline] + pub fn range_slack(&self) -> f32 { + self.range_slack } } @@ -151,7 +183,7 @@ where query: &'a T, _output: &'a mut (), ) -> impl SendFuture> { - let search_params = self.to_legacy_params(); + let search_params = *self; async move { let mut accessor = strategy .search_accessor(&index.data_provider, context) @@ -159,11 +191,11 @@ where let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; - let mut scratch = index.search_scratch(search_params.starting_l_value, start_ids.len()); + let mut scratch = index.search_scratch(search_params.starting_l(), start_ids.len()); let initial_stats = index .search_internal( - search_params.beam_width, + search_params.beam_width(), &start_ids, &mut accessor, &computer, @@ -172,14 +204,14 @@ where ) .await?; - let mut in_range = Vec::with_capacity(search_params.starting_l_value.into_usize()); + let mut in_range = Vec::with_capacity(search_params.starting_l().into_usize()); for neighbor in scratch .best .iter() - .take(search_params.starting_l_value.into_usize()) + .take(search_params.starting_l().into_usize()) { - if neighbor.distance <= search_params.radius { + if neighbor.distance <= search_params.radius() { in_range.push(neighbor); } } @@ -192,8 +224,7 @@ where scratch.in_range = in_range; let stats = if scratch.in_range.len() - >= ((search_params.starting_l_value as f32) * search_params.initial_search_slack) - as usize + >= ((search_params.starting_l() as f32) * search_params.initial_slack()) as usize { // Move to range search let range_stats = range_search_internal( @@ -237,7 +268,7 @@ where .into_ann_result()?; // Filter by inner/outer radius - let inner_cutoff = if let Some(inner_radius) = search_params.inner_radius { + let inner_cutoff = if let Some(inner_radius) = search_params.inner_radius() { result_dists .iter() .position(|dist| *dist > inner_radius) @@ -248,7 +279,7 @@ where let outer_cutoff = result_dists .iter() - .position(|dist| *dist > search_params.radius) + .position(|dist| *dist > search_params.radius()) .unwrap_or(result_dists.len()); result_ids.truncate(outer_cutoff); @@ -283,7 +314,7 @@ where /// Called after the initial graph search has identified starting candidates. pub(crate) async fn range_search_internal( max_degree_with_slack: usize, - search_params: &RangeSearchParams, + search_params: &RangeSearch, accessor: &mut A, computer: &A::QueryComputer, scratch: &mut SearchScratch, @@ -293,7 +324,7 @@ where A: ExpandBeam + SearchExt, T: ?Sized, { - let beam_width = search_params.beam_width.unwrap_or(1); + let beam_width = search_params.beam_width().unwrap_or(1); for neighbor in &scratch.in_range { scratch.range_frontier.push_back(neighbor.id); @@ -301,7 +332,7 @@ where let mut neighbors = Vec::with_capacity(max_degree_with_slack); - let max_returned = search_params.max_returned.unwrap_or(usize::MAX); + let max_returned = search_params.max_returned().unwrap_or(usize::MAX); while !scratch.range_frontier.is_empty() { scratch.beam_nodes.clear(); @@ -327,7 +358,7 @@ where // The predicate ensures that the contents of `neighbors` are unique. for neighbor in neighbors.iter() { - if neighbor.distance <= search_params.radius * search_params.range_search_slack + if neighbor.distance <= search_params.radius() * search_params.range_slack() && scratch.in_range.len() < max_returned { scratch.in_range.push(*neighbor); diff --git a/diskann/src/graph/test/cases/grid.rs b/diskann/src/graph/test/cases/grid.rs index 790ee3c96..9af303260 100644 --- a/diskann/src/graph/test/cases/grid.rs +++ b/diskann/src/graph/test/cases/grid.rs @@ -3,13 +3,13 @@ * Licensed under the MIT license. */ -use std::sync::Arc; +use std::{num::NonZeroUsize, sync::Arc}; use diskann_vector::distance::Metric; use crate::{ graph::{ - self, DiskANNIndex, GraphSearch, + self, DiskANNIndex, KnnSearch, test::{provider as test_provider, synthetic::Grid}, }, neighbor::Neighbor, @@ -126,10 +126,15 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { // are correct. let index = setup_grid_search(grid, size); - let mut params = GraphSearch::new(10, 10, Some(beam_width)).unwrap(); + let mut params = KnnSearch::new( + NonZeroUsize::new(10).unwrap(), + NonZeroUsize::new(10).unwrap(), + Some(beam_width), + ) + .unwrap(); let context = test_provider::Context::new(); - let mut neighbors = vec![Neighbor::::default(); params.k]; + let mut neighbors = vec![Neighbor::::default(); params.k_value().get()]; let graph::index::SearchStats { cmps, hops, @@ -147,7 +152,7 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { assert_eq!( result_count.into_usize(), - params.k, + params.k_value().get(), "grid search should be configured to always return the requested number of neighbors", ); From 33b2158a9faa7c58581d9aaddb17f1518729ba4d Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Fri, 13 Feb 2026 21:18:27 +0530 Subject: [PATCH 08/11] Add documentation to Search trait and KnnSearch dispatch methods --- diskann-providers/src/index/diskann_async.rs | 280 +++++++------------ diskann/src/graph/search/graph_search.rs | 261 ----------------- diskann/src/graph/search/knn_search.rs | 53 ++++ diskann/src/graph/search/mod.rs | 21 ++ 4 files changed, 182 insertions(+), 433 deletions(-) delete mode 100644 diskann/src/graph/search/graph_search.rs diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 66b8c8687..1a599a85f 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -222,54 +222,6 @@ pub(crate) mod tests { ////////////////////////// // Test helper functions // - ////////////////////////// - - use diskann::graph::index::SearchStats; - - /// Test helper: performs multihop search using the dispatch API. - async fn multihop_search( - index: &DiskANNIndex, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &KnnSearch, - output: &mut OB, - filter: &dyn graph::index::QueryLabelProvider, - ) -> diskann::ANNResult - where - DP: DataProvider, - T: Sync + ?Sized, - S: graph::glue::SearchStrategy, - O: Send, - OB: graph::search_output_buffer::SearchOutputBuffer + Send, - { - let mut multihop = graph::MultihopSearch::new(*search_params, filter); - index - .search(strategy, context, query, &mut multihop, output) - .await - } - - /// Test helper: performs range search using the dispatch API. - async fn range_search( - index: &DiskANNIndex, - strategy: &S, - context: &DP::Context, - query: &T, - search_params: &RangeSearch, - ) -> diskann::ANNResult<(SearchStats, Vec, Vec)> - where - DP: DataProvider, - T: Sync + ?Sized, - S: graph::glue::SearchStrategy, - O: Send + Default + Clone, - { - let mut range_search = *search_params; - let result = index - .search(strategy, context, query, &mut range_search, &mut ()) - .await?; - Ok((result.stats, result.ids, result.distances)) - } - ///////////////////////////////////////// // Tests from the original async index // ///////////////////////////////////////// @@ -460,17 +412,17 @@ pub(crate) mod tests { search_output_buffer::IdDistance::new(&mut ids, &mut distances); let search_params = KnnSearch::new_default(nz(parameters.search_k), nz(parameters.search_l)).unwrap(); - multihop_search( - index, - strategy, - ¶meters.context, - query, - &search_params, - &mut result_output_buffer, - filter, - ) - .await - .unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, filter); + index + .search( + strategy, + ¶meters.context, + query, + &mut multihop, + &mut result_output_buffer, + ) + .await + .unwrap(); // Loop over the requested number of results to check, invoking the checker closure. // @@ -1505,17 +1457,17 @@ pub(crate) mod tests { let search_params = KnnSearch::new_default(nz(parameters.search_k), nz(parameters.search_l)).unwrap(); - let stats = multihop_search( - &index, - &FullPrecision, - ¶meters.context, - query.as_slice(), - &search_params, - &mut result_output_buffer, - &filter, - ) - .await - .unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); + let stats = index + .search( + &FullPrecision, + ¶meters.context, + query.as_slice(), + &mut multihop, + &mut result_output_buffer, + ) + .await + .unwrap(); // Retrieve callback metrics for detailed validation let callback_metrics = filter.metrics(); @@ -2338,76 +2290,60 @@ pub(crate) mod tests { let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); { // Full Precision Search. - let (_, ids, _) = range_search( - &*index, - &FullPrecision, - ctx, - query, - &RangeSearch::new(starting_l_value, radius).unwrap(), - ) - .await - .unwrap(); + let mut range_search = RangeSearch::new(starting_l_value, radius).unwrap(); + let result = index + .search(&FullPrecision, ctx, query, &mut range_search, &mut ()) + .await + .unwrap(); - assert_range_results_exactly_match(q, >, &ids, radius, None); + assert_range_results_exactly_match(q, >, &result.ids, radius, None); } { // Quantized Search - let (_, ids, _) = range_search( - &*index, - &Hybrid::new(None), - ctx, - query, - &RangeSearch::new(starting_l_value, radius).unwrap(), - ) - .await - .unwrap(); + let mut range_search = RangeSearch::new(starting_l_value, radius).unwrap(); + let result = index + .search(&Hybrid::new(None), ctx, query, &mut range_search, &mut ()) + .await + .unwrap(); - assert_range_results_exactly_match(q, >, &ids, radius, None); + assert_range_results_exactly_match(q, >, &result.ids, radius, None); } { // Test with an inner radius assert!(inner_radius <= radius); - let (_, ids, _) = range_search( - &*index, - &FullPrecision, - ctx, - query, - &RangeSearch::with_options( - None, - starting_l_value, - None, - radius, - Some(inner_radius), - 1.0, - 1.0, - ) - .unwrap(), + let mut range_search = RangeSearch::with_options( + None, + starting_l_value, + None, + radius, + Some(inner_radius), + 1.0, + 1.0, ) - .await .unwrap(); + let result = index + .search(&FullPrecision, ctx, query, &mut range_search, &mut ()) + .await + .unwrap(); - assert_range_results_exactly_match(q, >, &ids, radius, Some(inner_radius)); + assert_range_results_exactly_match(q, >, &result.ids, radius, Some(inner_radius)); } { // Test with a lower initial beam to trigger more two-round searches // We don't expect results to exactly match here - let (_, ids, _) = range_search( - &*index, - &FullPrecision, - ctx, - query, - &RangeSearch::new(lower_l_value, radius).unwrap(), - ) - .await - .unwrap(); + let mut range_search = RangeSearch::new(lower_l_value, radius).unwrap(); + let result = index + .search(&FullPrecision, ctx, query, &mut range_search, &mut ()) + .await + .unwrap(); // check that ids don't have duplicates let mut ids_set = std::collections::HashSet::new(); - for id in &ids { + for id in &result.ids { assert!(ids_set.insert(*id)); } } @@ -4165,17 +4101,17 @@ pub(crate) mod tests { let filter = RejectAllFilter::only([0_u32]); let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); - let stats = multihop_search( - &index, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &search_params, - &mut result_output_buffer, - &filter, - ) - .await - .unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); + let stats = index + .search( + &FullPrecision, + &DefaultContext, + query.as_slice(), + &mut multihop, + &mut result_output_buffer, + ) + .await + .unwrap(); // When all candidates are rejected via on_visit, result_count should be 0 // because rejected candidates are not added to the search frontier @@ -4228,17 +4164,17 @@ pub(crate) mod tests { let filter = TerminatingFilter::new(target); let search_params = KnnSearch::new_default(nz(10), nz(40)).unwrap(); - let stats = multihop_search( - &index, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &search_params, - &mut result_output_buffer, - &filter, - ) - .await - .unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); + let stats = index + .search( + &FullPrecision, + &DefaultContext, + query.as_slice(), + &mut multihop, + &mut result_output_buffer, + ) + .await + .unwrap(); let hits = filter.hits(); @@ -4293,17 +4229,17 @@ pub(crate) mod tests { search_output_buffer::IdDistance::new(&mut baseline_ids, &mut baseline_distances); let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); - let baseline_stats = multihop_search( - &index, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &search_params, - &mut baseline_buffer, - &EvenFilter, // Just filter to even IDs - ) - .await - .unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &EvenFilter); + let baseline_stats = index + .search( + &FullPrecision, + &DefaultContext, + query.as_slice(), + &mut multihop, + &mut baseline_buffer, + ) + .await + .unwrap(); // Now run with a filter that boosts a specific far-away point let boosted_point = (num_points - 2) as u32; // A point far from origin @@ -4315,17 +4251,17 @@ pub(crate) mod tests { search_output_buffer::IdDistance::new(&mut adjusted_ids, &mut adjusted_distances); let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); - let adjusted_stats = multihop_search( - &index, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &search_params, - &mut adjusted_buffer, - &filter, - ) - .await - .unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); + let adjusted_stats = index + .search( + &FullPrecision, + &DefaultContext, + query.as_slice(), + &mut multihop, + &mut adjusted_buffer, + ) + .await + .unwrap(); // Both searches should return results assert!( @@ -4442,17 +4378,17 @@ pub(crate) mod tests { let filter = TerminateAfterN::new(max_visits); let search_params = KnnSearch::new_default(nz(10), nz(100)).unwrap(); // Large L to ensure we'd visit more without termination - let _stats = multihop_search( - &index, - &FullPrecision, - &DefaultContext, - query.as_slice(), - &search_params, - &mut result_output_buffer, - &filter, - ) - .await - .unwrap(); + let mut multihop = graph::MultihopSearch::new(search_params, &filter); + let _stats = index + .search( + &FullPrecision, + &DefaultContext, + query.as_slice(), + &mut multihop, + &mut result_output_buffer, + ) + .await + .unwrap(); // The search should have stopped after max_visits assert!( diff --git a/diskann/src/graph/search/graph_search.rs b/diskann/src/graph/search/graph_search.rs deleted file mode 100644 index 354a85d14..000000000 --- a/diskann/src/graph/search/graph_search.rs +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -//! Standard graph-based ANN search. - -use std::fmt::Debug; - -use diskann_utils::future::{AssertSend, SendFuture}; - -use super::Search; -use crate::{ - ANNResult, - error::IntoANNResult, - graph::{ - glue::{SearchExt, SearchPostProcess, SearchStrategy}, - index::{DiskANNIndex, SearchStats}, - search::record::NoopSearchRecord, - search_output_buffer::SearchOutputBuffer, - }, - provider::{BuildQueryComputer, DataProvider}, - utils::IntoUsize, -}; - -/// Parameters for standard graph-based ANN search. -/// -/// This is the primary search mode, using the Vamana graph structure for efficient -/// approximate nearest neighbor traversal. -/// -/// This type is also exported as `SearchParams` for backwards compatibility. -#[derive(Debug, Clone, Copy)] -pub struct GraphSearch { - /// Number of results to return (k in k-NN). - k_value: usize, - /// Search list size - controls accuracy vs speed tradeoff. - l_value: usize, - /// Optional beam width for parallel graph exploration. - beam_width: Option, -} - -impl GraphSearch { - /// Create new graph search parameters. - /// - /// # Errors - /// - /// Returns an error if `l_value < k_value` or if any value is zero. - pub fn new( - k_value: usize, - l_value: usize, - beam_width: Option, - ) -> Result { - use super::super::SearchParamsError; - - if k_value > l_value { - return Err(SearchParamsError::LLessThanK { l_value, k_value }); - } - if let Some(bw) = beam_width { - if bw == 0 { - return Err(SearchParamsError::BeamWidthZero); - } - } - if k_value == 0 { - return Err(SearchParamsError::KZero); - } - if l_value == 0 { - return Err(SearchParamsError::LZero); - } - - Ok(Self { - k_value, - l_value, - beam_width, - }) - } - - /// Create parameters with default beam width. - pub fn new_default(k_value: usize, l_value: usize) -> Result { - Self::new(k_value, l_value, None) - } - - /// Returns the number of results to return (k in k-NN). - #[inline] - pub fn k_value(&self) -> usize { - self.k_value - } - - /// Returns the search list size. - #[inline] - pub fn l_value(&self) -> usize { - self.l_value - } - - /// Returns the optional beam width for parallel graph exploration. - #[inline] - pub fn beam_width(&self) -> Option { - self.beam_width - } -} - -impl Search for GraphSearch -where - DP: DataProvider, - T: Sync + ?Sized, - S: SearchStrategy, - O: Send, - OB: SearchOutputBuffer + Send + ?Sized, -{ - type Output = SearchStats; - - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, - ) -> impl SendFuture> { - async move { - let mut accessor = strategy - .search_accessor(&index.data_provider, context) - .into_ann_result()?; - - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - let mut scratch = index.search_scratch(self.l_value, start_ids.len()); - - let stats = index - .search_internal( - self.beam_width, - &start_ids, - &mut accessor, - &computer, - &mut scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(self.l_value.into_usize()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) - } - } -} - -/////////////////////////// -// Recorded Graph Search // -/////////////////////////// - -/// Graph search with traversal path recording. -/// -/// Records the path taken during search for debugging or analysis. -#[derive(Debug)] -pub struct RecordedGraphSearch<'r, SR: ?Sized> { - /// Base graph search parameters. - pub inner: GraphSearch, - /// The recorder to capture search path. - pub recorder: &'r mut SR, -} - -impl<'r, SR: ?Sized> RecordedGraphSearch<'r, SR> { - /// Create new recorded search parameters. - pub fn new(inner: GraphSearch, recorder: &'r mut SR) -> Self { - Self { inner, recorder } - } -} - -impl<'r, DP, S, T, O, OB, SR> Search for RecordedGraphSearch<'r, SR> -where - DP: DataProvider, - T: Sync + ?Sized, - S: SearchStrategy, - O: Send, - OB: SearchOutputBuffer + Send + ?Sized, - SR: super::record::SearchRecord + ?Sized, -{ - type Output = SearchStats; - - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, - ) -> impl SendFuture> { - async move { - let mut accessor = strategy - .search_accessor(&index.data_provider, context) - .into_ann_result()?; - - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - let mut scratch = index.search_scratch(self.inner.l_value, start_ids.len()); - - let stats = index - .search_internal( - self.inner.beam_width, - &start_ids, - &mut accessor, - &computer, - &mut scratch, - self.recorder, - ) - .await?; - - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(self.inner.l_value.into_usize()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) - } - } -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_graph_search_validation() { - // Valid - assert!(GraphSearch::new(10, 100, None).is_ok()); - assert!(GraphSearch::new(10, 100, Some(4)).is_ok()); - assert!(GraphSearch::new(10, 10, None).is_ok()); // k == l is valid - - // Invalid: l < k - assert!(GraphSearch::new(100, 10, None).is_err()); - - // Invalid: zero values - assert!(GraphSearch::new(0, 100, None).is_err()); - assert!(GraphSearch::new(10, 0, None).is_err()); - assert!(GraphSearch::new(10, 100, Some(0)).is_err()); - } -} diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 1c790a2eb..b46bf07db 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -110,6 +110,35 @@ impl KnnSearch { } } +/// Standard k-NN graph-based search implementation. +/// +/// This is the primary search type for approximate nearest neighbor queries. It performs +/// a greedy beam search over the graph, maintaining a priority queue of the best candidates +/// found so far. The search explores neighbors of promising candidates until convergence. +/// +/// # Algorithm +/// +/// 1. Initialize with starting points +/// 2. Compute distances from query to starting points +/// 3. Greedily expand the most promising unexplored candidate +/// 4. Add the candidate's neighbors to the frontier +/// 5. Repeat until no unexplored candidates remain within the search list +/// 6. Return the top-k results from the best candidates found +/// +/// # Parameters +/// +/// - `k_value`: Number of nearest neighbors to return +/// - `l_value`: Search list size (larger values improve recall at cost of latency) +/// - `beam_width`: Optional parallel exploration width +/// +/// # Example +/// +/// ```ignore +/// use diskann::graph::{search::KnnSearch, Search}; +/// +/// let mut params = KnnSearch::new(10, 100, None)?; +/// let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?; +/// ``` impl Search for KnnSearch where DP: DataProvider, @@ -120,6 +149,30 @@ where { type Output = SearchStats; + /// Execute the k-NN search on the given index. + /// + /// This method executes a search using the provided `strategy` to access and process elements. + /// It computes the similarity between the query vector and the elements in the index, traversing + /// the graph towards the nearest neighbors according to the search parameters. + /// + /// # Arguments + /// + /// * `index` - The DiskANN index to search. + /// * `strategy` - The search strategy to use for accessing and processing elements. + /// * `context` - The context to pass through to providers. + /// * `query` - The query vector for which nearest neighbors are sought. + /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. + /// + /// # Returns + /// + /// Returns [`SearchStats`] containing: + /// - The number of distance computations performed. + /// - The number of hops (graph traversal steps). + /// - Timing information for the search operation. + /// + /// # Errors + /// + /// Returns an error if there is a failure accessing elements or computing distances. fn dispatch<'a>( &'a mut self, index: &'a DiskANNIndex, diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 65dd182ae..7847cfcbe 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -48,6 +48,27 @@ where type Output; /// Execute the search operation with full search logic. + /// + /// This method executes a search using the provided `strategy` to access and process elements. + /// It computes the similarity between the query vector and the elements in the index, + /// finding nearest neighbors according to the search parameters. + /// + /// # Arguments + /// + /// * `index` - The DiskANN index to search. + /// * `strategy` - The search strategy to use for accessing and processing elements. + /// * `context` - The context to pass through to providers. + /// * `query` - The query vector for which nearest neighbors are sought. + /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. + /// + /// # Returns + /// + /// Returns `Self::Output` which varies by search type (e.g., [`SearchStats`](super::index::SearchStats) + /// for k-NN, [`RangeSearchOutput`] for range search). + /// + /// # Errors + /// + /// Returns an error if there is a failure accessing elements or computing distances. fn dispatch<'a>( &'a mut self, index: &'a DiskANNIndex, From 5e81b353c2ba9c3aad37a0b6d97fe0f4d24c7395 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Fri, 13 Feb 2026 21:43:40 +0530 Subject: [PATCH 09/11] Remove debug_search, use RecordedKnnSearch + search() directly; Remove diverse_search_experimental, use DiverseSearch + search() directly --- .../src/search/provider/disk_provider.rs | 57 +++++++++---------- diskann-providers/src/index/diskann_async.rs | 15 ++--- diskann/src/graph/index.rs | 34 ----------- 3 files changed, 32 insertions(+), 74 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index f6b18332a..712d2bdfa 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -1633,15 +1633,17 @@ mod disk_provider_tests { ); let strategy = search_engine.search_strategy(&query_vector, &|_| true); let mut search_record = VisitedSearchRecord::new(0); + let search_params = KnnSearch::new(nz(10), nz(10), Some(4)).unwrap(); + let mut recorded_search = + diskann::graph::search::RecordedKnnSearch::new(search_params, &mut search_record); search_engine .runtime - .block_on(search_engine.index.debug_search( + .block_on(search_engine.index.search( &strategy, &DefaultContext, - &query_vector, - &KnnSearch::new(nz(10), nz(10), Some(4)).unwrap(), + query_vector.as_slice(), + &mut recorded_search, &mut result_output_buffer, - &mut search_record, )) .unwrap(); @@ -1752,7 +1754,6 @@ mod disk_provider_tests { &mut associated_data, ); let strategy = search_engine.search_strategy(&query_vector, &|_| true); - let mut search_record = VisitedSearchRecord::new(0); // Create diverse search parameters with attribute provider let diverse_params = DiverseSearchParams::new( @@ -1763,29 +1764,22 @@ mod disk_provider_tests { let search_params = KnnSearch::new(nz(10), nz(20), None).unwrap(); - search_engine + let mut diverse_search = diskann::graph::DiverseSearch::new(search_params, diverse_params); + let stats = search_engine .runtime - .block_on(search_engine.index.diverse_search_experimental( + .block_on(search_engine.index.search( &strategy, &DefaultContext, - &query_vector, - &search_params, - &diverse_params, + query_vector.as_slice(), + &mut diverse_search, &mut result_output_buffer, - &mut search_record, )) .unwrap(); - let ids = search_record - .visited - .iter() - .map(|n| n.id) - .collect::>(); - - // Verify that search was performed and visited some nodes + // Verify that search was performed and returned some results assert!( - !ids.is_empty(), - "Expected to visit some nodes during diversity search" + stats.result_count > 0, + "Expected to get some results during diversity search" ); let return_list_size = 10; @@ -1797,7 +1791,7 @@ mod disk_provider_tests { attribute_provider.clone(), ); - // Test diverse search using the experimental API + // Test diverse search using the search API let mut indices2 = vec![0u32; return_list_size as usize]; let mut distances2 = vec![0f32; return_list_size as usize]; let mut associated_data2 = vec![(); return_list_size as usize]; @@ -1807,7 +1801,6 @@ mod disk_provider_tests { &mut associated_data2, ); let strategy2 = search_engine.search_strategy(&query_vector, &|_| true); - let mut search_record2 = VisitedSearchRecord::new(0); let search_params2 = KnnSearch::new( nz(return_list_size as usize), nz(search_list_size as usize), @@ -1815,16 +1808,16 @@ mod disk_provider_tests { ) .unwrap(); + let mut diverse_search2 = + diskann::graph::DiverseSearch::new(search_params2, diverse_params); let stats = search_engine .runtime - .block_on(search_engine.index.diverse_search_experimental( + .block_on(search_engine.index.search( &strategy2, &DefaultContext, - &query_vector, - &search_params2, - &diverse_params, + query_vector.as_slice(), + &mut diverse_search2, &mut result_output_buffer2, - &mut search_record2, )) .unwrap(); @@ -2099,15 +2092,17 @@ mod disk_provider_tests { let strategy = search_engine.search_strategy(&query_vector, &|_| true); let mut search_record = VisitedSearchRecord::new(0); + let search_params = KnnSearch::new(nz(10), nz(10), Some(4)).unwrap(); + let mut recorded_search = + diskann::graph::search::RecordedKnnSearch::new(search_params, &mut search_record); search_engine .runtime - .block_on(search_engine.index.debug_search( + .block_on(search_engine.index.search( &strategy, &DefaultContext, - &query_vector, - &KnnSearch::new(nz(10), nz(10), Some(4)).unwrap(), + query_vector.as_slice(), + &mut recorded_search, &mut result_output_buffer, - &mut search_record, )) .unwrap(); let visited_ids = search_record diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 1a599a85f..ef648045d 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -3880,24 +3880,21 @@ pub(crate) mod tests { ); let search_params = diskann::graph::KnnSearch::new( - return_list_size, - search_list_size, + nz(return_list_size), + nz(search_list_size), None, // beam_width ) .unwrap(); - use diskann::graph::search::record::NoopSearchRecord; - let mut search_record = NoopSearchRecord::new(); + let mut diverse_search = diskann::graph::DiverseSearch::new(search_params, diverse_params); let result = index - .diverse_search_experimental( + .search( &FullPrecision, &DefaultContext, - &query, - &search_params, - &diverse_params, + query.as_slice(), + &mut diverse_search, &mut result_output_buffer, - &mut search_record, ) .await; diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index e444cfac1..b848eaf32 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -31,7 +31,6 @@ use super::{ }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ - Search, record::{NoopSearchRecord, SearchRecord, VisitedSearchRecord}, scratch::{self, PriorityQueueConfiguration, SearchScratch, SearchScratchParams}, }, @@ -2159,39 +2158,6 @@ where search_params.dispatch(self, strategy, context, query, output) } - /// Perform a graph search while recording the traversal path. - /// - /// **Note:** This method is intended for debugging and analysis only. - /// For production searches, use [`Self::search`] with [`super::search::KnnSearch`]. - /// - /// Records which nodes were visited during the search traversal, useful for - /// understanding search behavior or diagnosing issues. - #[allow(clippy::too_many_arguments)] - pub fn debug_search<'a, S, T, O, OB, SR>( - &'a self, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - search_params: &'a KnnSearch, - output: &'a mut OB, - search_record: &'a mut SR, - ) -> impl SendFuture> + 'a - where - T: Sync + ?Sized, - S: SearchStrategy, - O: Send + 'a, - OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, - SR: SearchRecord, - { - async move { - let mut recorded_search = - super::search::RecordedKnnSearch::new(*search_params, search_record); - recorded_search - .dispatch(self, strategy, context, query, output) - .await - } - } - /// Performs a brute-force flat search over the points matching a provided filter function. /// /// This method executes a linear scan through all points in the index, applying the provided From 168bfe5be16ce7decc146e320cf56404af7b737f Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Sat, 14 Feb 2026 15:42:05 +0530 Subject: [PATCH 10/11] Implement PR review feedback: search interface improvements - Rename dispatch() to search() in Search trait and all implementations - Reorder DiskANNIndex::search params: search_params now first argument - KnnSearch::new takes usize instead of NonZeroUsize (with KZero/LZero errors) - Change beam_width from Option to Option - Add #[track_caller] to From and From - Make DiverseSearch fields private with accessor methods - Remove unused nz() helper functions - Update all callers across workspace --- .../src/search/graph/knn.rs | 10 +- .../src/search/graph/multihop.rs | 10 +- .../src/search/graph/range.rs | 2 +- .../src/backend/index/search/knn.rs | 4 +- .../src/search/provider/disk_provider.rs | 36 ++--- diskann-providers/src/index/diskann_async.rs | 82 +++++------ diskann-providers/src/index/wrapped_async.rs | 2 +- diskann/src/graph/index.rs | 10 +- diskann/src/graph/search/diverse_search.rs | 34 +++-- diskann/src/graph/search/knn_search.rs | 136 ++++++++---------- diskann/src/graph/search/mod.rs | 22 +-- diskann/src/graph/search/multihop_search.rs | 16 +-- diskann/src/graph/search/range_search.rs | 15 +- diskann/src/graph/test/cases/grid.rs | 11 +- 14 files changed, 183 insertions(+), 207 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 71031a944..427169044 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -117,10 +117,10 @@ where let stats = self .index .search( + &mut knn_search, self.strategy.get(index)?, &context, self.queries.row(index), - &mut knn_search, buffer, ) .await?; @@ -313,7 +313,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( knn.clone(), - graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(10).unwrap(), None).unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -337,13 +337,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(10).unwrap(), None) - .unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(), setup.clone(), ), search::Run::new( - graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(15).unwrap(), None) - .unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 15, None).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 92741311b..245f1717b 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -115,10 +115,10 @@ where let stats = self .index .search( + &mut multihop_search, self.strategy.get(index)?, &context, self.queries.row(index), - &mut multihop_search, buffer, ) .await?; @@ -181,7 +181,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( multihop.clone(), - graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(10).unwrap(), None).unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -209,13 +209,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(10).unwrap(), None) - .unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(), setup.clone(), ), search::Run::new( - graph::KnnSearch::new(nearest_neighbors, NonZeroUsize::new(15).unwrap(), None) - .unwrap(), + graph::KnnSearch::new(nearest_neighbors.get(), 15, None).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index f82064f6e..e66fef5ec 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -108,10 +108,10 @@ where let result = self .index .search( + &mut range_search, self.strategy.get(index)?, &context, self.queries.row(index), - &mut range_search, &mut (), ) .await?; diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 0f32b0b2c..2ad3a8dd5 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -49,8 +49,8 @@ pub(crate) fn run( .search_l .iter() .map(|search_l| { - let k = NonZeroUsize::new(run.search_n).expect("search_n must be non-zero"); - let l = NonZeroUsize::new(*search_l).expect("search_l must be non-zero"); + let k = run.search_n; + let l = *search_l; let search_params = diskann::graph::KnnSearch::new(k, l, None).unwrap(); core_search::Run::new(search_params, setup.clone()) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 712d2bdfa..b53c0eaad 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -983,9 +983,8 @@ where let strategy = self.search_strategy(query, vector_filter); let timer = Instant::now(); - let k = NonZeroUsize::new(k_value).expect("k_value must be non-zero"); - let l = NonZeroUsize::new(search_list_size as usize) - .expect("search_list_size must be non-zero"); + let k = k_value; + let l = search_list_size as usize; let stats = if is_flat_search { self.runtime.block_on(self.index.flat_search( &strategy, @@ -998,10 +997,10 @@ where } else { let mut knn_search = KnnSearch::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( + &mut knn_search, &strategy, &DefaultContext, strategy.query, - &mut knn_search, &mut result_output_buffer, ))? }; @@ -1074,11 +1073,6 @@ mod disk_provider_tests { utils::{QueryStatistics, VirtualAlignedReaderFactory}, }; - /// Helper to create NonZeroUsize from usize (for tests only). - fn nz(v: usize) -> NonZeroUsize { - NonZeroUsize::new(v).expect("value must be non-zero") - } - const TEST_INDEX_PREFIX_128DIM: &str = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search"; const TEST_INDEX_128DIM: &str = @@ -1541,14 +1535,14 @@ mod disk_provider_tests { ); // Test error case: l < k - let res = KnnSearch::new_default(nz(20), nz(10)); + let res = KnnSearch::new_default(20, 10); assert!(res.is_err()); assert_eq!( >::into(res.unwrap_err()).kind(), ANNErrorKind::IndexError ); // Test error case: beam_width = 0 - let res = KnnSearch::new(nz(10), nz(10), Some(0)); + let res = KnnSearch::new(10, 10, Some(0)); assert!(res.is_err()); let search_engine = @@ -1633,7 +1627,7 @@ mod disk_provider_tests { ); let strategy = search_engine.search_strategy(&query_vector, &|_| true); let mut search_record = VisitedSearchRecord::new(0); - let search_params = KnnSearch::new(nz(10), nz(10), Some(4)).unwrap(); + let search_params = KnnSearch::new(10, 10, Some(4)).unwrap(); let mut recorded_search = diskann::graph::search::RecordedKnnSearch::new(search_params, &mut search_record); search_engine @@ -1762,16 +1756,16 @@ mod disk_provider_tests { attribute_provider.clone(), ); - let search_params = KnnSearch::new(nz(10), nz(20), None).unwrap(); + let search_params = KnnSearch::new(10, 20, None).unwrap(); let mut diverse_search = diskann::graph::DiverseSearch::new(search_params, diverse_params); let stats = search_engine .runtime .block_on(search_engine.index.search( + &mut diverse_search, &strategy, &DefaultContext, query_vector.as_slice(), - &mut diverse_search, &mut result_output_buffer, )) .unwrap(); @@ -1801,22 +1795,18 @@ mod disk_provider_tests { &mut associated_data2, ); let strategy2 = search_engine.search_strategy(&query_vector, &|_| true); - let search_params2 = KnnSearch::new( - nz(return_list_size as usize), - nz(search_list_size as usize), - None, - ) - .unwrap(); + let search_params2 = + KnnSearch::new(return_list_size as usize, search_list_size as usize, None).unwrap(); let mut diverse_search2 = diskann::graph::DiverseSearch::new(search_params2, diverse_params); let stats = search_engine .runtime .block_on(search_engine.index.search( + &mut diverse_search2, &strategy2, &DefaultContext, query_vector.as_slice(), - &mut diverse_search2, &mut result_output_buffer2, )) .unwrap(); @@ -2092,16 +2082,16 @@ mod disk_provider_tests { let strategy = search_engine.search_strategy(&query_vector, &|_| true); let mut search_record = VisitedSearchRecord::new(0); - let search_params = KnnSearch::new(nz(10), nz(10), Some(4)).unwrap(); + let search_params = KnnSearch::new(10, 10, Some(4)).unwrap(); let mut recorded_search = diskann::graph::search::RecordedKnnSearch::new(search_params, &mut search_record); search_engine .runtime .block_on(search_engine.index.search( + &mut recorded_search, &strategy, &DefaultContext, query_vector.as_slice(), - &mut recorded_search, &mut result_output_buffer, )) .unwrap(); diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index ef648045d..6ecc2ba36 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -215,11 +215,6 @@ pub(crate) mod tests { // Callbacks for use with `simplified_builder`. fn no_modify(_: &mut diskann::graph::config::Builder) {} - /// Helper to create NonZeroUsize from usize (for tests only). - fn nz(v: usize) -> NonZeroUsize { - NonZeroUsize::new(v).expect("value must be non-zero") - } - ////////////////////////// // Test helper functions // ///////////////////////////////////////// @@ -362,7 +357,7 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let mut graph_search = - graph::KnnSearch::new_default(nz(parameters.search_k), nz(parameters.search_l)) + graph::KnnSearch::new_default(parameters.search_k, parameters.search_l) .unwrap(); index .search( @@ -411,7 +406,7 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let search_params = - KnnSearch::new_default(nz(parameters.search_k), nz(parameters.search_l)).unwrap(); + KnnSearch::new_default(parameters.search_k, parameters.search_l).unwrap(); let mut multihop = graph::MultihopSearch::new(search_params, filter); index .search( @@ -1456,7 +1451,7 @@ pub(crate) mod tests { let filter = CallbackFilter::new(blocked, adjusted, 0.5); let search_params = - KnnSearch::new_default(nz(parameters.search_k), nz(parameters.search_l)).unwrap(); + KnnSearch::new_default(parameters.search_k, parameters.search_l).unwrap(); let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index .search( @@ -2204,15 +2199,14 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut graph_search = - graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &FullPrecision, ctx, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -2223,15 +2217,14 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut graph_search = - graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Quantized Search index .search( + &mut graph_search, &Hybrid::new(None), ctx, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -2292,7 +2285,7 @@ pub(crate) mod tests { // Full Precision Search. let mut range_search = RangeSearch::new(starting_l_value, radius).unwrap(); let result = index - .search(&FullPrecision, ctx, query, &mut range_search, &mut ()) + .search(&mut range_search, &FullPrecision, ctx, query, &mut ()) .await .unwrap(); @@ -2303,7 +2296,7 @@ pub(crate) mod tests { // Quantized Search let mut range_search = RangeSearch::new(starting_l_value, radius).unwrap(); let result = index - .search(&Hybrid::new(None), ctx, query, &mut range_search, &mut ()) + .search(&mut range_search, &Hybrid::new(None), ctx, query, &mut ()) .await .unwrap(); @@ -2325,7 +2318,7 @@ pub(crate) mod tests { ) .unwrap(); let result = index - .search(&FullPrecision, ctx, query, &mut range_search, &mut ()) + .search(&mut range_search, &FullPrecision, ctx, query, &mut ()) .await .unwrap(); @@ -2337,7 +2330,7 @@ pub(crate) mod tests { // We don't expect results to exactly match here let mut range_search = RangeSearch::new(lower_l_value, radius).unwrap(); let result = index - .search(&FullPrecision, ctx, query, &mut range_search, &mut ()) + .search(&mut range_search, &FullPrecision, ctx, query, &mut ()) .await .unwrap(); @@ -2459,14 +2452,14 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let mut graph_search = - graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &FullPrecision, ctx, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -2478,14 +2471,14 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let mut graph_search = - graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Quantized Search index .search( + &mut graph_search, &Quantized, ctx, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -2565,15 +2558,14 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut graph_search = - graph::KnnSearch::new_default(nz(top_k), nz(top_k)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, top_k).unwrap(); // Quantized Search index .search( + &mut graph_search, &Quantized, ctx, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -2680,9 +2672,9 @@ pub(crate) mod tests { // Full Precision Search. let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); index - .search(&FullPrecision, ctx, query, &mut graph_search, &mut output) + .search(&mut graph_search, &FullPrecision, ctx, query, &mut output) .await .unwrap(); assert_top_k_exactly_match(q, >, &ids, &distances, top_k); @@ -2692,10 +2684,10 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); - let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); index - .search(&strategy, ctx, query, &mut graph_search, &mut output) + .search(&mut graph_search, &strategy, ctx, query, &mut output) .await .unwrap(); assert_top_k_exactly_match(q, >, &ids, &distances, top_k); @@ -2795,10 +2787,10 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); - let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); index - .search(&strategy, ctx, query, &mut graph_search, &mut output) + .search(&mut graph_search, &strategy, ctx, query, &mut output) .await .unwrap(); @@ -2882,14 +2874,14 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &Quantized, ctx, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -3463,14 +3455,14 @@ pub(crate) mod tests { let gt = groundtruth(queries.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &Hybrid::new(max_fp_vecs_per_prune), ctx, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -3610,14 +3602,14 @@ pub(crate) mod tests { let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let mut graph_search = graph::KnnSearch::new_default(nz(top_k), nz(search_l)).unwrap(); + let mut graph_search = graph::KnnSearch::new_default(top_k, search_l).unwrap(); // Full Precision Search. index .search( + &mut graph_search, &FullPrecision, ctx, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -3880,8 +3872,8 @@ pub(crate) mod tests { ); let search_params = diskann::graph::KnnSearch::new( - nz(return_list_size), - nz(search_list_size), + return_list_size, + search_list_size, None, // beam_width ) .unwrap(); @@ -3890,10 +3882,10 @@ pub(crate) mod tests { let result = index .search( + &mut diverse_search, &FullPrecision, &DefaultContext, query.as_slice(), - &mut diverse_search, &mut result_output_buffer, ) .await; @@ -4097,7 +4089,7 @@ pub(crate) mod tests { // but reject everything via on_visit let filter = RejectAllFilter::only([0_u32]); - let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); + let search_params = KnnSearch::new_default(10, 20).unwrap(); let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index .search( @@ -4160,7 +4152,7 @@ pub(crate) mod tests { let target = (num_points / 2) as u32; let filter = TerminatingFilter::new(target); - let search_params = KnnSearch::new_default(nz(10), nz(40)).unwrap(); + let search_params = KnnSearch::new_default(10, 40).unwrap(); let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index .search( @@ -4225,7 +4217,7 @@ pub(crate) mod tests { let mut baseline_buffer = search_output_buffer::IdDistance::new(&mut baseline_ids, &mut baseline_distances); - let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); + let search_params = KnnSearch::new_default(10, 20).unwrap(); let mut multihop = graph::MultihopSearch::new(search_params, &EvenFilter); let baseline_stats = index .search( @@ -4247,7 +4239,7 @@ pub(crate) mod tests { let mut adjusted_buffer = search_output_buffer::IdDistance::new(&mut adjusted_ids, &mut adjusted_distances); - let search_params = KnnSearch::new_default(nz(10), nz(20)).unwrap(); + let search_params = KnnSearch::new_default(10, 20).unwrap(); let mut multihop = graph::MultihopSearch::new(search_params, &filter); let adjusted_stats = index .search( @@ -4374,7 +4366,7 @@ pub(crate) mod tests { let max_visits = 5; let filter = TerminateAfterN::new(max_visits); - let search_params = KnnSearch::new_default(nz(10), nz(100)).unwrap(); // Large L to ensure we'd visit more without termination + let search_params = KnnSearch::new_default(10, 100).unwrap(); // Large L to ensure we'd visit more without termination let mut multihop = graph::MultihopSearch::new(search_params, &filter); let _stats = index .search( diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 4ca45c5d8..0b4787213 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -238,7 +238,7 @@ where let mut knn_search = *search_params; self.handle.block_on( self.inner - .search(strategy, context, query, &mut knn_search, output), + .search(&mut knn_search, strategy, context, query, output), ) } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index b848eaf32..b7521c912 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2135,19 +2135,19 @@ where /// /// // Standard k-NN search /// let mut params = KnnSearch::new(10, 100, None)?; - /// let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?;; + /// let stats = index.search(&mut params, &strategy, &context, &query, &mut output).await?; /// /// // Range search (note: uses () as output buffer, results in Output type) /// let mut params = RangeSearch::new(100, 0.5)?; - /// let result = index.search(&strategy, &context, &query, &mut params, &mut ()).await?; + /// let result = index.search(&mut params, &strategy, &context, &query, &mut ()).await?; /// // result.ids and result.distances contain the matches /// ``` pub fn search<'a, S, T, O: 'a, OB, P>( &'a self, + search_params: &'a mut P, strategy: &'a S, context: &'a DP::Context, query: &'a T, - search_params: &'a mut P, output: &'a mut OB, ) -> impl SendFuture> + 'a where @@ -2155,7 +2155,7 @@ where T: ?Sized, OB: ?Sized, { - search_params.dispatch(self, strategy, context, query, output) + search_params.search(self, strategy, context, query, output) } /// Performs a brute-force flat search over the points matching a provided filter function. @@ -2170,7 +2170,7 @@ where /// * `context` - The context to pass through to providers. /// * `query` - The query vector for which nearest neighbors are sought. /// * `vector_filter` - A predicate function used to filter candidate vectors based on their external IDs. - /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`) and beam width. + /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`). /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. /// /// # Returns diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index 023327fdf..dc919fc08 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -31,9 +31,9 @@ where P: AttributeValueProvider, { /// Base k-NN search parameters. - pub inner: KnnSearch, + inner: KnnSearch, /// Diversity-specific parameters. - pub diverse_params: DiverseSearchParams

, + diverse_params: DiverseSearchParams

, } impl

DiverseSearch

@@ -48,6 +48,18 @@ where } } + /// Returns a reference to the inner k-NN search parameters. + #[inline] + pub fn inner(&self) -> &KnnSearch { + &self.inner + } + + /// Returns a reference to the diversity-specific parameters. + #[inline] + pub fn diverse_params(&self) -> &DiverseSearchParams

{ + &self.diverse_params + } + /// Create search scratch with DiverseNeighborQueue for this search. fn create_scratch( &self, @@ -71,7 +83,7 @@ where index.estimate_visited_set_capacity(Some(self.inner.l_value().get())), ), id_scratch: Vec::with_capacity(index.max_degree_with_slack()), - beam_nodes: Vec::with_capacity(self.inner.beam_width().unwrap_or(1)), + beam_nodes: Vec::with_capacity(self.inner.beam_width().map_or(1, |nz| nz.get())), range_frontier: std::collections::VecDeque::new(), in_range: Vec::new(), hops: 0, @@ -91,13 +103,13 @@ where { type Output = SearchStats; - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, ) -> impl SendFuture> { async move { let mut accessor = strategy @@ -111,7 +123,7 @@ where let stats = index .search_internal( - self.inner.beam_width(), + self.inner.beam_width().map(|nz| nz.get()), &start_ids, &mut accessor, &computer, diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index b46bf07db..f97fa7643 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -31,9 +31,14 @@ pub enum KnnSearchError { LLessThanK { l_value: usize, k_value: usize }, #[error("beam width cannot be zero")] BeamWidthZero, + #[error("k_value cannot be zero")] + KZero, + #[error("l_value cannot be zero")] + LZero, } impl From for ANNError { + #[track_caller] fn from(err: KnnSearchError) -> Self { Self::new(ANNErrorKind::IndexError, err) } @@ -50,7 +55,7 @@ pub struct KnnSearch { /// Search list size - controls accuracy vs speed tradeoff. l_value: NonZeroUsize, /// Optional beam width for parallel graph exploration. - beam_width: Option, + beam_width: Option, } impl KnnSearch { @@ -58,17 +63,21 @@ impl KnnSearch { /// /// # Errors /// - /// Returns an error if `l_value < k_value` or if beam_width is zero. + /// Returns an error if `k_value` is zero, `l_value` is zero, + /// `l_value < k_value`, or if `beam_width` is zero. pub fn new( - k_value: NonZeroUsize, - l_value: NonZeroUsize, + k_value: usize, + l_value: usize, beam_width: Option, ) -> Result { + if k_value == 0 { + return Err(KnnSearchError::KZero); + } + if l_value == 0 { + return Err(KnnSearchError::LZero); + } if k_value > l_value { - return Err(KnnSearchError::LLessThanK { - l_value: l_value.get(), - k_value: k_value.get(), - }); + return Err(KnnSearchError::LLessThanK { l_value, k_value }); } if let Some(bw) = beam_width && bw == 0 @@ -76,18 +85,16 @@ impl KnnSearch { return Err(KnnSearchError::BeamWidthZero); } + // SAFETY: We've validated k_value != 0 and l_value != 0 above Ok(Self { - k_value, - l_value, - beam_width, + k_value: unsafe { NonZeroUsize::new_unchecked(k_value) }, + l_value: unsafe { NonZeroUsize::new_unchecked(l_value) }, + beam_width: beam_width.and_then(NonZeroUsize::new), }) } /// Create parameters with default beam width. - pub fn new_default( - k_value: NonZeroUsize, - l_value: NonZeroUsize, - ) -> Result { + pub fn new_default(k_value: usize, l_value: usize) -> Result { Self::new(k_value, l_value, None) } @@ -105,7 +112,7 @@ impl KnnSearch { /// Returns the optional beam width for parallel graph exploration. #[inline] - pub fn beam_width(&self) -> Option { + pub fn beam_width(&self) -> Option { self.beam_width } } @@ -137,7 +144,7 @@ impl KnnSearch { /// use diskann::graph::{search::KnnSearch, Search}; /// /// let mut params = KnnSearch::new(10, 100, None)?; -/// let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?; +/// let stats = index.search(&mut params, &strategy, &context, &query, &mut output).await?; /// ``` impl Search for KnnSearch where @@ -173,13 +180,13 @@ where /// # Errors /// /// Returns an error if there is a failure accessing elements or computing distances. - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, ) -> impl SendFuture> { async move { let mut accessor = strategy @@ -193,7 +200,7 @@ where let stats = index .search_internal( - self.beam_width, + self.beam_width.map(|nz| nz.get()), &start_ids, &mut accessor, &computer, @@ -253,13 +260,13 @@ where { type Output = SearchStats; - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, ) -> impl SendFuture> { async move { let mut accessor = strategy @@ -273,7 +280,7 @@ where let stats = index .search_internal( - self.inner.beam_width, + self.inner.beam_width.map(|nz| nz.get()), &start_ids, &mut accessor, &computer, @@ -314,49 +321,32 @@ mod tests { #[test] fn test_knn_search_validation() { // Valid - assert!( - KnnSearch::new( - NonZeroUsize::new(10).unwrap(), - NonZeroUsize::new(100).unwrap(), - None - ) - .is_ok() - ); - assert!( - KnnSearch::new( - NonZeroUsize::new(10).unwrap(), - NonZeroUsize::new(100).unwrap(), - Some(4) - ) - .is_ok() - ); - assert!( - KnnSearch::new( - NonZeroUsize::new(10).unwrap(), - NonZeroUsize::new(10).unwrap(), - None - ) - .is_ok() - ); // k == l is valid + assert!(KnnSearch::new(10, 100, None).is_ok()); + assert!(KnnSearch::new(10, 100, Some(4)).is_ok()); + assert!(KnnSearch::new(10, 10, None).is_ok()); // k == l is valid + + // Invalid: k = 0 + assert!(matches!( + KnnSearch::new(0, 100, None), + Err(KnnSearchError::KZero) + )); + + // Invalid: l = 0 + assert!(matches!( + KnnSearch::new(10, 0, None), + Err(KnnSearchError::LZero) + )); // Invalid: l < k - assert!( - KnnSearch::new( - NonZeroUsize::new(100).unwrap(), - NonZeroUsize::new(10).unwrap(), - None - ) - .is_err() - ); + assert!(matches!( + KnnSearch::new(100, 10, None), + Err(KnnSearchError::LLessThanK { .. }) + )); // Invalid: zero beam_width - assert!( - KnnSearch::new( - NonZeroUsize::new(10).unwrap(), - NonZeroUsize::new(100).unwrap(), - Some(0) - ) - .is_err() - ); + assert!(matches!( + KnnSearch::new(10, 100, Some(0)), + Err(KnnSearchError::BeamWidthZero) + )); } } diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 7847cfcbe..49df6dbac 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -15,12 +15,12 @@ //! use diskann::graph::{KnnSearch, RangeSearch, MultihopSearch, Search}; //! //! // Standard k-NN search -//! let mut params = KnnSearch::new(10, 100, None)?;; -//! let stats = index.search(&strategy, &context, &query, &mut params, &mut output).await?; +//! let mut params = KnnSearch::new(10, 100, None)?; +//! let stats = index.search(&mut params, &strategy, &context, &query, &mut output).await?; //! //! // Range search //! let mut params = RangeSearch::new(100, 0.5)?; -//! let result = index.search(&strategy, &context, &query, &mut params, &mut ()).await?; +//! let result = index.search(&mut params, &strategy, &context, &query, &mut ()).await?; //! println!("Found {} points within radius", result.ids.len()); //! ``` @@ -39,7 +39,7 @@ pub(crate) mod scratch; /// /// Each search type (graph search, range search, etc.) implements this trait /// to define its complete search behavior. The [`DiskANNIndex::search`] method -/// delegates to the `dispatch` method. +/// delegates to the `search` method. pub trait Search where DP: DataProvider, @@ -69,13 +69,13 @@ where /// # Errors /// /// Returns an error if there is a failure accessing elements or computing distances. - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, ) -> impl SendFuture>; } diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 57706aef2..240c77b86 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -63,13 +63,13 @@ where { type Output = SearchStats; - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - output: &'a mut OB, + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, ) -> impl SendFuture> { let params = self.inner; async move { @@ -185,7 +185,7 @@ where T: ?Sized, SR: SearchRecord + ?Sized, { - let beam_width = search_params.beam_width().unwrap_or(1); + let beam_width = search_params.beam_width().map_or(1, |nz| nz.get()); // Helper to build the final stats from scratch state. let make_stats = |scratch: &SearchScratch| InternalSearchStats { diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index cbae5e38f..8d878eaff 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -49,6 +49,7 @@ pub enum RangeSearchError { } impl From for ANNError { + #[track_caller] fn from(err: RangeSearchError) -> Self { Self::new(ANNErrorKind::IndexError, err) } @@ -175,13 +176,13 @@ where { type Output = RangeSearchOutput; - fn dispatch<'a>( - &'a mut self, - index: &'a DiskANNIndex, - strategy: &'a S, - context: &'a DP::Context, - query: &'a T, - _output: &'a mut (), + fn search( + &mut self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + _output: &mut (), ) -> impl SendFuture> { let search_params = *self; async move { diff --git a/diskann/src/graph/test/cases/grid.rs b/diskann/src/graph/test/cases/grid.rs index 9af303260..4c4fe047f 100644 --- a/diskann/src/graph/test/cases/grid.rs +++ b/diskann/src/graph/test/cases/grid.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use std::{num::NonZeroUsize, sync::Arc}; +use std::sync::Arc; use diskann_vector::distance::Metric; @@ -126,12 +126,7 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { // are correct. let index = setup_grid_search(grid, size); - let mut params = KnnSearch::new( - NonZeroUsize::new(10).unwrap(), - NonZeroUsize::new(10).unwrap(), - Some(beam_width), - ) - .unwrap(); + let mut params = KnnSearch::new(10, 10, Some(beam_width)).unwrap(); let context = test_provider::Context::new(); let mut neighbors = vec![Neighbor::::default(); params.k_value().get()]; @@ -142,10 +137,10 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { range_search_second_round, } = rt .block_on(index.search( + &mut params, &test_provider::Strategy::new(), &context, query.as_slice(), - &mut params, &mut crate::neighbor::BackInserter::new(neighbors.as_mut_slice()), )) .unwrap(); From 3615079f50c38b3418f08515601bc0ba4704ccd2 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Sat, 14 Feb 2026 15:54:52 +0530 Subject: [PATCH 11/11] fix: correct remaining search param ordering in test files --- .../src/search/provider/disk_provider.rs | 4 +--- diskann-providers/src/index/diskann_async.rs | 19 +++++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index b53c0eaad..8f32bf3a9 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -1042,8 +1042,6 @@ fn ensure_vertex_loaded>( #[cfg(test)] mod disk_provider_tests { - use std::num::NonZeroUsize; - use diskann::{ graph::{search::record::VisitedSearchRecord, KnnSearch, KnnSearchError}, utils::IntoUsize, @@ -1633,10 +1631,10 @@ mod disk_provider_tests { search_engine .runtime .block_on(search_engine.index.search( + &mut recorded_search, &strategy, &DefaultContext, query_vector.as_slice(), - &mut recorded_search, &mut result_output_buffer, )) .unwrap(); diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 6ecc2ba36..e47371ee1 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -357,14 +357,13 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let mut graph_search = - graph::KnnSearch::new_default(parameters.search_k, parameters.search_l) - .unwrap(); + graph::KnnSearch::new_default(parameters.search_k, parameters.search_l).unwrap(); index .search( + &mut graph_search, &strategy, ¶meters.context, query, - &mut graph_search, &mut result_output_buffer, ) .await @@ -410,10 +409,10 @@ pub(crate) mod tests { let mut multihop = graph::MultihopSearch::new(search_params, filter); index .search( + &mut multihop, strategy, ¶meters.context, query, - &mut multihop, &mut result_output_buffer, ) .await @@ -1455,10 +1454,10 @@ pub(crate) mod tests { let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index .search( + &mut multihop, &FullPrecision, ¶meters.context, query.as_slice(), - &mut multihop, &mut result_output_buffer, ) .await @@ -4093,10 +4092,10 @@ pub(crate) mod tests { let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &mut multihop, &mut result_output_buffer, ) .await @@ -4156,10 +4155,10 @@ pub(crate) mod tests { let mut multihop = graph::MultihopSearch::new(search_params, &filter); let stats = index .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &mut multihop, &mut result_output_buffer, ) .await @@ -4221,10 +4220,10 @@ pub(crate) mod tests { let mut multihop = graph::MultihopSearch::new(search_params, &EvenFilter); let baseline_stats = index .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &mut multihop, &mut baseline_buffer, ) .await @@ -4243,10 +4242,10 @@ pub(crate) mod tests { let mut multihop = graph::MultihopSearch::new(search_params, &filter); let adjusted_stats = index .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &mut multihop, &mut adjusted_buffer, ) .await @@ -4370,10 +4369,10 @@ pub(crate) mod tests { let mut multihop = graph::MultihopSearch::new(search_params, &filter); let _stats = index .search( + &mut multihop, &FullPrecision, &DefaultContext, query.as_slice(), - &mut multihop, &mut result_output_buffer, ) .await