Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions diskann-benchmark-core/src/search/graph/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

//! A built-in helper for benchmarking K-nearest neighbors.

use std::{num::NonZeroUsize, sync::Arc};
use std::sync::Arc;

use diskann::{
ANNResult,
Expand All @@ -29,7 +29,7 @@ use crate::{
/// the latter. Result aggregation for [`search::search_all`] is provided
/// by the [`Aggregator`] type.
///
/// The provided implementation of [`Search`] accepts [`graph::SearchParams`]
/// The provided implementation of [`Search`] accepts [`graph::KnnSearch`]
/// and returns [`Metrics`] as additional output.
#[derive(Debug)]
pub struct KNN<DP, T, S>
Expand Down Expand Up @@ -92,15 +92,15 @@ where
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::SearchParams;
type Parameters = graph::KnnSearch;
type Output = Metrics;

fn num_queries(&self) -> usize {
self.queries.nrows()
}

fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
search::IdCount::Fixed(NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE))
search::IdCount::Fixed(parameters.k_value())
}

async fn search<O>(
Expand All @@ -113,13 +113,14 @@ where
O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
{
let context = DP::Context::default();
let mut knn_search = *parameters;
let stats = self
.index
.search(
&mut knn_search,
self.strategy.get(index)?,
&context,
self.queries.row(index),
parameters,
buffer,
)
.await?;
Expand All @@ -142,7 +143,7 @@ pub struct Summary {
pub setup: search::Setup,

/// The [`Search::Parameters`] used for the batch of runs.
pub parameters: graph::SearchParams,
pub parameters: graph::KnnSearch,

/// The end-to-end latency for each repetition in the batch.
pub end_to_end_latencies: Vec<MicroSeconds>,
Expand Down Expand Up @@ -207,15 +208,15 @@ impl<'a, I> Aggregator<'a, I> {
}
}

impl<I> search::Aggregate<graph::SearchParams, I, Metrics> for Aggregator<'_, I>
impl<I> search::Aggregate<graph::KnnSearch, I, Metrics> for Aggregator<'_, I>
where
I: crate::recall::RecallCompatible,
{
type Output = Summary;

fn aggregate(
&mut self,
run: search::Run<graph::SearchParams>,
run: search::Run<graph::KnnSearch>,
mut results: Vec<search::SearchResults<I, Metrics>>,
) -> anyhow::Result<Summary> {
// Compute the recall using just the first result.
Expand Down Expand Up @@ -280,13 +281,15 @@ where

#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;

use super::*;

use diskann::graph::test::provider;

#[test]
fn test_knn() {
let nearest_neighbors = 5;
let nearest_neighbors = NonZeroUsize::new(5).unwrap();

let index = search::graph::test_grid_provider();

Expand All @@ -310,7 +313,7 @@ mod tests {
let rt = crate::tokio::runtime(2).unwrap();
let results = search::search(
knn.clone(),
graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(),
graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(),
NonZeroUsize::new(2).unwrap(),
&rt,
)
Expand All @@ -321,7 +324,7 @@ mod tests {
assert_eq!(*rows.row(0).first().unwrap(), 0);

for r in 0..rows.nrows() {
assert_eq!(rows.row(r).len(), nearest_neighbors);
assert_eq!(rows.row(r).len(), nearest_neighbors.get());
}

const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
Expand All @@ -334,17 +337,17 @@ mod tests {
// Try the aggregated strategy.
let parameters = [
search::Run::new(
graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(),
graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(),
setup.clone(),
),
search::Run::new(
graph::SearchParams::new(nearest_neighbors, 15, None).unwrap(),
graph::KnnSearch::new(nearest_neighbors.get(), 15, None).unwrap(),
setup.clone(),
),
];

let recall_k = nearest_neighbors;
let recall_n = nearest_neighbors;
let recall_k = nearest_neighbors.get();
let recall_n = nearest_neighbors.get();

let all =
search::search_all(knn, parameters, Aggregator::new(rows, recall_k, recall_n)).unwrap();
Expand Down
30 changes: 16 additions & 14 deletions diskann-benchmark-core/src/search/graph/multihop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Licensed under the MIT license.
*/

use std::{num::NonZeroUsize, sync::Arc};
use std::sync::Arc;

use diskann::{
ANNResult,
Expand All @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy};
/// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same
/// aggregator as [`search::graph::KNN`]).
///
/// The provided implementation of [`Search`] accepts [`graph::SearchParams`]
/// The provided implementation of [`Search`] accepts [`graph::KnnSearch`]
/// and returns [`search::graph::knn::Metrics`] as additional output.
#[derive(Debug)]
pub struct MultiHop<DP, T, S>
Expand Down Expand Up @@ -90,15 +90,15 @@ where
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::SearchParams;
type Parameters = graph::KnnSearch;
type Output = super::knn::Metrics;

fn num_queries(&self) -> usize {
self.queries.nrows()
}

fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
search::IdCount::Fixed(NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE))
search::IdCount::Fixed(parameters.k_value())
}

async fn search<O>(
Expand All @@ -111,15 +111,15 @@ where
O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
{
let context = DP::Context::default();
let mut multihop_search = graph::MultihopSearch::new(*parameters, &*self.labels[index]);
let stats = self
.index
.multihop_search(
.search(
&mut multihop_search,
self.strategy.get(index)?,
&context,
self.queries.row(index),
parameters,
buffer,
&*self.labels[index],
)
.await?;

Expand All @@ -136,6 +136,8 @@ where

#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;

use super::*;

use diskann::graph::{index::QueryLabelProvider, test::provider};
Expand All @@ -152,7 +154,7 @@ mod tests {

#[test]
fn test_multihop() {
let nearest_neighbors = 5;
let nearest_neighbors = NonZeroUsize::new(5).unwrap();

let index = search::graph::test_grid_provider();

Expand All @@ -179,7 +181,7 @@ mod tests {
let rt = crate::tokio::runtime(2).unwrap();
let results = search::search(
multihop.clone(),
graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(),
graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(),
NonZeroUsize::new(2).unwrap(),
&rt,
)
Expand All @@ -191,7 +193,7 @@ mod tests {

// Check that only even IDs are returned.
for r in 0..rows.nrows() {
assert_eq!(rows.row(r).len(), nearest_neighbors);
assert_eq!(rows.row(r).len(), nearest_neighbors.get());
for &id in rows.row(r) {
assert_eq!(id % 2, 0, "Found odd ID {} in row {}", id, r);
}
Expand All @@ -207,17 +209,17 @@ mod tests {
// Try the aggregated strategy.
let parameters = [
search::Run::new(
graph::SearchParams::new(nearest_neighbors, 10, None).unwrap(),
graph::KnnSearch::new(nearest_neighbors.get(), 10, None).unwrap(),
setup.clone(),
),
search::Run::new(
graph::SearchParams::new(nearest_neighbors, 15, None).unwrap(),
graph::KnnSearch::new(nearest_neighbors.get(), 15, None).unwrap(),
setup.clone(),
),
];

let recall_k = nearest_neighbors;
let recall_n = nearest_neighbors;
let recall_k = nearest_neighbors.get();
let recall_n = nearest_neighbors.get();

let all = search::search_all(
multihop,
Expand Down
33 changes: 19 additions & 14 deletions diskann-benchmark-core/src/search/graph/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{
/// by the [`Aggregator`] type.
///
/// The provided implementation of [`Search`] accepts
/// [`graph::RangeSearchParams`] and returns [`Metrics`] as additional output.
/// [`graph::RangeSearch`] and returns [`Metrics`] as additional output.
#[derive(Debug)]
pub struct Range<DP, T, S>
where
Expand Down Expand Up @@ -83,15 +83,15 @@ where
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::RangeSearchParams;
type Parameters = graph::RangeSearch;
type Output = Metrics;

fn num_queries(&self) -> usize {
self.queries.nrows()
}

fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
search::IdCount::Dynamic(NonZeroUsize::new(parameters.starting_l_value))
search::IdCount::Dynamic(NonZeroUsize::new(parameters.starting_l()))
}

async fn search<O>(
Expand All @@ -104,16 +104,21 @@ where
O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
{
let context = DP::Context::default();
let (_, ids, distances) = self
let mut range_search = *parameters;
let result = self
.index
.range_search(
.search(
&mut range_search,
self.strategy.get(index)?,
&context,
self.queries.row(index),
parameters,
&mut (),
)
.await?;
buffer.extend(std::iter::zip(ids.into_iter(), distances.into_iter()));
buffer.extend(std::iter::zip(
result.ids.into_iter(),
result.distances.into_iter(),
));

Ok(Metrics {})
}
Expand All @@ -129,8 +134,8 @@ pub struct Summary {
/// The [`search::Setup`] used for the batch of runs.
pub setup: search::Setup,

/// The [`graph::RangeSearchParams`] used for the batch of runs.
pub parameters: graph::RangeSearchParams,
/// The [`graph::RangeSearch`] used for the batch of runs.
pub parameters: graph::RangeSearch,

/// The end-to-end latency for each repetition in the batch.
pub end_to_end_latencies: Vec<MicroSeconds>,
Expand Down Expand Up @@ -174,7 +179,7 @@ impl<'a, I> Aggregator<'a, I> {
}
}

impl<I> search::Aggregate<graph::RangeSearchParams, I, Metrics> for Aggregator<'_, I>
impl<I> search::Aggregate<graph::RangeSearch, I, Metrics> for Aggregator<'_, I>
where
I: crate::recall::RecallCompatible,
{
Expand All @@ -183,7 +188,7 @@ where
#[inline(never)]
fn aggregate(
&mut self,
run: search::Run<graph::RangeSearchParams>,
run: search::Run<graph::RangeSearch>,
mut results: Vec<search::SearchResults<I, Metrics>>,
) -> anyhow::Result<Summary> {
// Compute the recall using just the first result.
Expand Down Expand Up @@ -261,7 +266,7 @@ mod tests {
let rt = crate::tokio::runtime(2).unwrap();
let results = search::search(
range.clone(),
graph::RangeSearchParams::new(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(),
graph::RangeSearch::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(),
NonZeroUsize::new(2).unwrap(),
&rt,
)
Expand All @@ -280,11 +285,11 @@ mod tests {
// Try the aggregated strategy.
let parameters = [
search::Run::new(
graph::RangeSearchParams::new(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(),
graph::RangeSearch::with_options(None, 10, None, 2.0, None, 0.8, 1.2).unwrap(),
setup.clone(),
),
search::Run::new(
graph::RangeSearchParams::new(None, 15, None, 2.0, None, 0.8, 1.2).unwrap(),
graph::RangeSearch::with_options(None, 15, None, 2.0, None, 0.8, 1.2).unwrap(),
setup.clone(),
),
];
Expand Down
6 changes: 3 additions & 3 deletions diskann-benchmark/src/backend/index/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ impl SearchResults {

Self {
num_tasks: setup.tasks.into(),
search_n: parameters.k_value,
search_l: parameters.l_value,
search_n: parameters.k_value().get(),
search_l: parameters.l_value().get(),
qps,
search_latencies: end_to_end_latencies,
mean_latencies,
Expand Down Expand Up @@ -284,7 +284,7 @@ impl RangeSearchResults {

Self {
num_tasks: setup.tasks.into(),
initial_l: parameters.starting_l_value,
initial_l: parameters.starting_l(),
qps,
search_latencies: end_to_end_latencies,
mean_latencies,
Expand Down
Loading
Loading