diff --git a/__init__.py b/__init__.py index 61d9e4a..e33cde8 100644 --- a/__init__.py +++ b/__init__.py @@ -2,10 +2,10 @@ from __future__ import annotations +import os from typing import Any import voluptuous as vol -from aiofiles.ospath import exists from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT @@ -41,7 +41,7 @@ async def _validate_service_data(data: dict[str, Any]) -> None: translation_key="command_or_input", ) - if has_key_file and not await exists(data[CONF_KEY_FILE]): + if has_key_file and not os.path.exists(data[CONF_KEY_FILE]): raise ServiceValidationError( "Could not find key file.", translation_domain=DOMAIN, diff --git a/coordinator.py b/coordinator.py index 5992914..8c82258 100644 --- a/coordinator.py +++ b/coordinator.py @@ -15,8 +15,6 @@ from pathlib import Path from typing import Any -from aiofiles import open as aioopen -from aiofiles.ospath import exists from asyncssh import HostKeyNotVerifiable, PermissionDenied, connect, read_known_hosts from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT @@ -64,9 +62,8 @@ async def async_execute(self, data: dict[str, Any]) -> dict[str, Any]: timeout = data.get(CONF_TIMEOUT, CONST_DEFAULT_TIMEOUT) if input_data: - if await exists(input_data): - async with aioopen(input_data, 'r') as sf: - input_data = await sf.read() + if await self.hass.async_add_executor_job(Path(input_data).exists): + input_data = await self.hass.async_add_executor_job(Path(input_data).read_text) conn_kwargs = { CONF_HOST: host, @@ -131,6 +128,6 @@ async def _resolve_known_hosts(self, check_known_hosts: bool, known_hosts: str | return None if not known_hosts: known_hosts = str(Path("~", ".ssh", "known_hosts").expanduser()) - if await exists(known_hosts): + if await self.hass.async_add_executor_job(Path(known_hosts).exists): return await self.hass.async_add_executor_job(read_known_hosts, known_hosts) return known_hosts diff --git a/manifest.json b/manifest.json index d58238b..994f94e 100644 --- a/manifest.json +++ b/manifest.json @@ -11,7 +11,7 @@ "iot_class": "calculated", "issue_tracker": "https://github.com/gensyn/ssh_command/issues", "quality_scale": "bronze", - "requirements": ["asyncssh==2.22.0", "aiofiles==25.1.0"], + "requirements": ["asyncssh==2.22.0"], "ssdp": [], "version": "0.0.0", "zeroconf": [] diff --git a/requirements.txt b/requirements.txt index b7185b1..7e8976a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -aiofiles==25.1.0 asyncssh==2.22.0 \ No newline at end of file diff --git a/test/test_async_execute.py b/test/test_async_execute.py index bc280c6..caa880a 100644 --- a/test/test_async_execute.py +++ b/test/test_async_execute.py @@ -85,8 +85,7 @@ async def test_success(self): service_call = self._make_service_call(SERVICE_DATA_BASE) with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)): - with patch("ssh_command.coordinator.exists", return_value=False): - result = await self.handler(service_call) + result = await self.handler(service_call) self.assertEqual(result[CONF_OUTPUT], "hello\n") self.assertEqual(result[CONF_ERROR], "") @@ -96,9 +95,8 @@ async def test_host_key_not_verifiable(self): service_call = self._make_service_call(SERVICE_DATA_BASE) with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(ServiceValidationError) as ctx: - await self.handler(service_call) + with self.assertRaises(ServiceValidationError) as ctx: + await self.handler(service_call) self.assertEqual(ctx.exception.translation_key, "host_key_not_verifiable") @@ -106,9 +104,8 @@ async def test_permission_denied(self): service_call = self._make_service_call(SERVICE_DATA_BASE) with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(PermissionDenied("auth failed"))): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(ServiceValidationError) as ctx: - await self.handler(service_call) + with self.assertRaises(ServiceValidationError) as ctx: + await self.handler(service_call) self.assertEqual(ctx.exception.translation_key, "login_failed") @@ -116,9 +113,8 @@ async def test_timeout(self): service_call = self._make_service_call(SERVICE_DATA_BASE) with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(TimeoutError())): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(ServiceValidationError) as ctx: - await self.handler(service_call) + with self.assertRaises(ServiceValidationError) as ctx: + await self.handler(service_call) self.assertEqual(ctx.exception.translation_key, "connection_timed_out") @@ -127,9 +123,8 @@ async def test_name_resolution_failure(self): service_call = self._make_service_call(SERVICE_DATA_BASE) with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(ServiceValidationError) as ctx: - await self.handler(service_call) + with self.assertRaises(ServiceValidationError) as ctx: + await self.handler(service_call) self.assertEqual(ctx.exception.translation_key, "host_not_reachable") @@ -138,9 +133,8 @@ async def test_other_oserror_is_reraised(self): service_call = self._make_service_call(SERVICE_DATA_BASE) with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(OSError): - await self.handler(service_call) + with self.assertRaises(OSError): + await self.handler(service_call) async def test_input_from_file(self): with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf: @@ -153,8 +147,7 @@ async def test_input_from_file(self): service_call = self._make_service_call(data) with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)): - with patch("ssh_command.coordinator.exists", return_value=True): - await self.handler(service_call) + await self.handler(service_call) call_kwargs = mock_conn.run.call_args[1] self.assertEqual(call_kwargs["input"], "file content\n") @@ -167,8 +160,7 @@ async def test_input_string_not_file(self): service_call = self._make_service_call(data) with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)): - with patch("ssh_command.coordinator.exists", return_value=False): - await self.handler(service_call) + await self.handler(service_call) call_kwargs = mock_conn.run.call_args[1] self.assertEqual(call_kwargs["input"], "inline input") @@ -178,8 +170,7 @@ async def test_check_known_hosts_false(self): service_call = self._make_service_call(SERVICE_DATA_BASE) with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect: - with patch("ssh_command.coordinator.exists", return_value=False): - await self.handler(service_call) + await self.handler(service_call) call_kwargs = mock_connect.call_args[1] self.assertIsNone(call_kwargs["known_hosts"]) @@ -191,7 +182,7 @@ async def test_known_hosts_file_exists(self): service_call = self._make_service_call(data) with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect: - with patch("ssh_command.coordinator.exists", return_value=True): + with patch("pathlib.Path.exists", return_value=True): with patch("ssh_command.coordinator.read_known_hosts", return_value=mock_known_hosts) as mock_rkh: await self.handler(service_call) @@ -205,7 +196,7 @@ async def test_check_known_hosts_default_path_missing(self): service_call = self._make_service_call(data) with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect: - with patch("ssh_command.coordinator.exists", return_value=False): + with patch("pathlib.Path.exists", return_value=False): await self.handler(service_call) call_kwargs = mock_connect.call_args[1] diff --git a/test/test_coordinator.py b/test/test_coordinator.py index b6d6cb8..d446416 100644 --- a/test/test_coordinator.py +++ b/test/test_coordinator.py @@ -72,8 +72,7 @@ async def test_async_execute_success(self): mock_conn = self._make_mock_conn(stdout="hello\n", stderr="", exit_status=0) with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)): - with patch("ssh_command.coordinator.exists", return_value=False): - result = await self.coordinator.async_execute(EXECUTE_DATA_BASE) + result = await self.coordinator.async_execute(EXECUTE_DATA_BASE) self.assertEqual(result[CONF_OUTPUT], "hello\n") self.assertEqual(result[CONF_ERROR], "") @@ -81,25 +80,22 @@ async def test_async_execute_success(self): async def test_async_execute_host_key_not_verifiable(self): with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(ServiceValidationError) as ctx: - await self.coordinator.async_execute(EXECUTE_DATA_BASE) + with self.assertRaises(ServiceValidationError) as ctx: + await self.coordinator.async_execute(EXECUTE_DATA_BASE) self.assertEqual(ctx.exception.translation_key, "host_key_not_verifiable") async def test_async_execute_permission_denied(self): with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(PermissionDenied("auth failed"))): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(ServiceValidationError) as ctx: - await self.coordinator.async_execute(EXECUTE_DATA_BASE) + with self.assertRaises(ServiceValidationError) as ctx: + await self.coordinator.async_execute(EXECUTE_DATA_BASE) self.assertEqual(ctx.exception.translation_key, "login_failed") async def test_async_execute_timeout(self): with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(TimeoutError())): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(ServiceValidationError) as ctx: - await self.coordinator.async_execute(EXECUTE_DATA_BASE) + with self.assertRaises(ServiceValidationError) as ctx: + await self.coordinator.async_execute(EXECUTE_DATA_BASE) self.assertEqual(ctx.exception.translation_key, "connection_timed_out") @@ -107,9 +103,8 @@ async def test_async_execute_name_resolution_failure(self): err = socket.gaierror("Name or service not known") with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(ServiceValidationError) as ctx: - await self.coordinator.async_execute(EXECUTE_DATA_BASE) + with self.assertRaises(ServiceValidationError) as ctx: + await self.coordinator.async_execute(EXECUTE_DATA_BASE) self.assertEqual(ctx.exception.translation_key, "host_not_reachable") @@ -117,9 +112,8 @@ async def test_async_execute_other_oserror_reraised(self): err = OSError("something else") with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)): - with patch("ssh_command.coordinator.exists", return_value=False): - with self.assertRaises(OSError): - await self.coordinator.async_execute(EXECUTE_DATA_BASE) + with self.assertRaises(OSError): + await self.coordinator.async_execute(EXECUTE_DATA_BASE) async def test_resolve_known_hosts_check_disabled(self): result = await self.coordinator._resolve_known_hosts(False, None) @@ -128,7 +122,7 @@ async def test_resolve_known_hosts_check_disabled(self): async def test_resolve_known_hosts_file_exists(self): mock_known_hosts = MagicMock() - with patch("ssh_command.coordinator.exists", return_value=True): + with patch("pathlib.Path.exists", return_value=True): with patch("ssh_command.coordinator.read_known_hosts", return_value=mock_known_hosts) as mock_rkh: result = await self.coordinator._resolve_known_hosts(True, "/home/user/.ssh/known_hosts") @@ -136,13 +130,13 @@ async def test_resolve_known_hosts_file_exists(self): self.assertIs(result, mock_known_hosts) async def test_resolve_known_hosts_file_missing(self): - with patch("ssh_command.coordinator.exists", return_value=False): + with patch("pathlib.Path.exists", return_value=False): result = await self.coordinator._resolve_known_hosts(True, "/nonexistent/known_hosts") self.assertEqual(result, "/nonexistent/known_hosts") async def test_resolve_known_hosts_default_path(self): - with patch("ssh_command.coordinator.exists", return_value=False): + with patch("pathlib.Path.exists", return_value=False): result = await self.coordinator._resolve_known_hosts(True, None) self.assertIsInstance(result, str) diff --git a/test/test_validate_service_data.py b/test/test_validate_service_data.py index 1ea438d..d5665bb 100644 --- a/test/test_validate_service_data.py +++ b/test/test_validate_service_data.py @@ -27,27 +27,26 @@ async def test_no_command_no_input_raises(self): self.assertEqual(ctx.exception.translation_key, "command_or_input") async def test_key_file_not_found_raises(self): - with patch("ssh_command.exists", return_value=False): + with patch("os.path.exists", return_value=False): with self.assertRaises(ServiceValidationError) as ctx: await _validate_service_data({"key_file": "/nonexistent/key", "command": "ls"}) self.assertEqual(ctx.exception.translation_key, "key_file_not_found") async def test_known_hosts_with_check_disabled_raises(self): - with patch("ssh_command.exists", return_value=True): - with self.assertRaises(ServiceValidationError) as ctx: - await _validate_service_data({ - "password": "secret", - "command": "ls", - "known_hosts": "/etc/ssh/known_hosts", - "check_known_hosts": False, - }) + with self.assertRaises(ServiceValidationError) as ctx: + await _validate_service_data({ + "password": "secret", + "command": "ls", + "known_hosts": "/etc/ssh/known_hosts", + "check_known_hosts": False, + }) self.assertEqual(ctx.exception.translation_key, "known_hosts_with_check_disabled") async def test_valid_password_and_command(self): await _validate_service_data({"password": "secret", "command": "echo hi"}) async def test_valid_key_file_and_input(self): - with patch("ssh_command.exists", return_value=True): + with patch("os.path.exists", return_value=True): await _validate_service_data({"key_file": "/home/user/.ssh/id_rsa", "input": "some text"}) async def test_valid_known_hosts_with_check_enabled(self):