Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- bump: minor
changes:
added:
- Entity relationship approach for simulation filtering that preserves household integrity
- Reusable variable validation functions (`get_variable`, `validate_variable_entity`, `validate_household_variable`)
changed:
- Refactored `_filter_simulation_by_household_variable` to use explicit entity relationship mapping
- Place-level filtering now builds entity_rel DataFrame for cleaner filtering logic
191 changes: 182 additions & 9 deletions policyengine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,79 @@
SubsampleType = Optional[int]


# =============================================================================
# Variable Validation Functions
# =============================================================================


def get_variable(tax_benefit_system: Any, variable_name: str) -> Any:
"""Get a variable from the tax-benefit system, raising if not found.

Args:
tax_benefit_system: The tax-benefit system to search.
variable_name: The name of the variable to find.

Returns:
The variable object from the tax-benefit system.

Raises:
ValueError: If the variable is not found.
"""
if variable_name not in tax_benefit_system.variables:
raise ValueError(
f"Variable '{variable_name}' not found in tax-benefit system"
)
return tax_benefit_system.variables[variable_name]


def validate_variable_entity(
tax_benefit_system: Any,
variable_name: str,
expected_entity: str,
) -> None:
"""Validate that a variable belongs to the expected entity type.

Args:
tax_benefit_system: The tax-benefit system containing the variable.
variable_name: The name of the variable to validate.
expected_entity: The expected entity key (e.g., "household", "person").

Raises:
ValueError: If the variable is not found or belongs to a different entity.
"""
variable = get_variable(tax_benefit_system, variable_name)
actual_entity = variable.entity.key

if actual_entity != expected_entity:
raise ValueError(
f"Variable '{variable_name}' is a {actual_entity}-level variable, "
f"not a {expected_entity}-level variable."
)


def validate_household_variable(
tax_benefit_system: Any,
variable_name: str,
) -> None:
"""Validate that a variable is a household-level variable.

Args:
tax_benefit_system: The tax-benefit system containing the variable.
variable_name: The name of the variable to validate.

Raises:
ValueError: If the variable is not found or is not household-level.
"""
variable = get_variable(tax_benefit_system, variable_name)

if variable.entity.key != "household":
raise ValueError(
f"Variable '{variable_name}' is a {variable.entity.key}-level variable, "
f"not a household-level variable. Only household-level variables can be "
f"used for filtering to preserve household integrity."
)


class SimulationOptions(BaseModel):
country: CountryType = Field(..., description="The country to simulate.")
scope: ScopeType = Field(..., description="The scope of the simulation.")
Expand Down Expand Up @@ -388,6 +461,108 @@ def _apply_us_region_to_simulation(
)
return simulation

def _build_entity_relationships(
self,
simulation: CountryMicrosimulation,
) -> pd.DataFrame:
"""Build a DataFrame mapping each person to their containing entities.

Creates an explicit relationship map between persons and all entity
types (household, tax_unit, etc.). This enables filtering at any
entity level while preserving the integrity of all related entities.

Args:
simulation: The microsimulation to extract relationships from.

Returns:
A DataFrame indexed by person with columns for each entity ID.
"""
entity_rel = pd.DataFrame(
{"person_id": simulation.calculate("person_id").values}
)

# Add household relationship (required for all countries)
entity_rel["household_id"] = simulation.calculate(
"household_id", map_to="person"
).values

# Add country-specific entity relationships
tbs = simulation.tax_benefit_system
optional_entities = [
"tax_unit_id",
"spm_unit_id",
"family_id",
"marital_unit_id",
]

for entity_id in optional_entities:
if entity_id in tbs.variables:
entity_rel[entity_id] = simulation.calculate(
entity_id, map_to="person"
).values

return entity_rel

def _filter_simulation_by_household_variable(
self,
simulation: CountryMicrosimulation,
simulation_type: type,
variable_name: str,
variable_value: Any,
reform: ReformType | None,
) -> CountrySimulation:
"""Filter a simulation to only include households where a variable matches a value.

Uses the entity relationship approach: builds an explicit map of all
entity relationships, filters at the household level, and keeps all
persons in matching households to preserve entity integrity.

Args:
simulation: The microsimulation to filter.
simulation_type: The type of simulation to create (e.g., Microsimulation).
variable_name: The name of the variable to filter on. Must be a
household-level variable.
variable_value: The value to match. For string variables that may be
stored as bytes in HDF5, both str and bytes versions are checked.
reform: Optional reform to apply to the filtered simulation.

Returns:
A new simulation containing only households where the variable matches.

Raises:
ValueError: If the variable is not a household-level variable.
"""
validate_household_variable(
simulation.tax_benefit_system, variable_name
)

# Build entity relationships
entity_rel = self._build_entity_relationships(simulation)

# Get household-level variable values
hh_values = simulation.calculate(variable_name).values
hh_ids = simulation.calculate("household_id").values

# Create mask for matching households, handling bytes encoding
if isinstance(variable_value, str):
hh_mask = (hh_values == variable_value) | (
hh_values == variable_value.encode()
)
else:
hh_mask = hh_values == variable_value

matching_hh_ids = set(hh_ids[hh_mask])

# Filter entity_rel to persons in matching households
person_mask = entity_rel["household_id"].isin(matching_hh_ids)
filtered_entity_rel = entity_rel[person_mask]

# Filter the input DataFrame using the filtered person indices
df = simulation.to_input_dataframe()
subset_df = df.iloc[filtered_entity_rel.index]

return simulation_type(dataset=subset_df, reform=reform)

def _filter_us_simulation_by_place(
self,
simulation: CountryMicrosimulation,
Expand All @@ -409,16 +584,14 @@ def _filter_us_simulation_by_place(
from policyengine.utils.data.datasets import parse_us_place_region

_, place_fips_code = parse_us_place_region(region)
df = simulation.to_input_dataframe()
# Get place_fips at person level since to_input_dataframe() is person-level
person_place_fips = simulation.calculate(
"place_fips", map_to="person"
).values
# place_fips may be stored as bytes in HDF5; handle both str and bytes
mask = (person_place_fips == place_fips_code) | (
person_place_fips == place_fips_code.encode()

return self._filter_simulation_by_household_variable(
simulation=simulation,
simulation_type=simulation_type,
variable_name="place_fips",
variable_value=place_fips_code,
reform=reform,
)
return simulation_type(dataset=df[mask], reform=reform)

def check_model_version(self) -> None:
"""
Expand Down
Loading
Loading