diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a639c56..bdd39fa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.11', '3.12', '3.13'] + python-version: ['3.11'] steps: - uses: actions/checkout@v4 diff --git a/shine/scene.py b/shine/scene.py index 2decbe4..ee0dc82 100644 --- a/shine/scene.py +++ b/shine/scene.py @@ -159,8 +159,8 @@ def render_one_galaxy( ny=img_cfg.size_y, scale=img_cfg.pixel_scale, offset=( - x - img_cfg.size_x / 2 + 0.5, - y - img_cfg.size_y / 2 + 0.5, + x - img_cfg.size_x / 2, + y - img_cfg.size_y / 2, ), ).array @@ -262,8 +262,8 @@ def render_single_scene( ny=img_cfg.size_y, scale=img_cfg.pixel_scale, offset=( - x - img_cfg.size_x / 2 + 0.5, - y - img_cfg.size_y / 2 + 0.5, + x - img_cfg.size_x / 2, + y - img_cfg.size_y / 2, ), ).array diff --git a/tests/test_validation/test_batched_inference.py b/tests/test_validation/test_batched_inference.py index f82c8be..dc20d1a 100644 --- a/tests/test_validation/test_batched_inference.py +++ b/tests/test_validation/test_batched_inference.py @@ -4,15 +4,20 @@ import numpy as np import pytest import yaml +import jax -from shine.config import ConfigHandler +from shine.config import ( + ConfigHandler, + MAPConfig, +) from shine.data import Observation from shine.validation.extraction import split_batched_idata from shine.validation.simulation import ( BatchSimulationResult, generate_batch_observations, ) - +from shine.inference import Inference +from shine.scene import SceneBuilder @pytest.fixture def level0_config(tmp_path): @@ -86,6 +91,12 @@ def multi_object_config(tmp_path): return ConfigHandler.load(str(config_path)) +@pytest.fixture +def enable_x64_temporarily(): + initial_state = jax.config.x64_enabled + jax.config.update("jax_enable_x64", True) + yield + jax.config.update("jax_enable_x64", initial_state) def _make_batched_mock_idata( n_batch=3, @@ -200,7 +211,45 @@ def test_batched_model_runs(self, level0_config): assert tr["g1"]["value"].shape == (2,) assert tr["g2"]["value"].shape == (2,) - + def test_centers_match(self, level0_config, enable_x64_temporarily): + """Verify that the inferred MAP center matches the exact observation center""" + test_config = level0_config.model_copy() + test_config.image.noise.sigma = 1e-6 + test_config.inference.method = "map" + test_config.inference.map_config = MAPConfig(num_steps=150, learning_rate=0.1) + + batch_sim = generate_batch_observations( + test_config, + shear_pairs=[(0.01, 0.0)], + seeds=[42], + ) + + builder = SceneBuilder(test_config) + model_fn = builder.build_batched_model(n_batch=1) + + engine = Inference(model=model_fn, config=test_config.inference) + rng_key = jax.random.PRNGKey(42) + + batched_estimates = engine.run_map( + rng_key=rng_key, + observed_data=batch_sim.images, + extra_args={"psf": batch_sim.psf_model}, + map_config=test_config.inference.map_config, + ) + x_inferred = float(np.array(batched_estimates["x"])[0]) + y_inferred = float(np.array(batched_estimates["y"])[0]) + + fallback_x = test_config.image.size_x / 2.0 + fallback_y = test_config.image.size_y / 2.0 + expected_center_x = batch_sim.ground_truths[0].get("x", fallback_x) + expected_center_y = batch_sim.ground_truths[0].get("y", fallback_y) + + assert x_inferred == pytest.approx(expected_center_x, abs=1e-2), ( + f"X-axis alignment error. Expected X-coordinate: {expected_center_x}, Inferred X-coordinate: {x_inferred}. " + ) + assert y_inferred == pytest.approx(expected_center_y, abs=1e-2), ( + f"Y-axis alignment error. Expected Y-coordinate: {expected_center_y}, Inferred Y-coordinate: {y_inferred}. " + ) class TestSplitBatchedIdata: """Tests for split_batched_idata()."""