support to link a sensor to multiple sources

This commit is contained in:
mampfes
2022-12-17 13:45:36 +01:00
parent 505f618f03
commit 1d404c7a05
6 changed files with 224 additions and 153 deletions

View File

@@ -18,13 +18,13 @@ from homeassistant.helpers.event import async_track_time_change # isort:skip
# add module directory to path
package_dir = Path(__file__).resolve().parents[0]
site.addsitedir(str(package_dir))
from waste_collection_schedule import Customize, Scraper # type: ignore # isort:skip # noqa: E402
from waste_collection_schedule import Customize, SourceShell # type: ignore # isort:skip # noqa: E402
_LOGGER = logging.getLogger(__name__)
CONF_SOURCES = "sources"
CONF_SOURCE_NAME = "name"
CONF_SOURCE_ARGS = "args" # scraper-source arguments
CONF_SOURCE_ARGS = "args" # source arguments
CONF_SOURCE_CALENDAR_TITLE = "calendar_title"
CONF_SEPARATOR = "separator"
CONF_FETCH_TIME = "fetch_time"
@@ -92,7 +92,7 @@ async def async_setup(hass: HomeAssistant, config: dict):
day_switch_time=config[DOMAIN][CONF_DAY_SWITCH_TIME],
)
# create scraper(s)
# create shells for source(s)
for source in config[DOMAIN][CONF_SOURCES]:
# create customize object
customize = {}
@@ -106,7 +106,7 @@ async def async_setup(hass: HomeAssistant, config: dict):
use_dedicated_calendar=c.get(CONF_USE_DEDICATED_CALENDAR, False),
dedicated_calendar_title=c.get(CONF_DEDICATED_CALENDAR_TITLE, False),
)
api.add_scraper(
api.add_source_shell(
source_name=source[CONF_SOURCE_NAME],
customize=customize,
calendar_title=source.get(CONF_SOURCE_CALENDAR_TITLE),
@@ -132,7 +132,7 @@ class WasteCollectionApi:
self, hass, separator, fetch_time, random_fetch_time_offset, day_switch_time
):
self._hass = hass
self._scrapers = []
self._source_shells = []
self._separator = separator
self._fetch_time = fetch_time
self._random_fetch_time_offset = random_fetch_time_offset
@@ -183,15 +183,15 @@ class WasteCollectionApi:
"""When to hide entries for today."""
return self._day_switch_time
def add_scraper(
def add_source_shell(
self,
source_name,
customize,
source_args,
calendar_title,
):
self._scrapers.append(
Scraper.create(
self._source_shells.append(
SourceShell.create(
source_name=source_name,
customize=customize,
source_args=source_args,
@@ -200,17 +200,17 @@ class WasteCollectionApi:
)
def _fetch(self, *_):
for scraper in self._scrapers:
scraper.fetch()
for shell in self._source_shells:
shell.fetch()
self._update_sensors_callback()
@property
def scrapers(self):
return self._scrapers
def shells(self):
return self._source_shells
def get_scraper(self, index):
return self._scrapers[index] if index < len(self._scrapers) else None
def get_shell(self, index):
return self._source_shells[index] if index < len(self._source_shells) else None
@callback
def _fetch_callback(self, *_):

View File

@@ -6,7 +6,13 @@ from datetime import datetime, timedelta
from homeassistant.components.calendar import CalendarEntity, CalendarEvent
from homeassistant.core import HomeAssistant
from custom_components.waste_collection_schedule.waste_collection_schedule.scraper import Scraper
# fmt: off
from custom_components.waste_collection_schedule.waste_collection_schedule.collection_aggregator import \
CollectionAggregator
from custom_components.waste_collection_schedule.waste_collection_schedule.source_shell import \
SourceShell
# fmt: on
_LOGGER = logging.getLogger(__name__)
@@ -21,26 +27,29 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
api = discovery_info["api"]
for scraper in api.scrapers:
dedicated_calendar_types = scraper.get_dedicated_calendar_types()
for shell in api.shells:
dedicated_calendar_types = shell.get_dedicated_calendar_types()
for type in dedicated_calendar_types:
entities.append(
WasteCollectionCalendar(
api=api,
scraper=scraper,
name=scraper.get_calendar_title_for_type(type),
include_types={scraper.get_collection_type(type)},
unique_id=calc_unique_calendar_id(scraper, type),
aggregator=CollectionAggregator([shell]),
name=shell.get_calendar_title_for_type(type),
include_types={shell.get_collection_type_name(type)},
unique_id=calc_unique_calendar_id(shell, type),
)
)
entities.append(
WasteCollectionCalendar(
api=api,
scraper=scraper,
name=scraper.calendar_title,
exclude_types={scraper.get_collection_type(type) for type in dedicated_calendar_types},
unique_id=calc_unique_calendar_id(scraper),
aggregator=CollectionAggregator([shell]),
name=shell.calendar_title,
exclude_types={
shell.get_collection_type_name(type)
for type in dedicated_calendar_types
},
unique_id=calc_unique_calendar_id(shell),
)
)
@@ -50,9 +59,17 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
class WasteCollectionCalendar(CalendarEntity):
"""Calendar entity class."""
def __init__(self, api, scraper, name, unique_id: str, include_types=None, exclude_types=None):
def __init__(
self,
api,
aggregator,
name,
unique_id: str,
include_types=None,
exclude_types=None,
):
self._api = api
self._scraper = scraper
self._aggregator = aggregator
self._name = name
self._include_types = include_types
self._exclude_types = exclude_types
@@ -67,8 +84,11 @@ class WasteCollectionCalendar(CalendarEntity):
@property
def event(self):
"""Return next collection event."""
collections = self._scraper.get_upcoming(
count=1, include_today=True, include_types=self._include_types, exclude_types=self._exclude_types
collections = self._aggregator.get_upcoming(
count=1,
include_today=True,
include_types=self._include_types,
exclude_types=self._exclude_types,
)
if len(collections) == 0:
@@ -82,8 +102,10 @@ class WasteCollectionCalendar(CalendarEntity):
"""Return all events within specified time span."""
events = []
for collection in self._scraper.get_upcoming(
include_today=True, include_types=self._include_types, exclude_types=self._exclude_types
for collection in self._aggregator.get_upcoming(
include_today=True,
include_types=self._include_types,
exclude_types=self._exclude_types,
):
event = self._convert(collection)
@@ -101,5 +123,5 @@ class WasteCollectionCalendar(CalendarEntity):
)
def calc_unique_calendar_id(scraper: Scraper, type: str = None):
return scraper.unique_id + ("_" + type if type is not None else "") + "_calendar"
def calc_unique_calendar_id(shell: SourceShell, type: str = None):
return shell.unique_id + ("_" + type if type is not None else "") + "_calendar"

View File

@@ -11,8 +11,15 @@ from homeassistant.const import CONF_NAME, CONF_VALUE_TEMPLATE
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
# fmt: off
from custom_components.waste_collection_schedule.waste_collection_schedule.collection_aggregator import \
CollectionAggregator
from .const import DOMAIN, UPDATE_SENSORS_SIGNAL
# fmt: on
_LOGGER = logging.getLogger(__name__)
CONF_SOURCE_INDEX = "source_index"
@@ -35,7 +42,9 @@ class DetailsFormat(Enum):
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_NAME): cv.string,
vol.Optional(CONF_SOURCE_INDEX, default=0): cv.positive_int,
vol.Optional(CONF_SOURCE_INDEX, default=0): vol.Any(
cv.positive_int, vol.All(cv.ensure_list, [cv.positive_int])
), # can be a scalar or a list
vol.Optional(CONF_DETAILS_FORMAT, default="upcoming"): cv.enum(DetailsFormat),
vol.Optional(CONF_COUNT): cv.positive_int,
vol.Optional(CONF_LEADTIME): cv.positive_int,
@@ -56,14 +65,22 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
if date_template is not None:
date_template.hass = hass
api = hass.data[DOMAIN]
# create aggregator for all sources
source_index = config[CONF_SOURCE_INDEX]
if not isinstance(source_index, list):
source_index = [source_index]
aggregator = CollectionAggregator([api.get_shell(i) for i in source_index])
entities = []
entities.append(
ScheduleSensor(
hass=hass,
api=hass.data[DOMAIN],
api=api,
name=config[CONF_NAME],
source_index=config[CONF_SOURCE_INDEX],
aggregator=aggregator,
details_format=config[CONF_DETAILS_FORMAT],
count=config.get(CONF_COUNT),
leadtime=config.get(CONF_LEADTIME),
@@ -85,7 +102,7 @@ class ScheduleSensor(SensorEntity):
hass,
api,
name,
source_index,
aggregator,
details_format,
count,
leadtime,
@@ -96,7 +113,7 @@ class ScheduleSensor(SensorEntity):
):
"""Initialize the entity."""
self._api = api
self._source_index = source_index
self._aggregator = aggregator
self._details_format = details_format
self._count = count
self._leadtime = leadtime
@@ -123,10 +140,6 @@ class ScheduleSensor(SensorEntity):
"""Entities have been added to hass."""
self._update_sensor()
@property
def _scraper(self):
return self._api.get_scraper(self._source_index)
@property
def _separator(self):
"""Return separator string used to join waste types."""
@@ -140,8 +153,8 @@ class ScheduleSensor(SensorEntity):
def _add_refreshtime(self):
"""Add refresh-time (= last fetch time) to device-state-attributes."""
refreshtime = ""
if self._scraper.refreshtime is not None:
refreshtime = self._scraper.refreshtime.strftime("%x %X")
if self._aggregator.refreshtime is not None:
refreshtime = self._aggregator.refreshtime.strftime("%x %X")
self._attr_attribution = f"Last update: {refreshtime}"
def _set_state(self, upcoming):
@@ -179,14 +192,15 @@ class ScheduleSensor(SensorEntity):
def _update_sensor(self):
"""Update the state and the device-state-attributes of the entity.
Called if a new data has been fetched from the scraper source.
Called if a new data has been fetched from the source.
"""
if self._scraper is None:
_LOGGER.error(f"source_index {self._source_index} out of range")
if self._aggregator is None:
return None
upcoming1 = self._scraper.get_upcoming_group_by_day(
count=1, include_types=self._collection_types, include_today=self._include_today,
upcoming1 = self._aggregator.get_upcoming_group_by_day(
count=1,
include_types=self._collection_types,
include_today=self._include_today,
)
self._set_state(upcoming1)
@@ -194,14 +208,14 @@ class ScheduleSensor(SensorEntity):
attributes = {}
collection_types = (
sorted(self._scraper.get_types())
sorted(self._aggregator.types)
if self._collection_types is None
else self._collection_types
)
if self._details_format == DetailsFormat.upcoming:
# show upcoming events list in details
upcoming = self._scraper.get_upcoming_group_by_day(
upcoming = self._aggregator.get_upcoming_group_by_day(
count=self._count,
leadtime=self._leadtime,
include_types=self._collection_types,
@@ -214,7 +228,7 @@ class ScheduleSensor(SensorEntity):
elif self._details_format == DetailsFormat.appointment_types:
# show list of collections in details
for t in collection_types:
collections = self._scraper.get_upcoming(
collections = self._aggregator.get_upcoming(
count=1, include_types=[t], include_today=self._include_today
)
date = (
@@ -224,15 +238,15 @@ class ScheduleSensor(SensorEntity):
elif self._details_format == DetailsFormat.generic:
# insert generic attributes into details
attributes["types"] = collection_types
attributes["upcoming"] = self._scraper.get_upcoming(
attributes["upcoming"] = self._aggregator.get_upcoming(
count=self._count,
leadtime=self._leadtime,
include_types=self._collection_types,
include_today=self._include_today,
)
refreshtime = ""
if self._scraper.refreshtime is not None:
refreshtime = self._scraper.refreshtime.isoformat(timespec="seconds")
if self._aggregator.refreshtime is not None:
refreshtime = self._aggregator.refreshtime.isoformat(timespec="seconds")
attributes["last_update"] = refreshtime
if len(upcoming1) > 0:

View File

@@ -1,2 +1,3 @@
from .collection import Collection, CollectionBase, CollectionGroup # type: ignore # isort:skip # noqa: F401
from .scraper import Customize, Scraper # noqa: F401
from .collection_aggregator import CollectionAggregator # noqa: F401
from .source_shell import Customize, SourceShell # noqa: F401

View File

@@ -0,0 +1,121 @@
import itertools
import logging
from datetime import datetime, timedelta
from .collection import CollectionGroup
_LOGGER = logging.getLogger(__name__)
class CollectionAggregator:
def __init__(self, shells):
self._shells = shells
@property
def _entries(self):
"""Merge all entries from all connected sources."""
return [e for s in self._shells for e in s._entries]
@property
def refreshtime(self):
"""Simply return the timestamp of the first source."""
return self._shells[0].refreshtime
@property
def types(self):
"""Return set() of all collection types."""
return {e.type for e in self._entries}
def get_upcoming(
self,
count=None,
leadtime=None,
include_types=None,
exclude_types=None,
include_today=False,
):
"""Return list of all entries, limited by count and/or leadtime.
Keyword arguments:
count -- limits the number of returned entries (default=10)
leadtime -- limits the timespan in days of returned entries (default=7, 0 = today)
"""
return self._filter(
self._entries,
count=count,
leadtime=leadtime,
include_types=include_types,
exclude_types=exclude_types,
include_today=include_today,
)
def get_upcoming_group_by_day(
self,
count=None,
leadtime=None,
include_types=None,
exclude_types=None,
include_today=False,
):
"""Return list of all entries, grouped by day, limited by count and/or leadtime."""
entries = []
iterator = itertools.groupby(
self._filter(
self._entries,
leadtime=leadtime,
include_types=include_types,
exclude_types=exclude_types,
include_today=include_today,
),
lambda e: e.date,
)
for key, group in iterator:
entries.append(CollectionGroup.create(list(group)))
if count is not None:
entries = entries[:count]
return entries
def _filter(
self,
entries,
count=None,
leadtime=None,
include_types=None,
exclude_types=None,
include_today=False,
):
# remove unwanted waste types from include list
if include_types is not None:
entries = list(
filter(lambda e: e.type in set(include_types), self._entries)
)
# remove unwanted waste types from exclude list
if exclude_types is not None:
entries = list(
filter(lambda e: e.type not in set(exclude_types), self._entries)
)
# remove expired entries
now = datetime.now().date()
if include_today:
entries = list(filter(lambda e: e.date >= now, entries))
else:
entries = list(filter(lambda e: e.date > now, entries))
# remove entries which are too far in the future (0 = today)
if leadtime is not None:
x = now + timedelta(days=leadtime)
entries = list(filter(lambda e: e.date <= x, entries))
# ensure that entries are sorted by date
entries.sort(key=lambda e: e.date)
# remove surplus entries
if count is not None:
entries = entries[:count]
return entries

View File

@@ -1,13 +1,10 @@
#!/usr/bin/env python3
import datetime
import importlib
import itertools
import logging
import traceback
from typing import Dict, List, Optional
from .collection import Collection, CollectionGroup
from .collection import Collection
_LOGGER = logging.getLogger(__name__)
@@ -85,7 +82,7 @@ def customize_function(entry: Collection, customize: Dict[str, Customize]):
return entry
class Scraper:
class SourceShell:
def __init__(
self,
source,
@@ -106,10 +103,6 @@ class Scraper:
self._refreshtime = None
self._entries: List[Collection] = []
@property
def source(self):
return self._source
@property
def refreshtime(self):
return self._refreshtime
@@ -158,14 +151,8 @@ class Scraper:
self._entries = list(entries)
def get_types(self):
"""Return set() of all collection types."""
types = set()
for e in self._entries:
types.add(e.type)
return types
def get_dedicated_calendar_types(self):
"""Return set of waste types with a dedicated calendar."""
types = set()
for key, customize in self._customize.items():
@@ -174,92 +161,21 @@ class Scraper:
return types
def get_upcoming(self, count=None, leadtime=None, include_types=None, exclude_types=None, include_today=False):
"""Return list of all entries, limited by count and/or leadtime.
Keyword arguments:
count -- limits the number of returned entries (default=10)
leadtime -- limits the timespan in days of returned entries (default=7, 0 = today)
"""
return self._filter(
self._entries,
count=count,
leadtime=leadtime,
include_types=include_types,
exclude_types=exclude_types,
include_today=include_today,
)
def get_upcoming_group_by_day(
self, count=None, leadtime=None, include_types=None, exclude_types=None, include_today=False
):
"""Return list of all entries, grouped by day, limited by count and/or leadtime."""
entries = []
iterator = itertools.groupby(
self._filter(
self._entries,
leadtime=leadtime,
include_types=include_types,
exclude_types=exclude_types,
include_today=include_today,
),
lambda e: e.date,
)
for key, group in iterator:
entries.append(CollectionGroup.create(list(group)))
if count is not None:
entries = entries[:count]
return entries
def get_calendar_title_for_type(self, type):
"""Return calendar title for waste type (used for dedicated calendars)."""
c = self._customize.get(type)
if c is not None and c.dedicated_calendar_title:
return c.dedicated_calendar_title
return self.get_collection_type(type)
return self.get_collection_type_name(type)
def get_collection_type(self, type):
def get_collection_type_name(self, type):
c = self._customize.get(type)
if c is not None and c.alias:
return c.alias
return type
def _filter(
self, entries, count=None, leadtime=None, include_types=None, exclude_types=None, include_today=False
):
# remove unwanted waste types from include list
if include_types is not None:
entries = list(filter(lambda e: e.type in set(include_types), self._entries))
# remove unwanted waste types from exclude list
if exclude_types is not None:
entries = list(filter(lambda e: e.type not in set(exclude_types), self._entries))
# remove expired entries
now = datetime.datetime.now().date()
if include_today:
entries = list(filter(lambda e: e.date >= now, entries))
else:
entries = list(filter(lambda e: e.date > now, entries))
# remove entries which are too far in the future (0 = today)
if leadtime is not None:
x = now + datetime.timedelta(days=leadtime)
entries = list(filter(lambda e: e.date <= x, entries))
# ensure that entries are sorted by date
entries.sort(key=lambda e: e.date)
# remove surplus entries
if count is not None:
entries = entries[:count]
return entries
@staticmethod
def create(
source_name: str,
@@ -268,9 +184,6 @@ class Scraper:
calendar_title: Optional[str] = None,
):
# load source module
# for home-assistant, use the last 3 folders, e.g. custom_component/wave_collection_schedule/waste_collection_schedule
# otherwise, only use waste_collection_schedule
try:
source_module = importlib.import_module(
f"waste_collection_schedule.source.{source_name}"
@@ -282,19 +195,19 @@ class Scraper:
# create source
source = source_module.Source(**source_args) # type: ignore
# create scraper
g = Scraper(
# create source shell
g = SourceShell(
source=source,
customize=customize,
title=source_module.TITLE, # type: ignore[attr-defined]
description=source_module.DESCRIPTION, # type: ignore[attr-defined]
url=source_module.URL, # type: ignore[attr-defined]
calendar_title=calendar_title,
unique_id=calc_unique_scraper_id(source_name, source_args),
unique_id=calc_unique_source_id(source_name, source_args),
)
return g
def calc_unique_scraper_id(source_name, source_args):
def calc_unique_source_id(source_name, source_args):
return source_name + str(sorted(source_args.items()))