Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
423 changes: 394 additions & 29 deletions crates/quspin-core/src/expm/algorithm.rs

Large diffs are not rendered by default.

166 changes: 166 additions & 0 deletions crates/quspin-core/src/expm/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,167 @@

use std::fmt::Debug;
use std::ops::{Add, Mul};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};

use num_complex::Complex;

use crate::primitive::Primitive;

// ---------------------------------------------------------------------------
// AtomicAccum — parallel scatter-write accumulator
// ---------------------------------------------------------------------------

/// A single atomically-accumatable output slot for parallel scatter-writes.
///
/// Used by [`crate::expm::linear_operator::LinearOperator::dot_transpose_chunk`]
/// so that multiple threads can safely scatter-add into a shared output array
/// without per-thread intermediate buffers.
///
/// All operations use `Relaxed` ordering; the caller is responsible for
/// inserting the appropriate synchronisation barrier (e.g. the rayon thread-
/// pool join) before reading the accumulated results.
pub trait AtomicAccum: Send + Sync {
type Value;
/// Atomically add `val` to this slot.
fn fetch_add(&self, val: Self::Value);
/// Load the current value.
fn load(&self) -> Self::Value;
/// Construct a zero-valued slot.
fn zero() -> Self;
}

// ---------------------------------------------------------------------------
// Private helpers: CAS-loop atomic add for f32 / f64
// ---------------------------------------------------------------------------

#[inline]
fn cas_add_f32(atom: &AtomicU32, val: f32) {
atom.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |bits| {
Some((f32::from_bits(bits) + val).to_bits())
})
.unwrap();
}

#[inline]
fn cas_add_f64(atom: &AtomicU64, val: f64) {
atom.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |bits| {
Some((f64::from_bits(bits) + val).to_bits())
})
.unwrap();
}

// ---------------------------------------------------------------------------
// Concrete atomic accumulators
// ---------------------------------------------------------------------------

/// Atomic accumulator for `f32`.
pub struct AtomicF32(AtomicU32);

impl AtomicAccum for AtomicF32 {
type Value = f32;
#[inline]
fn fetch_add(&self, val: f32) {
cas_add_f32(&self.0, val);
}
#[inline]
fn load(&self) -> f32 {
f32::from_bits(self.0.load(Ordering::Relaxed))
}
#[inline]
fn zero() -> Self {
AtomicF32(AtomicU32::new(0))
}
}

/// Atomic accumulator for `f64`.
pub struct AtomicF64(AtomicU64);

impl AtomicAccum for AtomicF64 {
type Value = f64;
#[inline]
fn fetch_add(&self, val: f64) {
cas_add_f64(&self.0, val);
}
#[inline]
fn load(&self) -> f64 {
f64::from_bits(self.0.load(Ordering::Relaxed))
}
#[inline]
fn zero() -> Self {
AtomicF64(AtomicU64::new(0))
}
}

/// Atomic accumulator for `Complex<f32>`.
///
/// The real and imaginary parts are updated independently (each with a
/// separate CAS loop). No cross-field atomicity is provided; observers
/// must not read until all threads have finished.
pub struct AtomicComplex32 {
re: AtomicU32,
im: AtomicU32,
}

impl AtomicAccum for AtomicComplex32 {
type Value = Complex<f32>;
#[inline]
fn fetch_add(&self, val: Complex<f32>) {
cas_add_f32(&self.re, val.re);
cas_add_f32(&self.im, val.im);
}
#[inline]
fn load(&self) -> Complex<f32> {
Complex::new(
f32::from_bits(self.re.load(Ordering::Relaxed)),
f32::from_bits(self.im.load(Ordering::Relaxed)),
)
}
#[inline]
fn zero() -> Self {
AtomicComplex32 {
re: AtomicU32::new(0),
im: AtomicU32::new(0),
}
}
}

/// Atomic accumulator for `Complex<f64>`.
///
/// The real and imaginary parts are updated independently (each with a
/// separate CAS loop). No cross-field atomicity is provided; observers
/// must not read until all threads have finished.
pub struct AtomicComplex64 {
re: AtomicU64,
im: AtomicU64,
}

impl AtomicAccum for AtomicComplex64 {
type Value = Complex<f64>;
#[inline]
fn fetch_add(&self, val: Complex<f64>) {
cas_add_f64(&self.re, val.re);
cas_add_f64(&self.im, val.im);
}
#[inline]
fn load(&self) -> Complex<f64> {
Complex::new(
f64::from_bits(self.re.load(Ordering::Relaxed)),
f64::from_bits(self.im.load(Ordering::Relaxed)),
)
}
#[inline]
fn zero() -> Self {
AtomicComplex64 {
re: AtomicU64::new(0),
im: AtomicU64::new(0),
}
}
}

// ---------------------------------------------------------------------------
// ExpmComputation trait
// ---------------------------------------------------------------------------

/// Extension of [`Primitive`] with the operations needed by the Taylor-series
/// matrix-exponential algorithm.
///
Expand All @@ -15,12 +171,18 @@ pub trait ExpmComputation: Primitive {
/// The real scalar type: `f32` for `f32`/`Complex<f32>`,
/// `f64` for `f64`/`Complex<f64>`.
type Real: Copy
+ Send
+ PartialOrd
+ Add<Output = Self::Real>
+ Mul<Output = Self::Real>
+ Default
+ Debug;

/// Atomic accumulator for parallel scatter-writes into a shared output
/// slice. Used by
/// [`LinearOperator::dot_transpose_chunk`](crate::expm::linear_operator::LinearOperator::dot_transpose_chunk).
type Atomic: AtomicAccum<Value = Self>;

/// `|self|` as `Self::Real` (modulus for complex types).
fn abs_val(self) -> Self::Real;

Expand All @@ -43,6 +205,7 @@ pub trait ExpmComputation: Primitive {

impl ExpmComputation for f32 {
type Real = f32;
type Atomic = AtomicF32;

#[inline]
fn abs_val(self) -> f32 {
Expand All @@ -68,6 +231,7 @@ impl ExpmComputation for f32 {

impl ExpmComputation for f64 {
type Real = f64;
type Atomic = AtomicF64;

#[inline]
fn abs_val(self) -> f64 {
Expand All @@ -93,6 +257,7 @@ impl ExpmComputation for f64 {

impl ExpmComputation for Complex<f32> {
type Real = f32;
type Atomic = AtomicComplex32;

#[inline]
fn abs_val(self) -> f32 {
Expand All @@ -118,6 +283,7 @@ impl ExpmComputation for Complex<f32> {

impl ExpmComputation for Complex<f64> {
type Real = f64;
type Atomic = AtomicComplex64;

#[inline]
fn abs_val(self) -> f64 {
Expand Down
Loading
Loading