diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index f0b9c7be3..fc923cf29 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -24,7 +24,9 @@ class PipelineManager: def __init__(self) -> None: - self._pipelines: list[Pipeline] = [ + self._pipelines: list[Pipeline] = [] + self._hinter = PipelineHinter() + for builtin_pipeline in [ ComputeGroupPipeline(), FleetPipeline(), GatewayPipeline(), @@ -35,8 +37,12 @@ def __init__(self) -> None: PlacementGroupPipeline(), RunPipeline(), VolumePipeline(), - ] - self._hinter = PipelineHinter(self._pipelines) + ]: + self.register_pipeline(builtin_pipeline) + + def register_pipeline(self, pipeline: Pipeline): + self._pipelines.append(pipeline) + self._hinter.register_pipeline(pipeline) def start(self): for pipeline in self._pipelines: @@ -64,11 +70,11 @@ def hinter(self): class PipelineHinter: - def __init__(self, pipelines: list[Pipeline]) -> None: - self._pipelines = pipelines + def __init__(self) -> None: self._hint_fetch_map: dict[str, list[Pipeline]] = {} - for pipeline in self._pipelines: - self._hint_fetch_map.setdefault(pipeline.hint_fetch_model_name, []).append(pipeline) + + def register_pipeline(self, pipeline: Pipeline): + self._hint_fetch_map.setdefault(pipeline.hint_fetch_model_name, []).append(pipeline) def hint_fetch(self, model_name: str): pipelines = self._hint_fetch_map.get(model_name) @@ -79,11 +85,17 @@ def hint_fetch(self, model_name: str): pipeline.hint_fetch() +_pipeline_manager = PipelineManager() + + +def get_pipeline_manager() -> PipelineManager: + return _pipeline_manager + + def start_pipeline_tasks() -> PipelineManager: """ Start tasks processed by fetch-workers pipelines based on db + in-memory queues. Suitable for tasks that run frequently and need to lock rows for a long time. """ - pipeline_manager = PipelineManager() - pipeline_manager.start() - return pipeline_manager + _pipeline_manager.start() + return _pipeline_manager