diff --git a/changelog_entry.yaml b/changelog_entry.yaml
index e69de29b..cace1367 100644
--- a/changelog_entry.yaml
+++ b/changelog_entry.yaml
@@ -0,0 +1,7 @@
+- bump: minor
+ changes:
+ added:
+ - Place-level (city) impact analysis for US Census places with population over 100,000.
+ - Input validation for place region strings.
+ removed:
+ - Deprecated "city/nyc" region format (use "place/NY-51000" instead).
diff --git a/policyengine/simulation.py b/policyengine/simulation.py
index 9cd68c82..864aa90a 100644
--- a/policyengine/simulation.py
+++ b/policyengine/simulation.py
@@ -376,29 +376,49 @@ def _apply_us_region_to_simulation(
"""Apply US-specific regional filtering to a simulation.
Note: Most US regions (states, congressional districts) now use
- scoped datasets rather than filtering. Only NYC still requires
- filtering from the national dataset (and is still using the pooled
- CPS by default). This should be replaced with an approach based on
- the new datasets.
+ scoped datasets rather than filtering. Place-level regions use
+ the parent state's dataset and filter by place_fips.
"""
- if region == "city/nyc":
- simulation = self._filter_us_simulation_by_nyc(
+ if isinstance(region, str) and region.startswith("place/"):
+ simulation = self._filter_us_simulation_by_place(
simulation=simulation,
simulation_type=simulation_type,
+ region=region,
reform=reform,
)
return simulation
- def _filter_us_simulation_by_nyc(
+ def _filter_us_simulation_by_place(
self,
simulation: CountryMicrosimulation,
simulation_type: type,
+ region: str,
reform: ReformType | None,
) -> CountrySimulation:
- """Filter a US simulation to only include NYC households."""
+ """Filter a US simulation to only include households in a specific Census place.
+
+ Args:
+ simulation: The microsimulation to filter.
+ simulation_type: The type of simulation to create.
+ region: A place region string (e.g., "place/NJ-57000").
+ reform: The reform to apply to the filtered simulation.
+
+ Returns:
+ A new simulation containing only households in the specified place.
+ """
+ from policyengine.utils.data.datasets import parse_us_place_region
+
+ _, place_fips_code = parse_us_place_region(region)
df = simulation.to_input_dataframe()
- in_nyc = simulation.calculate("in_nyc", map_to="person").values
- return simulation_type(dataset=df[in_nyc], reform=reform)
+ # Get place_fips at person level since to_input_dataframe() is person-level
+ person_place_fips = simulation.calculate(
+ "place_fips", map_to="person"
+ ).values
+ # place_fips may be stored as bytes in HDF5; handle both str and bytes
+ mask = (person_place_fips == place_fips_code) | (
+ person_place_fips == place_fips_code.encode()
+ )
+ return simulation_type(dataset=df[mask], reform=reform)
def check_model_version(self) -> None:
"""
diff --git a/policyengine/utils/charts.py b/policyengine/utils/charts.py
index 6575c1e3..2dafd4c9 100644
--- a/policyengine/utils/charts.py
+++ b/policyengine/utils/charts.py
@@ -4,13 +4,11 @@
def add_fonts():
- fonts = HTML(
- """
+ fonts = HTML("""
- """
- )
+ """)
return display_html(fonts)
diff --git a/policyengine/utils/data/datasets.py b/policyengine/utils/data/datasets.py
index b1962e53..9b6614e2 100644
--- a/policyengine/utils/data/datasets.py
+++ b/policyengine/utils/data/datasets.py
@@ -47,13 +47,8 @@ def _get_default_us_dataset(region: str | None) -> str:
if region_type == "nationwide":
return ECPS_2024
- elif region_type == "city":
- # TODO: Implement a better approach to this for our one
- # city, New York City.
- # Cities use the pooled CPS dataset
- return CPS_2023_POOLED
- # For state and congressional_district, region is guaranteed to be non-None
+ # For state, congressional_district, and place, region is guaranteed to be non-None
assert region is not None
if region_type == "state":
@@ -68,6 +63,11 @@ def _get_default_us_dataset(region: str | None) -> str:
state_code, district_number
)
+ elif region_type == "place":
+ # Expected format: "place/NJ-57000"
+ state_code, _ = parse_us_place_region(region)
+ return get_us_state_dataset_path(state_code)
+
raise ValueError(f"Unhandled US region type: {region_type}")
@@ -124,9 +124,11 @@ def get_us_congressional_district_dataset_path(
return f"{US_DATA_BUCKET}/districts/{state_code.upper()}-{district_number:02d}.h5"
-USRegionType = Literal["nationwide", "city", "state", "congressional_district"]
+USRegionType = Literal[
+ "nationwide", "state", "congressional_district", "place"
+]
-US_REGION_PREFIXES = ("city", "state", "congressional_district")
+US_REGION_PREFIXES = ("state", "congressional_district", "place")
def determine_us_region_type(region: str | None) -> USRegionType:
@@ -134,11 +136,11 @@ def determine_us_region_type(region: str | None) -> USRegionType:
Determine the type of US region from a region string.
Args:
- region: A region string (e.g., "us", "city/nyc", "state/CA",
- "congressional_district/CA-01") or None.
+ region: A region string (e.g., "us", "state/CA",
+ "congressional_district/CA-01", "place/NJ-57000") or None.
Returns:
- One of "nationwide", "city", "state", or "congressional_district".
+ One of "nationwide", "state", "congressional_district", or "place".
Raises:
ValueError: If the region string has an unrecognized prefix.
@@ -154,3 +156,45 @@ def determine_us_region_type(region: str | None) -> USRegionType:
f"Unrecognized US region format: '{region}'. "
f"Expected 'us', or one of the following prefixes: {list(US_REGION_PREFIXES)}"
)
+
+
+def parse_us_place_region(region: str) -> Tuple[str, str]:
+ """Parse a place region string into (state_code, place_fips).
+
+ Format: 'place/{STATE}-{PLACE_FIPS}'
+ Example: 'place/NJ-57000' -> ('NJ', '57000')
+
+ Args:
+ region: A place region string (e.g., "place/NJ-57000").
+
+ Returns:
+ A tuple of (state_code, place_fips).
+
+ Raises:
+ ValueError: If the region format is invalid or missing required parts.
+ """
+ if not region.startswith("place/"):
+ raise ValueError(
+ f"Invalid place region format: '{region}'. "
+ "Expected format: 'place/{{STATE}}-{{PLACE_FIPS}}'"
+ )
+
+ place_str = region.split("/")[1]
+ if "-" not in place_str:
+ raise ValueError(
+ f"Invalid place region format: '{region}'. "
+ "Expected format: 'place/{{STATE}}-{{PLACE_FIPS}}'"
+ )
+
+ state_code, place_fips = place_str.split("-", 1)
+
+ if not state_code:
+ raise ValueError(
+ f"Invalid place region: '{region}'. State code cannot be empty."
+ )
+ if not place_fips:
+ raise ValueError(
+ f"Invalid place region: '{region}'. Place FIPS code cannot be empty."
+ )
+
+ return state_code, place_fips
diff --git a/scripts/test_place_simulations.py b/scripts/test_place_simulations.py
new file mode 100644
index 00000000..c5215196
--- /dev/null
+++ b/scripts/test_place_simulations.py
@@ -0,0 +1,591 @@
+"""
+Test script for place-level simulations.
+
+This script tests the place-level filtering functionality by:
+1. Running simulations for specific cities (NYC, Paterson NJ, Grand Rapids MI)
+2. Comparing results to their parent states
+3. Testing a CTC fully refundable reform
+4. Checking that budgetary impact is proportional to population
+5. Testing garbage/invalid inputs
+"""
+
+import sys
+import traceback
+from typing import Any
+
+# =============================================================================
+# Configuration
+# =============================================================================
+
+# Place definitions: (name, region_string, parent_state_region)
+PLACES_TO_TEST = [
+ ("New York City, NY", "place/NY-51000", "state/NY"),
+ ("Paterson, NJ", "place/NJ-57000", "state/NJ"),
+ ("Grand Rapids, MI", "place/MI-34000", "state/MI"),
+]
+
+# Reform: Make CTC fully refundable
+CTC_FULLY_REFUNDABLE_REFORM = {
+ "gov.irs.credits.ctc.refundable.fully_refundable": {
+ "2024-01-01.2100-12-31": True
+ }
+}
+
+# Garbage/invalid inputs to test
+GARBAGE_INPUTS = [
+ ("Empty place FIPS", "place/NY-"),
+ ("Invalid state code", "place/XX-12345"),
+ ("Non-existent place FIPS", "place/NY-99999"),
+ ("Malformed region - no dash", "place/NY12345"),
+ ("Malformed region - no slash", "placeNY-12345"),
+ ("Wrong prefix", "city/NY-51000"),
+ ("SQL injection attempt", "place/NY-51000'; DROP TABLE--"),
+ ("Very long FIPS", "place/NY-" + "1" * 100),
+ ("Negative FIPS", "place/NY--12345"),
+ ("None as region", None),
+ ("Empty string", ""),
+ ("Just place/", "place/"),
+ ("Lowercase state", "place/ny-51000"),
+]
+
+
+# =============================================================================
+# Helper Functions
+# =============================================================================
+
+
+def print_header(title: str) -> None:
+ """Print a formatted section header."""
+ print("\n" + "=" * 70)
+ print(f" {title}")
+ print("=" * 70)
+
+
+def print_subheader(title: str) -> None:
+ """Print a formatted subsection header."""
+ print(f"\n--- {title} ---")
+
+
+def format_currency(amount: float) -> str:
+ """Format a number as currency."""
+ if abs(amount) >= 1e9:
+ return f"${amount/1e9:.2f}B"
+ elif abs(amount) >= 1e6:
+ return f"${amount/1e6:.2f}M"
+ elif abs(amount) >= 1e3:
+ return f"${amount/1e3:.2f}K"
+ else:
+ return f"${amount:.2f}"
+
+
+def safe_simulation_run(
+ region: str | None,
+ reform: dict | None = None,
+ description: str = "",
+) -> dict[str, Any] | None:
+ """
+ Safely attempt to run a simulation, catching and reporting errors.
+
+ Returns simulation results dict or None if failed.
+ """
+ from policyengine import Simulation
+
+ try:
+ options = {
+ "country": "us",
+ "scope": "macro",
+ "time_period": "2024",
+ }
+ if region is not None:
+ options["region"] = region
+ if reform is not None:
+ options["reform"] = reform
+
+ sim = Simulation(**options)
+
+ # Try to get basic metrics
+ result = {
+ "success": True,
+ "region": region,
+ "household_count_baseline": None,
+ "total_net_income_baseline": None,
+ "budgetary_impact": None,
+ }
+
+ # Get baseline household count and income
+ try:
+ baseline_sim = sim.baseline_simulation
+ if baseline_sim is not None:
+ weights = baseline_sim.calculate("household_weight").values
+ result["household_count_baseline"] = weights.sum()
+
+ net_income = baseline_sim.calculate(
+ "household_net_income", map_to="household"
+ ).values
+ result["total_net_income_baseline"] = (
+ net_income * weights
+ ).sum()
+ except Exception as e:
+ result["baseline_error"] = str(e)
+
+ # Get budgetary impact if reform was applied
+ if reform is not None:
+ try:
+ impact = sim.calculate("budgetary/overall/budgetary_impact")
+ result["budgetary_impact"] = impact
+ except Exception as e:
+ result["impact_error"] = str(e)
+
+ return result
+
+ except Exception as e:
+ return {
+ "success": False,
+ "region": region,
+ "error": str(e),
+ "error_type": type(e).__name__,
+ "traceback": traceback.format_exc(),
+ }
+
+
+# =============================================================================
+# Test Functions
+# =============================================================================
+
+
+def test_place_vs_state_comparison() -> dict[str, Any]:
+ """
+ Test place-level simulations and compare to state-level results.
+
+ Returns a summary of results.
+ """
+ print_header("Place vs State Comparison Tests")
+
+ results = {
+ "places": {},
+ "states": {},
+ "comparisons": {},
+ }
+
+ # First, run state-level simulations
+ print_subheader("Running State-Level Baseline Simulations")
+ state_regions = set(place[2] for place in PLACES_TO_TEST)
+
+ for state_region in state_regions:
+ print(f" Running {state_region}...", end=" ", flush=True)
+ result = safe_simulation_run(state_region)
+ results["states"][state_region] = result
+ if result and result.get("success"):
+ print(
+ f"OK - {result.get('household_count_baseline', 'N/A'):,.0f} households"
+ )
+ else:
+ print(f"FAILED - {result.get('error', 'Unknown error')}")
+
+ # Run place-level simulations
+ print_subheader("Running Place-Level Baseline Simulations")
+
+ for place_name, place_region, state_region in PLACES_TO_TEST:
+ print(
+ f" Running {place_name} ({place_region})...", end=" ", flush=True
+ )
+ result = safe_simulation_run(place_region)
+ results["places"][place_region] = result
+ if result and result.get("success"):
+ print(
+ f"OK - {result.get('household_count_baseline', 'N/A'):,.0f} households"
+ )
+ else:
+ print(f"FAILED - {result.get('error', 'Unknown error')}")
+
+ # Compare place to state
+ print_subheader("Place-to-State Population Ratios")
+
+ for place_name, place_region, state_region in PLACES_TO_TEST:
+ place_result = results["places"].get(place_region, {})
+ state_result = results["states"].get(state_region, {})
+
+ if (
+ place_result.get("success")
+ and state_result.get("success")
+ and place_result.get("household_count_baseline")
+ and state_result.get("household_count_baseline")
+ ):
+
+ place_hh = place_result["household_count_baseline"]
+ state_hh = state_result["household_count_baseline"]
+ ratio = place_hh / state_hh if state_hh > 0 else 0
+
+ results["comparisons"][place_region] = {
+ "place_households": place_hh,
+ "state_households": state_hh,
+ "ratio": ratio,
+ }
+
+ print(f" {place_name}:")
+ print(f" Place households: {place_hh:>12,.0f}")
+ print(f" State households: {state_hh:>12,.0f}")
+ print(f" Ratio (place/state): {ratio:>10.2%}")
+ else:
+ print(f" {place_name}: Could not compare (missing data)")
+
+ return results
+
+
+def test_ctc_reform_impact() -> dict[str, Any]:
+ """
+ Test CTC fully refundable reform impact at place and state levels.
+
+ Checks that budgetary impact is roughly proportional to population.
+ """
+ print_header("CTC Fully Refundable Reform Impact Tests")
+
+ results = {
+ "places": {},
+ "states": {},
+ "proportionality_checks": {},
+ }
+
+ # Run state-level reform simulations
+ print_subheader("Running State-Level Reform Simulations")
+ state_regions = set(place[2] for place in PLACES_TO_TEST)
+
+ for state_region in state_regions:
+ print(
+ f" Running {state_region} with CTC reform...", end=" ", flush=True
+ )
+ result = safe_simulation_run(state_region, CTC_FULLY_REFUNDABLE_REFORM)
+ results["states"][state_region] = result
+ if result and result.get("success"):
+ impact = result.get("budgetary_impact")
+ if impact is not None:
+ print(f"OK - Impact: {format_currency(impact)}")
+ else:
+ print(f"OK - Impact: N/A")
+ else:
+ print(f"FAILED - {result.get('error', 'Unknown error')}")
+
+ # Run place-level reform simulations
+ print_subheader("Running Place-Level Reform Simulations")
+
+ for place_name, place_region, state_region in PLACES_TO_TEST:
+ print(
+ f" Running {place_name} with CTC reform...", end=" ", flush=True
+ )
+ result = safe_simulation_run(place_region, CTC_FULLY_REFUNDABLE_REFORM)
+ results["places"][place_region] = result
+ if result and result.get("success"):
+ impact = result.get("budgetary_impact")
+ if impact is not None:
+ print(f"OK - Impact: {format_currency(impact)}")
+ else:
+ print(f"OK - Impact: N/A")
+ else:
+ print(f"FAILED - {result.get('error', 'Unknown error')}")
+
+ # Check proportionality
+ print_subheader("Proportionality Check (Impact vs Population Ratio)")
+
+ for place_name, place_region, state_region in PLACES_TO_TEST:
+ place_result = results["places"].get(place_region, {})
+ state_result = results["states"].get(state_region, {})
+
+ place_impact = place_result.get("budgetary_impact")
+ state_impact = state_result.get("budgetary_impact")
+ place_hh = place_result.get("household_count_baseline")
+ state_hh = state_result.get("household_count_baseline")
+
+ if all(
+ v is not None
+ for v in [place_impact, state_impact, place_hh, state_hh]
+ ):
+ if state_impact != 0 and state_hh != 0:
+ impact_ratio = place_impact / state_impact
+ pop_ratio = place_hh / state_hh
+
+ # Check if impact ratio is within 50% of population ratio
+ # (allowing for demographic differences)
+ ratio_diff = (
+ abs(impact_ratio - pop_ratio) / pop_ratio
+ if pop_ratio > 0
+ else float("inf")
+ )
+ is_proportional = ratio_diff < 0.5
+
+ results["proportionality_checks"][place_region] = {
+ "place_impact": place_impact,
+ "state_impact": state_impact,
+ "impact_ratio": impact_ratio,
+ "population_ratio": pop_ratio,
+ "ratio_difference": ratio_diff,
+ "is_proportional": is_proportional,
+ }
+
+ status = "PASS" if is_proportional else "WARN"
+ print(f" {place_name}: [{status}]")
+ print(f" Impact ratio: {impact_ratio:>10.2%}")
+ print(f" Population ratio: {pop_ratio:>10.2%}")
+ print(f" Difference: {ratio_diff:>10.2%}")
+ else:
+ print(f" {place_name}: Could not check (missing data)")
+
+ return results
+
+
+def test_garbage_inputs() -> dict[str, Any]:
+ """
+ Test how the system handles garbage/invalid inputs.
+
+ This helps identify where we need more guards or type checking.
+ """
+ print_header("Garbage Input Tests")
+
+ results = {
+ "tests": {},
+ "needs_guards": [],
+ }
+
+ for test_name, garbage_input in GARBAGE_INPUTS:
+ print(f" Testing: {test_name}", end=" ", flush=True)
+ print(f"({repr(garbage_input)[:50]})")
+
+ result = safe_simulation_run(garbage_input)
+ results["tests"][test_name] = {
+ "input": garbage_input,
+ "result": result,
+ }
+
+ if result and result.get("success"):
+ # This might indicate we need more validation
+ print(f" Result: SUCCEEDED (unexpected - may need validation)")
+ results["needs_guards"].append(
+ {
+ "test_name": test_name,
+ "input": garbage_input,
+ "reason": "Succeeded when it probably should have failed",
+ }
+ )
+ else:
+ error_type = (
+ result.get("error_type", "Unknown") if result else "Unknown"
+ )
+ error_msg = (
+ result.get("error", "Unknown error")
+ if result
+ else "Unknown error"
+ )
+ print(f" Result: FAILED ({error_type})")
+ print(f" Error: {error_msg[:100]}")
+
+ # Check if error message is helpful
+ if (
+ "Traceback" in str(result.get("traceback", ""))
+ and "KeyError" in error_type
+ ):
+ results["needs_guards"].append(
+ {
+ "test_name": test_name,
+ "input": garbage_input,
+ "reason": f"KeyError instead of descriptive error: {error_msg}",
+ }
+ )
+
+ return results
+
+
+def test_datasets_functions_directly() -> dict[str, Any]:
+ """
+ Test the datasets.py functions directly with various inputs.
+ """
+ print_header("Direct Function Tests (datasets.py)")
+
+ from policyengine.utils.data.datasets import (
+ determine_us_region_type,
+ parse_us_place_region,
+ get_default_dataset,
+ )
+
+ results = {"tests": [], "needs_guards": []}
+
+ # Test determine_us_region_type
+ print_subheader("Testing determine_us_region_type()")
+
+ test_cases = [
+ ("place/NY-51000", "place"),
+ ("place/NJ-57000", "place"),
+ ("state/NY", "state"),
+ ("us", "nationwide"),
+ (None, "nationwide"),
+ ("congressional_district/CA-01", "congressional_district"),
+ ]
+
+ for input_val, expected in test_cases:
+ try:
+ result = determine_us_region_type(input_val)
+ status = "PASS" if result == expected else "FAIL"
+ print(f" {repr(input_val):40} -> {result:25} [{status}]")
+ results["tests"].append(
+ {
+ "function": "determine_us_region_type",
+ "input": input_val,
+ "result": result,
+ "expected": expected,
+ "passed": result == expected,
+ }
+ )
+ except Exception as e:
+ print(f" {repr(input_val):40} -> ERROR: {e}")
+ results["tests"].append(
+ {
+ "function": "determine_us_region_type",
+ "input": input_val,
+ "error": str(e),
+ "passed": False,
+ }
+ )
+
+ # Test invalid inputs for determine_us_region_type
+ print_subheader("Testing determine_us_region_type() with invalid inputs")
+
+ invalid_inputs = [
+ "invalid/something",
+ "place",
+ "",
+ "city/nyc", # Should now fail since we removed city
+ ]
+
+ for input_val in invalid_inputs:
+ try:
+ result = determine_us_region_type(input_val)
+ print(f" {repr(input_val):40} -> {result} [UNEXPECTED SUCCESS]")
+ results["needs_guards"].append(
+ {
+ "function": "determine_us_region_type",
+ "input": input_val,
+ "reason": "Should have raised ValueError",
+ }
+ )
+ except ValueError as e:
+ print(f" {repr(input_val):40} -> ValueError (expected) [PASS]")
+ except Exception as e:
+ print(f" {repr(input_val):40} -> {type(e).__name__}: {e}")
+
+ # Test parse_us_place_region
+ print_subheader("Testing parse_us_place_region()")
+
+ place_test_cases = [
+ ("place/NY-51000", ("NY", "51000")),
+ ("place/NJ-57000", ("NJ", "57000")),
+ ("place/MI-34000", ("MI", "34000")),
+ ("place/ca-12345", ("ca", "12345")), # lowercase
+ ]
+
+ for input_val, expected in place_test_cases:
+ try:
+ result = parse_us_place_region(input_val)
+ status = "PASS" if result == expected else "FAIL"
+ print(f" {repr(input_val):30} -> {result} [{status}]")
+ except Exception as e:
+ print(f" {repr(input_val):30} -> ERROR: {e}")
+
+ # Test invalid inputs for parse_us_place_region
+ print_subheader("Testing parse_us_place_region() with invalid inputs")
+
+ invalid_place_inputs = [
+ "state/NY", # Wrong prefix
+ "place/NY", # Missing FIPS
+ "place/", # Empty
+ "NY-51000", # Missing prefix
+ "place/NY-", # Empty FIPS
+ ]
+
+ for input_val in invalid_place_inputs:
+ try:
+ result = parse_us_place_region(input_val)
+ print(f" {repr(input_val):30} -> {result} [UNEXPECTED SUCCESS]")
+ results["needs_guards"].append(
+ {
+ "function": "parse_us_place_region",
+ "input": input_val,
+ "reason": f"Should have raised error, got {result}",
+ }
+ )
+ except (ValueError, IndexError) as e:
+ print(
+ f" {repr(input_val):30} -> {type(e).__name__} (expected) [PASS]"
+ )
+ except Exception as e:
+ print(f" {repr(input_val):30} -> {type(e).__name__}: {e}")
+ results["needs_guards"].append(
+ {
+ "function": "parse_us_place_region",
+ "input": input_val,
+ "reason": f"Got {type(e).__name__} instead of ValueError",
+ }
+ )
+
+ return results
+
+
+def print_summary(all_results: dict[str, Any]) -> None:
+ """Print a summary of all test results."""
+ print_header("SUMMARY")
+
+ # Check for items needing guards
+ all_needs_guards = []
+ for test_name, results in all_results.items():
+ if isinstance(results, dict) and "needs_guards" in results:
+ for item in results["needs_guards"]:
+ item["test_suite"] = test_name
+ all_needs_guards.append(item)
+
+ if all_needs_guards:
+ print("\nItems that may need additional guards/validation:")
+ for item in all_needs_guards:
+ print(
+ f" - {item.get('test_suite', 'Unknown')}: {item.get('test_name', item.get('function', 'Unknown'))}"
+ )
+ print(f" Input: {repr(item.get('input', 'N/A'))[:60]}")
+ print(f" Reason: {item.get('reason', 'N/A')}")
+ else:
+ print("\nNo additional guards appear to be needed.")
+
+ print("\n" + "=" * 70)
+ print(" Test script completed")
+ print("=" * 70)
+
+
+# =============================================================================
+# Main
+# =============================================================================
+
+
+def main():
+ """Run all tests."""
+ print("\n" + "=" * 70)
+ print(" Place-Level Simulation Test Script")
+ print(" Testing: NYC (NY), Paterson (NJ), Grand Rapids (MI)")
+ print("=" * 70)
+
+ all_results = {}
+
+ # Run direct function tests first (faster, no simulation)
+ all_results["direct_function_tests"] = test_datasets_functions_directly()
+
+ # Run place vs state comparison
+ all_results["place_vs_state"] = test_place_vs_state_comparison()
+
+ # Run CTC reform tests
+ all_results["ctc_reform"] = test_ctc_reform_impact()
+
+ # Run garbage input tests
+ all_results["garbage_inputs"] = test_garbage_inputs()
+
+ # Print summary
+ print_summary(all_results)
+
+ return all_results
+
+
+if __name__ == "__main__":
+ results = main()
diff --git a/tests/conftest.py b/tests/conftest.py
index e816468f..86f9d14e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -12,3 +12,17 @@
mock_single_economy_without_districts,
mock_single_economy_with_null_districts,
)
+
+from tests.fixtures.country.us_places import (
+ mock_simulation_type,
+ mock_reform,
+ simulation_instance,
+ mock_sim_mixed_places,
+ mock_sim_no_paterson,
+ mock_sim_all_paterson,
+ mock_sim_bytes_places,
+ mock_sim_multiple_nj_places,
+ mock_sim_for_reform_test,
+ mini_place_dataset,
+ mini_place_dataset_with_bytes,
+)
diff --git a/tests/country/test_us.py b/tests/country/test_us.py
index 385b3c1e..6e1666e7 100644
--- a/tests/country/test_us.py
+++ b/tests/country/test_us.py
@@ -2,22 +2,34 @@
def test_us_macro_single():
+ """Test US macro single economy calculation using a state dataset.
+
+ Uses Delaware (smallest state by population with data) to reduce memory
+ usage in CI while still testing the full simulation pipeline.
+ """
from policyengine import Simulation
sim = Simulation(
scope="macro",
country="us",
+ region="state/DE",
)
sim.calculate_single_economy()
def test_us_macro_comparison():
+ """Test US macro economy comparison using a state dataset.
+
+ Uses Delaware (smallest state by population with data) to reduce memory
+ usage in CI while still testing the full comparison pipeline.
+ """
from policyengine import Simulation
sim = Simulation(
scope="macro",
country="us",
+ region="state/DE",
reform={
"gov.usda.snap.income.deductions.earned_income": {"2025": 0.05}
},
diff --git a/tests/country/test_us_places.py b/tests/country/test_us_places.py
new file mode 100644
index 00000000..bfdda841
--- /dev/null
+++ b/tests/country/test_us_places.py
@@ -0,0 +1,324 @@
+"""Tests for US place-level (city) filtering functionality."""
+
+import pytest
+import pandas as pd
+from unittest.mock import Mock, patch
+
+from tests.fixtures.country.us_places import (
+ # Place FIPS Constants
+ NJ_PATERSON_FIPS,
+ NJ_NEWARK_FIPS,
+ NJ_JERSEY_CITY_FIPS,
+ CA_LOS_ANGELES_FIPS,
+ TX_HOUSTON_FIPS,
+ NONEXISTENT_PLACE_FIPS,
+ # Region String Constants
+ NJ_PATERSON_REGION,
+ NJ_NEWARK_REGION,
+ # Test Data Arrays
+ MIXED_PLACES_WITH_PATERSON,
+ PLACES_WITHOUT_PATERSON,
+ ALL_PATERSON_PLACES,
+ MIXED_PLACES_BYTES,
+ MULTIPLE_NJ_PLACES,
+ TWO_PLACES_FOR_REFORM_TEST,
+ # Expected Results Constants
+ EXPECTED_PATERSON_COUNT_IN_MIXED,
+ EXPECTED_PATERSON_COUNT_IN_BYTES,
+ EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ,
+ MINI_DATASET_PATERSON_COUNT,
+ MINI_DATASET_PATERSON_IDS,
+ MINI_DATASET_NEWARK_COUNT,
+ MINI_DATASET_NEWARK_IDS,
+ MINI_DATASET_JERSEY_CITY_COUNT,
+ MINI_DATASET_JERSEY_CITY_IDS,
+ MINI_DATASET_PATERSON_TOTAL_WEIGHT,
+ MINI_DATASET_NEWARK_TOTAL_WEIGHT,
+ MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT,
+ MINI_DATASET_BYTES_PATERSON_COUNT,
+ MINI_DATASET_BYTES_PATERSON_IDS,
+ # Factory Functions
+ create_mock_simulation_with_place_fips,
+ create_mock_simulation_with_bytes_place_fips,
+ create_mock_simulation_type,
+ create_simulation_instance,
+)
+
+
+class TestFilterUsSimulationByPlace:
+ """Tests for the _filter_us_simulation_by_place method."""
+
+ def test__given__households_in_target_place__then__filters_to_matching_households(
+ self,
+ ):
+ # Given
+ mock_sim = create_mock_simulation_with_place_fips(
+ MIXED_PLACES_WITH_PATERSON
+ )
+ mock_simulation_type = create_mock_simulation_type()
+
+ # When
+ sim_instance = create_simulation_instance()
+ result = sim_instance._filter_us_simulation_by_place(
+ simulation=mock_sim,
+ simulation_type=mock_simulation_type,
+ region=NJ_PATERSON_REGION,
+ reform=None,
+ )
+
+ # Then
+ call_args = mock_simulation_type.call_args
+ filtered_df = call_args.kwargs["dataset"]
+
+ assert len(filtered_df) == EXPECTED_PATERSON_COUNT_IN_MIXED
+ assert all(filtered_df["place_fips"] == NJ_PATERSON_FIPS)
+
+ def test__given__no_households_in_target_place__then__returns_empty_dataset(
+ self,
+ ):
+ # Given
+ mock_sim = create_mock_simulation_with_place_fips(
+ PLACES_WITHOUT_PATERSON
+ )
+ mock_simulation_type = create_mock_simulation_type()
+
+ # When
+ sim_instance = create_simulation_instance()
+ result = sim_instance._filter_us_simulation_by_place(
+ simulation=mock_sim,
+ simulation_type=mock_simulation_type,
+ region=NJ_PATERSON_REGION,
+ reform=None,
+ )
+
+ # Then
+ call_args = mock_simulation_type.call_args
+ filtered_df = call_args.kwargs["dataset"]
+
+ assert len(filtered_df) == 0
+
+ def test__given__all_households_in_target_place__then__returns_all_households(
+ self,
+ ):
+ # Given
+ mock_sim = create_mock_simulation_with_place_fips(ALL_PATERSON_PLACES)
+ mock_simulation_type = create_mock_simulation_type()
+
+ # When
+ sim_instance = create_simulation_instance()
+ result = sim_instance._filter_us_simulation_by_place(
+ simulation=mock_sim,
+ simulation_type=mock_simulation_type,
+ region=NJ_PATERSON_REGION,
+ reform=None,
+ )
+
+ # Then
+ call_args = mock_simulation_type.call_args
+ filtered_df = call_args.kwargs["dataset"]
+
+ assert len(filtered_df) == len(ALL_PATERSON_PLACES)
+
+ def test__given__bytes_place_fips_in_dataset__then__still_filters_correctly(
+ self,
+ ):
+ # Given: place_fips stored as bytes (as it might be in HDF5)
+ mock_sim = create_mock_simulation_with_bytes_place_fips(
+ MIXED_PLACES_BYTES
+ )
+ mock_simulation_type = create_mock_simulation_type()
+
+ # When
+ sim_instance = create_simulation_instance()
+ result = sim_instance._filter_us_simulation_by_place(
+ simulation=mock_sim,
+ simulation_type=mock_simulation_type,
+ region=NJ_PATERSON_REGION,
+ reform=None,
+ )
+
+ # Then
+ call_args = mock_simulation_type.call_args
+ filtered_df = call_args.kwargs["dataset"]
+
+ assert len(filtered_df) == EXPECTED_PATERSON_COUNT_IN_BYTES
+
+ def test__given__reform_provided__then__passes_reform_to_simulation_type(
+ self,
+ ):
+ # Given
+ mock_sim = create_mock_simulation_with_place_fips(
+ TWO_PLACES_FOR_REFORM_TEST
+ )
+ mock_simulation_type = create_mock_simulation_type()
+ mock_reform = Mock()
+
+ # When
+ sim_instance = create_simulation_instance()
+ result = sim_instance._filter_us_simulation_by_place(
+ simulation=mock_sim,
+ simulation_type=mock_simulation_type,
+ region=NJ_PATERSON_REGION,
+ reform=mock_reform,
+ )
+
+ # Then
+ call_args = mock_simulation_type.call_args
+ assert call_args.kwargs["reform"] is mock_reform
+
+ def test__given__different_place_in_same_state__then__filters_correctly(
+ self,
+ ):
+ # Given: Multiple NJ places
+ mock_sim = create_mock_simulation_with_place_fips(MULTIPLE_NJ_PLACES)
+ mock_simulation_type = create_mock_simulation_type()
+
+ # When: Filter for Newark only
+ sim_instance = create_simulation_instance()
+ result = sim_instance._filter_us_simulation_by_place(
+ simulation=mock_sim,
+ simulation_type=mock_simulation_type,
+ region=NJ_NEWARK_REGION,
+ reform=None,
+ )
+
+ # Then
+ call_args = mock_simulation_type.call_args
+ filtered_df = call_args.kwargs["dataset"]
+
+ assert len(filtered_df) == EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ
+ assert all(filtered_df["place_fips"] == NJ_NEWARK_FIPS)
+
+
+class TestApplyUsRegionToSimulationWithPlace:
+ """Tests for _apply_us_region_to_simulation with place regions."""
+
+ def test__given__place_region__then__calls_filter_by_place(self):
+ # Given
+ sim_instance = create_simulation_instance()
+ mock_simulation = Mock()
+ mock_simulation_type = Mock
+
+ # When / Then
+ with patch.object(
+ sim_instance,
+ "_filter_us_simulation_by_place",
+ return_value=Mock(),
+ ) as mock_filter:
+ result = sim_instance._apply_us_region_to_simulation(
+ simulation=mock_simulation,
+ simulation_type=mock_simulation_type,
+ region=NJ_PATERSON_REGION,
+ reform=None,
+ )
+
+ mock_filter.assert_called_once_with(
+ simulation=mock_simulation,
+ simulation_type=mock_simulation_type,
+ region=NJ_PATERSON_REGION,
+ reform=None,
+ )
+
+
+class TestMiniDatasetPlaceFiltering:
+ """Integration-style tests using a mini in-memory dataset.
+
+ Uses the mini_place_dataset fixture from conftest.py.
+ """
+
+ def test__given__mini_dataset__then__paterson_filter_returns_correct_count(
+ self, mini_place_dataset
+ ):
+ # Given
+ df = mini_place_dataset
+
+ # When
+ filtered_df = df[df["place_fips"] == NJ_PATERSON_FIPS]
+
+ # Then
+ assert len(filtered_df) == MINI_DATASET_PATERSON_COUNT
+ assert (
+ filtered_df["household_id"].tolist() == MINI_DATASET_PATERSON_IDS
+ )
+
+ def test__given__mini_dataset__then__newark_filter_returns_correct_count(
+ self, mini_place_dataset
+ ):
+ # Given
+ df = mini_place_dataset
+
+ # When
+ filtered_df = df[df["place_fips"] == NJ_NEWARK_FIPS]
+
+ # Then
+ assert len(filtered_df) == MINI_DATASET_NEWARK_COUNT
+ assert filtered_df["household_id"].tolist() == MINI_DATASET_NEWARK_IDS
+
+ def test__given__mini_dataset__then__jersey_city_filter_returns_correct_count(
+ self, mini_place_dataset
+ ):
+ # Given
+ df = mini_place_dataset
+
+ # When
+ filtered_df = df[df["place_fips"] == NJ_JERSEY_CITY_FIPS]
+
+ # Then
+ assert len(filtered_df) == MINI_DATASET_JERSEY_CITY_COUNT
+ assert (
+ filtered_df["household_id"].tolist()
+ == MINI_DATASET_JERSEY_CITY_IDS
+ )
+
+ def test__given__mini_dataset__then__total_weight_sums_correctly_per_place(
+ self, mini_place_dataset
+ ):
+ # Given
+ df = mini_place_dataset
+
+ # When
+ paterson_weight = df[df["place_fips"] == NJ_PATERSON_FIPS][
+ "household_weight"
+ ].sum()
+ newark_weight = df[df["place_fips"] == NJ_NEWARK_FIPS][
+ "household_weight"
+ ].sum()
+ jersey_city_weight = df[df["place_fips"] == NJ_JERSEY_CITY_FIPS][
+ "household_weight"
+ ].sum()
+
+ # Then
+ assert paterson_weight == MINI_DATASET_PATERSON_TOTAL_WEIGHT
+ assert newark_weight == MINI_DATASET_NEWARK_TOTAL_WEIGHT
+ assert jersey_city_weight == MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT
+
+ def test__given__mini_dataset__then__nonexistent_place_returns_empty(
+ self, mini_place_dataset
+ ):
+ # Given
+ df = mini_place_dataset
+
+ # When
+ filtered_df = df[df["place_fips"] == NONEXISTENT_PLACE_FIPS]
+
+ # Then
+ assert len(filtered_df) == 0
+
+ def test__given__mini_dataset_with_bytes_fips__then__filtering_works(
+ self, mini_place_dataset_with_bytes
+ ):
+ # Given
+ df = mini_place_dataset_with_bytes
+
+ # When: Filter handling both str and bytes
+ mask = (df["place_fips"] == NJ_PATERSON_FIPS) | (
+ df["place_fips"] == NJ_PATERSON_FIPS.encode()
+ )
+ filtered_df = df[mask]
+
+ # Then
+ assert len(filtered_df) == MINI_DATASET_BYTES_PATERSON_COUNT
+ assert (
+ filtered_df["household_id"].tolist()
+ == MINI_DATASET_BYTES_PATERSON_IDS
+ )
diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py
index 131d59d6..a5862e50 100644
--- a/tests/fixtures/__init__.py
+++ b/tests/fixtures/__init__.py
@@ -10,3 +10,62 @@
mock_single_economy_without_districts,
mock_single_economy_with_null_districts,
)
+
+from tests.fixtures.country.us_places import (
+ # Place FIPS Constants
+ NJ_PATERSON_FIPS,
+ NJ_NEWARK_FIPS,
+ NJ_JERSEY_CITY_FIPS,
+ CA_LOS_ANGELES_FIPS,
+ TX_HOUSTON_FIPS,
+ NONEXISTENT_PLACE_FIPS,
+ NJ_PATERSON_FIPS_BYTES,
+ NJ_NEWARK_FIPS_BYTES,
+ CA_LOS_ANGELES_FIPS_BYTES,
+ TX_HOUSTON_FIPS_BYTES,
+ # Region String Constants
+ NJ_PATERSON_REGION,
+ NJ_NEWARK_REGION,
+ NJ_JERSEY_CITY_REGION,
+ CA_LOS_ANGELES_REGION,
+ TX_HOUSTON_REGION,
+ # Test Data Arrays
+ MIXED_PLACES_WITH_PATERSON,
+ PLACES_WITHOUT_PATERSON,
+ ALL_PATERSON_PLACES,
+ MIXED_PLACES_BYTES,
+ MULTIPLE_NJ_PLACES,
+ TWO_PLACES_FOR_REFORM_TEST,
+ # Expected Results Constants
+ EXPECTED_PATERSON_COUNT_IN_MIXED,
+ EXPECTED_PATERSON_COUNT_IN_BYTES,
+ EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ,
+ MINI_DATASET_PATERSON_COUNT,
+ MINI_DATASET_PATERSON_IDS,
+ MINI_DATASET_NEWARK_COUNT,
+ MINI_DATASET_NEWARK_IDS,
+ MINI_DATASET_JERSEY_CITY_COUNT,
+ MINI_DATASET_JERSEY_CITY_IDS,
+ MINI_DATASET_PATERSON_TOTAL_WEIGHT,
+ MINI_DATASET_NEWARK_TOTAL_WEIGHT,
+ MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT,
+ MINI_DATASET_BYTES_PATERSON_COUNT,
+ MINI_DATASET_BYTES_PATERSON_IDS,
+ # Factory Functions
+ create_mock_simulation_with_place_fips,
+ create_mock_simulation_with_bytes_place_fips,
+ create_mock_simulation_type,
+ create_simulation_instance,
+ # Pytest Fixtures
+ mock_simulation_type,
+ mock_reform,
+ simulation_instance,
+ mock_sim_mixed_places,
+ mock_sim_no_paterson,
+ mock_sim_all_paterson,
+ mock_sim_bytes_places,
+ mock_sim_multiple_nj_places,
+ mock_sim_for_reform_test,
+ mini_place_dataset,
+ mini_place_dataset_with_bytes,
+)
diff --git a/tests/fixtures/country/__init__.py b/tests/fixtures/country/__init__.py
new file mode 100644
index 00000000..4eaa139b
--- /dev/null
+++ b/tests/fixtures/country/__init__.py
@@ -0,0 +1,60 @@
+"""Test fixtures for country-specific tests."""
+
+from tests.fixtures.country.us_places import (
+ # Place FIPS Constants
+ NJ_PATERSON_FIPS,
+ NJ_NEWARK_FIPS,
+ NJ_JERSEY_CITY_FIPS,
+ CA_LOS_ANGELES_FIPS,
+ TX_HOUSTON_FIPS,
+ NONEXISTENT_PLACE_FIPS,
+ NJ_PATERSON_FIPS_BYTES,
+ NJ_NEWARK_FIPS_BYTES,
+ CA_LOS_ANGELES_FIPS_BYTES,
+ TX_HOUSTON_FIPS_BYTES,
+ # Region String Constants
+ NJ_PATERSON_REGION,
+ NJ_NEWARK_REGION,
+ NJ_JERSEY_CITY_REGION,
+ CA_LOS_ANGELES_REGION,
+ TX_HOUSTON_REGION,
+ # Test Data Arrays
+ MIXED_PLACES_WITH_PATERSON,
+ PLACES_WITHOUT_PATERSON,
+ ALL_PATERSON_PLACES,
+ MIXED_PLACES_BYTES,
+ MULTIPLE_NJ_PLACES,
+ TWO_PLACES_FOR_REFORM_TEST,
+ # Expected Results Constants
+ EXPECTED_PATERSON_COUNT_IN_MIXED,
+ EXPECTED_PATERSON_COUNT_IN_BYTES,
+ EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ,
+ MINI_DATASET_PATERSON_COUNT,
+ MINI_DATASET_PATERSON_IDS,
+ MINI_DATASET_NEWARK_COUNT,
+ MINI_DATASET_NEWARK_IDS,
+ MINI_DATASET_JERSEY_CITY_COUNT,
+ MINI_DATASET_JERSEY_CITY_IDS,
+ MINI_DATASET_PATERSON_TOTAL_WEIGHT,
+ MINI_DATASET_NEWARK_TOTAL_WEIGHT,
+ MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT,
+ MINI_DATASET_BYTES_PATERSON_COUNT,
+ MINI_DATASET_BYTES_PATERSON_IDS,
+ # Factory Functions
+ create_mock_simulation_with_place_fips,
+ create_mock_simulation_with_bytes_place_fips,
+ create_mock_simulation_type,
+ create_simulation_instance,
+ # Pytest Fixtures
+ mock_simulation_type,
+ mock_reform,
+ simulation_instance,
+ mock_sim_mixed_places,
+ mock_sim_no_paterson,
+ mock_sim_all_paterson,
+ mock_sim_bytes_places,
+ mock_sim_multiple_nj_places,
+ mock_sim_for_reform_test,
+ mini_place_dataset,
+ mini_place_dataset_with_bytes,
+)
diff --git a/tests/fixtures/country/us_places.py b/tests/fixtures/country/us_places.py
new file mode 100644
index 00000000..ff65af60
--- /dev/null
+++ b/tests/fixtures/country/us_places.py
@@ -0,0 +1,358 @@
+"""Test fixtures for US place-level filtering tests."""
+
+import pytest
+import numpy as np
+import pandas as pd
+from unittest.mock import Mock
+
+# =============================================================================
+# Place FIPS Constants
+# =============================================================================
+
+# New Jersey places
+NJ_PATERSON_FIPS = "57000"
+NJ_NEWARK_FIPS = "51000"
+NJ_JERSEY_CITY_FIPS = "36000"
+
+# Other state places
+CA_LOS_ANGELES_FIPS = "44000"
+TX_HOUSTON_FIPS = "35000"
+
+# Non-existent place for testing empty results
+NONEXISTENT_PLACE_FIPS = "99999"
+
+# Bytes versions (as stored in HDF5)
+NJ_PATERSON_FIPS_BYTES = b"57000"
+NJ_NEWARK_FIPS_BYTES = b"51000"
+CA_LOS_ANGELES_FIPS_BYTES = b"44000"
+TX_HOUSTON_FIPS_BYTES = b"35000"
+
+# =============================================================================
+# Region String Constants
+# =============================================================================
+
+NJ_PATERSON_REGION = "place/NJ-57000"
+NJ_NEWARK_REGION = "place/NJ-51000"
+NJ_JERSEY_CITY_REGION = "place/NJ-36000"
+CA_LOS_ANGELES_REGION = "place/CA-44000"
+TX_HOUSTON_REGION = "place/TX-35000"
+
+# =============================================================================
+# Test Data Arrays - Place FIPS combinations for various test scenarios
+# =============================================================================
+
+# Mixed places: 3 Paterson, 1 LA, 1 Houston (5 total)
+MIXED_PLACES_WITH_PATERSON = [
+ NJ_PATERSON_FIPS,
+ NJ_PATERSON_FIPS,
+ CA_LOS_ANGELES_FIPS,
+ TX_HOUSTON_FIPS,
+ NJ_PATERSON_FIPS,
+]
+
+# No Paterson: LA, Houston, Newark (3 total)
+PLACES_WITHOUT_PATERSON = [
+ CA_LOS_ANGELES_FIPS,
+ TX_HOUSTON_FIPS,
+ NJ_NEWARK_FIPS,
+]
+
+# All Paterson (3 total)
+ALL_PATERSON_PLACES = [
+ NJ_PATERSON_FIPS,
+ NJ_PATERSON_FIPS,
+ NJ_PATERSON_FIPS,
+]
+
+# Bytes version: 2 Paterson, 1 LA, 1 Houston (4 total)
+MIXED_PLACES_BYTES = [
+ NJ_PATERSON_FIPS_BYTES,
+ NJ_PATERSON_FIPS_BYTES,
+ CA_LOS_ANGELES_FIPS_BYTES,
+ TX_HOUSTON_FIPS_BYTES,
+]
+
+# Multiple NJ places: 2 Paterson, 2 Newark, 1 Jersey City (5 total)
+MULTIPLE_NJ_PLACES = [
+ NJ_PATERSON_FIPS,
+ NJ_NEWARK_FIPS,
+ NJ_JERSEY_CITY_FIPS,
+ NJ_PATERSON_FIPS,
+ NJ_NEWARK_FIPS,
+]
+
+# Simple two-place array for reform test
+TWO_PLACES_FOR_REFORM_TEST = [
+ NJ_PATERSON_FIPS,
+ CA_LOS_ANGELES_FIPS,
+]
+
+# =============================================================================
+# Expected Results Constants
+# =============================================================================
+
+# Expected counts when filtering MIXED_PLACES_WITH_PATERSON for Paterson
+EXPECTED_PATERSON_COUNT_IN_MIXED = 3
+
+# Expected counts when filtering MIXED_PLACES_BYTES for Paterson
+EXPECTED_PATERSON_COUNT_IN_BYTES = 2
+
+# Expected counts when filtering MULTIPLE_NJ_PLACES for Newark
+EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ = 2
+
+# Mini dataset expected values
+MINI_DATASET_PATERSON_COUNT = 4
+MINI_DATASET_PATERSON_IDS = [0, 1, 4, 7]
+MINI_DATASET_NEWARK_COUNT = 3
+MINI_DATASET_NEWARK_IDS = [2, 5, 8]
+MINI_DATASET_JERSEY_CITY_COUNT = 3
+MINI_DATASET_JERSEY_CITY_IDS = [3, 6, 9]
+
+# Mini dataset expected weights per place
+MINI_DATASET_PATERSON_TOTAL_WEIGHT = 4200.0
+MINI_DATASET_NEWARK_TOTAL_WEIGHT = 5200.0
+MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT = 3600.0
+
+# Mini dataset with bytes expected values
+MINI_DATASET_BYTES_PATERSON_COUNT = 2
+MINI_DATASET_BYTES_PATERSON_IDS = [0, 1]
+
+
+# =============================================================================
+# Mock Factory Functions
+# =============================================================================
+
+
+def create_mock_simulation_with_place_fips(
+ place_fips_values: list[str],
+ household_ids: list[int] | None = None,
+) -> Mock:
+ """Create a mock simulation with place_fips data.
+
+ Args:
+ place_fips_values: List of place FIPS codes for each household.
+ household_ids: Optional list of household IDs.
+
+ Returns:
+ Mock simulation object with calculate() and to_input_dataframe() configured.
+ """
+ if household_ids is None:
+ household_ids = list(range(len(place_fips_values)))
+
+ mock_sim = Mock()
+
+ # Mock calculate to return place_fips values
+ mock_calculate_result = Mock()
+ mock_calculate_result.values = np.array(place_fips_values)
+ mock_sim.calculate.return_value = mock_calculate_result
+
+ # Mock to_input_dataframe to return a DataFrame
+ df = pd.DataFrame(
+ {
+ "household_id": household_ids,
+ "place_fips": place_fips_values,
+ }
+ )
+ mock_sim.to_input_dataframe.return_value = df
+
+ return mock_sim
+
+
+def create_mock_simulation_with_bytes_place_fips(
+ place_fips_values: list[bytes],
+ household_ids: list[int] | None = None,
+) -> Mock:
+ """Create a mock simulation with bytes place_fips data (as from HDF5).
+
+ Args:
+ place_fips_values: List of place FIPS codes as bytes.
+ household_ids: Optional list of household IDs.
+
+ Returns:
+ Mock simulation object with calculate() and to_input_dataframe() configured.
+ """
+ if household_ids is None:
+ household_ids = list(range(len(place_fips_values)))
+
+ mock_sim = Mock()
+
+ mock_calculate_result = Mock()
+ mock_calculate_result.values = np.array(place_fips_values)
+ mock_sim.calculate.return_value = mock_calculate_result
+
+ df = pd.DataFrame(
+ {
+ "household_id": household_ids,
+ "place_fips": place_fips_values,
+ }
+ )
+ mock_sim.to_input_dataframe.return_value = df
+
+ return mock_sim
+
+
+def create_mock_simulation_type() -> Mock:
+ """Create a mock simulation type class.
+
+ Returns:
+ Mock that can be called to create simulation instances.
+ """
+ mock_simulation_type = Mock()
+ mock_simulation_type.return_value = Mock()
+ return mock_simulation_type
+
+
+def create_simulation_instance():
+ """Create a Simulation instance without calling __init__.
+
+ Returns:
+ A bare Simulation instance for testing methods directly.
+ """
+ from policyengine.simulation import Simulation
+
+ return object.__new__(Simulation)
+
+
+# =============================================================================
+# Pytest Fixtures - Pre-configured Mocks
+# =============================================================================
+
+
+@pytest.fixture
+def mock_simulation_type() -> Mock:
+ """Fixture for a mock simulation type class."""
+ return create_mock_simulation_type()
+
+
+@pytest.fixture
+def mock_reform() -> Mock:
+ """Fixture for a mock reform object."""
+ return Mock(name="mock_reform")
+
+
+@pytest.fixture
+def simulation_instance():
+ """Fixture for a bare Simulation instance."""
+ return create_simulation_instance()
+
+
+@pytest.fixture
+def mock_sim_mixed_places() -> Mock:
+ """Mock simulation with mixed places including Paterson."""
+ return create_mock_simulation_with_place_fips(MIXED_PLACES_WITH_PATERSON)
+
+
+@pytest.fixture
+def mock_sim_no_paterson() -> Mock:
+ """Mock simulation with no Paterson households."""
+ return create_mock_simulation_with_place_fips(PLACES_WITHOUT_PATERSON)
+
+
+@pytest.fixture
+def mock_sim_all_paterson() -> Mock:
+ """Mock simulation where all households are in Paterson."""
+ return create_mock_simulation_with_place_fips(ALL_PATERSON_PLACES)
+
+
+@pytest.fixture
+def mock_sim_bytes_places() -> Mock:
+ """Mock simulation with bytes-encoded place FIPS (HDF5 style)."""
+ return create_mock_simulation_with_bytes_place_fips(MIXED_PLACES_BYTES)
+
+
+@pytest.fixture
+def mock_sim_multiple_nj_places() -> Mock:
+ """Mock simulation with multiple NJ places."""
+ return create_mock_simulation_with_place_fips(MULTIPLE_NJ_PLACES)
+
+
+@pytest.fixture
+def mock_sim_for_reform_test() -> Mock:
+ """Mock simulation for testing reform passthrough."""
+ return create_mock_simulation_with_place_fips(TWO_PLACES_FOR_REFORM_TEST)
+
+
+# =============================================================================
+# Pytest Fixtures - Dataset Fixtures
+# =============================================================================
+
+
+@pytest.fixture
+def mini_place_dataset() -> pd.DataFrame:
+ """Create a mini dataset with place_fips for testing.
+
+ Simulates 10 households across 3 NJ places:
+ - Paterson (57000): 4 households (IDs: 0, 1, 4, 7)
+ - Newark (51000): 3 households (IDs: 2, 5, 8)
+ - Jersey City (36000): 3 households (IDs: 3, 6, 9)
+
+ Returns:
+ DataFrame with household_id, place_fips, household_weight,
+ and household_net_income columns.
+ """
+ return pd.DataFrame(
+ {
+ "household_id": list(range(10)),
+ "place_fips": [
+ NJ_PATERSON_FIPS,
+ NJ_PATERSON_FIPS,
+ NJ_NEWARK_FIPS,
+ NJ_JERSEY_CITY_FIPS,
+ NJ_PATERSON_FIPS,
+ NJ_NEWARK_FIPS,
+ NJ_JERSEY_CITY_FIPS,
+ NJ_PATERSON_FIPS,
+ NJ_NEWARK_FIPS,
+ NJ_JERSEY_CITY_FIPS,
+ ],
+ "household_weight": [
+ 1000.0,
+ 1500.0,
+ 2000.0,
+ 1200.0,
+ 800.0,
+ 1800.0,
+ 1100.0,
+ 900.0,
+ 1400.0,
+ 1300.0,
+ ],
+ "household_net_income": [
+ 50000.0,
+ 60000.0,
+ 75000.0,
+ 45000.0,
+ 55000.0,
+ 80000.0,
+ 40000.0,
+ 65000.0,
+ 70000.0,
+ 48000.0,
+ ],
+ }
+ )
+
+
+@pytest.fixture
+def mini_place_dataset_with_bytes() -> pd.DataFrame:
+ """Create a mini dataset with bytes place_fips (as from HDF5).
+
+ Simulates 4 households across 2 NJ places with bytes-encoded FIPS:
+ - Paterson (b"57000"): 2 households (IDs: 0, 1)
+ - Newark (b"51000"): 2 households (IDs: 2, 3)
+
+ Returns:
+ DataFrame with household_id, place_fips (as bytes), and household_weight.
+ """
+ return pd.DataFrame(
+ {
+ "household_id": [0, 1, 2, 3],
+ "place_fips": [
+ NJ_PATERSON_FIPS_BYTES,
+ NJ_PATERSON_FIPS_BYTES,
+ NJ_NEWARK_FIPS_BYTES,
+ NJ_NEWARK_FIPS_BYTES,
+ ],
+ "household_weight": [1000.0, 1000.0, 1000.0, 1000.0],
+ }
+ )
diff --git a/tests/fixtures/simulation.py b/tests/fixtures/simulation.py
index d2361fef..a0f368fc 100644
--- a/tests/fixtures/simulation.py
+++ b/tests/fixtures/simulation.py
@@ -1,7 +1,11 @@
-from policyengine.simulation import SimulationOptions
from unittest.mock import patch, Mock
import pytest
-from policyengine.utils.data.datasets import CPS_2023
+
+# Constants that don't require heavy imports
+SAMPLE_DATASET_FILENAME = "frs_2023_24.h5"
+SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private"
+SAMPLE_DATASET_URI_PREFIX = "gs://"
+SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}"
non_data_uk_sim_options = {
"country": "uk",
@@ -21,25 +25,34 @@
"baseline": None,
}
-uk_sim_options_no_data = SimulationOptions.model_validate(
- {
- **non_data_uk_sim_options,
- "data": None,
- }
-)
-us_sim_options_cps_dataset = SimulationOptions.model_validate(
- {**non_data_us_sim_options, "data": CPS_2023}
-)
+# Lazy-loaded simulation options to avoid importing policyengine.simulation at collection time
+def get_uk_sim_options_no_data():
+ from policyengine.simulation import SimulationOptions
+
+ return SimulationOptions.model_validate(
+ {
+ **non_data_uk_sim_options,
+ "data": None,
+ }
+ )
+
+
+def get_us_sim_options_cps_dataset():
+ from policyengine.simulation import SimulationOptions
+ from policyengine.utils.data.datasets import CPS_2023
+
+ return SimulationOptions.model_validate(
+ {**non_data_us_sim_options, "data": CPS_2023}
+ )
-SAMPLE_DATASET_FILENAME = "frs_2023_24.h5"
-SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private"
-SAMPLE_DATASET_URI_PREFIX = "gs://"
-SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}"
-uk_sim_options_pe_dataset = SimulationOptions.model_validate(
- {**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS}
-)
+def get_uk_sim_options_pe_dataset():
+ from policyengine.simulation import SimulationOptions
+
+ return SimulationOptions.model_validate(
+ {**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS}
+ )
@pytest.fixture
diff --git a/tests/test_simulation.py b/tests/test_simulation.py
index e51f23a0..f761588e 100644
--- a/tests/test_simulation.py
+++ b/tests/test_simulation.py
@@ -1,7 +1,7 @@
from .fixtures.simulation import (
- uk_sim_options_no_data,
- uk_sim_options_pe_dataset,
- us_sim_options_cps_dataset,
+ get_uk_sim_options_no_data,
+ get_uk_sim_options_pe_dataset,
+ get_us_sim_options_cps_dataset,
mock_get_default_dataset,
mock_dataset,
SAMPLE_DATASET_FILENAME,
@@ -10,28 +10,27 @@
import sys
from copy import deepcopy
-from policyengine import Simulation
-from policyengine.outputs.macro.single.calculate_single_economy import (
- GeneralEconomyTask,
-)
-
class TestSimulation:
class TestSetData:
def test__given_no_data_option__sets_default_dataset(
self, mock_get_default_dataset, mock_dataset
):
+ from policyengine import Simulation
# Don't run entire init script
sim = object.__new__(Simulation)
+ uk_sim_options_no_data = get_uk_sim_options_no_data()
sim.options = deepcopy(uk_sim_options_no_data)
sim._set_data(uk_sim_options_no_data.data)
def test__given_pe_dataset__sets_data_option_to_dataset(
self, mock_dataset
):
+ from policyengine import Simulation
sim = object.__new__(Simulation)
+ uk_sim_options_pe_dataset = get_uk_sim_options_pe_dataset()
sim.options = deepcopy(uk_sim_options_pe_dataset)
sim._set_data(uk_sim_options_pe_dataset.data)
@@ -41,6 +40,7 @@ def test__given_cps_2023_in_filename__sets_time_period_to_2023(
from policyengine import Simulation
sim = object.__new__(Simulation)
+ us_sim_options_cps_dataset = get_us_sim_options_cps_dataset()
sim.options = deepcopy(us_sim_options_cps_dataset)
sim._set_data(us_sim_options_cps_dataset.data)
@@ -49,7 +49,7 @@ def test__given_dataset_with_time_period__sets_time_period(self):
from policyengine import Simulation
sim = object.__new__(Simulation)
-
+ us_sim_options_cps_dataset = get_us_sim_options_cps_dataset()
print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr)
assert (
sim._set_data_time_period(us_sim_options_cps_dataset.data)
@@ -62,6 +62,7 @@ def test__given_dataset_without_time_period__does_not_set_time_period(
from policyengine import Simulation
sim = object.__new__(Simulation)
+ uk_sim_options_pe_dataset = get_uk_sim_options_pe_dataset()
assert (
sim._set_data_time_period(uk_sim_options_pe_dataset.data)
== None
@@ -71,6 +72,9 @@ class TestCalculateCliffs:
def test__calculates_correct_cliff_metrics(
self, mock_simulation_with_cliff_vars
):
+ from policyengine.outputs.macro.single.calculate_single_economy import (
+ GeneralEconomyTask,
+ )
task = object.__new__(GeneralEconomyTask)
task.simulation = mock_simulation_with_cliff_vars
diff --git a/tests/utils/data/test_datasets.py b/tests/utils/data/test_datasets.py
index 48ae5f5b..eb2f8ba7 100644
--- a/tests/utils/data/test_datasets.py
+++ b/tests/utils/data/test_datasets.py
@@ -1,7 +1,14 @@
"""Tests for datasets.py utilities."""
import pytest
-from policyengine.utils.data.datasets import process_gs_path
+from policyengine.utils.data.datasets import (
+ process_gs_path,
+ determine_us_region_type,
+ parse_us_place_region,
+ get_us_state_dataset_path,
+ _get_default_us_dataset,
+ US_DATA_BUCKET,
+)
class TestProcessGsPath:
@@ -80,3 +87,197 @@ def test_real_policyengine_paths(self):
assert bucket == "policyengine-us-data"
assert path == "states/CA.h5"
assert version is None
+
+
+class TestDetermineUsRegionType:
+ """Tests for determine_us_region_type function with place regions."""
+
+ def test__given__place_region_string__then__returns_place(self):
+ # Given
+ region = "place/NJ-57000"
+
+ # When
+ result = determine_us_region_type(region)
+
+ # Then
+ assert result == "place"
+
+ def test__given__place_region_with_different_state__then__returns_place(
+ self,
+ ):
+ # Given
+ region = "place/CA-44000"
+
+ # When
+ result = determine_us_region_type(region)
+
+ # Then
+ assert result == "place"
+
+ def test__given__none_region__then__returns_nationwide(self):
+ # Given
+ region = None
+
+ # When
+ result = determine_us_region_type(region)
+
+ # Then
+ assert result == "nationwide"
+
+ def test__given__us_region__then__returns_nationwide(self):
+ # Given
+ region = "us"
+
+ # When
+ result = determine_us_region_type(region)
+
+ # Then
+ assert result == "nationwide"
+
+ def test__given__state_region__then__returns_state(self):
+ # Given
+ region = "state/CA"
+
+ # When
+ result = determine_us_region_type(region)
+
+ # Then
+ assert result == "state"
+
+ def test__given__congressional_district_region__then__returns_congressional_district(
+ self,
+ ):
+ # Given
+ region = "congressional_district/CA-01"
+
+ # When
+ result = determine_us_region_type(region)
+
+ # Then
+ assert result == "congressional_district"
+
+ def test__given__invalid_region__then__raises_value_error(self):
+ # Given
+ region = "invalid/something"
+
+ # When / Then
+ with pytest.raises(ValueError) as exc_info:
+ determine_us_region_type(region)
+
+ assert "Unrecognized US region format" in str(exc_info.value)
+
+
+class TestParseUsPlaceRegion:
+ """Tests for parse_us_place_region function."""
+
+ def test__given__nj_paterson_region__then__returns_nj_and_57000(self):
+ # Given
+ region = "place/NJ-57000"
+
+ # When
+ state_code, place_fips = parse_us_place_region(region)
+
+ # Then
+ assert state_code == "NJ"
+ assert place_fips == "57000"
+
+ def test__given__ca_los_angeles_region__then__returns_ca_and_44000(self):
+ # Given
+ region = "place/CA-44000"
+
+ # When
+ state_code, place_fips = parse_us_place_region(region)
+
+ # Then
+ assert state_code == "CA"
+ assert place_fips == "44000"
+
+ def test__given__tx_houston_region__then__returns_tx_and_35000(self):
+ # Given
+ region = "place/TX-35000"
+
+ # When
+ state_code, place_fips = parse_us_place_region(region)
+
+ # Then
+ assert state_code == "TX"
+ assert place_fips == "35000"
+
+ def test__given__lowercase_state_code__then__returns_lowercase(self):
+ # Given
+ region = "place/ny-51000"
+
+ # When
+ state_code, place_fips = parse_us_place_region(region)
+
+ # Then
+ assert state_code == "ny"
+ assert place_fips == "51000"
+
+ def test__given__place_fips_with_leading_zeros__then__preserves_zeros(
+ self,
+ ):
+ # Given
+ region = "place/AL-07000"
+
+ # When
+ state_code, place_fips = parse_us_place_region(region)
+
+ # Then
+ assert state_code == "AL"
+ assert place_fips == "07000"
+
+
+class TestGetDefaultUsDatasetForPlace:
+ """Tests for _get_default_us_dataset with place regions."""
+
+ def test__given__place_region__then__returns_state_dataset_path(self):
+ # Given
+ region = "place/NJ-57000"
+
+ # When
+ result = _get_default_us_dataset(region)
+
+ # Then
+ expected = f"{US_DATA_BUCKET}/states/NJ.h5"
+ assert result == expected
+
+ def test__given__ca_place_region__then__returns_ca_state_dataset(self):
+ # Given
+ region = "place/CA-44000"
+
+ # When
+ result = _get_default_us_dataset(region)
+
+ # Then
+ expected = f"{US_DATA_BUCKET}/states/CA.h5"
+ assert result == expected
+
+ def test__given__lowercase_state_in_place__then__returns_uppercase_state_path(
+ self,
+ ):
+ # Given
+ region = "place/tx-35000"
+
+ # When
+ result = _get_default_us_dataset(region)
+
+ # Then
+ # get_us_state_dataset_path uppercases the state code
+ expected = f"{US_DATA_BUCKET}/states/TX.h5"
+ assert result == expected
+
+ def test__given__place_region__then__result_matches_state_region_dataset(
+ self,
+ ):
+ # Given
+ place_region = "place/GA-04000"
+ state_region = "state/GA"
+
+ # When
+ place_result = _get_default_us_dataset(place_region)
+ state_result = _get_default_us_dataset(state_region)
+
+ # Then
+ # Place regions should load the same dataset as their parent state
+ assert place_result == state_result