mirror of
https://github.com/Electric-Special/ha-core.git
synced 2026-03-21 03:03:17 +01:00
Add timeout to dnsip (to handle stale connections) (#153086)
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from ipaddress import IPv4Address, IPv6Address
|
from ipaddress import IPv4Address, IPv6Address
|
||||||
import logging
|
import logging
|
||||||
@@ -88,8 +89,8 @@ class WanIpSensor(SensorEntity):
|
|||||||
self._attr_name = "IPv6" if ipv6 else None
|
self._attr_name = "IPv6" if ipv6 else None
|
||||||
self._attr_unique_id = f"{hostname}_{ipv6}"
|
self._attr_unique_id = f"{hostname}_{ipv6}"
|
||||||
self.hostname = hostname
|
self.hostname = hostname
|
||||||
self.resolver = aiodns.DNSResolver(tcp_port=port, udp_port=port)
|
self.port = port
|
||||||
self.resolver.nameservers = [resolver]
|
self._resolver = resolver
|
||||||
self.querytype: Literal["A", "AAAA"] = "AAAA" if ipv6 else "A"
|
self.querytype: Literal["A", "AAAA"] = "AAAA" if ipv6 else "A"
|
||||||
self._retries = DEFAULT_RETRIES
|
self._retries = DEFAULT_RETRIES
|
||||||
self._attr_extra_state_attributes = {
|
self._attr_extra_state_attributes = {
|
||||||
@@ -103,14 +104,26 @@ class WanIpSensor(SensorEntity):
|
|||||||
model=aiodns.__version__,
|
model=aiodns.__version__,
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
|
self.resolver: aiodns.DNSResolver
|
||||||
|
self.create_dns_resolver()
|
||||||
|
|
||||||
|
def create_dns_resolver(self) -> None:
|
||||||
|
"""Create the DNS resolver."""
|
||||||
|
self.resolver = aiodns.DNSResolver(tcp_port=self.port, udp_port=self.port)
|
||||||
|
self.resolver.nameservers = [self._resolver]
|
||||||
|
|
||||||
async def async_update(self) -> None:
|
async def async_update(self) -> None:
|
||||||
"""Get the current DNS IP address for hostname."""
|
"""Get the current DNS IP address for hostname."""
|
||||||
|
if self.resolver._closed: # noqa: SLF001
|
||||||
|
self.create_dns_resolver()
|
||||||
|
response = None
|
||||||
try:
|
try:
|
||||||
response = await self.resolver.query(self.hostname, self.querytype)
|
async with asyncio.timeout(10):
|
||||||
|
response = await self.resolver.query(self.hostname, self.querytype)
|
||||||
|
except TimeoutError:
|
||||||
|
await self.resolver.close()
|
||||||
except DNSError as err:
|
except DNSError as err:
|
||||||
_LOGGER.warning("Exception while resolving host: %s", err)
|
_LOGGER.warning("Exception while resolving host: %s", err)
|
||||||
response = None
|
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
sorted_ips = sort_ips(
|
sorted_ips = sort_ips(
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class RetrieveDNS:
|
|||||||
self.nameservers = nameservers
|
self.nameservers = nameservers
|
||||||
self._nameservers = ["1.2.3.4"]
|
self._nameservers = ["1.2.3.4"]
|
||||||
self.error = error
|
self.error = error
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
async def query(self, hostname, qtype) -> list[QueryResult]:
|
async def query(self, hostname, qtype) -> list[QueryResult]:
|
||||||
"""Return information."""
|
"""Return information."""
|
||||||
@@ -47,3 +48,7 @@ class RetrieveDNS:
|
|||||||
@nameservers.setter
|
@nameservers.setter
|
||||||
def nameservers(self, value: list[str]) -> None:
|
def nameservers(self, value: list[str]) -> None:
|
||||||
self._nameservers = value
|
self._nameservers = value
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the resolver."""
|
||||||
|
self._closed = True
|
||||||
|
|||||||
@@ -171,3 +171,70 @@ async def test_sensor_no_response(
|
|||||||
|
|
||||||
state = hass.states.get("sensor.home_assistant_io")
|
state = hass.states.get("sensor.home_assistant_io")
|
||||||
assert state.state == STATE_UNAVAILABLE
|
assert state.state == STATE_UNAVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
async def test_sensor_timeout(
|
||||||
|
hass: HomeAssistant, freezer: FrozenDateTimeFactory
|
||||||
|
) -> None:
|
||||||
|
"""Test the DNS IP sensor with timeout."""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN,
|
||||||
|
source=SOURCE_USER,
|
||||||
|
data={
|
||||||
|
CONF_HOSTNAME: "home-assistant.io",
|
||||||
|
CONF_NAME: "home-assistant.io",
|
||||||
|
CONF_IPV4: True,
|
||||||
|
CONF_IPV6: False,
|
||||||
|
},
|
||||||
|
options={
|
||||||
|
CONF_RESOLVER: "208.67.222.222",
|
||||||
|
CONF_RESOLVER_IPV6: "2620:119:53::53",
|
||||||
|
CONF_PORT: 53,
|
||||||
|
CONF_PORT_IPV6: 53,
|
||||||
|
},
|
||||||
|
entry_id="1",
|
||||||
|
unique_id="home-assistant.io",
|
||||||
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
dns_mock = RetrieveDNS()
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.dnsip.sensor.aiodns.DNSResolver",
|
||||||
|
return_value=dns_mock,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
state = hass.states.get("sensor.home_assistant_io")
|
||||||
|
|
||||||
|
assert state.state == "1.1.1.1"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.dnsip.sensor.aiodns.DNSResolver",
|
||||||
|
return_value=dns_mock,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.dnsip.sensor.asyncio.timeout",
|
||||||
|
side_effect=TimeoutError(),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
freezer.tick(timedelta(seconds=SCAN_INTERVAL.seconds))
|
||||||
|
async_fire_time_changed(hass)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Allows 2 retries before going unavailable
|
||||||
|
state = hass.states.get("sensor.home_assistant_io")
|
||||||
|
assert state.state == "1.1.1.1"
|
||||||
|
assert state.attributes["ip_addresses"] == ["1.1.1.1", "1.2.3.4"]
|
||||||
|
|
||||||
|
freezer.tick(timedelta(seconds=SCAN_INTERVAL.seconds))
|
||||||
|
async_fire_time_changed(hass)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
freezer.tick(timedelta(seconds=SCAN_INTERVAL.seconds))
|
||||||
|
async_fire_time_changed(hass)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
state = hass.states.get("sensor.home_assistant_io")
|
||||||
|
assert state.state == STATE_UNAVAILABLE
|
||||||
|
|||||||
Reference in New Issue
Block a user