diff --git a/custom_components/robovac/tuyalocalapi.py b/custom_components/robovac/tuyalocalapi.py index 6d09e74..d03926e 100644 --- a/custom_components/robovac/tuyalocalapi.py +++ b/custom_components/robovac/tuyalocalapi.py @@ -46,13 +46,16 @@ import socket import struct import sys import time +from typing import Callable, Coroutine from cryptography.hazmat.backends.openssl import backend as openssl_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.hashes import Hash, MD5 from cryptography.hazmat.primitives.padding import PKCS7 - +INITIAL_BACKOFF = 5 +INITIAL_QUEUE_TIME = 0.1 +BACKOFF_MULTIPLIER = 1.70224 _LOGGER = logging.getLogger(__name__) MESSAGE_PREFIX_FORMAT = ">IIII" MESSAGE_SUFFIX_FORMAT = ">II" @@ -351,6 +354,10 @@ class ResponseTimeoutException(TuyaException): """Did not recieve a response to the request within the timeout""" +class BackoffException(TuyaException): + """Backoff time not reached""" + + class TuyaCipher: """Tuya cryptographic helpers.""" @@ -444,7 +451,16 @@ class Message: SET_COMMAND = 0x07 GRATUITOUS_UPDATE = 0x08 - def __init__(self, command, payload=None, sequence=None, encrypt_for=None): + def __init__( + self, + command, + payload=None, + sequence=None, + encrypt=False, + device=None, + expect_response=True, + ttl=5, + ): if payload is None: payload = b"" self.payload = payload @@ -454,11 +470,15 @@ class Message: self.set_sequence() else: self.sequence = sequence - self.encrypt = False - self.device = None - if encrypt_for is not None: - self.device = encrypt_for - self.encrypt = True + self.encrypt = encrypt + self.device = device + self.expiry = int(time.time()) + ttl + self.expect_response = expect_response + self.listener = None + if expect_response is True: + self.listener = asyncio.Semaphore(0) + if device is not None: + device._listeners[self.sequence] = self.listener def __repr__(self): return "{}({}, {!r}, {!r}, {})".format( @@ -503,34 +523,8 @@ class Message: __bytes__ = bytes - async def async_send(self, device, retries=4): - device._listeners[self.sequence] = asyncio.Semaphore(0) - await device._async_send(self) - try: - await asyncio.wait_for( - device._listeners[self.sequence].acquire(), timeout=device.timeout - ) - except: - del device._listeners[self.sequence] - if retries == 0: - raise ResponseTimeoutException( - "Timed out waiting for response to sequence number {}".format( - self.sequence - ) - ) - - _LOGGER.debug( - "Timed out waiting for response to sequence number {}. Retrying".format( - self.sequence - ) - ) - - if self.original_sequence is None: - self.set_sequence() - - return self.async_send(device, retries - 1) - - return device._listeners.pop(self.sequence) + async def async_send(self): + await self.device._async_send(self) @classmethod def from_bytes(cls, data, cipher=None): @@ -604,27 +598,9 @@ class Message: return cls(command, payload, sequence) -def _call_async(fn, *args): - loop = None - if sys.version_info >= (3, 7): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - pass - - loop = asyncio.get_event_loop() - - def wrapper(fn, *args): - asyncio.ensure_future(fn(*args)) - - loop.call_soon(wrapper, fn, *args) - - class TuyaDevice: """Represents a generic Tuya device.""" - # PING_INTERVAL = 10 - def __init__( self, device_id, @@ -655,13 +631,22 @@ class TuyaDevice: self.cipher = TuyaCipher(local_key, self.version) self.writer = None - self._handlers = { - Message.GRATUITOUS_UPDATE: [self.async_gratuitous_update_state], - Message.PING_COMMAND: [self._async_pong_received], + self._response_task = None + self._recieve_task = None + self._handlers: dict[int, Callable[[Message], Coroutine]] = { + Message.GRATUITOUS_UPDATE: self.async_gratuitous_update_state, + Message.PING_COMMAND: self._async_pong_received, } self._dps = {} self._connected = False + self._enabled = True + self._queue = [] self._listeners = {} + self._backoff = False + self._queue_interval = INITIAL_QUEUE_TIME + self._failures = 0 + + asyncio.create_task(self.process_queue()) def __repr__(self): return "{}({!r}, {!r}, {!r}, {!r})".format( @@ -675,70 +660,141 @@ class TuyaDevice: def __str__(self): return "{} ({}:{})".format(self.device_id, self.host, self.port) - async def async_connect(self, callback=None): - if self._connected: + async def process_queue(self): + if self._enabled is False: return + + self.clean_queue() + + if len(self._queue) > 0: + _LOGGER.debug( + "Processing queue. Current length: {}".format(len(self._queue)) + ) + try: + message = self._queue.pop(0) + await message.async_send() + self._failures = 0 + self._queue_interval = INITIAL_QUEUE_TIME + self._backoff = False + except Exception as e: + self._failures += 1 + _LOGGER.debug("{} failures. Most recent: {}".format(self._failures, e)) + if self._failures > 3: + self._backoff = True + self._queue_interval = min( + INITIAL_BACKOFF * (BACKOFF_MULTIPLIER ** (self._failures - 4)), + 600, + ) + _LOGGER.warn( + "{} failures, backing off for {} seconds".format( + self._failures, self._queue_interval + ) + ) + + await asyncio.sleep(self._queue_interval) + asyncio.create_task(self.process_queue()) + + def clean_queue(self): + cleaned_queue = [] + now = int(time.time()) + for item in self._queue: + if item.expiry > now: + cleaned_queue.append(item) + self._queue = cleaned_queue + + async def async_connect(self): + if self._connected is True or self._enabled is False: + return + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) sock.settimeout(self.timeout) _LOGGER.debug("Connecting to {}".format(self)) try: sock.connect((self.host, self.port)) - except socket.timeout as e: + except (socket.timeout, TimeoutError) as e: self._dps["106"] = "CONNECTION_FAILED" - raise ConnectionTimeoutException("Connection timed out") from e + raise ConnectionTimeoutException("Connection timed out") + loop = asyncio.get_running_loop() + loop.create_connection self.reader, self.writer = await asyncio.open_connection(sock=sock) self._connected = True - asyncio.ensure_future(self._async_ping(self.ping_interval)) - asyncio.ensure_future(self._async_handle_message()) + asyncio.create_task(self.async_ping(self.ping_interval)) + asyncio.create_task(self._async_handle_message()) + + async def async_disable(self): + self._enabled = False + await self.async_disconnect() async def async_disconnect(self): + if self._connected is False: + return + _LOGGER.debug("Disconnected from {}".format(self)) self._connected = False self.last_pong = 0 + if self._response_task is not None: + self._response_task.cancel() + + if self._recieve_task is not None: + self._recieve_task.cancel() + if self.writer is not None: self.writer.close() async def async_get(self): payload = {"gwId": self.gateway_id, "devId": self.device_id} - maybe_self = None if self.version < (3, 3) else self - message = Message(Message.GET_COMMAND, payload, encrypt_for=maybe_self) - response = await message.async_send(self) - await self.async_update_state(response) + encrypt = False if self.version < (3, 3) else True + message = Message(Message.GET_COMMAND, payload, encrypt=encrypt, device=self) + self._queue.append(message) + response = await self.async_recieve(message) + asyncio.create_task(self.async_update_state(response)) async def async_set(self, dps): t = int(time.time()) payload = {"devId": self.device_id, "uid": "", "t": t, "dps": dps} - message = Message(Message.SET_COMMAND, payload, encrypt_for=self) - await message.async_send(self) + message = Message( + Message.SET_COMMAND, + payload, + encrypt=True, + device=self, + expect_response=False, + ) + self._queue.append(message) - def set(self, dps): - _call_async(self.async_set, dps) - - async def _async_ping(self, ping_interval): - if not self._connected: + async def async_ping(self, ping_interval): + if self._enabled is False: return - self.last_ping = time.time() - maybe_self = None if self.version < (3, 3) else self - message = Message(Message.PING_COMMAND, sequence=0, encrypt_for=maybe_self) - await self._async_send(message) + if self._backoff is True: + _LOGGER.debug("Currently in backoff, not adding ping to queue") + else: + self.last_ping = time.time() + encrypt = False if self.version < (3, 3) else True + message = Message( + Message.PING_COMMAND, + sequence=0, + encrypt=encrypt, + device=self, + expect_response=False, + ) + self._queue.append(message) + await asyncio.sleep(ping_interval) + asyncio.create_task(self.async_ping(self.ping_interval)) if self.last_pong < self.last_ping: await self.async_disconnect() - else: - asyncio.ensure_future(self._async_ping(self.ping_interval)) - async def _async_pong_received(self, message, device): + async def _async_pong_received(self, message): self.last_pong = time.time() - async def async_gratuitous_update_state(self, state_message, _): + async def async_gratuitous_update_state(self, state_message): await self.async_update_state(state_message) await self.update_entity_state_cb() async def async_update_state(self, state_message, _=None): if state_message.payload and state_message.payload["dps"]: self._dps.update(state_message.payload["dps"]) - _LOGGER.info("Received updated state {}: {}".format(self, self._dps)) + _LOGGER.debug("Received updated state {}: {}".format(self, self._dps)) @property def state(self): @@ -746,16 +802,31 @@ class TuyaDevice: @state.setter def state_setter(self, new_values): - asyncio.ensure_future(self.async_set(new_values)) + asyncio.create_task(self.async_set(new_values)) async def _async_handle_message(self): + if self._enabled is False: + return + try: - response_data = await self.reader.readuntil(MAGIC_SUFFIX_BYTES) + self._response_task = asyncio.create_task( + self.reader.readuntil(MAGIC_SUFFIX_BYTES) + ) + await self._response_task + response_data = self._response_task.result() message = Message.from_bytes(response_data, self.cipher) + self._response_task = None except InvalidMessage as e: _LOGGER.debug("Invalid message from {}: {}".format(self, e)) except MessageDecodeFailed as e: _LOGGER.debug("Failed to decrypt message from {}".format(self)) + except asyncio.IncompleteReadError as e: + self._response_task = None + if self._connected: + _LOGGER.debug("Incomplete read") + except ConnectionResetError as e: + _LOGGER.debug("Connection reset") + else: _LOGGER.debug("Received message from {}: {}".format(self, message)) if message.sequence in self._listeners: @@ -764,21 +835,22 @@ class TuyaDevice: self._listeners[message.sequence] = message sem.release() else: - for c in self._handlers.get(message.command, []): - asyncio.ensure_future(c(message, self)) + handler = self._handlers.get(message.command, None) + if handler is not None: + asyncio.create_task(handler(message)) - asyncio.ensure_future(self._async_handle_message()) + asyncio.create_task(self._async_handle_message()) - async def _async_send(self, message, retries=4): + async def _async_send(self, message, retries=2): + _LOGGER.debug("Sending to {}: {}".format(self, message)) try: await self.async_connect() - _LOGGER.debug("Sending to {}: {}".format(self, message)) self.writer.write(message.bytes()) await self.writer.drain() except Exception as e: if retries == 0: if isinstance(e, socket.error): - asyncio.ensure_future(self.async_disconnect()) + asyncio.create_task(self.async_disconnect()) raise ConnectionException( "Connection to {} failed: {}".format(self, e) ) @@ -805,4 +877,22 @@ class TuyaDevice: _LOGGER.debug( "Retrying send due to error. Failed to send data to {}".format(self) ) + await asyncio.sleep(0.25) await self._async_send(message, retries=retries - 1) + + async def async_recieve(self, message): + if message.expect_response is True: + try: + self._recieve_task = asyncio.create_task( + asyncio.wait_for(message.listener.acquire(), timeout=self.timeout) + ) + await self._recieve_task + return self._listeners.pop(message.sequence) + except Exception as e: + del self._listeners[message.sequence] + + raise ResponseTimeoutException( + "Timed out waiting for response to sequence number {}".format( + message.sequence + ) + ) diff --git a/custom_components/robovac/vacuum.py b/custom_components/robovac/vacuum.py index c4e40c2..e0f8752 100644 --- a/custom_components/robovac/vacuum.py +++ b/custom_components/robovac/vacuum.py @@ -328,15 +328,13 @@ class RoboVacEntity(StateVacuumEntity): self.update_entity_values() except TuyaException as e: self.update_failures += 1 - _LOGGER.debug( - "Update errored. Current failure count: {}. Reason: {}".format( + _LOGGER.warn( + "Update errored. Current update failure count: {}. Reason: {}".format( self.update_failures, e ) ) - if self.update_failures == UPDATE_RETRIES: - self.update_failures = 0 + if self.update_failures >= UPDATE_RETRIES: self.error_code = "CONNECTION_FAILED" - raise e async def pushed_update_handler(self): self.update_entity_values() @@ -479,4 +477,4 @@ class RoboVacEntity(StateVacuumEntity): await self.vacuum.async_set({"124": base64_str}) async def async_will_remove_from_hass(self): - await self.vacuum.async_disconnect() + await self.vacuum.async_disable()