454 lines
17 KiB
Python
454 lines
17 KiB
Python
|
"""Allows the creation of generic variable entities."""
|
||
|
|
||
|
import logging
|
||
|
from typing import Union, Sequence
|
||
|
import asyncio
|
||
|
import json
|
||
|
|
||
|
import voluptuous as vol
|
||
|
|
||
|
from homeassistant.core import callback
|
||
|
from homeassistant.helpers import config_validation as cv
|
||
|
from homeassistant.helpers.config_validation import make_entity_service_schema
|
||
|
from homeassistant.const import (
|
||
|
ATTR_FRIENDLY_NAME, ATTR_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE,
|
||
|
CONF_ICON, CONF_ICON_TEMPLATE, ATTR_ENTITY_PICTURE,
|
||
|
CONF_ENTITY_PICTURE_TEMPLATE, ATTR_ENTITY_ID,
|
||
|
EVENT_HOMEASSISTANT_START, CONF_FRIENDLY_NAME_TEMPLATE, MATCH_ALL,
|
||
|
EVENT_STATE_CHANGED)
|
||
|
from homeassistant.exceptions import TemplateError
|
||
|
from homeassistant.helpers.entity import async_generate_entity_id
|
||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||
|
from homeassistant.helpers.restore_state import RestoreEntity
|
||
|
from homeassistant.helpers.event import async_track_state_change
|
||
|
from homeassistant.helpers.json import JSONEncoder
|
||
|
from homeassistant.components import recorder
|
||
|
from homeassistant.components.recorder.models import Events
|
||
|
|
||
|
_LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
DOMAIN = 'var'
|
||
|
ENTITY_ID_FORMAT = DOMAIN + '.{}'
|
||
|
|
||
|
CONF_INITIAL_VALUE = 'initial_value'
|
||
|
CONF_RESTORE = 'restore'
|
||
|
CONF_FORCE_UPDATE = 'force_update'
|
||
|
CONF_QUERY = 'query'
|
||
|
CONF_COLUMN = 'column'
|
||
|
CONF_TRACKED_ENTITY_ID = 'tracked_entity_id'
|
||
|
CONF_TRACKED_EVENT_TYPE = 'tracked_event_type'
|
||
|
|
||
|
ATTR_VALUE = 'value'
|
||
|
|
||
|
def validate_event_types(value: Union[str, Sequence]) -> Sequence[str]:
|
||
|
"""Validate event types."""
|
||
|
if value is None:
|
||
|
raise vol.Invalid('Event types can not be None')
|
||
|
if isinstance(value, str):
|
||
|
value = [event_type.strip() for event_type in value.split(',')]
|
||
|
|
||
|
return [event_type for event_type in value]
|
||
|
|
||
|
def validate_sql_select(value):
|
||
|
"""Validate that value is a SQL SELECT query."""
|
||
|
if not value.lstrip().lower().startswith('select'):
|
||
|
raise vol.Invalid('Only SELECT queries allowed')
|
||
|
return value
|
||
|
|
||
|
SERVICE_SET = "set"
|
||
|
SERVICE_SET_SCHEMA = make_entity_service_schema({
|
||
|
vol.Optional(ATTR_VALUE): cv.match_all,
|
||
|
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
|
||
|
vol.Optional(CONF_QUERY): vol.All(cv.string, validate_sql_select),
|
||
|
vol.Optional(CONF_COLUMN): cv.string,
|
||
|
vol.Optional(ATTR_UNIT_OF_MEASUREMENT): cv.string,
|
||
|
vol.Optional(CONF_RESTORE): cv.boolean,
|
||
|
vol.Optional(CONF_FORCE_UPDATE): cv.boolean,
|
||
|
vol.Optional(ATTR_FRIENDLY_NAME): cv.string,
|
||
|
vol.Optional(CONF_FRIENDLY_NAME_TEMPLATE): cv.template,
|
||
|
vol.Optional(CONF_ICON): cv.icon,
|
||
|
vol.Optional(CONF_ICON_TEMPLATE): cv.template,
|
||
|
vol.Optional(ATTR_ENTITY_PICTURE): cv.string,
|
||
|
vol.Optional(CONF_ENTITY_PICTURE_TEMPLATE): cv.template,
|
||
|
vol.Optional(CONF_TRACKED_ENTITY_ID): cv.entity_ids,
|
||
|
vol.Optional(CONF_TRACKED_EVENT_TYPE): validate_event_types,
|
||
|
})
|
||
|
|
||
|
SERVICE_UPDATE = "update"
|
||
|
SERVICE_UPDATE_SCHEMA = make_entity_service_schema({})
|
||
|
|
||
|
CONFIG_SCHEMA = vol.Schema({
|
||
|
DOMAIN: vol.Schema({
|
||
|
cv.slug: vol.Any({
|
||
|
vol.Optional(CONF_INITIAL_VALUE): cv.match_all,
|
||
|
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
|
||
|
vol.Optional(CONF_QUERY): vol.All(cv.string, validate_sql_select),
|
||
|
vol.Optional(CONF_COLUMN): cv.string,
|
||
|
vol.Optional(ATTR_UNIT_OF_MEASUREMENT): cv.string,
|
||
|
vol.Optional(CONF_RESTORE): cv.boolean,
|
||
|
vol.Optional(CONF_FORCE_UPDATE): cv.boolean,
|
||
|
vol.Optional(ATTR_FRIENDLY_NAME): cv.string,
|
||
|
vol.Optional(CONF_FRIENDLY_NAME_TEMPLATE): cv.template,
|
||
|
vol.Optional(CONF_ICON): cv.icon,
|
||
|
vol.Optional(CONF_ICON_TEMPLATE): cv.template,
|
||
|
vol.Optional(ATTR_ENTITY_PICTURE): cv.string,
|
||
|
vol.Optional(CONF_ENTITY_PICTURE_TEMPLATE): cv.template,
|
||
|
vol.Optional(CONF_TRACKED_ENTITY_ID): cv.entity_ids,
|
||
|
vol.Optional(CONF_TRACKED_EVENT_TYPE): validate_event_types,
|
||
|
}, None)
|
||
|
})
|
||
|
}, extra=vol.ALLOW_EXTRA)
|
||
|
|
||
|
async def async_setup(hass, config):
|
||
|
"""Set up variables from config."""
|
||
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
||
|
await component.async_setup(config)
|
||
|
|
||
|
entities = []
|
||
|
|
||
|
for object_id, cfg in config[DOMAIN].items():
|
||
|
if not cfg:
|
||
|
cfg = {}
|
||
|
|
||
|
initial_value = cfg.get(CONF_INITIAL_VALUE)
|
||
|
unit = cfg.get(ATTR_UNIT_OF_MEASUREMENT)
|
||
|
restore = cfg.get(CONF_RESTORE, True)
|
||
|
force_update = cfg.get(CONF_FORCE_UPDATE, False)
|
||
|
friendly_name = cfg.get(ATTR_FRIENDLY_NAME, object_id)
|
||
|
icon = cfg.get(CONF_ICON)
|
||
|
entity_picture = cfg.get(ATTR_ENTITY_PICTURE)
|
||
|
|
||
|
value_template = cfg.get(CONF_VALUE_TEMPLATE)
|
||
|
friendly_name_template = cfg.get(CONF_FRIENDLY_NAME_TEMPLATE)
|
||
|
icon_template = cfg.get(CONF_ICON_TEMPLATE)
|
||
|
entity_picture_template = cfg.get(CONF_ENTITY_PICTURE_TEMPLATE)
|
||
|
for template in (value_template,
|
||
|
icon_template,
|
||
|
entity_picture_template,
|
||
|
friendly_name_template,
|
||
|
):
|
||
|
if template is not None:
|
||
|
template.hass = hass
|
||
|
|
||
|
manual_entity_ids = cfg.get(CONF_TRACKED_ENTITY_ID)
|
||
|
|
||
|
tracked_entity_ids = list()
|
||
|
if manual_entity_ids is not None:
|
||
|
tracked_entity_ids = list(set(manual_entity_ids))
|
||
|
|
||
|
tracked_event_types = cfg.get(CONF_TRACKED_EVENT_TYPE)
|
||
|
if tracked_event_types is not None:
|
||
|
tracked_event_types = list(set(tracked_event_types))
|
||
|
|
||
|
query = cfg.get(CONF_QUERY)
|
||
|
column = cfg.get(CONF_COLUMN)
|
||
|
|
||
|
session = hass.data[recorder.DATA_INSTANCE].get_session()
|
||
|
|
||
|
entities.append(
|
||
|
Variable(
|
||
|
hass,
|
||
|
object_id,
|
||
|
initial_value,
|
||
|
value_template,
|
||
|
session,
|
||
|
query,
|
||
|
column,
|
||
|
unit,
|
||
|
restore,
|
||
|
force_update,
|
||
|
friendly_name,
|
||
|
friendly_name_template,
|
||
|
icon,
|
||
|
icon_template,
|
||
|
entity_picture,
|
||
|
entity_picture_template,
|
||
|
tracked_entity_ids,
|
||
|
tracked_event_types)
|
||
|
)
|
||
|
|
||
|
if not entities:
|
||
|
return False
|
||
|
|
||
|
component.async_register_entity_service(
|
||
|
SERVICE_SET, SERVICE_SET_SCHEMA,
|
||
|
'async_set'
|
||
|
)
|
||
|
|
||
|
component.async_register_entity_service(
|
||
|
SERVICE_UPDATE, SERVICE_UPDATE_SCHEMA,
|
||
|
'async_force_update'
|
||
|
)
|
||
|
|
||
|
await component.async_add_entities(entities)
|
||
|
return True
|
||
|
|
||
|
class Variable(RestoreEntity):
|
||
|
"""Representation of a variable."""
|
||
|
|
||
|
def __init__(self, hass, object_id, initial_value, value_template,
|
||
|
session, query, column, unit, restore, force_update,
|
||
|
friendly_name, friendly_name_template, icon,
|
||
|
icon_template, entity_picture, entity_picture_template,
|
||
|
tracked_entity_ids, tracked_event_types):
|
||
|
"""Initialize a variable."""
|
||
|
self.hass = hass
|
||
|
self.entity_id = ENTITY_ID_FORMAT.format(object_id)
|
||
|
self._value = initial_value
|
||
|
self._initial_value = initial_value
|
||
|
self._value_template = value_template
|
||
|
self._session = session
|
||
|
if query is not None and not 'LIMIT' in query:
|
||
|
self._query = query.replace(";", " LIMIT 1;")
|
||
|
else:
|
||
|
self._query = query
|
||
|
self._column = column
|
||
|
self._unit = unit
|
||
|
self._restore = restore
|
||
|
self._force_update = force_update
|
||
|
self._friendly_name = friendly_name
|
||
|
self._friendly_name_template = friendly_name_template
|
||
|
self._icon = icon
|
||
|
self._icon_template = icon_template
|
||
|
self._entity_picture = entity_picture
|
||
|
self._entity_picture_template = entity_picture_template
|
||
|
self._tracked_entity_ids = tracked_entity_ids
|
||
|
self._stop_track_state_change = None
|
||
|
self._tracked_event_types = tracked_event_types
|
||
|
self._stop_track_events = []
|
||
|
|
||
|
def _is_event_in_db(self, event):
|
||
|
"""Query the database to see if the event has been written."""
|
||
|
event_id = self._session.query(Events.event_id).filter_by(
|
||
|
event_type=event.event_type, time_fired=event.time_fired,
|
||
|
event_data=json.dumps(event.data, cls=JSONEncoder)).scalar()
|
||
|
return event_id is not None
|
||
|
|
||
|
def _get_variable_event_listener(self):
|
||
|
@callback
|
||
|
def listener(event):
|
||
|
"""Update variable once monitored event fires and is recorded to the database."""
|
||
|
# Skip untracked state changes
|
||
|
if (event.event_type == EVENT_STATE_CHANGED and
|
||
|
self._tracked_entity_ids is not None and
|
||
|
event.data['entity_id'] not in self._tracked_entity_ids):
|
||
|
return
|
||
|
|
||
|
# Schedule update immediately if there is no query
|
||
|
if self._query is None:
|
||
|
self.async_schedule_update_ha_state(True)
|
||
|
return
|
||
|
|
||
|
# Otherwise poll the database scheduling update once event has been committed
|
||
|
async def update_var():
|
||
|
"""Poll the database until the event shows up."""
|
||
|
while not self._is_event_in_db(event):
|
||
|
await asyncio.sleep(1)
|
||
|
self.async_schedule_update_ha_state(True)
|
||
|
|
||
|
self.hass.add_job(update_var)
|
||
|
return listener
|
||
|
|
||
|
async def async_added_to_hass(self):
|
||
|
"""Register callbacks."""
|
||
|
|
||
|
@callback
|
||
|
def variable_startup(event):
|
||
|
"""Update variable event listeners on startup."""
|
||
|
if self._tracked_entity_ids is not None:
|
||
|
listener = self._get_variable_event_listener()
|
||
|
stop = self.hass.bus.async_listen(EVENT_STATE_CHANGED, listener)
|
||
|
self._stop_track_state_change = stop
|
||
|
if self._tracked_event_types is not None:
|
||
|
listener = self._get_variable_event_listener()
|
||
|
for event_type in self._tracked_event_types:
|
||
|
stop = self.hass.bus.async_listen(event_type, listener)
|
||
|
self._stop_track_events.append(stop)
|
||
|
|
||
|
self.hass.bus.async_listen_once(
|
||
|
EVENT_HOMEASSISTANT_START, variable_startup)
|
||
|
|
||
|
# Restore previous value on startup
|
||
|
await super().async_added_to_hass()
|
||
|
if self._restore == True:
|
||
|
state = await self.async_get_last_state()
|
||
|
if state:
|
||
|
self._value = state.state
|
||
|
|
||
|
@property
|
||
|
def should_poll(self):
|
||
|
"""If entity should be polled."""
|
||
|
return False
|
||
|
|
||
|
@property
|
||
|
def force_update(self):
|
||
|
"""Return True if state updates should be forced.
|
||
|
If True, a state change will be triggered anytime the state property is
|
||
|
updated, not just when the value changes.
|
||
|
"""
|
||
|
return self._force_update
|
||
|
|
||
|
@property
|
||
|
def name(self):
|
||
|
"""Return the name of the variable."""
|
||
|
return self._friendly_name
|
||
|
|
||
|
@property
|
||
|
def icon(self):
|
||
|
"""Return the icon to be used for this entity."""
|
||
|
return self._icon
|
||
|
|
||
|
@property
|
||
|
def entity_picture(self):
|
||
|
"""Return the entity_picture to be used for this entity."""
|
||
|
return self._entity_picture
|
||
|
|
||
|
@property
|
||
|
def state(self):
|
||
|
"""Return the state of the component."""
|
||
|
return self._value
|
||
|
|
||
|
@property
|
||
|
def unit_of_measurement(self):
|
||
|
"""Return the unit the value is expressed in."""
|
||
|
return self._unit
|
||
|
|
||
|
@property
|
||
|
def _templates_dict(self):
|
||
|
return {'_value': self._value_template,
|
||
|
'_name': self._friendly_name_template,
|
||
|
'_icon': self._icon_template,
|
||
|
'_entity_picture': self._entity_picture_template}
|
||
|
|
||
|
async def async_set(self,
|
||
|
value=None,
|
||
|
value_template=None,
|
||
|
query=None,
|
||
|
column=None,
|
||
|
unit=None,
|
||
|
restore=None,
|
||
|
force_update=None,
|
||
|
friendly_name=None,
|
||
|
friendly_name_template=None,
|
||
|
icon=None,
|
||
|
icon_template=None,
|
||
|
entity_picture=None,
|
||
|
entity_picture_template=None,
|
||
|
manual_tracked_entity_ids=None,
|
||
|
tracked_event_types=None):
|
||
|
"""Set new attributes for the variable."""
|
||
|
if value is not None:
|
||
|
self._value = value
|
||
|
if unit is not None:
|
||
|
self._unit = unit
|
||
|
if restore is not None:
|
||
|
self._restore = restore
|
||
|
if force_update is not None:
|
||
|
self._force_update = force_update
|
||
|
if friendly_name is not None:
|
||
|
self._friendly_name = friendly_name
|
||
|
if icon is not None:
|
||
|
self._icon = icon
|
||
|
if entity_picture is not None:
|
||
|
self._entity_picture = entity_picture
|
||
|
templates_dict = {
|
||
|
'_value': value_template,
|
||
|
'_name': friendly_name_template,
|
||
|
'_icon': icon_template,
|
||
|
'_entity_picture': entity_picture_template}
|
||
|
for property_name, template in templates_dict.items():
|
||
|
if template is not None:
|
||
|
template.hass = self.hass
|
||
|
setattr(self, property_name, template.async_render())
|
||
|
if query is not None:
|
||
|
self._query = query
|
||
|
if column is not None:
|
||
|
self._column = column
|
||
|
|
||
|
tracked_entity_ids = None
|
||
|
if manual_tracked_entity_ids is not None:
|
||
|
tracked_entity_ids = manual_tracked_entity_ids
|
||
|
|
||
|
if tracked_entity_ids is not None:
|
||
|
if self._stop_track_state_change:
|
||
|
self._stop_track_state_change()
|
||
|
self._tracked_entity_ids = tracked_entity_ids
|
||
|
listener = self._get_variable_event_listener()
|
||
|
stop = self.hass.bus.async_listen(EVENT_STATE_CHANGED, listener)
|
||
|
self._stop_track_state_change = stop
|
||
|
|
||
|
if tracked_event_types is not None:
|
||
|
if self._stop_track_events:
|
||
|
for stop in self._stop_track_events:
|
||
|
stop()
|
||
|
self._tracked_event_types = tracked_event_types
|
||
|
listener = self._get_variable_event_listener()
|
||
|
for event_type in self._tracked_event_types:
|
||
|
stop = self.hass.bus.async_listen(event_type, listener)
|
||
|
self._stop_track_events.append(stop)
|
||
|
|
||
|
await self.async_update_ha_state()
|
||
|
|
||
|
async def async_force_update(self):
|
||
|
await self.async_update_ha_state(True)
|
||
|
|
||
|
async def async_update(self):
|
||
|
"""Update the state and attributes from the templates."""
|
||
|
|
||
|
# Run the db query
|
||
|
db_value = None
|
||
|
if self._query is not None:
|
||
|
import sqlalchemy
|
||
|
try:
|
||
|
result = self._session.execute(self._query)
|
||
|
|
||
|
if not result.returns_rows or result.rowcount == 0:
|
||
|
_LOGGER.warning("%s returned no results", self._query)
|
||
|
self._state = None
|
||
|
return
|
||
|
|
||
|
for res in result:
|
||
|
_LOGGER.debug("result = %s", res.items())
|
||
|
db_value = res[self._column]
|
||
|
except sqlalchemy.exc.SQLAlchemyError as err:
|
||
|
_LOGGER.error("Error executing query %s: %s", self._query, err)
|
||
|
return
|
||
|
finally:
|
||
|
self._session.close()
|
||
|
|
||
|
# Update the state and attributes from their templates
|
||
|
for property_name, template in self._templates_dict.items():
|
||
|
if property_name != '_value' and template is None:
|
||
|
continue
|
||
|
elif property_name == '_value' and template is None and db_value is None:
|
||
|
continue
|
||
|
|
||
|
try:
|
||
|
rendered_template = None
|
||
|
if template is not None:
|
||
|
if db_value is not None:
|
||
|
rendered_template = template.async_render_with_possible_json_value(db_value, None)
|
||
|
else:
|
||
|
rendered_template = template.async_render()
|
||
|
|
||
|
if rendered_template is not None:
|
||
|
setattr(self, property_name, rendered_template)
|
||
|
elif property_name == '_value' and db_value is not None:
|
||
|
setattr(self, property_name, db_value)
|
||
|
except TemplateError as ex:
|
||
|
friendly_property_name = property_name[1:].replace('_', ' ')
|
||
|
if ex.args and ex.args[0].startswith(
|
||
|
"UndefinedError: 'None' has no attribute"):
|
||
|
# Common during HA startup - so just a warning
|
||
|
_LOGGER.warning('Could not render %s template %s,'
|
||
|
' the state is unknown.',
|
||
|
friendly_property_name, self._friendly_name)
|
||
|
continue
|
||
|
|
||
|
try:
|
||
|
setattr(self, property_name, getattr(super(), property_name))
|
||
|
except AttributeError:
|
||
|
_LOGGER.error('Could not render %s template %s: %s',
|
||
|
friendly_property_name, self._friendly_name, ex)
|
||
|
|