Add support for HTTP Streamable to MCP integration (#161547)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Allen Porter
2026-01-24 22:03:55 -08:00
committed by GitHub
parent aca0232e71
commit 6aaec5c159
6 changed files with 168 additions and 15 deletions

View File

@@ -47,7 +47,7 @@ MCP_DISCOVERY_HEADERS = {
"MCP-Protocol-Version": "2025-03-26",
}
EXAMPLE_URL = "http://example/sse"
EXAMPLE_URL = "http://example/mcp"
@dataclass
@@ -122,7 +122,7 @@ async def validate_input(
except vol.Invalid as error:
raise InvalidUrl from error
try:
async with mcp_client(url, token_manager=token_manager) as session:
async with mcp_client(hass, url, token_manager=token_manager) as session:
response = await session.initialize()
except httpx.TimeoutException as error:
_LOGGER.info("Timeout connecting to MCP server: %s", error)

View File

@@ -9,6 +9,7 @@ import logging
import httpx
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamable_http_client
import voluptuous as vol
from voluptuous_openapi import convert_to_voluptuous
@@ -17,6 +18,7 @@ from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.httpx_client import create_async_httpx_client
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from homeassistant.util.json import JsonObjectType
@@ -32,10 +34,11 @@ type TokenManager = Callable[[], Awaitable[str]]
@asynccontextmanager
async def mcp_client(
hass: HomeAssistant,
url: str,
token_manager: TokenManager | None = None,
) -> AsyncGenerator[ClientSession]:
"""Create a server-sent event MCP client.
"""Create an MCP client.
This is an asynccontext manager that exists to wrap other async context managers
so that the coordinator has a single object to manage.
@@ -44,16 +47,39 @@ async def mcp_client(
if token_manager is not None:
token = await token_manager()
headers["Authorization"] = f"Bearer {token}"
try:
async with (
sse_client(url=url, headers=headers) as streams,
ClientSession(*streams) as session,
streamable_http_client(
url=url,
http_client=create_async_httpx_client(hass, headers=headers),
) as (read_stream, write_stream, _),
ClientSession(read_stream, write_stream) as session,
):
await session.initialize()
yield session
except ExceptionGroup as err:
_LOGGER.debug("Error creating MCP client: %s", err)
raise err.exceptions[0] from err
except ExceptionGroup as streamable_err:
main_error = streamable_err.exceptions[0]
# Method not Allowed likely means this is not a streamable HTTP server,
# but it may be an SSE server. This is part of the MCP Transport
# backwards compatibility specification.
if (
isinstance(main_error, httpx.HTTPStatusError)
and main_error.response.status_code == 405
):
try:
async with (
sse_client(url=url, headers=headers) as streams,
ClientSession(*streams) as session,
):
await session.initialize()
yield session
except ExceptionGroup as sse_err:
_LOGGER.debug("Error creating SSE MCP client: %s", sse_err)
raise sse_err.exceptions[0] from sse_err
else:
_LOGGER.debug("Error creating MCP client: %s", streamable_err)
raise main_error from streamable_err
class ModelContextProtocolTool(llm.Tool):
@@ -83,7 +109,9 @@ class ModelContextProtocolTool(llm.Tool):
"""Call the tool."""
try:
async with asyncio.timeout(TIMEOUT):
async with mcp_client(self.server_url, self.token_manager) as session:
async with mcp_client(
hass, self.server_url, self.token_manager
) as session:
result = await session.call_tool(
tool_input.tool_name, tool_input.tool_args
)
@@ -126,7 +154,7 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
try:
async with asyncio.timeout(TIMEOUT):
async with mcp_client(
self.config_entry.data[CONF_URL], self.token_manager
self.hass, self.config_entry.data[CONF_URL], self.token_manager
) as session:
result = await session.list_tools()
except TimeoutError as error:

View File

@@ -117,7 +117,10 @@ def create_async_httpx_client(
kwargs.setdefault("http2", True)
client = HassHttpXAsyncClient(
verify=ssl_context,
headers={USER_AGENT: SERVER_SOFTWARE},
headers={
USER_AGENT: SERVER_SOFTWARE,
**kwargs.pop("headers", {}),
},
limits=DEFAULT_LIMITS,
**kwargs,
)

View File

@@ -2,6 +2,7 @@
from collections.abc import Generator
import datetime
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
@@ -42,10 +43,34 @@ def mock_setup_entry() -> Generator[AsyncMock]:
@pytest.fixture
def mock_mcp_client() -> Generator[AsyncMock]:
def mock_sse_client() -> Generator[AsyncMock]:
"""Fixture to mock the MCP client."""
with patch(
"homeassistant.components.mcp.coordinator.sse_client"
) as mock_sse_client:
yield mock_sse_client
@pytest.fixture
def mock_http_streamable_client() -> Generator[AsyncMock]:
"""Fixture to mock the MCP client."""
with patch(
"homeassistant.components.mcp.coordinator.streamable_http_client"
) as mock_streamable_client:
mock_streamable_client.return_value.__aenter__.return_value = (
AsyncMock(),
AsyncMock(),
AsyncMock(),
)
yield mock_streamable_client
@pytest.fixture
def mock_mcp_client(
mock_sse_client: Any, mock_http_streamable_client: Any
) -> Generator[AsyncMock]:
"""Fixture to mock the MCP client."""
with (
patch("homeassistant.components.mcp.coordinator.sse_client"),
patch("homeassistant.components.mcp.coordinator.ClientSession") as mock_session,
patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1),
):

View File

@@ -1,7 +1,7 @@
"""Tests for the Model Context Protocol component."""
import re
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import httpx
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
@@ -90,13 +90,90 @@ async def test_mcp_server_failure(
mock_mcp_client: Mock,
side_effect: Exception,
) -> None:
"""Test the integration fails to setup if the server fails initialization."""
"""Test the integration fails to setup if the server fails initialization.
This tests generic failure types that are independent of transport.
"""
mock_mcp_client.side_effect = side_effect
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_mcp_server_http_transport_failure(
hass: HomeAssistant,
config_entry: MockConfigEntry,
mock_http_streamable_client: AsyncMock,
) -> None:
"""Test the integration fails to setup if the HTTP transport fails."""
mock_http_streamable_client.side_effect = ExceptionGroup(
"Connection error", [httpx.ConnectError("Connection failed")]
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_mcp_server_sse_transport_failure(
hass: HomeAssistant,
config_entry: MockConfigEntry,
mock_http_streamable_client: AsyncMock,
mock_sse_client: AsyncMock,
) -> None:
"""Test the integration fails to setup if the SSE transport fails.
This exercises the case where the HTTP transport fails with method not
allowed, indicating an SSE server, then also fails with SSE.
"""
http_405 = httpx.HTTPStatusError(
"Method not allowed", request=None, response=httpx.Response(405)
)
mock_http_streamable_client.side_effect = ExceptionGroup(
"Method not allowed", [http_405]
)
mock_sse_client.side_effect = ExceptionGroup(
"Connection error", [httpx.ConnectError("Connection failed")]
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_mcp_client_fallback_to_sse_success(
hass: HomeAssistant,
config_entry: MockConfigEntry,
mock_http_streamable_client: AsyncMock,
mock_sse_client: AsyncMock,
mock_mcp_client: Mock,
) -> None:
"""Test mcp_client falls back to SSE on method not allowed error.
This exercises the backwards compatibility part of the MCP Transport
specification.
"""
http_405 = httpx.HTTPStatusError(
"Method not allowed",
request=None, # type: ignore[arg-type]
response=httpx.Response(405),
)
mock_http_streamable_client.side_effect = ExceptionGroup(
"Method not allowed", [http_405]
)
# Setup mocks for SSE fallback
mock_sse_client.return_value.__aenter__.return_value = ("read", "write")
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL]
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
assert mock_http_streamable_client.called
assert mock_sse_client.called
async def test_mcp_server_authentication_failure(
hass: HomeAssistant,
credential: None,

View File

@@ -57,6 +57,26 @@ async def test_create_async_httpx_client_without_ssl_and_cookies(
assert hass.data[client.DATA_ASYNC_CLIENT][(False, SSL_ALPN_HTTP11)] != httpx_client
async def test_create_async_httpx_client_default_headers(
hass: HomeAssistant,
) -> None:
"""Test init async client with default headers."""
httpx_client = client.create_async_httpx_client(hass)
assert isinstance(httpx_client, httpx.AsyncClient)
assert httpx_client.headers[client.USER_AGENT] == client.SERVER_SOFTWARE
async def test_create_async_httpx_client_with_headers(
hass: HomeAssistant,
) -> None:
"""Test init async client with headers."""
httpx_client = client.create_async_httpx_client(hass, headers={"x-test": "true"})
assert isinstance(httpx_client, httpx.AsyncClient)
assert httpx_client.headers["x-test"] == "true"
# Default headers are preserved
assert httpx_client.headers[client.USER_AGENT] == client.SERVER_SOFTWARE
async def test_get_async_client_cleanup(hass: HomeAssistant) -> None:
"""Test init async client with ssl."""
client.get_async_client(hass)