mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 03:03:17 +01:00
Add support for HTTP Streamable to MCP integration (#161547)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -47,7 +47,7 @@ MCP_DISCOVERY_HEADERS = {
|
||||
"MCP-Protocol-Version": "2025-03-26",
|
||||
}
|
||||
|
||||
EXAMPLE_URL = "http://example/sse"
|
||||
EXAMPLE_URL = "http://example/mcp"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -122,7 +122,7 @@ async def validate_input(
|
||||
except vol.Invalid as error:
|
||||
raise InvalidUrl from error
|
||||
try:
|
||||
async with mcp_client(url, token_manager=token_manager) as session:
|
||||
async with mcp_client(hass, url, token_manager=token_manager) as session:
|
||||
response = await session.initialize()
|
||||
except httpx.TimeoutException as error:
|
||||
_LOGGER.info("Timeout connecting to MCP server: %s", error)
|
||||
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
import httpx
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert_to_voluptuous
|
||||
|
||||
@@ -17,6 +18,7 @@ from homeassistant.const import CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.httpx_client import create_async_httpx_client
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
@@ -32,10 +34,11 @@ type TokenManager = Callable[[], Awaitable[str]]
|
||||
|
||||
@asynccontextmanager
|
||||
async def mcp_client(
|
||||
hass: HomeAssistant,
|
||||
url: str,
|
||||
token_manager: TokenManager | None = None,
|
||||
) -> AsyncGenerator[ClientSession]:
|
||||
"""Create a server-sent event MCP client.
|
||||
"""Create an MCP client.
|
||||
|
||||
This is an asynccontext manager that exists to wrap other async context managers
|
||||
so that the coordinator has a single object to manage.
|
||||
@@ -44,6 +47,26 @@ async def mcp_client(
|
||||
if token_manager is not None:
|
||||
token = await token_manager()
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
try:
|
||||
async with (
|
||||
streamable_http_client(
|
||||
url=url,
|
||||
http_client=create_async_httpx_client(hass, headers=headers),
|
||||
) as (read_stream, write_stream, _),
|
||||
ClientSession(read_stream, write_stream) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
yield session
|
||||
except ExceptionGroup as streamable_err:
|
||||
main_error = streamable_err.exceptions[0]
|
||||
# Method not Allowed likely means this is not a streamable HTTP server,
|
||||
# but it may be an SSE server. This is part of the MCP Transport
|
||||
# backwards compatibility specification.
|
||||
if (
|
||||
isinstance(main_error, httpx.HTTPStatusError)
|
||||
and main_error.response.status_code == 405
|
||||
):
|
||||
try:
|
||||
async with (
|
||||
sse_client(url=url, headers=headers) as streams,
|
||||
@@ -51,9 +74,12 @@ async def mcp_client(
|
||||
):
|
||||
await session.initialize()
|
||||
yield session
|
||||
except ExceptionGroup as err:
|
||||
_LOGGER.debug("Error creating MCP client: %s", err)
|
||||
raise err.exceptions[0] from err
|
||||
except ExceptionGroup as sse_err:
|
||||
_LOGGER.debug("Error creating SSE MCP client: %s", sse_err)
|
||||
raise sse_err.exceptions[0] from sse_err
|
||||
else:
|
||||
_LOGGER.debug("Error creating MCP client: %s", streamable_err)
|
||||
raise main_error from streamable_err
|
||||
|
||||
|
||||
class ModelContextProtocolTool(llm.Tool):
|
||||
@@ -83,7 +109,9 @@ class ModelContextProtocolTool(llm.Tool):
|
||||
"""Call the tool."""
|
||||
try:
|
||||
async with asyncio.timeout(TIMEOUT):
|
||||
async with mcp_client(self.server_url, self.token_manager) as session:
|
||||
async with mcp_client(
|
||||
hass, self.server_url, self.token_manager
|
||||
) as session:
|
||||
result = await session.call_tool(
|
||||
tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
@@ -126,7 +154,7 @@ class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
|
||||
try:
|
||||
async with asyncio.timeout(TIMEOUT):
|
||||
async with mcp_client(
|
||||
self.config_entry.data[CONF_URL], self.token_manager
|
||||
self.hass, self.config_entry.data[CONF_URL], self.token_manager
|
||||
) as session:
|
||||
result = await session.list_tools()
|
||||
except TimeoutError as error:
|
||||
|
||||
@@ -117,7 +117,10 @@ def create_async_httpx_client(
|
||||
kwargs.setdefault("http2", True)
|
||||
client = HassHttpXAsyncClient(
|
||||
verify=ssl_context,
|
||||
headers={USER_AGENT: SERVER_SOFTWARE},
|
||||
headers={
|
||||
USER_AGENT: SERVER_SOFTWARE,
|
||||
**kwargs.pop("headers", {}),
|
||||
},
|
||||
limits=DEFAULT_LIMITS,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from collections.abc import Generator
|
||||
import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -42,10 +43,34 @@ def mock_setup_entry() -> Generator[AsyncMock]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_client() -> Generator[AsyncMock]:
|
||||
def mock_sse_client() -> Generator[AsyncMock]:
|
||||
"""Fixture to mock the MCP client."""
|
||||
with patch(
|
||||
"homeassistant.components.mcp.coordinator.sse_client"
|
||||
) as mock_sse_client:
|
||||
yield mock_sse_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_streamable_client() -> Generator[AsyncMock]:
|
||||
"""Fixture to mock the MCP client."""
|
||||
with patch(
|
||||
"homeassistant.components.mcp.coordinator.streamable_http_client"
|
||||
) as mock_streamable_client:
|
||||
mock_streamable_client.return_value.__aenter__.return_value = (
|
||||
AsyncMock(),
|
||||
AsyncMock(),
|
||||
AsyncMock(),
|
||||
)
|
||||
yield mock_streamable_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_client(
|
||||
mock_sse_client: Any, mock_http_streamable_client: Any
|
||||
) -> Generator[AsyncMock]:
|
||||
"""Fixture to mock the MCP client."""
|
||||
with (
|
||||
patch("homeassistant.components.mcp.coordinator.sse_client"),
|
||||
patch("homeassistant.components.mcp.coordinator.ClientSession") as mock_session,
|
||||
patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1),
|
||||
):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for the Model Context Protocol component."""
|
||||
|
||||
import re
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
||||
@@ -90,13 +90,90 @@ async def test_mcp_server_failure(
|
||||
mock_mcp_client: Mock,
|
||||
side_effect: Exception,
|
||||
) -> None:
|
||||
"""Test the integration fails to setup if the server fails initialization."""
|
||||
"""Test the integration fails to setup if the server fails initialization.
|
||||
|
||||
This tests generic failure types that are independent of transport.
|
||||
"""
|
||||
mock_mcp_client.side_effect = side_effect
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
|
||||
|
||||
async def test_mcp_server_http_transport_failure(
|
||||
hass: HomeAssistant,
|
||||
config_entry: MockConfigEntry,
|
||||
mock_http_streamable_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test the integration fails to setup if the HTTP transport fails."""
|
||||
mock_http_streamable_client.side_effect = ExceptionGroup(
|
||||
"Connection error", [httpx.ConnectError("Connection failed")]
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
|
||||
|
||||
async def test_mcp_server_sse_transport_failure(
|
||||
hass: HomeAssistant,
|
||||
config_entry: MockConfigEntry,
|
||||
mock_http_streamable_client: AsyncMock,
|
||||
mock_sse_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test the integration fails to setup if the SSE transport fails.
|
||||
|
||||
This exercises the case where the HTTP transport fails with method not
|
||||
allowed, indicating an SSE server, then also fails with SSE.
|
||||
"""
|
||||
http_405 = httpx.HTTPStatusError(
|
||||
"Method not allowed", request=None, response=httpx.Response(405)
|
||||
)
|
||||
mock_http_streamable_client.side_effect = ExceptionGroup(
|
||||
"Method not allowed", [http_405]
|
||||
)
|
||||
|
||||
mock_sse_client.side_effect = ExceptionGroup(
|
||||
"Connection error", [httpx.ConnectError("Connection failed")]
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
|
||||
|
||||
async def test_mcp_client_fallback_to_sse_success(
|
||||
hass: HomeAssistant,
|
||||
config_entry: MockConfigEntry,
|
||||
mock_http_streamable_client: AsyncMock,
|
||||
mock_sse_client: AsyncMock,
|
||||
mock_mcp_client: Mock,
|
||||
) -> None:
|
||||
"""Test mcp_client falls back to SSE on method not allowed error.
|
||||
|
||||
This exercises the backwards compatibility part of the MCP Transport
|
||||
specification.
|
||||
"""
|
||||
http_405 = httpx.HTTPStatusError(
|
||||
"Method not allowed",
|
||||
request=None, # type: ignore[arg-type]
|
||||
response=httpx.Response(405),
|
||||
)
|
||||
mock_http_streamable_client.side_effect = ExceptionGroup(
|
||||
"Method not allowed", [http_405]
|
||||
)
|
||||
|
||||
# Setup mocks for SSE fallback
|
||||
mock_sse_client.return_value.__aenter__.return_value = ("read", "write")
|
||||
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
|
||||
tools=[SEARCH_MEMORY_TOOL]
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
assert mock_http_streamable_client.called
|
||||
assert mock_sse_client.called
|
||||
|
||||
|
||||
async def test_mcp_server_authentication_failure(
|
||||
hass: HomeAssistant,
|
||||
credential: None,
|
||||
|
||||
@@ -57,6 +57,26 @@ async def test_create_async_httpx_client_without_ssl_and_cookies(
|
||||
assert hass.data[client.DATA_ASYNC_CLIENT][(False, SSL_ALPN_HTTP11)] != httpx_client
|
||||
|
||||
|
||||
async def test_create_async_httpx_client_default_headers(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test init async client with default headers."""
|
||||
httpx_client = client.create_async_httpx_client(hass)
|
||||
assert isinstance(httpx_client, httpx.AsyncClient)
|
||||
assert httpx_client.headers[client.USER_AGENT] == client.SERVER_SOFTWARE
|
||||
|
||||
|
||||
async def test_create_async_httpx_client_with_headers(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test init async client with headers."""
|
||||
httpx_client = client.create_async_httpx_client(hass, headers={"x-test": "true"})
|
||||
assert isinstance(httpx_client, httpx.AsyncClient)
|
||||
assert httpx_client.headers["x-test"] == "true"
|
||||
# Default headers are preserved
|
||||
assert httpx_client.headers[client.USER_AGENT] == client.SERVER_SOFTWARE
|
||||
|
||||
|
||||
async def test_get_async_client_cleanup(hass: HomeAssistant) -> None:
|
||||
"""Test init async client with ssl."""
|
||||
client.get_async_client(hass)
|
||||
|
||||
Reference in New Issue
Block a user