diff --git a/benches/construct.rs b/benches/construct.rs index 71a4fb905..958eaa3b6 100644 --- a/benches/construct.rs +++ b/benches/construct.rs @@ -21,7 +21,7 @@ fn zeros_f64(bench: &mut Bencher) #[bench] fn map_regular(bench: &mut test::Bencher) { - let a = Array::linspace(0., 127., 128) + let a = Array::linspace(0.0..=127.0, 128) .into_shape_with_order((8, 16)) .unwrap(); bench.iter(|| a.map(|&x| 2. * x)); @@ -31,7 +31,7 @@ fn map_regular(bench: &mut test::Bencher) #[bench] fn map_stride(bench: &mut test::Bencher) { - let a = Array::linspace(0., 127., 256) + let a = Array::linspace(0.0..=127.0, 256) .into_shape_with_order((8, 32)) .unwrap(); let av = a.slice(s![.., ..;2]); diff --git a/benches/iter.rs b/benches/iter.rs index bc483c8c2..0e18f1230 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -47,7 +47,7 @@ fn iter_sum_2d_transpose(bench: &mut Bencher) #[bench] fn iter_filter_sum_2d_u32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256) + let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); let b = a.mapv(|x| (x * 100.) as u32); @@ -58,7 +58,7 @@ fn iter_filter_sum_2d_u32(bench: &mut Bencher) #[bench] fn iter_filter_sum_2d_f32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256) + let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); let b = a * 100.; @@ -69,7 +69,7 @@ fn iter_filter_sum_2d_f32(bench: &mut Bencher) #[bench] fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256) + let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); let b = a.mapv(|x| (x * 100.) as u32); @@ -81,7 +81,7 @@ fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) #[bench] fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256) + let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); let b = a * 100.; @@ -93,7 +93,7 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) #[bench] fn iter_rev_step_by_contiguous(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 512); + let a = Array::linspace(0.0..=1.0, 512); bench.iter(|| { a.iter().rev().step_by(2).for_each(|x| { black_box(x); @@ -105,7 +105,7 @@ fn iter_rev_step_by_contiguous(bench: &mut Bencher) #[bench] fn iter_rev_step_by_discontiguous(bench: &mut Bencher) { - let mut a = Array::linspace(0., 1., 1024); + let mut a = Array::linspace(0.0..=1.0, 1024); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); bench.iter(|| { a.iter().rev().step_by(2).for_each(|x| { diff --git a/src/finite_bounds.rs b/src/finite_bounds.rs new file mode 100644 index 000000000..565fe2bcb --- /dev/null +++ b/src/finite_bounds.rs @@ -0,0 +1,42 @@ +use num_traits::Float; + +pub enum Bound +{ + Included(F), + Excluded(F), +} + +/// A version of std::ops::RangeBounds that only implements a..b and a..=b ranges. +pub trait FiniteBounds +{ + fn start_bound(&self) -> F; + fn end_bound(&self) -> Bound; +} + +impl FiniteBounds for std::ops::Range +where F: Float +{ + fn start_bound(&self) -> F + { + self.start + } + + fn end_bound(&self) -> Bound + { + Bound::Excluded(self.end) + } +} + +impl FiniteBounds for std::ops::RangeInclusive +where F: Float +{ + fn start_bound(&self) -> F + { + *self.start() + } + + fn end_bound(&self) -> Bound + { + Bound::Included(*self.end()) + } +} diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index ba01e2ca3..846223f26 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -58,10 +58,7 @@ where S: DataOwned pub fn from_vec(v: Vec) -> Self { if mem::size_of::() == 0 { - assert!( - v.len() <= isize::MAX as usize, - "Length must fit in `isize`.", - ); + assert!(v.len() <= isize::MAX as usize, "Length must fit in `isize`.",); } unsafe { Self::from_shape_vec_unchecked(v.len() as Ix, v) } } @@ -99,10 +96,12 @@ where S: DataOwned /// assert!(array == arr1(&[0.0, 0.25, 0.5, 0.75, 1.0])) /// ``` #[cfg(feature = "std")] - pub fn linspace(start: A, end: A, n: usize) -> Self - where A: Float + pub fn linspace(range: R, n: usize) -> Self + where + R: crate::finite_bounds::FiniteBounds, + A: Float, { - Self::from(to_vec(linspace::linspace(start, end, n))) + Self::from(to_vec(linspace::linspace(range, n))) } /// Create a one-dimensional array with elements from `start` to `end` @@ -145,10 +144,12 @@ where S: DataOwned /// # } /// ``` #[cfg(feature = "std")] - pub fn logspace(base: A, start: A, end: A, n: usize) -> Self - where A: Float + pub fn logspace(base: A, range: R, n: usize) -> Self + where + R: crate::finite_bounds::FiniteBounds, + A: Float, { - Self::from(to_vec(logspace::logspace(base, start, end, n))) + Self::from(to_vec(logspace::logspace(base, range, n))) } /// Create a one-dimensional array with `n` geometrically spaced elements diff --git a/src/lib.rs b/src/lib.rs index 41e5ca350..970c3f126 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -199,6 +199,8 @@ mod indexes; mod iterators; mod layout; mod linalg_traits; +#[cfg(feature = "std")] +mod finite_bounds; mod linspace; #[cfg(feature = "std")] pub use crate::linspace::{linspace, range, Linspace}; diff --git a/src/linspace.rs b/src/linspace.rs index 411c480db..ff52bf0c1 100644 --- a/src/linspace.rs +++ b/src/linspace.rs @@ -6,6 +6,9 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. #![cfg(feature = "std")] + +use crate::finite_bounds::{Bound, FiniteBounds}; + use num_traits::Float; /// An iterator of a sequence of evenly spaced floats. @@ -71,17 +74,24 @@ impl ExactSizeIterator for Linspace where Linspace: Iterator {} /// The iterator element type is `F`, where `F` must implement [`Float`], e.g. /// [`f32`] or [`f64`]. /// -/// **Panics** if converting `n - 1` to type `F` fails. +/// **Panics** if converting `n` to type `F` fails. #[inline] -pub fn linspace(a: F, b: F, n: usize) -> Linspace -where F: Float +pub fn linspace(range: R, n: usize) -> Linspace +where + R: FiniteBounds, + F: Float, { - let step = if n > 1 { - let num_steps = F::from(n - 1).expect("Converting number of steps to `A` must not fail."); + let (a, b, num_steps) = match (range.start_bound(), range.end_bound()) { + (a, Bound::Included(b)) => (a, b, F::from(n - 1).expect("Converting number of steps to `A` must not fail.")), + (a, Bound::Excluded(b)) => (a, b, F::from(n).expect("Converting number of steps to `A` must not fail.")), + }; + + let step = if num_steps > F::zero() { (b - a) / num_steps } else { F::zero() }; + Linspace { start: a, step, diff --git a/src/logspace.rs b/src/logspace.rs index 463012018..dd1b7ae19 100644 --- a/src/logspace.rs +++ b/src/logspace.rs @@ -6,6 +6,9 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. #![cfg(feature = "std")] + +use crate::finite_bounds::{Bound, FiniteBounds}; + use num_traits::Float; /// An iterator of a sequence of logarithmically spaced number. @@ -79,15 +82,22 @@ impl ExactSizeIterator for Logspace where Logspace: Iterator {} /// /// **Panics** if converting `n - 1` to type `F` fails. #[inline] -pub fn logspace(base: F, a: F, b: F, n: usize) -> Logspace -where F: Float +pub fn logspace(base: F, range: R, n: usize) -> Logspace +where + R: FiniteBounds, + F: Float, { - let step = if n > 1 { - let num_steps = F::from(n - 1).expect("Converting number of steps to `A` must not fail."); + let (a, b, num_steps) = match (range.start_bound(), range.end_bound()) { + (a, Bound::Included(b)) => (a, b, F::from(n - 1).expect("Converting number of steps to `A` must not fail.")), + (a, Bound::Excluded(b)) => (a, b, F::from(n).expect("Converting number of steps to `A` must not fail.")), + }; + + let step = if num_steps > F::zero() { (b - a) / num_steps } else { F::zero() }; + Logspace { sign: base.signum(), base: base.abs(), @@ -110,23 +120,23 @@ mod tests use crate::{arr1, Array1}; use approx::assert_abs_diff_eq; - let array: Array1<_> = logspace(10.0, 0.0, 3.0, 4).collect(); + let array: Array1<_> = logspace(10.0, 0.0..=3.0, 4).collect(); assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12); - let array: Array1<_> = logspace(10.0, 3.0, 0.0, 4).collect(); + let array: Array1<_> = logspace(10.0, 3.0..=0.0, 4).collect(); assert_abs_diff_eq!(array, arr1(&[1e3, 1e2, 1e1, 1e0]), epsilon = 1e-12); - let array: Array1<_> = logspace(-10.0, 3.0, 0.0, 4).collect(); + let array: Array1<_> = logspace(-10.0, 3.0..=0.0, 4).collect(); assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12); - let array: Array1<_> = logspace(-10.0, 0.0, 3.0, 4).collect(); + let array: Array1<_> = logspace(-10.0, 0.0..=3.0, 4).collect(); assert_abs_diff_eq!(array, arr1(&[-1e0, -1e1, -1e2, -1e3]), epsilon = 1e-12); } #[test] fn iter_forward() { - let mut iter = logspace(10.0f64, 0.0, 3.0, 4); + let mut iter = logspace(10.0f64, 0.0..=3.0, 4); assert!(iter.size_hint() == (4, Some(4))); @@ -142,7 +152,7 @@ mod tests #[test] fn iter_backward() { - let mut iter = logspace(10.0f64, 0.0, 3.0, 4); + let mut iter = logspace(10.0f64, 0.0..=3.0, 4); assert!(iter.size_hint() == (4, Some(4)));