diff --git a/diskann-quantization/src/multi_vector/matrix.rs b/diskann-quantization/src/multi_vector/matrix.rs index e8789b2c8..36e13cc2a 100644 --- a/diskann-quantization/src/multi_vector/matrix.rs +++ b/diskann-quantization/src/multi_vector/matrix.rs @@ -655,9 +655,9 @@ impl Mat { Self { ptr, repr } } - #[cfg(test)] - fn as_ptr(&self) -> NonNull { - self.ptr + /// Return the base pointer for the [`Mat`]. + pub fn as_raw_ptr(&self) -> *const u8 { + self.ptr.as_ptr() } } @@ -783,6 +783,11 @@ impl<'a, T: Repr> MatRef<'a, T> { _lifetime: PhantomData, } } + + /// Return the base pointer for the [`MatRef`]. + pub fn as_raw_ptr(&self) -> *const u8 { + self.ptr.as_ptr() + } } impl<'a, T: Copy> MatRef<'a, Standard> { @@ -961,6 +966,11 @@ impl<'a, T: ReprMut> MatMut<'a, T> { _lifetime: PhantomData, } } + + /// Return the base pointer for the [`MatMut`]. + pub fn as_raw_ptr(&self) -> *const u8 { + self.ptr.as_ptr() + } } // Reborrow: MatMut -> MatRef @@ -1396,7 +1406,7 @@ mod tests { #[test] fn mat_new_and_basic_accessors() { let mat = Mat::new(Standard::::new(3, 4).unwrap(), 42usize).unwrap(); - let base: *const u8 = mat.as_ptr().as_ptr(); + let base: *const u8 = mat.as_raw_ptr(); assert_eq!(mat.num_vectors(), 3); assert_eq!(mat.vector_dim(), 4); @@ -1418,7 +1428,7 @@ mod tests { #[test] fn mat_new_with_default() { let mat = Mat::new(Standard::::new(2, 3).unwrap(), Defaulted).unwrap(); - let base: *const u8 = mat.as_ptr().as_ptr(); + let base: *const u8 = mat.as_raw_ptr(); assert_eq!(mat.num_vectors(), 2); for (i, row) in mat.rows().enumerate() { @@ -1456,6 +1466,10 @@ mod tests { check_mat_ref(mat.reborrow(), repr, ctx); check_mat_mut(mat.reborrow_mut(), repr, ctx); check_rows(mat.rows(), repr, ctx); + + // Check reborrow preserves pointers. + assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr()); + assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr()); } // Populate the matrix using `MatMut` @@ -1514,7 +1528,7 @@ mod tests { // Cloned allocation is independent. if repr.num_elements() > 0 { - assert_ne!(mat.as_ptr(), cloned.as_ptr()); + assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr()); } } @@ -1528,7 +1542,7 @@ mod tests { check_rows(owned.rows(), repr, ctx); if repr.num_elements() > 0 { - assert_ne!(mat.as_ptr(), owned.as_ptr()); + assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr()); } } @@ -1542,7 +1556,7 @@ mod tests { check_rows(owned.rows(), repr, ctx); if repr.num_elements() > 0 { - assert_ne!(mat.as_ptr(), owned.as_ptr()); + assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr()); } } } @@ -1560,8 +1574,15 @@ mod tests { { let ctx = &lazy_format!("{ctx} - by matmut"); let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect(); + let ptr = b.as_ptr().cast::(); let mut matmut = MatMut::new(repr, &mut b).unwrap(); + assert_eq!( + ptr, + matmut.as_raw_ptr(), + "underlying memory should be preserved", + ); + fill_mat_mut(matmut.reborrow_mut(), repr); check_mat_mut(matmut.reborrow_mut(), repr, ctx); @@ -1579,8 +1600,15 @@ mod tests { { let ctx = &lazy_format!("{ctx} - by rows"); let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect(); + let ptr = b.as_ptr().cast::(); let mut matmut = MatMut::new(repr, &mut b).unwrap(); + assert_eq!( + ptr, + matmut.as_raw_ptr(), + "underlying memory should be preserved", + ); + fill_rows_mut(matmut.rows_mut(), repr); check_mat_mut(matmut.reborrow_mut(), repr, ctx);