diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..e5391df5 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - Optimisation improvements for loading tax-benefit systems (caching). diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index fea8216a..1dfb9954 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -125,6 +125,7 @@ def __init__( self.is_over_dataset = dataset is not None self.invalidated_caches = set() + self._fast_cache: dict = {} self.debug: bool = False self.trace: bool = trace self.tracer: SimpleTracer = ( @@ -481,6 +482,17 @@ def calculate( elif period is None and self.default_calculation_period is not None: period = periods.period(self.default_calculation_period) + # Fast path: skip tracer, random seed and all _calculate() machinery for + # already-computed values. map_to and decode_enums are NOT cached here — + # they are post-processing steps that vary per call site. + if map_to is None and not decode_enums: + _fast_key = (variable_name, str(period)) + _fast_cache = getattr(self, "_fast_cache", None) + if _fast_cache is not None: + _cached = _fast_cache.get(_fast_key) + if _cached is not None: + return _cached + self.tracer.record_calculation_start( variable_name, period, self.branch_name ) @@ -804,6 +816,9 @@ def _calculate( if is_cache_available: smc.set_cache_value(cache_path, array) + if hasattr(self, "_fast_cache"): + self._fast_cache[(variable_name, str(period))] = array + return array def purge_cache_of_invalid_values(self) -> None: @@ -813,6 +828,7 @@ def purge_cache_of_invalid_values(self) -> None: for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) + self._fast_cache.pop((_name, str(_period)), None) self.invalidated_caches = set() def calculate_add( @@ -1193,6 +1209,12 @@ def delete_arrays(self, variable: str, period: Period = None) -> None: True """ self.get_holder(variable).delete_arrays(period) + if period is None: + self._fast_cache = { + k: v for k, v in self._fast_cache.items() if k[0] != variable + } + else: + self._fast_cache.pop((variable, str(period)), None) def get_known_periods(self, variable: str) -> List[Period]: """ @@ -1281,8 +1303,15 @@ def clone( new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ("debug", "trace", "tracer", "branches"): + if key not in ( + "debug", + "trace", + "tracer", + "branches", + "_fast_cache", + ): new_dict[key] = value + new._fast_cache = {} new.persons = self.persons.clone(new) setattr(new, new.persons.entity.key, new.persons)