From 05649dc18004619e3c28f99d6c9fd4950796dd2f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 5 Feb 2026 20:45:36 +0300 Subject: [PATCH 1/5] Fix household-level filtering to preserve household integrity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add generic _filter_simulation_by_household_variable() method that: - Validates the variable is household-level via tax-benefit system - Filters at household level by selecting household IDs first - Keeps ALL persons in matching households (subsample pattern) - Handles bytes encoding for HDF5 string values - Refactor _filter_us_simulation_by_place() to use the new method This follows the same pattern as policyengine-core's subsample() method to ensure household integrity is preserved during filtering. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- policyengine/simulation.py | 82 +++++++++++++++++++++++++++++++++----- 1 file changed, 73 insertions(+), 9 deletions(-) diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 864aa90a..67c876d5 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -388,6 +388,72 @@ def _apply_us_region_to_simulation( ) return simulation + def _filter_simulation_by_household_variable( + self, + simulation: CountryMicrosimulation, + simulation_type: type, + variable_name: str, + variable_value: Any, + reform: ReformType | None, + ) -> CountrySimulation: + """Filter a simulation to only include households where a variable matches a value. + + Uses household-level filtering to preserve household integrity - all persons + in matching households are kept together. + + Args: + simulation: The microsimulation to filter. + simulation_type: The type of simulation to create (e.g., Microsimulation). + variable_name: The name of the variable to filter on. Must be a + household-level variable. + variable_value: The value to match. For string variables that may be + stored as bytes in HDF5, both str and bytes versions are checked. + reform: Optional reform to apply to the filtered simulation. + + Returns: + A new simulation containing only households where the variable matches. + + Raises: + ValueError: If the variable is not a household-level variable. + """ + # Validate that the variable is household-level + tbs = simulation.tax_benefit_system + if variable_name not in tbs.variables: + raise ValueError(f"Variable '{variable_name}' not found in tax-benefit system") + + variable = tbs.variables[variable_name] + if variable.entity.key != "household": + raise ValueError( + f"Variable '{variable_name}' is a {variable.entity.key}-level variable, " + f"not a household-level variable. Only household-level variables can be " + f"used for filtering to preserve household integrity." + ) + + df = simulation.to_input_dataframe() + + # Find the household_id column + hh_id_cols = [c for c in df.columns if c.startswith("household_id__")] + if not hh_id_cols: + raise ValueError("Could not find household_id column in dataframe") + hh_id_col = hh_id_cols[0] + + # Get variable values at person level (since df is person-level) + values = simulation.calculate(variable_name, map_to="person").values + + # Create mask, handling potential bytes encoding for string values + if isinstance(variable_value, str): + mask = (values == variable_value) | (values == variable_value.encode()) + else: + mask = values == variable_value + + # Get household IDs where any person matches + matching_hh_ids = df.loc[mask, hh_id_col].unique() + + # Keep ALL persons in matching households + subset_df = df[df[hh_id_col].isin(matching_hh_ids)] + + return simulation_type(dataset=subset_df, reform=reform) + def _filter_us_simulation_by_place( self, simulation: CountryMicrosimulation, @@ -409,16 +475,14 @@ def _filter_us_simulation_by_place( from policyengine.utils.data.datasets import parse_us_place_region _, place_fips_code = parse_us_place_region(region) - df = simulation.to_input_dataframe() - # Get place_fips at person level since to_input_dataframe() is person-level - person_place_fips = simulation.calculate( - "place_fips", map_to="person" - ).values - # place_fips may be stored as bytes in HDF5; handle both str and bytes - mask = (person_place_fips == place_fips_code) | ( - person_place_fips == place_fips_code.encode() + + return self._filter_simulation_by_household_variable( + simulation=simulation, + simulation_type=simulation_type, + variable_name="place_fips", + variable_value=place_fips_code, + reform=reform, ) - return simulation_type(dataset=df[mask], reform=reform) def check_model_version(self) -> None: """ From 9215c1f86e485276e9072969640baf2defb78f3d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 5 Feb 2026 22:01:26 +0300 Subject: [PATCH 2/5] Refactor filtering to use entity relationship approach MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _build_entity_relationships() method that creates an explicit mapping between persons and all entity types (household, tax_unit, spm_unit, family, marital_unit) - Refactor _filter_simulation_by_household_variable() to use entity_rel for cleaner, more explicit filtering logic - Filter at household level, then use entity_rel to find all persons in matching households This approach makes the entity relationships explicit and enables future extensions for filtering at other entity levels. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- policyengine/simulation.py | 73 ++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 19 deletions(-) diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 67c876d5..e1243aea 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -388,6 +388,41 @@ def _apply_us_region_to_simulation( ) return simulation + def _build_entity_relationships( + self, + simulation: CountryMicrosimulation, + ) -> pd.DataFrame: + """Build a DataFrame mapping each person to their containing entities. + + Creates an explicit relationship map between persons and all entity + types (household, tax_unit, etc.). This enables filtering at any + entity level while preserving the integrity of all related entities. + + Args: + simulation: The microsimulation to extract relationships from. + + Returns: + A DataFrame indexed by person with columns for each entity ID. + """ + entity_rel = pd.DataFrame({"person_id": simulation.calculate("person_id").values}) + + # Add household relationship (required for all countries) + entity_rel["household_id"] = simulation.calculate( + "household_id", map_to="person" + ).values + + # Add country-specific entity relationships + tbs = simulation.tax_benefit_system + optional_entities = ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"] + + for entity_id in optional_entities: + if entity_id in tbs.variables: + entity_rel[entity_id] = simulation.calculate( + entity_id, map_to="person" + ).values + + return entity_rel + def _filter_simulation_by_household_variable( self, simulation: CountryMicrosimulation, @@ -398,8 +433,9 @@ def _filter_simulation_by_household_variable( ) -> CountrySimulation: """Filter a simulation to only include households where a variable matches a value. - Uses household-level filtering to preserve household integrity - all persons - in matching households are kept together. + Uses the entity relationship approach: builds an explicit map of all + entity relationships, filters at the household level, and keeps all + persons in matching households to preserve entity integrity. Args: simulation: The microsimulation to filter. @@ -416,7 +452,6 @@ def _filter_simulation_by_household_variable( Raises: ValueError: If the variable is not a household-level variable. """ - # Validate that the variable is household-level tbs = simulation.tax_benefit_system if variable_name not in tbs.variables: raise ValueError(f"Variable '{variable_name}' not found in tax-benefit system") @@ -429,28 +464,28 @@ def _filter_simulation_by_household_variable( f"used for filtering to preserve household integrity." ) - df = simulation.to_input_dataframe() - - # Find the household_id column - hh_id_cols = [c for c in df.columns if c.startswith("household_id__")] - if not hh_id_cols: - raise ValueError("Could not find household_id column in dataframe") - hh_id_col = hh_id_cols[0] + # Build entity relationships + entity_rel = self._build_entity_relationships(simulation) - # Get variable values at person level (since df is person-level) - values = simulation.calculate(variable_name, map_to="person").values + # Get household-level variable values + hh_values = simulation.calculate(variable_name).values + hh_ids = simulation.calculate("household_id").values - # Create mask, handling potential bytes encoding for string values + # Create mask for matching households, handling bytes encoding if isinstance(variable_value, str): - mask = (values == variable_value) | (values == variable_value.encode()) + hh_mask = (hh_values == variable_value) | (hh_values == variable_value.encode()) else: - mask = values == variable_value + hh_mask = hh_values == variable_value - # Get household IDs where any person matches - matching_hh_ids = df.loc[mask, hh_id_col].unique() + matching_hh_ids = set(hh_ids[hh_mask]) - # Keep ALL persons in matching households - subset_df = df[df[hh_id_col].isin(matching_hh_ids)] + # Filter entity_rel to persons in matching households + person_mask = entity_rel["household_id"].isin(matching_hh_ids) + filtered_entity_rel = entity_rel[person_mask] + + # Filter the input DataFrame using the filtered person indices + df = simulation.to_input_dataframe() + subset_df = df.iloc[filtered_entity_rel.index] return simulation_type(dataset=subset_df, reform=reform) From ce81e710364349c5f0e1f52dec72739250d9e923 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 5 Feb 2026 22:06:53 +0300 Subject: [PATCH 3/5] Update tests for entity_rel filtering approach MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update fixtures to support entity_rel approach with: - create_mock_tax_benefit_system() for variable entity validation - Enhanced mock simulations that handle multiple calculate() calls - Support for persons_per_household parameter - Add TestBuildEntityRelationships test class - Add TestFilterSimulationByHouseholdVariable test class with validation tests - Update existing tests to work with person-level DataFrame output - Add test for multi-person household entity preservation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/country/test_us_places.py | 158 ++++++++++++++++++++++++- tests/fixtures/country/us_places.py | 171 +++++++++++++++++++++++----- 2 files changed, 298 insertions(+), 31 deletions(-) diff --git a/tests/country/test_us_places.py b/tests/country/test_us_places.py index bfdda841..bd6101a1 100644 --- a/tests/country/test_us_places.py +++ b/tests/country/test_us_places.py @@ -1,7 +1,14 @@ -"""Tests for US place-level (city) filtering functionality.""" +"""Tests for US place-level (city) filtering functionality. + +Tests the entity_rel filtering approach which: +1. Builds explicit entity relationships (person -> household, tax_unit, etc.) +2. Filters at household level to preserve entity integrity +3. Creates new simulations from filtered DataFrames +""" import pytest import pandas as pd +import numpy as np from unittest.mock import Mock, patch from tests.fixtures.country.us_places import ( @@ -42,11 +49,50 @@ create_mock_simulation_with_bytes_place_fips, create_mock_simulation_type, create_simulation_instance, + create_mock_tax_benefit_system, ) +class TestBuildEntityRelationships: + """Tests for the _build_entity_relationships method.""" + + def test__given__simulation__then__returns_dataframe_with_person_and_household_ids( + self, + ): + # Given + mock_sim = create_mock_simulation_with_place_fips( + MIXED_PLACES_WITH_PATERSON, persons_per_household=2 + ) + + # When + sim_instance = create_simulation_instance() + entity_rel = sim_instance._build_entity_relationships(mock_sim) + + # Then + assert "person_id" in entity_rel.columns + assert "household_id" in entity_rel.columns + # 5 households * 2 persons each = 10 persons + assert len(entity_rel) == 10 + + def test__given__simulation__then__includes_optional_entity_ids_when_available( + self, + ): + # Given + mock_sim = create_mock_simulation_with_place_fips( + MIXED_PLACES_WITH_PATERSON + ) + + # When + sim_instance = create_simulation_instance() + entity_rel = sim_instance._build_entity_relationships(mock_sim) + + # Then: Optional US entity IDs should be present + assert "tax_unit_id" in entity_rel.columns + assert "spm_unit_id" in entity_rel.columns + + class TestFilterUsSimulationByPlace: - """Tests for the _filter_us_simulation_by_place method.""" + """Tests for the _filter_us_simulation_by_place method using entity_rel approach.""" def test__given__households_in_target_place__then__filters_to_matching_households( self, @@ -70,8 +116,10 @@ def test__given__households_in_target_place__then__filters_to_matching_household call_args = mock_simulation_type.call_args filtered_df = call_args.kwargs["dataset"] + # With entity_rel, DataFrame is person-level assert len(filtered_df) == EXPECTED_PATERSON_COUNT_IN_MIXED - assert all(filtered_df["place_fips"] == NJ_PATERSON_FIPS) + # Verify all records belong to Paterson households + assert all(filtered_df["place_fips__2024"] == NJ_PATERSON_FIPS) def test__given__no_households_in_target_place__then__returns_empty_dataset( self, @@ -187,7 +235,109 @@ def test__given__different_place_in_same_state__then__filters_correctly( filtered_df = call_args.kwargs["dataset"] assert len(filtered_df) == EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ - assert all(filtered_df["place_fips"] == NJ_NEWARK_FIPS) + assert all(filtered_df["place_fips__2024"] == NJ_NEWARK_FIPS) + + def test__given__multi_person_households__then__preserves_all_persons( + self, + ): + # Given: 3 households with 2 persons each, only first household in Paterson + mock_sim = create_mock_simulation_with_place_fips( + [NJ_PATERSON_FIPS, NJ_NEWARK_FIPS, NJ_JERSEY_CITY_FIPS], + persons_per_household=2, + ) + mock_simulation_type = create_mock_simulation_type() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_us_simulation_by_place( + simulation=mock_sim, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=None, + ) + + # Then: Should have 2 persons (both from the Paterson household) + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + + assert len(filtered_df) == 2 # 1 household * 2 persons + # All persons should be from household 0 + assert all(filtered_df["household_id__2024"] == 0) + + +class TestFilterSimulationByHouseholdVariable: + """Tests for _filter_simulation_by_household_variable validation and behavior.""" + + def test__given__non_household_variable__then__raises_value_error(self): + # Given: A mock with a person-level variable + mock_sim = Mock() + mock_tbs = Mock() + mock_var = Mock() + mock_var.entity = Mock() + mock_var.entity.key = "person" # Not household-level + mock_tbs.variables = {"age": mock_var} + mock_sim.tax_benefit_system = mock_tbs + + mock_simulation_type = create_mock_simulation_type() + + # When / Then + sim_instance = create_simulation_instance() + with pytest.raises(ValueError) as exc_info: + sim_instance._filter_simulation_by_household_variable( + simulation=mock_sim, + simulation_type=mock_simulation_type, + variable_name="age", + variable_value=30, + reform=None, + ) + + assert "person-level variable" in str(exc_info.value) + assert "household-level variable" in str(exc_info.value) + + def test__given__nonexistent_variable__then__raises_value_error(self): + # Given + mock_sim = Mock() + mock_tbs = Mock() + mock_tbs.variables = {} # Empty - no variables + mock_sim.tax_benefit_system = mock_tbs + + mock_simulation_type = create_mock_simulation_type() + + # When / Then + sim_instance = create_simulation_instance() + with pytest.raises(ValueError) as exc_info: + sim_instance._filter_simulation_by_household_variable( + simulation=mock_sim, + simulation_type=mock_simulation_type, + variable_name="nonexistent_var", + variable_value="test", + reform=None, + ) + + assert "not found" in str(exc_info.value) + + def test__given__household_variable__then__filters_successfully(self): + # Given + mock_sim = create_mock_simulation_with_place_fips( + MIXED_PLACES_WITH_PATERSON + ) + mock_simulation_type = create_mock_simulation_type() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_simulation_by_household_variable( + simulation=mock_sim, + simulation_type=mock_simulation_type, + variable_name="place_fips", + variable_value=NJ_PATERSON_FIPS, + reform=None, + ) + + # Then: Should create simulation with filtered data + assert mock_simulation_type.called + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + assert len(filtered_df) == EXPECTED_PATERSON_COUNT_IN_MIXED class TestApplyUsRegionToSimulationWithPlace: diff --git a/tests/fixtures/country/us_places.py b/tests/fixtures/country/us_places.py index ff65af60..f7ce250c 100644 --- a/tests/fixtures/country/us_places.py +++ b/tests/fixtures/country/us_places.py @@ -1,9 +1,15 @@ -"""Test fixtures for US place-level filtering tests.""" +"""Test fixtures for US place-level filtering tests. + +These fixtures support testing the entity_rel filtering approach which: +1. Builds explicit entity relationships (person -> household, tax_unit, etc.) +2. Filters at household level to preserve entity integrity +3. Creates new simulations from filtered DataFrames +""" import pytest import numpy as np import pandas as pd -from unittest.mock import Mock +from unittest.mock import Mock, MagicMock # ============================================================================= # Place FIPS Constants @@ -123,36 +129,113 @@ # ============================================================================= +def create_mock_tax_benefit_system(household_variables: list[str] | None = None) -> Mock: + """Create a mock tax benefit system with variable entity information. + + Args: + household_variables: List of variable names that are household-level. + Defaults to ["place_fips"]. + + Returns: + Mock TaxBenefitSystem with variables dict containing entity info. + """ + if household_variables is None: + household_variables = ["place_fips"] + + mock_tbs = Mock() + mock_tbs.variables = {} + + for var_name in household_variables: + mock_var = Mock() + mock_var.entity = Mock() + mock_var.entity.key = "household" + mock_tbs.variables[var_name] = mock_var + + # Add standard entity ID variables + for entity_id in ["person_id", "household_id", "tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]: + mock_var = Mock() + mock_var.entity = Mock() + # Entity IDs belong to their respective entities + entity_name = entity_id.replace("_id", "") + mock_var.entity.key = entity_name if entity_name != "person" else "person" + mock_tbs.variables[entity_id] = mock_var + + return mock_tbs + + def create_mock_simulation_with_place_fips( place_fips_values: list[str], household_ids: list[int] | None = None, + persons_per_household: int = 1, ) -> Mock: - """Create a mock simulation with place_fips data. + """Create a mock simulation with place_fips data for entity_rel filtering. + + Supports the entity_rel approach by mocking: + - calculate() with variable-specific return values + - tax_benefit_system.variables for entity validation + - to_input_dataframe() returning person-level DataFrame Args: place_fips_values: List of place FIPS codes for each household. household_ids: Optional list of household IDs. + persons_per_household: Number of persons per household (default 1). Returns: - Mock simulation object with calculate() and to_input_dataframe() configured. + Mock simulation object configured for entity_rel filtering. """ if household_ids is None: household_ids = list(range(len(place_fips_values))) - mock_sim = Mock() + num_households = len(place_fips_values) + num_persons = num_households * persons_per_household - # Mock calculate to return place_fips values - mock_calculate_result = Mock() - mock_calculate_result.values = np.array(place_fips_values) - mock_sim.calculate.return_value = mock_calculate_result + # Create person-level data by repeating household data + person_ids = list(range(num_persons)) + person_household_ids = [] + person_place_fips = [] + for i, (hh_id, place) in enumerate(zip(household_ids, place_fips_values)): + for _ in range(persons_per_household): + person_household_ids.append(hh_id) + person_place_fips.append(place) - # Mock to_input_dataframe to return a DataFrame - df = pd.DataFrame( - { - "household_id": household_ids, - "place_fips": place_fips_values, - } - ) + mock_sim = Mock() + + # Mock tax_benefit_system + mock_sim.tax_benefit_system = create_mock_tax_benefit_system() + + # Mock calculate to return different values based on variable and map_to + def mock_calculate(variable_name, map_to=None, period=None): + result = Mock() + if variable_name == "place_fips": + if map_to == "person": + result.values = np.array(person_place_fips) + else: + result.values = np.array(place_fips_values) + elif variable_name == "person_id": + result.values = np.array(person_ids) + elif variable_name == "household_id": + if map_to == "person": + result.values = np.array(person_household_ids) + else: + result.values = np.array(household_ids) + elif variable_name in ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]: + # For simplicity, use household_id as proxy for other entity IDs + if map_to == "person": + result.values = np.array(person_household_ids) + else: + result.values = np.array(household_ids) + else: + result.values = np.array([]) + return result + + mock_sim.calculate = mock_calculate + + # Mock to_input_dataframe to return person-level DataFrame + df = pd.DataFrame({ + "person_id__2024": person_ids, + "household_id__2024": person_household_ids, + "place_fips__2024": person_place_fips, + }) mock_sim.to_input_dataframe.return_value = df return mock_sim @@ -161,31 +244,65 @@ def create_mock_simulation_with_place_fips( def create_mock_simulation_with_bytes_place_fips( place_fips_values: list[bytes], household_ids: list[int] | None = None, + persons_per_household: int = 1, ) -> Mock: """Create a mock simulation with bytes place_fips data (as from HDF5). Args: place_fips_values: List of place FIPS codes as bytes. household_ids: Optional list of household IDs. + persons_per_household: Number of persons per household (default 1). Returns: - Mock simulation object with calculate() and to_input_dataframe() configured. + Mock simulation object configured for entity_rel filtering. """ if household_ids is None: household_ids = list(range(len(place_fips_values))) - mock_sim = Mock() + num_households = len(place_fips_values) + num_persons = num_households * persons_per_household - mock_calculate_result = Mock() - mock_calculate_result.values = np.array(place_fips_values) - mock_sim.calculate.return_value = mock_calculate_result + person_ids = list(range(num_persons)) + person_household_ids = [] + person_place_fips = [] + for i, (hh_id, place) in enumerate(zip(household_ids, place_fips_values)): + for _ in range(persons_per_household): + person_household_ids.append(hh_id) + person_place_fips.append(place) - df = pd.DataFrame( - { - "household_id": household_ids, - "place_fips": place_fips_values, - } - ) + mock_sim = Mock() + mock_sim.tax_benefit_system = create_mock_tax_benefit_system() + + def mock_calculate(variable_name, map_to=None, period=None): + result = Mock() + if variable_name == "place_fips": + if map_to == "person": + result.values = np.array(person_place_fips) + else: + result.values = np.array(place_fips_values) + elif variable_name == "person_id": + result.values = np.array(person_ids) + elif variable_name == "household_id": + if map_to == "person": + result.values = np.array(person_household_ids) + else: + result.values = np.array(household_ids) + elif variable_name in ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]: + if map_to == "person": + result.values = np.array(person_household_ids) + else: + result.values = np.array(household_ids) + else: + result.values = np.array([]) + return result + + mock_sim.calculate = mock_calculate + + df = pd.DataFrame({ + "person_id__2024": person_ids, + "household_id__2024": person_household_ids, + "place_fips__2024": person_place_fips, + }) mock_sim.to_input_dataframe.return_value = df return mock_sim From 5c5c97e7ddc973a86c6a82e9c459ebeeff235e73 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 5 Feb 2026 22:15:11 +0300 Subject: [PATCH 4/5] Refactor variable validation into reusable functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add get_variable() to retrieve variable from TBS with error handling - Add validate_variable_entity() for generic entity type validation - Add validate_household_variable() for household-specific validation - Refactor _filter_simulation_by_household_variable to use new functions - Add comprehensive tests for all validation functions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- policyengine/simulation.py | 85 ++++++++++++++++++++++---- tests/test_simulation.py | 118 +++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 11 deletions(-) diff --git a/policyengine/simulation.py b/policyengine/simulation.py index e1243aea..ce4a3dfb 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -47,6 +47,79 @@ SubsampleType = Optional[int] +# ============================================================================= +# Variable Validation Functions +# ============================================================================= + + +def get_variable(tax_benefit_system: Any, variable_name: str) -> Any: + """Get a variable from the tax-benefit system, raising if not found. + + Args: + tax_benefit_system: The tax-benefit system to search. + variable_name: The name of the variable to find. + + Returns: + The variable object from the tax-benefit system. + + Raises: + ValueError: If the variable is not found. + """ + if variable_name not in tax_benefit_system.variables: + raise ValueError( + f"Variable '{variable_name}' not found in tax-benefit system" + ) + return tax_benefit_system.variables[variable_name] + + +def validate_variable_entity( + tax_benefit_system: Any, + variable_name: str, + expected_entity: str, +) -> None: + """Validate that a variable belongs to the expected entity type. + + Args: + tax_benefit_system: The tax-benefit system containing the variable. + variable_name: The name of the variable to validate. + expected_entity: The expected entity key (e.g., "household", "person"). + + Raises: + ValueError: If the variable is not found or belongs to a different entity. + """ + variable = get_variable(tax_benefit_system, variable_name) + actual_entity = variable.entity.key + + if actual_entity != expected_entity: + raise ValueError( + f"Variable '{variable_name}' is a {actual_entity}-level variable, " + f"not a {expected_entity}-level variable." + ) + + +def validate_household_variable( + tax_benefit_system: Any, + variable_name: str, +) -> None: + """Validate that a variable is a household-level variable. + + Args: + tax_benefit_system: The tax-benefit system containing the variable. + variable_name: The name of the variable to validate. + + Raises: + ValueError: If the variable is not found or is not household-level. + """ + variable = get_variable(tax_benefit_system, variable_name) + + if variable.entity.key != "household": + raise ValueError( + f"Variable '{variable_name}' is a {variable.entity.key}-level variable, " + f"not a household-level variable. Only household-level variables can be " + f"used for filtering to preserve household integrity." + ) + + class SimulationOptions(BaseModel): country: CountryType = Field(..., description="The country to simulate.") scope: ScopeType = Field(..., description="The scope of the simulation.") @@ -452,17 +525,7 @@ def _filter_simulation_by_household_variable( Raises: ValueError: If the variable is not a household-level variable. """ - tbs = simulation.tax_benefit_system - if variable_name not in tbs.variables: - raise ValueError(f"Variable '{variable_name}' not found in tax-benefit system") - - variable = tbs.variables[variable_name] - if variable.entity.key != "household": - raise ValueError( - f"Variable '{variable_name}' is a {variable.entity.key}-level variable, " - f"not a household-level variable. Only household-level variables can be " - f"used for filtering to preserve household integrity." - ) + validate_household_variable(simulation.tax_benefit_system, variable_name) # Build entity relationships entity_rel = self._build_entity_relationships(simulation) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index f761588e..b4f9c752 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -8,7 +8,9 @@ mock_simulation_with_cliff_vars, ) import sys +import pytest from copy import deepcopy +from unittest.mock import Mock class TestSimulation: @@ -83,3 +85,119 @@ def test__calculates_correct_cliff_metrics( assert cliff_result.cliff_gap == 100.0 assert cliff_result.cliff_share == 0.5 + + +class TestVariableValidation: + """Tests for variable validation functions.""" + + @staticmethod + def _create_mock_tbs(variables: dict[str, str]) -> Mock: + """Create a mock tax-benefit system with specified variables. + + Args: + variables: Dict mapping variable names to entity keys. + + Returns: + Mock TBS with variables configured. + """ + mock_tbs = Mock() + mock_tbs.variables = {} + for var_name, entity_key in variables.items(): + mock_var = Mock() + mock_var.entity = Mock() + mock_var.entity.key = entity_key + mock_tbs.variables[var_name] = mock_var + return mock_tbs + + class TestGetVariable: + def test__given__existing_variable__then__returns_variable(self): + from policyengine.simulation import get_variable + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"place_fips": "household"} + ) + + result = get_variable(mock_tbs, "place_fips") + + assert result is mock_tbs.variables["place_fips"] + + def test__given__nonexistent_variable__then__raises_value_error(self): + from policyengine.simulation import get_variable + + mock_tbs = TestVariableValidation._create_mock_tbs({}) + + with pytest.raises(ValueError) as exc_info: + get_variable(mock_tbs, "nonexistent") + + assert "not found" in str(exc_info.value) + assert "nonexistent" in str(exc_info.value) + + class TestValidateVariableEntity: + def test__given__matching_entity__then__passes(self): + from policyengine.simulation import validate_variable_entity + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"place_fips": "household"} + ) + + # Should not raise + validate_variable_entity(mock_tbs, "place_fips", "household") + + def test__given__mismatched_entity__then__raises_value_error(self): + from policyengine.simulation import validate_variable_entity + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"age": "person"} + ) + + with pytest.raises(ValueError) as exc_info: + validate_variable_entity(mock_tbs, "age", "household") + + assert "person-level" in str(exc_info.value) + assert "household-level" in str(exc_info.value) + + def test__given__nonexistent_variable__then__raises_value_error(self): + from policyengine.simulation import validate_variable_entity + + mock_tbs = TestVariableValidation._create_mock_tbs({}) + + with pytest.raises(ValueError) as exc_info: + validate_variable_entity(mock_tbs, "nonexistent", "household") + + assert "not found" in str(exc_info.value) + + class TestValidateHouseholdVariable: + def test__given__household_variable__then__passes(self): + from policyengine.simulation import validate_household_variable + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"place_fips": "household"} + ) + + # Should not raise + validate_household_variable(mock_tbs, "place_fips") + + def test__given__person_variable__then__raises_value_error(self): + from policyengine.simulation import validate_household_variable + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"age": "person"} + ) + + with pytest.raises(ValueError) as exc_info: + validate_household_variable(mock_tbs, "age") + + assert "person-level" in str(exc_info.value) + assert "household-level" in str(exc_info.value) + + def test__given__tax_unit_variable__then__raises_value_error(self): + from policyengine.simulation import validate_household_variable + + mock_tbs = TestVariableValidation._create_mock_tbs( + {"filing_status": "tax_unit"} + ) + + with pytest.raises(ValueError) as exc_info: + validate_household_variable(mock_tbs, "filing_status") + + assert "tax_unit-level" in str(exc_info.value) From ba01c582d9a2ffc22183892caf8d50b9d09a252d Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 5 Feb 2026 22:18:47 +0300 Subject: [PATCH 5/5] Add changelog entry and format code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- changelog_entry.yaml | 8 +++++ policyengine/simulation.py | 19 +++++++--- tests/fixtures/country/us_places.py | 55 +++++++++++++++++++++-------- 3 files changed, 63 insertions(+), 19 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..047421db 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,8 @@ +- bump: minor + changes: + added: + - Entity relationship approach for simulation filtering that preserves household integrity + - Reusable variable validation functions (`get_variable`, `validate_variable_entity`, `validate_household_variable`) + changed: + - Refactored `_filter_simulation_by_household_variable` to use explicit entity relationship mapping + - Place-level filtering now builds entity_rel DataFrame for cleaner filtering logic diff --git a/policyengine/simulation.py b/policyengine/simulation.py index ce4a3dfb..8ad432ca 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -477,7 +477,9 @@ def _build_entity_relationships( Returns: A DataFrame indexed by person with columns for each entity ID. """ - entity_rel = pd.DataFrame({"person_id": simulation.calculate("person_id").values}) + entity_rel = pd.DataFrame( + {"person_id": simulation.calculate("person_id").values} + ) # Add household relationship (required for all countries) entity_rel["household_id"] = simulation.calculate( @@ -486,7 +488,12 @@ def _build_entity_relationships( # Add country-specific entity relationships tbs = simulation.tax_benefit_system - optional_entities = ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"] + optional_entities = [ + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ] for entity_id in optional_entities: if entity_id in tbs.variables: @@ -525,7 +532,9 @@ def _filter_simulation_by_household_variable( Raises: ValueError: If the variable is not a household-level variable. """ - validate_household_variable(simulation.tax_benefit_system, variable_name) + validate_household_variable( + simulation.tax_benefit_system, variable_name + ) # Build entity relationships entity_rel = self._build_entity_relationships(simulation) @@ -536,7 +545,9 @@ def _filter_simulation_by_household_variable( # Create mask for matching households, handling bytes encoding if isinstance(variable_value, str): - hh_mask = (hh_values == variable_value) | (hh_values == variable_value.encode()) + hh_mask = (hh_values == variable_value) | ( + hh_values == variable_value.encode() + ) else: hh_mask = hh_values == variable_value diff --git a/tests/fixtures/country/us_places.py b/tests/fixtures/country/us_places.py index f7ce250c..3dc7e12a 100644 --- a/tests/fixtures/country/us_places.py +++ b/tests/fixtures/country/us_places.py @@ -129,7 +129,9 @@ # ============================================================================= -def create_mock_tax_benefit_system(household_variables: list[str] | None = None) -> Mock: +def create_mock_tax_benefit_system( + household_variables: list[str] | None = None, +) -> Mock: """Create a mock tax benefit system with variable entity information. Args: @@ -152,12 +154,21 @@ def create_mock_tax_benefit_system(household_variables: list[str] | None = None) mock_tbs.variables[var_name] = mock_var # Add standard entity ID variables - for entity_id in ["person_id", "household_id", "tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]: + for entity_id in [ + "person_id", + "household_id", + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ]: mock_var = Mock() mock_var.entity = Mock() # Entity IDs belong to their respective entities entity_name = entity_id.replace("_id", "") - mock_var.entity.key = entity_name if entity_name != "person" else "person" + mock_var.entity.key = ( + entity_name if entity_name != "person" else "person" + ) mock_tbs.variables[entity_id] = mock_var return mock_tbs @@ -218,7 +229,12 @@ def mock_calculate(variable_name, map_to=None, period=None): result.values = np.array(person_household_ids) else: result.values = np.array(household_ids) - elif variable_name in ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]: + elif variable_name in [ + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ]: # For simplicity, use household_id as proxy for other entity IDs if map_to == "person": result.values = np.array(person_household_ids) @@ -231,11 +247,13 @@ def mock_calculate(variable_name, map_to=None, period=None): mock_sim.calculate = mock_calculate # Mock to_input_dataframe to return person-level DataFrame - df = pd.DataFrame({ - "person_id__2024": person_ids, - "household_id__2024": person_household_ids, - "place_fips__2024": person_place_fips, - }) + df = pd.DataFrame( + { + "person_id__2024": person_ids, + "household_id__2024": person_household_ids, + "place_fips__2024": person_place_fips, + } + ) mock_sim.to_input_dataframe.return_value = df return mock_sim @@ -287,7 +305,12 @@ def mock_calculate(variable_name, map_to=None, period=None): result.values = np.array(person_household_ids) else: result.values = np.array(household_ids) - elif variable_name in ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]: + elif variable_name in [ + "tax_unit_id", + "spm_unit_id", + "family_id", + "marital_unit_id", + ]: if map_to == "person": result.values = np.array(person_household_ids) else: @@ -298,11 +321,13 @@ def mock_calculate(variable_name, map_to=None, period=None): mock_sim.calculate = mock_calculate - df = pd.DataFrame({ - "person_id__2024": person_ids, - "household_id__2024": person_household_ids, - "place_fips__2024": person_place_fips, - }) + df = pd.DataFrame( + { + "person_id__2024": person_ids, + "household_id__2024": person_household_ids, + "place_fips__2024": person_place_fips, + } + ) mock_sim.to_input_dataframe.return_value = df return mock_sim