diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..cace1367 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,7 @@ +- bump: minor + changes: + added: + - Place-level (city) impact analysis for US Census places with population over 100,000. + - Input validation for place region strings. + removed: + - Deprecated "city/nyc" region format (use "place/NY-51000" instead). diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 9cd68c82..864aa90a 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -376,29 +376,49 @@ def _apply_us_region_to_simulation( """Apply US-specific regional filtering to a simulation. Note: Most US regions (states, congressional districts) now use - scoped datasets rather than filtering. Only NYC still requires - filtering from the national dataset (and is still using the pooled - CPS by default). This should be replaced with an approach based on - the new datasets. + scoped datasets rather than filtering. Place-level regions use + the parent state's dataset and filter by place_fips. """ - if region == "city/nyc": - simulation = self._filter_us_simulation_by_nyc( + if isinstance(region, str) and region.startswith("place/"): + simulation = self._filter_us_simulation_by_place( simulation=simulation, simulation_type=simulation_type, + region=region, reform=reform, ) return simulation - def _filter_us_simulation_by_nyc( + def _filter_us_simulation_by_place( self, simulation: CountryMicrosimulation, simulation_type: type, + region: str, reform: ReformType | None, ) -> CountrySimulation: - """Filter a US simulation to only include NYC households.""" + """Filter a US simulation to only include households in a specific Census place. + + Args: + simulation: The microsimulation to filter. + simulation_type: The type of simulation to create. + region: A place region string (e.g., "place/NJ-57000"). + reform: The reform to apply to the filtered simulation. + + Returns: + A new simulation containing only households in the specified place. + """ + from policyengine.utils.data.datasets import parse_us_place_region + + _, place_fips_code = parse_us_place_region(region) df = simulation.to_input_dataframe() - in_nyc = simulation.calculate("in_nyc", map_to="person").values - return simulation_type(dataset=df[in_nyc], reform=reform) + # Get place_fips at person level since to_input_dataframe() is person-level + person_place_fips = simulation.calculate( + "place_fips", map_to="person" + ).values + # place_fips may be stored as bytes in HDF5; handle both str and bytes + mask = (person_place_fips == place_fips_code) | ( + person_place_fips == place_fips_code.encode() + ) + return simulation_type(dataset=df[mask], reform=reform) def check_model_version(self) -> None: """ diff --git a/policyengine/utils/charts.py b/policyengine/utils/charts.py index 6575c1e3..2dafd4c9 100644 --- a/policyengine/utils/charts.py +++ b/policyengine/utils/charts.py @@ -4,13 +4,11 @@ def add_fonts(): - fonts = HTML( - """ + fonts = HTML(""" - """ - ) + """) return display_html(fonts) diff --git a/policyengine/utils/data/datasets.py b/policyengine/utils/data/datasets.py index b1962e53..9b6614e2 100644 --- a/policyengine/utils/data/datasets.py +++ b/policyengine/utils/data/datasets.py @@ -47,13 +47,8 @@ def _get_default_us_dataset(region: str | None) -> str: if region_type == "nationwide": return ECPS_2024 - elif region_type == "city": - # TODO: Implement a better approach to this for our one - # city, New York City. - # Cities use the pooled CPS dataset - return CPS_2023_POOLED - # For state and congressional_district, region is guaranteed to be non-None + # For state, congressional_district, and place, region is guaranteed to be non-None assert region is not None if region_type == "state": @@ -68,6 +63,11 @@ def _get_default_us_dataset(region: str | None) -> str: state_code, district_number ) + elif region_type == "place": + # Expected format: "place/NJ-57000" + state_code, _ = parse_us_place_region(region) + return get_us_state_dataset_path(state_code) + raise ValueError(f"Unhandled US region type: {region_type}") @@ -124,9 +124,11 @@ def get_us_congressional_district_dataset_path( return f"{US_DATA_BUCKET}/districts/{state_code.upper()}-{district_number:02d}.h5" -USRegionType = Literal["nationwide", "city", "state", "congressional_district"] +USRegionType = Literal[ + "nationwide", "state", "congressional_district", "place" +] -US_REGION_PREFIXES = ("city", "state", "congressional_district") +US_REGION_PREFIXES = ("state", "congressional_district", "place") def determine_us_region_type(region: str | None) -> USRegionType: @@ -134,11 +136,11 @@ def determine_us_region_type(region: str | None) -> USRegionType: Determine the type of US region from a region string. Args: - region: A region string (e.g., "us", "city/nyc", "state/CA", - "congressional_district/CA-01") or None. + region: A region string (e.g., "us", "state/CA", + "congressional_district/CA-01", "place/NJ-57000") or None. Returns: - One of "nationwide", "city", "state", or "congressional_district". + One of "nationwide", "state", "congressional_district", or "place". Raises: ValueError: If the region string has an unrecognized prefix. @@ -154,3 +156,45 @@ def determine_us_region_type(region: str | None) -> USRegionType: f"Unrecognized US region format: '{region}'. " f"Expected 'us', or one of the following prefixes: {list(US_REGION_PREFIXES)}" ) + + +def parse_us_place_region(region: str) -> Tuple[str, str]: + """Parse a place region string into (state_code, place_fips). + + Format: 'place/{STATE}-{PLACE_FIPS}' + Example: 'place/NJ-57000' -> ('NJ', '57000') + + Args: + region: A place region string (e.g., "place/NJ-57000"). + + Returns: + A tuple of (state_code, place_fips). + + Raises: + ValueError: If the region format is invalid or missing required parts. + """ + if not region.startswith("place/"): + raise ValueError( + f"Invalid place region format: '{region}'. " + "Expected format: 'place/{{STATE}}-{{PLACE_FIPS}}'" + ) + + place_str = region.split("/")[1] + if "-" not in place_str: + raise ValueError( + f"Invalid place region format: '{region}'. " + "Expected format: 'place/{{STATE}}-{{PLACE_FIPS}}'" + ) + + state_code, place_fips = place_str.split("-", 1) + + if not state_code: + raise ValueError( + f"Invalid place region: '{region}'. State code cannot be empty." + ) + if not place_fips: + raise ValueError( + f"Invalid place region: '{region}'. Place FIPS code cannot be empty." + ) + + return state_code, place_fips diff --git a/scripts/test_place_simulations.py b/scripts/test_place_simulations.py new file mode 100644 index 00000000..c5215196 --- /dev/null +++ b/scripts/test_place_simulations.py @@ -0,0 +1,591 @@ +""" +Test script for place-level simulations. + +This script tests the place-level filtering functionality by: +1. Running simulations for specific cities (NYC, Paterson NJ, Grand Rapids MI) +2. Comparing results to their parent states +3. Testing a CTC fully refundable reform +4. Checking that budgetary impact is proportional to population +5. Testing garbage/invalid inputs +""" + +import sys +import traceback +from typing import Any + +# ============================================================================= +# Configuration +# ============================================================================= + +# Place definitions: (name, region_string, parent_state_region) +PLACES_TO_TEST = [ + ("New York City, NY", "place/NY-51000", "state/NY"), + ("Paterson, NJ", "place/NJ-57000", "state/NJ"), + ("Grand Rapids, MI", "place/MI-34000", "state/MI"), +] + +# Reform: Make CTC fully refundable +CTC_FULLY_REFUNDABLE_REFORM = { + "gov.irs.credits.ctc.refundable.fully_refundable": { + "2024-01-01.2100-12-31": True + } +} + +# Garbage/invalid inputs to test +GARBAGE_INPUTS = [ + ("Empty place FIPS", "place/NY-"), + ("Invalid state code", "place/XX-12345"), + ("Non-existent place FIPS", "place/NY-99999"), + ("Malformed region - no dash", "place/NY12345"), + ("Malformed region - no slash", "placeNY-12345"), + ("Wrong prefix", "city/NY-51000"), + ("SQL injection attempt", "place/NY-51000'; DROP TABLE--"), + ("Very long FIPS", "place/NY-" + "1" * 100), + ("Negative FIPS", "place/NY--12345"), + ("None as region", None), + ("Empty string", ""), + ("Just place/", "place/"), + ("Lowercase state", "place/ny-51000"), +] + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def print_header(title: str) -> None: + """Print a formatted section header.""" + print("\n" + "=" * 70) + print(f" {title}") + print("=" * 70) + + +def print_subheader(title: str) -> None: + """Print a formatted subsection header.""" + print(f"\n--- {title} ---") + + +def format_currency(amount: float) -> str: + """Format a number as currency.""" + if abs(amount) >= 1e9: + return f"${amount/1e9:.2f}B" + elif abs(amount) >= 1e6: + return f"${amount/1e6:.2f}M" + elif abs(amount) >= 1e3: + return f"${amount/1e3:.2f}K" + else: + return f"${amount:.2f}" + + +def safe_simulation_run( + region: str | None, + reform: dict | None = None, + description: str = "", +) -> dict[str, Any] | None: + """ + Safely attempt to run a simulation, catching and reporting errors. + + Returns simulation results dict or None if failed. + """ + from policyengine import Simulation + + try: + options = { + "country": "us", + "scope": "macro", + "time_period": "2024", + } + if region is not None: + options["region"] = region + if reform is not None: + options["reform"] = reform + + sim = Simulation(**options) + + # Try to get basic metrics + result = { + "success": True, + "region": region, + "household_count_baseline": None, + "total_net_income_baseline": None, + "budgetary_impact": None, + } + + # Get baseline household count and income + try: + baseline_sim = sim.baseline_simulation + if baseline_sim is not None: + weights = baseline_sim.calculate("household_weight").values + result["household_count_baseline"] = weights.sum() + + net_income = baseline_sim.calculate( + "household_net_income", map_to="household" + ).values + result["total_net_income_baseline"] = ( + net_income * weights + ).sum() + except Exception as e: + result["baseline_error"] = str(e) + + # Get budgetary impact if reform was applied + if reform is not None: + try: + impact = sim.calculate("budgetary/overall/budgetary_impact") + result["budgetary_impact"] = impact + except Exception as e: + result["impact_error"] = str(e) + + return result + + except Exception as e: + return { + "success": False, + "region": region, + "error": str(e), + "error_type": type(e).__name__, + "traceback": traceback.format_exc(), + } + + +# ============================================================================= +# Test Functions +# ============================================================================= + + +def test_place_vs_state_comparison() -> dict[str, Any]: + """ + Test place-level simulations and compare to state-level results. + + Returns a summary of results. + """ + print_header("Place vs State Comparison Tests") + + results = { + "places": {}, + "states": {}, + "comparisons": {}, + } + + # First, run state-level simulations + print_subheader("Running State-Level Baseline Simulations") + state_regions = set(place[2] for place in PLACES_TO_TEST) + + for state_region in state_regions: + print(f" Running {state_region}...", end=" ", flush=True) + result = safe_simulation_run(state_region) + results["states"][state_region] = result + if result and result.get("success"): + print( + f"OK - {result.get('household_count_baseline', 'N/A'):,.0f} households" + ) + else: + print(f"FAILED - {result.get('error', 'Unknown error')}") + + # Run place-level simulations + print_subheader("Running Place-Level Baseline Simulations") + + for place_name, place_region, state_region in PLACES_TO_TEST: + print( + f" Running {place_name} ({place_region})...", end=" ", flush=True + ) + result = safe_simulation_run(place_region) + results["places"][place_region] = result + if result and result.get("success"): + print( + f"OK - {result.get('household_count_baseline', 'N/A'):,.0f} households" + ) + else: + print(f"FAILED - {result.get('error', 'Unknown error')}") + + # Compare place to state + print_subheader("Place-to-State Population Ratios") + + for place_name, place_region, state_region in PLACES_TO_TEST: + place_result = results["places"].get(place_region, {}) + state_result = results["states"].get(state_region, {}) + + if ( + place_result.get("success") + and state_result.get("success") + and place_result.get("household_count_baseline") + and state_result.get("household_count_baseline") + ): + + place_hh = place_result["household_count_baseline"] + state_hh = state_result["household_count_baseline"] + ratio = place_hh / state_hh if state_hh > 0 else 0 + + results["comparisons"][place_region] = { + "place_households": place_hh, + "state_households": state_hh, + "ratio": ratio, + } + + print(f" {place_name}:") + print(f" Place households: {place_hh:>12,.0f}") + print(f" State households: {state_hh:>12,.0f}") + print(f" Ratio (place/state): {ratio:>10.2%}") + else: + print(f" {place_name}: Could not compare (missing data)") + + return results + + +def test_ctc_reform_impact() -> dict[str, Any]: + """ + Test CTC fully refundable reform impact at place and state levels. + + Checks that budgetary impact is roughly proportional to population. + """ + print_header("CTC Fully Refundable Reform Impact Tests") + + results = { + "places": {}, + "states": {}, + "proportionality_checks": {}, + } + + # Run state-level reform simulations + print_subheader("Running State-Level Reform Simulations") + state_regions = set(place[2] for place in PLACES_TO_TEST) + + for state_region in state_regions: + print( + f" Running {state_region} with CTC reform...", end=" ", flush=True + ) + result = safe_simulation_run(state_region, CTC_FULLY_REFUNDABLE_REFORM) + results["states"][state_region] = result + if result and result.get("success"): + impact = result.get("budgetary_impact") + if impact is not None: + print(f"OK - Impact: {format_currency(impact)}") + else: + print(f"OK - Impact: N/A") + else: + print(f"FAILED - {result.get('error', 'Unknown error')}") + + # Run place-level reform simulations + print_subheader("Running Place-Level Reform Simulations") + + for place_name, place_region, state_region in PLACES_TO_TEST: + print( + f" Running {place_name} with CTC reform...", end=" ", flush=True + ) + result = safe_simulation_run(place_region, CTC_FULLY_REFUNDABLE_REFORM) + results["places"][place_region] = result + if result and result.get("success"): + impact = result.get("budgetary_impact") + if impact is not None: + print(f"OK - Impact: {format_currency(impact)}") + else: + print(f"OK - Impact: N/A") + else: + print(f"FAILED - {result.get('error', 'Unknown error')}") + + # Check proportionality + print_subheader("Proportionality Check (Impact vs Population Ratio)") + + for place_name, place_region, state_region in PLACES_TO_TEST: + place_result = results["places"].get(place_region, {}) + state_result = results["states"].get(state_region, {}) + + place_impact = place_result.get("budgetary_impact") + state_impact = state_result.get("budgetary_impact") + place_hh = place_result.get("household_count_baseline") + state_hh = state_result.get("household_count_baseline") + + if all( + v is not None + for v in [place_impact, state_impact, place_hh, state_hh] + ): + if state_impact != 0 and state_hh != 0: + impact_ratio = place_impact / state_impact + pop_ratio = place_hh / state_hh + + # Check if impact ratio is within 50% of population ratio + # (allowing for demographic differences) + ratio_diff = ( + abs(impact_ratio - pop_ratio) / pop_ratio + if pop_ratio > 0 + else float("inf") + ) + is_proportional = ratio_diff < 0.5 + + results["proportionality_checks"][place_region] = { + "place_impact": place_impact, + "state_impact": state_impact, + "impact_ratio": impact_ratio, + "population_ratio": pop_ratio, + "ratio_difference": ratio_diff, + "is_proportional": is_proportional, + } + + status = "PASS" if is_proportional else "WARN" + print(f" {place_name}: [{status}]") + print(f" Impact ratio: {impact_ratio:>10.2%}") + print(f" Population ratio: {pop_ratio:>10.2%}") + print(f" Difference: {ratio_diff:>10.2%}") + else: + print(f" {place_name}: Could not check (missing data)") + + return results + + +def test_garbage_inputs() -> dict[str, Any]: + """ + Test how the system handles garbage/invalid inputs. + + This helps identify where we need more guards or type checking. + """ + print_header("Garbage Input Tests") + + results = { + "tests": {}, + "needs_guards": [], + } + + for test_name, garbage_input in GARBAGE_INPUTS: + print(f" Testing: {test_name}", end=" ", flush=True) + print(f"({repr(garbage_input)[:50]})") + + result = safe_simulation_run(garbage_input) + results["tests"][test_name] = { + "input": garbage_input, + "result": result, + } + + if result and result.get("success"): + # This might indicate we need more validation + print(f" Result: SUCCEEDED (unexpected - may need validation)") + results["needs_guards"].append( + { + "test_name": test_name, + "input": garbage_input, + "reason": "Succeeded when it probably should have failed", + } + ) + else: + error_type = ( + result.get("error_type", "Unknown") if result else "Unknown" + ) + error_msg = ( + result.get("error", "Unknown error") + if result + else "Unknown error" + ) + print(f" Result: FAILED ({error_type})") + print(f" Error: {error_msg[:100]}") + + # Check if error message is helpful + if ( + "Traceback" in str(result.get("traceback", "")) + and "KeyError" in error_type + ): + results["needs_guards"].append( + { + "test_name": test_name, + "input": garbage_input, + "reason": f"KeyError instead of descriptive error: {error_msg}", + } + ) + + return results + + +def test_datasets_functions_directly() -> dict[str, Any]: + """ + Test the datasets.py functions directly with various inputs. + """ + print_header("Direct Function Tests (datasets.py)") + + from policyengine.utils.data.datasets import ( + determine_us_region_type, + parse_us_place_region, + get_default_dataset, + ) + + results = {"tests": [], "needs_guards": []} + + # Test determine_us_region_type + print_subheader("Testing determine_us_region_type()") + + test_cases = [ + ("place/NY-51000", "place"), + ("place/NJ-57000", "place"), + ("state/NY", "state"), + ("us", "nationwide"), + (None, "nationwide"), + ("congressional_district/CA-01", "congressional_district"), + ] + + for input_val, expected in test_cases: + try: + result = determine_us_region_type(input_val) + status = "PASS" if result == expected else "FAIL" + print(f" {repr(input_val):40} -> {result:25} [{status}]") + results["tests"].append( + { + "function": "determine_us_region_type", + "input": input_val, + "result": result, + "expected": expected, + "passed": result == expected, + } + ) + except Exception as e: + print(f" {repr(input_val):40} -> ERROR: {e}") + results["tests"].append( + { + "function": "determine_us_region_type", + "input": input_val, + "error": str(e), + "passed": False, + } + ) + + # Test invalid inputs for determine_us_region_type + print_subheader("Testing determine_us_region_type() with invalid inputs") + + invalid_inputs = [ + "invalid/something", + "place", + "", + "city/nyc", # Should now fail since we removed city + ] + + for input_val in invalid_inputs: + try: + result = determine_us_region_type(input_val) + print(f" {repr(input_val):40} -> {result} [UNEXPECTED SUCCESS]") + results["needs_guards"].append( + { + "function": "determine_us_region_type", + "input": input_val, + "reason": "Should have raised ValueError", + } + ) + except ValueError as e: + print(f" {repr(input_val):40} -> ValueError (expected) [PASS]") + except Exception as e: + print(f" {repr(input_val):40} -> {type(e).__name__}: {e}") + + # Test parse_us_place_region + print_subheader("Testing parse_us_place_region()") + + place_test_cases = [ + ("place/NY-51000", ("NY", "51000")), + ("place/NJ-57000", ("NJ", "57000")), + ("place/MI-34000", ("MI", "34000")), + ("place/ca-12345", ("ca", "12345")), # lowercase + ] + + for input_val, expected in place_test_cases: + try: + result = parse_us_place_region(input_val) + status = "PASS" if result == expected else "FAIL" + print(f" {repr(input_val):30} -> {result} [{status}]") + except Exception as e: + print(f" {repr(input_val):30} -> ERROR: {e}") + + # Test invalid inputs for parse_us_place_region + print_subheader("Testing parse_us_place_region() with invalid inputs") + + invalid_place_inputs = [ + "state/NY", # Wrong prefix + "place/NY", # Missing FIPS + "place/", # Empty + "NY-51000", # Missing prefix + "place/NY-", # Empty FIPS + ] + + for input_val in invalid_place_inputs: + try: + result = parse_us_place_region(input_val) + print(f" {repr(input_val):30} -> {result} [UNEXPECTED SUCCESS]") + results["needs_guards"].append( + { + "function": "parse_us_place_region", + "input": input_val, + "reason": f"Should have raised error, got {result}", + } + ) + except (ValueError, IndexError) as e: + print( + f" {repr(input_val):30} -> {type(e).__name__} (expected) [PASS]" + ) + except Exception as e: + print(f" {repr(input_val):30} -> {type(e).__name__}: {e}") + results["needs_guards"].append( + { + "function": "parse_us_place_region", + "input": input_val, + "reason": f"Got {type(e).__name__} instead of ValueError", + } + ) + + return results + + +def print_summary(all_results: dict[str, Any]) -> None: + """Print a summary of all test results.""" + print_header("SUMMARY") + + # Check for items needing guards + all_needs_guards = [] + for test_name, results in all_results.items(): + if isinstance(results, dict) and "needs_guards" in results: + for item in results["needs_guards"]: + item["test_suite"] = test_name + all_needs_guards.append(item) + + if all_needs_guards: + print("\nItems that may need additional guards/validation:") + for item in all_needs_guards: + print( + f" - {item.get('test_suite', 'Unknown')}: {item.get('test_name', item.get('function', 'Unknown'))}" + ) + print(f" Input: {repr(item.get('input', 'N/A'))[:60]}") + print(f" Reason: {item.get('reason', 'N/A')}") + else: + print("\nNo additional guards appear to be needed.") + + print("\n" + "=" * 70) + print(" Test script completed") + print("=" * 70) + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + """Run all tests.""" + print("\n" + "=" * 70) + print(" Place-Level Simulation Test Script") + print(" Testing: NYC (NY), Paterson (NJ), Grand Rapids (MI)") + print("=" * 70) + + all_results = {} + + # Run direct function tests first (faster, no simulation) + all_results["direct_function_tests"] = test_datasets_functions_directly() + + # Run place vs state comparison + all_results["place_vs_state"] = test_place_vs_state_comparison() + + # Run CTC reform tests + all_results["ctc_reform"] = test_ctc_reform_impact() + + # Run garbage input tests + all_results["garbage_inputs"] = test_garbage_inputs() + + # Print summary + print_summary(all_results) + + return all_results + + +if __name__ == "__main__": + results = main() diff --git a/tests/conftest.py b/tests/conftest.py index e816468f..86f9d14e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,3 +12,17 @@ mock_single_economy_without_districts, mock_single_economy_with_null_districts, ) + +from tests.fixtures.country.us_places import ( + mock_simulation_type, + mock_reform, + simulation_instance, + mock_sim_mixed_places, + mock_sim_no_paterson, + mock_sim_all_paterson, + mock_sim_bytes_places, + mock_sim_multiple_nj_places, + mock_sim_for_reform_test, + mini_place_dataset, + mini_place_dataset_with_bytes, +) diff --git a/tests/country/test_us.py b/tests/country/test_us.py index 385b3c1e..6e1666e7 100644 --- a/tests/country/test_us.py +++ b/tests/country/test_us.py @@ -2,22 +2,34 @@ def test_us_macro_single(): + """Test US macro single economy calculation using a state dataset. + + Uses Delaware (smallest state by population with data) to reduce memory + usage in CI while still testing the full simulation pipeline. + """ from policyengine import Simulation sim = Simulation( scope="macro", country="us", + region="state/DE", ) sim.calculate_single_economy() def test_us_macro_comparison(): + """Test US macro economy comparison using a state dataset. + + Uses Delaware (smallest state by population with data) to reduce memory + usage in CI while still testing the full comparison pipeline. + """ from policyengine import Simulation sim = Simulation( scope="macro", country="us", + region="state/DE", reform={ "gov.usda.snap.income.deductions.earned_income": {"2025": 0.05} }, diff --git a/tests/country/test_us_places.py b/tests/country/test_us_places.py new file mode 100644 index 00000000..bfdda841 --- /dev/null +++ b/tests/country/test_us_places.py @@ -0,0 +1,324 @@ +"""Tests for US place-level (city) filtering functionality.""" + +import pytest +import pandas as pd +from unittest.mock import Mock, patch + +from tests.fixtures.country.us_places import ( + # Place FIPS Constants + NJ_PATERSON_FIPS, + NJ_NEWARK_FIPS, + NJ_JERSEY_CITY_FIPS, + CA_LOS_ANGELES_FIPS, + TX_HOUSTON_FIPS, + NONEXISTENT_PLACE_FIPS, + # Region String Constants + NJ_PATERSON_REGION, + NJ_NEWARK_REGION, + # Test Data Arrays + MIXED_PLACES_WITH_PATERSON, + PLACES_WITHOUT_PATERSON, + ALL_PATERSON_PLACES, + MIXED_PLACES_BYTES, + MULTIPLE_NJ_PLACES, + TWO_PLACES_FOR_REFORM_TEST, + # Expected Results Constants + EXPECTED_PATERSON_COUNT_IN_MIXED, + EXPECTED_PATERSON_COUNT_IN_BYTES, + EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ, + MINI_DATASET_PATERSON_COUNT, + MINI_DATASET_PATERSON_IDS, + MINI_DATASET_NEWARK_COUNT, + MINI_DATASET_NEWARK_IDS, + MINI_DATASET_JERSEY_CITY_COUNT, + MINI_DATASET_JERSEY_CITY_IDS, + MINI_DATASET_PATERSON_TOTAL_WEIGHT, + MINI_DATASET_NEWARK_TOTAL_WEIGHT, + MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT, + MINI_DATASET_BYTES_PATERSON_COUNT, + MINI_DATASET_BYTES_PATERSON_IDS, + # Factory Functions + create_mock_simulation_with_place_fips, + create_mock_simulation_with_bytes_place_fips, + create_mock_simulation_type, + create_simulation_instance, +) + + +class TestFilterUsSimulationByPlace: + """Tests for the _filter_us_simulation_by_place method.""" + + def test__given__households_in_target_place__then__filters_to_matching_households( + self, + ): + # Given + mock_sim = create_mock_simulation_with_place_fips( + MIXED_PLACES_WITH_PATERSON + ) + mock_simulation_type = create_mock_simulation_type() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_us_simulation_by_place( + simulation=mock_sim, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=None, + ) + + # Then + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + + assert len(filtered_df) == EXPECTED_PATERSON_COUNT_IN_MIXED + assert all(filtered_df["place_fips"] == NJ_PATERSON_FIPS) + + def test__given__no_households_in_target_place__then__returns_empty_dataset( + self, + ): + # Given + mock_sim = create_mock_simulation_with_place_fips( + PLACES_WITHOUT_PATERSON + ) + mock_simulation_type = create_mock_simulation_type() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_us_simulation_by_place( + simulation=mock_sim, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=None, + ) + + # Then + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + + assert len(filtered_df) == 0 + + def test__given__all_households_in_target_place__then__returns_all_households( + self, + ): + # Given + mock_sim = create_mock_simulation_with_place_fips(ALL_PATERSON_PLACES) + mock_simulation_type = create_mock_simulation_type() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_us_simulation_by_place( + simulation=mock_sim, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=None, + ) + + # Then + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + + assert len(filtered_df) == len(ALL_PATERSON_PLACES) + + def test__given__bytes_place_fips_in_dataset__then__still_filters_correctly( + self, + ): + # Given: place_fips stored as bytes (as it might be in HDF5) + mock_sim = create_mock_simulation_with_bytes_place_fips( + MIXED_PLACES_BYTES + ) + mock_simulation_type = create_mock_simulation_type() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_us_simulation_by_place( + simulation=mock_sim, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=None, + ) + + # Then + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + + assert len(filtered_df) == EXPECTED_PATERSON_COUNT_IN_BYTES + + def test__given__reform_provided__then__passes_reform_to_simulation_type( + self, + ): + # Given + mock_sim = create_mock_simulation_with_place_fips( + TWO_PLACES_FOR_REFORM_TEST + ) + mock_simulation_type = create_mock_simulation_type() + mock_reform = Mock() + + # When + sim_instance = create_simulation_instance() + result = sim_instance._filter_us_simulation_by_place( + simulation=mock_sim, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=mock_reform, + ) + + # Then + call_args = mock_simulation_type.call_args + assert call_args.kwargs["reform"] is mock_reform + + def test__given__different_place_in_same_state__then__filters_correctly( + self, + ): + # Given: Multiple NJ places + mock_sim = create_mock_simulation_with_place_fips(MULTIPLE_NJ_PLACES) + mock_simulation_type = create_mock_simulation_type() + + # When: Filter for Newark only + sim_instance = create_simulation_instance() + result = sim_instance._filter_us_simulation_by_place( + simulation=mock_sim, + simulation_type=mock_simulation_type, + region=NJ_NEWARK_REGION, + reform=None, + ) + + # Then + call_args = mock_simulation_type.call_args + filtered_df = call_args.kwargs["dataset"] + + assert len(filtered_df) == EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ + assert all(filtered_df["place_fips"] == NJ_NEWARK_FIPS) + + +class TestApplyUsRegionToSimulationWithPlace: + """Tests for _apply_us_region_to_simulation with place regions.""" + + def test__given__place_region__then__calls_filter_by_place(self): + # Given + sim_instance = create_simulation_instance() + mock_simulation = Mock() + mock_simulation_type = Mock + + # When / Then + with patch.object( + sim_instance, + "_filter_us_simulation_by_place", + return_value=Mock(), + ) as mock_filter: + result = sim_instance._apply_us_region_to_simulation( + simulation=mock_simulation, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=None, + ) + + mock_filter.assert_called_once_with( + simulation=mock_simulation, + simulation_type=mock_simulation_type, + region=NJ_PATERSON_REGION, + reform=None, + ) + + +class TestMiniDatasetPlaceFiltering: + """Integration-style tests using a mini in-memory dataset. + + Uses the mini_place_dataset fixture from conftest.py. + """ + + def test__given__mini_dataset__then__paterson_filter_returns_correct_count( + self, mini_place_dataset + ): + # Given + df = mini_place_dataset + + # When + filtered_df = df[df["place_fips"] == NJ_PATERSON_FIPS] + + # Then + assert len(filtered_df) == MINI_DATASET_PATERSON_COUNT + assert ( + filtered_df["household_id"].tolist() == MINI_DATASET_PATERSON_IDS + ) + + def test__given__mini_dataset__then__newark_filter_returns_correct_count( + self, mini_place_dataset + ): + # Given + df = mini_place_dataset + + # When + filtered_df = df[df["place_fips"] == NJ_NEWARK_FIPS] + + # Then + assert len(filtered_df) == MINI_DATASET_NEWARK_COUNT + assert filtered_df["household_id"].tolist() == MINI_DATASET_NEWARK_IDS + + def test__given__mini_dataset__then__jersey_city_filter_returns_correct_count( + self, mini_place_dataset + ): + # Given + df = mini_place_dataset + + # When + filtered_df = df[df["place_fips"] == NJ_JERSEY_CITY_FIPS] + + # Then + assert len(filtered_df) == MINI_DATASET_JERSEY_CITY_COUNT + assert ( + filtered_df["household_id"].tolist() + == MINI_DATASET_JERSEY_CITY_IDS + ) + + def test__given__mini_dataset__then__total_weight_sums_correctly_per_place( + self, mini_place_dataset + ): + # Given + df = mini_place_dataset + + # When + paterson_weight = df[df["place_fips"] == NJ_PATERSON_FIPS][ + "household_weight" + ].sum() + newark_weight = df[df["place_fips"] == NJ_NEWARK_FIPS][ + "household_weight" + ].sum() + jersey_city_weight = df[df["place_fips"] == NJ_JERSEY_CITY_FIPS][ + "household_weight" + ].sum() + + # Then + assert paterson_weight == MINI_DATASET_PATERSON_TOTAL_WEIGHT + assert newark_weight == MINI_DATASET_NEWARK_TOTAL_WEIGHT + assert jersey_city_weight == MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT + + def test__given__mini_dataset__then__nonexistent_place_returns_empty( + self, mini_place_dataset + ): + # Given + df = mini_place_dataset + + # When + filtered_df = df[df["place_fips"] == NONEXISTENT_PLACE_FIPS] + + # Then + assert len(filtered_df) == 0 + + def test__given__mini_dataset_with_bytes_fips__then__filtering_works( + self, mini_place_dataset_with_bytes + ): + # Given + df = mini_place_dataset_with_bytes + + # When: Filter handling both str and bytes + mask = (df["place_fips"] == NJ_PATERSON_FIPS) | ( + df["place_fips"] == NJ_PATERSON_FIPS.encode() + ) + filtered_df = df[mask] + + # Then + assert len(filtered_df) == MINI_DATASET_BYTES_PATERSON_COUNT + assert ( + filtered_df["household_id"].tolist() + == MINI_DATASET_BYTES_PATERSON_IDS + ) diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 131d59d6..a5862e50 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -10,3 +10,62 @@ mock_single_economy_without_districts, mock_single_economy_with_null_districts, ) + +from tests.fixtures.country.us_places import ( + # Place FIPS Constants + NJ_PATERSON_FIPS, + NJ_NEWARK_FIPS, + NJ_JERSEY_CITY_FIPS, + CA_LOS_ANGELES_FIPS, + TX_HOUSTON_FIPS, + NONEXISTENT_PLACE_FIPS, + NJ_PATERSON_FIPS_BYTES, + NJ_NEWARK_FIPS_BYTES, + CA_LOS_ANGELES_FIPS_BYTES, + TX_HOUSTON_FIPS_BYTES, + # Region String Constants + NJ_PATERSON_REGION, + NJ_NEWARK_REGION, + NJ_JERSEY_CITY_REGION, + CA_LOS_ANGELES_REGION, + TX_HOUSTON_REGION, + # Test Data Arrays + MIXED_PLACES_WITH_PATERSON, + PLACES_WITHOUT_PATERSON, + ALL_PATERSON_PLACES, + MIXED_PLACES_BYTES, + MULTIPLE_NJ_PLACES, + TWO_PLACES_FOR_REFORM_TEST, + # Expected Results Constants + EXPECTED_PATERSON_COUNT_IN_MIXED, + EXPECTED_PATERSON_COUNT_IN_BYTES, + EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ, + MINI_DATASET_PATERSON_COUNT, + MINI_DATASET_PATERSON_IDS, + MINI_DATASET_NEWARK_COUNT, + MINI_DATASET_NEWARK_IDS, + MINI_DATASET_JERSEY_CITY_COUNT, + MINI_DATASET_JERSEY_CITY_IDS, + MINI_DATASET_PATERSON_TOTAL_WEIGHT, + MINI_DATASET_NEWARK_TOTAL_WEIGHT, + MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT, + MINI_DATASET_BYTES_PATERSON_COUNT, + MINI_DATASET_BYTES_PATERSON_IDS, + # Factory Functions + create_mock_simulation_with_place_fips, + create_mock_simulation_with_bytes_place_fips, + create_mock_simulation_type, + create_simulation_instance, + # Pytest Fixtures + mock_simulation_type, + mock_reform, + simulation_instance, + mock_sim_mixed_places, + mock_sim_no_paterson, + mock_sim_all_paterson, + mock_sim_bytes_places, + mock_sim_multiple_nj_places, + mock_sim_for_reform_test, + mini_place_dataset, + mini_place_dataset_with_bytes, +) diff --git a/tests/fixtures/country/__init__.py b/tests/fixtures/country/__init__.py new file mode 100644 index 00000000..4eaa139b --- /dev/null +++ b/tests/fixtures/country/__init__.py @@ -0,0 +1,60 @@ +"""Test fixtures for country-specific tests.""" + +from tests.fixtures.country.us_places import ( + # Place FIPS Constants + NJ_PATERSON_FIPS, + NJ_NEWARK_FIPS, + NJ_JERSEY_CITY_FIPS, + CA_LOS_ANGELES_FIPS, + TX_HOUSTON_FIPS, + NONEXISTENT_PLACE_FIPS, + NJ_PATERSON_FIPS_BYTES, + NJ_NEWARK_FIPS_BYTES, + CA_LOS_ANGELES_FIPS_BYTES, + TX_HOUSTON_FIPS_BYTES, + # Region String Constants + NJ_PATERSON_REGION, + NJ_NEWARK_REGION, + NJ_JERSEY_CITY_REGION, + CA_LOS_ANGELES_REGION, + TX_HOUSTON_REGION, + # Test Data Arrays + MIXED_PLACES_WITH_PATERSON, + PLACES_WITHOUT_PATERSON, + ALL_PATERSON_PLACES, + MIXED_PLACES_BYTES, + MULTIPLE_NJ_PLACES, + TWO_PLACES_FOR_REFORM_TEST, + # Expected Results Constants + EXPECTED_PATERSON_COUNT_IN_MIXED, + EXPECTED_PATERSON_COUNT_IN_BYTES, + EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ, + MINI_DATASET_PATERSON_COUNT, + MINI_DATASET_PATERSON_IDS, + MINI_DATASET_NEWARK_COUNT, + MINI_DATASET_NEWARK_IDS, + MINI_DATASET_JERSEY_CITY_COUNT, + MINI_DATASET_JERSEY_CITY_IDS, + MINI_DATASET_PATERSON_TOTAL_WEIGHT, + MINI_DATASET_NEWARK_TOTAL_WEIGHT, + MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT, + MINI_DATASET_BYTES_PATERSON_COUNT, + MINI_DATASET_BYTES_PATERSON_IDS, + # Factory Functions + create_mock_simulation_with_place_fips, + create_mock_simulation_with_bytes_place_fips, + create_mock_simulation_type, + create_simulation_instance, + # Pytest Fixtures + mock_simulation_type, + mock_reform, + simulation_instance, + mock_sim_mixed_places, + mock_sim_no_paterson, + mock_sim_all_paterson, + mock_sim_bytes_places, + mock_sim_multiple_nj_places, + mock_sim_for_reform_test, + mini_place_dataset, + mini_place_dataset_with_bytes, +) diff --git a/tests/fixtures/country/us_places.py b/tests/fixtures/country/us_places.py new file mode 100644 index 00000000..ff65af60 --- /dev/null +++ b/tests/fixtures/country/us_places.py @@ -0,0 +1,358 @@ +"""Test fixtures for US place-level filtering tests.""" + +import pytest +import numpy as np +import pandas as pd +from unittest.mock import Mock + +# ============================================================================= +# Place FIPS Constants +# ============================================================================= + +# New Jersey places +NJ_PATERSON_FIPS = "57000" +NJ_NEWARK_FIPS = "51000" +NJ_JERSEY_CITY_FIPS = "36000" + +# Other state places +CA_LOS_ANGELES_FIPS = "44000" +TX_HOUSTON_FIPS = "35000" + +# Non-existent place for testing empty results +NONEXISTENT_PLACE_FIPS = "99999" + +# Bytes versions (as stored in HDF5) +NJ_PATERSON_FIPS_BYTES = b"57000" +NJ_NEWARK_FIPS_BYTES = b"51000" +CA_LOS_ANGELES_FIPS_BYTES = b"44000" +TX_HOUSTON_FIPS_BYTES = b"35000" + +# ============================================================================= +# Region String Constants +# ============================================================================= + +NJ_PATERSON_REGION = "place/NJ-57000" +NJ_NEWARK_REGION = "place/NJ-51000" +NJ_JERSEY_CITY_REGION = "place/NJ-36000" +CA_LOS_ANGELES_REGION = "place/CA-44000" +TX_HOUSTON_REGION = "place/TX-35000" + +# ============================================================================= +# Test Data Arrays - Place FIPS combinations for various test scenarios +# ============================================================================= + +# Mixed places: 3 Paterson, 1 LA, 1 Houston (5 total) +MIXED_PLACES_WITH_PATERSON = [ + NJ_PATERSON_FIPS, + NJ_PATERSON_FIPS, + CA_LOS_ANGELES_FIPS, + TX_HOUSTON_FIPS, + NJ_PATERSON_FIPS, +] + +# No Paterson: LA, Houston, Newark (3 total) +PLACES_WITHOUT_PATERSON = [ + CA_LOS_ANGELES_FIPS, + TX_HOUSTON_FIPS, + NJ_NEWARK_FIPS, +] + +# All Paterson (3 total) +ALL_PATERSON_PLACES = [ + NJ_PATERSON_FIPS, + NJ_PATERSON_FIPS, + NJ_PATERSON_FIPS, +] + +# Bytes version: 2 Paterson, 1 LA, 1 Houston (4 total) +MIXED_PLACES_BYTES = [ + NJ_PATERSON_FIPS_BYTES, + NJ_PATERSON_FIPS_BYTES, + CA_LOS_ANGELES_FIPS_BYTES, + TX_HOUSTON_FIPS_BYTES, +] + +# Multiple NJ places: 2 Paterson, 2 Newark, 1 Jersey City (5 total) +MULTIPLE_NJ_PLACES = [ + NJ_PATERSON_FIPS, + NJ_NEWARK_FIPS, + NJ_JERSEY_CITY_FIPS, + NJ_PATERSON_FIPS, + NJ_NEWARK_FIPS, +] + +# Simple two-place array for reform test +TWO_PLACES_FOR_REFORM_TEST = [ + NJ_PATERSON_FIPS, + CA_LOS_ANGELES_FIPS, +] + +# ============================================================================= +# Expected Results Constants +# ============================================================================= + +# Expected counts when filtering MIXED_PLACES_WITH_PATERSON for Paterson +EXPECTED_PATERSON_COUNT_IN_MIXED = 3 + +# Expected counts when filtering MIXED_PLACES_BYTES for Paterson +EXPECTED_PATERSON_COUNT_IN_BYTES = 2 + +# Expected counts when filtering MULTIPLE_NJ_PLACES for Newark +EXPECTED_NEWARK_COUNT_IN_MULTIPLE_NJ = 2 + +# Mini dataset expected values +MINI_DATASET_PATERSON_COUNT = 4 +MINI_DATASET_PATERSON_IDS = [0, 1, 4, 7] +MINI_DATASET_NEWARK_COUNT = 3 +MINI_DATASET_NEWARK_IDS = [2, 5, 8] +MINI_DATASET_JERSEY_CITY_COUNT = 3 +MINI_DATASET_JERSEY_CITY_IDS = [3, 6, 9] + +# Mini dataset expected weights per place +MINI_DATASET_PATERSON_TOTAL_WEIGHT = 4200.0 +MINI_DATASET_NEWARK_TOTAL_WEIGHT = 5200.0 +MINI_DATASET_JERSEY_CITY_TOTAL_WEIGHT = 3600.0 + +# Mini dataset with bytes expected values +MINI_DATASET_BYTES_PATERSON_COUNT = 2 +MINI_DATASET_BYTES_PATERSON_IDS = [0, 1] + + +# ============================================================================= +# Mock Factory Functions +# ============================================================================= + + +def create_mock_simulation_with_place_fips( + place_fips_values: list[str], + household_ids: list[int] | None = None, +) -> Mock: + """Create a mock simulation with place_fips data. + + Args: + place_fips_values: List of place FIPS codes for each household. + household_ids: Optional list of household IDs. + + Returns: + Mock simulation object with calculate() and to_input_dataframe() configured. + """ + if household_ids is None: + household_ids = list(range(len(place_fips_values))) + + mock_sim = Mock() + + # Mock calculate to return place_fips values + mock_calculate_result = Mock() + mock_calculate_result.values = np.array(place_fips_values) + mock_sim.calculate.return_value = mock_calculate_result + + # Mock to_input_dataframe to return a DataFrame + df = pd.DataFrame( + { + "household_id": household_ids, + "place_fips": place_fips_values, + } + ) + mock_sim.to_input_dataframe.return_value = df + + return mock_sim + + +def create_mock_simulation_with_bytes_place_fips( + place_fips_values: list[bytes], + household_ids: list[int] | None = None, +) -> Mock: + """Create a mock simulation with bytes place_fips data (as from HDF5). + + Args: + place_fips_values: List of place FIPS codes as bytes. + household_ids: Optional list of household IDs. + + Returns: + Mock simulation object with calculate() and to_input_dataframe() configured. + """ + if household_ids is None: + household_ids = list(range(len(place_fips_values))) + + mock_sim = Mock() + + mock_calculate_result = Mock() + mock_calculate_result.values = np.array(place_fips_values) + mock_sim.calculate.return_value = mock_calculate_result + + df = pd.DataFrame( + { + "household_id": household_ids, + "place_fips": place_fips_values, + } + ) + mock_sim.to_input_dataframe.return_value = df + + return mock_sim + + +def create_mock_simulation_type() -> Mock: + """Create a mock simulation type class. + + Returns: + Mock that can be called to create simulation instances. + """ + mock_simulation_type = Mock() + mock_simulation_type.return_value = Mock() + return mock_simulation_type + + +def create_simulation_instance(): + """Create a Simulation instance without calling __init__. + + Returns: + A bare Simulation instance for testing methods directly. + """ + from policyengine.simulation import Simulation + + return object.__new__(Simulation) + + +# ============================================================================= +# Pytest Fixtures - Pre-configured Mocks +# ============================================================================= + + +@pytest.fixture +def mock_simulation_type() -> Mock: + """Fixture for a mock simulation type class.""" + return create_mock_simulation_type() + + +@pytest.fixture +def mock_reform() -> Mock: + """Fixture for a mock reform object.""" + return Mock(name="mock_reform") + + +@pytest.fixture +def simulation_instance(): + """Fixture for a bare Simulation instance.""" + return create_simulation_instance() + + +@pytest.fixture +def mock_sim_mixed_places() -> Mock: + """Mock simulation with mixed places including Paterson.""" + return create_mock_simulation_with_place_fips(MIXED_PLACES_WITH_PATERSON) + + +@pytest.fixture +def mock_sim_no_paterson() -> Mock: + """Mock simulation with no Paterson households.""" + return create_mock_simulation_with_place_fips(PLACES_WITHOUT_PATERSON) + + +@pytest.fixture +def mock_sim_all_paterson() -> Mock: + """Mock simulation where all households are in Paterson.""" + return create_mock_simulation_with_place_fips(ALL_PATERSON_PLACES) + + +@pytest.fixture +def mock_sim_bytes_places() -> Mock: + """Mock simulation with bytes-encoded place FIPS (HDF5 style).""" + return create_mock_simulation_with_bytes_place_fips(MIXED_PLACES_BYTES) + + +@pytest.fixture +def mock_sim_multiple_nj_places() -> Mock: + """Mock simulation with multiple NJ places.""" + return create_mock_simulation_with_place_fips(MULTIPLE_NJ_PLACES) + + +@pytest.fixture +def mock_sim_for_reform_test() -> Mock: + """Mock simulation for testing reform passthrough.""" + return create_mock_simulation_with_place_fips(TWO_PLACES_FOR_REFORM_TEST) + + +# ============================================================================= +# Pytest Fixtures - Dataset Fixtures +# ============================================================================= + + +@pytest.fixture +def mini_place_dataset() -> pd.DataFrame: + """Create a mini dataset with place_fips for testing. + + Simulates 10 households across 3 NJ places: + - Paterson (57000): 4 households (IDs: 0, 1, 4, 7) + - Newark (51000): 3 households (IDs: 2, 5, 8) + - Jersey City (36000): 3 households (IDs: 3, 6, 9) + + Returns: + DataFrame with household_id, place_fips, household_weight, + and household_net_income columns. + """ + return pd.DataFrame( + { + "household_id": list(range(10)), + "place_fips": [ + NJ_PATERSON_FIPS, + NJ_PATERSON_FIPS, + NJ_NEWARK_FIPS, + NJ_JERSEY_CITY_FIPS, + NJ_PATERSON_FIPS, + NJ_NEWARK_FIPS, + NJ_JERSEY_CITY_FIPS, + NJ_PATERSON_FIPS, + NJ_NEWARK_FIPS, + NJ_JERSEY_CITY_FIPS, + ], + "household_weight": [ + 1000.0, + 1500.0, + 2000.0, + 1200.0, + 800.0, + 1800.0, + 1100.0, + 900.0, + 1400.0, + 1300.0, + ], + "household_net_income": [ + 50000.0, + 60000.0, + 75000.0, + 45000.0, + 55000.0, + 80000.0, + 40000.0, + 65000.0, + 70000.0, + 48000.0, + ], + } + ) + + +@pytest.fixture +def mini_place_dataset_with_bytes() -> pd.DataFrame: + """Create a mini dataset with bytes place_fips (as from HDF5). + + Simulates 4 households across 2 NJ places with bytes-encoded FIPS: + - Paterson (b"57000"): 2 households (IDs: 0, 1) + - Newark (b"51000"): 2 households (IDs: 2, 3) + + Returns: + DataFrame with household_id, place_fips (as bytes), and household_weight. + """ + return pd.DataFrame( + { + "household_id": [0, 1, 2, 3], + "place_fips": [ + NJ_PATERSON_FIPS_BYTES, + NJ_PATERSON_FIPS_BYTES, + NJ_NEWARK_FIPS_BYTES, + NJ_NEWARK_FIPS_BYTES, + ], + "household_weight": [1000.0, 1000.0, 1000.0, 1000.0], + } + ) diff --git a/tests/fixtures/simulation.py b/tests/fixtures/simulation.py index d2361fef..a0f368fc 100644 --- a/tests/fixtures/simulation.py +++ b/tests/fixtures/simulation.py @@ -1,7 +1,11 @@ -from policyengine.simulation import SimulationOptions from unittest.mock import patch, Mock import pytest -from policyengine.utils.data.datasets import CPS_2023 + +# Constants that don't require heavy imports +SAMPLE_DATASET_FILENAME = "frs_2023_24.h5" +SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private" +SAMPLE_DATASET_URI_PREFIX = "gs://" +SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}" non_data_uk_sim_options = { "country": "uk", @@ -21,25 +25,34 @@ "baseline": None, } -uk_sim_options_no_data = SimulationOptions.model_validate( - { - **non_data_uk_sim_options, - "data": None, - } -) -us_sim_options_cps_dataset = SimulationOptions.model_validate( - {**non_data_us_sim_options, "data": CPS_2023} -) +# Lazy-loaded simulation options to avoid importing policyengine.simulation at collection time +def get_uk_sim_options_no_data(): + from policyengine.simulation import SimulationOptions + + return SimulationOptions.model_validate( + { + **non_data_uk_sim_options, + "data": None, + } + ) + + +def get_us_sim_options_cps_dataset(): + from policyengine.simulation import SimulationOptions + from policyengine.utils.data.datasets import CPS_2023 + + return SimulationOptions.model_validate( + {**non_data_us_sim_options, "data": CPS_2023} + ) -SAMPLE_DATASET_FILENAME = "frs_2023_24.h5" -SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private" -SAMPLE_DATASET_URI_PREFIX = "gs://" -SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}" -uk_sim_options_pe_dataset = SimulationOptions.model_validate( - {**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS} -) +def get_uk_sim_options_pe_dataset(): + from policyengine.simulation import SimulationOptions + + return SimulationOptions.model_validate( + {**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS} + ) @pytest.fixture diff --git a/tests/test_simulation.py b/tests/test_simulation.py index e51f23a0..f761588e 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -1,7 +1,7 @@ from .fixtures.simulation import ( - uk_sim_options_no_data, - uk_sim_options_pe_dataset, - us_sim_options_cps_dataset, + get_uk_sim_options_no_data, + get_uk_sim_options_pe_dataset, + get_us_sim_options_cps_dataset, mock_get_default_dataset, mock_dataset, SAMPLE_DATASET_FILENAME, @@ -10,28 +10,27 @@ import sys from copy import deepcopy -from policyengine import Simulation -from policyengine.outputs.macro.single.calculate_single_economy import ( - GeneralEconomyTask, -) - class TestSimulation: class TestSetData: def test__given_no_data_option__sets_default_dataset( self, mock_get_default_dataset, mock_dataset ): + from policyengine import Simulation # Don't run entire init script sim = object.__new__(Simulation) + uk_sim_options_no_data = get_uk_sim_options_no_data() sim.options = deepcopy(uk_sim_options_no_data) sim._set_data(uk_sim_options_no_data.data) def test__given_pe_dataset__sets_data_option_to_dataset( self, mock_dataset ): + from policyengine import Simulation sim = object.__new__(Simulation) + uk_sim_options_pe_dataset = get_uk_sim_options_pe_dataset() sim.options = deepcopy(uk_sim_options_pe_dataset) sim._set_data(uk_sim_options_pe_dataset.data) @@ -41,6 +40,7 @@ def test__given_cps_2023_in_filename__sets_time_period_to_2023( from policyengine import Simulation sim = object.__new__(Simulation) + us_sim_options_cps_dataset = get_us_sim_options_cps_dataset() sim.options = deepcopy(us_sim_options_cps_dataset) sim._set_data(us_sim_options_cps_dataset.data) @@ -49,7 +49,7 @@ def test__given_dataset_with_time_period__sets_time_period(self): from policyengine import Simulation sim = object.__new__(Simulation) - + us_sim_options_cps_dataset = get_us_sim_options_cps_dataset() print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr) assert ( sim._set_data_time_period(us_sim_options_cps_dataset.data) @@ -62,6 +62,7 @@ def test__given_dataset_without_time_period__does_not_set_time_period( from policyengine import Simulation sim = object.__new__(Simulation) + uk_sim_options_pe_dataset = get_uk_sim_options_pe_dataset() assert ( sim._set_data_time_period(uk_sim_options_pe_dataset.data) == None @@ -71,6 +72,9 @@ class TestCalculateCliffs: def test__calculates_correct_cliff_metrics( self, mock_simulation_with_cliff_vars ): + from policyengine.outputs.macro.single.calculate_single_economy import ( + GeneralEconomyTask, + ) task = object.__new__(GeneralEconomyTask) task.simulation = mock_simulation_with_cliff_vars diff --git a/tests/utils/data/test_datasets.py b/tests/utils/data/test_datasets.py index 48ae5f5b..eb2f8ba7 100644 --- a/tests/utils/data/test_datasets.py +++ b/tests/utils/data/test_datasets.py @@ -1,7 +1,14 @@ """Tests for datasets.py utilities.""" import pytest -from policyengine.utils.data.datasets import process_gs_path +from policyengine.utils.data.datasets import ( + process_gs_path, + determine_us_region_type, + parse_us_place_region, + get_us_state_dataset_path, + _get_default_us_dataset, + US_DATA_BUCKET, +) class TestProcessGsPath: @@ -80,3 +87,197 @@ def test_real_policyengine_paths(self): assert bucket == "policyengine-us-data" assert path == "states/CA.h5" assert version is None + + +class TestDetermineUsRegionType: + """Tests for determine_us_region_type function with place regions.""" + + def test__given__place_region_string__then__returns_place(self): + # Given + region = "place/NJ-57000" + + # When + result = determine_us_region_type(region) + + # Then + assert result == "place" + + def test__given__place_region_with_different_state__then__returns_place( + self, + ): + # Given + region = "place/CA-44000" + + # When + result = determine_us_region_type(region) + + # Then + assert result == "place" + + def test__given__none_region__then__returns_nationwide(self): + # Given + region = None + + # When + result = determine_us_region_type(region) + + # Then + assert result == "nationwide" + + def test__given__us_region__then__returns_nationwide(self): + # Given + region = "us" + + # When + result = determine_us_region_type(region) + + # Then + assert result == "nationwide" + + def test__given__state_region__then__returns_state(self): + # Given + region = "state/CA" + + # When + result = determine_us_region_type(region) + + # Then + assert result == "state" + + def test__given__congressional_district_region__then__returns_congressional_district( + self, + ): + # Given + region = "congressional_district/CA-01" + + # When + result = determine_us_region_type(region) + + # Then + assert result == "congressional_district" + + def test__given__invalid_region__then__raises_value_error(self): + # Given + region = "invalid/something" + + # When / Then + with pytest.raises(ValueError) as exc_info: + determine_us_region_type(region) + + assert "Unrecognized US region format" in str(exc_info.value) + + +class TestParseUsPlaceRegion: + """Tests for parse_us_place_region function.""" + + def test__given__nj_paterson_region__then__returns_nj_and_57000(self): + # Given + region = "place/NJ-57000" + + # When + state_code, place_fips = parse_us_place_region(region) + + # Then + assert state_code == "NJ" + assert place_fips == "57000" + + def test__given__ca_los_angeles_region__then__returns_ca_and_44000(self): + # Given + region = "place/CA-44000" + + # When + state_code, place_fips = parse_us_place_region(region) + + # Then + assert state_code == "CA" + assert place_fips == "44000" + + def test__given__tx_houston_region__then__returns_tx_and_35000(self): + # Given + region = "place/TX-35000" + + # When + state_code, place_fips = parse_us_place_region(region) + + # Then + assert state_code == "TX" + assert place_fips == "35000" + + def test__given__lowercase_state_code__then__returns_lowercase(self): + # Given + region = "place/ny-51000" + + # When + state_code, place_fips = parse_us_place_region(region) + + # Then + assert state_code == "ny" + assert place_fips == "51000" + + def test__given__place_fips_with_leading_zeros__then__preserves_zeros( + self, + ): + # Given + region = "place/AL-07000" + + # When + state_code, place_fips = parse_us_place_region(region) + + # Then + assert state_code == "AL" + assert place_fips == "07000" + + +class TestGetDefaultUsDatasetForPlace: + """Tests for _get_default_us_dataset with place regions.""" + + def test__given__place_region__then__returns_state_dataset_path(self): + # Given + region = "place/NJ-57000" + + # When + result = _get_default_us_dataset(region) + + # Then + expected = f"{US_DATA_BUCKET}/states/NJ.h5" + assert result == expected + + def test__given__ca_place_region__then__returns_ca_state_dataset(self): + # Given + region = "place/CA-44000" + + # When + result = _get_default_us_dataset(region) + + # Then + expected = f"{US_DATA_BUCKET}/states/CA.h5" + assert result == expected + + def test__given__lowercase_state_in_place__then__returns_uppercase_state_path( + self, + ): + # Given + region = "place/tx-35000" + + # When + result = _get_default_us_dataset(region) + + # Then + # get_us_state_dataset_path uppercases the state code + expected = f"{US_DATA_BUCKET}/states/TX.h5" + assert result == expected + + def test__given__place_region__then__result_matches_state_region_dataset( + self, + ): + # Given + place_region = "place/GA-04000" + state_region = "state/GA" + + # When + place_result = _get_default_us_dataset(place_region) + state_result = _get_default_us_dataset(state_region) + + # Then + # Place regions should load the same dataset as their parent state + assert place_result == state_result