diff --git a/CHANGELOG.md b/CHANGELOG.md index d8557e0e0e..be55e36542 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ ENHANCEMENTS: * Pass OIDC vars directly to the devcontainer ([#4871](https://github.com/microsoft/AzureTRE/issues/4871)) BUG FIXES: +* Implement service bus consumer monitoring with heartbeat detection, automatic recovery, and /health endpoint integration to prevent operations getting stuck indefinitely ([#4464](https://github.com/microsoft/AzureTRE/issues/4464)) * Fix property substitution not occuring where there is only a main step in the pipeline ([#4824](https://github.com/microsoft/AzureTRE/issues/4824)) * Fix Mysql template ignored storage_mb ([#4846](https://github.com/microsoft/AzureTRE/issues/4846)) * Fix duplicate `TOPIC_SUBSCRIPTION_NAME` in `core/terraform/airlock/airlock_processor.tf` ([#4847](https://github.com/microsoft/AzureTRE/pull/4847)) diff --git a/api_app/_version.py b/api_app/_version.py index 6623c5202f..2cb28789f2 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.25.14" +__version__ = "0.25.15" diff --git a/api_app/api/routes/health.py b/api_app/api/routes/health.py index 2cefe21266..c9674d6179 100644 --- a/api_app/api/routes/health.py +++ b/api_app/api/routes/health.py @@ -3,7 +3,7 @@ from core import credentials from models.schemas.status import HealthCheck, ServiceStatus, StatusEnum from resources import strings -from services.health_checker import create_resource_processor_status, create_state_store_status, create_service_bus_status +from services.health_checker import create_airlock_consumer_status, create_deployment_consumer_status, create_resource_processor_status, create_state_store_status, create_service_bus_status from services.logging import logger router = APIRouter() @@ -14,22 +14,28 @@ async def health_check(request: Request) -> HealthCheck: # The health endpoint checks the status of key components of the system. # Note that Resource Processor checks incur Azure management calls, so # calling this endpoint frequently may result in API throttling. + deployment_consumer = getattr(request.app.state, 'deployment_status_updater', None) + airlock_consumer = getattr(request.app.state, 'airlock_status_updater', None) + async with credentials.get_credential_async_context() as credential: - cosmos, sb, rp = await asyncio.gather( + cosmos, sb, rp, deploy, airlock = await asyncio.gather( create_state_store_status(), create_service_bus_status(credential), - create_resource_processor_status(credential) + create_resource_processor_status(credential), + create_deployment_consumer_status(deployment_consumer), + create_airlock_consumer_status(airlock_consumer), ) - cosmos_status, cosmos_message = cosmos - sb_status, sb_message = sb - rp_status, rp_message = rp - if cosmos_status == StatusEnum.not_ok or sb_status == StatusEnum.not_ok or rp_status == StatusEnum.not_ok: - logger.error(f'Cosmos Status: {cosmos_status}, message: {cosmos_message}') - logger.error(f'Service Bus Status: {sb_status}, message: {sb_message}') - logger.error(f'Resource Processor Status: {rp_status}, message: {rp_message}') - services = [ServiceStatus(service=strings.COSMOS_DB, status=cosmos_status, message=cosmos_message), - ServiceStatus(service=strings.SERVICE_BUS, status=sb_status, message=sb_message), - ServiceStatus(service=strings.RESOURCE_PROCESSOR, status=rp_status, message=rp_message)] + services = [ + ServiceStatus(service=strings.COSMOS_DB, status=cosmos[0], message=cosmos[1]), + ServiceStatus(service=strings.SERVICE_BUS, status=sb[0], message=sb[1]), + ServiceStatus(service=strings.RESOURCE_PROCESSOR, status=rp[0], message=rp[1]), + ServiceStatus(service=strings.DEPLOYMENT_STATUS_CONSUMER, status=deploy[0], message=deploy[1]), + ServiceStatus(service=strings.AIRLOCK_STATUS_CONSUMER, status=airlock[0], message=airlock[1]), + ] + + for svc in services: + if svc.status == StatusEnum.not_ok: + logger.error(f'{svc.service} Status: {svc.status}, message: {svc.message}') return HealthCheck(services=services) diff --git a/api_app/main.py b/api_app/main.py index 0bdc769141..f3f53fecc9 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -34,8 +34,12 @@ async def lifespan(app: FastAPI): airlockStatusUpdater = AirlockStatusUpdater() await airlockStatusUpdater.init_repos() - asyncio.create_task(deploymentStatusUpdater.receive_messages()) - asyncio.create_task(airlockStatusUpdater.receive_messages()) + # Store consumer references on app.state so the /health endpoint can check their heartbeats + app.state.deployment_status_updater = deploymentStatusUpdater + app.state.airlock_status_updater = airlockStatusUpdater + + asyncio.create_task(deploymentStatusUpdater.supervisor_with_heartbeat_check()) + asyncio.create_task(airlockStatusUpdater.supervisor_with_heartbeat_check()) yield diff --git a/api_app/resources/strings.py b/api_app/resources/strings.py index c54a40ba52..07726b43ad 100644 --- a/api_app/resources/strings.py +++ b/api_app/resources/strings.py @@ -106,6 +106,12 @@ RESOURCE_PROCESSOR_GENERAL_ERROR_MESSAGE = "Resource Processor is not responding" RESOURCE_PROCESSOR_HEALTHY_MESSAGE = "HealthState/healthy" +# Service bus consumer status +DEPLOYMENT_STATUS_CONSUMER = "Deployment Status Consumer" +AIRLOCK_STATUS_CONSUMER = "Airlock Status Consumer" +CONSUMER_HEARTBEAT_STALE = "{} heartbeat is stale or missing" +CONSUMER_NOT_INITIALIZED = "{} has not been initialized" + # Error strings ACCESS_APP_IS_MISSING_ROLE = "The App is missing role" ACCESS_PLEASE_SUPPLY_CLIENT_ID = "Please supply the client_id for the AAD application" diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index a643404a86..8cf09822ce 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -16,12 +16,13 @@ from models.domain.airlock_operations import StepResultStatusUpdateMessage from core import config, credentials from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer -class AirlockStatusUpdater(): +class AirlockStatusUpdater(ServiceBusConsumer): def __init__(self): - pass + super().__init__("airlock_status_updater") async def init_repos(self): self.airlock_request_repo = await AirlockRequestRepository.create() @@ -36,9 +37,10 @@ async def receive_messages(self): try: current_time = time.time() polling_count += 1 + # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: - logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_STEP_RESULT_QUEUE} queue {polling_count} times in the last minute") + logger.info(f"{config.SERVICE_BUS_STEP_RESULT_QUEUE} queue polled {polling_count} times in the last minute") last_heartbeat_time = current_time polling_count = 0 @@ -60,17 +62,20 @@ async def receive_messages(self): await asyncio.sleep(10) + # Update heartbeat for supervisor monitoring + self.update_heartbeat() + except OperationTimeoutError: # Timeout occurred whilst connecting to a session - this is expected and indicates no non-empty sessions are available logger.debug("No sessions for this process. Will look again...") - except ServiceBusConnectionError: + except ServiceBusConnectionError as e: # Occasionally there will be a transient / network-level error in connecting to SB. - logger.info("Unknown Service Bus connection error. Will retry...") + logger.warning(f"Service Bus connection error (will retry): {e}") except Exception as e: # Catch all other exceptions, log them via .exception to get the stack trace, and reconnect - logger.exception(f"Unknown exception. Will retry - {e}") + logger.exception(f"Unexpected error in message processing: {type(e).__name__}: {e}") async def process_message(self, msg): with tracer.start_as_current_span("process_message") as current_span: diff --git a/api_app/service_bus/deployment_status_updater.py b/api_app/service_bus/deployment_status_updater.py index 41670464c7..c04fb0c06f 100644 --- a/api_app/service_bus/deployment_status_updater.py +++ b/api_app/service_bus/deployment_status_updater.py @@ -1,7 +1,7 @@ -import asyncio import json import uuid import time +from typing import Dict, List, Any from pydantic import ValidationError, parse_obj_as @@ -21,11 +21,12 @@ from models.domain.operation import DeploymentStatusUpdateMessage, Operation, OperationStep, Status from resources import strings from services.logging import logger, tracer +from service_bus.service_bus_consumer import ServiceBusConsumer -class DeploymentStatusUpdater(): +class DeploymentStatusUpdater(ServiceBusConsumer): def __init__(self): - pass + super().__init__("deployment_status_updater") async def init_repos(self): self.operations_repo = await OperationRepository.create() @@ -33,9 +34,6 @@ async def init_repos(self): self.resource_template_repo = await ResourceTemplateRepository.create() self.resource_history_repo = await ResourceHistoryRepository.create() - def run(self, *args, **kwargs): - asyncio.run(self.receive_messages()) - async def receive_messages(self): with tracer.start_as_current_span("deployment_status_receive_messages"): last_heartbeat_time = 0 @@ -45,9 +43,10 @@ async def receive_messages(self): try: current_time = time.time() polling_count += 1 + # Log a heartbeat message every 60 seconds to show the service is still working if current_time - last_heartbeat_time >= 60: - logger.info(f"Queue reader heartbeat: Polled {config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue {polling_count} times in the last minute") + logger.info(f"{config.SERVICE_BUS_DEPLOYMENT_STATUS_UPDATE_QUEUE} queue polled {polling_count} times in the last minute") last_heartbeat_time = current_time polling_count = 0 @@ -69,19 +68,22 @@ async def receive_messages(self): await receiver.abandon_message(msg) logger.info(f"Closing session: {receiver.session.session_id}") + # Update heartbeat for supervisor monitoring + self.update_heartbeat() + except OperationTimeoutError: # Timeout occurred whilst connecting to a session - this is expected and indicates no non-empty sessions are available logger.debug("No sessions for this process. Will look again...") - except ServiceBusConnectionError: + except ServiceBusConnectionError as e: # Occasionally there will be a transient / network-level error in connecting to SB. - logger.info("Unknown Service Bus connection error. Will retry...") + logger.warning(f"Service Bus connection error (will retry): {e}") except Exception as e: # Catch all other exceptions, log them via .exception to get the stack trace, and reconnect - logger.exception(f"Unknown exception. Will retry - {e}") + logger.exception(f"Unexpected error in message processing: {type(e).__name__}: {e}") - async def process_message(self, msg): + async def process_message(self, msg) -> bool: complete_message = False message = "" @@ -115,6 +117,11 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage try: # update the op operation = await self.operations_repo.get_operation_by_id(str(message.operationId)) + + # Add null safety for operation steps + if not operation.steps: + raise ValueError(f"Operation {message.operationId} has no steps") + step_to_update = None is_last_step = False @@ -128,7 +135,7 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage is_last_step = True if step_to_update is None: - raise f"Error finding step {message.stepId} in operation {message.operationId}" + raise ValueError(f"Step {message.stepId} not found in operation {message.operationId}") # update the step status step_to_update.status = message.status @@ -159,7 +166,8 @@ async def update_status_in_database(self, message: DeploymentStatusUpdateMessage # more steps in the op to do? if is_last_step is False: - assert current_step_index < (len(operation.steps) - 1) + if current_step_index >= len(operation.steps) - 1: + raise ValueError(f"Step index {current_step_index} is the last step in operation (has {len(operation.steps)} steps), but more steps were expected") next_step = operation.steps[current_step_index + 1] # catch any errors in updating the resource - maybe Cosmos / schema invalid etc, and report them back to the op @@ -255,7 +263,7 @@ def get_failure_status_for_action(self, action: RequestAction): return status - def create_updated_resource_document(self, resource: dict, message: DeploymentStatusUpdateMessage): + def create_updated_resource_document(self, resource: Dict[str, Any], message: DeploymentStatusUpdateMessage) -> Dict[str, Any]: """ Merge the outputs with the resource document to persist """ @@ -268,7 +276,7 @@ def create_updated_resource_document(self, resource: dict, message: DeploymentSt return resource - def convert_outputs_to_dict(self, outputs_list: [Output]): + def convert_outputs_to_dict(self, outputs_list: List[Output]) -> Dict[str, Any]: """ Convert a list of Porter outputs to a dictionary """ diff --git a/api_app/service_bus/service_bus_consumer.py b/api_app/service_bus/service_bus_consumer.py new file mode 100644 index 0000000000..dd04fbcce8 --- /dev/null +++ b/api_app/service_bus/service_bus_consumer.py @@ -0,0 +1,77 @@ +import asyncio +import time + +from services.logging import logger + +# Configuration constants for monitoring intervals +HEARTBEAT_CHECK_INTERVAL_SECONDS = 60 +HEARTBEAT_STALENESS_THRESHOLD_SECONDS = 300 +RESTART_DELAY_SECONDS = 5 +MAX_RESTART_DELAY_SECONDS = 300 +SUPERVISOR_ERROR_DELAY_SECONDS = 30 + + +class ServiceBusConsumer: + + def __init__(self, consumer_name: str): + self.service_name = consumer_name.replace('_', ' ').title() + self._last_heartbeat: float = time.monotonic() + self._restart_delay: float = RESTART_DELAY_SECONDS + logger.info(f"Initializing {self.service_name}") + + def update_heartbeat(self): + self._last_heartbeat = time.monotonic() + + def check_heartbeat(self, max_age_seconds: int = HEARTBEAT_STALENESS_THRESHOLD_SECONDS) -> bool: + age = time.monotonic() - self._last_heartbeat + if age > max_age_seconds: + logger.warning(f"{self.service_name} heartbeat is {age:.1f}s old (threshold: {max_age_seconds}s)") + return False + return True + + async def supervisor_with_heartbeat_check(self): + task = None + try: + while True: + try: + task_just_started = False + if task is None or task.done(): + if task and task.done(): + try: + await task + except Exception as e: + logger.exception(f"{self.service_name} task failed: {e}") + await asyncio.sleep(self._restart_delay) + self._restart_delay = min(self._restart_delay * 2, MAX_RESTART_DELAY_SECONDS) + + logger.info(f"Starting {self.service_name} task...") + task = asyncio.create_task(self.receive_messages()) + self.update_heartbeat() + task_just_started = True + + await asyncio.sleep(HEARTBEAT_CHECK_INTERVAL_SECONDS) + + if not self.check_heartbeat(): + logger.warning(f"{self.service_name} heartbeat stale, restarting...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + task = None + elif not task_just_started: + self._restart_delay = RESTART_DELAY_SECONDS + except Exception as e: + logger.exception(f"{self.service_name} supervisor error: {e}") + await asyncio.sleep(SUPERVISOR_ERROR_DELAY_SECONDS) + finally: + if task and not task.done(): + logger.info(f"Cleaning up {self.service_name} task...") + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def receive_messages(self): + raise NotImplementedError("Subclasses must implement receive_messages()") diff --git a/api_app/services/health_checker.py b/api_app/services/health_checker.py index a4d53067b0..8b56757317 100644 --- a/api_app/services/health_checker.py +++ b/api_app/services/health_checker.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple from azure.core import exceptions from azure.servicebus.aio import ServiceBusClient from azure.mgmt.compute.aio import ComputeManagementClient @@ -11,6 +11,7 @@ from core import config from models.schemas.status import StatusEnum from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer from services.logging import logger @@ -55,6 +56,22 @@ async def create_service_bus_status(credential) -> Tuple[StatusEnum, str]: return status, message +def create_consumer_status(consumer: Optional[ServiceBusConsumer], name: str) -> Tuple[StatusEnum, str]: + if consumer is None: + return StatusEnum.not_ok, strings.CONSUMER_NOT_INITIALIZED.format(name) + if consumer.check_heartbeat(): + return StatusEnum.ok, "" + return StatusEnum.not_ok, strings.CONSUMER_HEARTBEAT_STALE.format(name) + + +async def create_deployment_consumer_status(consumer: Optional[ServiceBusConsumer]) -> Tuple[StatusEnum, str]: + return create_consumer_status(consumer, strings.DEPLOYMENT_STATUS_CONSUMER) + + +async def create_airlock_consumer_status(consumer: Optional[ServiceBusConsumer]) -> Tuple[StatusEnum, str]: + return create_consumer_status(consumer, strings.AIRLOCK_STATUS_CONSUMER) + + async def create_resource_processor_status(credential) -> Tuple[StatusEnum, str]: status = StatusEnum.ok message = "" diff --git a/api_app/services/logging.py b/api_app/services/logging.py index ad6966b6d8..f674ee48a1 100644 --- a/api_app/services/logging.py +++ b/api_app/services/logging.py @@ -45,6 +45,7 @@ "urllib3.connectionpool" ] + logger = logging.getLogger("azuretre_api") tracer = trace.get_tracer("azuretre_api") diff --git a/api_app/tests_ma/test_api/test_routes/test_health.py b/api_app/tests_ma/test_api/test_routes/test_health.py index 21e795b619..ab52fe87f6 100644 --- a/api_app/tests_ma/test_api/test_routes/test_health.py +++ b/api_app/tests_ma/test_api/test_routes/test_health.py @@ -1,9 +1,10 @@ import pytest from httpx import AsyncClient -from mock import patch +from mock import patch, MagicMock from models.schemas.status import StatusEnum from resources import strings +from service_bus.service_bus_consumer import ServiceBusConsumer pytestmark = pytest.mark.asyncio @@ -48,3 +49,84 @@ async def test_health_response_contains_resource_processor_status(health_check_c response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) assert {"message": message, "service": strings.RESOURCE_PROCESSOR, "status": strings.OK} in response.json()["services"] + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_contains_consumer_statuses(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint includes deployment and airlock consumer status.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Simulate consumers stored on app.state with healthy heartbeats + mock_deployment_consumer = MagicMock(spec=ServiceBusConsumer) + mock_deployment_consumer.check_heartbeat.return_value = True + mock_airlock_consumer = MagicMock(spec=ServiceBusConsumer) + mock_airlock_consumer.check_heartbeat.return_value = True + app.state.deployment_status_updater = mock_deployment_consumer + app.state.airlock_status_updater = mock_airlock_consumer + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + assert {"message": "", "service": strings.DEPLOYMENT_STATUS_CONSUMER, "status": strings.OK} in services + assert {"message": "", "service": strings.AIRLOCK_STATUS_CONSUMER, "status": strings.OK} in services + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_reports_stale_consumer(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint reports not_ok when a consumer heartbeat is stale.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Simulate deployment consumer with stale heartbeat + mock_deployment_consumer = MagicMock(spec=ServiceBusConsumer) + mock_deployment_consumer.check_heartbeat.return_value = False + mock_airlock_consumer = MagicMock(spec=ServiceBusConsumer) + mock_airlock_consumer.check_heartbeat.return_value = True + app.state.deployment_status_updater = mock_deployment_consumer + app.state.airlock_status_updater = mock_airlock_consumer + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + deploy_svc = next(s for s in services if s["service"] == strings.DEPLOYMENT_STATUS_CONSUMER) + assert deploy_svc["status"] == strings.NOT_OK + assert deploy_svc["message"] == strings.CONSUMER_HEARTBEAT_STALE.format(strings.DEPLOYMENT_STATUS_CONSUMER) + + airlock_svc = next(s for s in services if s["service"] == strings.AIRLOCK_STATUS_CONSUMER) + assert airlock_svc["status"] == strings.OK + + +@patch("api.routes.health.create_resource_processor_status") +@patch("api.routes.health.create_service_bus_status") +@patch("api.routes.health.create_state_store_status") +async def test_health_response_handles_missing_consumers(health_check_cosmos_mock, health_check_service_bus_mock, health_check_rp_mock, app, + client: AsyncClient) -> None: + """Test that health endpoint handles missing consumer references gracefully.""" + message = "" + health_check_cosmos_mock.return_value = StatusEnum.ok, message + health_check_service_bus_mock.return_value = StatusEnum.ok, message + health_check_rp_mock.return_value = StatusEnum.ok, message + + # Remove consumer references from app.state if they exist + if hasattr(app.state, 'deployment_status_updater'): + delattr(app.state, 'deployment_status_updater') + if hasattr(app.state, 'airlock_status_updater'): + delattr(app.state, 'airlock_status_updater') + + response = await client.get(app.url_path_for(strings.API_GET_HEALTH_STATUS)) + services = response.json()["services"] + + deploy_svc = next(s for s in services if s["service"] == strings.DEPLOYMENT_STATUS_CONSUMER) + assert deploy_svc["status"] == strings.NOT_OK + assert deploy_svc["message"] == strings.CONSUMER_NOT_INITIALIZED.format(strings.DEPLOYMENT_STATUS_CONSUMER) diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py new file mode 100644 index 0000000000..293138eac9 --- /dev/null +++ b/api_app/tests_ma/test_service_bus/test_service_bus_consumer.py @@ -0,0 +1,395 @@ +import asyncio +import time +import pytest +from unittest.mock import patch + +from service_bus.service_bus_consumer import ( + ServiceBusConsumer, + HEARTBEAT_STALENESS_THRESHOLD_SECONDS, + RESTART_DELAY_SECONDS, + MAX_RESTART_DELAY_SECONDS, + HEARTBEAT_CHECK_INTERVAL_SECONDS, + SUPERVISOR_ERROR_DELAY_SECONDS, +) + + +# Create a concrete implementation for testing +class MockConsumer(ServiceBusConsumer): + def __init__(self): + super().__init__("test_consumer") + self.receive_messages_called = False + + async def receive_messages(self): + self.receive_messages_called = True + await asyncio.sleep(0.1) + return + + +@pytest.mark.asyncio +async def test_init(): + """Test initialization of ServiceBusConsumer.""" + consumer = MockConsumer() + assert consumer.service_name == "Test Consumer" + assert consumer._restart_delay == RESTART_DELAY_SECONDS + assert consumer._last_heartbeat > 0 + + +@pytest.mark.asyncio +async def test_update_heartbeat(): + """Test updating heartbeat updates timestamp.""" + consumer = MockConsumer() + old_heartbeat = consumer._last_heartbeat + await asyncio.sleep(0.01) + consumer.update_heartbeat() + + assert consumer._last_heartbeat > old_heartbeat + + +@pytest.mark.asyncio +async def test_check_heartbeat_recent(): + """Test checking a recent heartbeat returns True.""" + consumer = MockConsumer() + assert consumer.check_heartbeat(max_age_seconds=300) is True + + +@pytest.mark.asyncio +async def test_check_heartbeat_stale(): + """Test checking a stale heartbeat returns False.""" + consumer = MockConsumer() + consumer._last_heartbeat = time.monotonic() - 400 + assert consumer.check_heartbeat(max_age_seconds=300) is False + + +@pytest.mark.asyncio +async def test_check_heartbeat_default_uses_constant(): + """Test that check_heartbeat default max_age_seconds uses the module constant.""" + import inspect + sig = inspect.signature(ServiceBusConsumer.check_heartbeat) + default = sig.parameters['max_age_seconds'].default + assert default == HEARTBEAT_STALENESS_THRESHOLD_SECONDS + + +def test_restart_delay_configuration(): + """Test that configuration constants exist and have reasonable values.""" + assert RESTART_DELAY_SECONDS > 0 + assert RESTART_DELAY_SECONDS <= 10 + assert MAX_RESTART_DELAY_SECONDS >= RESTART_DELAY_SECONDS + assert MAX_RESTART_DELAY_SECONDS <= 600 + assert HEARTBEAT_CHECK_INTERVAL_SECONDS > 0 + assert HEARTBEAT_STALENESS_THRESHOLD_SECONDS > HEARTBEAT_CHECK_INTERVAL_SECONDS + assert SUPERVISOR_ERROR_DELAY_SECONDS > 0 + + +@pytest.mark.asyncio +async def test_backoff_increases_on_consecutive_failures(): + """Test that restart delay increases exponentially on consecutive failures via supervisor.""" + consumer = MockConsumer() + + async def failing_receive(): + raise RuntimeError("Simulated failure") + + consumer.receive_messages = failing_receive + + sleep_calls = [] + task_create_calls = 0 + + class FailingTask: + def __init__(self): + nonlocal task_create_calls + task_create_calls += 1 + + def done(self): + return True + + def cancel(self): + pass + + def __await__(self): + async def _await(): + raise RuntimeError("Simulated task failure") + return _await().__await__() + + async def mock_sleep(duration): + sleep_calls.append(duration) + if task_create_calls >= 4: + raise asyncio.CancelledError() + + def create_failing_task(coro): + coro.close() + return FailingTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_failing_task): + try: + await consumer.supervisor_with_heartbeat_check() + except asyncio.CancelledError: + pass + + # The backoff delays should increase exponentially: 5, 10, 20, ... + backoff_sleeps = [d for d in sleep_calls if d != HEARTBEAT_CHECK_INTERVAL_SECONDS] + assert backoff_sleeps[0] == RESTART_DELAY_SECONDS + assert backoff_sleeps[1] == RESTART_DELAY_SECONDS * 2 + assert backoff_sleeps[2] == RESTART_DELAY_SECONDS * 4 + + +@pytest.mark.asyncio +async def test_backoff_caps_at_maximum(): + """Test that restart delay caps at MAX_RESTART_DELAY_SECONDS via supervisor.""" + consumer = MockConsumer() + consumer._restart_delay = MAX_RESTART_DELAY_SECONDS + + task_create_calls = 0 + + class FailingTask: + def __init__(self): + nonlocal task_create_calls + task_create_calls += 1 + + def done(self): + return True + + def cancel(self): + pass + + def __await__(self): + async def _await(): + raise RuntimeError("Simulated task failure") + return _await().__await__() + + sleep_calls = [] + + async def mock_sleep(duration): + sleep_calls.append(duration) + if task_create_calls >= 2: + raise asyncio.CancelledError() + + def create_failing_task(coro): + coro.close() + return FailingTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_failing_task): + try: + await consumer.supervisor_with_heartbeat_check() + except asyncio.CancelledError: + pass + + backoff_sleeps = [d for d in sleep_calls if d != HEARTBEAT_CHECK_INTERVAL_SECONDS] + assert backoff_sleeps[0] == MAX_RESTART_DELAY_SECONDS + + +@pytest.mark.asyncio +async def test_supervisor_restarts_failed_task(): + """Test supervisor restarts the receive_messages task when it fails.""" + consumer = MockConsumer() + + task_create_calls = 0 + sleep_calls = [] + + class FailOnFirstDoneTask: + """A mock task that reports done() immediately to simulate task failure.""" + + def __init__(self): + nonlocal task_create_calls + task_create_calls += 1 + self._is_first_task = (task_create_calls == 1) + + def done(self): + # First task always reports done (crashed) + # Second task always reports running + return self._is_first_task + + def cancel(self): + pass + + def __await__(self): + async def _await(): + if self._is_first_task: + raise RuntimeError("Simulated task failure") + return None + return _await().__await__() + + iteration = 0 + + async def mock_sleep(duration): + nonlocal iteration + sleep_calls.append(duration) + iteration += 1 + if iteration >= 4: + raise KeyboardInterrupt("Test complete") + + consumer.check_heartbeat = lambda **kwargs: True + + def create_fail_task(coro): + coro.close() + return FailOnFirstDoneTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_fail_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + assert task_create_calls >= 2 + + +@pytest.mark.asyncio +async def test_supervisor_restarts_on_stale_heartbeat(): + """Test supervisor cancels and restarts task when heartbeat goes stale.""" + consumer = MockConsumer() + + heartbeat_calls = 0 + task_create_calls = 0 + task_cancel_calls = 0 + sleep_calls = [] + + def mock_check_heartbeat(**kwargs): + nonlocal heartbeat_calls + heartbeat_calls += 1 + if heartbeat_calls == 1: + return True # Heartbeat is fresh + elif heartbeat_calls == 2: + return False # Heartbeat is stale, should trigger restart + else: + raise KeyboardInterrupt("Test complete") + + async def mock_sleep(duration): + sleep_calls.append(duration) + + class MockTask: + def __init__(self): + nonlocal task_create_calls + task_create_calls += 1 + + def cancel(self): + nonlocal task_cancel_calls + task_cancel_calls += 1 + + def done(self): + return False + + def __await__(self): + async def _await(): + return None + return _await().__await__() + + consumer.check_heartbeat = mock_check_heartbeat + + def create_mock_task(coro): + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + assert heartbeat_calls >= 2 + assert task_create_calls >= 2 + assert task_cancel_calls >= 1 + assert HEARTBEAT_CHECK_INTERVAL_SECONDS in sleep_calls + assert consumer._restart_delay == RESTART_DELAY_SECONDS + + +@pytest.mark.asyncio +async def test_supervisor_cleanup_on_shutdown(): + """Test supervisor properly cleans up tasks when interrupted.""" + consumer = MockConsumer() + + task_created = False + task_cancelled = False + + class MockTask: + def __init__(self): + nonlocal task_created + task_created = True + self.done_count = 0 + + def done(self): + return self.done_count > 0 + + def cancel(self): + nonlocal task_cancelled + task_cancelled = True + self.done_count = 1 + + def __await__(self): + async def _await(): + if task_cancelled: + raise asyncio.CancelledError() + return None + return _await().__await__() + + call_count = 0 + + async def mock_sleep(duration): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise KeyboardInterrupt("Test cleanup") + + def create_mock_task(coro): + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task), \ + patch.object(consumer, "check_heartbeat", return_value=True): + + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + assert task_created, "Task should have been created" + assert task_cancelled, "Task should have been cancelled during cleanup" + + +@pytest.mark.asyncio +async def test_supervisor_backoff_resets_after_healthy_cycle(): + """Test that supervisor resets backoff after the task has been running through a full heartbeat cycle.""" + consumer = MockConsumer() + consumer._restart_delay = 160 + + heartbeat_calls = 0 + + def mock_check_heartbeat(**kwargs): + nonlocal heartbeat_calls + heartbeat_calls += 1 + if heartbeat_calls <= 2: + return True # Two healthy cycles — backoff should reset on the second + raise KeyboardInterrupt("Test complete") + + async def mock_sleep(duration): + pass + + class MockTask: + def done(self): + return False + + def cancel(self): + pass + + def __await__(self): + async def _await(): + return None + return _await().__await__() + + consumer.check_heartbeat = mock_check_heartbeat + + def create_mock_task(coro): + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + # Backoff resets after a task has been running continuously through a healthy heartbeat cycle + assert consumer._restart_delay == RESTART_DELAY_SECONDS diff --git a/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py new file mode 100644 index 0000000000..deba83c1ce --- /dev/null +++ b/api_app/tests_ma/test_service_bus/test_service_bus_edge_cases.py @@ -0,0 +1,115 @@ +import asyncio +import time +import pytest +from unittest.mock import patch +from service_bus.service_bus_consumer import ( + ServiceBusConsumer, + RESTART_DELAY_SECONDS, + HEARTBEAT_CHECK_INTERVAL_SECONDS, + HEARTBEAT_STALENESS_THRESHOLD_SECONDS, + MAX_RESTART_DELAY_SECONDS, + SUPERVISOR_ERROR_DELAY_SECONDS, +) + + +# Create a concrete implementation for testing edge cases +class MockConsumerForEdgeCases(ServiceBusConsumer): + def __init__(self): + super().__init__("test_consumer_edge") + self.receive_messages_called = False + + async def receive_messages(self): + self.receive_messages_called = True + await asyncio.sleep(0.01) + return + + +@pytest.mark.asyncio +async def test_supervisor_error_recovery(): + """Test supervisor recovers from an unexpected error in the main loop.""" + consumer = MockConsumerForEdgeCases() + + error_triggered = False + iteration = 0 + + class MockTask: + def done(self): + return False + + def cancel(self): + pass + + def __await__(self): + async def _await(): + return None + return _await().__await__() + + async def mock_sleep(duration): + nonlocal iteration, error_triggered + iteration += 1 + if iteration == 1 and not error_triggered: + error_triggered = True + raise ValueError("Unexpected supervisor error") + if iteration >= 3: + raise KeyboardInterrupt("Test complete") + + def create_mock_task(coro): + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + assert error_triggered, "Supervisor should have encountered and recovered from the error" + + +@pytest.mark.asyncio +async def test_supervisor_backoff_not_reset_on_stale_heartbeat(): + """Test that backoff is NOT reset when restarting due to stale heartbeat. + Backoff only resets after a healthy heartbeat cycle.""" + consumer = MockConsumerForEdgeCases() + consumer._restart_delay = 160 + + heartbeat_calls = 0 + + def mock_check_heartbeat(**kwargs): + nonlocal heartbeat_calls + heartbeat_calls += 1 + if heartbeat_calls == 1: + return False # Stale — backoff should NOT reset + raise KeyboardInterrupt("Test complete") + + async def mock_sleep(duration): + pass + + class MockTask: + def done(self): + return False + + def cancel(self): + pass + + def __await__(self): + async def _await(): + return None + return _await().__await__() + + consumer.check_heartbeat = mock_check_heartbeat + + def create_mock_task(coro): + coro.close() + return MockTask() + + with patch("service_bus.service_bus_consumer.asyncio.sleep", side_effect=mock_sleep), \ + patch("service_bus.service_bus_consumer.asyncio.create_task", side_effect=create_mock_task): + try: + await consumer.supervisor_with_heartbeat_check() + except KeyboardInterrupt: + pass + + # Backoff is preserved after stale heartbeat restart (not reset until a healthy cycle) + assert consumer._restart_delay == 160