diff --git a/httomolibgpu/misc/sorting.py b/httomolibgpu/misc/sorting.py new file mode 100644 index 00000000..241c6e98 --- /dev/null +++ b/httomolibgpu/misc/sorting.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# --------------------------------------------------------------------------- +# Copyright 2026 Diamond Light Source Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either ecpress or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------------- +# Created By : Tomography Team at DLS +# Created Date: 22 April 2026 +# --------------------------------------------------------------------------- + +from typing import Tuple +import cupy as cp + + +def argsort_with_reverse( + data: cp.ndarray, axis: int = -1 +) -> Tuple[cp.ndarray, cp.ndarray]: + """ + Compute sorting indices for an 1D or 2D array, and efficiently compute the indices to revert the sort. + """ + dim = len(data.shape) + if 1 <= dim <= 2: + pass + else: + raise ValueError("only 1D and 2D arrays are supported") + if axis < 0: + axis = dim + axis + if axis >= dim: + raise ValueError("invalid axis") + sort_indices = cp.argsort(data, axis=axis) + reverse_sort_indices = cp.empty_like(sort_indices) + if dim == 1: + reverse_sort_indices[sort_indices] = cp.arange(0, data.size) + elif axis == 0: # sort rows + nrows, ncols = data.shape + cols = cp.arange(ncols)[None, :] + reverse_sort_indices[sort_indices, cols] = cp.arange(nrows)[:, None] + else: # sort columns + nrows, ncols = data.shape + rows = cp.arange(nrows)[:, None] + reverse_sort_indices[rows, sort_indices] = cp.arange(ncols) + + return sort_indices, reverse_sort_indices diff --git a/httomolibgpu/prep/stripe.py b/httomolibgpu/prep/stripe.py index 1938fa31..719b4b97 100644 --- a/httomolibgpu/prep/stripe.py +++ b/httomolibgpu/prep/stripe.py @@ -23,6 +23,7 @@ import numpy as np import pywt from httomolibgpu import cupywrapper +from httomolibgpu.misc.sorting import argsort_with_reverse cp = cupywrapper.cp cupy_run = cupywrapper.cupy_run @@ -116,21 +117,18 @@ def _rs_sort(sinogram, size, dim): """ Remove stripes using the sorting technique. """ - sinogram = cp.transpose(sinogram) + #: Sort each row of the sinogram by its grayscale values + sortvals, sortvals_reverse = argsort_with_reverse(sinogram, axis=0) - #: Sort each column of the sinogram by its grayscale values - #: Keep track of the sorting indices so we can reverse it below - sortvals = cp.argsort(sinogram, axis=1) - sortvals_reverse = cp.argsort(sortvals, axis=1) - sino_sort = cp.take_along_axis(sinogram, sortvals, axis=1) + sino_sort = cp.take_along_axis(sinogram, sortvals, axis=0) #: Now apply the median filter on the sorted image along each row - sino_sort = median_filter(sino_sort, (size, 1) if dim == 1 else (size, size)) + sino_sort = median_filter(sino_sort, (1, size) if dim == 1 else (size, size)) #: step 3: re-sort the smoothed image columns to the original rows - sino_corrected = cp.take_along_axis(sino_sort, sortvals_reverse, axis=1) + sino_corrected = cp.take_along_axis(sino_sort, sortvals_reverse, axis=0) - return cp.transpose(sino_corrected) + return sino_corrected def remove_stripe_ti( @@ -821,6 +819,7 @@ def remove_all_stripe( la_size: int = 61, sm_size: int = 21, dim: Literal[1, 2] = 1, + normalize: bool = False, ) -> cp.ndarray: """ Remove all types of stripe artifacts from sinogram using Nghia Vo's @@ -839,6 +838,8 @@ def remove_all_stripe( Window size of the median filter to remove small-to-medium stripes. dim : {1, 2}, Dimension of the window. + normalize : bool, + Controls whether to normalize while removing large stripes Returns ------- @@ -859,10 +860,10 @@ def remove_all_stripe( __check_variable_type(dim, [int], "dim", [1, 2], methods_name) ################################### - matindex = _create_matindex(data.shape[2], data.shape[0]) for m in range(data.shape[1]): sino = data[:, m, :] - sino = _rs_dead(sino, snr, la_size, matindex) + sino = _rs_dead(sino, snr, la_size) + sino = _rs_large(sino, snr, la_size, normalize) sino = _rs_sort(sino, sm_size, dim) sino = cp.nan_to_num(sino) data[:, m, :] = sino @@ -910,15 +911,17 @@ def _detect_stripe(listdata, snr): return listmask -def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True): +def _rs_large(sinogram, snr, size, normalize, drop_ratio=0.1): """ Remove large stripes. """ drop_ratio = max(min(drop_ratio, 0.8), 0) # = cp.clip(drop_ratio, 0.0, 0.8) nrow, ncol = sinogram.shape ndrop = int(0.5 * drop_ratio * nrow) - sinosort = cp.sort(sinogram, axis=0) + sort_indices, sort_indices_reverse = argsort_with_reverse(sinogram, axis=0) + sinosort = cp.take_along_axis(sinogram, sort_indices, axis=0) sinosmooth = median_filter(sinosort, (1, size)) + list1 = cp.mean(sinosort[ndrop : nrow - ndrop], axis=0) list2 = cp.mean(sinosmooth[ndrop : nrow - ndrop], axis=0) listfact = list1 / list2 @@ -927,30 +930,16 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True): listmask = _detect_stripe(listfact, snr) listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype) matfact = cp.tile(listfact, (nrow, 1)) - # Normalize - if norm is True: + if normalize: sinogram = sinogram / matfact - sinogram1 = cp.transpose(sinogram) - matcombine = cp.asarray(cp.dstack((matindex, sinogram1))) - - ids = cp.argsort(matcombine[:, :, 1], axis=1) - matsort = matcombine.copy() - matsort[:, :, 0] = cp.take_along_axis(matsort[:, :, 0], ids, axis=1) - matsort[:, :, 1] = cp.take_along_axis(matsort[:, :, 1], ids, axis=1) - - matsort[:, :, 1] = cp.transpose(sinosmooth) - ids = cp.argsort(matsort[:, :, 0], axis=1) - matsortback = matsort.copy() - matsortback[:, :, 0] = cp.take_along_axis(matsortback[:, :, 0], ids, axis=1) - matsortback[:, :, 1] = cp.take_along_axis(matsortback[:, :, 1], ids, axis=1) - - sino_corrected = cp.transpose(matsortback[:, :, 1]) + sort_indices, sort_indices_reverse = argsort_with_reverse(sinogram, axis=0) + sino_corrected = cp.take_along_axis(sinosmooth, sort_indices_reverse, axis=0) listxmiss = cp.where(listmask > 0.0)[0] sinogram[:, listxmiss] = sino_corrected[:, listxmiss] return sinogram -def _rs_dead(sinogram, snr, size, matindex, norm=True): +def _rs_dead(sinogram, snr, size): """remove unresponsive and fluctuating stripes""" sinogram = cp.copy(sinogram) # Make it mutable nrow, _ = sinogram.shape @@ -979,9 +968,6 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True): sinogram[:, listx[ids]] - sinogram[:, listx[ids - 1]] ) - # Remove residual stripes - if norm is True: - sinogram = _rs_large(sinogram, snr, size, matindex) return sinogram diff --git a/pyproject.toml b/pyproject.toml index 78d34f4b..60bcb168 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,8 @@ dev = [ "imageio", "h5py", "pre-commit", - "pyfftw" + "pyfftw", + "tomophantom" ] diff --git a/tests/test_misc/test_sorting.py b/tests/test_misc/test_sorting.py new file mode 100644 index 00000000..60c39edd --- /dev/null +++ b/tests/test_misc/test_sorting.py @@ -0,0 +1,144 @@ +import pytest +import cupy as cp + +from httomolibgpu.misc.sorting import argsort_with_reverse + + +class TestArgsort1D: + def test_1d_sorted(self): + data = cp.array([1, 2, 3, 4, 5]) + sort_idx, rev_idx = argsort_with_reverse(data) + cp.testing.assert_array_equal(sort_idx, cp.array([0, 1, 2, 3, 4])) + cp.testing.assert_array_equal(data[sort_idx], data) + cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(5)) + + def test_1d_reverse_sorted(self): + data = cp.array([5, 4, 3, 2, 1]) + sort_idx, rev_idx = argsort_with_reverse(data) + cp.testing.assert_array_equal(sort_idx, cp.array([4, 3, 2, 1, 0])) + cp.testing.assert_array_equal(data[sort_idx], cp.array([1, 2, 3, 4, 5])) + cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(5)) + + def test_1d_random(self): + data = cp.array([10, 3, 7, 1, 9]) + sort_idx, rev_idx = argsort_with_reverse(data) + sorted_data = data[sort_idx] + assert cp.all(sorted_data[:-1] <= sorted_data[1:]) + cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(5)) + + def test_1d_duplicates(self): + data = cp.array([2, 1, 2, 1]) + sort_idx, rev_idx = argsort_with_reverse(data) + sorted_data = data[sort_idx] + assert cp.all(sorted_data[:-1] <= sorted_data[1:]) + cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(4)) + + def test_1d_single_element(self): + data = cp.array([42]) + sort_idx, rev_idx = argsort_with_reverse(data) + cp.testing.assert_array_equal(sort_idx, cp.array([0])) + cp.testing.assert_array_equal(rev_idx[sort_idx], cp.array([0])) + + def test_1d_negative_axis(self): + data = cp.array([3, 1, 2]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=-1) + cp.testing.assert_array_equal(sort_idx, cp.argsort(data)) + cp.testing.assert_array_equal(rev_idx[sort_idx], cp.arange(3)) + + +class TestArgsort2D: + def test_2d_axis0_sorted(self): + data = cp.array([[1, 2], [3, 4], [5, 6]]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=0) + sorted_data = data[sort_idx, cp.arange(data.shape[1])[None, :]] + assert cp.all(sorted_data[:-1, :] <= sorted_data[1:, :]) + rows, cols = data.shape + col_indices = cp.arange(cols)[None, :] + cp.testing.assert_array_equal( + rev_idx[sort_idx, col_indices], cp.tile(cp.arange(rows)[:, None], [1, 2]) + ) + + def test_2d_axis0_random(self): + data = cp.array([[5, 1], [2, 4], [3, 6]]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=0) + rows, cols = data.shape + col_indices = cp.arange(cols)[None, :] + cp.testing.assert_array_equal( + rev_idx[sort_idx, col_indices], cp.tile(cp.arange(rows)[:, None], [1, 2]) + ) + + def test_2d_axis0_negative_axis(self): + data = cp.array([[1, 2], [3, 4]]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=-2) + rows, cols = data.shape + col_indices = cp.arange(cols)[None, :] + cp.testing.assert_array_equal( + rev_idx[sort_idx, col_indices], cp.tile(cp.arange(rows)[:, None], [1, 2]) + ) + + def test_2d_axis1_sorted(self): + data = cp.array([[1, 2, 3], [4, 5, 6]]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=1) + rows, cols = data.shape + row_indices = cp.arange(rows)[:, None] + cp.testing.assert_array_equal( + rev_idx[row_indices, sort_idx], cp.tile(cp.arange(cols), [2, 1]) + ) + + def test_2d_axis1_random(self): + data = cp.array([[3, 1, 2], [6, 4, 5]]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=1) + rows, cols = data.shape + row_indices = cp.arange(rows)[:, None] + cp.testing.assert_array_equal( + rev_idx[row_indices, sort_idx], cp.tile(cp.arange(cols), [2, 1]) + ) + + def test_2d_axis1_negative_axis(self): + data = cp.array([[1, 2], [3, 4]]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=-1) + rows, cols = data.shape + row_indices = cp.arange(rows)[:, None] + cp.testing.assert_array_equal( + rev_idx[row_indices, sort_idx], cp.tile(cp.arange(cols), [2, 1]) + ) + + def test_2d_single_row(self): + data = cp.array([[3, 1, 2]]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=1) + rows, cols = data.shape + row_indices = cp.arange(rows)[:, None] + cp.testing.assert_array_equal( + rev_idx[row_indices, sort_idx], cp.arange(cols)[None, :] + ) + + def test_2d_single_col(self): + data = cp.array([[3], [1], [2]]) + sort_idx, rev_idx = argsort_with_reverse(data, axis=0) + rows, cols = data.shape + col_indices = cp.arange(cols)[None, :] + cp.testing.assert_array_equal( + rev_idx[sort_idx, col_indices], cp.arange(rows)[:, None] + ) + + +class TestArgsortErrors: + def test_invalid_dim_3d(self): + data = cp.ones((2, 2, 2)) + with pytest.raises(ValueError): + argsort_with_reverse(data) + + def test_invalid_dim_0d(self): + data = cp.array(1) + with pytest.raises(ValueError): + argsort_with_reverse(data) + + def test_invalid_axis_1d(self): + data = cp.array([1, 2, 3]) + with pytest.raises(ValueError): + argsort_with_reverse(data, axis=1) + + def test_invalid_axis_2d(self): + data = cp.array([[1, 2], [3, 4]]) + with pytest.raises(ValueError): + argsort_with_reverse(data, axis=2) diff --git a/tests/test_prep/test_stripe.py b/tests/test_prep/test_stripe.py index e98af7cd..45d1bc70 100644 --- a/tests/test_prep/test_stripe.py +++ b/tests/test_prep/test_stripe.py @@ -3,6 +3,9 @@ from cupy.cuda import nvtx import numpy as np import pytest +import tomophantom +from tomophantom.qualitymetrics import QualityTools +import os from httomolibgpu.prep.normalize import dark_flat_field_correction, minus_log from httomolibgpu.prep.stripe import ( @@ -12,6 +15,7 @@ remove_all_stripe, raven_filter, ) +from httomolibgpu.recon.algorithm import FBP3d_tomobar from numpy.testing import assert_allclose @@ -292,7 +296,7 @@ def test_remove_all_stripe_on_data(data, flats, darks): data_norm = dark_flat_field_correction(cp.copy(data), flats, darks) data_norm = minus_log(data_norm) - data_after_stripe_removal = cp.asnumpy(remove_all_stripe(data_norm)) + data_after_stripe_removal = cp.asnumpy(remove_all_stripe(data_norm, normalize=True)) assert_allclose(np.mean(data_after_stripe_removal), 0.266914, rtol=1e-05) assert_allclose( @@ -313,6 +317,89 @@ def test_remove_all_stripe_on_data(data, flats, darks): assert data_after_stripe_removal.flags.c_contiguous +def test_remove_all_stripe_on_data_not_normalized(data, flats, darks): + # --- testing the CuPy implementation from TomoCupy ---# + data_norm = dark_flat_field_correction(cp.copy(data), flats, darks) + data_norm = minus_log(data_norm) + + data_after_stripe_removal = cp.asnumpy( + remove_all_stripe(data_norm, normalize=False) + ) + + assert_allclose(np.mean(data_after_stripe_removal), 0.281677, rtol=1e-05) + assert_allclose( + np.mean(data_after_stripe_removal, axis=(1, 2)).sum(), 50.701828, rtol=1e-06 + ) + assert_allclose(np.median(data_after_stripe_removal), 0.017893, rtol=1e-04) + assert_allclose(np.max(data_after_stripe_removal), 2.446227, rtol=1e-05) + assert_allclose( + np.median(data_after_stripe_removal, axis=(1, 2)).sum(), 3.252155, rtol=1e-6 + ) + assert_allclose( + np.median(data_after_stripe_removal, axis=(0, 1)).sum(), 31.853008, rtol=1e-6 + ) + + data = None #: free up GPU memory + # make sure the output is float32 + assert data_after_stripe_removal.dtype == np.float32 + assert data_after_stripe_removal.flags.c_contiguous + + +def test_remove_all_stripe_quality(ensure_clean_memory): + model = 13 # select a model number from the library + N_size = 256 # Define phantom dimensions using a scalar value (cubic phantom) + path = os.path.dirname(tomophantom.__file__) + path_library3D = os.path.join(path, "phantomlib", "Phantom3DLibrary.dat") + # This will generate a N_size x N_size x N_size phantom (3D) + phantom_tm = tomophantom.TomoP3D.Model(model, N_size, path_library3D) + + # Projection geometry related parameters: + Horiz_det = N_size # detector column count (horizontal) + Vert_det = N_size # detector row count (vertical) (no reason for it to be > N) + angles_num = int(0.5 * np.pi * N_size) + # angles number + angles = np.linspace(0.0, 179.9, angles_num, dtype="float32") # in degrees + angles_rad = angles * (np.pi / 180.0) + projData3D_analyt = tomophantom.TomoP3D.ModelSino( + model, N_size, Horiz_det, Vert_det, angles, path_library3D + ) + + # forming dictionaries with artifact types + _noise_ = { + "noise_type": "Poisson", + "noise_amplitude": 10000, # noise amplitude + "noise_seed": 0, + } + + _stripes_ = { + "stripes_percentage": 1.2, + "stripes_maxthickness": 3, + "stripes_intensity": 0.25, + "stripes_type": "mix", + "stripes_variability": 0.005, + } + + projData3D_analyt_noisy = tomophantom.artefacts.artefacts_mix( + projData3D_analyt, **_noise_, **_stripes_ + ) + projData3D_analyt_noisy = cp.asarray(projData3D_analyt_noisy).swapaxes(0, 1) + + recNumerical_noisy = FBP3d_tomobar( + projData3D_analyt_noisy, angles_rad + ) # FBP reconstruction + Qtools = QualityTools(phantom_tm, recNumerical_noisy.get()) + RMSE_noisy = Qtools.rmse() + + projData3D_analyt_stripe_removed = remove_all_stripe(projData3D_analyt_noisy) + recNumerical_stripe_removed = FBP3d_tomobar( + projData3D_analyt_stripe_removed, angles_rad + ) # FBP reconstruction + Qtools = QualityTools(phantom_tm, recNumerical_stripe_removed.get()) + RMSE_stripe_removed = Qtools.rmse() + + assert RMSE_noisy > RMSE_stripe_removed + + @pytest.mark.perf def test_remove_all_stripe_performance(ensure_clean_memory): data_host = ( diff --git a/zenodo-tests/test_prep/test_stripe.py b/zenodo-tests/test_prep/test_stripe.py index 722fd13b..aaf9011f 100644 --- a/zenodo-tests/test_prep/test_stripe.py +++ b/zenodo-tests/test_prep/test_stripe.py @@ -172,6 +172,7 @@ def test_remove_all_stripe_i12_dataset4( la_size=la_size_val, sm_size=sm_size_val, dim=1, + normalize=True, ) residual_calc = data_normalised - output