feat: rewrite network code to make it more linear and handle being offline better

This commit is contained in:
Luke Bonaccorsi 2024-01-14 03:14:34 +00:00
parent 8676cb3fad
commit 88ef4a6e25
2 changed files with 184 additions and 96 deletions

View File

@ -46,13 +46,16 @@ import socket
import struct import struct
import sys import sys
import time import time
from typing import Callable, Coroutine
from cryptography.hazmat.backends.openssl import backend as openssl_backend from cryptography.hazmat.backends.openssl import backend as openssl_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.hashes import Hash, MD5 from cryptography.hazmat.primitives.hashes import Hash, MD5
from cryptography.hazmat.primitives.padding import PKCS7 from cryptography.hazmat.primitives.padding import PKCS7
INITIAL_BACKOFF = 5
INITIAL_QUEUE_TIME = 0.1
BACKOFF_MULTIPLIER = 1.70224
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
MESSAGE_PREFIX_FORMAT = ">IIII" MESSAGE_PREFIX_FORMAT = ">IIII"
MESSAGE_SUFFIX_FORMAT = ">II" MESSAGE_SUFFIX_FORMAT = ">II"
@ -351,6 +354,10 @@ class ResponseTimeoutException(TuyaException):
"""Did not recieve a response to the request within the timeout""" """Did not recieve a response to the request within the timeout"""
class BackoffException(TuyaException):
"""Backoff time not reached"""
class TuyaCipher: class TuyaCipher:
"""Tuya cryptographic helpers.""" """Tuya cryptographic helpers."""
@ -444,7 +451,16 @@ class Message:
SET_COMMAND = 0x07 SET_COMMAND = 0x07
GRATUITOUS_UPDATE = 0x08 GRATUITOUS_UPDATE = 0x08
def __init__(self, command, payload=None, sequence=None, encrypt_for=None): def __init__(
self,
command,
payload=None,
sequence=None,
encrypt=False,
device=None,
expect_response=True,
ttl=5,
):
if payload is None: if payload is None:
payload = b"" payload = b""
self.payload = payload self.payload = payload
@ -454,11 +470,15 @@ class Message:
self.set_sequence() self.set_sequence()
else: else:
self.sequence = sequence self.sequence = sequence
self.encrypt = False self.encrypt = encrypt
self.device = None self.device = device
if encrypt_for is not None: self.expiry = int(time.time()) + ttl
self.device = encrypt_for self.expect_response = expect_response
self.encrypt = True self.listener = None
if expect_response is True:
self.listener = asyncio.Semaphore(0)
if device is not None:
device._listeners[self.sequence] = self.listener
def __repr__(self): def __repr__(self):
return "{}({}, {!r}, {!r}, {})".format( return "{}({}, {!r}, {!r}, {})".format(
@ -503,34 +523,8 @@ class Message:
__bytes__ = bytes __bytes__ = bytes
async def async_send(self, device, retries=4): async def async_send(self):
device._listeners[self.sequence] = asyncio.Semaphore(0) await self.device._async_send(self)
await device._async_send(self)
try:
await asyncio.wait_for(
device._listeners[self.sequence].acquire(), timeout=device.timeout
)
except:
del device._listeners[self.sequence]
if retries == 0:
raise ResponseTimeoutException(
"Timed out waiting for response to sequence number {}".format(
self.sequence
)
)
_LOGGER.debug(
"Timed out waiting for response to sequence number {}. Retrying".format(
self.sequence
)
)
if self.original_sequence is None:
self.set_sequence()
return self.async_send(device, retries - 1)
return device._listeners.pop(self.sequence)
@classmethod @classmethod
def from_bytes(cls, data, cipher=None): def from_bytes(cls, data, cipher=None):
@ -604,27 +598,9 @@ class Message:
return cls(command, payload, sequence) return cls(command, payload, sequence)
def _call_async(fn, *args):
loop = None
if sys.version_info >= (3, 7):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
pass
loop = asyncio.get_event_loop()
def wrapper(fn, *args):
asyncio.ensure_future(fn(*args))
loop.call_soon(wrapper, fn, *args)
class TuyaDevice: class TuyaDevice:
"""Represents a generic Tuya device.""" """Represents a generic Tuya device."""
# PING_INTERVAL = 10
def __init__( def __init__(
self, self,
device_id, device_id,
@ -655,13 +631,22 @@ class TuyaDevice:
self.cipher = TuyaCipher(local_key, self.version) self.cipher = TuyaCipher(local_key, self.version)
self.writer = None self.writer = None
self._handlers = { self._response_task = None
Message.GRATUITOUS_UPDATE: [self.async_gratuitous_update_state], self._recieve_task = None
Message.PING_COMMAND: [self._async_pong_received], self._handlers: dict[int, Callable[[Message], Coroutine]] = {
Message.GRATUITOUS_UPDATE: self.async_gratuitous_update_state,
Message.PING_COMMAND: self._async_pong_received,
} }
self._dps = {} self._dps = {}
self._connected = False self._connected = False
self._enabled = True
self._queue = []
self._listeners = {} self._listeners = {}
self._backoff = False
self._queue_interval = INITIAL_QUEUE_TIME
self._failures = 0
asyncio.create_task(self.process_queue())
def __repr__(self): def __repr__(self):
return "{}({!r}, {!r}, {!r}, {!r})".format( return "{}({!r}, {!r}, {!r}, {!r})".format(
@ -675,70 +660,141 @@ class TuyaDevice:
def __str__(self): def __str__(self):
return "{} ({}:{})".format(self.device_id, self.host, self.port) return "{} ({}:{})".format(self.device_id, self.host, self.port)
async def async_connect(self, callback=None): async def process_queue(self):
if self._connected: if self._enabled is False:
return return
self.clean_queue()
if len(self._queue) > 0:
_LOGGER.debug(
"Processing queue. Current length: {}".format(len(self._queue))
)
try:
message = self._queue.pop(0)
await message.async_send()
self._failures = 0
self._queue_interval = INITIAL_QUEUE_TIME
self._backoff = False
except Exception as e:
self._failures += 1
_LOGGER.debug("{} failures. Most recent: {}".format(self._failures, e))
if self._failures > 3:
self._backoff = True
self._queue_interval = min(
INITIAL_BACKOFF * (BACKOFF_MULTIPLIER ** (self._failures - 4)),
600,
)
_LOGGER.warn(
"{} failures, backing off for {} seconds".format(
self._failures, self._queue_interval
)
)
await asyncio.sleep(self._queue_interval)
asyncio.create_task(self.process_queue())
def clean_queue(self):
cleaned_queue = []
now = int(time.time())
for item in self._queue:
if item.expiry > now:
cleaned_queue.append(item)
self._queue = cleaned_queue
async def async_connect(self):
if self._connected is True or self._enabled is False:
return
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
sock.settimeout(self.timeout) sock.settimeout(self.timeout)
_LOGGER.debug("Connecting to {}".format(self)) _LOGGER.debug("Connecting to {}".format(self))
try: try:
sock.connect((self.host, self.port)) sock.connect((self.host, self.port))
except socket.timeout as e: except (socket.timeout, TimeoutError) as e:
self._dps["106"] = "CONNECTION_FAILED" self._dps["106"] = "CONNECTION_FAILED"
raise ConnectionTimeoutException("Connection timed out") from e raise ConnectionTimeoutException("Connection timed out")
loop = asyncio.get_running_loop()
loop.create_connection
self.reader, self.writer = await asyncio.open_connection(sock=sock) self.reader, self.writer = await asyncio.open_connection(sock=sock)
self._connected = True self._connected = True
asyncio.ensure_future(self._async_ping(self.ping_interval)) asyncio.create_task(self.async_ping(self.ping_interval))
asyncio.ensure_future(self._async_handle_message()) asyncio.create_task(self._async_handle_message())
async def async_disable(self):
self._enabled = False
await self.async_disconnect()
async def async_disconnect(self): async def async_disconnect(self):
if self._connected is False:
return
_LOGGER.debug("Disconnected from {}".format(self)) _LOGGER.debug("Disconnected from {}".format(self))
self._connected = False self._connected = False
self.last_pong = 0 self.last_pong = 0
if self._response_task is not None:
self._response_task.cancel()
if self._recieve_task is not None:
self._recieve_task.cancel()
if self.writer is not None: if self.writer is not None:
self.writer.close() self.writer.close()
async def async_get(self): async def async_get(self):
payload = {"gwId": self.gateway_id, "devId": self.device_id} payload = {"gwId": self.gateway_id, "devId": self.device_id}
maybe_self = None if self.version < (3, 3) else self encrypt = False if self.version < (3, 3) else True
message = Message(Message.GET_COMMAND, payload, encrypt_for=maybe_self) message = Message(Message.GET_COMMAND, payload, encrypt=encrypt, device=self)
response = await message.async_send(self) self._queue.append(message)
await self.async_update_state(response) response = await self.async_recieve(message)
asyncio.create_task(self.async_update_state(response))
async def async_set(self, dps): async def async_set(self, dps):
t = int(time.time()) t = int(time.time())
payload = {"devId": self.device_id, "uid": "", "t": t, "dps": dps} payload = {"devId": self.device_id, "uid": "", "t": t, "dps": dps}
message = Message(Message.SET_COMMAND, payload, encrypt_for=self) message = Message(
await message.async_send(self) Message.SET_COMMAND,
payload,
encrypt=True,
device=self,
expect_response=False,
)
self._queue.append(message)
def set(self, dps): async def async_ping(self, ping_interval):
_call_async(self.async_set, dps) if self._enabled is False:
async def _async_ping(self, ping_interval):
if not self._connected:
return return
self.last_ping = time.time() if self._backoff is True:
maybe_self = None if self.version < (3, 3) else self _LOGGER.debug("Currently in backoff, not adding ping to queue")
message = Message(Message.PING_COMMAND, sequence=0, encrypt_for=maybe_self) else:
await self._async_send(message) self.last_ping = time.time()
encrypt = False if self.version < (3, 3) else True
message = Message(
Message.PING_COMMAND,
sequence=0,
encrypt=encrypt,
device=self,
expect_response=False,
)
self._queue.append(message)
await asyncio.sleep(ping_interval) await asyncio.sleep(ping_interval)
asyncio.create_task(self.async_ping(self.ping_interval))
if self.last_pong < self.last_ping: if self.last_pong < self.last_ping:
await self.async_disconnect() await self.async_disconnect()
else:
asyncio.ensure_future(self._async_ping(self.ping_interval))
async def _async_pong_received(self, message, device): async def _async_pong_received(self, message):
self.last_pong = time.time() self.last_pong = time.time()
async def async_gratuitous_update_state(self, state_message, _): async def async_gratuitous_update_state(self, state_message):
await self.async_update_state(state_message) await self.async_update_state(state_message)
await self.update_entity_state_cb() await self.update_entity_state_cb()
async def async_update_state(self, state_message, _=None): async def async_update_state(self, state_message, _=None):
if state_message.payload and state_message.payload["dps"]: if state_message.payload and state_message.payload["dps"]:
self._dps.update(state_message.payload["dps"]) self._dps.update(state_message.payload["dps"])
_LOGGER.info("Received updated state {}: {}".format(self, self._dps)) _LOGGER.debug("Received updated state {}: {}".format(self, self._dps))
@property @property
def state(self): def state(self):
@ -746,16 +802,31 @@ class TuyaDevice:
@state.setter @state.setter
def state_setter(self, new_values): def state_setter(self, new_values):
asyncio.ensure_future(self.async_set(new_values)) asyncio.create_task(self.async_set(new_values))
async def _async_handle_message(self): async def _async_handle_message(self):
if self._enabled is False:
return
try: try:
response_data = await self.reader.readuntil(MAGIC_SUFFIX_BYTES) self._response_task = asyncio.create_task(
self.reader.readuntil(MAGIC_SUFFIX_BYTES)
)
await self._response_task
response_data = self._response_task.result()
message = Message.from_bytes(response_data, self.cipher) message = Message.from_bytes(response_data, self.cipher)
self._response_task = None
except InvalidMessage as e: except InvalidMessage as e:
_LOGGER.debug("Invalid message from {}: {}".format(self, e)) _LOGGER.debug("Invalid message from {}: {}".format(self, e))
except MessageDecodeFailed as e: except MessageDecodeFailed as e:
_LOGGER.debug("Failed to decrypt message from {}".format(self)) _LOGGER.debug("Failed to decrypt message from {}".format(self))
except asyncio.IncompleteReadError as e:
self._response_task = None
if self._connected:
_LOGGER.debug("Incomplete read")
except ConnectionResetError as e:
_LOGGER.debug("Connection reset")
else: else:
_LOGGER.debug("Received message from {}: {}".format(self, message)) _LOGGER.debug("Received message from {}: {}".format(self, message))
if message.sequence in self._listeners: if message.sequence in self._listeners:
@ -764,21 +835,22 @@ class TuyaDevice:
self._listeners[message.sequence] = message self._listeners[message.sequence] = message
sem.release() sem.release()
else: else:
for c in self._handlers.get(message.command, []): handler = self._handlers.get(message.command, None)
asyncio.ensure_future(c(message, self)) if handler is not None:
asyncio.create_task(handler(message))
asyncio.ensure_future(self._async_handle_message()) asyncio.create_task(self._async_handle_message())
async def _async_send(self, message, retries=4): async def _async_send(self, message, retries=2):
_LOGGER.debug("Sending to {}: {}".format(self, message))
try: try:
await self.async_connect() await self.async_connect()
_LOGGER.debug("Sending to {}: {}".format(self, message))
self.writer.write(message.bytes()) self.writer.write(message.bytes())
await self.writer.drain() await self.writer.drain()
except Exception as e: except Exception as e:
if retries == 0: if retries == 0:
if isinstance(e, socket.error): if isinstance(e, socket.error):
asyncio.ensure_future(self.async_disconnect()) asyncio.create_task(self.async_disconnect())
raise ConnectionException( raise ConnectionException(
"Connection to {} failed: {}".format(self, e) "Connection to {} failed: {}".format(self, e)
) )
@ -805,4 +877,22 @@ class TuyaDevice:
_LOGGER.debug( _LOGGER.debug(
"Retrying send due to error. Failed to send data to {}".format(self) "Retrying send due to error. Failed to send data to {}".format(self)
) )
await asyncio.sleep(0.25)
await self._async_send(message, retries=retries - 1) await self._async_send(message, retries=retries - 1)
async def async_recieve(self, message):
if message.expect_response is True:
try:
self._recieve_task = asyncio.create_task(
asyncio.wait_for(message.listener.acquire(), timeout=self.timeout)
)
await self._recieve_task
return self._listeners.pop(message.sequence)
except Exception as e:
del self._listeners[message.sequence]
raise ResponseTimeoutException(
"Timed out waiting for response to sequence number {}".format(
message.sequence
)
)

View File

@ -328,15 +328,13 @@ class RoboVacEntity(StateVacuumEntity):
self.update_entity_values() self.update_entity_values()
except TuyaException as e: except TuyaException as e:
self.update_failures += 1 self.update_failures += 1
_LOGGER.debug( _LOGGER.warn(
"Update errored. Current failure count: {}. Reason: {}".format( "Update errored. Current update failure count: {}. Reason: {}".format(
self.update_failures, e self.update_failures, e
) )
) )
if self.update_failures == UPDATE_RETRIES: if self.update_failures >= UPDATE_RETRIES:
self.update_failures = 0
self.error_code = "CONNECTION_FAILED" self.error_code = "CONNECTION_FAILED"
raise e
async def pushed_update_handler(self): async def pushed_update_handler(self):
self.update_entity_values() self.update_entity_values()
@ -479,4 +477,4 @@ class RoboVacEntity(StateVacuumEntity):
await self.vacuum.async_set({"124": base64_str}) await self.vacuum.async_set({"124": base64_str})
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
await self.vacuum.async_disconnect() await self.vacuum.async_disable()