add more typehints especially in calendar.py

This commit is contained in:
5ila5
2024-09-20 13:24:20 +02:00
committed by 5ila5
parent ed5fc43d1b
commit 6178dc2095
4 changed files with 67 additions and 40 deletions

View File

@@ -5,7 +5,14 @@ import uuid
from datetime import datetime, timedelta
from homeassistant.components.calendar import CalendarEntity, CalendarEvent
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from custom_components.waste_collection_schedule.waste_collection_schedule import (
Collection,
)
# fmt: off
from custom_components.waste_collection_schedule.waste_collection_schedule.collection_aggregator import (
@@ -29,13 +36,13 @@ class WasteCollectionCalendar(CalendarEntity):
def __init__(
self,
aggregator,
name,
aggregator: CollectionAggregator,
name: str,
unique_id: str,
coordinator=None,
api=None,
include_types=None,
exclude_types=None,
coordinator: WCSCoordinator | None = None,
api: WasteCollectionApi | None = None,
include_types: set[str] | None = None,
exclude_types: set[str] | None = None,
):
self._api = api
self._coordinator = coordinator
@@ -86,9 +93,9 @@ class WasteCollectionCalendar(CalendarEntity):
async def async_get_events(
self, hass: HomeAssistant, start_date: datetime, end_date: datetime
):
) -> list[CalendarEvent]:
"""Return all events within specified time span."""
events = []
events: list[CalendarEvent] = []
for collection in self._aggregator.get_upcoming(
include_today=True,
@@ -102,7 +109,7 @@ class WasteCollectionCalendar(CalendarEntity):
return events
def _convert(self, collection) -> CalendarEvent:
def _convert(self, collection: Collection) -> CalendarEvent:
"""Convert an collection into a Home Assistant calendar event."""
return CalendarEvent(
summary=collection.type,
@@ -117,7 +124,7 @@ def create_calendar_entries(
api: WasteCollectionApi | None = None,
coordinator: WCSCoordinator | None = None,
) -> list[WasteCollectionCalendar]:
entities = []
entities: list[WasteCollectionCalendar] = []
for shell in shells:
dedicated_calendar_types = shell.get_dedicated_calendar_types()
for type in dedicated_calendar_types:
@@ -149,8 +156,10 @@ def create_calendar_entries(
# Config flow setup
async def async_setup_entry(hass, config, async_add_entities):
coordinator = hass.data[DOMAIN][config.entry_id]
async def async_setup_entry(
hass: HomeAssistant, config: ConfigEntry, async_add_entities: AddEntitiesCallback
):
coordinator: WCSCoordinator = hass.data[DOMAIN][config.entry_id]
shell = coordinator.shell
entities = create_calendar_entries([shell], coordinator=coordinator)
@@ -159,7 +168,12 @@ async def async_setup_entry(hass, config, async_add_entities):
# YAML setup
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
async def async_setup_platform(
hass: HomeAssistant,
config: ConfigType,
async_add_entities: AddEntitiesCallback,
discovery_info: DiscoveryInfoType | None = None,
):
"""Set up calendar platform."""
# We only want this platform to be set up via discovery.
if discovery_info is None:
@@ -167,7 +181,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
entities = []
api = discovery_info["api"]
api: WasteCollectionApi = discovery_info["api"]
entities = create_calendar_entries(api.shells, api=api)
async_add_entities(entities)

View File

@@ -83,7 +83,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True
async def async_update_listener(hass: HomeAssistant, entry: ConfigEntry):
async def async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Reload this instance
await hass.config_entries.async_reload(entry.entry_id)

View File

@@ -71,7 +71,7 @@ CONFIG_SCHEMA = vol.Schema(
)
async def async_setup(hass: HomeAssistant, config: dict):
async def async_setup(hass: HomeAssistant, config: dict) -> bool:
"""Set up the component. config contains data from configuration.yaml."""
# Skip for config flow
if const.DOMAIN not in config:

View File

@@ -2,25 +2,38 @@ import datetime
import importlib
import logging
import traceback
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional, Protocol
from .collection import Collection
_LOGGER = logging.getLogger(__name__)
class Fetchable(Protocol):
def fetch(self) -> list[Collection]:
...
class SourceModule(Protocol):
TITLE: str
DESCRIPTION: str
URL: str
Source: Fetchable
class Customize:
"""Customize one waste collection type."""
def __init__(
self,
waste_type,
alias=None,
show=True,
icon=None,
picture=None,
use_dedicated_calendar=False,
dedicated_calendar_title=None,
waste_type: str,
alias: str | None = None,
show: bool = True,
icon: str | None = None,
picture: str | None = None,
use_dedicated_calendar: bool = False,
dedicated_calendar_title: str | None = None,
):
self._waste_type = waste_type
self._alias = alias
@@ -82,7 +95,7 @@ def customize_function(entry: Collection, customize: Dict[str, Customize]):
return entry
def apply_day_offset(entry: Collection, day_offset: int):
def apply_day_offset(entry: Collection, day_offset: int) -> Collection:
entry.set_date(entry.date + datetime.timedelta(days=day_offset))
return entry
@@ -90,7 +103,7 @@ def apply_day_offset(entry: Collection, day_offset: int):
class SourceShell:
def __init__(
self,
source,
source: Fetchable,
customize: Dict[str, Customize],
title: str,
description: str,
@@ -106,7 +119,7 @@ class SourceShell:
self._url = url
self._calendar_title = calendar_title
self._unique_id = unique_id
self._refreshtime = None
self._refreshtime: datetime.datetime | None = None
self._entries: List[Collection] = []
self._day_offset = day_offset
@@ -138,11 +151,11 @@ class SourceShell:
def day_offset(self):
return self._day_offset
def fetch(self):
def fetch(self) -> None:
"""Fetch data from source."""
try:
# fetch returns a list of Collection's
entries = self._source.fetch()
entries: Iterable[Collection] = self._source.fetch()
except Exception:
_LOGGER.error(
f"fetch failed for source {self._title}:\n{traceback.format_exc()}"
@@ -166,7 +179,7 @@ class SourceShell:
self._entries = list(entries)
def get_dedicated_calendar_types(self):
def get_dedicated_calendar_types(self) -> set[str]:
"""Return set of waste types with a dedicated calendar."""
types = set()
@@ -176,7 +189,7 @@ class SourceShell:
return types
def get_calendar_title_for_type(self, type):
def get_calendar_title_for_type(self, type: str) -> str:
"""Return calendar title for waste type (used for dedicated calendars)."""
c = self._customize.get(type)
if c is not None and c.dedicated_calendar_title:
@@ -184,7 +197,7 @@ class SourceShell:
return self.get_collection_type_name(type)
def get_collection_type_name(self, type):
def get_collection_type_name(self, type: str) -> str:
c = self._customize.get(type)
if c is not None and c.alias:
return c.alias
@@ -198,26 +211,26 @@ class SourceShell:
source_args,
calendar_title: Optional[str] = None,
day_offset: int = 0,
):
) -> "SourceShell | None":
# load source module
try:
source_module = importlib.import_module(
source_module: SourceModule = importlib.import_module(
f"waste_collection_schedule.source.{source_name}"
)
except ImportError:
_LOGGER.error(f"source not found: {source_name}")
return
return None
# create source
source = source_module.Source(**source_args) # type: ignore
source: Fetchable = source_module.Source(**source_args) # type: ignore
# create source shell
g = SourceShell(
source=source,
customize=customize,
title=source_module.TITLE, # type: ignore[attr-defined]
description=source_module.DESCRIPTION, # type: ignore[attr-defined]
url=source_module.URL, # type: ignore[attr-defined]
title=source_module.TITLE,
description=source_module.DESCRIPTION,
url=source_module.URL,
calendar_title=calendar_title,
unique_id=calc_unique_source_id(source_name, source_args),
day_offset=day_offset,
@@ -226,5 +239,5 @@ class SourceShell:
return g
def calc_unique_source_id(source_name, source_args):
def calc_unique_source_id(source_name: str, source_args) -> str:
return source_name + str(sorted(source_args.items()))