From 6aaec5c1595dcaece511766dc30c494925aac8eb Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 24 Jan 2026 22:03:55 -0800 Subject: [PATCH] Add support for HTTP Streamable to MCP integration (#161547) Co-authored-by: Paulus Schoutsen Co-authored-by: Paulus Schoutsen Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- homeassistant/components/mcp/config_flow.py | 4 +- homeassistant/components/mcp/coordinator.py | 44 +++++++++-- homeassistant/helpers/httpx_client.py | 5 +- tests/components/mcp/conftest.py | 29 +++++++- tests/components/mcp/test_init.py | 81 ++++++++++++++++++++- tests/helpers/test_httpx_client.py | 20 +++++ 6 files changed, 168 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/mcp/config_flow.py b/homeassistant/components/mcp/config_flow.py index a44b22baadc..8ade969bf9f 100644 --- a/homeassistant/components/mcp/config_flow.py +++ b/homeassistant/components/mcp/config_flow.py @@ -47,7 +47,7 @@ MCP_DISCOVERY_HEADERS = { "MCP-Protocol-Version": "2025-03-26", } -EXAMPLE_URL = "http://example/sse" +EXAMPLE_URL = "http://example/mcp" @dataclass @@ -122,7 +122,7 @@ async def validate_input( except vol.Invalid as error: raise InvalidUrl from error try: - async with mcp_client(url, token_manager=token_manager) as session: + async with mcp_client(hass, url, token_manager=token_manager) as session: response = await session.initialize() except httpx.TimeoutException as error: _LOGGER.info("Timeout connecting to MCP server: %s", error) diff --git a/homeassistant/components/mcp/coordinator.py b/homeassistant/components/mcp/coordinator.py index df4ec2c0f2f..6c3303c647d 100644 --- a/homeassistant/components/mcp/coordinator.py +++ b/homeassistant/components/mcp/coordinator.py @@ -9,6 +9,7 @@ import logging import httpx from mcp.client.session import ClientSession from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamable_http_client import voluptuous as vol from voluptuous_openapi import convert_to_voluptuous @@ -17,6 +18,7 @@ from homeassistant.const import CONF_URL from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError from homeassistant.helpers import llm +from homeassistant.helpers.httpx_client import create_async_httpx_client from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.util.json import JsonObjectType @@ -32,10 +34,11 @@ type TokenManager = Callable[[], Awaitable[str]] @asynccontextmanager async def mcp_client( + hass: HomeAssistant, url: str, token_manager: TokenManager | None = None, ) -> AsyncGenerator[ClientSession]: - """Create a server-sent event MCP client. + """Create an MCP client. This is an asynccontext manager that exists to wrap other async context managers so that the coordinator has a single object to manage. @@ -44,16 +47,39 @@ async def mcp_client( if token_manager is not None: token = await token_manager() headers["Authorization"] = f"Bearer {token}" + try: async with ( - sse_client(url=url, headers=headers) as streams, - ClientSession(*streams) as session, + streamable_http_client( + url=url, + http_client=create_async_httpx_client(hass, headers=headers), + ) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): await session.initialize() yield session - except ExceptionGroup as err: - _LOGGER.debug("Error creating MCP client: %s", err) - raise err.exceptions[0] from err + except ExceptionGroup as streamable_err: + main_error = streamable_err.exceptions[0] + # Method not Allowed likely means this is not a streamable HTTP server, + # but it may be an SSE server. This is part of the MCP Transport + # backwards compatibility specification. + if ( + isinstance(main_error, httpx.HTTPStatusError) + and main_error.response.status_code == 405 + ): + try: + async with ( + sse_client(url=url, headers=headers) as streams, + ClientSession(*streams) as session, + ): + await session.initialize() + yield session + except ExceptionGroup as sse_err: + _LOGGER.debug("Error creating SSE MCP client: %s", sse_err) + raise sse_err.exceptions[0] from sse_err + else: + _LOGGER.debug("Error creating MCP client: %s", streamable_err) + raise main_error from streamable_err class ModelContextProtocolTool(llm.Tool): @@ -83,7 +109,9 @@ class ModelContextProtocolTool(llm.Tool): """Call the tool.""" try: async with asyncio.timeout(TIMEOUT): - async with mcp_client(self.server_url, self.token_manager) as session: + async with mcp_client( + hass, self.server_url, self.token_manager + ) as session: result = await session.call_tool( tool_input.tool_name, tool_input.tool_args ) @@ -126,7 +154,7 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]): try: async with asyncio.timeout(TIMEOUT): async with mcp_client( - self.config_entry.data[CONF_URL], self.token_manager + self.hass, self.config_entry.data[CONF_URL], self.token_manager ) as session: result = await session.list_tools() except TimeoutError as error: diff --git a/homeassistant/helpers/httpx_client.py b/homeassistant/helpers/httpx_client.py index 690b2579d12..d253c3377aa 100644 --- a/homeassistant/helpers/httpx_client.py +++ b/homeassistant/helpers/httpx_client.py @@ -117,7 +117,10 @@ def create_async_httpx_client( kwargs.setdefault("http2", True) client = HassHttpXAsyncClient( verify=ssl_context, - headers={USER_AGENT: SERVER_SOFTWARE}, + headers={ + USER_AGENT: SERVER_SOFTWARE, + **kwargs.pop("headers", {}), + }, limits=DEFAULT_LIMITS, **kwargs, ) diff --git a/tests/components/mcp/conftest.py b/tests/components/mcp/conftest.py index c179936f7d6..48700ce007a 100644 --- a/tests/components/mcp/conftest.py +++ b/tests/components/mcp/conftest.py @@ -2,6 +2,7 @@ from collections.abc import Generator import datetime +from typing import Any from unittest.mock import AsyncMock, patch import pytest @@ -42,10 +43,34 @@ def mock_setup_entry() -> Generator[AsyncMock]: @pytest.fixture -def mock_mcp_client() -> Generator[AsyncMock]: +def mock_sse_client() -> Generator[AsyncMock]: + """Fixture to mock the MCP client.""" + with patch( + "homeassistant.components.mcp.coordinator.sse_client" + ) as mock_sse_client: + yield mock_sse_client + + +@pytest.fixture +def mock_http_streamable_client() -> Generator[AsyncMock]: + """Fixture to mock the MCP client.""" + with patch( + "homeassistant.components.mcp.coordinator.streamable_http_client" + ) as mock_streamable_client: + mock_streamable_client.return_value.__aenter__.return_value = ( + AsyncMock(), + AsyncMock(), + AsyncMock(), + ) + yield mock_streamable_client + + +@pytest.fixture +def mock_mcp_client( + mock_sse_client: Any, mock_http_streamable_client: Any +) -> Generator[AsyncMock]: """Fixture to mock the MCP client.""" with ( - patch("homeassistant.components.mcp.coordinator.sse_client"), patch("homeassistant.components.mcp.coordinator.ClientSession") as mock_session, patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1), ): diff --git a/tests/components/mcp/test_init.py b/tests/components/mcp/test_init.py index 00666e71d05..24aebe0101f 100644 --- a/tests/components/mcp/test_init.py +++ b/tests/components/mcp/test_init.py @@ -1,7 +1,7 @@ """Tests for the Model Context Protocol component.""" import re -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import httpx from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool @@ -90,13 +90,90 @@ async def test_mcp_server_failure( mock_mcp_client: Mock, side_effect: Exception, ) -> None: - """Test the integration fails to setup if the server fails initialization.""" + """Test the integration fails to setup if the server fails initialization. + + This tests generic failure types that are independent of transport. + """ mock_mcp_client.side_effect = side_effect await hass.config_entries.async_setup(config_entry.entry_id) assert config_entry.state is ConfigEntryState.SETUP_RETRY +async def test_mcp_server_http_transport_failure( + hass: HomeAssistant, + config_entry: MockConfigEntry, + mock_http_streamable_client: AsyncMock, +) -> None: + """Test the integration fails to setup if the HTTP transport fails.""" + mock_http_streamable_client.side_effect = ExceptionGroup( + "Connection error", [httpx.ConnectError("Connection failed")] + ) + + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.SETUP_RETRY + + +async def test_mcp_server_sse_transport_failure( + hass: HomeAssistant, + config_entry: MockConfigEntry, + mock_http_streamable_client: AsyncMock, + mock_sse_client: AsyncMock, +) -> None: + """Test the integration fails to setup if the SSE transport fails. + + This exercises the case where the HTTP transport fails with method not + allowed, indicating an SSE server, then also fails with SSE. + """ + http_405 = httpx.HTTPStatusError( + "Method not allowed", request=None, response=httpx.Response(405) + ) + mock_http_streamable_client.side_effect = ExceptionGroup( + "Method not allowed", [http_405] + ) + + mock_sse_client.side_effect = ExceptionGroup( + "Connection error", [httpx.ConnectError("Connection failed")] + ) + + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.SETUP_RETRY + + +async def test_mcp_client_fallback_to_sse_success( + hass: HomeAssistant, + config_entry: MockConfigEntry, + mock_http_streamable_client: AsyncMock, + mock_sse_client: AsyncMock, + mock_mcp_client: Mock, +) -> None: + """Test mcp_client falls back to SSE on method not allowed error. + + This exercises the backwards compatibility part of the MCP Transport + specification. + """ + http_405 = httpx.HTTPStatusError( + "Method not allowed", + request=None, # type: ignore[arg-type] + response=httpx.Response(405), + ) + mock_http_streamable_client.side_effect = ExceptionGroup( + "Method not allowed", [http_405] + ) + + # Setup mocks for SSE fallback + mock_sse_client.return_value.__aenter__.return_value = ("read", "write") + mock_mcp_client.return_value.list_tools.return_value = ListToolsResult( + tools=[SEARCH_MEMORY_TOOL] + ) + + await hass.config_entries.async_setup(config_entry.entry_id) + assert config_entry.state is ConfigEntryState.LOADED + + assert mock_http_streamable_client.called + assert mock_sse_client.called + + async def test_mcp_server_authentication_failure( hass: HomeAssistant, credential: None, diff --git a/tests/helpers/test_httpx_client.py b/tests/helpers/test_httpx_client.py index 2e2ab82a7cd..fc87b2bc963 100644 --- a/tests/helpers/test_httpx_client.py +++ b/tests/helpers/test_httpx_client.py @@ -57,6 +57,26 @@ async def test_create_async_httpx_client_without_ssl_and_cookies( assert hass.data[client.DATA_ASYNC_CLIENT][(False, SSL_ALPN_HTTP11)] != httpx_client +async def test_create_async_httpx_client_default_headers( + hass: HomeAssistant, +) -> None: + """Test init async client with default headers.""" + httpx_client = client.create_async_httpx_client(hass) + assert isinstance(httpx_client, httpx.AsyncClient) + assert httpx_client.headers[client.USER_AGENT] == client.SERVER_SOFTWARE + + +async def test_create_async_httpx_client_with_headers( + hass: HomeAssistant, +) -> None: + """Test init async client with headers.""" + httpx_client = client.create_async_httpx_client(hass, headers={"x-test": "true"}) + assert isinstance(httpx_client, httpx.AsyncClient) + assert httpx_client.headers["x-test"] == "true" + # Default headers are preserved + assert httpx_client.headers[client.USER_AGENT] == client.SERVER_SOFTWARE + + async def test_get_async_client_cleanup(hass: HomeAssistant) -> None: """Test init async client with ssl.""" client.get_async_client(hass)