mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 08:06:00 +01:00
Fix JSON serialization of datetime objects in Google Generative AI tool results (#162495)
This commit is contained in:
@@ -7,6 +7,7 @@ import base64
|
||||
import codecs
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Callable
|
||||
from dataclasses import dataclass, replace
|
||||
import datetime
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@@ -181,13 +182,25 @@ def _escape_decode(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _validate_tool_results(value: Any) -> Any:
|
||||
"""Recursively convert non-json-serializable types."""
|
||||
if isinstance(value, (datetime.time, datetime.date)):
|
||||
return value.isoformat()
|
||||
if isinstance(value, list):
|
||||
return [_validate_tool_results(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {k: _validate_tool_results(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def _create_google_tool_response_parts(
|
||||
parts: list[conversation.ToolResultContent],
|
||||
) -> list[Part]:
|
||||
"""Create Google tool response parts."""
|
||||
return [
|
||||
Part.from_function_response(
|
||||
name=tool_result.tool_name, response=tool_result.tool_result
|
||||
name=tool_result.tool_name,
|
||||
response=_validate_tool_results(tool_result.tool_result),
|
||||
)
|
||||
for tool_result in parts
|
||||
]
|
||||
|
||||
@@ -1,6 +1,57 @@
|
||||
# serializer version: 1
|
||||
# name: test_function_call
|
||||
list([
|
||||
Content(
|
||||
parts=[
|
||||
Part(
|
||||
text='What time is it?'
|
||||
),
|
||||
],
|
||||
role='user'
|
||||
),
|
||||
Content(
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
args={},
|
||||
name='HassGetCurrentTime'
|
||||
)
|
||||
),
|
||||
],
|
||||
role='model'
|
||||
),
|
||||
Content(
|
||||
parts=[
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name='HassGetCurrentTime',
|
||||
response={
|
||||
'data': {
|
||||
'failed': [],
|
||||
'success': [],
|
||||
'targets': []
|
||||
},
|
||||
'response_type': 'action_done',
|
||||
'speech': {
|
||||
'plain': {<... 2 items at Max depth ...>}
|
||||
},
|
||||
'speech_slots': {
|
||||
'time': '16:24:17.813343'
|
||||
}
|
||||
}
|
||||
)
|
||||
),
|
||||
],
|
||||
role='user'
|
||||
),
|
||||
Content(
|
||||
parts=[
|
||||
Part(
|
||||
text='4:24 PM'
|
||||
),
|
||||
],
|
||||
role='model'
|
||||
),
|
||||
Content(
|
||||
parts=[
|
||||
Part(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for the Google Generative AI Conversation integration conversation platform."""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
@@ -8,7 +9,11 @@ import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import UserContent
|
||||
from homeassistant.components.conversation import (
|
||||
AssistantContent,
|
||||
ToolResultContent,
|
||||
UserContent,
|
||||
)
|
||||
from homeassistant.components.google_generative_ai_conversation.entity import (
|
||||
ERROR_GETTING_RESPONSE,
|
||||
_escape_decode,
|
||||
@@ -17,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.entity import (
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.llm import ToolInput
|
||||
|
||||
from . import API_ERROR_500, CLIENT_ERROR_BAD_REQUEST
|
||||
|
||||
@@ -87,6 +93,41 @@ async def test_function_call(
|
||||
agent_id = "conversation.google_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
# Add some pre-existing content from conversation.default_agent
|
||||
mock_chat_log.async_add_user_content(UserContent(content="What time is it?"))
|
||||
mock_chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id=agent_id,
|
||||
tool_calls=[
|
||||
ToolInput(
|
||||
tool_name="HassGetCurrentTime",
|
||||
tool_args={},
|
||||
id="01KGW7TFC1VVVK7ANHVMDA4DJ6",
|
||||
external=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
mock_chat_log.async_add_assistant_content_without_tools(
|
||||
ToolResultContent(
|
||||
agent_id=agent_id,
|
||||
tool_call_id="01KGW7TFC1VVVK7ANHVMDA4DJ6",
|
||||
tool_name="HassGetCurrentTime",
|
||||
tool_result={
|
||||
"speech": {"plain": {"speech": "4:24 PM", "extra_data": None}},
|
||||
"response_type": "action_done",
|
||||
"speech_slots": {"time": datetime.time(16, 24, 17, 813343)},
|
||||
"data": {"targets": [], "success": [], "failed": []},
|
||||
},
|
||||
)
|
||||
)
|
||||
mock_chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content="4:24 PM",
|
||||
)
|
||||
)
|
||||
|
||||
messages = [
|
||||
# Function call stream
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user