Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from __future__ import annotations

import os
from typing import Any

import voluptuous as vol
from aiofiles.ospath import exists

from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT
Expand Down Expand Up @@ -41,7 +41,7 @@ async def _validate_service_data(data: dict[str, Any]) -> None:
translation_key="command_or_input",
)

if has_key_file and not await exists(data[CONF_KEY_FILE]):
if has_key_file and not os.path.exists(data[CONF_KEY_FILE]):
raise ServiceValidationError(
"Could not find key file.",
translation_domain=DOMAIN,
Expand Down
9 changes: 3 additions & 6 deletions coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from pathlib import Path
from typing import Any

from aiofiles import open as aioopen
from aiofiles.ospath import exists
from asyncssh import HostKeyNotVerifiable, PermissionDenied, connect, read_known_hosts

from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT
Expand Down Expand Up @@ -64,9 +62,8 @@ async def async_execute(self, data: dict[str, Any]) -> dict[str, Any]:
timeout = data.get(CONF_TIMEOUT, CONST_DEFAULT_TIMEOUT)

if input_data:
if await exists(input_data):
async with aioopen(input_data, 'r') as sf:
input_data = await sf.read()
if await self.hass.async_add_executor_job(Path(input_data).exists):
input_data = await self.hass.async_add_executor_job(Path(input_data).read_text)

conn_kwargs = {
CONF_HOST: host,
Expand Down Expand Up @@ -131,6 +128,6 @@ async def _resolve_known_hosts(self, check_known_hosts: bool, known_hosts: str |
return None
if not known_hosts:
known_hosts = str(Path("~", ".ssh", "known_hosts").expanduser())
if await exists(known_hosts):
if await self.hass.async_add_executor_job(Path(known_hosts).exists):
return await self.hass.async_add_executor_job(read_known_hosts, known_hosts)
return known_hosts
2 changes: 1 addition & 1 deletion manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"iot_class": "calculated",
"issue_tracker": "https://github.com/gensyn/ssh_command/issues",
"quality_scale": "bronze",
"requirements": ["asyncssh==2.22.0", "aiofiles==25.1.0"],
"requirements": ["asyncssh==2.22.0"],
"ssdp": [],
"version": "0.0.0",
"zeroconf": []
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
aiofiles==25.1.0
asyncssh==2.22.0
41 changes: 16 additions & 25 deletions test/test_async_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ async def test_success(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)):
with patch("ssh_command.coordinator.exists", return_value=False):
result = await self.handler(service_call)
result = await self.handler(service_call)

self.assertEqual(result[CONF_OUTPUT], "hello\n")
self.assertEqual(result[CONF_ERROR], "")
Expand All @@ -96,29 +95,26 @@ async def test_host_key_not_verifiable(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)

self.assertEqual(ctx.exception.translation_key, "host_key_not_verifiable")

async def test_permission_denied(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(PermissionDenied("auth failed"))):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)

self.assertEqual(ctx.exception.translation_key, "login_failed")

async def test_timeout(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(TimeoutError())):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)

self.assertEqual(ctx.exception.translation_key, "connection_timed_out")

Expand All @@ -127,9 +123,8 @@ async def test_name_resolution_failure(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)
with self.assertRaises(ServiceValidationError) as ctx:
await self.handler(service_call)

self.assertEqual(ctx.exception.translation_key, "host_not_reachable")

Expand All @@ -138,9 +133,8 @@ async def test_other_oserror_is_reraised(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(OSError):
await self.handler(service_call)
with self.assertRaises(OSError):
await self.handler(service_call)

async def test_input_from_file(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf:
Expand All @@ -153,8 +147,7 @@ async def test_input_from_file(self):
service_call = self._make_service_call(data)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)):
with patch("ssh_command.coordinator.exists", return_value=True):
await self.handler(service_call)
await self.handler(service_call)

call_kwargs = mock_conn.run.call_args[1]
self.assertEqual(call_kwargs["input"], "file content\n")
Expand All @@ -167,8 +160,7 @@ async def test_input_string_not_file(self):
service_call = self._make_service_call(data)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)):
with patch("ssh_command.coordinator.exists", return_value=False):
await self.handler(service_call)
await self.handler(service_call)

call_kwargs = mock_conn.run.call_args[1]
self.assertEqual(call_kwargs["input"], "inline input")
Expand All @@ -178,8 +170,7 @@ async def test_check_known_hosts_false(self):
service_call = self._make_service_call(SERVICE_DATA_BASE)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("ssh_command.coordinator.exists", return_value=False):
await self.handler(service_call)
await self.handler(service_call)

call_kwargs = mock_connect.call_args[1]
self.assertIsNone(call_kwargs["known_hosts"])
Expand All @@ -191,7 +182,7 @@ async def test_known_hosts_file_exists(self):
service_call = self._make_service_call(data)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("ssh_command.coordinator.exists", return_value=True):
with patch("pathlib.Path.exists", return_value=True):
with patch("ssh_command.coordinator.read_known_hosts", return_value=mock_known_hosts) as mock_rkh:
await self.handler(service_call)

Expand All @@ -205,7 +196,7 @@ async def test_check_known_hosts_default_path_missing(self):
service_call = self._make_service_call(data)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)) as mock_connect:
with patch("ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
await self.handler(service_call)

call_kwargs = mock_connect.call_args[1]
Expand Down
34 changes: 14 additions & 20 deletions test/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,54 +72,48 @@ async def test_async_execute_success(self):
mock_conn = self._make_mock_conn(stdout="hello\n", stderr="", exit_status=0)

with patch("ssh_command.coordinator.connect", return_value=_MockConnect(mock_conn)):
with patch("ssh_command.coordinator.exists", return_value=False):
result = await self.coordinator.async_execute(EXECUTE_DATA_BASE)
result = await self.coordinator.async_execute(EXECUTE_DATA_BASE)

self.assertEqual(result[CONF_OUTPUT], "hello\n")
self.assertEqual(result[CONF_ERROR], "")
self.assertEqual(result[CONF_EXIT_STATUS], 0)

async def test_async_execute_host_key_not_verifiable(self):
with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(HostKeyNotVerifiable("test"))):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.coordinator.async_execute(EXECUTE_DATA_BASE)
with self.assertRaises(ServiceValidationError) as ctx:
await self.coordinator.async_execute(EXECUTE_DATA_BASE)

self.assertEqual(ctx.exception.translation_key, "host_key_not_verifiable")

async def test_async_execute_permission_denied(self):
with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(PermissionDenied("auth failed"))):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.coordinator.async_execute(EXECUTE_DATA_BASE)
with self.assertRaises(ServiceValidationError) as ctx:
await self.coordinator.async_execute(EXECUTE_DATA_BASE)

self.assertEqual(ctx.exception.translation_key, "login_failed")

async def test_async_execute_timeout(self):
with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(TimeoutError())):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.coordinator.async_execute(EXECUTE_DATA_BASE)
with self.assertRaises(ServiceValidationError) as ctx:
await self.coordinator.async_execute(EXECUTE_DATA_BASE)

self.assertEqual(ctx.exception.translation_key, "connection_timed_out")

async def test_async_execute_name_resolution_failure(self):
err = socket.gaierror("Name or service not known")

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await self.coordinator.async_execute(EXECUTE_DATA_BASE)
with self.assertRaises(ServiceValidationError) as ctx:
await self.coordinator.async_execute(EXECUTE_DATA_BASE)

self.assertEqual(ctx.exception.translation_key, "host_not_reachable")

async def test_async_execute_other_oserror_reraised(self):
err = OSError("something else")

with patch("ssh_command.coordinator.connect", return_value=_MockConnectRaises(err)):
with patch("ssh_command.coordinator.exists", return_value=False):
with self.assertRaises(OSError):
await self.coordinator.async_execute(EXECUTE_DATA_BASE)
with self.assertRaises(OSError):
await self.coordinator.async_execute(EXECUTE_DATA_BASE)

async def test_resolve_known_hosts_check_disabled(self):
result = await self.coordinator._resolve_known_hosts(False, None)
Expand All @@ -128,21 +122,21 @@ async def test_resolve_known_hosts_check_disabled(self):
async def test_resolve_known_hosts_file_exists(self):
mock_known_hosts = MagicMock()

with patch("ssh_command.coordinator.exists", return_value=True):
with patch("pathlib.Path.exists", return_value=True):
with patch("ssh_command.coordinator.read_known_hosts", return_value=mock_known_hosts) as mock_rkh:
result = await self.coordinator._resolve_known_hosts(True, "/home/user/.ssh/known_hosts")

mock_rkh.assert_called_once_with("/home/user/.ssh/known_hosts")
self.assertIs(result, mock_known_hosts)

async def test_resolve_known_hosts_file_missing(self):
with patch("ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
result = await self.coordinator._resolve_known_hosts(True, "/nonexistent/known_hosts")

self.assertEqual(result, "/nonexistent/known_hosts")

async def test_resolve_known_hosts_default_path(self):
with patch("ssh_command.coordinator.exists", return_value=False):
with patch("pathlib.Path.exists", return_value=False):
result = await self.coordinator._resolve_known_hosts(True, None)

self.assertIsInstance(result, str)
Expand Down
19 changes: 9 additions & 10 deletions test/test_validate_service_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,26 @@ async def test_no_command_no_input_raises(self):
self.assertEqual(ctx.exception.translation_key, "command_or_input")

async def test_key_file_not_found_raises(self):
with patch("ssh_command.exists", return_value=False):
with patch("os.path.exists", return_value=False):
with self.assertRaises(ServiceValidationError) as ctx:
await _validate_service_data({"key_file": "/nonexistent/key", "command": "ls"})
self.assertEqual(ctx.exception.translation_key, "key_file_not_found")

async def test_known_hosts_with_check_disabled_raises(self):
with patch("ssh_command.exists", return_value=True):
with self.assertRaises(ServiceValidationError) as ctx:
await _validate_service_data({
"password": "secret",
"command": "ls",
"known_hosts": "/etc/ssh/known_hosts",
"check_known_hosts": False,
})
with self.assertRaises(ServiceValidationError) as ctx:
await _validate_service_data({
"password": "secret",
"command": "ls",
"known_hosts": "/etc/ssh/known_hosts",
"check_known_hosts": False,
})
self.assertEqual(ctx.exception.translation_key, "known_hosts_with_check_disabled")

async def test_valid_password_and_command(self):
await _validate_service_data({"password": "secret", "command": "echo hi"})

async def test_valid_key_file_and_input(self):
with patch("ssh_command.exists", return_value=True):
with patch("os.path.exists", return_value=True):
await _validate_service_data({"key_file": "/home/user/.ssh/id_rsa", "input": "some text"})

async def test_valid_known_hosts_with_check_enabled(self):
Expand Down
Loading