diff --git a/homeassistant/components/derivative/sensor.py b/homeassistant/components/derivative/sensor.py index e0b4a19de64..8515b54295a 100644 --- a/homeassistant/components/derivative/sensor.py +++ b/homeassistant/components/derivative/sensor.py @@ -10,13 +10,16 @@ import voluptuous as vol from homeassistant.components.sensor import ( ATTR_STATE_CLASS, + DEVICE_CLASS_UNITS, PLATFORM_SCHEMA as SENSOR_PLATFORM_SCHEMA, RestoreSensor, + SensorDeviceClass, SensorEntity, SensorStateClass, ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( + ATTR_DEVICE_CLASS, ATTR_UNIT_OF_MEASUREMENT, CONF_NAME, CONF_SOURCE, @@ -83,6 +86,17 @@ UNIT_TIME = { UnitOfTime.DAYS: 24 * 60 * 60, } +DERIVED_CLASS = { + SensorDeviceClass.ENERGY: SensorDeviceClass.POWER, + SensorDeviceClass.ENERGY_STORAGE: SensorDeviceClass.POWER, + SensorDeviceClass.DATA_SIZE: SensorDeviceClass.DATA_RATE, + SensorDeviceClass.DISTANCE: SensorDeviceClass.SPEED, + SensorDeviceClass.WATER: SensorDeviceClass.VOLUME_FLOW_RATE, + SensorDeviceClass.GAS: SensorDeviceClass.VOLUME_FLOW_RATE, + SensorDeviceClass.VOLUME: SensorDeviceClass.VOLUME_FLOW_RATE, + SensorDeviceClass.VOLUME_STORAGE: SensorDeviceClass.VOLUME_FLOW_RATE, +} + DEFAULT_ROUND = 3 DEFAULT_TIME_WINDOW = 0 @@ -203,10 +217,11 @@ class DerivativeSensor(RestoreSensor, SensorEntity): self._attr_name = name if name is not None else f"{source_entity} derivative" self._attr_extra_state_attributes = {ATTR_SOURCE_ID: source_entity} - self._unit_template: str | None = None + self._string_unit_prefix: str | None = None + self._string_unit_time: str | None = None if unit_of_measurement is None: - final_unit_prefix = "" if unit_prefix is None else unit_prefix - self._unit_template = f"{final_unit_prefix}{{}}/{unit_time}" + self._string_unit_prefix = "" if unit_prefix is None else unit_prefix + self._string_unit_time = unit_time # we postpone the definition of unit_of_measurement to later self._attr_native_unit_of_measurement = None else: @@ -225,12 +240,40 @@ class DerivativeSensor(RestoreSensor, SensorEntity): ) def _derive_and_set_attributes_from_state(self, source_state: State | None) -> None: - if self._unit_template and source_state: + if not source_state: + return + + source_class_raw = source_state.attributes.get(ATTR_DEVICE_CLASS) + source_class: SensorDeviceClass | None = None + if isinstance(source_class_raw, str): + try: + source_class = SensorDeviceClass(source_class_raw) + except ValueError: + source_class = None + if self._string_unit_prefix is not None and self._string_unit_time is not None: original_unit = self._attr_native_unit_of_measurement source_unit = source_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) - self._attr_native_unit_of_measurement = self._unit_template.format( - "" if source_unit is None else source_unit - ) + if ( + ( + source_class + in (SensorDeviceClass.ENERGY, SensorDeviceClass.ENERGY_STORAGE) + ) + and self._string_unit_time == UnitOfTime.HOURS + and source_unit + and source_unit.endswith("Wh") + ): + self._attr_native_unit_of_measurement = ( + f"{self._string_unit_prefix}{source_unit[:-1]}" + ) + + else: + unit_template = ( + f"{self._string_unit_prefix}{{}}/{self._string_unit_time}" + ) + self._attr_native_unit_of_measurement = unit_template.format( + "" if source_unit is None else source_unit + ) + if original_unit != self._attr_native_unit_of_measurement: _LOGGER.debug( "%s: Derivative sensor switched UoM from %s to %s, resetting state to 0", @@ -241,6 +284,16 @@ class DerivativeSensor(RestoreSensor, SensorEntity): self._state_list = [] self._attr_native_value = round(Decimal(0), self._round_digits) + self._attr_device_class = None + if source_class: + derived_class = DERIVED_CLASS.get(source_class) + if ( + derived_class + and self._attr_native_unit_of_measurement + in DEVICE_CLASS_UNITS[derived_class] + ): + self._attr_device_class = derived_class + def _calc_derivative_from_state_list(self, current_time: datetime) -> Decimal: def calculate_weight(start: datetime, end: datetime, now: datetime) -> float: window_start = now - timedelta(seconds=self._time_window) @@ -309,6 +362,10 @@ class DerivativeSensor(RestoreSensor, SensorEntity): except InvalidOperation, TypeError: self._attr_native_value = None + last_state = await self.async_get_last_state() + if last_state: + self._attr_device_class = last_state.attributes.get(ATTR_DEVICE_CLASS) + async def async_added_to_hass(self) -> None: """Handle entity which will be added.""" await super().async_added_to_hass() diff --git a/tests/components/derivative/test_sensor.py b/tests/components/derivative/test_sensor.py index 4b282582789..29337d5d369 100644 --- a/tests/components/derivative/test_sensor.py +++ b/tests/components/derivative/test_sensor.py @@ -11,13 +11,22 @@ import pytest from homeassistant import config as hass_config, core as ha from homeassistant.components.derivative.const import DOMAIN -from homeassistant.components.sensor import ATTR_STATE_CLASS, SensorStateClass +from homeassistant.components.sensor import ( + ATTR_STATE_CLASS, + SensorDeviceClass, + SensorStateClass, +) from homeassistant.const import ( + ATTR_DEVICE_CLASS, SERVICE_RELOAD, STATE_UNAVAILABLE, STATE_UNKNOWN, + UnitOfDataRate, + UnitOfEnergy, UnitOfPower, + UnitOfSpeed, UnitOfTime, + UnitOfVolumeFlowRate, ) from homeassistant.core import ( Event, @@ -642,6 +651,137 @@ async def test_sub_intervals_with_time_window(hass: HomeAssistant) -> None: assert expect_min <= derivative <= expect_max, f"Failed at time {time}" +@pytest.mark.parametrize( + ("extra_config", "source_unit", "source_class", "derived_unit", "derived_class"), + [ + ( + {}, + UnitOfEnergy.KILO_WATT_HOUR, + SensorDeviceClass.ENERGY, + UnitOfPower.KILO_WATT, + SensorDeviceClass.POWER, + ), + ( + {}, + UnitOfEnergy.TERA_WATT_HOUR, + SensorDeviceClass.ENERGY, + UnitOfPower.TERA_WATT, + SensorDeviceClass.POWER, + ), + ( + {"unit_prefix": "m"}, + UnitOfEnergy.WATT_HOUR, + SensorDeviceClass.ENERGY_STORAGE, + UnitOfPower.MILLIWATT, + SensorDeviceClass.POWER, + ), + ( + {"unit_prefix": "k"}, + UnitOfEnergy.WATT_HOUR, + SensorDeviceClass.ENERGY, + UnitOfPower.KILO_WATT, + SensorDeviceClass.POWER, + ), + ( + {"unit_prefix": "n"}, + UnitOfEnergy.WATT_HOUR, + SensorDeviceClass.ENERGY, + "nW", + None, + ), + ( + {}, + "GB", + SensorDeviceClass.DATA_SIZE, + "GB/h", + None, + ), + ( + {"unit_time": "s"}, + "GB", + SensorDeviceClass.DATA_SIZE, + UnitOfDataRate.GIGABYTES_PER_SECOND, + SensorDeviceClass.DATA_RATE, + ), + ( + {}, + "km", + SensorDeviceClass.DISTANCE, + UnitOfSpeed.KILOMETERS_PER_HOUR, + SensorDeviceClass.SPEED, + ), + ( + {}, + "m³", + SensorDeviceClass.GAS, + UnitOfVolumeFlowRate.CUBIC_METERS_PER_HOUR, + SensorDeviceClass.VOLUME_FLOW_RATE, + ), + ( + {"unit_time": "min"}, + "gal", + SensorDeviceClass.WATER, + UnitOfVolumeFlowRate.GALLONS_PER_MINUTE, + SensorDeviceClass.VOLUME_FLOW_RATE, + ), + ( + {}, + UnitOfEnergy.KILO_WATT_HOUR, + "not_a_real_device_class", + "kWh/h", + None, + ), + ], +) +async def test_device_classes( + extra_config: dict[str, Any], + source_unit: str, + source_class: str, + derived_unit: str, + derived_class: str, + hass: HomeAssistant, +) -> None: + """Test derivative sensor handles unit conversions and device classes.""" + config = { + "sensor": { + "platform": "derivative", + "name": "derivative", + "source": "sensor.source", + "round": 2, + "unit_time": "h", + **extra_config, + } + } + + assert await async_setup_component(hass, "sensor", config) + entity_id = config["sensor"]["source"] + base = dt_util.utcnow() + with freeze_time(base) as freezer: + hass.states.async_set( + entity_id, + 1000, + { + "unit_of_measurement": source_unit, + "device_class": source_class, + }, + ) + await hass.async_block_till_done() + freezer.move_to(dt_util.utcnow() + timedelta(seconds=3600)) + hass.states.async_set( + entity_id, + 2000, + { + "unit_of_measurement": source_unit, + "device_class": source_class, + }, + ) + await hass.async_block_till_done() + state = hass.states.get("sensor.derivative") + assert state is not None + assert state.attributes.get("unit_of_measurement") == derived_unit + assert state.attributes.get("device_class") == derived_class + + async def test_prefix(hass: HomeAssistant) -> None: """Test derivative sensor state using a power source.""" config = { @@ -885,13 +1025,11 @@ async def test_unavailable_boot( State( "sensor.power", restore_state, - { - "unit_of_measurement": "kWh/s", - }, + {"unit_of_measurement": "kW", "device_class": "power"}, ), { "native_value": restore_state, - "native_unit_of_measurement": "kWh/s", + "native_unit_of_measurement": "kW", }, ), ], @@ -902,12 +1040,16 @@ async def test_unavailable_boot( "name": "power", "source": "sensor.energy", "round": 2, - "unit_time": "s", + "unit_time": "h", } config = {"sensor": config} entity_id = config["sensor"]["source"] - hass.states.async_set(entity_id, STATE_UNAVAILABLE, {"unit_of_measurement": "kWh"}) + hass.states.async_set( + entity_id, + STATE_UNAVAILABLE, + {"unit_of_measurement": "kWh", "device_class": "energy"}, + ) await hass.async_block_till_done() assert await async_setup_component(hass, "sensor", config) @@ -917,11 +1059,14 @@ async def test_unavailable_boot( assert state is not None # Sensor is unavailable as source is unavailable assert state.state == STATE_UNAVAILABLE + assert state.attributes.get(ATTR_DEVICE_CLASS) == "power" base = dt_util.utcnow() with freeze_time(base) as freezer: - freezer.move_to(base + timedelta(seconds=1)) - hass.states.async_set(entity_id, 10, {"unit_of_measurement": "kWh"}) + freezer.move_to(base + timedelta(hours=1)) + hass.states.async_set( + entity_id, 10, {"unit_of_measurement": "kWh", "device_class": "energy"} + ) await hass.async_block_till_done() state = hass.states.get("sensor.power") @@ -930,15 +1075,17 @@ async def test_unavailable_boot( # so just hold until the next tick assert state.state == restore_state - freezer.move_to(base + timedelta(seconds=2)) - hass.states.async_set(entity_id, 15, {"unit_of_measurement": "kWh"}) + freezer.move_to(base + timedelta(hours=2)) + hass.states.async_set( + entity_id, 15, {"unit_of_measurement": "kWh", "device_class": "energy"} + ) await hass.async_block_till_done() state = hass.states.get("sensor.power") assert state is not None # Now that the source sensor has two valid datapoints, we can calculate derivative assert state.state == "5.00" - assert state.attributes.get("unit_of_measurement") == "kWh/s" + assert state.attributes.get("unit_of_measurement") == "kW" async def test_source_unit_change(