Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11', '3.12', '3.13']
python-version: ['3.11']

steps:
- uses: actions/checkout@v4
Expand Down
8 changes: 4 additions & 4 deletions shine/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def render_one_galaxy(
ny=img_cfg.size_y,
scale=img_cfg.pixel_scale,
offset=(
x - img_cfg.size_x / 2 + 0.5,
y - img_cfg.size_y / 2 + 0.5,
x - img_cfg.size_x / 2,
y - img_cfg.size_y / 2,
),
).array

Expand Down Expand Up @@ -262,8 +262,8 @@ def render_single_scene(
ny=img_cfg.size_y,
scale=img_cfg.pixel_scale,
offset=(
x - img_cfg.size_x / 2 + 0.5,
y - img_cfg.size_y / 2 + 0.5,
x - img_cfg.size_x / 2,
y - img_cfg.size_y / 2,
),
).array

Expand Down
55 changes: 52 additions & 3 deletions tests/test_validation/test_batched_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
import numpy as np
import pytest
import yaml
import jax

from shine.config import ConfigHandler
from shine.config import (
ConfigHandler,
MAPConfig,
)
from shine.data import Observation
from shine.validation.extraction import split_batched_idata
from shine.validation.simulation import (
BatchSimulationResult,
generate_batch_observations,
)

from shine.inference import Inference
from shine.scene import SceneBuilder

@pytest.fixture
def level0_config(tmp_path):
Expand Down Expand Up @@ -86,6 +91,12 @@ def multi_object_config(tmp_path):

return ConfigHandler.load(str(config_path))

@pytest.fixture
def enable_x64_temporarily():
initial_state = jax.config.x64_enabled
jax.config.update("jax_enable_x64", True)
yield
jax.config.update("jax_enable_x64", initial_state)

def _make_batched_mock_idata(
n_batch=3,
Expand Down Expand Up @@ -200,7 +211,45 @@ def test_batched_model_runs(self, level0_config):
assert tr["g1"]["value"].shape == (2,)
assert tr["g2"]["value"].shape == (2,)


def test_centers_match(self, level0_config, enable_x64_temporarily):
"""Verify that the inferred MAP center matches the exact observation center"""
test_config = level0_config.model_copy()
test_config.image.noise.sigma = 1e-6
test_config.inference.method = "map"
test_config.inference.map_config = MAPConfig(num_steps=150, learning_rate=0.1)

batch_sim = generate_batch_observations(
test_config,
shear_pairs=[(0.01, 0.0)],
seeds=[42],
)

builder = SceneBuilder(test_config)
model_fn = builder.build_batched_model(n_batch=1)

engine = Inference(model=model_fn, config=test_config.inference)
rng_key = jax.random.PRNGKey(42)

batched_estimates = engine.run_map(
rng_key=rng_key,
observed_data=batch_sim.images,
extra_args={"psf": batch_sim.psf_model},
map_config=test_config.inference.map_config,
)
x_inferred = float(np.array(batched_estimates["x"])[0])
y_inferred = float(np.array(batched_estimates["y"])[0])

fallback_x = test_config.image.size_x / 2.0
fallback_y = test_config.image.size_y / 2.0
expected_center_x = batch_sim.ground_truths[0].get("x", fallback_x)
expected_center_y = batch_sim.ground_truths[0].get("y", fallback_y)

assert x_inferred == pytest.approx(expected_center_x, abs=1e-2), (
f"X-axis alignment error. Expected X-coordinate: {expected_center_x}, Inferred X-coordinate: {x_inferred}. "
)
assert y_inferred == pytest.approx(expected_center_y, abs=1e-2), (
f"Y-axis alignment error. Expected Y-coordinate: {expected_center_y}, Inferred Y-coordinate: {y_inferred}. "
)
class TestSplitBatchedIdata:
"""Tests for split_batched_idata()."""

Expand Down
Loading