diff --git a/custom_components/waste_collection_schedule/config_flow.py b/custom_components/waste_collection_schedule/config_flow.py index 08ac262f..05705980 100644 --- a/custom_components/waste_collection_schedule/config_flow.py +++ b/custom_components/waste_collection_schedule/config_flow.py @@ -5,7 +5,7 @@ import logging import types from datetime import date, datetime from pathlib import Path -from typing import Any, Tuple, TypedDict, cast +from typing import Any, Literal, Tuple, TypedDict, Union, cast, get_origin import homeassistant.helpers.config_validation as cv import voluptuous as vol @@ -34,6 +34,13 @@ from homeassistant.helpers.selector import ( ) from homeassistant.helpers.translation import async_get_translations from voluptuous.schema_builder import UNDEFINED +from waste_collection_schedule.collection import Collection +from waste_collection_schedule.exceptions import ( + SourceArgumentException, + SourceArgumentExceptionMultiple, + SourceArgumentRequired, + SourceArgumentSuggestionsExceptionBase, +) from .const import ( CONF_ADD_DAYS_TO, @@ -289,6 +296,7 @@ class WasteCollectionConfigFlow(ConfigFlow, domain=DOMAIN): # type: ignore[call _source: str | None = None _sources: dict[str, list[SourceDict]] = {} + _error_suggestions: dict[str, list[Any]] def __init__(self, *args: list, **kwargs: dict): super().__init__(*args, **kwargs) @@ -373,6 +381,54 @@ class WasteCollectionConfigFlow(ConfigFlow, domain=DOMAIN): # type: ignore[call return self.async_show_form(step_id="source", data_schema=SCHEMA, errors=errors) + async def __get_simple_annotation_type(self, annotation: Any) -> Any: + if annotation in SUPPORTED_ARG_TYPES: + return SUPPORTED_ARG_TYPES[annotation] + if ( + isinstance(annotation, types.GenericAlias) + and annotation.__origin__ in SUPPORTED_ARG_TYPES + ): + return SUPPORTED_ARG_TYPES[annotation.__origin__] + + if getattr(annotation, "__origin__", None) is Literal: + return SelectSelector( + SelectSelectorConfig( + options=[ + SelectOptionDict(label=x, value=x) for x in annotation.__args__ + ], + custom_value=False, + multiple=False, + ) + ) + return None + + async def __get_type_by_annotation(self, annotation: Any) -> Any: + if a := await self.__get_simple_annotation_type(annotation): + return a + if ( + (isinstance(annotation, types.GenericAlias)) + or ( + get_origin(annotation) is not None and hasattr(annotation, "__origin__") + ) + and (a := await self.__get_simple_annotation_type(annotation.__origin__)) + ): + return a + return_val = None + is_string = False + + if ( + isinstance(annotation, types.UnionType) + or getattr(annotation, "__origin__", None) is Union + ): + for arg in annotation.__args__: + if a := await self.__get_type_by_annotation(arg): + if isinstance(a, ObjectSelector): + return a + if not is_string: + return_val = a + is_string = a == cv.string + return return_val + async def __get_arg_schema( self, source: str, @@ -391,6 +447,14 @@ class WasteCollectionConfigFlow(ConfigFlow, domain=DOMAIN): # type: ignore[call Returns: Tuple[vol.Schema, types.ModuleType]: schema, module """ + suggestions: dict[str, list[Any]] = {} + if hasattr(self, "_error_suggestions"): + suggestions = { + key: value + for key, value in self._error_suggestions.items() + if len(value) > 0 + } + # Import source and get arguments module = await self.hass.async_add_executor_job( importlib.import_module, f"waste_collection_schedule.source.{source}" @@ -436,30 +500,10 @@ class WasteCollectionConfigFlow(ConfigFlow, domain=DOMAIN): # type: ignore[call description = { "suggested_value": pre_filled[args[arg].name], } - - if ( - default == inspect.Signature.empty or default is None - ) and annotation != inspect._empty: - if annotation in SUPPORTED_ARG_TYPES: - field_type = SUPPORTED_ARG_TYPES[annotation] - elif ( - isinstance(annotation, types.GenericAlias) - and annotation.__origin__ in SUPPORTED_ARG_TYPES - ): - field_type = SUPPORTED_ARG_TYPES[annotation.__origin__] - elif isinstance(annotation, types.UnionType): - for a in annotation.__args__: - _LOGGER.debug(f"{args[arg].name} UnionType: {a}, {type(a)}") - if a in SUPPORTED_ARG_TYPES: - field_type = SUPPORTED_ARG_TYPES[a] - if a == str: - break - elif ( - isinstance(a, types.GenericAlias) - and a.__origin__ in SUPPORTED_ARG_TYPES - ): - field_type = SUPPORTED_ARG_TYPES[a.__origin__] - + if annotation != inspect._empty: + field_type = ( + await self.__get_type_by_annotation(annotation) or field_type + ) _LOGGER.debug( f"Default for {args[arg].name}: {type(default) if default is not inspect.Signature.empty else inspect.Signature.empty}" ) @@ -482,11 +526,29 @@ class WasteCollectionConfigFlow(ConfigFlow, domain=DOMAIN): # type: ignore[call if field_type is None: field_type = SUPPORTED_ARG_TYPES.get(type(default)) + if (field_type or str) in (str, cv.string) and args[ + arg + ].name in suggestions: + _LOGGER.debug( + f"Adding suggestions to {args[arg].name}: {suggestions[args[arg].name]}" + ) + # Add suggestions to the field if fetch/init raised an Exception with suggestions + field_type = SelectSelector( + SelectSelectorConfig( + options=[ + SelectOptionDict(label=x, value=x) + for x in suggestions[args[arg].name] + ], + mode=SelectSelectorMode.DROPDOWN, + custom_value=True, + multiple=False, + ) + ) + if default == inspect.Signature.empty: vol_args[vol.Required(args[arg].name, description=description)] = ( field_type or str ) - _LOGGER.debug(f"Required: {args[arg].name} as default type: str") elif field_type or (default is None): # Handle boolean, int, string, date, datetime, list defaults @@ -538,11 +600,34 @@ class WasteCollectionConfigFlow(ConfigFlow, domain=DOMAIN): # type: ignore[call try: instance = module.Source(**args_input) - resp = await self.hass.async_add_executor_job(instance.fetch) + resp: list[Collection] = await self.hass.async_add_executor_job( + instance.fetch + ) if len(resp) == 0: errors["base"] = "fetch_empty" self._fetched_types = list({x.type.strip() for x in resp}) + except SourceArgumentSuggestionsExceptionBase as e: + if not hasattr(self, "_error_suggestions"): + self._error_suggestions = {} + self._error_suggestions.update({e.argument: e.suggestions}) + errors[e.argument] = "invalid_arg" + description_placeholders["invalid_arg_message"] = e.simple_message + if e.suggestion_type != str and e.suggestion_type != int: + description_placeholders["invalid_arg_message"] = e.message + except SourceArgumentRequired as e: + errors[e.argument] = "invalid_arg" + description_placeholders["invalid_arg_message"] = e.message + except SourceArgumentException as e: + errors[e.argument] = "invalid_arg" + description_placeholders["invalid_arg_message"] = e.message + except SourceArgumentExceptionMultiple as e: + description_placeholders["invalid_arg_message"] = e.message + if len(e.arguments) == 0: + errors["base"] = "invalid_arg" + else: + for arg in e.arguments: + errors[arg] = "invalid_arg" except Exception as e: errors["base"] = "fetch_error" description_placeholders["fetch_error_message"] = str(e) @@ -565,7 +650,12 @@ class WasteCollectionConfigFlow(ConfigFlow, domain=DOMAIN): # type: ignore[call description_placeholders, options, ) = await self.__validate_args_user_input(self._source, args_input, module) - if len(errors) == 0: + + if len(errors) > 0: + schema, module = await self.__get_arg_schema( + self._source, self._extra_info_default_params, args_input + ) + else: self._args_data = { CONF_SOURCE_NAME: self._source, CONF_SOURCE_ARGS: args_input, diff --git a/custom_components/waste_collection_schedule/waste_collection_schedule/exceptions.py b/custom_components/waste_collection_schedule/waste_collection_schedule/exceptions.py new file mode 100644 index 00000000..533decef --- /dev/null +++ b/custom_components/waste_collection_schedule/waste_collection_schedule/exceptions.py @@ -0,0 +1,133 @@ +from typing import Any, Generic, Iterable, Type, TypeVar + +T = TypeVar("T") + + +class SourceArgumentExceptionMultiple(Exception): + def __init__(self, arguments: Iterable[str], message: str): + self._arguments = arguments + self.message = message + super().__init__(self.message) + + @property + def arguments(self) -> Iterable[str]: + return self._arguments + + +class SourceArgumentException(Exception): + def __init__(self, argument, message): + self._argument = argument + self.message = message + super().__init__(self.message) + + @property + def argument(self) -> str: + return self._argument + + +class SourceArgumentSuggestionsExceptionBase(SourceArgumentException, Generic[T]): + def __init__( + self, + argument: str, + message: str, + suggestions: Iterable[T], + message_addition: str = "", + ): + self._simple_message = message + message += f", {message_addition}" if message_addition else "" + super().__init__(argument=argument, message=message) + self._suggestions = suggestions + self._suggestion_type: Type[T] | None = ( + type(list(suggestions)[0]) if suggestions else None + ) + + @property + def suggestions(self) -> Iterable[T]: + return self._suggestions + + @property + def suggestion_type(self) -> Type[T] | None: + return self._suggestion_type + + @property + def simple_message(self) -> str: + return self._simple_message + + +class SourceArgumentNotFound(SourceArgumentException): + """Invalid arguments provided.""" + + def __init__( + self, + argument: str, + value: Any, + message_addition="please check the spelling and try again.", + ) -> None: + self._simple_message = f"We could not find values for the argument '{argument}' with the value '{value}'" + self.message = self._simple_message + if message_addition: + self.message += f", {message_addition}" + super().__init__(argument, self.message) + + @property + def simple_message(self) -> str: + return self._simple_message + + +class SourceArgumentNotFoundWithSuggestions(SourceArgumentSuggestionsExceptionBase): + def __init__(self, argument: str, value: Any, suggestions: Iterable[T]) -> None: + message = f"We could not find values for the argument '{argument}' with the value '{value}'" + suggestions = list(suggestions) + if len(suggestions) == 0: + message += ", We could not find any suggestions. Please also check other arguments." + message_addition = "" + else: + message_addition = ( + f"you may want to use one of the following: {suggestions}" + ) + super().__init__( + argument=argument, + message=message, + message_addition=message_addition, + suggestions=suggestions, + ) + + +class SourceArgAmbiguousWithSuggestions(SourceArgumentSuggestionsExceptionBase): + def __init__(self, argument: str, value: Any, suggestions: Iterable[T]) -> None: + message = f"Multiple values found for the argument '{argument}' with the value '{value}'" + message_addition = f"please specify one of: {suggestions}" + super().__init__( + argument=argument, + message=message, + suggestions=suggestions, + message_addition=message_addition, + ) + + +class SourceArgumentRequired(SourceArgumentException): + """Argument must be provided.""" + + def __init__(self, argument: str, reason: str) -> None: + self.message = f"Argument '{argument}' must be provided" + if reason: + self.message += f", {reason}" + super().__init__(argument, self.message) + + +class SourceArgumentRequiredWithSuggestions(SourceArgumentSuggestionsExceptionBase): + """Argument must be provided.""" + + def __init__(self, argument: str, reason: str, suggestions: Iterable[T]) -> None: + message = f"Argument '{argument}' must be provided" + message_addition = ( + f"you may want to use one of the following: {list(suggestions)}" + ) + if reason: + message += f", {reason}" + super().__init__( + argument=argument, + message=message, + message_addition=message_addition, + suggestions=suggestions, + )