diff --git a/src/workload-orchestration/HISTORY.rst b/src/workload-orchestration/HISTORY.rst index f80c37473ad..49671d6fc4a 100644 --- a/src/workload-orchestration/HISTORY.rst +++ b/src/workload-orchestration/HISTORY.rst @@ -2,6 +2,17 @@ Release History =============== +5.1.1 +++++++ +* Resolved solution template name to uniqueIdentifier for ``az workload-orchestration target solution-revision-list`` and ``az workload-orchestration target solution-instance-list`` +* Added shared ``_target_helper.py`` for reusable solution template resolution logic +* Added ``az workload-orchestration support create-bundle`` command for troubleshooting Day 0 (installation) and Day N (runtime) issues on 3rd-party Kubernetes clusters: + * Collects cluster info, node details, pod/deployment/service/event descriptions across configurable namespaces + * Collects container logs (current + previous for crash-looping pods) with configurable tail lines + * Runs prerequisite validation checks across 10 categories + * Generates a zip bundle for sharing with Microsoft support + * Includes retry with exponential backoff and per-call timeout for resilient K8s API access + 5.1.0 ++++++ * Added new target solution management command: diff --git a/src/workload-orchestration/azext_workload_orchestration/_help.py b/src/workload-orchestration/azext_workload_orchestration/_help.py index 126d5d00714..1f93a9946f0 100644 --- a/src/workload-orchestration/azext_workload_orchestration/_help.py +++ b/src/workload-orchestration/azext_workload_orchestration/_help.py @@ -9,3 +9,40 @@ # pylint: disable=too-many-lines from knack.help_files import helps # pylint: disable=unused-import + + +helps['workload-orchestration support'] = """ +type: group +short-summary: Commands for troubleshooting and diagnostics of workload orchestration deployments. +""" + +helps['workload-orchestration support create-bundle'] = """ +type: command +short-summary: Create a support bundle for troubleshooting workload orchestration issues. +long-summary: | + Collects cluster information, resource descriptions, container logs, and runs + prerequisite validation checks. The output is a zip file that can be shared with + Microsoft support for troubleshooting Day 0 (installation) and Day N (runtime) issues. + + Collected data includes: + - Cluster info (version, nodes, namespaces) + - Pod/Deployment/Service/DaemonSet/Event descriptions per namespace + - Container logs (tailed by default) + - Network configuration (kube-proxy, external services, pod CIDRs) + - StorageClass, PV, webhook, CRD inventory + - WO component health (Symphony, cert-manager) + - Prerequisite checks (K8s version, node capacity, DNS, storage, RBAC) +examples: + - name: Create a support bundle with defaults + text: az workload-orchestration support create-bundle + - name: Create a named bundle + text: az workload-orchestration support create-bundle --bundle-name my-cluster-debug + - name: Create a bundle in a specific directory + text: az workload-orchestration support create-bundle --output-dir /tmp/bundles + - name: Collect full logs (no tail) for WO namespace only + text: az workload-orchestration support create-bundle --full-logs --namespaces workloadorchestration + - name: Run checks only, skip log collection + text: az workload-orchestration support create-bundle --skip-logs + - name: Use a specific kubeconfig and context + text: az workload-orchestration support create-bundle --kube-config ~/.kube/prod-config --kube-context my-cluster +""" diff --git a/src/workload-orchestration/azext_workload_orchestration/_params.py b/src/workload-orchestration/azext_workload_orchestration/_params.py index cfcec717c9c..b17dd8cccd7 100644 --- a/src/workload-orchestration/azext_workload_orchestration/_params.py +++ b/src/workload-orchestration/azext_workload_orchestration/_params.py @@ -10,4 +10,58 @@ def load_arguments(self, _): # pylint: disable=unused-argument - pass + with self.argument_context('workload-orchestration support create-bundle') as c: + c.argument( + 'bundle_name', + options_list=['--bundle-name', '-n'], + help='Optional name for the support bundle. ' + 'Defaults to wo-support-bundle-YYYYMMDD-HHMMSS.', + ) + c.argument( + 'output_dir', + options_list=['--output-dir', '-d'], + help='Directory where the support bundle zip will be saved. Defaults to current directory.', + ) + c.argument( + 'namespaces', + options_list=['--namespaces'], + nargs='+', + help='Kubernetes namespaces to collect logs and resources from. ' + 'Defaults to kube-system, workloadorchestration, cert-manager.', + ) + c.argument( + 'tail_lines', + options_list=['--tail-lines'], + type=int, + help='Number of log lines to collect per container (default: 1000). ' + 'Use --full-logs to collect all lines.', + ) + c.argument( + 'full_logs', + options_list=['--full-logs'], + action='store_true', + help='Collect full container logs instead of tailing. ' + 'Warning: may produce very large bundles.', + ) + c.argument( + 'skip_checks', + options_list=['--skip-checks'], + action='store_true', + help='Skip prerequisite validation checks and only collect logs/resources.', + ) + c.argument( + 'skip_logs', + options_list=['--skip-logs'], + action='store_true', + help='Skip container log collection and only run checks/collect resources.', + ) + c.argument( + 'kube_config', + options_list=['--kube-config'], + help='Path to kubeconfig file. Defaults to ~/.kube/config.', + ) + c.argument( + 'kube_context', + options_list=['--kube-context'], + help='Kubernetes context to use. Defaults to current context.', + ) diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/_resource_validator.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/_resource_validator.py new file mode 100644 index 00000000000..2aee68816dc --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/_resource_validator.py @@ -0,0 +1,79 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: skip-file +# flake8: noqa + +from azure.cli.core.aaz import * +from azure.cli.core.azclierror import ValidationError + + +class ValidateResourceExists(AAZHttpOperation): + """Validates that an ARM resource exists by making a GET request to its resource ID.""" + CLIENT_TYPE = "MgmtClient" + + def __init__(self, ctx, resource_id, resource_label="Resource"): + super().__init__(ctx) + self._resource_id = str(resource_id) + self._resource_label = resource_label + + def __call__(self, *args, **kwargs): + request = self.make_request() + session = self.client.send_request(request=request, stream=False, **kwargs) + if session.http_response.status_code == 404: + raise ValidationError( + f"{self._resource_label} not found. The resource with ID '{self._resource_id}' does not exist. " + f"Please provide a valid {self._resource_label.lower()} resource ID." + ) + if session.http_response.status_code != 200: + raise ValidationError( + f"Failed to validate {self._resource_label.lower()} existence for ID '{self._resource_id}'. " + f"Received status code: {session.http_response.status_code}" + ) + + @property + def url(self): + return self.client.format_url( + "{resourceId}", + **self.url_parameters + ) + + @property + def method(self): + return "GET" + + @property + def error_format(self): + return "MgmtErrorFormat" + + @property + def url_parameters(self): + parameters = { + **self.serialize_url_param( + "resourceId", self._resource_id, + required=True, + skip_quote=True, + ), + } + return parameters + + @property + def query_parameters(self): + parameters = { + **self.serialize_query_param( + "api-version", "2025-06-01", + required=True, + ), + } + return parameters + + @property + def header_parameters(self): + parameters = { + **self.serialize_header_param( + "Accept", "application/json", + ), + } + return parameters diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/config_template/_link.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/config_template/_link.py index b7c274b5b03..6c3de2eabf7 100644 --- a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/config_template/_link.py +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/config_template/_link.py @@ -9,6 +9,8 @@ # flake8: noqa from azure.cli.core.aaz import * +from azext_workload_orchestration.aaz.latest.workload_orchestration._resource_validator import ValidateResourceExists + @register_command( "workload-orchestration config-template link", @@ -71,6 +73,9 @@ def _build_arguments_schema(cls, *args, **kwargs): def _execute_operations(self): self.pre_operations() + if has_value(self.ctx.args.hierarchy_ids): + for hierarchy_id in self.ctx.args.hierarchy_ids: + ValidateResourceExists(ctx=self.ctx, resource_id=hierarchy_id, resource_label="Hierarchy")() yield self.ConfigTemplatesLinkToHierarchies(ctx=self.ctx)() self.post_operations() diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/configuration/_config_helper.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/configuration/_config_helper.py index 9a478e6f949..5c694b761d5 100644 --- a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/configuration/_config_helper.py +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/configuration/_config_helper.py @@ -115,8 +115,19 @@ def try_get_config_id(lookup_id, api_version = "2025-08-01"): # If we reach here, no configuration was found - raise CLIInternalError(f"No configuration linked to this hierarchy: {hierarchy_id_str}") - + if "microsoft.edge/targets" in hierarchy_id_str.lower(): + raise CLIInternalError( + f"Missing target configuration and configuration reference for Target: {hierarchy_id_str}" + ) + elif "microsoft.edge/sites" in hierarchy_id_str.lower(): + raise CLIInternalError( + f"Missing site configuration and configuration reference for Site: {hierarchy_id_str}" + ) + else: + raise CLIInternalError( + f"Hierarchy Id can either be of Target or Site Resource. Invalid Id: {hierarchy_id_str}" + ) + @staticmethod def getTemplateUniqueIdentifier(subscription_id, template_resource_group_name, template_name, solution_flag, client): """ diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/context/site_reference/_create.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/context/site_reference/_create.py index 2372603feeb..96d10ade4a6 100644 --- a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/context/site_reference/_create.py +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/context/site_reference/_create.py @@ -9,6 +9,7 @@ # flake8: noqa from azure.cli.core.aaz import * +from azext_workload_orchestration.aaz.latest.workload_orchestration._resource_validator import ValidateResourceExists @register_command( @@ -78,6 +79,8 @@ def _build_arguments_schema(cls, *args, **kwargs): def _execute_operations(self): self.pre_operations() + if has_value(self.ctx.args.site_id): + ValidateResourceExists(ctx=self.ctx, resource_id=self.ctx.args.site_id, resource_label="Site")() yield self.SiteReferencesCreateOrUpdate(ctx=self.ctx)() self.post_operations() diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/solution_template/__init__.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/solution_template/__init__.py index 808d30168a5..488a2b0b037 100644 --- a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/solution_template/__init__.py +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/solution_template/__init__.py @@ -19,4 +19,5 @@ from ._bulk_publish_solution import * from ._bulk_review_solution import * # from ._update import * +from ._update_capabilities import * from ._wait import * diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/solution_template/_update_capabilities.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/solution_template/_update_capabilities.py new file mode 100644 index 00000000000..d14f0120816 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/solution_template/_update_capabilities.py @@ -0,0 +1,299 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# +# Code generated by aaz-dev-tools +# -------------------------------------------------------------------------------------------- + +# pylint: skip-file +# flake8: noqa + +from azure.cli.core.aaz import * + + +@register_command( + "workload-orchestration solution-template update-capabilities", +) +class UpdateCapabilities(AAZCommand): + """Update the capabilities of a Solution Template Resource + + :example: Update Solution Template Capabilities + az workload-orchestration solution-template update-capabilities -n mySolutionTemplate --capabilities "capability1" "capability2" "capability3" --location eastus --resource-group myResourceGroup + """ + + _aaz_info = { + "version": "2025-08-01", + "resources": [ + ["mgmt-plane", "/subscriptions/{}/resourcegroups/{}/providers/Microsoft.Edge/solutiontemplates/{}", "2025-08-01"], + ] + } + + AZ_SUPPORT_NO_WAIT = True + + def _handler(self, command_args): + super()._handler(command_args) + return self.build_lro_poller(self._execute_operations, self._output) + + _args_schema = None + + @classmethod + def _build_arguments_schema(cls, *args, **kwargs): + if cls._args_schema is not None: + return cls._args_schema + cls._args_schema = super()._build_arguments_schema(*args, **kwargs) + + _args_schema = cls._args_schema + _args_schema.resource_group = AAZResourceGroupNameArg( + required=True, + ) + _args_schema.solution_template_name = AAZStrArg( + options=["-n", "--name", "--solution-template-name"], + help="The name of the SolutionTemplate", + required=True, + fmt=AAZStrArgFormat( + pattern="^[a-zA-Z0-9-]{3,24}$", + ), + ) + + _args_schema.capabilities = AAZListArg( + options=["--capabilities"], + arg_group="Properties", + help="List of capabilities", + required=True, + ) + _args_schema.description = AAZStrArg( + options=["--description"], + arg_group="Properties", + help="Description of Solution template", + required=True, + ) + + capabilities = cls._args_schema.capabilities + capabilities.Element = AAZStrArg() + + _args_schema.location = AAZResourceLocationArg( + arg_group="Resource", + help="The geo-location where the resource lives", + required=True, + fmt=AAZResourceLocationArgFormat( + resource_group_arg="resource_group", + ), + ) + + return cls._args_schema + + def _execute_operations(self): + self.pre_operations() + yield self.SolutionTemplatesUpdateCapabilities(ctx=self.ctx)() + self.post_operations() + + @register_callback + def pre_operations(self): + pass + + @register_callback + def post_operations(self): + pass + + def _output(self, *args, **kwargs): + result = self.deserialize_output(self.ctx.vars.instance, client_flatten=True) + return result + + class SolutionTemplatesUpdateCapabilities(AAZHttpOperation): + CLIENT_TYPE = "MgmtClient" + + def __call__(self, *args, **kwargs): + request = self.make_request() + session = self.client.send_request(request=request, stream=False, **kwargs) + if session.http_response.status_code in [202]: + return self.client.build_lro_polling( + self.ctx.args.no_wait, + session, + self.on_200_201, + self.on_error, + lro_options={"final-state-via": "azure-async-operation"}, + path_format_arguments=self.url_parameters, + ) + if session.http_response.status_code in [200, 201]: + return self.client.build_lro_polling( + self.ctx.args.no_wait, + session, + self.on_200_201, + self.on_error, + lro_options={"final-state-via": "azure-async-operation"}, + path_format_arguments=self.url_parameters, + ) + + return self.on_error(session.http_response) + + @property + def url(self): + return self.client.format_url( + "/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Edge/solutionTemplates/{solutionTemplateName}", + **self.url_parameters + ) + + @property + def method(self): + return "PUT" + + @property + def error_format(self): + return "MgmtErrorFormat" + + @property + def url_parameters(self): + parameters = { + **self.serialize_url_param( + "resourceGroupName", self.ctx.args.resource_group, + required=True, + ), + **self.serialize_url_param( + "solutionTemplateName", self.ctx.args.solution_template_name, + required=True, + ), + **self.serialize_url_param( + "subscriptionId", self.ctx.subscription_id, + required=True, + ), + } + return parameters + + @property + def query_parameters(self): + parameters = { + **self.serialize_query_param( + "api-version", "2025-08-01", + required=True, + ), + } + return parameters + + @property + def header_parameters(self): + parameters = { + **self.serialize_header_param( + "Content-Type", "application/json", + ), + **self.serialize_header_param( + "Accept", "application/json", + ), + } + return parameters + + @property + def content(self): + _content_value, _builder = self.new_content_builder( + self.ctx.args, + typ=AAZObjectType, + typ_kwargs={"flags": {"required": True, "client_flatten": True}} + ) + _builder.set_prop("location", AAZStrType, ".location", typ_kwargs={"flags": {"required": True}}) + _builder.set_prop("properties", AAZObjectType) + + properties = _builder.get(".properties") + if properties is not None: + properties.set_prop("capabilities", AAZListType, ".capabilities", typ_kwargs={"flags": {"required": True}}) + properties.set_prop("description", AAZStrType, ".description", typ_kwargs={"flags": {"required": True}}) + + capabilities = _builder.get(".properties.capabilities") + if capabilities is not None: + capabilities.set_elements(AAZStrType, ".") + + return self.serialize_content(_content_value) + + def on_200_201(self, session): + data = self.deserialize_http_content(session) + self.ctx.set_var( + "instance", + data, + schema_builder=self._build_schema_on_200_201 + ) + + _schema_on_200_201 = None + + @classmethod + def _build_schema_on_200_201(cls): + if cls._schema_on_200_201 is not None: + return cls._schema_on_200_201 + + cls._schema_on_200_201 = AAZObjectType() + + _schema_on_200_201 = cls._schema_on_200_201 + _schema_on_200_201.e_tag = AAZStrType( + serialized_name="eTag", + flags={"read_only": True}, + ) + _schema_on_200_201.id = AAZStrType( + flags={"read_only": True}, + ) + _schema_on_200_201.location = AAZStrType( + flags={"required": True}, + ) + _schema_on_200_201.name = AAZStrType( + flags={"read_only": True}, + ) + _schema_on_200_201.properties = AAZObjectType() + _schema_on_200_201.system_data = AAZObjectType( + serialized_name="systemData", + flags={"read_only": True}, + ) + _schema_on_200_201.tags = AAZDictType() + _schema_on_200_201.type = AAZStrType( + flags={"read_only": True}, + ) + + properties = cls._schema_on_200_201.properties + properties.capabilities = AAZListType( + flags={"required": True}, + ) + properties.description = AAZStrType( + flags={"required": True}, + ) + properties.enable_external_validation = AAZBoolType( + serialized_name="enableExternalValidation", + ) + properties.latest_version = AAZStrType( + serialized_name="latestVersion", + flags={"read_only": True}, + ) + properties.provisioning_state = AAZStrType( + serialized_name="provisioningState", + flags={"read_only": True}, + ) + properties.state = AAZStrType() + + capabilities = cls._schema_on_200_201.properties.capabilities + capabilities.Element = AAZStrType() + + system_data = cls._schema_on_200_201.system_data + system_data.created_at = AAZStrType( + serialized_name="createdAt", + ) + system_data.created_by = AAZStrType( + serialized_name="createdBy", + ) + system_data.created_by_type = AAZStrType( + serialized_name="createdByType", + ) + system_data.last_modified_at = AAZStrType( + serialized_name="lastModifiedAt", + ) + system_data.last_modified_by = AAZStrType( + serialized_name="lastModifiedBy", + ) + system_data.last_modified_by_type = AAZStrType( + serialized_name="lastModifiedByType", + ) + + tags = cls._schema_on_200_201.tags + tags.Element = AAZStrType() + + return cls._schema_on_200_201 + + +class _UpdateCapabilitiesHelper: + """Helper class for UpdateCapabilities""" + + +__all__ = ["UpdateCapabilities"] \ No newline at end of file diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_solution_instance_list.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_solution_instance_list.py index b9b3eb5ec91..3b5908c1fe2 100644 --- a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_solution_instance_list.py +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_solution_instance_list.py @@ -9,6 +9,7 @@ # flake8: noqa from azure.cli.core.aaz import * +from ._target_helper import TargetHelper @register_command( @@ -17,7 +18,7 @@ class ListSolutionInstances(AAZCommand): """List all solution instances of a solution deployed on a target :example: - az workload-orchestration solution-instance list -g MyResourceGroup --solution-name MySolution + az workload-orchestration solution-instance list -g MyResourceGroup --solution-template-id MySolution """ _aaz_info = { @@ -56,14 +57,9 @@ def _build_arguments_schema(cls, *args, **kwargs): ), ) _args_schema.solution_name = AAZStrArg( - options=["--solution-template-name", "--solution"], - help="Name of the solution", + options=["--solution-template-id", "--solution-id"], + help="ARM resource ID of the solution template (e.g. /subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Edge/solutionTemplates/{name})", required=True, - fmt=AAZStrArgFormat( - pattern="^[a-zA-Z0-9]([-a-zA-Z0-9]*[a-zA-Z0-9])?(\\.[a-zA-Z0-9]([-a-zA-Z0-9]*[a-zA-Z0-9])?)*$", - max_length=61, - min_length=3, - ), ) return cls._args_schema @@ -89,6 +85,12 @@ class SolutionInstancesList(AAZHttpOperation): CLIENT_TYPE = "MgmtClient" def __call__(self, *args, **kwargs): + # Resolve solution template ARM resource ID to its uniqueIdentifier + self.unique_identifier = TargetHelper.get_solution_template_unique_identifier( + self.ctx.args.solution_name, + self.client + ) + request = self.make_request() session = self.client.send_request(request=request, stream=False, **kwargs) if session.http_response.status_code in [200]: @@ -127,7 +129,7 @@ def url_parameters(self): required=True, ), **self.serialize_url_param( - "solutionName", self.ctx.args.solution_name, + "solutionName", self.unique_identifier, required=True, ), } diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_solution_revision_list.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_solution_revision_list.py index 48bcb2fdc1e..9c358955e53 100644 --- a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_solution_revision_list.py +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_solution_revision_list.py @@ -9,6 +9,7 @@ # flake8: noqa from azure.cli.core.aaz import * +from ._target_helper import TargetHelper @register_command( @@ -18,7 +19,7 @@ class ListRevisions(AAZCommand): """List all revisions of a solution deployed on a target :example: List all revisions of a solution on a target - az workload-orchestration target solution-revision-list -g MyResourceGroup --target-name MyTarget --solution-name MySolution + az workload-orchestration target solution-revision-list -g MyResourceGroup --target-name MyTarget --solution-template-id MySolution """ _aaz_info = { @@ -49,14 +50,9 @@ def _build_arguments_schema(cls, *args, **kwargs): required=True, ) _args_schema.solution_name = AAZStrArg( - options=["--solution-template-name", "--solution"], - help="Name of the solution", + options=["--solution-template-id", "--solution-id"], + help="ARM resource ID of the solution template (e.g. /subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Edge/solutionTemplates/{name})", required=True, - fmt=AAZStrArgFormat( - pattern="^(?!v-)(?!.*-v-)[a-zA-Z0-9]([-a-zA-Z0-9]*[a-zA-Z0-9])?(\\.[a-zA-Z0-9]([-a-zA-Z0-9]*[a-zA-Z0-9])?)*$", - max_length=61, - min_length=3, - ), ) _args_schema.target_name = AAZStrArg( options=["--target-name", "--name", "-n"], @@ -92,6 +88,12 @@ class TargetSolutionRevisionsList(AAZHttpOperation): CLIENT_TYPE = "MgmtClient" def __call__(self, *args, **kwargs): + # Resolve solution template ARM resource ID to its uniqueIdentifier + self.unique_identifier = TargetHelper.get_solution_template_unique_identifier( + self.ctx.args.solution_name, + self.client + ) + request = self.make_request() session = self.client.send_request(request=request, stream=False, **kwargs) if session.http_response.status_code in [200]: @@ -122,7 +124,7 @@ def url_parameters(self): required=True, ), **self.serialize_url_param( - "solutionName", self.ctx.args.solution_name, + "solutionName", self.unique_identifier, required=True, ), **self.serialize_url_param( diff --git a/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_target_helper.py b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_target_helper.py new file mode 100644 index 00000000000..ef15fbe346d --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/aaz/latest/workload_orchestration/target/_target_helper.py @@ -0,0 +1,69 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +# Code generated by aaz-dev-tools. DO NOT EDIT. + +# pylint: skip-file +# flake8: noqa + + +class TargetHelper: + """Shared helper for target commands.""" + + @staticmethod + def get_solution_template_unique_identifier(solution_template_resource_id, client): + """Fetch the solution template by its full ARM resource ID and return its uniqueIdentifier. + + Args: + solution_template_resource_id: Full ARM resource ID of the solution template + (e.g. /subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Edge/solutionTemplates/{name}) + client: HTTP client for making the request + + Returns: + str: The uniqueIdentifier from template properties, or the template name extracted from the ID as fallback + + Raises: + CLIInternalError: If the template does not exist or the request fails + """ + from azure.cli.core.azclierror import CLIInternalError + import json + + template_url = client.format_url( + "{solutionTemplateId}", + solutionTemplateId=solution_template_resource_id + ) + request = client._request("GET", template_url, { + "api-version": "2025-08-01" + }, { + "Accept": "application/json" + }, None, {}, None) + + try: + response = client.send_request(request=request, stream=False) + + if response.http_response.status_code == 404: + raise CLIInternalError( + f"Solution template not found: '{solution_template_resource_id}'." + ) + if response.http_response.status_code != 200: + raise CLIInternalError( + f"Failed to get solution template '{solution_template_resource_id}': HTTP {response.http_response.status_code}" + ) + + data = json.loads(response.http_response.text()) + unique_identifier = data.get("properties", {}).get("uniqueIdentifier") + + if unique_identifier and unique_identifier.strip(): + return unique_identifier + # Fallback: extract the template name from the ARM resource ID + return solution_template_resource_id.rstrip("/").split("/")[-1] + except CLIInternalError: + # Propagate explicitly raised CLIInternalError instances unchanged. + raise + except Exception as exc: + # Wrap unexpected errors (e.g., network issues, JSON parsing failures) + # in CLIInternalError to match the documented behavior. + raise CLIInternalError( + f"Failed to get solution template '{solution_template_resource_id}': {exc}" + ) from exc diff --git a/src/workload-orchestration/azext_workload_orchestration/commands.py b/src/workload-orchestration/azext_workload_orchestration/commands.py index b0d842e4993..1f1d9c002a7 100644 --- a/src/workload-orchestration/azext_workload_orchestration/commands.py +++ b/src/workload-orchestration/azext_workload_orchestration/commands.py @@ -8,8 +8,7 @@ # pylint: disable=too-many-lines # pylint: disable=too-many-statements -# from azure.cli.core.commands import CliCommandType - def load_command_table(self, _): # pylint: disable=unused-argument - pass + with self.command_group('workload-orchestration support', is_preview=True) as g: + g.custom_command('create-bundle', 'create_support_bundle') diff --git a/src/workload-orchestration/azext_workload_orchestration/custom.py b/src/workload-orchestration/azext_workload_orchestration/custom.py index 86df1e48ef5..849a65ef9de 100644 --- a/src/workload-orchestration/azext_workload_orchestration/custom.py +++ b/src/workload-orchestration/azext_workload_orchestration/custom.py @@ -5,10 +5,5 @@ # Code generated by aaz-dev-tools # -------------------------------------------------------------------------------------------- -# pylint: disable=too-many-lines -# pylint: disable=too-many-statements - -from knack.log import get_logger - - -logger = get_logger(__name__) +# Support bundle command +from azext_workload_orchestration.support import create_support_bundle # pylint: disable=unused-import # noqa: F401 diff --git a/src/workload-orchestration/azext_workload_orchestration/support/README.md b/src/workload-orchestration/azext_workload_orchestration/support/README.md new file mode 100644 index 00000000000..7fd07d9a004 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/support/README.md @@ -0,0 +1,463 @@ +# Support Bundle Package + +Diagnostic toolkit for the `az workload-orchestration support create-bundle` CLI command. + +Collects Kubernetes cluster health data, container logs, prerequisite validation +checks, and WO-specific resource states into a ZIP bundle for troubleshooting +Day 0 (installation) and Day N (runtime) issues on 3rd-party clusters. + +## Architecture + +``` +support/ +├── __init__.py # Public API — exports create_support_bundle() +├── bundle.py # Orchestrator — wires everything together +├── consts.py # Constants — thresholds, namespaces, folder names +├── utils.py # Infrastructure — K8s client, safe API calls, file I/O +├── collectors.py # Data collection — resources, logs, metrics +├── validators.py # Health checks — 18 checks across 10 categories +└── README.md # This file +``` + +### Data Flow + +``` +create_support_bundle() # bundle.py — entry point + │ + ├── get_kubernetes_client() # utils.py — connect to cluster + ├── create_bundle_directory() # utils.py — create folder structure + ├── detect_cluster_capabilities() # utils.py — detect installed components + │ + ├── run_all_checks() # validators.py — 18 prerequisite checks + │ ├── _check_k8s_version() + │ ├── _check_node_readiness() + │ ├── _check_dns_health() + │ └── ... (18 total) + │ + ├── collect_cluster_info() # collectors.py — version, nodes, namespaces + ├── collect_cluster_resources() # collectors.py — SCs, CRDs, webhooks + ├── validate_namespaces() # collectors.py — skip invalid namespaces + │ + ├── for each namespace: + │ ├── collect_namespace_resources() # pods, deployments, services, etc. + │ ├── collect_container_logs() # threaded log collection + │ ├── collect_previous_logs() # crash-looping pod logs + │ ├── collect_resource_quotas() # quotas & limit ranges + │ └── collect_pvcs() # persistent volume claims + │ + ├── collect_wo_components() # collectors.py — Symphony, cert-manager, Gatekeeper + ├── collect_metrics() # collectors.py — node/pod metrics + │ + ├── _compute_health_summary() # bundle.py — score 0-100, status + ├── write metadata.json # bundle.py — full bundle metadata + └── create_zip_bundle() # utils.py — zip + cleanup +``` + +### Bundle Output Structure + +``` +wo-support-bundle-YYYYMMDD-HHMMSS.zip +├── metadata.json # Bundle info, health summary, capabilities +├── cluster-info/ +│ ├── cluster-info.json # K8s version, nodes, namespaces +│ ├── capabilities.json # Detected components (Symphony, cert-manager, etc.) +│ └── metrics.json # Node/pod resource usage (if metrics-server) +├── checks/ +│ ├── cluster-info--k8s-version.json +│ ├── node-health--node-readiness.json +│ └── ... (one file per check) +├── resources/ +│ ├── cluster-resources.json # StorageClasses, CRDs, webhooks, CSI +│ ├── kube-system-resources.json # Pods, deployments, services per ns +│ ├── kube-system-quotas.json # ResourceQuotas, LimitRanges +│ └── wo-components.json # Symphony targets, ClusterIssuers +└── logs/ + ├── kube-system/ + │ ├── coredns-abc--coredns.log + │ └── ... + └── workloadorchestration/ + ├── symphony-api-xyz--symphony-api.log + ├── symphony-api-xyz--symphony-api--previous.log # crashed container + └── ... +``` + +## Module Guide + +### consts.py — Constants + +All tunable values in one place. No business logic. + +| Constant Group | Examples | Purpose | +|----------------|----------|---------| +| Bundle defaults | `DEFAULT_TAIL_LINES=1000`, `DEFAULT_API_TIMEOUT_SECONDS=30` | Collection behavior | +| Retry | `DEFAULT_MAX_RETRIES=3`, `DEFAULT_RETRY_BACKOFF_BASE=1.0` | API call resilience | +| Namespaces | `WO_NAMESPACE`, `DEFAULT_NAMESPACES`, `PROTECTED_NAMESPACES` | Which namespaces to collect | +| API groups | `API_GROUP_SYMPHONY`, `API_GROUP_CERT_MANAGER` | Capability detection | +| Thresholds | `MIN_CPU_CORES=2`, `MIN_MEMORY_GI=4`, `MIN_NODE_COUNT_PROD=3` | Prerequisite minimums | +| Folder names | `FOLDER_LOGS`, `FOLDER_CHECKS`, `FOLDER_RESOURCES` | Bundle directory layout | +| Status values | `STATUS_PASS`, `STATUS_FAIL`, `STATUS_WARN` | Check result statuses | + +### utils.py — Infrastructure + +Shared utilities used by collectors and validators. + +| Function | Purpose | +|----------|---------| +| `get_kubernetes_client()` | Initialize K8s API clients from kubeconfig | +| `safe_api_call()` | Wrap any K8s API call with retry, timeout, and RBAC error detection | +| `create_bundle_directory()` | Create the bundle folder structure | +| `create_zip_bundle()` | Zip the bundle and clean up raw directory | +| `detect_cluster_capabilities()` | Detect installed components (Symphony, cert-manager, Gatekeeper, etc.) | +| `write_json()` / `write_text()` | Safe file writers (never crash on I/O errors) | +| `write_check_result()` | Write a check result to the checks/ folder | +| `parse_cpu()` / `parse_memory_gi()` | Parse K8s resource strings (`"3860m"` → `3.86`) | +| `format_bytes()` | Human-readable byte formatting | +| `check_disk_space()` | Pre-flight disk space check | + +**Key pattern — `safe_api_call()`:** +```python +result, err = safe_api_call( + core.list_namespaced_pod, namespace, + description="list pods in kube-system", # for error messages + max_retries=3, # retries on 500/502/503/504 + timeout_seconds=30, # per-call timeout +) +if err: + logger.warning("Failed: %s", err) +else: + process(result) +``` + +### collectors.py — Data Collection + +Gathers cluster state into the bundle directory. + +| Function | What it collects | Output location | +|----------|-----------------|-----------------| +| `validate_namespaces()` | Pre-flight namespace existence check | (no file — filters list) | +| `collect_cluster_info()` | K8s version, nodes, namespaces | `cluster-info/cluster-info.json` | +| `collect_namespace_resources()` | Pods, Deployments, Services, DaemonSets, StatefulSets, ReplicaSets, Jobs, CronJobs, Ingresses, NetworkPolicies, ServiceAccounts, Events, ConfigMaps | `resources/{ns}-resources.json` | +| `collect_cluster_resources()` | StorageClasses, PVs, webhooks, CRDs, CSI drivers | `resources/cluster-resources.json` | +| `collect_container_logs()` | Container logs (threaded, with tail + truncation) | `logs/{ns}/{pod}--{container}.log` | +| `collect_previous_logs()` | Previous logs for crash-looping containers | `logs/{ns}/{pod}--{container}--previous.log` | +| `collect_wo_components()` | Symphony targets, ClusterIssuers, Gatekeeper templates | `resources/wo-components.json` | +| `collect_resource_quotas()` | ResourceQuotas, LimitRanges | `resources/{ns}-quotas.json` | +| `collect_pvcs()` | PersistentVolumeClaims | `resources/{ns}-pvcs.json` | +| `collect_metrics()` | Node/pod metrics (if metrics-server available) | `cluster-info/metrics.json` | + +### validators.py — Health Checks + +18 prerequisite checks organized in `run_all_checks()`. + +| # | Check Name | Category | What it validates | +|---|------------|----------|-------------------| +| 1 | `k8s-version` | cluster-info | Server version ≥ 1.24.0 | +| 2 | `node-readiness` | node-health | All nodes Ready, no pressure conditions | +| 3 | `node-capacity` | node-health | CPU ≥ 2 cores, Memory ≥ 4Gi per node | +| 4 | `cluster-resources` | node-health | Total cluster CPU/memory | +| 5 | `dns-pods` | dns-health | CoreDNS pods running | +| 6 | `dns-resolution` | dns-health | External DNS resolution works | +| 7 | `default-storage-class` | storage | Default StorageClass exists | +| 8 | `csi-drivers` | storage | CSI drivers installed | +| 9 | `cert-manager-installed` | cert-manager | cert-manager pods running | +| 10 | `wo-namespace` | wo-components | workloadorchestration ns exists | +| 11 | `protected-namespace` | wo-components | WO ns is not a protected namespace | +| 12 | `wo-pods` | wo-components | All WO pods running | +| 13 | `wo-webhooks` | wo-components | Symphony webhooks configured | +| 14 | `policy-engines` | admission-controllers | Gatekeeper/Kyverno detected | +| 15 | `psa-labels` | admission-controllers | Pod Security Admission labels | +| 16 | `resource-quotas` | wo-components | Quota usage on WO namespace | +| 17 | `image-pull-secrets` | registry-access | Image pull secrets available | +| 18 | `proxy-settings` | wo-components | Proxy env vars in WO pods | + +### bundle.py — Orchestrator + +Main entry point. Wires collectors + validators together. + +| Function | Purpose | +|----------|---------| +| `create_support_bundle()` | Main orchestration — called by CLI | +| `_compute_health_summary()` | Score 0-100 + HEALTHY/DEGRADED/CRITICAL | +| `_out()` | Console output (uses logger.warning per az CLI convention) | +| `_print_cluster_info()` | Format cluster info for console | +| `_print_capabilities()` | Format detected capabilities | +| `_print_check_results()` | Format check results with [PASS]/[FAIL] icons | + +## How to Add a New Check + +Adding a new prerequisite check takes 3 steps: write the check, register it, add a test. + +### Step 1: Write the check function in `validators.py` + +Every check function has the **exact same signature** — 4 arguments, returns a dict: + +```python +def _check_my_new_thing(clients, bundle_dir, cluster_info, capabilities): + """Check that my new thing is properly configured.""" +``` + +**Arguments available to every check:** + +| Argument | Type | What it gives you | +|----------|------|-------------------| +| `clients` | dict | K8s API clients: `clients["core_v1"]` (CoreV1Api), `clients["apps_v1"]` (AppsV1Api), `clients["storage_v1"]`, `clients["admissionregistration_v1"]`, `clients["custom_objects"]` | +| `bundle_dir` | str | Path to bundle directory — pass to `write_check_result()` | +| `cluster_info` | dict | Pre-collected cluster data: `cluster_info["nodes"]` (list), `cluster_info["server_version"]`, `cluster_info["namespaces"]` | +| `capabilities` | dict | Detected components: `capabilities["has_symphony"]`, `capabilities["has_cert_manager"]`, `capabilities["has_gatekeeper"]`, `capabilities["has_metrics"]`, `capabilities["has_kyverno"]`, `capabilities["has_openshift"]` | + +**Template — copy this and modify:** + +```python +def _check_my_new_thing(clients, bundle_dir, cluster_info, capabilities): + """Check that my new thing is properly configured.""" + core = clients["core_v1"] + + # 1. Call K8s API using safe_api_call (handles timeouts, retries, RBAC) + result, err = safe_api_call( + core.list_namespaced_pod, "my-namespace", + description="list pods in my-namespace", + ) + + # 2. Handle API errors gracefully (never crash) + if err: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "my-new-check", + STATUS_WARN, f"Could not verify: {err}" + ) + + # 3. Validate and return PASS/FAIL/WARN + pods = result.items + if len(pods) >= 1: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "my-new-check", + STATUS_PASS, f"{len(pods)} pod(s) found", + details={"pod_count": len(pods)}, # optional extra data + ) + else: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "my-new-check", + STATUS_FAIL, "No pods found — ensure my-thing is installed" + ) +``` + +**Rules for check functions:** +- Signature must be `(clients, bundle_dir, cluster_info, capabilities)` — always 4 args +- Always use `safe_api_call()` for K8s API calls — never call APIs directly +- Always return `write_check_result()` — never return raw dicts +- Use `STATUS_PASS`, `STATUS_FAIL`, `STATUS_WARN`, or `STATUS_SKIP` from consts +- Never raise exceptions — handle errors and return WARN/ERROR status +- Use `capabilities` dict to skip checks when a component isn't installed + +**Real example — checking if a CRD exists:** + +```python +def _check_symphony_crds(clients, bundle_dir, cluster_info, capabilities): + """Check that Symphony CRDs are installed.""" + if not capabilities.get("has_symphony"): + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "symphony-crds", + STATUS_SKIP, "Symphony not detected on this cluster" + ) + + custom = clients["custom_objects"] + result, err = safe_api_call( + custom.list_cluster_custom_object, + "apiextensions.k8s.io", "v1", "customresourcedefinitions", + description="list CRDs for Symphony check", + ) + if err: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "symphony-crds", + STATUS_WARN, f"Could not list CRDs: {err}" + ) + + symphony_crds = [ + c for c in result.get("items", []) + if "symphony" in c.get("spec", {}).get("group", "") + ] + if symphony_crds: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "symphony-crds", + STATUS_PASS, f"{len(symphony_crds)} Symphony CRD(s) installed", + details={"crds": [c["metadata"]["name"] for c in symphony_crds]}, + ) + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "symphony-crds", + STATUS_FAIL, "No Symphony CRDs found — WO extension may not be installed" + ) +``` + +### Step 2: Register in `run_all_checks()` + +Add one line to the `checks` list in `validators.py`: + +```python +checks = [ + (_check_k8s_version, "Kubernetes version compatibility"), + # ... existing checks ... + (_check_proxy_settings, "Proxy configuration"), + (_check_my_new_thing, "My new thing"), # ← ADD HERE +] +``` + +The string (second element) is a human-readable description used in log messages. +The check will automatically: +- Run during bundle creation +- Show `[PASS]`/`[FAIL]`/`[WARN]` in console output +- Write result JSON to `checks/{category}--{check-name}.json` +- Count toward the health summary score (PASS=100%, WARN=50%, FAIL=0%) + +### Step 3: Add a unit test in `test_support_bundle.py` + +Every check should have at least 2 tests: one for PASS, one for FAIL/WARN. + +```python +class TestMyNewCheck(unittest.TestCase): + """Tests for _check_my_new_thing.""" + + def _run_check(self, pods): + """Helper: run the check with mocked pods.""" + from azext_workload_orchestration.support.validators import _check_my_new_thing + + # Build mock pod list + pod_list = MagicMock() + pod_list.items = pods + + # Build mock clients + clients = {"core_v1": MagicMock(), "apps_v1": MagicMock(), + "custom_objects": MagicMock(), "storage_v1": MagicMock(), + "admissionregistration_v1": MagicMock(), "apis": MagicMock(), + "version": MagicMock()} + clients["core_v1"].list_namespaced_pod = MagicMock(return_value=pod_list) + + cluster_info = {"nodes": [], "server_version": {}, "namespaces": []} + capabilities = {"has_symphony": True, "has_cert_manager": True} + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, "checks"), exist_ok=True) + return _check_my_new_thing(clients, tmpdir, cluster_info, capabilities) + + def test_pods_found_passes(self): + pod = MagicMock() + pod.metadata.name = "my-pod" + result = self._run_check([pod]) + self.assertEqual(result["status"], "PASS") + + def test_no_pods_fails(self): + result = self._run_check([]) + self.assertEqual(result["status"], "FAIL") + + def test_api_error_returns_warn(self): + from azext_workload_orchestration.support.validators import _check_my_new_thing + from kubernetes.client.exceptions import ApiException + + clients = {"core_v1": MagicMock(), "apps_v1": MagicMock(), + "custom_objects": MagicMock(), "storage_v1": MagicMock(), + "admissionregistration_v1": MagicMock(), "apis": MagicMock(), + "version": MagicMock()} + clients["core_v1"].list_namespaced_pod = MagicMock( + side_effect=ApiException(status=403, reason="Forbidden") + ) + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, "checks"), exist_ok=True) + result = _check_my_new_thing(clients, tmpdir, {}, {}) + self.assertEqual(result["status"], "WARN") + self.assertIn("403", result["message"]) +``` + +### Checklist for adding a new check + +- [ ] Function name starts with `_check_` and is in `validators.py` +- [ ] Uses `safe_api_call()` for all K8s API calls +- [ ] Returns `write_check_result()` in every code path (PASS, FAIL, WARN, SKIP) +- [ ] Handles API errors gracefully (never raises) +- [ ] Uses `capabilities` to skip when component isn't installed +- [ ] Registered in the `checks` list in `run_all_checks()` +- [ ] Has at least 2 unit tests (PASS path + FAIL/error path) +- [ ] All 170+ tests still pass after adding + +## How to Add a New Resource Collector + +### Namespace-scoped resource + +Add a block to `collect_namespace_resources()` in `collectors.py`. Pattern: + +```python +# HorizontalPodAutoscalers +try: + from kubernetes import client as _k8s_client + autoscaling_v2 = _k8s_client.AutoscalingV2Api() + result, err = safe_api_call( + autoscaling_v2.list_namespaced_horizontal_pod_autoscaler, namespace, + description=f"list HPAs in {namespace}" + ) + if result: + resources["hpas"] = [ + {"name": h.metadata.name, "min": h.spec.min_replicas, "max": h.spec.max_replicas} + for h in result.items + ] +except Exception as ex: + logger.debug("Autoscaling API not available: %s", ex) +``` + +**Rules:** +- Use `safe_api_call()` — never call K8s APIs directly +- Wrap non-core APIs in `try/except` (they may not be available on all clusters) +- Only extract fields that are useful for diagnostics (name, status, counts) +- Never collect secrets/tokens/credentials +- Add the resource key to the `resources` dict (e.g., `resources["hpas"]`) + +### Cluster-scoped resource + +Same pattern, but add to `collect_cluster_resources()` in `collectors.py`. +Uses `list_*` instead of `list_namespaced_*`. + +### Adding a test for a new collector + +```python +class TestCollectHPAs(unittest.TestCase): + def test_hpas_collected(self): + from azext_workload_orchestration.support.collectors import collect_namespace_resources + + hpa = MagicMock() + hpa.metadata.name = "my-hpa" + hpa.spec.min_replicas = 1 + hpa.spec.max_replicas = 10 + hpa_list = MagicMock() + hpa_list.items = [hpa] + + # ... setup clients with mock, call collect_namespace_resources + # ... assert "hpas" in result +``` + +## CLI Parameters + +| Parameter | Python arg | Default | Description | +|-----------|-----------|---------|-------------| +| `--output-dir` / `-d` | `output_dir` | cwd | Where to save the zip | +| `--namespaces` | `namespaces` | kube-system, workloadorchestration, cert-manager | Namespaces to collect | +| `--tail-lines` | `tail_lines` | 1000 | Log lines per container | +| `--full-logs` | `full_logs` | False | Collect complete logs | +| `--skip-checks` | `skip_checks` | False | Skip prerequisite checks | +| `--skip-logs` | `skip_logs` | False | Skip log collection | +| `--kube-config` | `kube_config` | ~/.kube/config | Kubeconfig path | +| `--kube-context` | `kube_context` | current | K8s context name | + +## Testing + +```bash +# Run all 170 unit tests +cd azure-cli-extensions/src/workload-orchestration +python -m pytest azext_workload_orchestration/tests/test_support_bundle.py -v + +# Run specific test class +python -m pytest ... -k "TestHealthSummary" + +# Run with coverage +python -m pytest ... --cov=azext_workload_orchestration.support +``` + +Test file: `tests/test_support_bundle.py` (~2100 lines, 170 tests) + +Tests mock the kubernetes client — no live cluster needed for unit tests. diff --git a/src/workload-orchestration/azext_workload_orchestration/support/__init__.py b/src/workload-orchestration/azext_workload_orchestration/support/__init__.py new file mode 100644 index 00000000000..c91713eef18 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/support/__init__.py @@ -0,0 +1,22 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +"""Support bundle package for workload-orchestration CLI extension. + +This package provides the ``az workload-orchestration support create-bundle`` command +which collects Kubernetes cluster diagnostics, runs prerequisite validation checks, +and packages everything into a zip bundle for troubleshooting. + +Modules: + consts — Constants (namespaces, thresholds, folder names, API groups) + utils — K8s client initialization, safe API calls, file writers, parsers + collectors — Resource descriptions, container logs, and metrics collection + validators — 10 prerequisite validation categories with 50+ individual checks + bundle — Main orchestration logic that ties everything together +""" + +from azext_workload_orchestration.support.bundle import create_support_bundle + +__all__ = ["create_support_bundle"] diff --git a/src/workload-orchestration/azext_workload_orchestration/support/bundle.py b/src/workload-orchestration/azext_workload_orchestration/support/bundle.py new file mode 100644 index 00000000000..228c059db31 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/support/bundle.py @@ -0,0 +1,706 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=too-many-lines,too-many-statements,too-many-branches +# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments +# pylint: disable=broad-exception-caught,consider-using-f-string +# pylint: disable=import-outside-toplevel,raise-missing-from +# pylint: disable=unused-argument,unspecified-encoding + +"""Main orchestration for the support bundle command.""" + +import os +import time +from datetime import datetime, timezone + +from knack.log import get_logger + +from azext_workload_orchestration.support.consts import ( + DEFAULT_NAMESPACES, + DEFAULT_TAIL_LINES, + STATUS_PASS, + STATUS_FAIL, + STATUS_WARN, + FOLDER_CLUSTER_INFO, +) +from azext_workload_orchestration.support.utils import ( + get_kubernetes_client, + create_bundle_directory, + create_zip_bundle, + detect_cluster_capabilities, + write_json, + format_bytes, + check_disk_space, +) + +logger = get_logger(__name__) + + +def create_support_bundle(cmd, + bundle_name=None, + output_dir=None, + namespaces=None, + tail_lines=None, + full_logs=False, + skip_checks=False, + skip_logs=False, + kube_config=None, + kube_context=None): + """Create a support bundle for troubleshooting workload orchestration issues.""" + from azure.cli.core.azclierror import CLIError + from azext_workload_orchestration.support.collectors import ( + collect_cluster_info, + collect_namespace_resources, + collect_cluster_resources, + collect_container_logs, + collect_wo_components, + collect_previous_logs, + collect_resource_quotas, + collect_metrics, + collect_pvcs, + validate_namespaces, + collect_network_config, + collect_all_events, + ) + from azext_workload_orchestration.support.validators import run_all_checks + + start_time = time.time() + namespaces = namespaces or DEFAULT_NAMESPACES + tail = None if full_logs else (tail_lines or DEFAULT_TAIL_LINES) + errors = [] + + # --- Step 1: Initialize K8s clients --- + _out("") + _out("Connecting to Kubernetes cluster...") + clients = get_kubernetes_client(kube_config=kube_config, kube_context=kube_context) + + # Show connection details + ctx = clients.get("context_info", {}) + _out(" Context: %s", ctx.get("context", "unknown")) + _out(" Cluster: %s", ctx.get("cluster", "unknown")) + + # Verify we can actually reach the cluster + try: + version_result = clients["version"].get_code() + _out(" Connected: Kubernetes %s", version_result.git_version) + except Exception as ex: + raise CLIError( + f"Cannot reach Kubernetes cluster: {ex}. " + f"Context '{ctx.get('context', '?')}' may be stale or the " + "cluster may be unreachable. Try running " + "'az aks get-credentials' to refresh." + ) + + # --- Step 2: Create bundle directory --- + try: + bundle_dir, bundle_name = create_bundle_directory(output_dir, bundle_name) + except Exception as ex: + raise CLIError( + f"Failed to create bundle directory: {ex}. " + f"Check that the output directory '{output_dir or os.getcwd()}' exists " + "and you have write permissions." + ) + + # Pre-flight: check disk space + ok, free = check_disk_space(output_dir or os.getcwd(), 100 * 1024 * 1024) + if not ok: + _out(" [WARN] Low disk space (%s free). Bundle may fail.", format_bytes(free)) + + # --- Step 3: Collect cluster info --- + cluster_info = {} + _out("") + _out("Collecting cluster information...") + try: + cluster_info = collect_cluster_info(clients, bundle_dir) + _print_cluster_info(cluster_info) + except Exception as ex: + err_msg = "Step 3 - Collect cluster info failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 4: Detect capabilities --- + capabilities = {} + try: + capabilities = detect_cluster_capabilities(clients) + write_json( + os.path.join(bundle_dir, FOLDER_CLUSTER_INFO, "capabilities.json"), + capabilities, + ) + _print_capabilities(capabilities) + except Exception as ex: + err_msg = "Step 4 - Detect capabilities failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 4b: Collect cluster-wide events --- + try: + all_events = collect_all_events(clients, bundle_dir) + if all_events: + warning_count = sum(1 for e in all_events if e["type"] == "Warning") + _out(" Events: %d total (%d warnings)", len(all_events), warning_count) + except Exception as ex: + err_msg = "Step 4b - Collect cluster events failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 5: Run prerequisite checks --- + check_results = [] + if not skip_checks: + _out("") + _out("Running prerequisite checks...") + _out("-" * 58) + try: + check_results = run_all_checks(clients, bundle_dir, cluster_info, capabilities) + _print_check_results(check_results) + except Exception as ex: + err_msg = "Step 5 - Prerequisite checks failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 6: Collect cluster-scoped resources --- + _out("") + _out("Collecting resources...") + try: + cluster_res = collect_cluster_resources(clients, bundle_dir) + sc_count = len(cluster_res.get("storage_classes", [])) + wh_count = len(cluster_res.get("validating_webhooks", [])) + len(cluster_res.get("mutating_webhooks", [])) + crd_count = len(cluster_res.get("crds", [])) + _out(" Cluster-scoped: %d StorageClasses, %d webhooks, %d CRDs", sc_count, wh_count, crd_count) + except Exception as ex: + err_msg = "Step 6 - Collect cluster-scoped resources failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 6b: Validate namespaces exist --- + skipped_ns = [] + try: + namespaces, skipped_ns = validate_namespaces(clients, namespaces) + if skipped_ns: + for ns, reason in skipped_ns: + _out(" [SKIP] Namespace '%s': %s", ns, reason) + if not namespaces: + _out(" [WARN] No valid namespaces to collect resources from") + except Exception as ex: + err_msg = "Step 6b - Namespace validation failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s (proceeding with original list)", err_msg) + + # --- Step 7: Collect per-namespace resources --- + for ns in namespaces: + try: + ns_res = collect_namespace_resources(clients, bundle_dir, ns) + collect_resource_quotas(clients, bundle_dir, ns) + collect_pvcs(clients, bundle_dir, ns) + pod_count = len(ns_res.get("pods", [])) + dep_count = len(ns_res.get("deployments", [])) + svc_count = len(ns_res.get("services", [])) + job_count = len(ns_res.get("jobs", [])) + parts = ["%d pods" % pod_count, "%d deployments" % dep_count, + "%d services" % svc_count] + if job_count: + parts.append("%d jobs" % job_count) + rs_count = len(ns_res.get("replicasets", [])) + if rs_count: + parts.append("%d replicasets" % rs_count) + _out(" %s: %s", ns, ", ".join(parts)) + except Exception as ex: + err_msg = "Step 7 - Collect namespace '%s' resources failed: %s" % (ns, ex) + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 8: Collect WO-specific components --- + try: + wo_res = collect_wo_components(clients, bundle_dir, capabilities) + if wo_res: + parts = [] + if "symphony_targets" in wo_res: + parts.append("%d Symphony targets" % len(wo_res["symphony_targets"])) + if "cluster_issuers" in wo_res: + parts.append("%d ClusterIssuers" % len(wo_res["cluster_issuers"])) + if "gatekeeper_templates" in wo_res: + parts.append("%d Gatekeeper templates" % len(wo_res["gatekeeper_templates"])) + if parts: + _out(" WO components: %s", ", ".join(parts)) + except Exception as ex: + err_msg = "Step 8 - Collect WO components failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 8b: Collect metrics --- + try: + metrics = collect_metrics(clients, bundle_dir, capabilities) + if metrics: + nm = len(metrics.get("node_metrics", [])) + pm = len(metrics.get("wo_pod_metrics", [])) + _out(" Metrics: %d node(s), %d WO pod(s)", nm, pm) + except Exception as ex: + err_msg = "Step 8b - Collect metrics failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 8c: Collect network configuration --- + try: + net_info = collect_network_config(clients, bundle_dir) + if net_info: + parts = [] + if net_info.get("kube_proxy_config"): + parts.append("kube-proxy config") + ep_count = len(net_info.get("endpoint_slices", [])) + if ep_count: + parts.append("%d endpoint slices" % ep_count) + svc_count = len(net_info.get("external_services", [])) + if svc_count: + parts.append("%d external services" % svc_count) + if parts: + _out(" Network: %s", ", ".join(parts)) + except Exception as ex: + err_msg = "Step 8c - Collect network config failed: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 9: Collect container logs --- + total_logs = 0 + total_prev = 0 + if not skip_logs: + _out("") + _out("Collecting container logs%s...", + "" if full_logs else " (tail=%d lines)" % tail) + for ns in namespaces: + try: + count = collect_container_logs(clients, bundle_dir, ns, tail_lines=tail) + total_logs += count + prev = collect_previous_logs(clients, bundle_dir, ns, tail_lines=tail) + total_prev += prev + extra = " + %d previous" % prev if prev else "" + _out(" %s: %d logs%s", ns, count, extra) + except Exception as ex: + err_msg = "Step 9 - Collect logs for namespace '%s' failed: %s" % (ns, ex) + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + + # --- Step 10: Write bundle metadata --- + elapsed = time.time() - start_time + health_summary = _compute_health_summary(check_results, errors) + metadata = { + "bundle_name": bundle_name, + "created_at": datetime.now(timezone.utc).isoformat(), + "collection_time_seconds": round(elapsed, 1), + "health_summary": health_summary, + "namespaces_collected": namespaces, + "namespaces_skipped": [{"name": ns, "reason": r} for ns, r in skipped_ns] if skipped_ns else None, + "tail_lines": tail, + "full_logs": full_logs, + "skip_checks": skip_checks, + "skip_logs": skip_logs, + "total_logs_collected": total_logs, + "total_previous_logs": total_prev, + "check_count": len(check_results), + "capabilities": capabilities, + "cluster_version": cluster_info.get("server_version", {}).get("git_version", "unknown"), + "node_count": cluster_info.get("node_count", 0), + "errors": errors if errors else None, + } + write_json(os.path.join(bundle_dir, "metadata.json"), metadata) + + # --- Step 10b: Write checks summary --- + if check_results: + from azext_workload_orchestration.support.consts import FOLDER_CHECKS + checks_summary = { + "total": len(check_results), + "passed": sum(1 for c in check_results if c.get("status") == STATUS_PASS), + "failed": sum(1 for c in check_results if c.get("status") == STATUS_FAIL), + "warned": sum(1 for c in check_results if c.get("status") == STATUS_WARN), + "skipped": sum(1 for c in check_results if c.get("status") == "SKIP"), + "errored": sum(1 for c in check_results if c.get("status") == "ERROR"), + "checks": [ + { + "name": c.get("check_name", "unknown"), + "category": c.get("category", "unknown"), + "status": c.get("status", "UNKNOWN"), + "message": c.get("message", ""), + } + for c in check_results + ], + } + write_json(os.path.join(bundle_dir, FOLDER_CHECKS, "summary.json"), checks_summary) + + # --- Step 10c: Write human-readable summary --- + _write_summary_md(bundle_dir, bundle_name, cluster_info, capabilities, + check_results, namespaces, total_logs, total_prev, errors) + + # --- Step 11: Create zip --- + zip_path = create_zip_bundle(bundle_dir, bundle_name, output_dir) + + try: + zip_size = os.path.getsize(zip_path) + except OSError as ex: + err_msg = "Failed to read zip file size: %s" % ex + errors.append(err_msg) + _out(" [ERROR] %s", err_msg) + zip_size = 0 + + # --- Final summary --- + passed = sum(1 for c in check_results if c.get("status") == STATUS_PASS) + failed = sum(1 for c in check_results if c.get("status") == STATUS_FAIL) + warned = sum(1 for c in check_results if c.get("status") == STATUS_WARN) + + _out("") + _out("=" * 58) + if errors: + _out(" Support bundle created with %d error(s)", len(errors)) + else: + _out(" Support bundle created successfully!") + _out("") + _out(" File: %s", zip_path) + _out(" Size: %s", format_bytes(zip_size)) + _out(" Time: %.1fs", elapsed) + _out("") + if check_results: + _out(" Checks: %d passed, %d failed, %d warnings", passed, failed, warned) + if not skip_logs: + log_msg = " Logs: %d container logs" % total_logs + if total_prev: + log_msg += " + %d previous" % total_prev + _out(log_msg) + if errors: + _out("") + _out(" Errors:") + for err in errors: + _out(" - %s", err) + _out("=" * 58) + _out("") + + return { + "bundle_path": zip_path, + "bundle_size": zip_size, + "bundle_size_human": format_bytes(zip_size), + "collection_time_seconds": round(elapsed, 1), + "logs_collected": total_logs, + "previous_logs_collected": total_prev, + "checks_run": len(check_results), + "checks_passed": passed, + "checks_failed": failed, + "checks_warned": warned, + "errors": errors if errors else None, + } + + +def _compute_health_summary(check_results, errors): + """Compute a health summary from check results. + + Returns a dict with check counts and collection error count. + """ + if not check_results: + return { + "checks_total": 0, + "checks_passed": 0, + "checks_failed": 0, + "checks_warned": 0, + "collection_errors": len(errors) if errors else 0, + } + + return { + "checks_total": len(check_results), + "checks_passed": sum(1 for c in check_results if c.get("status") == STATUS_PASS), + "checks_failed": sum(1 for c in check_results if c.get("status") == STATUS_FAIL), + "checks_warned": sum(1 for c in check_results if c.get("status") == STATUS_WARN), + "collection_errors": len(errors) if errors else 0, + } + + +def _append_namespace_resources(bundle_dir, namespaces, lines): + """Append per-namespace resource counts to summary lines.""" + import json + for ns in namespaces: + res_file = os.path.join(bundle_dir, "resources", ns, "resources.json") + if not os.path.exists(res_file): + continue + try: + with open(res_file, "r") as f: + res_data = json.load(f) + parts = [ + f"{len(items)} {key}" + for key, items in res_data.items() + if isinstance(items, list) and items + ] + if parts: + lines.append(f"**{ns}:** {', '.join(parts)}") + lines.append("") + except Exception: # pylint: disable=broad-exception-caught + pass + + +def _append_wo_components(bundle_dir, lines): + """Append WO component details to summary lines.""" + import json + wo_file = os.path.join(bundle_dir, "resources", "cluster", "wo-components.json") + if not os.path.exists(wo_file): + return + try: + with open(wo_file, "r") as f: + wo_data = json.load(f) + if not wo_data: + return + lines.append("### WO Components") + lines.append("") + for key, items in wo_data.items(): + if not isinstance(items, list): + continue + label = key.replace("_", " ").title() + lines.append(f"- **{label}:** {len(items)}") + for item in items: + name = item.get("name", "?") + _status = item.get("status", item.get("ready", "?")) + lines.append(f" - `{name}` — {_status}") + lines.append("") + except Exception: # pylint: disable=broad-exception-caught + pass + + +def _write_summary_md(bundle_dir, bundle_name, cluster_info, capabilities, + check_results, namespaces, total_logs, total_prev, errors): + # pylint: disable=too-many-branches + """Write a comprehensive SUMMARY.md at the bundle root. + + This is the single file a DRI opens first — it summarizes everything + in the bundle: cluster state, check results, collected resources, errors. + """ + from azext_workload_orchestration.support.utils import write_text + + lines = [] + lines.append("# WO Support Bundle — Summary Report") + lines.append("") + + sv = cluster_info.get("server_version", {}) + _ctx_name = cluster_info.get("context", "unknown") # noqa: F841 + lines.append("## Cluster Overview") + lines.append("") + lines.append("| Field | Value |") + lines.append("|-------|-------|") + lines.append(f"| **Bundle** | `{bundle_name}` |") + lines.append(f"| **Kubernetes Version** | {sv.get('git_version', 'unknown')} |") + lines.append(f"| **Platform** | {sv.get('platform', 'unknown')} |") + lines.append(f"| **Node Count** | {cluster_info.get('node_count', 0)} |") + lines.append(f"| **Namespace Count** | {len(cluster_info.get('namespaces', []))} |") + lines.append(f"| **Namespaces Collected** | {', '.join(namespaces)} |") + + # Detected capabilities + detected = [k.replace("has_", "") for k, v in capabilities.items() if v] + not_detected = [k.replace("has_", "") for k, v in capabilities.items() if not v] + lines.append(f"| **Components Detected** | {', '.join(detected) if detected else 'none'} |") + if not_detected: + lines.append(f"| **Not Detected** | {', '.join(not_detected)} |") + lines.append("") + + # Nodes + nodes = cluster_info.get("nodes", []) + if nodes: + lines.append("## Nodes") + lines.append("") + lines.append("| Name | Ready | Roles | CPU | Memory | Runtime | Kubelet |") + lines.append("|------|-------|-------|-----|--------|---------|---------|") + for n in nodes: + ready = "✅ Yes" if n.get("ready") == "True" else "❌ No" + roles = ", ".join(n.get("roles", [""])) + lines.append( + f"| {n['name']} | {ready} | {roles} " + f"| {n.get('allocatable_cpu', '?')} " + f"| {n.get('allocatable_memory', '?')} " + f"| {n.get('container_runtime', '?')} " + f"| {n.get('kubelet_version', '?')} |" + ) + + # Node conditions (pressure, etc.) + has_issues = False + for n in nodes: + conditions = n.get("conditions", {}) + for cond, val in conditions.items(): + if cond != "Ready" and val == "True": + if not has_issues: + lines.append("") + lines.append("### ⚠️ Node Conditions") + lines.append("") + has_issues = True + lines.append(f"- **{n['name']}**: {cond} = True") + + # Taints + tainted = [n for n in nodes if n.get("taints")] + if tainted: + lines.append("") + lines.append("### Node Taints") + lines.append("") + for n in tainted: + for t in n["taints"]: + lines.append( + f"- **{n['name']}**: `{t.get('key', '?')}=" + f"{t.get('value', '')}:{t.get('effect', '?')}`" + ) + lines.append("") + + # Checks — the main section + if check_results: + passed = sum(1 for c in check_results if c.get("status") == STATUS_PASS) + failed = sum(1 for c in check_results if c.get("status") == STATUS_FAIL) + warned = sum(1 for c in check_results if c.get("status") == STATUS_WARN) + + lines.append("## Prerequisite Checks") + lines.append("") + lines.append(f"> **{passed} passed, {failed} failed, {warned} warnings** " + f"(out of {len(check_results)} total)") + lines.append("") + + # Failed checks first (most important) + failed_checks = [c for c in check_results if c.get("status") == STATUS_FAIL] + if failed_checks: + lines.append("### ❌ Failed Checks (Action Required)") + lines.append("") + for c in failed_checks: + lines.append(f"- **{c.get('check_name', '?')}** ({c.get('category', '?')}): {c.get('message', '')}") + lines.append("") + + # Warnings + warn_checks = [c for c in check_results if c.get("status") == STATUS_WARN] + if warn_checks: + lines.append("### ⚠️ Warnings") + lines.append("") + for c in warn_checks: + lines.append(f"- **{c.get('check_name', '?')}** ({c.get('category', '?')}): {c.get('message', '')}") + lines.append("") + + # Full table + lines.append("### All Checks") + lines.append("") + lines.append("| Status | Check | Category | Details |") + lines.append("|--------|-------|----------|---------|") + + status_icons = { + STATUS_PASS: "✅ PASS", + STATUS_FAIL: "❌ FAIL", + STATUS_WARN: "⚠️ WARN", + "SKIP": "⏭️ SKIP", + "ERROR": "💥 ERROR", + } + for c in check_results: + icon = status_icons.get(c.get("status"), c.get("status", "?")) + name = c.get("check_name", "unknown") + cat = c.get("category", "") + msg = c.get("message", "").replace("|", "\\|") + lines.append(f"| {icon} | {name} | {cat} | {msg} |") + lines.append("") + + # Data collected + lines.append("## Data Collected") + lines.append("") + lines.append("| Item | Count |") + lines.append("|------|-------|") + lines.append(f"| Container logs | {total_logs} |") + if total_prev: + lines.append(f"| Previous logs (crash-looping pods) | {total_prev} |") + lines.append(f"| Namespaces collected | {len(namespaces)} |") + lines.append(f"| Prerequisite checks | {len(check_results)} |") + lines.append("") + + # Per-namespace resource counts (read from collected files) + lines.append("### Resources Per Namespace") + lines.append("") + _append_namespace_resources(bundle_dir, namespaces, lines) + + # WO components + _append_wo_components(bundle_dir, lines) + + # Errors + if errors: + lines.append("## ⚠️ Collection Errors") + lines.append("") + lines.append("The following errors occurred during bundle collection. " + "The bundle was still generated but may be missing some data.") + lines.append("") + for err in errors: + lines.append(f"- {err}") + lines.append("") + + # Bundle contents guide + lines.append("## How to Read This Bundle") + lines.append("") + lines.append("| File/Folder | What's Inside |") + lines.append("|-------------|---------------|") + lines.append("| 📄 `SUMMARY.md` | This file — start here |") + lines.append("| 📄 `metadata.json` | Bundle parameters, timestamps, capabilities |") + lines.append("| 📁 `checks/` | Individual check results (JSON) + `summary.json` |") + lines.append("| 📁 `cluster-info/` | K8s version, node details, namespace list, metrics |") + lines.append("| 📁 `resources/` | Per-namespace resource descriptions, cluster-scoped resources, network config |") + lines.append("| 📁 `logs/` | Container logs organized by `namespace/pod--container.log` |") + lines.append("") + lines.append("### Quick Troubleshooting") + lines.append("") + lines.append("1. **Check failed?** → Look at the ❌ Failed Checks section above") + lines.append("2. **Pod crashing?** → Check `logs//----previous.log`") + lines.append("3. **WO not working?** → Check `resources/wo-components.json` and `logs/workloadorchestration/`") + lines.append("4. **Network issues?** → Check `resources/network-config.json`") + lines.append("5. **Storage issues?** → Check `resources/cluster-resources.json` (storage_classes, csi_drivers)") + lines.append("") + + write_text(os.path.join(bundle_dir, "SUMMARY.md"), "\n".join(lines)) + + +def _out(msg, *args): + """Print a line to console via logger.warning (az CLI convention).""" + if args: + logger.warning(msg, *args) + else: + logger.warning(msg) + + +def _print_cluster_info(cluster_info): + """Print cluster overview to console.""" + sv = cluster_info.get("server_version", {}) + version = sv.get("git_version", "unknown") + node_count = cluster_info.get("node_count", 0) + ns_count = len(cluster_info.get("namespaces", [])) + + _out("") + _out(" Cluster: Kubernetes %s", version) + _out(" Nodes: %d", node_count) + _out(" Namespaces: %d", ns_count) + + # Show node details + for node in cluster_info.get("nodes", []): + cpu = node.get("allocatable_cpu", "?") + mem = node.get("allocatable_memory", "?") + ready = node.get("ready", "?") + roles = ", ".join(node.get("roles", [""])) + _status = "Ready" if ready == "True" else "NOT READY" # noqa: F841 + _out(" %s %s [%s] cpu=%s mem=%s", + " " if ready == "True" else "! ", node["name"], roles, cpu, mem) + + +def _print_capabilities(capabilities): + """Print detected capabilities.""" + detected = [k.replace("has_", "") for k, v in capabilities.items() if v] + if detected: + _out(" Detected: %s", ", ".join(detected)) + + +def _print_check_results(check_results): + """Print each check result with status icon.""" + status_icons = { + STATUS_PASS: "[PASS]", + STATUS_FAIL: "[FAIL]", + STATUS_WARN: "[WARN]", + "SKIP": "[SKIP]", + "ERROR": "[ERR!]", + } + + for c in check_results: + icon = status_icons.get(c.get("status"), "[????]") + name = c.get("check_name", "unknown") + msg = c.get("message", "") + _out(" %s %-25s %s", icon, name, msg) + + passed = sum(1 for c in check_results if c.get("status") == STATUS_PASS) + failed = sum(1 for c in check_results if c.get("status") == STATUS_FAIL) + warned = sum(1 for c in check_results if c.get("status") == STATUS_WARN) + _out("-" * 58) + _out(" %d passed, %d failed, %d warnings", passed, failed, warned) diff --git a/src/workload-orchestration/azext_workload_orchestration/support/collectors.py b/src/workload-orchestration/azext_workload_orchestration/support/collectors.py new file mode 100644 index 00000000000..f33d0903bb0 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/support/collectors.py @@ -0,0 +1,1070 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=import-outside-toplevel,too-many-branches,too-many-statements +# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments + +"""Data collectors for the workload-orchestration support bundle feature.""" + +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +from knack.log import get_logger + +from azext_workload_orchestration.support.consts import ( + DEFAULT_TAIL_LINES, + DEFAULT_MAX_LOG_SIZE_BYTES, + FOLDER_RESOURCES, + FOLDER_CLUSTER_INFO, + WO_NAMESPACE, +) +from azext_workload_orchestration.support.utils import ( + safe_api_call, + write_json, + write_text, + create_namespace_log_dir, +) + +logger = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Namespace validation (pre-flight) +# --------------------------------------------------------------------------- + +def validate_namespaces(clients, namespaces): + """Validate that requested namespaces exist on the cluster. + + Returns (valid_namespaces, skipped_namespaces) where skipped_namespaces + is a list of (namespace, reason) tuples. + """ + core = clients["core_v1"] + valid = [] + skipped = [] + + for ns in namespaces: + result, err = safe_api_call( + core.read_namespace, ns, + description=f"validate namespace '{ns}'", + max_retries=1, + ) + if result: + if result.status and result.status.phase == "Terminating": + skipped.append((ns, "namespace is terminating")) + logger.warning("Namespace '%s' is terminating, skipping", ns) + else: + valid.append(ns) + else: + skipped.append((ns, err or "namespace not found")) + logger.warning("Namespace '%s' not found, skipping: %s", ns, err) + + return valid, skipped + + +# --------------------------------------------------------------------------- +# Resource directory helpers +# --------------------------------------------------------------------------- + +def _get_ns_resource_dir(bundle_dir, namespace): + """Get (and create) the per-namespace resource subdirectory.""" + ns_dir = os.path.join(bundle_dir, FOLDER_RESOURCES, namespace) + os.makedirs(ns_dir, exist_ok=True) + return ns_dir + + +def _get_cluster_resource_dir(bundle_dir): + """Get (and create) the cluster-scoped resource subdirectory.""" + cluster_dir = os.path.join(bundle_dir, FOLDER_RESOURCES, "cluster") + os.makedirs(cluster_dir, exist_ok=True) + return cluster_dir + + +# --------------------------------------------------------------------------- +# Cluster info collection +# --------------------------------------------------------------------------- + +def collect_cluster_info(clients, bundle_dir): + """Collect basic cluster information (version, nodes, namespaces).""" + info = {} + + # Kubernetes version + version_client = clients["version"] + result, _err = safe_api_call(version_client.get_code, description="get server version") + if result: + info["server_version"] = { + "major": result.major, + "minor": result.minor, + "git_version": result.git_version, + "platform": result.platform, + } + + # Node summary + core = clients["core_v1"] + result, _err = safe_api_call(core.list_node, description="list nodes") + if result: + nodes = [] + for node in result.items: + status = node.status + conditions_list = [ + {"type": c.type, "status": c.status, "reason": c.reason, "message": c.message} + for c in (status.conditions or []) + ] + conditions = {c.type: c.status for c in (status.conditions or [])} + alloc = status.allocatable or {} + nodes.append({ + "name": node.metadata.name, + "ready": conditions.get("Ready", "Unknown"), + "roles": _get_node_roles(node), + "os": node.status.node_info.os_image if status.node_info else "unknown", + "container_runtime": status.node_info.container_runtime_version if status.node_info else "unknown", + "kubelet_version": status.node_info.kubelet_version if status.node_info else "unknown", + "allocatable_cpu": alloc.get("cpu", "0"), + "allocatable_memory": alloc.get("memory", "0"), + "taints": [ + {"key": t.key, "effect": t.effect, "value": t.value} + for t in (node.spec.taints or []) + ], + "conditions": conditions, + "conditions_detail": conditions_list, + }) + info["nodes"] = nodes + info["node_count"] = len(nodes) + + # Namespace list + result, _err = safe_api_call(core.list_namespace, description="list namespaces") + if result: + info["namespaces"] = [ + { + "name": ns.metadata.name, + "status": ns.status.phase, + "labels": dict(ns.metadata.labels or {}), + } + for ns in result.items + ] + + write_json(os.path.join(bundle_dir, FOLDER_CLUSTER_INFO, "cluster-info.json"), info) + logger.info("Collected cluster info: %d nodes, %d namespaces", + info.get("node_count", 0), len(info.get("namespaces", []))) + return info + + +def collect_all_events(clients, bundle_dir): + """Collect events from all namespaces (Warning events prioritized). + + Saves to cluster-info/events.json. Limits to most recent 500 events + to keep bundle size reasonable. + """ + core = clients["core_v1"] + result, err = safe_api_call( + core.list_event_for_all_namespaces, + description="list events across all namespaces", + ) + if not result: + logger.debug("Could not collect cluster events: %s", err) + return [] + + events = [] + for e in result.items: + events.append({ + "namespace": e.metadata.namespace, + "type": e.type, + "reason": e.reason, + "message": e.message, + "involved_object": f"{e.involved_object.kind}/{e.involved_object.name}", + "count": e.count, + "first_timestamp": str(e.first_timestamp) if e.first_timestamp else None, + "last_timestamp": str(e.last_timestamp) if e.last_timestamp else None, + }) + + # Sort: Warning first, then by last_timestamp descending, limit to 500 + events.sort(key=lambda e: ( + 0 if e["type"] == "Warning" else 1, + e.get("last_timestamp") or "", + )) + events = events[:500] + + write_json(os.path.join(bundle_dir, FOLDER_CLUSTER_INFO, "events.json"), events) + warning_count = sum(1 for e in events if e["type"] == "Warning") + logger.info("Collected %d cluster events (%d warnings)", len(events), warning_count) + return events + + +def _get_node_roles(node): + """Extract node roles from labels.""" + roles = [] + for label in (node.metadata.labels or {}): + if label.startswith("node-role.kubernetes.io/"): + roles.append(label.split("/")[-1]) + return roles if roles else [""] + + +# --------------------------------------------------------------------------- +# Resource collection +# --------------------------------------------------------------------------- + +def collect_namespace_resources(clients, bundle_dir, namespace): + """Collect resource descriptions for a given namespace.""" + core = clients["core_v1"] + apps = clients["apps_v1"] + resources = {} + + # Pods + result, _err = safe_api_call( + core.list_namespaced_pod, namespace, description=f"list pods in {namespace}" + ) + if result: + resources["pods"] = [ + { + "name": p.metadata.name, + "phase": p.status.phase, + "ready": _pod_ready_count(p), + "restarts": _pod_restart_count(p), + "node": p.spec.node_name, + "containers": _get_container_details(p), + } + for p in result.items + ] + + # Deployments + result, _err = safe_api_call( + apps.list_namespaced_deployment, namespace, description=f"list deployments in {namespace}" + ) + if result: + resources["deployments"] = [ + { + "name": d.metadata.name, + "replicas": d.spec.replicas, + "ready_replicas": d.status.ready_replicas or 0, + "available_replicas": d.status.available_replicas or 0, + } + for d in result.items + ] + + # Services + result, _err = safe_api_call( + core.list_namespaced_service, namespace, description=f"list services in {namespace}" + ) + if result: + resources["services"] = [ + { + "name": s.metadata.name, + "type": s.spec.type, + "cluster_ip": s.spec.cluster_ip, + "ports": [ + {"port": p.port, "target_port": str(p.target_port), "protocol": p.protocol} + for p in (s.spec.ports or []) + ], + } + for s in result.items + ] + + # DaemonSets + result, _err = safe_api_call( + apps.list_namespaced_daemon_set, namespace, description=f"list daemonsets in {namespace}" + ) + if result: + resources["daemonsets"] = [ + { + "name": ds.metadata.name, + "desired": ds.status.desired_number_scheduled, + "ready": ds.status.number_ready, + } + for ds in result.items + ] + + # StatefulSets + result, _err = safe_api_call( + apps.list_namespaced_stateful_set, namespace, + description=f"list statefulsets in {namespace}" + ) + if result: + resources["statefulsets"] = [ + { + "name": ss.metadata.name, + "replicas": ss.spec.replicas, + "ready_replicas": ss.status.ready_replicas or 0, + } + for ss in result.items + ] + + # Events + result, _err = safe_api_call( + core.list_namespaced_event, namespace, description=f"list events in {namespace}" + ) + if result: + resources["events"] = [ + { + "type": e.type, + "reason": e.reason, + "message": e.message, + "involved_object": f"{e.involved_object.kind}/{e.involved_object.name}", + "count": e.count, + "last_timestamp": str(e.last_timestamp) if e.last_timestamp else None, + } + for e in result.items + ] + + # ConfigMaps (names only, not data — could contain secrets) + result, _err = safe_api_call( + core.list_namespaced_config_map, namespace, description=f"list configmaps in {namespace}" + ) + if result: + resources["configmaps"] = [ + {"name": cm.metadata.name, "data_keys": list((cm.data or {}).keys())} + for cm in result.items + ] + + # ReplicaSets + result, _err = safe_api_call( + apps.list_namespaced_replica_set, namespace, + description=f"list replicasets in {namespace}" + ) + if result: + resources["replicasets"] = [ + { + "name": rs.metadata.name, + "replicas": rs.spec.replicas, + "ready_replicas": rs.status.ready_replicas or 0, + "available_replicas": rs.status.available_replicas or 0, + "owner": _get_owner_ref(rs), + } + for rs in result.items + ] + + # Jobs + try: + from kubernetes import client as _k8s_client + batch_v1 = _k8s_client.BatchV1Api() + result, _err = safe_api_call( + batch_v1.list_namespaced_job, namespace, + description=f"list jobs in {namespace}" + ) + if result: + resources["jobs"] = [ + { + "name": j.metadata.name, + "active": j.status.active or 0, + "succeeded": j.status.succeeded or 0, + "failed": j.status.failed or 0, + "completions": j.spec.completions, + "start_time": str(j.status.start_time) if j.status.start_time else None, + "completion_time": str(j.status.completion_time) if j.status.completion_time else None, + } + for j in result.items + ] + + # CronJobs + result, _err = safe_api_call( + batch_v1.list_namespaced_cron_job, namespace, + description=f"list cronjobs in {namespace}" + ) + if result: + resources["cronjobs"] = [ + { + "name": cj.metadata.name, + "schedule": cj.spec.schedule, + "suspend": cj.spec.suspend, + "active_jobs": len(cj.status.active or []), + "last_schedule": str(cj.status.last_schedule_time) if cj.status.last_schedule_time else None, + "last_successful": str(cj.status.last_successful_time) if cj.status.last_successful_time else None, + } + for cj in result.items + ] + except Exception as ex: # pylint: disable=broad-exception-caught + logger.debug("Batch API not available for %s: %s", namespace, ex) + + # Ingresses + try: + from kubernetes import client as _k8s_client + networking_v1 = _k8s_client.NetworkingV1Api() + result, _err = safe_api_call( + networking_v1.list_namespaced_ingress, namespace, + description=f"list ingresses in {namespace}" + ) + if result: + resources["ingresses"] = [ + { + "name": ing.metadata.name, + "class_name": ing.spec.ingress_class_name, + "rules_count": len(ing.spec.rules or []), + "tls_count": len(ing.spec.tls or []), + "hosts": [r.host for r in (ing.spec.rules or []) if r.host], + } + for ing in result.items + ] + + # NetworkPolicies + result, _err = safe_api_call( + networking_v1.list_namespaced_network_policy, namespace, + description=f"list network policies in {namespace}" + ) + if result: + resources["network_policies"] = [ + { + "name": np.metadata.name, + "pod_selector": (dict(np.spec.pod_selector.match_labels or {}) + if np.spec.pod_selector and np.spec.pod_selector.match_labels + else {}), + "policy_types": np.spec.policy_types or [], + "ingress_rules": len(np.spec.ingress or []) if np.spec.ingress else 0, + "egress_rules": len(np.spec.egress or []) if np.spec.egress else 0, + } + for np in result.items + ] + except Exception as ex: # pylint: disable=broad-exception-caught + logger.debug("Networking API not available for %s: %s", namespace, ex) + + # ServiceAccounts + result, _err = safe_api_call( + core.list_namespaced_service_account, namespace, + description=f"list service accounts in {namespace}" + ) + if result: + resources["service_accounts"] = [ + { + "name": sa.metadata.name, + "secrets_count": len(sa.secrets or []) if sa.secrets else 0, + "image_pull_secrets": [ + ips.name for ips in (sa.image_pull_secrets or []) + ], + } + for sa in result.items + ] + + ns_res_dir = _get_ns_resource_dir(bundle_dir, namespace) + filepath = os.path.join(ns_res_dir, "resources.json") + write_json(filepath, resources) + pod_count = len(resources.get("pods", [])) + logger.info("Collected resources for %s: %d pods, %d resource types", + namespace, pod_count, len(resources)) + return resources + + +def _get_owner_ref(resource): + """Extract owner reference (controller) for a resource.""" + refs = resource.metadata.owner_references or [] + if refs: + return {"kind": refs[0].kind, "name": refs[0].name} + return None + + +def _get_container_details(pod): + """Extract container status details for a pod.""" + details = [] + statuses = {cs.name: cs for cs in (pod.status.container_statuses or [])} + for c in (pod.spec.containers or []): + cs = statuses.get(c.name) + info = {"name": c.name} + if cs: + info["ready"] = cs.ready + info["restart_count"] = cs.restart_count + # Extract current state + if cs.state: + if cs.state.running: + info["state"] = "running" + elif cs.state.waiting: + info["state"] = "waiting" + info["reason"] = cs.state.waiting.reason + info["message"] = cs.state.waiting.message + elif cs.state.terminated: + info["state"] = "terminated" + info["reason"] = cs.state.terminated.reason + info["exit_code"] = cs.state.terminated.exit_code + # Extract last state (previous run) + if cs.last_state and cs.last_state.terminated: + info["last_terminated_reason"] = cs.last_state.terminated.reason + info["last_exit_code"] = cs.last_state.terminated.exit_code + details.append(info) + return details + + +def _pod_ready_count(pod): + """Return 'ready/total' string for a pod.""" + containers = pod.spec.containers or [] + total = len(containers) + ready = sum( + 1 for cs in (pod.status.container_statuses or []) if cs.ready + ) + return f"{ready}/{total}" + + +def _pod_restart_count(pod): + """Return total restart count across all containers.""" + return sum(cs.restart_count for cs in (pod.status.container_statuses or [])) + + +# --------------------------------------------------------------------------- +# Cluster-scoped resource collection +# --------------------------------------------------------------------------- + +def collect_cluster_resources(clients, bundle_dir): + """Collect cluster-scoped resources (StorageClasses, CRDs, webhooks, PVs).""" + cluster = {} + + # StorageClasses + storage = clients["storage_v1"] + result, _err = safe_api_call(storage.list_storage_class, description="list storage classes") + if result: + cluster["storage_classes"] = [ + { + "name": sc.metadata.name, + "provisioner": sc.provisioner, + "is_default": _is_default_sc(sc), + "reclaim_policy": sc.reclaim_policy, + } + for sc in result.items + ] + + # PersistentVolumes + core = clients["core_v1"] + result, _err = safe_api_call(core.list_persistent_volume, description="list PVs") + if result: + cluster["persistent_volumes"] = [ + { + "name": pv.metadata.name, + "capacity": dict(pv.spec.capacity or {}), + "status": pv.status.phase, + "storage_class": pv.spec.storage_class_name, + "claim": f"{pv.spec.claim_ref.namespace}/{pv.spec.claim_ref.name}" if pv.spec.claim_ref else None, + } + for pv in result.items + ] + + # Validating Webhooks + admission = clients["admissionregistration_v1"] + result, _err = safe_api_call( + admission.list_validating_webhook_configuration, + description="list validating webhooks", + ) + if result: + cluster["validating_webhooks"] = [ + { + "name": w.metadata.name, + "webhook_count": len(w.webhooks or []), + "failure_policies": list({wh.failure_policy for wh in (w.webhooks or [])}), + } + for w in result.items + ] + + # Mutating Webhooks + result, _err = safe_api_call( + admission.list_mutating_webhook_configuration, + description="list mutating webhooks", + ) + if result: + cluster["mutating_webhooks"] = [ + { + "name": w.metadata.name, + "webhook_count": len(w.webhooks or []), + "failure_policies": list({wh.failure_policy for wh in (w.webhooks or [])}), + } + for w in result.items + ] + + # CRDs (names only — full JSON is huge) + custom = clients["custom_objects"] + result, _err = safe_api_call( + custom.list_cluster_custom_object, + "apiextensions.k8s.io", "v1", "customresourcedefinitions", + description="list CRDs", + ) + if result: + cluster["crds"] = [ + { + "name": crd.get("metadata", {}).get("name", "unknown"), + "group": crd.get("spec", {}).get("group", "unknown"), + } + for crd in result.get("items", []) + ] + + # CSI Drivers + result, _err = safe_api_call(storage.list_csi_driver, description="list CSI drivers") + if result: + cluster["csi_drivers"] = [ + { + "name": d.metadata.name, + "attach_required": d.spec.attach_required if d.spec else None, + } + for d in result.items + ] + + cluster_dir = _get_cluster_resource_dir(bundle_dir) + filepath = os.path.join(cluster_dir, "resources.json") + write_json(filepath, cluster) + logger.info("Collected cluster resources: %d SCs, %d webhooks, %d CRDs, %d CSI drivers", + len(cluster.get("storage_classes", [])), + len(cluster.get("validating_webhooks", [])) + len(cluster.get("mutating_webhooks", [])), + len(cluster.get("crds", [])), + len(cluster.get("csi_drivers", []))) + return cluster + + +def _is_default_sc(sc): + """Check if a StorageClass is the default (v1 or beta annotation).""" + from azext_workload_orchestration.support.consts import ( + SC_DEFAULT_ANNOTATION_V1, SC_DEFAULT_ANNOTATION_BETA, + ) + ann = sc.metadata.annotations or {} + return ( + ann.get(SC_DEFAULT_ANNOTATION_V1) == "true" + or ann.get(SC_DEFAULT_ANNOTATION_BETA) == "true" + ) + + +# --------------------------------------------------------------------------- +# Container log collection +# --------------------------------------------------------------------------- + +def collect_container_logs(clients, bundle_dir, namespace, tail_lines=DEFAULT_TAIL_LINES, + max_workers=5, log_timeout=None): + """Collect container logs for all pods in a namespace. + + Uses threading for parallel log fetching. Returns count of logs collected. + """ + from azext_workload_orchestration.support.consts import DEFAULT_LOG_TIMEOUT_SECONDS + + per_log_timeout = log_timeout or DEFAULT_LOG_TIMEOUT_SECONDS + core = clients["core_v1"] + result, err = safe_api_call( + core.list_namespaced_pod, namespace, description=f"list pods for logs in {namespace}" + ) + if not result: + logger.warning("Could not list pods in %s: %s", namespace, err) + return 0 + + ns_log_dir = create_namespace_log_dir(bundle_dir, namespace) + + # Build list of (pod_name, container_name) to collect + targets = [] + for pod in result.items: + for container in (pod.spec.containers or []): + targets.append((pod.metadata.name, container.name)) + + if not targets: + return 0 + + collected = 0 + + def _fetch_log(pod_name, container_name): + log_result, _log_err = safe_api_call( + core.read_namespaced_pod_log, + pod_name, namespace, + container=container_name, + tail_lines=tail_lines, + _preload_content=True, + description=f"logs {namespace}/{pod_name}/{container_name}", + ) + if log_result is not None: + # Truncate if exceeds max size + log_text = log_result + if len(log_text.encode("utf-8", errors="replace")) > DEFAULT_MAX_LOG_SIZE_BYTES: + lines = log_text.splitlines() + truncated = [] + size = 0 + for line in reversed(lines): + size += len(line.encode("utf-8", errors="replace")) + 1 + if size > DEFAULT_MAX_LOG_SIZE_BYTES: + break + truncated.insert(0, line) + log_text = f"[TRUNCATED to last {len(truncated)} lines]\n" + "\n".join(truncated) + + filepath = os.path.join(ns_log_dir, f"{pod_name}--{container_name}.log") + write_text(filepath, log_text) + return True + return False + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_fetch_log, pod, container): (pod, container) + for pod, container in targets + } + for future in as_completed(futures, timeout=per_log_timeout * len(targets)): + pod, container = futures[future] + try: + if future.result(timeout=per_log_timeout): + collected += 1 + except TimeoutError: + logger.debug("Timeout collecting log for %s/%s", pod, container) + except Exception as ex: # pylint: disable=broad-exception-caught + logger.debug("Failed to collect log for %s/%s: %s", pod, container, ex) + + logger.info("Collected %d/%d container logs in %s", collected, len(targets), namespace) + return collected + + +# --------------------------------------------------------------------------- +# WO-specific collection +# --------------------------------------------------------------------------- + +def collect_wo_components(clients, bundle_dir, capabilities): + """Collect WO-specific resources: Symphony CRDs, cert-manager status, etc.""" + wo_info = {} + custom = clients["custom_objects"] + + # Symphony targets (if symphony is installed) + if capabilities.get("has_symphony"): + result, _err = safe_api_call( + custom.list_namespaced_custom_object, + "fabric.symphony", "v1", WO_NAMESPACE, "targets", + description="list Symphony targets", + ) + if result: + wo_info["symphony_targets"] = [ + { + "name": t.get("metadata", {}).get("name", "unknown"), + "status": t.get("status", {}).get("provisioningStatus", {}).get("status", "unknown"), + } + for t in result.get("items", []) + ] + + # cert-manager ClusterIssuers (if cert-manager is installed) + if capabilities.get("has_cert_manager"): + result, _err = safe_api_call( + custom.list_cluster_custom_object, + "cert-manager.io", "v1", "clusterissuers", + description="list ClusterIssuers", + ) + if result: + wo_info["cluster_issuers"] = [ + { + "name": ci.get("metadata", {}).get("name", "unknown"), + "ready": _cert_issuer_ready(ci), + } + for ci in result.get("items", []) + ] + + # Gatekeeper constraints (if gatekeeper is installed) + if capabilities.get("has_gatekeeper"): + result, _err = safe_api_call( + custom.list_cluster_custom_object, + "templates.gatekeeper.sh", "v1", "constrainttemplates", + description="list Gatekeeper ConstraintTemplates", + ) + if result: + wo_info["gatekeeper_templates"] = [ + {"name": t.get("metadata", {}).get("name", "unknown")} for t in result.get("items", []) + ] + + filepath = os.path.join(_get_cluster_resource_dir(bundle_dir), "wo-components.json") + write_json(filepath, wo_info) + return wo_info + + +def _cert_issuer_ready(issuer): + """Check if a cert-manager issuer is Ready.""" + conditions = issuer.get("status", {}).get("conditions", []) + for c in conditions: + if c.get("type") == "Ready": + return c.get("status") == "True" + return False + + +# --------------------------------------------------------------------------- +# Previous container logs (crash-looping pods) +# --------------------------------------------------------------------------- + +def collect_previous_logs(clients, bundle_dir, namespace, tail_lines=DEFAULT_TAIL_LINES): + """Collect previous container logs for pods that have restarted. + + Only collects previous logs for containers with restart_count > 0. + Returns count of previous logs collected. + """ + core = clients["core_v1"] + result, _err = safe_api_call( + core.list_namespaced_pod, namespace, + description=f"list pods for previous logs in {namespace}", + ) + if not result: + return 0 + + ns_log_dir = create_namespace_log_dir(bundle_dir, namespace) + collected = 0 + + for pod in result.items: + for cs in (pod.status.container_statuses or []): + if cs.restart_count and cs.restart_count > 0: + log_result, _log_err = safe_api_call( + core.read_namespaced_pod_log, + pod.metadata.name, namespace, + container=cs.name, + tail_lines=tail_lines, + previous=True, + _preload_content=True, + description=f"previous logs {namespace}/{pod.metadata.name}/{cs.name}", + ) + if log_result: + filepath = os.path.join( + ns_log_dir, f"{pod.metadata.name}--{cs.name}--previous.log" + ) + try: + write_text(filepath, log_result) + collected += 1 + except OSError as ex: + logger.warning("Failed to write previous log %s: %s", filepath, ex) + + if collected: + logger.info("Collected %d previous container logs in %s", collected, namespace) + return collected + + +# --------------------------------------------------------------------------- +# Resource quotas and limit ranges +# --------------------------------------------------------------------------- + +def collect_resource_quotas(clients, bundle_dir, namespace): + """Collect ResourceQuotas and LimitRanges for a namespace.""" + core = clients["core_v1"] + quota_data = {} + + # ResourceQuotas + result, _err = safe_api_call( + core.list_namespaced_resource_quota, namespace, + description=f"list resource quotas in {namespace}", + ) + if result and result.items: + quota_data["resource_quotas"] = [ + { + "name": rq.metadata.name, + "hard": dict(rq.status.hard or {}) if rq.status else {}, + "used": dict(rq.status.used or {}) if rq.status else {}, + } + for rq in result.items + ] + + # LimitRanges + result, _err = safe_api_call( + core.list_namespaced_limit_range, namespace, + description=f"list limit ranges in {namespace}", + ) + if result and result.items: + quota_data["limit_ranges"] = [ + { + "name": lr.metadata.name, + "limits": [ + { + "type": lim.type, + "default": dict(lim.default or {}), + "default_request": dict(lim.default_request or {}), + "max": dict(getattr(lim, "max", None) or {}), + "min": dict(getattr(lim, "min", None) or {}), + } + for lim in (lr.spec.limits or []) + ], + } + for lr in result.items + ] + + if quota_data: + ns_res_dir = _get_ns_resource_dir(bundle_dir, namespace) + filepath = os.path.join(ns_res_dir, "quotas.json") + write_json(filepath, quota_data) + + return quota_data + + +# --------------------------------------------------------------------------- +# Metrics (kubectl top equivalent) +# --------------------------------------------------------------------------- + +def collect_metrics(clients, bundle_dir, capabilities): + """Collect node and pod metrics if metrics-server is available.""" + if not capabilities.get("has_metrics"): + logger.info("Metrics API not available, skipping metrics collection") + return {} + + custom = clients["custom_objects"] + metrics = {} + + # Node metrics + result, _err = safe_api_call( + custom.list_cluster_custom_object, + "metrics.k8s.io", "v1beta1", "nodes", + description="get node metrics", + ) + if result: + metrics["node_metrics"] = [ + { + "name": n.get("metadata", {}).get("name", "unknown"), + "cpu": n.get("usage", {}).get("cpu", "0"), + "memory": n.get("usage", {}).get("memory", "0"), + } + for n in result.get("items", []) + ] + + # Pod metrics (WO namespace) + result, _err = safe_api_call( + custom.list_namespaced_custom_object, + "metrics.k8s.io", "v1beta1", WO_NAMESPACE, "pods", + description="get WO pod metrics", + ) + if result: + metrics["wo_pod_metrics"] = [ + { + "name": p.get("metadata", {}).get("name", "unknown"), + "containers": [ + { + "name": c.get("name", "unknown"), + "cpu": c.get("usage", {}).get("cpu", "0"), + "memory": c.get("usage", {}).get("memory", "0"), + } + for c in p.get("containers", []) + ], + } + for p in result.get("items", []) + ] + + if metrics: + filepath = os.path.join(bundle_dir, FOLDER_CLUSTER_INFO, "metrics.json") + write_json(filepath, metrics) + logger.info("Collected metrics: %d nodes, %d WO pods", + len(metrics.get("node_metrics", [])), + len(metrics.get("wo_pod_metrics", []))) + + return metrics + + +# --------------------------------------------------------------------------- +# PersistentVolumeClaims per namespace +# --------------------------------------------------------------------------- + +def collect_pvcs(clients, bundle_dir, namespace): + """Collect PVC information for a namespace.""" + core = clients["core_v1"] + result, _err = safe_api_call( + core.list_namespaced_persistent_volume_claim, namespace, + description=f"list PVCs in {namespace}", + ) + if not result or not result.items: + return [] + + pvcs = [ + { + "name": pvc.metadata.name, + "status": pvc.status.phase, + "capacity": dict(pvc.status.capacity or {}) if pvc.status.capacity else {}, + "storage_class": pvc.spec.storage_class_name, + "access_modes": pvc.spec.access_modes, + "volume_name": pvc.spec.volume_name, + } + for pvc in result.items + ] + + ns_res_dir = _get_ns_resource_dir(bundle_dir, namespace) + filepath = os.path.join(ns_res_dir, "pvcs.json") + write_json(filepath, pvcs) + return pvcs + + +# --------------------------------------------------------------------------- +# Network configuration collection (iptables/proxy/connectivity) +# --------------------------------------------------------------------------- + +def collect_network_config(clients, bundle_dir): + """Collect network configuration for diagnosing connectivity issues. + + Collects: kube-proxy ConfigMap (contains iptables mode/rules config), + Services with external access (LoadBalancer/NodePort), and endpoint slices + for kube-system to verify service mesh health. + """ + core = clients["core_v1"] + net_info = {} + + # 1. kube-proxy ConfigMap — contains iptables mode, CIDR ranges, proxy rules + result, err = safe_api_call( + core.read_namespaced_config_map, "kube-proxy", "kube-system", + description="read kube-proxy ConfigMap", + ) + if result: + data = result.data or {} + net_info["kube_proxy_config"] = { + "data_keys": list(data.keys()), + } + # Parse the config.conf or kubeconfig if present + for key in ("config.conf", "kubeconfig.conf"): + if key in data: + net_info["kube_proxy_config"][key] = data[key] + else: + logger.debug("kube-proxy ConfigMap not found: %s", err) + + # 2. Services with external access (LoadBalancer, NodePort, ExternalName) + result, _err = safe_api_call( + core.list_service_for_all_namespaces, + description="list all services for network config", + ) + if result: + external_svcs = [] + for svc in result.items: + svc_type = svc.spec.type + if svc_type in ("LoadBalancer", "NodePort", "ExternalName"): + external_svcs.append({ + "name": svc.metadata.name, + "namespace": svc.metadata.namespace, + "type": svc_type, + "cluster_ip": svc.spec.cluster_ip, + "external_ips": getattr(svc.spec, 'external_i_ps', None) or getattr(svc.spec, 'external_ips', []), + "ports": [ + { + "port": p.port, + "target_port": str(p.target_port), + "node_port": p.node_port, + "protocol": p.protocol, + } + for p in (svc.spec.ports or []) + ], + "load_balancer_ip": ( + svc.status.load_balancer.ingress[0].ip + if svc.status and svc.status.load_balancer + and svc.status.load_balancer.ingress + else None + ), + }) + net_info["external_services"] = external_svcs + + # 3. Endpoint slices for kube-system (verify service discovery works) + try: + from kubernetes import client as _k8s_client + discovery_v1 = _k8s_client.DiscoveryV1Api() + result, _err = safe_api_call( + discovery_v1.list_namespaced_endpoint_slice, "kube-system", + description="list endpoint slices in kube-system", + ) + if result: + net_info["endpoint_slices"] = [ + { + "name": eps.metadata.name, + "address_type": eps.address_type, + "endpoints_count": len(eps.endpoints or []), + "ports": [ + {"port": p.port, "protocol": p.protocol, "name": p.name} + for p in (eps.ports or []) + ], + } + for eps in result.items + ] + except Exception as ex: # pylint: disable=broad-exception-caught + logger.debug("Discovery API not available: %s", ex) + + # 4. Cluster CIDR / pod CIDR from node specs + result, _err = safe_api_call( + core.list_node, description="list nodes for pod CIDRs", + ) + if result: + net_info["node_cidrs"] = [ + { + "name": node.metadata.name, + "pod_cidr": node.spec.pod_cidr, + "pod_cidrs": getattr(node.spec, 'pod_cid_rs', None) or getattr(node.spec, 'pod_cidrs', None), + } + for node in result.items + ] + + if net_info: + filepath = os.path.join(_get_cluster_resource_dir(bundle_dir), "network-config.json") + write_json(filepath, net_info) + logger.info("Collected network config: %d external services, %s", + len(net_info.get("external_services", [])), + "kube-proxy config found" if net_info.get("kube_proxy_config") else "no kube-proxy") + + return net_info diff --git a/src/workload-orchestration/azext_workload_orchestration/support/consts.py b/src/workload-orchestration/azext_workload_orchestration/support/consts.py new file mode 100644 index 00000000000..10ea75e11d3 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/support/consts.py @@ -0,0 +1,99 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +"""Constants for the workload-orchestration support bundle feature.""" + +# Bundle defaults +DEFAULT_TAIL_LINES = 1000 +DEFAULT_TIMEOUT_SECONDS = 600 # 10 minutes total +DEFAULT_API_TIMEOUT_SECONDS = 30 # per-API-call timeout +DEFAULT_LOG_TIMEOUT_SECONDS = 60 # per-container log fetch timeout +DEFAULT_MAX_LOG_SIZE_BYTES = 5 * 1024 * 1024 # 5 MB per container +DEFAULT_MAX_BUNDLE_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB total +BUNDLE_PREFIX = "wo-support-bundle" + +# Retry defaults +DEFAULT_MAX_RETRIES = 3 +DEFAULT_RETRY_BACKOFF_BASE = 1.0 # seconds; retries wait 1s, 2s, 4s + +# WO-relevant namespaces +WO_NAMESPACE = "workloadorchestration" +CERT_MANAGER_NAMESPACE = "cert-manager" +KUBE_SYSTEM_NAMESPACE = "kube-system" +DEFAULT_NAMESPACES = [KUBE_SYSTEM_NAMESPACE, WO_NAMESPACE, CERT_MANAGER_NAMESPACE] + +# Protected namespaces — deploying workloads here is not recommended +PROTECTED_NAMESPACES = [ + "kube-system", + "kube-public", + "kube-node-lease", + "azure-arc", + "azure-arc-release", + "azure-extensions", + "gatekeeper-system", + "azure-workload-identity-system", + "cert-manager", + "flux-system", +] + +# DNS +DNS_SERVICE_LABEL = "k8s-app=kube-dns" +DNS_INTERNAL_HOST = "kubernetes.default.svc.cluster.local" +DNS_EXTERNAL_HOST = "mcr.microsoft.com" + +# Test pod +TEST_POD_IMAGE = "busybox:1.36" +TEST_POD_TIMEOUT = 60 # seconds +TEST_POD_PREFIX = "wo-diag-" + +# API groups for capability detection +API_GROUP_GATEKEEPER_TEMPLATES = "templates.gatekeeper.sh" +API_GROUP_GATEKEEPER_CONSTRAINTS = "constraints.gatekeeper.sh" +API_GROUP_KYVERNO = "kyverno.io" +API_GROUP_CERT_MANAGER = "cert-manager.io" +API_GROUP_SYMPHONY = "solution.symphony" +API_GROUP_OPENSHIFT_SECURITY = "security.openshift.io" +API_GROUP_METRICS = "metrics.k8s.io" + +# cert-manager CRD detection +CERT_MANAGER_CRD_SUFFIX = ".cert-manager.io" +CERT_MANAGER_ISSUER_PLURAL = "clusterissuers" + +# StorageClass annotations (check both v1 and beta) +SC_DEFAULT_ANNOTATION_V1 = "storageclass.kubernetes.io/is-default-class" +SC_DEFAULT_ANNOTATION_BETA = "storageclass.beta.kubernetes.io/is-default-class" + +# PSA label prefix +PSA_LABEL_PREFIX = "pod-security.kubernetes.io/" + +# Check categories +CATEGORY_CLUSTER_INFO = "cluster-info" +CATEGORY_NODE_HEALTH = "node-health" +CATEGORY_DNS_HEALTH = "dns-health" +CATEGORY_STORAGE = "storage" +CATEGORY_REGISTRY_ACCESS = "registry-access" +CATEGORY_CERT_MANAGER = "cert-manager" +CATEGORY_WO_COMPONENTS = "wo-components" +CATEGORY_ADMISSION_CONTROLLERS = "admission-controllers" +CATEGORY_CONNECTIVITY = "connectivity" +CATEGORY_RBAC = "rbac" + +# Check result statuses +STATUS_PASS = "PASS" +STATUS_FAIL = "FAIL" +STATUS_WARN = "WARN" +STATUS_SKIP = "SKIP" +STATUS_ERROR = "ERROR" + +# Minimum resource requirements +MIN_CPU_CORES = 2 +MIN_MEMORY_GI = 4 +MIN_NODE_COUNT_PROD = 3 + +# Bundle folder structure +FOLDER_LOGS = "logs" +FOLDER_RESOURCES = "resources" +FOLDER_CHECKS = "checks" +FOLDER_CLUSTER_INFO = "cluster-info" diff --git a/src/workload-orchestration/azext_workload_orchestration/support/utils.py b/src/workload-orchestration/azext_workload_orchestration/support/utils.py new file mode 100644 index 00000000000..347d0605722 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/support/utils.py @@ -0,0 +1,419 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=import-outside-toplevel,too-many-return-statements,too-many-branches +# pylint: disable=raise-missing-from,too-many-locals,broad-exception-caught +# pylint: disable=too-many-arguments,too-many-positional-arguments + +"""Utility functions for the workload-orchestration support bundle feature.""" + +import json +import os +import shutil +from datetime import datetime, timezone + +from knack.log import get_logger + +from azext_workload_orchestration.support.consts import ( + BUNDLE_PREFIX, + FOLDER_LOGS, + FOLDER_RESOURCES, + FOLDER_CHECKS, + FOLDER_CLUSTER_INFO, +) + +logger = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Kubernetes client initialization +# --------------------------------------------------------------------------- + +def get_kubernetes_client(kube_config=None, kube_context=None): + """Initialize and return kubernetes API clients. + + Returns a dict with 'core_v1', 'apps_v1', 'custom_objects', 'storage_v1', + 'admissionregistration_v1', 'apis', 'version' clients, plus 'context_info' + with the active context name, cluster, and kubeconfig path. + """ + try: + from kubernetes import client, config + from kubernetes.config import list_kube_config_contexts + except ImportError: + raise CLIError( + "The 'kubernetes' package is required. " + "Install it with: pip install kubernetes>=24.2.0" + ) + + config_file = kube_config or os.path.expanduser("~/.kube/config") + + # Read context info before loading + context_info = {"context": "unknown", "cluster": "unknown", "kubeconfig": config_file} + try: + _contexts, active = list_kube_config_contexts(config_file=kube_config) + if active: + context_info["context"] = active.get("name", "unknown") + context_info["cluster"] = active.get("context", {}).get("cluster", "unknown") + context_info["user"] = active.get("context", {}).get("user", "unknown") + if kube_context: + context_info["context"] = kube_context + except (TypeError, KeyError, FileNotFoundError, OSError): + pass + + try: + config.load_kube_config( + config_file=kube_config, + context=kube_context, + ) + except config.ConfigException as ex: + raise CLIError( + f"Failed to load kubeconfig: {ex}. " + "Make sure you have a valid kubeconfig file at " + f"'{config_file}'. Run 'az aks get-credentials' or " + "'export KUBECONFIG=/path/to/config'." + ) + except Exception as ex: + raise CLIError( + f"Failed to load kubeconfig: {ex}. " + "Make sure you have a valid kubeconfig and cluster context. " + "Run 'az aks get-credentials' or set KUBECONFIG." + ) + + return { + "core_v1": client.CoreV1Api(), + "apps_v1": client.AppsV1Api(), + "custom_objects": client.CustomObjectsApi(), + "storage_v1": client.StorageV1Api(), + "admissionregistration_v1": client.AdmissionregistrationV1Api(), + "apis": client.ApisApi(), + "version": client.VersionApi(), + "context_info": context_info, + } + + +# --------------------------------------------------------------------------- +# Bundle directory management +# --------------------------------------------------------------------------- + +def create_bundle_directory(output_dir=None, bundle_name=None): + """Create the bundle directory structure and return its path. + + Args: + output_dir: Optional directory to create the bundle in. + bundle_name: Optional custom name for the bundle. Defaults to + wo-support-bundle-YYYYMMDD-HHMMSS. + + Returns (bundle_dir, bundle_name) tuple. + """ + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") + if bundle_name: + # Sanitize: replace spaces/special chars, append timestamp for uniqueness + import re + safe_name = re.sub(r'[^\w\-.]', '-', bundle_name).strip('-') + bundle_name = f"{safe_name}-{timestamp}" + else: + bundle_name = f"{BUNDLE_PREFIX}-{timestamp}" + + if output_dir: + base = os.path.abspath(output_dir) + os.makedirs(base, exist_ok=True) + else: + base = os.getcwd() + + bundle_dir = os.path.join(base, bundle_name) + os.makedirs(bundle_dir, exist_ok=True) + + # Create sub-folders + for folder in (FOLDER_LOGS, FOLDER_RESOURCES, FOLDER_CHECKS, FOLDER_CLUSTER_INFO): + os.makedirs(os.path.join(bundle_dir, folder), exist_ok=True) + + # Create per-namespace log directories + # (populated later when we know which namespaces to collect) + return bundle_dir, bundle_name + + +def create_namespace_log_dir(bundle_dir, namespace): + """Create a log subdirectory for a namespace.""" + ns_dir = os.path.join(bundle_dir, FOLDER_LOGS, namespace) + os.makedirs(ns_dir, exist_ok=True) + return ns_dir + + +def create_zip_bundle(bundle_dir, bundle_name, output_dir=None): + """Zip the bundle directory and remove the raw folder. + + Returns the path to the zip file. If zip creation fails, keeps the raw + directory so data is not lost. + """ + if output_dir: + zip_base = os.path.join(os.path.abspath(output_dir), bundle_name) + else: + zip_base = os.path.join(os.path.dirname(bundle_dir), bundle_name) + + try: + zip_path = shutil.make_archive(zip_base, "zip", os.path.dirname(bundle_dir), bundle_name) + except (IOError, OSError, PermissionError) as ex: + logger.warning("Failed to create zip: %s. Raw bundle preserved at: %s", ex, bundle_dir) + raise CLIError( + f"Failed to create zip bundle: {ex}. " + f"Raw bundle data preserved at: {bundle_dir}" + ) + + # Only clean up raw directory after successful zip + shutil.rmtree(bundle_dir, ignore_errors=True) + + return zip_path + + +# --------------------------------------------------------------------------- +# Safe API call wrapper +# --------------------------------------------------------------------------- + +def safe_api_call(func, *args, description="API call", max_retries=None, + timeout_seconds=None, **kwargs): + """Execute a kubernetes API call with error handling, timeout, and retry. + + Returns (result, error_string). On success error_string is None. + On failure result is None and error_string describes the problem. + + Args: + func: The kubernetes API method to call. + description: Human-readable description for logging. + max_retries: Number of retries on transient errors (default from consts). + timeout_seconds: Per-call timeout in seconds (default from consts). + **kwargs: Additional keyword arguments passed to the API call. + """ + import time as _time + + try: + from kubernetes.client.exceptions import ApiException + except ImportError: + return None, "kubernetes package not available" + + from azext_workload_orchestration.support.consts import ( + DEFAULT_MAX_RETRIES, + DEFAULT_RETRY_BACKOFF_BASE, + DEFAULT_API_TIMEOUT_SECONDS, + ) + + retries = max_retries if max_retries is not None else DEFAULT_MAX_RETRIES + timeout = timeout_seconds if timeout_seconds is not None else DEFAULT_API_TIMEOUT_SECONDS + + # Inject timeout into the API call if not already set + if "_request_timeout" not in kwargs: + kwargs["_request_timeout"] = timeout + + _NON_RETRYABLE = {400, 401, 403, 404, 405, 409, 422} + + last_err = None + for attempt in range(retries + 1): + try: + result = func(*args, **kwargs) + return result, None + except ApiException as ex: + if ex.status == 403: + msg = ( + f"Permission denied for {description} (403 Forbidden). " + "The service account may lack the required RBAC role. " + "Ensure the user has at least 'view' ClusterRole binding." + ) + logger.warning(msg) + return None, msg + if ex.status == 401: + msg = ( + f"Authentication failed for {description} (401 Unauthorized). " + "Cluster credentials may be expired. " + "Run 'az aks get-credentials' to refresh." + ) + logger.warning(msg) + return None, msg + if ex.status == 404: + msg = f"Resource not found for {description} (404)" + logger.debug(msg) + return None, msg + if ex.status in _NON_RETRYABLE: + msg = f"{description} failed: {ex.status} {ex.reason}" + logger.warning(msg) + return None, msg + # Retryable error (429, 500, 502, 503, 504, etc.) + last_err = f"{description} failed: {ex.status} {ex.reason}" + if attempt < retries: + wait = DEFAULT_RETRY_BACKOFF_BASE * (2 ** attempt) + logger.debug("Retrying %s in %.1fs (attempt %d/%d): %s", + description, wait, attempt + 1, retries, last_err) + _time.sleep(wait) + else: + logger.warning("%s (exhausted %d retries)", last_err, retries) + except (ConnectionError, TimeoutError, OSError) as ex: + last_err = f"{description} failed: {type(ex).__name__}: {ex}" + if attempt < retries: + wait = DEFAULT_RETRY_BACKOFF_BASE * (2 ** attempt) + logger.debug("Retrying %s in %.1fs (attempt %d/%d): %s", + description, wait, attempt + 1, retries, last_err) + _time.sleep(wait) + else: + logger.warning("%s (exhausted %d retries)", last_err, retries) + except Exception as ex: + msg = f"{description} failed: {type(ex).__name__}: {ex}" + logger.warning(msg) + return None, msg + + return None, last_err + + +# --------------------------------------------------------------------------- +# File writers +# --------------------------------------------------------------------------- + +def write_json(filepath, data): + """Write data as formatted JSON. Returns True on success.""" + try: + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, default=str) + return True + except (IOError, OSError, PermissionError, TypeError) as ex: + logger.warning("Failed to write %s: %s", filepath, ex) + return False + + +def write_text(filepath, text): + """Write plain text to file. Returns True on success.""" + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write(text if text else "") + return True + except (IOError, OSError, PermissionError) as ex: + logger.warning("Failed to write %s: %s", filepath, ex) + return False + + +def write_check_result(bundle_dir, category, check_name, status, message, details=None): + """Write a single prerequisite check result to the checks folder. + + Returns a dict representing the check result. + """ + result = { + "category": category, + "check_name": check_name, + "status": status, + "message": message, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + if details: + result["details"] = details + + filepath = os.path.join(bundle_dir, FOLDER_CHECKS, f"{category}--{check_name}.json") + write_json(filepath, result) + return result + + +# --------------------------------------------------------------------------- +# Resource parsing helpers +# --------------------------------------------------------------------------- + +def parse_cpu(cpu_str): + """Parse Kubernetes CPU string to float cores. + + Examples: '3860m' -> 3.86, '4' -> 4.0, '500m' -> 0.5 + """ + if not cpu_str: + return 0.0 + cpu_str = str(cpu_str).strip() + if cpu_str.endswith("m"): + return float(cpu_str[:-1]) / 1000.0 + return float(cpu_str) + + +def parse_memory_gi(mem_str): + """Parse Kubernetes memory string to GiB. + + Examples: '27601704Ki' -> ~26.3, '4Gi' -> 4.0, '4096Mi' -> 4.0 + """ + if not mem_str: + return 0.0 + mem_str = str(mem_str).strip() + if mem_str.endswith("Ki"): + return float(mem_str[:-2]) / (1024 * 1024) + if mem_str.endswith("Mi"): + return float(mem_str[:-2]) / 1024 + if mem_str.endswith("Gi"): + return float(mem_str[:-2]) + if mem_str.endswith("Ti"): + return float(mem_str[:-2]) * 1024 + # Plain bytes + try: + return float(mem_str) / (1024 ** 3) + except ValueError: + return 0.0 + + +def format_bytes(size_bytes): + """Format byte count to human-readable string.""" + if size_bytes < 1024: + return f"{size_bytes} B" + if size_bytes < 1024 * 1024: + return f"{size_bytes / 1024:.1f} KB" + if size_bytes < 1024 * 1024 * 1024: + return f"{size_bytes / (1024 * 1024):.1f} MB" + return f"{size_bytes / (1024 ** 3):.1f} GB" + + +def check_disk_space(path, estimated_bytes): + """Check if there is enough disk space. Returns (ok, free_bytes).""" + usage = shutil.disk_usage(path) + needed = estimated_bytes * 2 # raw + zip + return usage.free >= needed, usage.free + + +# --------------------------------------------------------------------------- +# Detect cluster capabilities +# --------------------------------------------------------------------------- + +def detect_cluster_capabilities(clients): + """Detect which optional components are installed on the cluster. + + Returns a dict of capability booleans. + """ + apis_client = clients["apis"] + result, err = safe_api_call(apis_client.get_api_versions, description="get API groups") + if err: + logger.warning("Could not detect cluster capabilities: %s", err) + return { + "has_gatekeeper": False, "has_kyverno": False, + "has_cert_manager": False, "has_symphony": False, + "has_openshift": False, "has_metrics": False, + } + + group_names = {g.name for g in (result.groups or [])} + + from azext_workload_orchestration.support.consts import ( + API_GROUP_GATEKEEPER_TEMPLATES, + API_GROUP_KYVERNO, + API_GROUP_CERT_MANAGER, + API_GROUP_SYMPHONY, + API_GROUP_OPENSHIFT_SECURITY, + API_GROUP_METRICS, + ) + + return { + "has_gatekeeper": API_GROUP_GATEKEEPER_TEMPLATES in group_names, + "has_kyverno": API_GROUP_KYVERNO in group_names, + "has_cert_manager": API_GROUP_CERT_MANAGER in group_names, + "has_symphony": API_GROUP_SYMPHONY in group_names, + "has_openshift": API_GROUP_OPENSHIFT_SECURITY in group_names, + "has_metrics": API_GROUP_METRICS in group_names, + } + + +# --------------------------------------------------------------------------- +# CLI error helper +# --------------------------------------------------------------------------- + +try: + from azure.cli.core.azclierror import CLIError +except ImportError: + # Fallback for testing outside azure-cli + class CLIError(Exception): + pass diff --git a/src/workload-orchestration/azext_workload_orchestration/support/validators.py b/src/workload-orchestration/azext_workload_orchestration/support/validators.py new file mode 100644 index 00000000000..38f2f79c3ae --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/support/validators.py @@ -0,0 +1,849 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=unused-argument,import-outside-toplevel,too-many-locals + +"""Prerequisite validators for the workload-orchestration support bundle feature. + +Each check function has the same signature (clients, bundle_dir, cluster_info, capabilities) +for consistency. Not all checks use all arguments. +""" + +from knack.log import get_logger + +from azext_workload_orchestration.support.consts import ( + CATEGORY_CLUSTER_INFO, + CATEGORY_NODE_HEALTH, + CATEGORY_DNS_HEALTH, + CATEGORY_STORAGE, + CATEGORY_REGISTRY_ACCESS, + CATEGORY_CERT_MANAGER, + CATEGORY_WO_COMPONENTS, + CATEGORY_ADMISSION_CONTROLLERS, + CATEGORY_CONNECTIVITY, + MIN_CPU_CORES, + MIN_MEMORY_GI, + MIN_NODE_COUNT_PROD, + DNS_SERVICE_LABEL, + WO_NAMESPACE, + CERT_MANAGER_NAMESPACE, + PROTECTED_NAMESPACES, + STATUS_PASS, + STATUS_FAIL, + STATUS_WARN, + STATUS_SKIP, + STATUS_ERROR, + PSA_LABEL_PREFIX, +) +from azext_workload_orchestration.support.utils import ( + safe_api_call, + write_check_result, + parse_cpu, + parse_memory_gi, +) + +logger = get_logger(__name__) + + +def run_all_checks(clients, bundle_dir, cluster_info, capabilities): + """Run all prerequisite validation checks. + + Returns a list of check result dicts. + """ + results = [] + + checks = [ + (_check_k8s_version, "Kubernetes version compatibility"), + (_check_node_readiness, "Node readiness"), + (_check_node_capacity, "Node capacity (CPU/memory)"), + (_check_cluster_resources, "Cluster-wide resource availability"), + (_check_dns_health, "CoreDNS health"), + (_check_dns_resolution, "DNS resolution"), + (_check_default_storage_class, "Default StorageClass"), + (_check_csi_drivers, "CSI drivers"), + (_check_cert_manager, "cert-manager installation"), + (_check_arc_dependencies, "Azure Arc dependencies"), + (_check_wo_namespace, "WO namespace exists"), + (_check_protected_namespace, "Protected namespace check"), + (_check_wo_pods, "WO pods running"), + (_check_wo_services_deployments, "WO services and deployments"), + (_check_wo_webhooks, "WO webhook health"), + (_check_admission_controllers, "Admission controller detection"), + (_check_psa_labels, "Pod Security Admission labels"), + (_check_resource_quotas, "Resource quotas on WO namespace"), + (_check_image_pull_secrets, "Image pull secrets"), + (_check_proxy_settings, "Proxy configuration"), + ] + + for check_fn, description in checks: + try: + result = check_fn(clients, bundle_dir, cluster_info, capabilities) + results.append(result) + status_icon = { + STATUS_PASS: "✓", STATUS_FAIL: "✗", STATUS_WARN: "⚠", + STATUS_SKIP: "—", STATUS_ERROR: "!" + }.get(result["status"], "?") + logger.info(" %s %s: %s", status_icon, description, result["message"]) + except Exception as ex: # pylint: disable=broad-exception-caught + err_result = write_check_result( + bundle_dir, "error", description.replace(" ", "-").lower(), + STATUS_ERROR, f"Check crashed: {ex}" + ) + results.append(err_result) + logger.warning(" ! %s: crashed (%s)", description, ex) + + return results + + +# --------------------------------------------------------------------------- +# Individual checks +# --------------------------------------------------------------------------- + +def _check_k8s_version(clients, bundle_dir, cluster_info, capabilities): + """Check Kubernetes version is in supported range.""" + version_info = cluster_info.get("server_version", {}) + git_version = version_info.get("git_version", "unknown") + + try: + major = int(version_info.get("major", "0").rstrip("+")) + minor = int(version_info.get("minor", "0").rstrip("+")) + except ValueError: + return write_check_result( + bundle_dir, CATEGORY_CLUSTER_INFO, "k8s-version", + STATUS_WARN, f"Could not parse version: {git_version}" + ) + + # WO supports K8s 1.24+ + if major == 1 and minor >= 24: + return write_check_result( + bundle_dir, CATEGORY_CLUSTER_INFO, "k8s-version", + STATUS_PASS, f"Kubernetes {git_version} is supported (>=1.24)" + ) + + return write_check_result( + bundle_dir, CATEGORY_CLUSTER_INFO, "k8s-version", + STATUS_FAIL, f"Kubernetes {git_version} may not be supported (require >=1.24)", + details={"major": major, "minor": minor} + ) + + +def _check_node_readiness(clients, bundle_dir, cluster_info, capabilities): + """Check all nodes are Ready with no pressure conditions.""" + nodes = cluster_info.get("nodes") or [] + if not nodes: + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "node-readiness", + STATUS_FAIL, "No nodes found in cluster" + ) + + not_ready = [n["name"] for n in nodes if n.get("ready") != "True"] + pressure_nodes = [] + for n in nodes: + conditions = n.get("conditions", {}) + pressures = [ + ctype for ctype in ("DiskPressure", "MemoryPressure", "PIDPressure") + if conditions.get(ctype) == "True" + ] + if pressures: + pressure_nodes.append({"node": n["name"], "pressures": pressures}) + + if not_ready: + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "node-readiness", + STATUS_FAIL, f"{len(not_ready)} node(s) not Ready: {', '.join(not_ready)}", + details={"not_ready": not_ready, "pressure_nodes": pressure_nodes} + ) + + if pressure_nodes: + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "node-readiness", + STATUS_WARN, f"{len(pressure_nodes)} node(s) have pressure conditions", + details={"pressure_nodes": pressure_nodes} + ) + + node_count = len(nodes) + msg = f"All {node_count} node(s) Ready, no pressure conditions" + if node_count < MIN_NODE_COUNT_PROD: + msg += f" (note: {node_count} nodes, recommend {MIN_NODE_COUNT_PROD}+ for production)" + + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "node-readiness", + STATUS_PASS, msg + ) + + +def _check_node_capacity(clients, bundle_dir, cluster_info, capabilities): + """Check nodes have minimum CPU and memory.""" + nodes = cluster_info.get("nodes") or [] + if not nodes: + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "node-capacity", + STATUS_SKIP, "No nodes to check" + ) + + low_cpu = [] + low_mem = [] + for n in nodes: + cpu = parse_cpu(n.get("allocatable_cpu", "0")) + mem = parse_memory_gi(n.get("allocatable_memory", "0")) + if cpu < MIN_CPU_CORES: + low_cpu.append(f"{n['name']} ({cpu:.1f} cores)") + if mem < MIN_MEMORY_GI: + low_mem.append(f"{n['name']} ({mem:.1f} Gi)") + + issues = [] + if low_cpu: + issues.append(f"Low CPU: {', '.join(low_cpu)} (min {MIN_CPU_CORES} cores)") + if low_mem: + issues.append(f"Low memory: {', '.join(low_mem)} (min {MIN_MEMORY_GI} Gi)") + + if issues: + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "node-capacity", + STATUS_WARN, "; ".join(issues), + details={"low_cpu": low_cpu, "low_mem": low_mem} + ) + + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "node-capacity", + STATUS_PASS, f"All {len(nodes)} nodes meet minimum requirements (CPU>={MIN_CPU_CORES}, Mem>={MIN_MEMORY_GI}Gi)" + ) + + +def _check_dns_health(clients, bundle_dir, cluster_info, capabilities): + """Check CoreDNS pods are running and DNS service exists.""" + core = clients["core_v1"] + + # Find DNS pods by label (works across most distros) + result, err = safe_api_call( + core.list_namespaced_pod, "kube-system", + label_selector=DNS_SERVICE_LABEL, + description="list DNS pods", + ) + + if err: + return write_check_result( + bundle_dir, CATEGORY_DNS_HEALTH, "dns-pods", + STATUS_WARN, f"Could not check DNS pods: {err}" + ) + + dns_pods = result.items if result else [] + + if not dns_pods: + # Fallback: try searching by name pattern (OpenShift, RKE2, etc.) + result, err = safe_api_call( + core.list_namespaced_pod, "kube-system", + description="list all kube-system pods for DNS fallback", + ) + if result: + dns_pods = [ + p for p in result.items + if "dns" in p.metadata.name.lower() + or "coredns" in p.metadata.name.lower() + ] + + if not dns_pods: + return write_check_result( + bundle_dir, CATEGORY_DNS_HEALTH, "dns-pods", + STATUS_FAIL, "No DNS pods found in kube-system (checked label k8s-app=kube-dns and name pattern)" + ) + + running = [p for p in dns_pods if p.status.phase == "Running"] + if len(running) < len(dns_pods): + not_running = [p.metadata.name for p in dns_pods if p.status.phase != "Running"] + return write_check_result( + bundle_dir, CATEGORY_DNS_HEALTH, "dns-pods", + STATUS_WARN, f"{len(running)}/{len(dns_pods)} DNS pods Running (not running: {', '.join(not_running)})" + ) + + return write_check_result( + bundle_dir, CATEGORY_DNS_HEALTH, "dns-pods", + STATUS_PASS, f"{len(running)} DNS pod(s) Running" + ) + + +def _check_default_storage_class(clients, bundle_dir, cluster_info, capabilities): + """Check a default StorageClass exists.""" + storage = clients["storage_v1"] + result, err = safe_api_call(storage.list_storage_class, description="list storage classes") + if err: + return write_check_result( + bundle_dir, CATEGORY_STORAGE, "default-storage-class", + STATUS_WARN, f"Could not list StorageClasses: {err}" + ) + + from azext_workload_orchestration.support.consts import SC_DEFAULT_ANNOTATION_V1, SC_DEFAULT_ANNOTATION_BETA + + scs = result.items if result else [] + defaults = [] + for sc in scs: + ann = sc.metadata.annotations or {} + if ann.get(SC_DEFAULT_ANNOTATION_V1) == "true" or ann.get(SC_DEFAULT_ANNOTATION_BETA) == "true": + defaults.append(sc.metadata.name) + + if not defaults: + return write_check_result( + bundle_dir, CATEGORY_STORAGE, "default-storage-class", + STATUS_WARN, f"No default StorageClass found ({len(scs)} classes exist)", + details={"storage_classes": [sc.metadata.name for sc in scs]} + ) + + return write_check_result( + bundle_dir, CATEGORY_STORAGE, "default-storage-class", + STATUS_PASS, f"Default StorageClass: {', '.join(defaults)}" + ) + + +def _check_cert_manager(clients, bundle_dir, cluster_info, capabilities): + """Check cert-manager is installed and healthy.""" + if not capabilities.get("has_cert_manager"): + return write_check_result( + bundle_dir, CATEGORY_CERT_MANAGER, "cert-manager-installed", + STATUS_FAIL, "cert-manager CRDs not found (cert-manager.io API group missing)" + ) + + core = clients["core_v1"] + # Check pods in cert-manager namespace + result, err = safe_api_call( + core.list_namespaced_pod, CERT_MANAGER_NAMESPACE, + description="list cert-manager pods", + ) + + if err or not result or not result.items: + return write_check_result( + bundle_dir, CATEGORY_CERT_MANAGER, "cert-manager-installed", + STATUS_WARN, "cert-manager CRDs exist but no pods found in cert-manager namespace" + ) + + pods = result.items + running = [p for p in pods if p.status.phase == "Running"] + + if len(running) < len(pods): + return write_check_result( + bundle_dir, CATEGORY_CERT_MANAGER, "cert-manager-installed", + STATUS_WARN, + f"cert-manager: {len(running)}/{len(pods)} pods Running", + details={"pods": [{"name": p.metadata.name, "phase": p.status.phase} for p in pods]} + ) + + return write_check_result( + bundle_dir, CATEGORY_CERT_MANAGER, "cert-manager-installed", + STATUS_PASS, f"cert-manager healthy: {len(running)} pod(s) Running" + ) + + +# --------------------------------------------------------------------------- +# Azure Arc dependency checks +# --------------------------------------------------------------------------- + +ARC_DEPENDENCY_NAMESPACES = ["azure-arc", "azure-extensions"] + + +def _check_arc_dependencies(clients, bundle_dir, cluster_info, capabilities): + """Check that Azure Arc prerequisite namespaces and components exist.""" + namespaces = cluster_info.get("namespaces") or [] + ns_names = {ns["name"] for ns in namespaces} + + missing = [ns for ns in ARC_DEPENDENCY_NAMESPACES if ns not in ns_names] + found = [ns for ns in ARC_DEPENDENCY_NAMESPACES if ns in ns_names] + + if missing and not found: + return write_check_result( + bundle_dir, CATEGORY_CONNECTIVITY, "arc-dependencies", + STATUS_FAIL, + f"Azure Arc namespaces missing: {', '.join(missing)}. " + "WO requires an Arc-enabled cluster. Run 'az connectedk8s connect' first.", + ) + + if missing: + return write_check_result( + bundle_dir, CATEGORY_CONNECTIVITY, "arc-dependencies", + STATUS_WARN, + f"Partial Arc setup: found {', '.join(found)}, " + f"missing {', '.join(missing)}", + details={"found": found, "missing": missing}, + ) + + # Check azure-arc namespace has healthy pods + core = clients["core_v1"] + result, err = safe_api_call( + core.list_namespaced_pod, "azure-arc", + description="list pods in azure-arc", + ) + if err: + return write_check_result( + bundle_dir, CATEGORY_CONNECTIVITY, "arc-dependencies", + STATUS_WARN, f"Arc namespaces exist but could not verify pods: {err}", + ) + + pods = result.items or [] + running = [p for p in pods if p.status.phase == "Running"] + not_running = [p for p in pods if p.status.phase != "Running"] + + if not_running: + names = [p.metadata.name for p in not_running[:5]] + return write_check_result( + bundle_dir, CATEGORY_CONNECTIVITY, "arc-dependencies", + STATUS_WARN, + f"Arc namespaces present, {len(running)} pod(s) Running, " + f"{len(not_running)} not Running: {', '.join(names)}", + details={"running": len(running), "not_running_pods": names}, + ) + + return write_check_result( + bundle_dir, CATEGORY_CONNECTIVITY, "arc-dependencies", + STATUS_PASS, + f"Azure Arc healthy: namespaces {', '.join(found)} present, " + f"{len(running)} pod(s) Running", + ) + + +def _check_wo_namespace(clients, bundle_dir, cluster_info, capabilities): + """Check the WO namespace exists.""" + namespaces = cluster_info.get("namespaces") or [] + wo_ns = [ns for ns in namespaces if ns["name"] == WO_NAMESPACE] + + if not wo_ns: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-namespace", + STATUS_FAIL, f"Namespace '{WO_NAMESPACE}' not found" + ) + + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-namespace", + STATUS_PASS, f"Namespace '{WO_NAMESPACE}' exists (status: {wo_ns[0]['status']})" + ) + + +def _check_wo_pods(clients, bundle_dir, cluster_info, capabilities): + """Check WO pods are running.""" + core = clients["core_v1"] + result, err = safe_api_call( + core.list_namespaced_pod, WO_NAMESPACE, + description="list WO pods", + ) + + if err: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-pods", + STATUS_WARN, f"Could not list WO pods: {err}" + ) + + pods = result.items if result else [] + if not pods: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-pods", + STATUS_FAIL, f"No pods found in {WO_NAMESPACE}" + ) + + running = [p for p in pods if p.status.phase == "Running"] + not_running = [ + {"name": p.metadata.name, "phase": p.status.phase} + for p in pods if p.status.phase != "Running" + ] + + if not_running: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-pods", + STATUS_WARN, + f"{len(running)}/{len(pods)} WO pods Running", + details={"not_running": not_running} + ) + + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-pods", + STATUS_PASS, f"All {len(running)} WO pods Running" + ) + + +def _check_wo_services_deployments(clients, bundle_dir, cluster_info, capabilities): + """Check WO services and deployments are healthy.""" + core = clients["core_v1"] + apps = clients["apps_v1"] + + issues = [] + + # Check deployments + result, err = safe_api_call( + apps.list_namespaced_deployment, WO_NAMESPACE, + description=f"list deployments in {WO_NAMESPACE}", + ) + if err: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-services-deployments", + STATUS_WARN, f"Could not check WO deployments: {err}" + ) + + deployments = result.items or [] + dep_details = [] + for d in deployments: + desired = d.spec.replicas or 0 + ready = d.status.ready_replicas or 0 + dep_details.append({ + "name": d.metadata.name, + "desired": desired, + "ready": ready, + }) + if ready < desired: + issues.append(f"Deployment {d.metadata.name}: {ready}/{desired} ready") + + # Check services + result, err = safe_api_call( + core.list_namespaced_service, WO_NAMESPACE, + description=f"list services in {WO_NAMESPACE}", + ) + svc_count = len(result.items) if result else 0 + + if issues: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-services-deployments", + STATUS_WARN, + f"{len(deployments)} deployment(s), {svc_count} service(s) — " + f"issues: {'; '.join(issues)}", + details={"deployments": dep_details, "services": svc_count, "issues": issues}, + ) + + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-services-deployments", + STATUS_PASS, + f"{len(deployments)} deployment(s) all healthy, {svc_count} service(s)", + details={"deployments": dep_details, "services": svc_count}, + ) + + +def _check_wo_webhooks(clients, bundle_dir, cluster_info, capabilities): + """Check Symphony validating/mutating webhooks are configured.""" + admission = clients["admissionregistration_v1"] + + result, err = safe_api_call( + admission.list_validating_webhook_configuration, + description="list validating webhooks for WO check", + ) + if err: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-webhooks", + STATUS_WARN, f"Could not list webhooks: {err}" + ) + + vwcs = result.items if result else [] + symphony_vwc = [w for w in vwcs if "symphony" in w.metadata.name.lower()] + + if not symphony_vwc: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-webhooks", + STATUS_WARN, "No Symphony validating webhook configurations found" + ) + + total_hooks = sum(len(w.webhooks or []) for w in symphony_vwc) + fail_hooks = sum( + 1 for w in symphony_vwc for wh in (w.webhooks or []) if wh.failure_policy == "Fail" + ) + + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "wo-webhooks", + STATUS_PASS, + f"Symphony webhooks configured: {total_hooks} hooks ({fail_hooks} with failurePolicy=Fail)", + details={"configs": [w.metadata.name for w in symphony_vwc]} + ) + + +def _check_admission_controllers(clients, bundle_dir, cluster_info, capabilities): + """Detect and report admission controllers.""" + detected = [] + if capabilities.get("has_gatekeeper"): + detected.append("Gatekeeper") + if capabilities.get("has_kyverno"): + detected.append("Kyverno") + if capabilities.get("has_openshift"): + detected.append("OpenShift SCC") + + if not detected: + return write_check_result( + bundle_dir, CATEGORY_ADMISSION_CONTROLLERS, "policy-engines", + STATUS_PASS, "No additional policy engines detected (Gatekeeper/Kyverno/OpenShift SCC)" + ) + + return write_check_result( + bundle_dir, CATEGORY_ADMISSION_CONTROLLERS, "policy-engines", + STATUS_PASS, f"Policy engines detected: {', '.join(detected)}", + details={"engines": detected} + ) + + +def _check_psa_labels(clients, bundle_dir, cluster_info, capabilities): + """Check PSA enforcement labels on WO-relevant namespaces.""" + namespaces = cluster_info.get("namespaces") or [] + enforced = [] + + for ns in namespaces: + if ns["name"] not in (WO_NAMESPACE, CERT_MANAGER_NAMESPACE, "default"): + continue + labels = ns.get("labels", {}) + enforce = labels.get(f"{PSA_LABEL_PREFIX}enforce") + if enforce: + enforced.append({"namespace": ns["name"], "level": enforce}) + + if not enforced: + return write_check_result( + bundle_dir, CATEGORY_ADMISSION_CONTROLLERS, "psa-labels", + STATUS_PASS, "No PSA enforce labels on WO-relevant namespaces" + ) + + restricted = [e for e in enforced if e["level"] == "restricted"] + if restricted: + return write_check_result( + bundle_dir, CATEGORY_ADMISSION_CONTROLLERS, "psa-labels", + STATUS_WARN, + f"PSA enforce=restricted on: {', '.join(e['namespace'] for e in restricted)} " + "(test pods may need explicit securityContext)", + details={"enforced": enforced} + ) + + return write_check_result( + bundle_dir, CATEGORY_ADMISSION_CONTROLLERS, "psa-labels", + STATUS_PASS, f"PSA labels found but not restricted: {enforced}", + details={"enforced": enforced} + ) + + +def _check_dns_resolution(clients, bundle_dir, cluster_info, capabilities): + """Check DNS resolution works for internal and external names (client-side).""" + import socket + + from azext_workload_orchestration.support.consts import DNS_EXTERNAL_HOST + + results_detail = {} + + # External DNS check (from the client machine running az cli) + try: + addr = socket.getaddrinfo(DNS_EXTERNAL_HOST, 443, socket.AF_INET) + results_detail["external_dns"] = { + "host": DNS_EXTERNAL_HOST, "resolved": True, + "addresses": list({a[4][0] for a in addr}), + } + except (socket.gaierror, socket.timeout, OSError) as ex: + results_detail["external_dns"] = { + "host": DNS_EXTERNAL_HOST, "resolved": False, "error": str(ex), + } + return write_check_result( + bundle_dir, CATEGORY_DNS_HEALTH, "dns-resolution", + STATUS_WARN, + f"Cannot resolve {DNS_EXTERNAL_HOST} from client (may be expected in air-gapped environments)", + details=results_detail, + ) + + return write_check_result( + bundle_dir, CATEGORY_DNS_HEALTH, "dns-resolution", + STATUS_PASS, + f"DNS resolution OK: {DNS_EXTERNAL_HOST} resolves from client", + details=results_detail, + ) + + +def _check_resource_quotas(clients, bundle_dir, cluster_info, capabilities): + """Check if resource quotas exist on the WO namespace that could limit pods.""" + core = clients["core_v1"] + + result, err = safe_api_call( + core.list_namespaced_resource_quota, WO_NAMESPACE, + description="list resource quotas on WO namespace", + ) + + if err: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "resource-quotas", + STATUS_SKIP, f"Could not check resource quotas: {err}" + ) + + quotas = result.items if result else [] + if not quotas: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "resource-quotas", + STATUS_PASS, f"No resource quotas on {WO_NAMESPACE} namespace" + ) + + # Check if any quota is near its limit + warnings = [] + for rq in quotas: + hard = rq.status.hard or {} + used = rq.status.used or {} + for resource, limit_str in hard.items(): + used_str = used.get(resource, "0") + try: + limit_val = float(limit_str) + used_val = float(used_str) + if limit_val > 0 and used_val / limit_val > 0.8: + warnings.append(f"{resource}: {used_str}/{limit_str} ({used_val / limit_val * 100:.0f}%)") + except (ValueError, ZeroDivisionError): + pass + + if warnings: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "resource-quotas", + STATUS_WARN, + f"Resource quotas >80% utilized on {WO_NAMESPACE}: {'; '.join(warnings)}", + details={"quotas": [rq.metadata.name for rq in quotas], "warnings": warnings} + ) + + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "resource-quotas", + STATUS_PASS, + f"{len(quotas)} resource quota(s) on {WO_NAMESPACE}, all within limits" + ) + + +def _check_cluster_resources(clients, bundle_dir, cluster_info, capabilities): + """Check cluster-wide aggregate CPU and memory against minimums.""" + nodes = cluster_info.get("nodes") or [] + if not nodes: + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "cluster-resources", + STATUS_SKIP, "No nodes to check" + ) + + total_cpu = 0.0 + total_mem = 0.0 + for n in nodes: + total_cpu += parse_cpu(n.get("allocatable_cpu", "0")) + total_mem += parse_memory_gi(n.get("allocatable_memory", "0")) + + issues = [] + if total_cpu < MIN_CPU_CORES: + issues.append(f"Total CPU {total_cpu:.1f} cores < {MIN_CPU_CORES} minimum") + if total_mem < MIN_MEMORY_GI: + issues.append(f"Total memory {total_mem:.1f}Gi < {MIN_MEMORY_GI}Gi minimum") + + if issues: + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "cluster-resources", + STATUS_WARN, "; ".join(issues), + details={"total_cpu": round(total_cpu, 2), "total_memory_gi": round(total_mem, 2)} + ) + + return write_check_result( + bundle_dir, CATEGORY_NODE_HEALTH, "cluster-resources", + STATUS_PASS, + f"Cluster total: {total_cpu:.1f} CPU cores, {total_mem:.1f}Gi memory " + f"across {len(nodes)} node(s)", + details={"total_cpu": round(total_cpu, 2), "total_memory_gi": round(total_mem, 2)} + ) + + +def _check_protected_namespace(clients, bundle_dir, cluster_info, capabilities): + """Check that the WO namespace is not a protected system namespace.""" + if WO_NAMESPACE in PROTECTED_NAMESPACES or \ + WO_NAMESPACE.startswith("kube-") or \ + WO_NAMESPACE.startswith("azure-"): + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "protected-namespace", + STATUS_FAIL, + f"WO namespace '{WO_NAMESPACE}' is a protected/system namespace", + details={"protected_namespaces": PROTECTED_NAMESPACES} + ) + + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "protected-namespace", + STATUS_PASS, + f"WO namespace '{WO_NAMESPACE}' is not a protected system namespace" + ) + + +def _check_csi_drivers(clients, bundle_dir, cluster_info, capabilities): + """Check for installed CSI drivers.""" + storage = clients["storage_v1"] + result, err = safe_api_call(storage.list_csi_driver, description="list CSI drivers") + + if err: + return write_check_result( + bundle_dir, CATEGORY_STORAGE, "csi-drivers", + STATUS_SKIP, f"Could not list CSI drivers: {err}" + ) + + drivers = result.items if result else [] + if not drivers: + return write_check_result( + bundle_dir, CATEGORY_STORAGE, "csi-drivers", + STATUS_WARN, "No CSI drivers found in cluster" + ) + + driver_names = [d.metadata.name for d in drivers] + return write_check_result( + bundle_dir, CATEGORY_STORAGE, "csi-drivers", + STATUS_PASS, f"{len(drivers)} CSI driver(s): {', '.join(driver_names)}", + details={"drivers": driver_names} + ) + + +def _check_image_pull_secrets(clients, bundle_dir, cluster_info, capabilities): + """Check for image pull secrets across relevant namespaces.""" + core = clients["core_v1"] + pull_secrets = {} + + for ns in [WO_NAMESPACE, CERT_MANAGER_NAMESPACE]: + result, _err = safe_api_call( + core.list_namespaced_secret, ns, + field_selector="type=kubernetes.io/dockerconfigjson", + description=f"list pull secrets in {ns}", + ) + if result and result.items: + pull_secrets[ns] = [s.metadata.name for s in result.items] + + if pull_secrets: + parts = [f"{ns}: {', '.join(names)}" for ns, names in pull_secrets.items()] + return write_check_result( + bundle_dir, CATEGORY_REGISTRY_ACCESS, "image-pull-secrets", + STATUS_PASS, f"Image pull secrets found: {'; '.join(parts)}", + details={"secrets": pull_secrets} + ) + + return write_check_result( + bundle_dir, CATEGORY_REGISTRY_ACCESS, "image-pull-secrets", + STATUS_PASS, + "No image pull secrets in WO namespaces (using default service account credentials)" + ) + + +def _check_proxy_settings(clients, bundle_dir, cluster_info, capabilities): + """Check for HTTP proxy configuration in WO pods.""" + core = clients["core_v1"] + result, err = safe_api_call( + core.list_namespaced_pod, WO_NAMESPACE, + description="list WO pods for proxy check", + ) + + if err or not result or not result.items: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "proxy-settings", + STATUS_SKIP, f"Could not check proxy settings: {err or 'no pods found'}" + ) + + proxy_vars = ("HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", + "http_proxy", "https_proxy", "no_proxy") + pods_with_proxy = [] + + for pod in result.items: + for container in (pod.spec.containers or []): + for env in (container.env or []): + if env.name in proxy_vars: + pods_with_proxy.append({ + "pod": pod.metadata.name, + "container": container.name, + "var": env.name, + "value": env.value or "(from ref)", + }) + + if pods_with_proxy: + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "proxy-settings", + STATUS_WARN, + f"Proxy env vars found in {len(pods_with_proxy)} container(s) — " + "verify proxy allows access to mcr.microsoft.com", + details={"proxy_configs": pods_with_proxy} + ) + + return write_check_result( + bundle_dir, CATEGORY_WO_COMPONENTS, "proxy-settings", + STATUS_PASS, "No proxy environment variables in WO pods" + ) diff --git a/src/workload-orchestration/azext_workload_orchestration/tests/conftest.py b/src/workload-orchestration/azext_workload_orchestration/tests/conftest.py new file mode 100644 index 00000000000..71d8f39fc35 --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/tests/conftest.py @@ -0,0 +1,17 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +"""Pytest conftest - ensure extension root is on sys.path. + +Mock module setup is handled by the root-level conftest.py at +src/workload-orchestration/conftest.py which runs before this file. +""" + +import os +import sys + +# Ensure the extension package root is on sys.path +_ext_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _ext_root not in sys.path: + sys.path.insert(0, _ext_root) diff --git a/src/workload-orchestration/azext_workload_orchestration/tests/test_support_bundle.py b/src/workload-orchestration/azext_workload_orchestration/tests/test_support_bundle.py new file mode 100644 index 00000000000..b7b290be2bb --- /dev/null +++ b/src/workload-orchestration/azext_workload_orchestration/tests/test_support_bundle.py @@ -0,0 +1,2218 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +"""Unit tests for the support bundle feature.""" + +import json +import os +import shutil +import sys +import tempfile +import types +import unittest +from unittest.mock import MagicMock, patch, PropertyMock + +# Ensure the extension package is importable regardless of how the test is invoked +_ext_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _ext_root not in sys.path: + sys.path.insert(0, _ext_root) + +# Mock azure CLI modules +_azure = types.ModuleType("azure") +_azure_cli = types.ModuleType("azure.cli") +_azure_cli_core = types.ModuleType("azure.cli.core") +_azure_cli_core.AzCommandsLoader = type("AzCommandsLoader", (), {}) +_azure_cli_commands = types.ModuleType("azure.cli.core.commands") +_azure_cli_commands.CliCommandType = type("CliCommandType", (), {"__init__": lambda self, **kw: None}) +_azure_cli_aaz = types.ModuleType("azure.cli.core.aaz") +_azure_cli_aaz.load_aaz_command_table = lambda **kw: None +_azure_cli_params = types.ModuleType("azure.cli.core.commands.parameters") +_azure_cli_params.get_enum_type = lambda x: x +_azure_cli_azclierror = types.ModuleType("azure.cli.core.azclierror") +_azure_cli_azclierror.CLIError = Exception +_knack = types.ModuleType("knack") +_knack_log = types.ModuleType("knack.log") +import logging # noqa: E402 +_knack_log.get_logger = logging.getLogger +_knack_help = types.ModuleType("knack.help_files") +_knack_help.helps = {} + +for mod_name, mod in [ + ("azure", _azure), ("azure.cli", _azure_cli), + ("azure.cli.core", _azure_cli_core), + ("azure.cli.core.commands", _azure_cli_commands), + ("azure.cli.core.aaz", _azure_cli_aaz), + ("azure.cli.core.commands.parameters", _azure_cli_params), + ("azure.cli.core.azclierror", _azure_cli_azclierror), + ("knack", _knack), ("knack.log", _knack_log), + ("knack.help_files", _knack_help), +]: + sys.modules[mod_name] = mod + + +# --------------------------------------------------------------------------- +# Tests for _support_consts +# --------------------------------------------------------------------------- + +class TestConstants(unittest.TestCase): + def test_default_namespaces(self): + from azext_workload_orchestration.support.consts import DEFAULT_NAMESPACES + self.assertEqual(len(DEFAULT_NAMESPACES), 3) + self.assertIn("kube-system", DEFAULT_NAMESPACES) + self.assertIn("workloadorchestration", DEFAULT_NAMESPACES) + self.assertIn("cert-manager", DEFAULT_NAMESPACES) + + def test_default_tail_lines(self): + from azext_workload_orchestration.support.consts import DEFAULT_TAIL_LINES + self.assertEqual(DEFAULT_TAIL_LINES, 1000) + + def test_status_constants(self): + from azext_workload_orchestration.support.consts import ( + STATUS_PASS, STATUS_FAIL, STATUS_WARN, STATUS_SKIP, STATUS_ERROR, + ) + self.assertEqual(STATUS_PASS, "PASS") + self.assertEqual(STATUS_FAIL, "FAIL") + self.assertEqual(STATUS_WARN, "WARN") + self.assertEqual(STATUS_SKIP, "SKIP") + self.assertEqual(STATUS_ERROR, "ERROR") + + +# --------------------------------------------------------------------------- +# Tests for _support_utils +# --------------------------------------------------------------------------- + +class TestParseCpu(unittest.TestCase): + def test_millicores(self): + from azext_workload_orchestration.support.utils import parse_cpu + self.assertAlmostEqual(parse_cpu("3860m"), 3.86) + self.assertAlmostEqual(parse_cpu("500m"), 0.5) + self.assertAlmostEqual(parse_cpu("100m"), 0.1) + + def test_whole_cores(self): + from azext_workload_orchestration.support.utils import parse_cpu + self.assertEqual(parse_cpu("4"), 4.0) + self.assertEqual(parse_cpu("1"), 1.0) + + def test_empty_and_none(self): + from azext_workload_orchestration.support.utils import parse_cpu + self.assertEqual(parse_cpu(""), 0.0) + self.assertEqual(parse_cpu(None), 0.0) + + +class TestParseMemory(unittest.TestCase): + def test_ki(self): + from azext_workload_orchestration.support.utils import parse_memory_gi + result = parse_memory_gi("27601704Ki") + self.assertAlmostEqual(result, 26.32, places=1) + + def test_mi(self): + from azext_workload_orchestration.support.utils import parse_memory_gi + self.assertAlmostEqual(parse_memory_gi("4096Mi"), 4.0) + + def test_gi(self): + from azext_workload_orchestration.support.utils import parse_memory_gi + self.assertEqual(parse_memory_gi("4Gi"), 4.0) + self.assertEqual(parse_memory_gi("16Gi"), 16.0) + + def test_ti(self): + from azext_workload_orchestration.support.utils import parse_memory_gi + self.assertEqual(parse_memory_gi("1Ti"), 1024.0) + + def test_empty_and_none(self): + from azext_workload_orchestration.support.utils import parse_memory_gi + self.assertEqual(parse_memory_gi(""), 0.0) + self.assertEqual(parse_memory_gi(None), 0.0) + + +class TestFormatBytes(unittest.TestCase): + def test_bytes(self): + from azext_workload_orchestration.support.utils import format_bytes + self.assertEqual(format_bytes(500), "500 B") + + def test_kb(self): + from azext_workload_orchestration.support.utils import format_bytes + self.assertEqual(format_bytes(1536), "1.5 KB") + + def test_mb(self): + from azext_workload_orchestration.support.utils import format_bytes + self.assertEqual(format_bytes(3660710), "3.5 MB") + + def test_gb(self): + from azext_workload_orchestration.support.utils import format_bytes + result = format_bytes(2 * 1024 * 1024 * 1024) + self.assertEqual(result, "2.0 GB") + + +class TestBundleDirectory(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_creates_structure(self): + from azext_workload_orchestration.support.utils import create_bundle_directory + from azext_workload_orchestration.support.consts import ( + FOLDER_LOGS, FOLDER_RESOURCES, FOLDER_CHECKS, FOLDER_CLUSTER_INFO, + ) + bundle_dir, bundle_name = create_bundle_directory(self.tmpdir) + self.assertTrue(os.path.isdir(bundle_dir)) + self.assertTrue(os.path.isdir(os.path.join(bundle_dir, FOLDER_LOGS))) + self.assertTrue(os.path.isdir(os.path.join(bundle_dir, FOLDER_RESOURCES))) + self.assertTrue(os.path.isdir(os.path.join(bundle_dir, FOLDER_CHECKS))) + self.assertTrue(os.path.isdir(os.path.join(bundle_dir, FOLDER_CLUSTER_INFO))) + self.assertTrue(bundle_name.startswith("wo-support-bundle-")) + + def test_zip_bundle(self): + from azext_workload_orchestration.support.utils import ( + create_bundle_directory, create_zip_bundle, write_text, + ) + bundle_dir, bundle_name = create_bundle_directory(self.tmpdir) + write_text(os.path.join(bundle_dir, "test.txt"), "hello") + zip_path = create_zip_bundle(bundle_dir, bundle_name, self.tmpdir) + self.assertTrue(os.path.isfile(zip_path)) + self.assertTrue(zip_path.endswith(".zip")) + # Raw dir should be cleaned up + self.assertFalse(os.path.isdir(bundle_dir)) + + +class TestSafeApiCall(unittest.TestCase): + def test_success(self): + from azext_workload_orchestration.support.utils import safe_api_call + mock_fn = MagicMock(return_value="result") + result, err = safe_api_call(mock_fn, "arg1", description="test") + self.assertEqual(result, "result") + self.assertIsNone(err) + + def test_403(self): + from azext_workload_orchestration.support.utils import safe_api_call + from kubernetes.client.exceptions import ApiException + mock_fn = MagicMock(side_effect=ApiException(status=403, reason="Forbidden")) + result, err = safe_api_call(mock_fn, description="test") + self.assertIsNone(result) + self.assertIn("403", err) + + def test_404(self): + from azext_workload_orchestration.support.utils import safe_api_call + from kubernetes.client.exceptions import ApiException + mock_fn = MagicMock(side_effect=ApiException(status=404, reason="Not Found")) + result, err = safe_api_call(mock_fn, description="test") + self.assertIsNone(result) + self.assertIn("404", err) + + def test_generic_exception(self): + from azext_workload_orchestration.support.utils import safe_api_call + mock_fn = MagicMock(side_effect=RuntimeError("boom")) + result, err = safe_api_call(mock_fn, description="test") + self.assertIsNone(result) + self.assertIn("boom", err) + + +class TestWriteCheckResult(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_writes_json(self): + from azext_workload_orchestration.support.utils import write_check_result + result = write_check_result( + self.bundle_dir, "test-cat", "test-check", "PASS", "all good" + ) + self.assertEqual(result["status"], "PASS") + self.assertEqual(result["category"], "test-cat") + filepath = os.path.join(self.bundle_dir, "checks", "test-cat--test-check.json") + self.assertTrue(os.path.isfile(filepath)) + with open(filepath) as f: + data = json.load(f) + self.assertEqual(data["status"], "PASS") + + def test_with_details(self): + from azext_workload_orchestration.support.utils import write_check_result + result = write_check_result( + self.bundle_dir, "cat", "chk", "WARN", "not great", + details={"nodes": ["n1", "n2"]} + ) + self.assertEqual(result["details"]["nodes"], ["n1", "n2"]) + + +class TestCheckDiskSpace(unittest.TestCase): + def test_enough_space(self): + from azext_workload_orchestration.support.utils import check_disk_space + ok, free = check_disk_space(tempfile.gettempdir(), 1024) + self.assertTrue(ok) + self.assertGreater(free, 0) + + +class TestDetectCapabilities(unittest.TestCase): + def test_detects_groups(self): + from azext_workload_orchestration.support.utils import detect_cluster_capabilities + + # Mock the API response + mock_group = MagicMock() + mock_group.name = "cert-manager.io" + mock_result = MagicMock() + mock_result.groups = [mock_group] + + mock_apis = MagicMock() + mock_apis.get_api_versions.return_value = mock_result + + clients = {"apis": mock_apis} + caps = detect_cluster_capabilities(clients) + self.assertTrue(caps["has_cert_manager"]) + self.assertFalse(caps["has_gatekeeper"]) + self.assertFalse(caps["has_openshift"]) + + +# --------------------------------------------------------------------------- +# Tests for _support_validators +# --------------------------------------------------------------------------- + +class TestKubernetesVersionCheck(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_supported_version(self): + from azext_workload_orchestration.support.validators import _check_k8s_version + info = {"server_version": {"major": "1", "minor": "33", "git_version": "v1.33.5"}} + result = _check_k8s_version(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "PASS") + + def test_old_version(self): + from azext_workload_orchestration.support.validators import _check_k8s_version + info = {"server_version": {"major": "1", "minor": "22", "git_version": "v1.22.0"}} + result = _check_k8s_version(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "FAIL") + + def test_edge_version_124(self): + from azext_workload_orchestration.support.validators import _check_k8s_version + info = {"server_version": {"major": "1", "minor": "24", "git_version": "v1.24.0"}} + result = _check_k8s_version(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "PASS") + + +class TestNodeReadinessCheck(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_all_ready(self): + from azext_workload_orchestration.support.validators import _check_node_readiness + info = { + "nodes": [ + {"name": "node1", "ready": "True", "conditions": {"Ready": "True"}}, + {"name": "node2", "ready": "True", "conditions": {"Ready": "True"}}, + ] + } + result = _check_node_readiness(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "PASS") + + def test_node_not_ready(self): + from azext_workload_orchestration.support.validators import _check_node_readiness + info = { + "nodes": [ + {"name": "node1", "ready": "True", "conditions": {"Ready": "True"}}, + {"name": "node2", "ready": "False", "conditions": {"Ready": "False"}}, + ] + } + result = _check_node_readiness(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "FAIL") + self.assertIn("node2", result["message"]) + + def test_node_pressure(self): + from azext_workload_orchestration.support.validators import _check_node_readiness + info = { + "nodes": [ + { + "name": "node1", "ready": "True", + "conditions": {"Ready": "True", "DiskPressure": "True"}, + }, + ] + } + result = _check_node_readiness(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "WARN") + + def test_no_nodes(self): + from azext_workload_orchestration.support.validators import _check_node_readiness + result = _check_node_readiness(None, self.bundle_dir, {"nodes": []}, {}) + self.assertEqual(result["status"], "FAIL") + + +class TestNodeCapacityCheck(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_sufficient_capacity(self): + from azext_workload_orchestration.support.validators import _check_node_capacity + info = {"nodes": [ + {"name": "n1", "allocatable_cpu": "4", "allocatable_memory": "16Gi"}, + ]} + result = _check_node_capacity(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "PASS") + + def test_low_cpu(self): + from azext_workload_orchestration.support.validators import _check_node_capacity + info = {"nodes": [ + {"name": "n1", "allocatable_cpu": "1", "allocatable_memory": "16Gi"}, + ]} + result = _check_node_capacity(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "WARN") + self.assertIn("Low CPU", result["message"]) + + +class TestCertManagerCheck(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_not_installed(self): + from azext_workload_orchestration.support.validators import _check_cert_manager + result = _check_cert_manager(None, self.bundle_dir, {}, {"has_cert_manager": False}) + self.assertEqual(result["status"], "FAIL") + + def test_installed_and_healthy(self): + from azext_workload_orchestration.support.validators import _check_cert_manager + mock_pod = MagicMock() + mock_pod.metadata.name = "cert-manager-xyz" + mock_pod.status.phase = "Running" + + mock_result = MagicMock() + mock_result.items = [mock_pod] + + mock_core = MagicMock() + mock_core.list_namespaced_pod.return_value = mock_result + + clients = {"core_v1": mock_core} + result = _check_cert_manager(clients, self.bundle_dir, {}, {"has_cert_manager": True}) + self.assertEqual(result["status"], "PASS") + + +class TestAdmissionControllersCheck(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_no_engines(self): + from azext_workload_orchestration.support.validators import _check_admission_controllers + caps = {"has_gatekeeper": False, "has_kyverno": False, "has_openshift": False} + result = _check_admission_controllers(None, self.bundle_dir, {}, caps) + self.assertEqual(result["status"], "PASS") + self.assertIn("No additional", result["message"]) + + def test_gatekeeper_detected(self): + from azext_workload_orchestration.support.validators import _check_admission_controllers + caps = {"has_gatekeeper": True, "has_kyverno": False, "has_openshift": False} + result = _check_admission_controllers(None, self.bundle_dir, {}, caps) + self.assertEqual(result["status"], "PASS") + self.assertIn("Gatekeeper", result["message"]) + + +class TestPsaLabelsCheck(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_no_psa(self): + from azext_workload_orchestration.support.validators import _check_psa_labels + info = {"namespaces": [ + {"name": "workloadorchestration", "labels": {}}, + {"name": "cert-manager", "labels": {}}, + ]} + result = _check_psa_labels(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "PASS") + + def test_restricted_psa(self): + from azext_workload_orchestration.support.validators import _check_psa_labels + info = {"namespaces": [ + {"name": "workloadorchestration", "labels": { + "pod-security.kubernetes.io/enforce": "restricted" + }}, + ]} + result = _check_psa_labels(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "WARN") + + +# --------------------------------------------------------------------------- +# Tests for _support_collectors (with mocked K8s API) +# --------------------------------------------------------------------------- + +class TestCollectClusterInfo(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_collects_version_and_nodes(self): + from azext_workload_orchestration.support.collectors import collect_cluster_info + + # Mock version + mock_version = MagicMock() + mock_version.major = "1" + mock_version.minor = "33" + mock_version.git_version = "v1.33.5" + mock_version.platform = "linux/amd64" + mock_version_api = MagicMock() + mock_version_api.get_code.return_value = mock_version + + # Mock node + mock_node = MagicMock() + mock_node.metadata.name = "node1" + mock_node.metadata.labels = {"node-role.kubernetes.io/control-plane": ""} + mock_node.status.conditions = [MagicMock(type="Ready", status="True")] + mock_node.status.node_info.os_image = "AzureLinux 3" + mock_node.status.node_info.container_runtime_version = "containerd://2.0" + mock_node.status.node_info.kubelet_version = "v1.33.5" + mock_node.status.allocatable = {"cpu": "4", "memory": "16Gi"} + mock_node_list = MagicMock() + mock_node_list.items = [mock_node] + + # Mock namespace + mock_ns = MagicMock() + mock_ns.metadata.name = "default" + mock_ns.metadata.labels = {} + mock_ns.status.phase = "Active" + mock_ns_list = MagicMock() + mock_ns_list.items = [mock_ns] + + mock_core = MagicMock() + mock_core.list_node.return_value = mock_node_list + mock_core.list_namespace.return_value = mock_ns_list + + clients = {"core_v1": mock_core, "version": mock_version_api} + info = collect_cluster_info(clients, self.bundle_dir) + + self.assertEqual(info["server_version"]["git_version"], "v1.33.5") + self.assertEqual(info["node_count"], 1) + self.assertEqual(info["nodes"][0]["name"], "node1") + self.assertIn("control-plane", info["nodes"][0]["roles"]) + + # Verify file written + filepath = os.path.join(self.bundle_dir, "cluster-info", "cluster-info.json") + self.assertTrue(os.path.isfile(filepath)) + + +# --------------------------------------------------------------------------- +# Error handling / resilience tests +# --------------------------------------------------------------------------- + +class TestWriteJsonResilience(unittest.TestCase): + """Test that write_json handles I/O errors gracefully.""" + + def test_returns_true_on_success(self): + from azext_workload_orchestration.support.utils import write_json + import tempfile + fd, path = tempfile.mkstemp(suffix=".json") + os.close(fd) + try: + result = write_json(path, {"key": "value"}) + self.assertTrue(result) + finally: + os.unlink(path) + + def test_returns_false_on_bad_path(self): + from azext_workload_orchestration.support.utils import write_json + result = write_json("/nonexistent/dir/file.json", {"key": "value"}) + self.assertFalse(result) + + def test_handles_non_serializable_data(self): + from azext_workload_orchestration.support.utils import write_json + import tempfile + fd, path = tempfile.mkstemp(suffix=".json") + os.close(fd) + try: + # default=str should handle this + result = write_json(path, {"dt": object()}) + self.assertTrue(result) + finally: + os.unlink(path) + + +class TestWriteTextResilience(unittest.TestCase): + """Test that write_text handles I/O errors gracefully.""" + + def test_returns_true_on_success(self): + from azext_workload_orchestration.support.utils import write_text + import tempfile + fd, path = tempfile.mkstemp(suffix=".txt") + os.close(fd) + try: + result = write_text(path, "hello") + self.assertTrue(result) + finally: + os.unlink(path) + + def test_returns_false_on_bad_path(self): + from azext_workload_orchestration.support.utils import write_text + result = write_text("/nonexistent/dir/file.txt", "hello") + self.assertFalse(result) + + def test_handles_none_text(self): + from azext_workload_orchestration.support.utils import write_text + import tempfile + fd, path = tempfile.mkstemp(suffix=".txt") + os.close(fd) + try: + result = write_text(path, None) + self.assertTrue(result) + with open(path) as f: + self.assertEqual(f.read(), "") + finally: + os.unlink(path) + + +class TestSafeApiCallRBAC(unittest.TestCase): + """Test RBAC-specific error handling in safe_api_call.""" + + def test_401_unauthorized(self): + from azext_workload_orchestration.support.utils import safe_api_call + from kubernetes.client.exceptions import ApiException + fn = MagicMock(side_effect=ApiException(status=401, reason="Unauthorized")) + result, err = safe_api_call(fn, description="test auth") + self.assertIsNone(result) + self.assertIn("401", err) + + def test_500_server_error(self): + from azext_workload_orchestration.support.utils import safe_api_call + from kubernetes.client.exceptions import ApiException + fn = MagicMock(side_effect=ApiException(status=500, reason="Internal Server Error")) + result, err = safe_api_call(fn, description="test server err") + self.assertIsNone(result) + self.assertIn("500", err) + + def test_timeout_error(self): + from azext_workload_orchestration.support.utils import safe_api_call + from urllib3.exceptions import MaxRetryError, NewConnectionError + fn = MagicMock(side_effect=MaxRetryError(None, None, "timed out")) + result, err = safe_api_call(fn, description="test timeout") + self.assertIsNone(result) + self.assertIn("timed out", err) + + def test_connection_refused(self): + from azext_workload_orchestration.support.utils import safe_api_call + fn = MagicMock(side_effect=ConnectionRefusedError("refused")) + result, err = safe_api_call(fn, description="test refused") + self.assertIsNone(result) + self.assertIn("refused", err) + + +class TestDetectCapabilitiesResilience(unittest.TestCase): + """Test detect_cluster_capabilities handles failures.""" + + def test_api_failure_returns_all_false(self): + from azext_workload_orchestration.support.utils import detect_cluster_capabilities + from kubernetes.client.exceptions import ApiException + mock_apis = MagicMock() + mock_apis.get_api_versions.side_effect = ApiException(status=403, reason="Forbidden") + caps = detect_cluster_capabilities({"apis": mock_apis}) + self.assertFalse(caps.get("has_gatekeeper")) + self.assertFalse(caps.get("has_cert_manager")) + self.assertFalse(caps.get("has_symphony")) + + def test_empty_groups_returns_all_false(self): + from azext_workload_orchestration.support.utils import detect_cluster_capabilities + mock_apis = MagicMock() + mock_result = MagicMock() + mock_result.groups = [] + mock_apis.get_api_versions.return_value = mock_result + caps = detect_cluster_capabilities({"apis": mock_apis}) + self.assertFalse(caps["has_gatekeeper"]) + self.assertFalse(caps["has_cert_manager"]) + + def test_none_groups_returns_all_false(self): + from azext_workload_orchestration.support.utils import detect_cluster_capabilities + mock_apis = MagicMock() + mock_result = MagicMock() + mock_result.groups = None + mock_apis.get_api_versions.return_value = mock_result + caps = detect_cluster_capabilities({"apis": mock_apis}) + self.assertFalse(caps["has_gatekeeper"]) + + +class TestNodeChecksWithNoneData(unittest.TestCase): + """Test validators handle None/missing cluster_info gracefully.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_node_readiness_with_none_nodes(self): + from azext_workload_orchestration.support.validators import _check_node_readiness + result = _check_node_readiness(None, self.bundle_dir, {"nodes": None}, {}) + self.assertEqual(result["status"], "FAIL") + + def test_node_capacity_with_none_nodes(self): + from azext_workload_orchestration.support.validators import _check_node_capacity + result = _check_node_capacity(None, self.bundle_dir, {"nodes": None}, {}) + self.assertEqual(result["status"], "SKIP") + + def test_wo_namespace_with_none_namespaces(self): + from azext_workload_orchestration.support.validators import _check_wo_namespace + result = _check_wo_namespace(None, self.bundle_dir, {"namespaces": None}, {}) + self.assertEqual(result["status"], "FAIL") + + def test_psa_labels_with_none_namespaces(self): + from azext_workload_orchestration.support.validators import _check_psa_labels + result = _check_psa_labels(None, self.bundle_dir, {"namespaces": None}, {}) + self.assertEqual(result["status"], "PASS") + + def test_cluster_resources_with_none_nodes(self): + from azext_workload_orchestration.support.validators import _check_cluster_resources + result = _check_cluster_resources(None, self.bundle_dir, {"nodes": None}, {}) + self.assertEqual(result["status"], "SKIP") + + def test_empty_cluster_info(self): + from azext_workload_orchestration.support.validators import _check_k8s_version + result = _check_k8s_version(None, self.bundle_dir, {}, {}) + # Empty version info → can't parse → WARN or FAIL (both acceptable) + self.assertIn(result["status"], ("WARN", "FAIL")) + + def test_version_with_plus_suffix(self): + from azext_workload_orchestration.support.validators import _check_k8s_version + info = {"server_version": {"major": "1", "minor": "28+", "git_version": "v1.28.2-gke.1"}} + result = _check_k8s_version(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "PASS") + + +class TestProtectedNamespaceCheck(unittest.TestCase): + """Test protected namespace validation.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_wo_namespace_is_not_protected(self): + from azext_workload_orchestration.support.validators import _check_protected_namespace + result = _check_protected_namespace(None, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + + +class TestCsiDriversCheck(unittest.TestCase): + """Test CSI driver check.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_no_drivers(self): + from azext_workload_orchestration.support.validators import _check_csi_drivers + mock_storage = MagicMock() + mock_result = MagicMock() + mock_result.items = [] + mock_storage.list_csi_driver.return_value = mock_result + result = _check_csi_drivers({"storage_v1": mock_storage}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + def test_with_drivers(self): + from azext_workload_orchestration.support.validators import _check_csi_drivers + mock_storage = MagicMock() + mock_driver = MagicMock() + mock_driver.metadata.name = "disk.csi.azure.com" + mock_result = MagicMock() + mock_result.items = [mock_driver] + mock_storage.list_csi_driver.return_value = mock_result + result = _check_csi_drivers({"storage_v1": mock_storage}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + self.assertIn("disk.csi.azure.com", result["message"]) + + def test_rbac_denied(self): + from azext_workload_orchestration.support.validators import _check_csi_drivers + from kubernetes.client.exceptions import ApiException + mock_storage = MagicMock() + mock_storage.list_csi_driver.side_effect = ApiException(status=403, reason="Forbidden") + result = _check_csi_drivers({"storage_v1": mock_storage}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "SKIP") + + +class TestProxyCheck(unittest.TestCase): + """Test proxy settings check.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_no_proxy(self): + from azext_workload_orchestration.support.validators import _check_proxy_settings + mock_core = MagicMock() + mock_pod = MagicMock() + mock_pod.metadata.name = "pod1" + mock_container = MagicMock() + mock_container.name = "c1" + mock_container.env = [] + mock_pod.spec.containers = [mock_container] + mock_result = MagicMock() + mock_result.items = [mock_pod] + mock_core.list_namespaced_pod.return_value = mock_result + result = _check_proxy_settings({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + + def test_with_proxy(self): + from azext_workload_orchestration.support.validators import _check_proxy_settings + mock_core = MagicMock() + mock_pod = MagicMock() + mock_pod.metadata.name = "pod1" + mock_env = MagicMock() + mock_env.name = "HTTP_PROXY" + mock_env.value = "http://proxy:8080" + mock_container = MagicMock() + mock_container.name = "c1" + mock_container.env = [mock_env] + mock_pod.spec.containers = [mock_container] + mock_result = MagicMock() + mock_result.items = [mock_pod] + mock_core.list_namespaced_pod.return_value = mock_result + result = _check_proxy_settings({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + self.assertIn("proxy", result["message"].lower()) + + +class TestZipBundleResilience(unittest.TestCase): + """Test zip bundle creation handles edge cases.""" + + def test_empty_bundle_dir(self): + """Zip creation works even with empty bundle directory.""" + from azext_workload_orchestration.support.utils import ( + create_bundle_directory, create_zip_bundle, + ) + tmpdir = tempfile.mkdtemp() + try: + bundle_dir, bundle_name = create_bundle_directory(tmpdir) + zip_path = create_zip_bundle(bundle_dir, bundle_name, tmpdir) + self.assertTrue(os.path.isfile(zip_path)) + self.assertFalse(os.path.isdir(bundle_dir)) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestClusterResourcesCheck(unittest.TestCase): + """Test cluster-wide aggregate resource check.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_sufficient_total(self): + from azext_workload_orchestration.support.validators import _check_cluster_resources + info = {"nodes": [ + {"name": "n1", "allocatable_cpu": "4", "allocatable_memory": "16Gi"}, + {"name": "n2", "allocatable_cpu": "4", "allocatable_memory": "16Gi"}, + ]} + result = _check_cluster_resources(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "PASS") + self.assertIn("8.0 CPU", result["message"]) + + def test_insufficient_total(self): + from azext_workload_orchestration.support.validators import _check_cluster_resources + info = {"nodes": [ + {"name": "n1", "allocatable_cpu": "500m", "allocatable_memory": "1Gi"}, + ]} + result = _check_cluster_resources(None, self.bundle_dir, info, {}) + self.assertEqual(result["status"], "WARN") + + +# --------------------------------------------------------------------------- +# Collector helper function tests +# --------------------------------------------------------------------------- + +class TestGetNodeRoles(unittest.TestCase): + """Test _get_node_roles helper.""" + + def test_control_plane_role(self): + from azext_workload_orchestration.support.collectors import _get_node_roles + node = MagicMock() + node.metadata.labels = {"node-role.kubernetes.io/control-plane": ""} + self.assertEqual(_get_node_roles(node), ["control-plane"]) + + def test_multiple_roles(self): + from azext_workload_orchestration.support.collectors import _get_node_roles + node = MagicMock() + node.metadata.labels = { + "node-role.kubernetes.io/control-plane": "", + "node-role.kubernetes.io/master": "", + } + roles = _get_node_roles(node) + self.assertIn("control-plane", roles) + self.assertIn("master", roles) + + def test_no_roles(self): + from azext_workload_orchestration.support.collectors import _get_node_roles + node = MagicMock() + node.metadata.labels = {"kubernetes.io/os": "linux"} + self.assertEqual(_get_node_roles(node), [""]) + + def test_no_labels(self): + from azext_workload_orchestration.support.collectors import _get_node_roles + node = MagicMock() + node.metadata.labels = None + self.assertEqual(_get_node_roles(node), [""]) + + +class TestPodReadyCount(unittest.TestCase): + """Test _pod_ready_count helper.""" + + def test_all_ready(self): + from azext_workload_orchestration.support.collectors import _pod_ready_count + pod = MagicMock() + pod.spec.containers = [MagicMock(), MagicMock()] + cs1 = MagicMock(); cs1.ready = True + cs2 = MagicMock(); cs2.ready = True + pod.status.container_statuses = [cs1, cs2] + self.assertEqual(_pod_ready_count(pod), "2/2") + + def test_partial_ready(self): + from azext_workload_orchestration.support.collectors import _pod_ready_count + pod = MagicMock() + pod.spec.containers = [MagicMock(), MagicMock(), MagicMock()] + cs1 = MagicMock(); cs1.ready = True + cs2 = MagicMock(); cs2.ready = False + pod.status.container_statuses = [cs1, cs2] + self.assertEqual(_pod_ready_count(pod), "1/3") + + def test_no_container_statuses(self): + from azext_workload_orchestration.support.collectors import _pod_ready_count + pod = MagicMock() + pod.spec.containers = [MagicMock()] + pod.status.container_statuses = None + self.assertEqual(_pod_ready_count(pod), "0/1") + + +class TestPodRestartCount(unittest.TestCase): + """Test _pod_restart_count helper.""" + + def test_no_restarts(self): + from azext_workload_orchestration.support.collectors import _pod_restart_count + pod = MagicMock() + cs = MagicMock(); cs.restart_count = 0 + pod.status.container_statuses = [cs] + self.assertEqual(_pod_restart_count(pod), 0) + + def test_high_restarts(self): + from azext_workload_orchestration.support.collectors import _pod_restart_count + pod = MagicMock() + cs1 = MagicMock(); cs1.restart_count = 15 + cs2 = MagicMock(); cs2.restart_count = 3 + pod.status.container_statuses = [cs1, cs2] + self.assertEqual(_pod_restart_count(pod), 18) + + def test_none_statuses(self): + from azext_workload_orchestration.support.collectors import _pod_restart_count + pod = MagicMock() + pod.status.container_statuses = None + self.assertEqual(_pod_restart_count(pod), 0) + + +class TestIsDefaultSC(unittest.TestCase): + """Test _is_default_sc helper.""" + + def test_v1_annotation(self): + from azext_workload_orchestration.support.collectors import _is_default_sc + sc = MagicMock() + sc.metadata.annotations = {"storageclass.kubernetes.io/is-default-class": "true"} + self.assertTrue(_is_default_sc(sc)) + + def test_beta_annotation(self): + from azext_workload_orchestration.support.collectors import _is_default_sc + sc = MagicMock() + sc.metadata.annotations = {"storageclass.beta.kubernetes.io/is-default-class": "true"} + self.assertTrue(_is_default_sc(sc)) + + def test_not_default(self): + from azext_workload_orchestration.support.collectors import _is_default_sc + sc = MagicMock() + sc.metadata.annotations = {} + self.assertFalse(_is_default_sc(sc)) + + def test_none_annotations(self): + from azext_workload_orchestration.support.collectors import _is_default_sc + sc = MagicMock() + sc.metadata.annotations = None + self.assertFalse(_is_default_sc(sc)) + + +class TestCertIssuerReady(unittest.TestCase): + """Test _cert_issuer_ready helper.""" + + def test_ready_true(self): + from azext_workload_orchestration.support.collectors import _cert_issuer_ready + issuer = {"status": {"conditions": [{"type": "Ready", "status": "True"}]}} + self.assertTrue(_cert_issuer_ready(issuer)) + + def test_ready_false(self): + from azext_workload_orchestration.support.collectors import _cert_issuer_ready + issuer = {"status": {"conditions": [{"type": "Ready", "status": "False"}]}} + self.assertFalse(_cert_issuer_ready(issuer)) + + def test_no_conditions(self): + from azext_workload_orchestration.support.collectors import _cert_issuer_ready + self.assertFalse(_cert_issuer_ready({"status": {}})) + + def test_no_status(self): + from azext_workload_orchestration.support.collectors import _cert_issuer_ready + self.assertFalse(_cert_issuer_ready({})) + + +class TestCreateNamespaceLogDir(unittest.TestCase): + """Test create_namespace_log_dir.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_creates_dir(self): + from azext_workload_orchestration.support.utils import create_namespace_log_dir + log_dir = create_namespace_log_dir(self.bundle_dir, "kube-system") + self.assertTrue(os.path.isdir(log_dir)) + self.assertTrue(log_dir.endswith("kube-system")) + + def test_idempotent(self): + from azext_workload_orchestration.support.utils import create_namespace_log_dir + d1 = create_namespace_log_dir(self.bundle_dir, "test-ns") + d2 = create_namespace_log_dir(self.bundle_dir, "test-ns") + self.assertEqual(d1, d2) + + +# --------------------------------------------------------------------------- +# Validator edge case tests +# --------------------------------------------------------------------------- + +class TestDnsHealthCheck(unittest.TestCase): + """Test _check_dns_health with various scenarios.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_dns_pods_running(self): + from azext_workload_orchestration.support.validators import _check_dns_health + mock_core = MagicMock() + pod = MagicMock() + pod.metadata.name = "coredns-abc" + pod.status.phase = "Running" + result_obj = MagicMock() + result_obj.items = [pod] + mock_core.list_namespaced_pod.return_value = result_obj + result = _check_dns_health({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + + def test_dns_pods_not_running(self): + from azext_workload_orchestration.support.validators import _check_dns_health + mock_core = MagicMock() + pod = MagicMock() + pod.metadata.name = "coredns-abc" + pod.status.phase = "Pending" + result_obj = MagicMock() + result_obj.items = [pod] + mock_core.list_namespaced_pod.return_value = result_obj + result = _check_dns_health({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + def test_no_dns_pods_fallback_by_name(self): + from azext_workload_orchestration.support.validators import _check_dns_health + mock_core = MagicMock() + empty = MagicMock(); empty.items = [] + dns_pod = MagicMock() + dns_pod.metadata.name = "coredns-xyz" + dns_pod.status.phase = "Running" + all_pods = MagicMock(); all_pods.items = [dns_pod] + mock_core.list_namespaced_pod.side_effect = [empty, all_pods] + result = _check_dns_health({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + + def test_rbac_denied(self): + from azext_workload_orchestration.support.validators import _check_dns_health + from kubernetes.client.exceptions import ApiException + mock_core = MagicMock() + mock_core.list_namespaced_pod.side_effect = ApiException(status=403, reason="Forbidden") + result = _check_dns_health({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + +class TestDefaultStorageClassCheck(unittest.TestCase): + """Test _check_default_storage_class.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_has_default(self): + from azext_workload_orchestration.support.validators import _check_default_storage_class + mock_storage = MagicMock() + sc = MagicMock() + sc.metadata.name = "default" + sc.metadata.annotations = {"storageclass.kubernetes.io/is-default-class": "true"} + result_obj = MagicMock(); result_obj.items = [sc] + mock_storage.list_storage_class.return_value = result_obj + result = _check_default_storage_class({"storage_v1": mock_storage}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + self.assertIn("default", result["message"]) + + def test_no_default(self): + from azext_workload_orchestration.support.validators import _check_default_storage_class + mock_storage = MagicMock() + sc = MagicMock() + sc.metadata.name = "managed-premium" + sc.metadata.annotations = {} + result_obj = MagicMock(); result_obj.items = [sc] + mock_storage.list_storage_class.return_value = result_obj + result = _check_default_storage_class({"storage_v1": mock_storage}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + def test_no_storage_classes(self): + from azext_workload_orchestration.support.validators import _check_default_storage_class + mock_storage = MagicMock() + result_obj = MagicMock(); result_obj.items = [] + mock_storage.list_storage_class.return_value = result_obj + result = _check_default_storage_class({"storage_v1": mock_storage}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + +class TestWoPodsCheck(unittest.TestCase): + """Test _check_wo_pods.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_all_running(self): + from azext_workload_orchestration.support.validators import _check_wo_pods + mock_core = MagicMock() + p1 = MagicMock(); p1.metadata.name = "sym-api"; p1.status.phase = "Running" + p2 = MagicMock(); p2.metadata.name = "sym-ctrl"; p2.status.phase = "Running" + result_obj = MagicMock(); result_obj.items = [p1, p2] + mock_core.list_namespaced_pod.return_value = result_obj + result = _check_wo_pods({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + + def test_some_pending(self): + from azext_workload_orchestration.support.validators import _check_wo_pods + mock_core = MagicMock() + p1 = MagicMock(); p1.metadata.name = "sym-api"; p1.status.phase = "Running" + p2 = MagicMock(); p2.metadata.name = "sym-ctrl"; p2.status.phase = "Pending" + result_obj = MagicMock(); result_obj.items = [p1, p2] + mock_core.list_namespaced_pod.return_value = result_obj + result = _check_wo_pods({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + def test_no_pods(self): + from azext_workload_orchestration.support.validators import _check_wo_pods + mock_core = MagicMock() + result_obj = MagicMock(); result_obj.items = [] + mock_core.list_namespaced_pod.return_value = result_obj + result = _check_wo_pods({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "FAIL") + + def test_rbac_denied(self): + from azext_workload_orchestration.support.validators import _check_wo_pods + from kubernetes.client.exceptions import ApiException + mock_core = MagicMock() + mock_core.list_namespaced_pod.side_effect = ApiException(status=403, reason="Forbidden") + result = _check_wo_pods({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + +class TestWoWebhooksCheck(unittest.TestCase): + """Test _check_wo_webhooks.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_symphony_webhooks_found(self): + from azext_workload_orchestration.support.validators import _check_wo_webhooks + mock_adm = MagicMock() + wh = MagicMock() + wh.metadata.name = "symphony-validating-webhook" + hook1 = MagicMock(); hook1.failure_policy = "Fail" + hook2 = MagicMock(); hook2.failure_policy = "Fail" + wh.webhooks = [hook1, hook2] + result_obj = MagicMock(); result_obj.items = [wh] + mock_adm.list_validating_webhook_configuration.return_value = result_obj + result = _check_wo_webhooks({"admissionregistration_v1": mock_adm}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + self.assertIn("2 hooks", result["message"]) + + def test_no_symphony_webhooks(self): + from azext_workload_orchestration.support.validators import _check_wo_webhooks + mock_adm = MagicMock() + wh = MagicMock() + wh.metadata.name = "gatekeeper-validating" + wh.webhooks = [] + result_obj = MagicMock(); result_obj.items = [wh] + mock_adm.list_validating_webhook_configuration.return_value = result_obj + result = _check_wo_webhooks({"admissionregistration_v1": mock_adm}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + +class TestResourceQuotasCheck(unittest.TestCase): + """Test _check_resource_quotas edge cases.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_no_quotas(self): + from azext_workload_orchestration.support.validators import _check_resource_quotas + mock_core = MagicMock() + result_obj = MagicMock(); result_obj.items = [] + mock_core.list_namespaced_resource_quota.return_value = result_obj + result = _check_resource_quotas({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + + def test_quota_over_80_percent(self): + from azext_workload_orchestration.support.validators import _check_resource_quotas + mock_core = MagicMock() + rq = MagicMock() + rq.metadata.name = "compute-quota" + rq.status.hard = {"cpu": "10"} + rq.status.used = {"cpu": "9"} + result_obj = MagicMock(); result_obj.items = [rq] + mock_core.list_namespaced_resource_quota.return_value = result_obj + result = _check_resource_quotas({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "WARN") + + def test_quota_under_80_percent(self): + from azext_workload_orchestration.support.validators import _check_resource_quotas + mock_core = MagicMock() + rq = MagicMock() + rq.metadata.name = "compute-quota" + rq.status.hard = {"cpu": "10"} + rq.status.used = {"cpu": "5"} + result_obj = MagicMock(); result_obj.items = [rq] + mock_core.list_namespaced_resource_quota.return_value = result_obj + result = _check_resource_quotas({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + + +class TestImagePullSecretsCheck(unittest.TestCase): + """Test _check_image_pull_secrets.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_no_secrets(self): + from azext_workload_orchestration.support.validators import _check_image_pull_secrets + mock_core = MagicMock() + result_obj = MagicMock(); result_obj.items = [] + mock_core.list_namespaced_secret.return_value = result_obj + result = _check_image_pull_secrets({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + self.assertIn("default service account", result["message"]) + + def test_has_secrets(self): + from azext_workload_orchestration.support.validators import _check_image_pull_secrets + mock_core = MagicMock() + sec = MagicMock(); sec.metadata.name = "acr-creds" + result_with = MagicMock(); result_with.items = [sec] + result_empty = MagicMock(); result_empty.items = [] + mock_core.list_namespaced_secret.side_effect = [result_with, result_empty] + result = _check_image_pull_secrets({"core_v1": mock_core}, self.bundle_dir, {}, {}) + self.assertEqual(result["status"], "PASS") + self.assertIn("acr-creds", result["message"]) + + +class TestCollectNamespaceResources(unittest.TestCase): + """Test collect_namespace_resources with mocked API.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_empty_namespace(self): + from azext_workload_orchestration.support.collectors import collect_namespace_resources + mock_core = MagicMock() + mock_apps = MagicMock() + empty = MagicMock(); empty.items = [] + mock_core.list_namespaced_pod.return_value = empty + mock_apps.list_namespaced_deployment.return_value = empty + mock_core.list_namespaced_service.return_value = empty + mock_apps.list_namespaced_daemon_set.return_value = empty + mock_core.list_namespaced_event.return_value = empty + mock_core.list_namespaced_config_map.return_value = empty + result = collect_namespace_resources( + {"core_v1": mock_core, "apps_v1": mock_apps}, + self.bundle_dir, "test-ns" + ) + self.assertEqual(result.get("pods"), []) + self.assertEqual(result.get("deployments"), []) + + def test_namespace_with_pod(self): + from azext_workload_orchestration.support.collectors import collect_namespace_resources + mock_core = MagicMock() + mock_apps = MagicMock() + pod = MagicMock() + pod.metadata.name = "test-pod" + pod.status.phase = "Running" + pod.spec.node_name = "node1" + pod.spec.containers = [MagicMock(name="c1")] + cs = MagicMock(); cs.ready = True; cs.restart_count = 0 + pod.status.container_statuses = [cs] + pod_list = MagicMock(); pod_list.items = [pod] + empty = MagicMock(); empty.items = [] + mock_core.list_namespaced_pod.return_value = pod_list + mock_apps.list_namespaced_deployment.return_value = empty + mock_core.list_namespaced_service.return_value = empty + mock_apps.list_namespaced_daemon_set.return_value = empty + mock_core.list_namespaced_event.return_value = empty + mock_core.list_namespaced_config_map.return_value = empty + result = collect_namespace_resources( + {"core_v1": mock_core, "apps_v1": mock_apps}, + self.bundle_dir, "test-ns" + ) + self.assertEqual(len(result["pods"]), 1) + self.assertEqual(result["pods"][0]["name"], "test-pod") + + +class TestCollectPreviousLogs(unittest.TestCase): + """Test collect_previous_logs.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_no_restarted_containers(self): + from azext_workload_orchestration.support.collectors import collect_previous_logs + mock_core = MagicMock() + pod = MagicMock() + pod.metadata.name = "pod1" + cs = MagicMock(); cs.restart_count = 0; cs.name = "c1" + pod.status.container_statuses = [cs] + result_obj = MagicMock(); result_obj.items = [pod] + mock_core.list_namespaced_pod.return_value = result_obj + count = collect_previous_logs({"core_v1": mock_core}, self.bundle_dir, "test-ns") + self.assertEqual(count, 0) + + def test_restarted_container_collects(self): + from azext_workload_orchestration.support.collectors import collect_previous_logs + mock_core = MagicMock() + pod = MagicMock() + pod.metadata.name = "crash-pod" + cs = MagicMock(); cs.restart_count = 5; cs.name = "app" + pod.status.container_statuses = [cs] + result_obj = MagicMock(); result_obj.items = [pod] + mock_core.list_namespaced_pod.return_value = result_obj + mock_core.read_namespaced_pod_log.return_value = "error log line\npanic" + count = collect_previous_logs({"core_v1": mock_core}, self.bundle_dir, "test-ns") + self.assertEqual(count, 1) + log_dir = os.path.join(self.bundle_dir, "logs", "test-ns") + self.assertTrue(os.path.isdir(log_dir)) + + def test_previous_log_api_fails(self): + from azext_workload_orchestration.support.collectors import collect_previous_logs + from kubernetes.client.exceptions import ApiException + mock_core = MagicMock() + pod = MagicMock() + pod.metadata.name = "crash-pod" + cs = MagicMock(); cs.restart_count = 3; cs.name = "app" + pod.status.container_statuses = [cs] + result_obj = MagicMock(); result_obj.items = [pod] + mock_core.list_namespaced_pod.return_value = result_obj + mock_core.read_namespaced_pod_log.side_effect = ApiException(status=400, reason="Bad Request") + count = collect_previous_logs({"core_v1": mock_core}, self.bundle_dir, "test-ns") + self.assertEqual(count, 0) + + +class TestLogTruncation(unittest.TestCase): + """Test container log size truncation.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + from azext_workload_orchestration.support.utils import create_bundle_directory + self.bundle_dir, _ = create_bundle_directory(self.tmpdir) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_large_log_gets_truncated(self): + from azext_workload_orchestration.support.collectors import collect_container_logs + from azext_workload_orchestration.support.consts import DEFAULT_MAX_LOG_SIZE_BYTES + mock_core = MagicMock() + pod = MagicMock() + pod.metadata.name = "chatty-pod" + mock_container = MagicMock() + mock_container.name = "app" + pod.spec.containers = [mock_container] + result_obj = MagicMock(); result_obj.items = [pod] + mock_core.list_namespaced_pod.return_value = result_obj + # Create a log bigger than max size + big_log = "X" * (DEFAULT_MAX_LOG_SIZE_BYTES + 1000) + mock_core.read_namespaced_pod_log.return_value = big_log + count = collect_container_logs({"core_v1": mock_core}, self.bundle_dir, "test-ns", tail_lines=None) + self.assertEqual(count, 1) + log_file = os.path.join(self.bundle_dir, "logs", "test-ns", "chatty-pod--app.log") + self.assertTrue(os.path.isfile(log_file)) + with open(log_file) as f: + content = f.read() + self.assertIn("[TRUNCATED", content) + + +class TestParseCpuEdgeCases(unittest.TestCase): + """Additional edge cases for CPU parsing.""" + + def test_zero(self): + from azext_workload_orchestration.support.utils import parse_cpu + self.assertEqual(parse_cpu("0"), 0.0) + self.assertEqual(parse_cpu("0m"), 0.0) + + def test_large_millicores(self): + from azext_workload_orchestration.support.utils import parse_cpu + self.assertAlmostEqual(parse_cpu("32000m"), 32.0) + + def test_decimal_cores(self): + from azext_workload_orchestration.support.utils import parse_cpu + self.assertAlmostEqual(parse_cpu("0.5"), 0.5) + + def test_whitespace(self): + from azext_workload_orchestration.support.utils import parse_cpu + self.assertAlmostEqual(parse_cpu(" 4 "), 4.0) + self.assertAlmostEqual(parse_cpu(" 500m "), 0.5) + + +class TestParseMemoryEdgeCases(unittest.TestCase): + """Additional edge cases for memory parsing.""" + + def test_plain_bytes(self): + from azext_workload_orchestration.support.utils import parse_memory_gi + result = parse_memory_gi("1073741824") + self.assertAlmostEqual(result, 1.0, places=1) + + def test_invalid_string(self): + from azext_workload_orchestration.support.utils import parse_memory_gi + self.assertEqual(parse_memory_gi("not-a-number"), 0.0) + + def test_zero(self): + from azext_workload_orchestration.support.utils import parse_memory_gi + self.assertEqual(parse_memory_gi("0"), 0.0) + self.assertEqual(parse_memory_gi("0Ki"), 0.0) + + + + +def _skip_if_no_cluster(): + """Return True if we should skip live cluster tests.""" + if os.environ.get("SKIP_LIVE_TESTS", "").lower() in ("1", "true", "yes"): + return True + try: + from kubernetes import config, client + config.load_kube_config() + v1 = client.VersionApi() + v1.get_code() + return False + except Exception: + return True + + +_NO_CLUSTER = _skip_if_no_cluster() + + +@unittest.skipIf(_NO_CLUSTER, "No live Kubernetes cluster available") +class IntegrationTestFullBundle(unittest.TestCase): + """End-to-end integration tests against a real cluster. + + These tests validate that every collector and validator works against + real Kubernetes API responses — not mocks. They are safe (read-only) + and create no resources on the cluster. + """ + + @classmethod + def setUpClass(cls): + from azext_workload_orchestration.support.utils import ( + get_kubernetes_client, create_bundle_directory, + detect_cluster_capabilities, + ) + from azext_workload_orchestration.support.collectors import collect_cluster_info + + cls.tmpdir = tempfile.mkdtemp(prefix="wo-integration-test-") + cls.bundle_dir, cls.bundle_name = create_bundle_directory(cls.tmpdir) + cls.clients = get_kubernetes_client() + cls.cluster_info = collect_cluster_info(cls.clients, cls.bundle_dir) + cls.capabilities = detect_cluster_capabilities(cls.clients) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + # -- Cluster info -------------------------------------------------------- + + def test_cluster_info_has_version(self): + self.assertIn("server_version", self.cluster_info) + sv = self.cluster_info["server_version"] + self.assertIn("major", sv) + self.assertIn("minor", sv) + self.assertIn("git_version", sv) + + def test_cluster_info_has_nodes(self): + self.assertIn("nodes", self.cluster_info) + self.assertGreater(len(self.cluster_info["nodes"]), 0) + node = self.cluster_info["nodes"][0] + for key in ("name", "ready", "roles", "os", "container_runtime", + "kubelet_version", "allocatable_cpu", "allocatable_memory"): + self.assertIn(key, node, f"Missing key '{key}' in node info") + + def test_cluster_info_has_namespaces(self): + self.assertIn("namespaces", self.cluster_info) + ns_names = [ns["name"] for ns in self.cluster_info["namespaces"]] + self.assertIn("kube-system", ns_names) + self.assertIn("default", ns_names) + + # -- Capabilities -------------------------------------------------------- + + def test_capabilities_detected(self): + for key in ("has_gatekeeper", "has_kyverno", "has_cert_manager", + "has_symphony", "has_openshift", "has_metrics"): + self.assertIn(key, self.capabilities, f"Missing capability '{key}'") + self.assertIsInstance(self.capabilities[key], bool) + + # -- Prerequisite checks ------------------------------------------------- + + def test_all_checks_run_without_crash(self): + from azext_workload_orchestration.support.validators import run_all_checks + from azext_workload_orchestration.support.consts import ( + STATUS_PASS, STATUS_FAIL, STATUS_WARN, STATUS_SKIP, STATUS_ERROR, + ) + valid_statuses = {STATUS_PASS, STATUS_FAIL, STATUS_WARN, STATUS_SKIP, STATUS_ERROR} + + results = run_all_checks( + self.clients, self.bundle_dir, self.cluster_info, self.capabilities, + ) + self.assertGreaterEqual(len(results), 10, "Expected at least 10 checks") + + for r in results: + self.assertIn("status", r) + self.assertIn("message", r) + self.assertIn(r["status"], valid_statuses, + f"Invalid status '{r['status']}' for check '{r.get('check_name')}'") + # No check should crash (ERROR status) + self.assertNotEqual(r["status"], STATUS_ERROR, + f"Check crashed: {r.get('check_name')} — {r['message']}") + + def test_k8s_version_passes(self): + from azext_workload_orchestration.support.validators import _check_k8s_version + result = _check_k8s_version(self.clients, self.bundle_dir, + self.cluster_info, self.capabilities) + self.assertEqual(result["status"], "PASS") + + def test_node_readiness_returns_valid_status(self): + from azext_workload_orchestration.support.validators import _check_node_readiness + result = _check_node_readiness(self.clients, self.bundle_dir, + self.cluster_info, self.capabilities) + self.assertIn(result["status"], ("PASS", "WARN", "FAIL")) + + # -- Collectors ---------------------------------------------------------- + + def test_collect_cluster_resources(self): + from azext_workload_orchestration.support.collectors import collect_cluster_resources + cr = collect_cluster_resources(self.clients, self.bundle_dir) + self.assertIn("storage_classes", cr) + self.assertIn("validating_webhooks", cr) + self.assertIn("crds", cr) + self.assertIsInstance(cr["storage_classes"], list) + + def test_collect_namespace_resources_kube_system(self): + from azext_workload_orchestration.support.collectors import collect_namespace_resources + nr = collect_namespace_resources(self.clients, self.bundle_dir, "kube-system") + self.assertIn("pods", nr) + self.assertGreater(len(nr["pods"]), 0, "kube-system should have pods") + pod = nr["pods"][0] + for key in ("name", "phase", "ready", "restarts", "containers"): + self.assertIn(key, pod) + + def test_collect_container_logs(self): + from azext_workload_orchestration.support.collectors import collect_container_logs + count = collect_container_logs( + self.clients, self.bundle_dir, "kube-system", tail_lines=10, + ) + self.assertGreater(count, 0, "Should collect at least 1 log from kube-system") + log_dir = os.path.join(self.bundle_dir, "logs", "kube-system") + self.assertTrue(os.path.isdir(log_dir)) + log_files = os.listdir(log_dir) + self.assertGreater(len(log_files), 0) + + def test_collect_metrics_if_available(self): + from azext_workload_orchestration.support.collectors import collect_metrics + m = collect_metrics(self.clients, self.bundle_dir, self.capabilities) + if self.capabilities.get("has_metrics"): + self.assertIn("node_metrics", m) + self.assertGreater(len(m["node_metrics"]), 0) + else: + self.assertEqual(m, {}) + + def test_collect_resource_quotas(self): + from azext_workload_orchestration.support.collectors import collect_resource_quotas + # Should not crash on any namespace + q = collect_resource_quotas(self.clients, self.bundle_dir, "kube-system") + self.assertIsInstance(q, dict) + + def test_collect_pvcs(self): + from azext_workload_orchestration.support.collectors import collect_pvcs + p = collect_pvcs(self.clients, self.bundle_dir, "kube-system") + self.assertIsInstance(p, list) + + def test_collect_wo_components(self): + from azext_workload_orchestration.support.collectors import collect_wo_components + wo = collect_wo_components(self.clients, self.bundle_dir, self.capabilities) + self.assertIsInstance(wo, dict) + + # -- Bundle zip ---------------------------------------------------------- + + def test_bundle_creates_valid_zip(self): + """Create a fresh bundle, collect data, zip it, and validate contents. + + Uses its own temp directory so it doesn't destroy the shared bundle_dir + that other tests rely on. + """ + import zipfile + from azext_workload_orchestration.support.utils import ( + create_bundle_directory, create_zip_bundle, detect_cluster_capabilities, + write_json, + ) + from azext_workload_orchestration.support.collectors import ( + collect_cluster_info, collect_namespace_resources, + collect_cluster_resources, collect_container_logs, + ) + from azext_workload_orchestration.support.validators import run_all_checks + + zip_tmpdir = tempfile.mkdtemp(prefix="wo-zip-test-") + try: + bdir, bname = create_bundle_directory(zip_tmpdir) + + # Collect enough data so the zip has content + info = collect_cluster_info(self.clients, bdir) + caps = detect_cluster_capabilities(self.clients) + write_json(os.path.join(bdir, "cluster-info", "capabilities.json"), caps) + run_all_checks(self.clients, bdir, info, caps) + collect_cluster_resources(self.clients, bdir) + collect_namespace_resources(self.clients, bdir, "kube-system") + collect_container_logs(self.clients, bdir, "kube-system", tail_lines=10) + + zip_path = create_zip_bundle(bdir, bname, zip_tmpdir) + self.assertTrue(os.path.isfile(zip_path)) + self.assertTrue(zip_path.endswith(".zip")) + + with zipfile.ZipFile(zip_path) as zf: + names = zf.namelist() + has_checks = any("checks/" in n for n in names) + has_cluster_info = any("cluster-info/" in n for n in names) + has_resources = any("resources/" in n for n in names) + has_logs = any("logs/" in n for n in names) + self.assertTrue(has_checks, "Zip missing checks/ folder") + self.assertTrue(has_cluster_info, "Zip missing cluster-info/ folder") + self.assertTrue(has_resources, "Zip missing resources/ folder") + self.assertTrue(has_logs, "Zip missing logs/ folder") + self.assertGreater(len(names), 20, + f"Expected 20+ files in bundle, got {len(names)}") + finally: + shutil.rmtree(zip_tmpdir, ignore_errors=True) + + +# =========================================================================== +# Tests for retry + timeout in safe_api_call +# =========================================================================== + + +class TestSafeApiCallRetry(unittest.TestCase): + """Test retry logic in safe_api_call.""" + + def test_retries_on_500_error(self): + """safe_api_call retries on 500 server error.""" + from azext_workload_orchestration.support.utils import safe_api_call + from kubernetes.client.exceptions import ApiException + + call_count = [0] + + def side_effect_func(*args, **kwargs): + call_count[0] += 1 + if call_count[0] <= 2: + raise ApiException(status=500, reason="Internal Server Error") + return "success" + + func = MagicMock(side_effect=side_effect_func) + result, err = safe_api_call( + func, description="test-500", max_retries=2, timeout_seconds=5 + ) + self.assertEqual(result, "success") + self.assertIsNone(err) + self.assertEqual(call_count[0], 3) + + def test_no_retry_on_403(self): + """safe_api_call does NOT retry on 403 Forbidden.""" + from azext_workload_orchestration.support.utils import safe_api_call + + call_count = [0] + + def side_effect_func(*args, **kwargs): + call_count[0] += 1 + exc = Exception("forbidden") + exc.status = 403 + exc.reason = "Forbidden" + # Need to raise proper ApiException + from kubernetes.client.exceptions import ApiException + raise ApiException(status=403, reason="Forbidden") + + func = MagicMock(side_effect=side_effect_func) + result, err = safe_api_call(func, description="test-403", max_retries=3) + self.assertIsNone(result) + self.assertIn("403", err) + self.assertEqual(call_count[0], 1) # no retries + + def test_no_retry_on_404(self): + """safe_api_call does NOT retry on 404.""" + from azext_workload_orchestration.support.utils import safe_api_call + + call_count = [0] + + def side_effect_func(*args, **kwargs): + call_count[0] += 1 + from kubernetes.client.exceptions import ApiException + raise ApiException(status=404, reason="Not Found") + + func = MagicMock(side_effect=side_effect_func) + result, err = safe_api_call(func, description="test-404", max_retries=3) + self.assertIsNone(result) + self.assertIn("404", err) + self.assertEqual(call_count[0], 1) + + def test_retries_on_connection_error(self): + """safe_api_call retries on ConnectionError.""" + from azext_workload_orchestration.support.utils import safe_api_call + + call_count = [0] + + def side_effect_func(*args, **kwargs): + call_count[0] += 1 + if call_count[0] <= 1: + raise ConnectionError("refused") + return "recovered" + + func = MagicMock(side_effect=side_effect_func) + result, err = safe_api_call(func, description="conn-err", max_retries=2, timeout_seconds=5) + self.assertEqual(result, "recovered") + self.assertIsNone(err) + self.assertEqual(call_count[0], 2) + + def test_retries_on_timeout_error(self): + """safe_api_call retries on TimeoutError.""" + from azext_workload_orchestration.support.utils import safe_api_call + + call_count = [0] + + def side_effect_func(*args, **kwargs): + call_count[0] += 1 + if call_count[0] <= 1: + raise TimeoutError("timed out") + return "ok" + + func = MagicMock(side_effect=side_effect_func) + result, err = safe_api_call(func, description="timeout-err", max_retries=2, timeout_seconds=5) + self.assertEqual(result, "ok") + self.assertIsNone(err) + + def test_exhausted_retries_returns_error(self): + """safe_api_call returns error after exhausting retries.""" + from azext_workload_orchestration.support.utils import safe_api_call + + func = MagicMock(side_effect=ConnectionError("always fails")) + result, err = safe_api_call(func, description="always-fail", max_retries=2, timeout_seconds=5) + self.assertIsNone(result) + self.assertIn("always fails", err) + self.assertEqual(func.call_count, 3) # initial + 2 retries + + def test_no_retry_on_generic_exception(self): + """safe_api_call does NOT retry on generic exceptions like ValueError.""" + from azext_workload_orchestration.support.utils import safe_api_call + + func = MagicMock(side_effect=ValueError("bad value")) + result, err = safe_api_call(func, description="val-err", max_retries=3) + self.assertIsNone(result) + self.assertIn("ValueError", err) + self.assertEqual(func.call_count, 1) + + def test_timeout_is_passed_to_api_call(self): + """safe_api_call passes _request_timeout to the underlying API call.""" + from azext_workload_orchestration.support.utils import safe_api_call + + func = MagicMock(return_value="ok") + result, err = safe_api_call(func, description="timeout-test", timeout_seconds=42) + self.assertEqual(result, "ok") + # Verify _request_timeout was injected + _, kwargs = func.call_args + self.assertEqual(kwargs.get("_request_timeout"), 42) + + def test_existing_request_timeout_not_overwritten(self): + """safe_api_call doesn't overwrite an existing _request_timeout.""" + from azext_workload_orchestration.support.utils import safe_api_call + + func = MagicMock(return_value="ok") + result, err = safe_api_call( + func, description="timeout-existing", + _request_timeout=99, timeout_seconds=42 + ) + self.assertEqual(result, "ok") + _, kwargs = func.call_args + self.assertEqual(kwargs.get("_request_timeout"), 99) + + def test_max_retries_zero_means_no_retry(self): + """max_retries=0 means try once, no retries.""" + from azext_workload_orchestration.support.utils import safe_api_call + + func = MagicMock(side_effect=ConnectionError("fail")) + result, err = safe_api_call(func, description="no-retry", max_retries=0) + self.assertIsNone(result) + self.assertEqual(func.call_count, 1) + + +# =========================================================================== +# Tests for namespace validation +# =========================================================================== + + +class TestValidateNamespaces(unittest.TestCase): + """Test pre-flight namespace validation.""" + + def _make_ns(self, name, phase="Active"): + ns = MagicMock() + ns.metadata.name = name + ns.status.phase = phase + return ns + + def test_all_valid(self): + from azext_workload_orchestration.support.collectors import validate_namespaces + + clients = {"core_v1": MagicMock()} + clients["core_v1"].read_namespace = MagicMock( + side_effect=lambda ns, **kw: self._make_ns(ns) + ) + + valid, skipped = validate_namespaces(clients, ["kube-system", "default"]) + self.assertEqual(valid, ["kube-system", "default"]) + self.assertEqual(skipped, []) + + def test_nonexistent_namespace_skipped(self): + from azext_workload_orchestration.support.collectors import validate_namespaces + from kubernetes.client.exceptions import ApiException + + def read_ns(ns, **kwargs): + if ns == "missing-ns": + raise ApiException(status=404, reason="Not Found") + return self._make_ns(ns) + + clients = {"core_v1": MagicMock()} + clients["core_v1"].read_namespace = MagicMock(side_effect=read_ns) + + valid, skipped = validate_namespaces(clients, ["kube-system", "missing-ns", "default"]) + self.assertEqual(valid, ["kube-system", "default"]) + self.assertEqual(len(skipped), 1) + self.assertEqual(skipped[0][0], "missing-ns") + + def test_terminating_namespace_skipped(self): + from azext_workload_orchestration.support.collectors import validate_namespaces + + def read_ns(ns, **kwargs): + if ns == "dying-ns": + return self._make_ns(ns, phase="Terminating") + return self._make_ns(ns) + + clients = {"core_v1": MagicMock()} + clients["core_v1"].read_namespace = MagicMock(side_effect=read_ns) + + valid, skipped = validate_namespaces(clients, ["kube-system", "dying-ns"]) + self.assertEqual(valid, ["kube-system"]) + self.assertEqual(len(skipped), 1) + self.assertIn("terminating", skipped[0][1]) + + def test_all_namespaces_invalid(self): + from azext_workload_orchestration.support.collectors import validate_namespaces + from kubernetes.client.exceptions import ApiException + + clients = {"core_v1": MagicMock()} + clients["core_v1"].read_namespace = MagicMock( + side_effect=ApiException(status=404, reason="Not Found") + ) + + valid, skipped = validate_namespaces(clients, ["ns1", "ns2"]) + self.assertEqual(valid, []) + self.assertEqual(len(skipped), 2) + + def test_empty_namespace_list(self): + from azext_workload_orchestration.support.collectors import validate_namespaces + + clients = {"core_v1": MagicMock()} + valid, skipped = validate_namespaces(clients, []) + self.assertEqual(valid, []) + self.assertEqual(skipped, []) + + def test_rbac_denied_namespace(self): + from azext_workload_orchestration.support.collectors import validate_namespaces + from kubernetes.client.exceptions import ApiException + + clients = {"core_v1": MagicMock()} + clients["core_v1"].read_namespace = MagicMock( + side_effect=ApiException(status=403, reason="Forbidden") + ) + + valid, skipped = validate_namespaces(clients, ["secret-ns"]) + self.assertEqual(valid, []) + self.assertEqual(len(skipped), 1) + self.assertIn("403", skipped[0][1]) + + +# =========================================================================== +# Tests for new resource collectors (ReplicaSets, Jobs, etc.) +# =========================================================================== + + +class TestCollectReplicaSets(unittest.TestCase): + """Test ReplicaSet collection in collect_namespace_resources.""" + + def test_replicasets_collected(self): + from azext_workload_orchestration.support.collectors import collect_namespace_resources + + # Build mock replicaset + rs = MagicMock() + rs.metadata.name = "nginx-abc123" + rs.spec.replicas = 3 + rs.status.ready_replicas = 3 + rs.status.available_replicas = 3 + owner = MagicMock() + owner.kind = "Deployment" + owner.name = "nginx" + rs.metadata.owner_references = [owner] + + rs_list = MagicMock() + rs_list.items = [rs] + + clients = _make_clients() + clients["apps_v1"].list_namespaced_replica_set = MagicMock(return_value=rs_list) + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, "resources"), exist_ok=True) + result = collect_namespace_resources(clients, tmpdir, "default") + + self.assertIn("replicasets", result) + self.assertEqual(len(result["replicasets"]), 1) + self.assertEqual(result["replicasets"][0]["name"], "nginx-abc123") + self.assertEqual(result["replicasets"][0]["owner"]["kind"], "Deployment") + + +class TestCollectJobs(unittest.TestCase): + """Test Job and CronJob collection.""" + + @patch("azext_workload_orchestration.support.collectors.safe_api_call") + def test_jobs_collected(self, mock_safe_call): + from azext_workload_orchestration.support.collectors import collect_namespace_resources + + # Build mock job + job = MagicMock() + job.metadata.name = "backup-job" + job.status.active = 0 + job.status.succeeded = 1 + job.status.failed = 0 + job.spec.completions = 1 + job.status.start_time = "2026-01-01T00:00:00Z" + job.status.completion_time = "2026-01-01T00:05:00Z" + + job_list = MagicMock() + job_list.items = [job] + + # Build mock empty responses for standard resources + empty_list = MagicMock() + empty_list.items = [] + + def mock_safe(func, *args, **kwargs): + desc = kwargs.get("description", "") + if "jobs" in desc: + return job_list, None + return empty_list, None + + mock_safe_call.side_effect = mock_safe + + clients = _make_clients() + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, "resources"), exist_ok=True) + # Since safe_api_call is mocked, batch API calls go through mock + result = collect_namespace_resources(clients, tmpdir, "default") + + # Jobs may or may not be collected depending on batch API availability + # The test verifies no crash occurs + + +class TestCollectIngresses(unittest.TestCase): + """Test Ingress collection.""" + + def test_no_crash_on_missing_networking_api(self): + """Ingress collection handles missing networking API gracefully.""" + from azext_workload_orchestration.support.collectors import collect_namespace_resources + + clients = _make_clients() + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, "resources"), exist_ok=True) + # Should not crash even if networking API isn't available + result = collect_namespace_resources(clients, tmpdir, "default") + + self.assertIsInstance(result, dict) + + +class TestCollectServiceAccounts(unittest.TestCase): + """Test ServiceAccount collection.""" + + def test_service_accounts_collected(self): + from azext_workload_orchestration.support.collectors import collect_namespace_resources + + sa = MagicMock() + sa.metadata.name = "default" + sa.secrets = [MagicMock()] + ips = MagicMock() + ips.name = "registry-secret" + sa.image_pull_secrets = [ips] + + sa_list = MagicMock() + sa_list.items = [sa] + + clients = _make_clients() + clients["core_v1"].list_namespaced_service_account = MagicMock(return_value=sa_list) + + with tempfile.TemporaryDirectory() as tmpdir: + os.makedirs(os.path.join(tmpdir, "resources"), exist_ok=True) + result = collect_namespace_resources(clients, tmpdir, "default") + + self.assertIn("service_accounts", result) + self.assertEqual(len(result["service_accounts"]), 1) + self.assertEqual(result["service_accounts"][0]["name"], "default") + self.assertEqual(result["service_accounts"][0]["image_pull_secrets"], ["registry-secret"]) + + +class TestGetOwnerRef(unittest.TestCase): + """Test _get_owner_ref helper.""" + + def test_with_owner(self): + from azext_workload_orchestration.support.collectors import _get_owner_ref + + resource = MagicMock() + owner = MagicMock() + owner.kind = "Deployment" + owner.name = "nginx" + resource.metadata.owner_references = [owner] + + result = _get_owner_ref(resource) + self.assertEqual(result, {"kind": "Deployment", "name": "nginx"}) + + def test_without_owner(self): + from azext_workload_orchestration.support.collectors import _get_owner_ref + + resource = MagicMock() + resource.metadata.owner_references = [] + + result = _get_owner_ref(resource) + self.assertIsNone(result) + + def test_none_owner_refs(self): + from azext_workload_orchestration.support.collectors import _get_owner_ref + + resource = MagicMock() + resource.metadata.owner_references = None + + result = _get_owner_ref(resource) + self.assertIsNone(result) + + +# =========================================================================== +# Tests for health summary +# =========================================================================== + + +class TestHealthSummary(unittest.TestCase): + """Test _compute_health_summary.""" + + def test_all_pass(self): + from azext_workload_orchestration.support.bundle import _compute_health_summary + + checks = [ + {"status": "PASS", "check_name": "c1"}, + {"status": "PASS", "check_name": "c2"}, + {"status": "PASS", "check_name": "c3"}, + ] + result = _compute_health_summary(checks, []) + self.assertEqual(result["checks_passed"], 3) + self.assertEqual(result["checks_failed"], 0) + self.assertEqual(result["checks_warned"], 0) + self.assertEqual(result["checks_total"], 3) + + def test_mixed_statuses(self): + from azext_workload_orchestration.support.bundle import _compute_health_summary + + checks = [ + {"status": "PASS", "check_name": "c1"}, + {"status": "WARN", "check_name": "c2"}, + {"status": "FAIL", "check_name": "c3"}, + ] + result = _compute_health_summary(checks, []) + self.assertEqual(result["checks_passed"], 1) + self.assertEqual(result["checks_failed"], 1) + self.assertEqual(result["checks_warned"], 1) + + def test_no_checks(self): + from azext_workload_orchestration.support.bundle import _compute_health_summary + + result = _compute_health_summary([], []) + self.assertEqual(result["checks_total"], 0) + + def test_collection_errors_counted(self): + from azext_workload_orchestration.support.bundle import _compute_health_summary + + checks = [{"status": "PASS", "check_name": "c1"}] + result = _compute_health_summary(checks, ["err1", "err2"]) + self.assertEqual(result["collection_errors"], 2) + + +# =========================================================================== +# Tests for new consts +# =========================================================================== + + +class TestNewConstants(unittest.TestCase): + """Verify new constants are properly defined.""" + + def test_api_timeout_constant(self): + from azext_workload_orchestration.support.consts import DEFAULT_API_TIMEOUT_SECONDS + self.assertEqual(DEFAULT_API_TIMEOUT_SECONDS, 30) + + def test_log_timeout_constant(self): + from azext_workload_orchestration.support.consts import DEFAULT_LOG_TIMEOUT_SECONDS + self.assertEqual(DEFAULT_LOG_TIMEOUT_SECONDS, 60) + + def test_retry_constants(self): + from azext_workload_orchestration.support.consts import ( + DEFAULT_MAX_RETRIES, + DEFAULT_RETRY_BACKOFF_BASE, + ) + self.assertEqual(DEFAULT_MAX_RETRIES, 3) + self.assertEqual(DEFAULT_RETRY_BACKOFF_BASE, 1.0) + + +# =========================================================================== +# Helper to build mock clients +# =========================================================================== + + +def _make_clients(): + """Create a standard mock clients dict for tests.""" + empty_list = MagicMock() + empty_list.items = [] + + core = MagicMock() + core.list_namespaced_pod = MagicMock(return_value=empty_list) + core.list_namespaced_service = MagicMock(return_value=empty_list) + core.list_namespaced_config_map = MagicMock(return_value=empty_list) + core.list_namespaced_event = MagicMock(return_value=empty_list) + core.list_namespaced_service_account = MagicMock(return_value=empty_list) + + apps = MagicMock() + apps.list_namespaced_deployment = MagicMock(return_value=empty_list) + apps.list_namespaced_daemon_set = MagicMock(return_value=empty_list) + apps.list_namespaced_stateful_set = MagicMock(return_value=empty_list) + apps.list_namespaced_replica_set = MagicMock(return_value=empty_list) + + return { + "core_v1": core, + "apps_v1": apps, + "custom_objects": MagicMock(), + "storage_v1": MagicMock(), + "admissionregistration_v1": MagicMock(), + "apis": MagicMock(), + "version": MagicMock(), + } + + +if __name__ == "__main__": + unittest.main() diff --git a/src/workload-orchestration/conftest.py b/src/workload-orchestration/conftest.py new file mode 100644 index 00000000000..f4eaeb15bdf --- /dev/null +++ b/src/workload-orchestration/conftest.py @@ -0,0 +1,67 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +"""Root conftest - install mock azure.cli modules before package discovery.""" + +import logging +import sys +import types + +# Create mock azure.cli modules so the extension __init__.py can be imported +# without a full azure-cli installation. +_azure = types.ModuleType("azure") +_azure.__path__ = [] +_azure_cli = types.ModuleType("azure.cli") +_azure_cli.__path__ = [] +_azure_cli_core = types.ModuleType("azure.cli.core") +_azure_cli_core.__path__ = [] +_azure_cli_core.AzCommandsLoader = type("AzCommandsLoader", (), { + "__init__": lambda self, *a, **kw: None, + "command_table": {}, + "load_command_table": lambda self, args: {}, + "load_arguments": lambda self, command: None, +}) +_azure_cli_commands = types.ModuleType("azure.cli.core.commands") +_azure_cli_commands.__path__ = [] +_azure_cli_commands.CliCommandType = type("CliCommandType", (), {"__init__": lambda self, **kw: None}) +_azure_cli_aaz = types.ModuleType("azure.cli.core.aaz") +_azure_cli_aaz.__path__ = [] +_azure_cli_aaz.load_aaz_command_table = lambda **kw: None +# Mock AAZ decorators and base classes used by __cmd_group.py files +_azure_cli_aaz.register_command_group = lambda *a, **kw: (lambda cls: cls) +_azure_cli_aaz.register_command = lambda *a, **kw: (lambda cls: cls) +_azure_cli_aaz.AAZCommandGroup = type("AAZCommandGroup", (), {}) +_azure_cli_aaz.AAZCommand = type("AAZCommand", (), { + "__init__": lambda self, *a, **kw: None, +}) +# Expose as module globals for `from azure.cli.core.aaz import *` +_azure_cli_aaz.__all__ = [ + "register_command_group", "register_command", "AAZCommandGroup", + "AAZCommand", "load_aaz_command_table", +] +_azure_cli_params = types.ModuleType("azure.cli.core.commands.parameters") +_azure_cli_params.get_enum_type = lambda x: x +_azure_cli_azclierror = types.ModuleType("azure.cli.core.azclierror") +_azure_cli_azclierror.CLIError = Exception +_knack = types.ModuleType("knack") +_knack.__path__ = [] +_knack_log = types.ModuleType("knack.log") +_knack_log.get_logger = logging.getLogger +_knack_help = types.ModuleType("knack.help_files") +_knack_help.helps = {} + +for mod_name, mod in [ + ("azure", _azure), + ("azure.cli", _azure_cli), + ("azure.cli.core", _azure_cli_core), + ("azure.cli.core.commands", _azure_cli_commands), + ("azure.cli.core.aaz", _azure_cli_aaz), + ("azure.cli.core.commands.parameters", _azure_cli_params), + ("azure.cli.core.azclierror", _azure_cli_azclierror), + ("knack", _knack), + ("knack.log", _knack_log), + ("knack.help_files", _knack_help), +]: + # Install mocks — use setdefault to not break if real modules exist + sys.modules.setdefault(mod_name, mod) diff --git a/src/workload-orchestration/setup.py b/src/workload-orchestration/setup.py index 32448955dc4..7a0cfc1110b 100644 --- a/src/workload-orchestration/setup.py +++ b/src/workload-orchestration/setup.py @@ -10,7 +10,7 @@ # HISTORY.rst entry. -VERSION = '5.1.0' +VERSION = '5.1.1' # The full list of classifiers is available at # https://pypi.python.org/pypi?%3Aaction=list_classifiers @@ -26,7 +26,9 @@ 'License :: OSI Approved :: MIT License', ] -DEPENDENCIES = [] +DEPENDENCIES = [ + 'kubernetes>=24.2.0', +] with open('README.md', 'r', encoding='utf-8') as f: README = f.read()