diff --git a/tests/components/sql/test_services.py b/tests/components/sql/test_services.py index 0ef2f144a01..1d8199a3db0 100644 --- a/tests/components/sql/test_services.py +++ b/tests/components/sql/test_services.py @@ -13,6 +13,7 @@ from voluptuous import MultipleInvalid from homeassistant.components.recorder import Recorder from homeassistant.components.sql.const import DOMAIN from homeassistant.components.sql.services import SERVICE_QUERY +from homeassistant.components.sql.util import generate_lambda_stmt from homeassistant.core import HomeAssistant from homeassistant.exceptions import ServiceValidationError from homeassistant.setup import async_setup_component @@ -86,6 +87,35 @@ async def test_query_service_external_db(hass: HomeAssistant, tmp_path: Path) -> } +async def test_query_service_rollback_on_error(hass: HomeAssistant) -> None: + """Test the query service.""" + await async_setup_component(hass, DOMAIN, {}) + await hass.async_block_till_done() + + with ( + patch( + "homeassistant.components.sql.services.generate_lambda_stmt", + return_value=generate_lambda_stmt("Faulty syntax create operational issue"), + ), + pytest.raises( + ServiceValidationError, match="An error occurred when executing the query" + ), + patch("sqlalchemy.orm.session.Session.rollback") as mock_session_rollback, + ): + await hass.services.async_call( + DOMAIN, + SERVICE_QUERY, + { + "query": "SELECT name, age FROM users ORDER BY age", + "db_url": "sqlite:///", + }, + blocking=True, + return_response=True, + ) + + mock_session_rollback.assert_called_once() + + async def test_query_service_data_conversion( hass: HomeAssistant, tmp_path: Path ) -> None: