Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions httomolibgpu/misc/sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ---------------------------------------------------------------------------
# Copyright 2026 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either ecpress or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------------
# Created By : Tomography Team at DLS <scientificsoftware@diamond.ac.uk>
# Created Date: 22 April 2026
# ---------------------------------------------------------------------------

from typing import Tuple
import cupy as cp


def argsort_with_reverse(
data: cp.ndarray, axis: int = -1
) -> Tuple[cp.ndarray, cp.ndarray]:
"""
Compute sorting indices for an 1D or 2D array, and efficiently compute the indices to revert the sort.
"""
dim = len(data.shape)
if 1 <= dim <= 2:
pass
else:
raise ValueError("only 1D and 2D arrays are supported")
if axis < 0:
axis = dim + axis
if axis >= dim:
raise ValueError("invalid axis")
sort_indices = cp.argsort(data, axis=axis)
reverse_sort_indices = cp.empty_like(sort_indices)
if dim == 1:
reverse_sort_indices[sort_indices] = cp.arange(0, data.size)
elif axis == 0: # sort rows
nrows, ncols = data.shape
cols = cp.arange(ncols)[None, :]
reverse_sort_indices[sort_indices, cols] = cp.arange(nrows)[:, None]
else: # sort columns
nrows, ncols = data.shape
rows = cp.arange(nrows)[:, None]
reverse_sort_indices[rows, sort_indices] = cp.arange(ncols)

return sort_indices, reverse_sort_indices
54 changes: 20 additions & 34 deletions httomolibgpu/prep/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import pywt
from httomolibgpu import cupywrapper
from httomolibgpu.misc.sorting import argsort_with_reverse

cp = cupywrapper.cp
cupy_run = cupywrapper.cupy_run
Expand Down Expand Up @@ -116,21 +117,18 @@ def _rs_sort(sinogram, size, dim):
"""
Remove stripes using the sorting technique.
"""
sinogram = cp.transpose(sinogram)
#: Sort each row of the sinogram by its grayscale values
sortvals, sortvals_reverse = argsort_with_reverse(sinogram, axis=0)

#: Sort each column of the sinogram by its grayscale values
#: Keep track of the sorting indices so we can reverse it below
sortvals = cp.argsort(sinogram, axis=1)
sortvals_reverse = cp.argsort(sortvals, axis=1)
sino_sort = cp.take_along_axis(sinogram, sortvals, axis=1)
sino_sort = cp.take_along_axis(sinogram, sortvals, axis=0)

#: Now apply the median filter on the sorted image along each row
sino_sort = median_filter(sino_sort, (size, 1) if dim == 1 else (size, size))
sino_sort = median_filter(sino_sort, (1, size) if dim == 1 else (size, size))

#: step 3: re-sort the smoothed image columns to the original rows
sino_corrected = cp.take_along_axis(sino_sort, sortvals_reverse, axis=1)
sino_corrected = cp.take_along_axis(sino_sort, sortvals_reverse, axis=0)

return cp.transpose(sino_corrected)
return sino_corrected


def remove_stripe_ti(
Expand Down Expand Up @@ -821,6 +819,7 @@ def remove_all_stripe(
la_size: int = 61,
sm_size: int = 21,
dim: Literal[1, 2] = 1,
normalize: bool = False,
) -> cp.ndarray:
"""
Remove all types of stripe artifacts from sinogram using Nghia Vo's
Expand All @@ -839,6 +838,8 @@ def remove_all_stripe(
Window size of the median filter to remove small-to-medium stripes.
dim : {1, 2},
Dimension of the window.
normalize : bool,
Controls whether to normalize while removing large stripes

Returns
-------
Expand All @@ -859,10 +860,10 @@ def remove_all_stripe(
__check_variable_type(dim, [int], "dim", [1, 2], methods_name)
###################################

matindex = _create_matindex(data.shape[2], data.shape[0])
for m in range(data.shape[1]):
sino = data[:, m, :]
sino = _rs_dead(sino, snr, la_size, matindex)
sino = _rs_dead(sino, snr, la_size)
sino = _rs_large(sino, snr, la_size, normalize)
sino = _rs_sort(sino, sm_size, dim)
sino = cp.nan_to_num(sino)
data[:, m, :] = sino
Expand Down Expand Up @@ -910,15 +911,17 @@ def _detect_stripe(listdata, snr):
return listmask


def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
def _rs_large(sinogram, snr, size, normalize, drop_ratio=0.1):
"""
Remove large stripes.
"""
drop_ratio = max(min(drop_ratio, 0.8), 0) # = cp.clip(drop_ratio, 0.0, 0.8)
nrow, ncol = sinogram.shape
ndrop = int(0.5 * drop_ratio * nrow)
sinosort = cp.sort(sinogram, axis=0)
sort_indices, sort_indices_reverse = argsort_with_reverse(sinogram, axis=0)
sinosort = cp.take_along_axis(sinogram, sort_indices, axis=0)
sinosmooth = median_filter(sinosort, (1, size))

list1 = cp.mean(sinosort[ndrop : nrow - ndrop], axis=0)
list2 = cp.mean(sinosmooth[ndrop : nrow - ndrop], axis=0)
listfact = list1 / list2
Expand All @@ -927,30 +930,16 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
listmask = _detect_stripe(listfact, snr)
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
matfact = cp.tile(listfact, (nrow, 1))
# Normalize
if norm is True:
if normalize:
sinogram = sinogram / matfact
sinogram1 = cp.transpose(sinogram)
matcombine = cp.asarray(cp.dstack((matindex, sinogram1)))

ids = cp.argsort(matcombine[:, :, 1], axis=1)
matsort = matcombine.copy()
matsort[:, :, 0] = cp.take_along_axis(matsort[:, :, 0], ids, axis=1)
matsort[:, :, 1] = cp.take_along_axis(matsort[:, :, 1], ids, axis=1)

matsort[:, :, 1] = cp.transpose(sinosmooth)
ids = cp.argsort(matsort[:, :, 0], axis=1)
matsortback = matsort.copy()
matsortback[:, :, 0] = cp.take_along_axis(matsortback[:, :, 0], ids, axis=1)
matsortback[:, :, 1] = cp.take_along_axis(matsortback[:, :, 1], ids, axis=1)

sino_corrected = cp.transpose(matsortback[:, :, 1])
sort_indices, sort_indices_reverse = argsort_with_reverse(sinogram, axis=0)
sino_corrected = cp.take_along_axis(sinosmooth, sort_indices_reverse, axis=0)
listxmiss = cp.where(listmask > 0.0)[0]
sinogram[:, listxmiss] = sino_corrected[:, listxmiss]
return sinogram


def _rs_dead(sinogram, snr, size, matindex, norm=True):
def _rs_dead(sinogram, snr, size):
"""remove unresponsive and fluctuating stripes"""
sinogram = cp.copy(sinogram) # Make it mutable
nrow, _ = sinogram.shape
Expand Down Expand Up @@ -979,9 +968,6 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
sinogram[:, listx[ids]] - sinogram[:, listx[ids - 1]]
)

# Remove residual stripes
if norm is True:
sinogram = _rs_large(sinogram, snr, size, matindex)
return sinogram


Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ dev = [
"imageio",
"h5py",
"pre-commit",
"pyfftw"
"pyfftw",
"tomophantom"
]


Expand Down
144 changes: 144 additions & 0 deletions tests/test_misc/test_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import pytest
import cupy as cp

from httomolibgpu.misc.sorting import argsort_with_reverse


class TestArgsort1D:
def test_1d_sorted(self):
data = cp.array([1, 2, 3, 4, 5])
sort_idx, rev_idx = argsort_with_reverse(data)
cp.testing.assert_array_equal(sort_idx, cp.array([0, 1, 2, 3, 4]))
cp.testing.assert_array_equal(data[sort_idx], data)
cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(5))

def test_1d_reverse_sorted(self):
data = cp.array([5, 4, 3, 2, 1])
sort_idx, rev_idx = argsort_with_reverse(data)
cp.testing.assert_array_equal(sort_idx, cp.array([4, 3, 2, 1, 0]))
cp.testing.assert_array_equal(data[sort_idx], cp.array([1, 2, 3, 4, 5]))
cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(5))

def test_1d_random(self):
data = cp.array([10, 3, 7, 1, 9])
sort_idx, rev_idx = argsort_with_reverse(data)
sorted_data = data[sort_idx]
assert cp.all(sorted_data[:-1] <= sorted_data[1:])
cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(5))

def test_1d_duplicates(self):
data = cp.array([2, 1, 2, 1])
sort_idx, rev_idx = argsort_with_reverse(data)
sorted_data = data[sort_idx]
assert cp.all(sorted_data[:-1] <= sorted_data[1:])
cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(4))

def test_1d_single_element(self):
data = cp.array([42])
sort_idx, rev_idx = argsort_with_reverse(data)
cp.testing.assert_array_equal(sort_idx, cp.array([0]))
cp.testing.assert_array_equal(rev_idx[sort_idx], cp.array([0]))

def test_1d_negative_axis(self):
data = cp.array([3, 1, 2])
sort_idx, rev_idx = argsort_with_reverse(data, axis=-1)
cp.testing.assert_array_equal(sort_idx, cp.argsort(data))
cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(3))


class TestArgsort2D:
def test_2d_axis0_sorted(self):
data = cp.array([[1, 2], [3, 4], [5, 6]])
sort_idx, rev_idx = argsort_with_reverse(data, axis=0)
sorted_data = data[sort_idx, cp.arange(data.shape[1])[None, :]]
assert cp.all(sorted_data[:-1, :] <= sorted_data[1:, :])
rows, cols = data.shape
col_indices = cp.arange(cols)[None, :]
cp.testing.assert_array_equal(
rev_idx[sort_idx, col_indices], cp.tile(cp.arange(rows)[:, None], [1, 2])
)

def test_2d_axis0_random(self):
data = cp.array([[5, 1], [2, 4], [3, 6]])
sort_idx, rev_idx = argsort_with_reverse(data, axis=0)
rows, cols = data.shape
col_indices = cp.arange(cols)[None, :]
cp.testing.assert_array_equal(
rev_idx[sort_idx, col_indices], cp.tile(cp.arange(rows)[:, None], [1, 2])
)

def test_2d_axis0_negative_axis(self):
data = cp.array([[1, 2], [3, 4]])
sort_idx, rev_idx = argsort_with_reverse(data, axis=-2)
rows, cols = data.shape
col_indices = cp.arange(cols)[None, :]
cp.testing.assert_array_equal(
rev_idx[sort_idx, col_indices], cp.tile(cp.arange(rows)[:, None], [1, 2])
)

def test_2d_axis1_sorted(self):
data = cp.array([[1, 2, 3], [4, 5, 6]])
sort_idx, rev_idx = argsort_with_reverse(data, axis=1)
rows, cols = data.shape
row_indices = cp.arange(rows)[:, None]
cp.testing.assert_array_equal(
rev_idx[row_indices, sort_idx], cp.tile(cp.arange(cols), [2, 1])
)

def test_2d_axis1_random(self):
data = cp.array([[3, 1, 2], [6, 4, 5]])
sort_idx, rev_idx = argsort_with_reverse(data, axis=1)
rows, cols = data.shape
row_indices = cp.arange(rows)[:, None]
cp.testing.assert_array_equal(
rev_idx[row_indices, sort_idx], cp.tile(cp.arange(cols), [2, 1])
)

def test_2d_axis1_negative_axis(self):
data = cp.array([[1, 2], [3, 4]])
sort_idx, rev_idx = argsort_with_reverse(data, axis=-1)
rows, cols = data.shape
row_indices = cp.arange(rows)[:, None]
cp.testing.assert_array_equal(
rev_idx[row_indices, sort_idx], cp.tile(cp.arange(cols), [2, 1])
)

def test_2d_single_row(self):
data = cp.array([[3, 1, 2]])
sort_idx, rev_idx = argsort_with_reverse(data, axis=1)
rows, cols = data.shape
row_indices = cp.arange(rows)[:, None]
cp.testing.assert_array_equal(
rev_idx[row_indices, sort_idx], cp.arange(cols)[None, :]
)

def test_2d_single_col(self):
data = cp.array([[3], [1], [2]])
sort_idx, rev_idx = argsort_with_reverse(data, axis=0)
rows, cols = data.shape
col_indices = cp.arange(cols)[None, :]
cp.testing.assert_array_equal(
rev_idx[sort_idx, col_indices], cp.arange(rows)[:, None]
)


class TestArgsortErrors:
def test_invalid_dim_3d(self):
data = cp.ones((2, 2, 2))
with pytest.raises(ValueError):
argsort_with_reverse(data)

def test_invalid_dim_0d(self):
data = cp.array(1)
with pytest.raises(ValueError):
argsort_with_reverse(data)

def test_invalid_axis_1d(self):
data = cp.array([1, 2, 3])
with pytest.raises(ValueError):
argsort_with_reverse(data, axis=1)

def test_invalid_axis_2d(self):
data = cp.array([[1, 2], [3, 4]])
with pytest.raises(ValueError):
argsort_with_reverse(data, axis=2)
Loading
Loading