feat: rewrite network code to make it more linear and handle being offline better
This commit is contained in:
parent
8676cb3fad
commit
88ef4a6e25
|
|
@ -46,13 +46,16 @@ import socket
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from typing import Callable, Coroutine
|
||||||
|
|
||||||
from cryptography.hazmat.backends.openssl import backend as openssl_backend
|
from cryptography.hazmat.backends.openssl import backend as openssl_backend
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
from cryptography.hazmat.primitives.hashes import Hash, MD5
|
from cryptography.hazmat.primitives.hashes import Hash, MD5
|
||||||
from cryptography.hazmat.primitives.padding import PKCS7
|
from cryptography.hazmat.primitives.padding import PKCS7
|
||||||
|
|
||||||
|
INITIAL_BACKOFF = 5
|
||||||
|
INITIAL_QUEUE_TIME = 0.1
|
||||||
|
BACKOFF_MULTIPLIER = 1.70224
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
MESSAGE_PREFIX_FORMAT = ">IIII"
|
MESSAGE_PREFIX_FORMAT = ">IIII"
|
||||||
MESSAGE_SUFFIX_FORMAT = ">II"
|
MESSAGE_SUFFIX_FORMAT = ">II"
|
||||||
|
|
@ -351,6 +354,10 @@ class ResponseTimeoutException(TuyaException):
|
||||||
"""Did not recieve a response to the request within the timeout"""
|
"""Did not recieve a response to the request within the timeout"""
|
||||||
|
|
||||||
|
|
||||||
|
class BackoffException(TuyaException):
|
||||||
|
"""Backoff time not reached"""
|
||||||
|
|
||||||
|
|
||||||
class TuyaCipher:
|
class TuyaCipher:
|
||||||
"""Tuya cryptographic helpers."""
|
"""Tuya cryptographic helpers."""
|
||||||
|
|
||||||
|
|
@ -444,7 +451,16 @@ class Message:
|
||||||
SET_COMMAND = 0x07
|
SET_COMMAND = 0x07
|
||||||
GRATUITOUS_UPDATE = 0x08
|
GRATUITOUS_UPDATE = 0x08
|
||||||
|
|
||||||
def __init__(self, command, payload=None, sequence=None, encrypt_for=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
command,
|
||||||
|
payload=None,
|
||||||
|
sequence=None,
|
||||||
|
encrypt=False,
|
||||||
|
device=None,
|
||||||
|
expect_response=True,
|
||||||
|
ttl=5,
|
||||||
|
):
|
||||||
if payload is None:
|
if payload is None:
|
||||||
payload = b""
|
payload = b""
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
|
|
@ -454,11 +470,15 @@ class Message:
|
||||||
self.set_sequence()
|
self.set_sequence()
|
||||||
else:
|
else:
|
||||||
self.sequence = sequence
|
self.sequence = sequence
|
||||||
self.encrypt = False
|
self.encrypt = encrypt
|
||||||
self.device = None
|
self.device = device
|
||||||
if encrypt_for is not None:
|
self.expiry = int(time.time()) + ttl
|
||||||
self.device = encrypt_for
|
self.expect_response = expect_response
|
||||||
self.encrypt = True
|
self.listener = None
|
||||||
|
if expect_response is True:
|
||||||
|
self.listener = asyncio.Semaphore(0)
|
||||||
|
if device is not None:
|
||||||
|
device._listeners[self.sequence] = self.listener
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{}({}, {!r}, {!r}, {})".format(
|
return "{}({}, {!r}, {!r}, {})".format(
|
||||||
|
|
@ -503,34 +523,8 @@ class Message:
|
||||||
|
|
||||||
__bytes__ = bytes
|
__bytes__ = bytes
|
||||||
|
|
||||||
async def async_send(self, device, retries=4):
|
async def async_send(self):
|
||||||
device._listeners[self.sequence] = asyncio.Semaphore(0)
|
await self.device._async_send(self)
|
||||||
await device._async_send(self)
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
device._listeners[self.sequence].acquire(), timeout=device.timeout
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
del device._listeners[self.sequence]
|
|
||||||
if retries == 0:
|
|
||||||
raise ResponseTimeoutException(
|
|
||||||
"Timed out waiting for response to sequence number {}".format(
|
|
||||||
self.sequence
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
_LOGGER.debug(
|
|
||||||
"Timed out waiting for response to sequence number {}. Retrying".format(
|
|
||||||
self.sequence
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.original_sequence is None:
|
|
||||||
self.set_sequence()
|
|
||||||
|
|
||||||
return self.async_send(device, retries - 1)
|
|
||||||
|
|
||||||
return device._listeners.pop(self.sequence)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, data, cipher=None):
|
def from_bytes(cls, data, cipher=None):
|
||||||
|
|
@ -604,27 +598,9 @@ class Message:
|
||||||
return cls(command, payload, sequence)
|
return cls(command, payload, sequence)
|
||||||
|
|
||||||
|
|
||||||
def _call_async(fn, *args):
|
|
||||||
loop = None
|
|
||||||
if sys.version_info >= (3, 7):
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
def wrapper(fn, *args):
|
|
||||||
asyncio.ensure_future(fn(*args))
|
|
||||||
|
|
||||||
loop.call_soon(wrapper, fn, *args)
|
|
||||||
|
|
||||||
|
|
||||||
class TuyaDevice:
|
class TuyaDevice:
|
||||||
"""Represents a generic Tuya device."""
|
"""Represents a generic Tuya device."""
|
||||||
|
|
||||||
# PING_INTERVAL = 10
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device_id,
|
device_id,
|
||||||
|
|
@ -655,13 +631,22 @@ class TuyaDevice:
|
||||||
|
|
||||||
self.cipher = TuyaCipher(local_key, self.version)
|
self.cipher = TuyaCipher(local_key, self.version)
|
||||||
self.writer = None
|
self.writer = None
|
||||||
self._handlers = {
|
self._response_task = None
|
||||||
Message.GRATUITOUS_UPDATE: [self.async_gratuitous_update_state],
|
self._recieve_task = None
|
||||||
Message.PING_COMMAND: [self._async_pong_received],
|
self._handlers: dict[int, Callable[[Message], Coroutine]] = {
|
||||||
|
Message.GRATUITOUS_UPDATE: self.async_gratuitous_update_state,
|
||||||
|
Message.PING_COMMAND: self._async_pong_received,
|
||||||
}
|
}
|
||||||
self._dps = {}
|
self._dps = {}
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
self._enabled = True
|
||||||
|
self._queue = []
|
||||||
self._listeners = {}
|
self._listeners = {}
|
||||||
|
self._backoff = False
|
||||||
|
self._queue_interval = INITIAL_QUEUE_TIME
|
||||||
|
self._failures = 0
|
||||||
|
|
||||||
|
asyncio.create_task(self.process_queue())
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{}({!r}, {!r}, {!r}, {!r})".format(
|
return "{}({!r}, {!r}, {!r}, {!r})".format(
|
||||||
|
|
@ -675,70 +660,141 @@ class TuyaDevice:
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "{} ({}:{})".format(self.device_id, self.host, self.port)
|
return "{} ({}:{})".format(self.device_id, self.host, self.port)
|
||||||
|
|
||||||
async def async_connect(self, callback=None):
|
async def process_queue(self):
|
||||||
if self._connected:
|
if self._enabled is False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.clean_queue()
|
||||||
|
|
||||||
|
if len(self._queue) > 0:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Processing queue. Current length: {}".format(len(self._queue))
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
message = self._queue.pop(0)
|
||||||
|
await message.async_send()
|
||||||
|
self._failures = 0
|
||||||
|
self._queue_interval = INITIAL_QUEUE_TIME
|
||||||
|
self._backoff = False
|
||||||
|
except Exception as e:
|
||||||
|
self._failures += 1
|
||||||
|
_LOGGER.debug("{} failures. Most recent: {}".format(self._failures, e))
|
||||||
|
if self._failures > 3:
|
||||||
|
self._backoff = True
|
||||||
|
self._queue_interval = min(
|
||||||
|
INITIAL_BACKOFF * (BACKOFF_MULTIPLIER ** (self._failures - 4)),
|
||||||
|
600,
|
||||||
|
)
|
||||||
|
_LOGGER.warn(
|
||||||
|
"{} failures, backing off for {} seconds".format(
|
||||||
|
self._failures, self._queue_interval
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(self._queue_interval)
|
||||||
|
asyncio.create_task(self.process_queue())
|
||||||
|
|
||||||
|
def clean_queue(self):
|
||||||
|
cleaned_queue = []
|
||||||
|
now = int(time.time())
|
||||||
|
for item in self._queue:
|
||||||
|
if item.expiry > now:
|
||||||
|
cleaned_queue.append(item)
|
||||||
|
self._queue = cleaned_queue
|
||||||
|
|
||||||
|
async def async_connect(self):
|
||||||
|
if self._connected is True or self._enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
|
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
|
||||||
sock.settimeout(self.timeout)
|
sock.settimeout(self.timeout)
|
||||||
_LOGGER.debug("Connecting to {}".format(self))
|
_LOGGER.debug("Connecting to {}".format(self))
|
||||||
try:
|
try:
|
||||||
sock.connect((self.host, self.port))
|
sock.connect((self.host, self.port))
|
||||||
except socket.timeout as e:
|
except (socket.timeout, TimeoutError) as e:
|
||||||
self._dps["106"] = "CONNECTION_FAILED"
|
self._dps["106"] = "CONNECTION_FAILED"
|
||||||
raise ConnectionTimeoutException("Connection timed out") from e
|
raise ConnectionTimeoutException("Connection timed out")
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.create_connection
|
||||||
self.reader, self.writer = await asyncio.open_connection(sock=sock)
|
self.reader, self.writer = await asyncio.open_connection(sock=sock)
|
||||||
self._connected = True
|
self._connected = True
|
||||||
asyncio.ensure_future(self._async_ping(self.ping_interval))
|
asyncio.create_task(self.async_ping(self.ping_interval))
|
||||||
asyncio.ensure_future(self._async_handle_message())
|
asyncio.create_task(self._async_handle_message())
|
||||||
|
|
||||||
|
async def async_disable(self):
|
||||||
|
self._enabled = False
|
||||||
|
await self.async_disconnect()
|
||||||
|
|
||||||
async def async_disconnect(self):
|
async def async_disconnect(self):
|
||||||
|
if self._connected is False:
|
||||||
|
return
|
||||||
|
|
||||||
_LOGGER.debug("Disconnected from {}".format(self))
|
_LOGGER.debug("Disconnected from {}".format(self))
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self.last_pong = 0
|
self.last_pong = 0
|
||||||
|
if self._response_task is not None:
|
||||||
|
self._response_task.cancel()
|
||||||
|
|
||||||
|
if self._recieve_task is not None:
|
||||||
|
self._recieve_task.cancel()
|
||||||
|
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
|
||||||
async def async_get(self):
|
async def async_get(self):
|
||||||
payload = {"gwId": self.gateway_id, "devId": self.device_id}
|
payload = {"gwId": self.gateway_id, "devId": self.device_id}
|
||||||
maybe_self = None if self.version < (3, 3) else self
|
encrypt = False if self.version < (3, 3) else True
|
||||||
message = Message(Message.GET_COMMAND, payload, encrypt_for=maybe_self)
|
message = Message(Message.GET_COMMAND, payload, encrypt=encrypt, device=self)
|
||||||
response = await message.async_send(self)
|
self._queue.append(message)
|
||||||
await self.async_update_state(response)
|
response = await self.async_recieve(message)
|
||||||
|
asyncio.create_task(self.async_update_state(response))
|
||||||
|
|
||||||
async def async_set(self, dps):
|
async def async_set(self, dps):
|
||||||
t = int(time.time())
|
t = int(time.time())
|
||||||
payload = {"devId": self.device_id, "uid": "", "t": t, "dps": dps}
|
payload = {"devId": self.device_id, "uid": "", "t": t, "dps": dps}
|
||||||
message = Message(Message.SET_COMMAND, payload, encrypt_for=self)
|
message = Message(
|
||||||
await message.async_send(self)
|
Message.SET_COMMAND,
|
||||||
|
payload,
|
||||||
|
encrypt=True,
|
||||||
|
device=self,
|
||||||
|
expect_response=False,
|
||||||
|
)
|
||||||
|
self._queue.append(message)
|
||||||
|
|
||||||
def set(self, dps):
|
async def async_ping(self, ping_interval):
|
||||||
_call_async(self.async_set, dps)
|
if self._enabled is False:
|
||||||
|
|
||||||
async def _async_ping(self, ping_interval):
|
|
||||||
if not self._connected:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.last_ping = time.time()
|
if self._backoff is True:
|
||||||
maybe_self = None if self.version < (3, 3) else self
|
_LOGGER.debug("Currently in backoff, not adding ping to queue")
|
||||||
message = Message(Message.PING_COMMAND, sequence=0, encrypt_for=maybe_self)
|
else:
|
||||||
await self._async_send(message)
|
self.last_ping = time.time()
|
||||||
|
encrypt = False if self.version < (3, 3) else True
|
||||||
|
message = Message(
|
||||||
|
Message.PING_COMMAND,
|
||||||
|
sequence=0,
|
||||||
|
encrypt=encrypt,
|
||||||
|
device=self,
|
||||||
|
expect_response=False,
|
||||||
|
)
|
||||||
|
self._queue.append(message)
|
||||||
|
|
||||||
await asyncio.sleep(ping_interval)
|
await asyncio.sleep(ping_interval)
|
||||||
|
asyncio.create_task(self.async_ping(self.ping_interval))
|
||||||
if self.last_pong < self.last_ping:
|
if self.last_pong < self.last_ping:
|
||||||
await self.async_disconnect()
|
await self.async_disconnect()
|
||||||
else:
|
|
||||||
asyncio.ensure_future(self._async_ping(self.ping_interval))
|
|
||||||
|
|
||||||
async def _async_pong_received(self, message, device):
|
async def _async_pong_received(self, message):
|
||||||
self.last_pong = time.time()
|
self.last_pong = time.time()
|
||||||
|
|
||||||
async def async_gratuitous_update_state(self, state_message, _):
|
async def async_gratuitous_update_state(self, state_message):
|
||||||
await self.async_update_state(state_message)
|
await self.async_update_state(state_message)
|
||||||
await self.update_entity_state_cb()
|
await self.update_entity_state_cb()
|
||||||
|
|
||||||
async def async_update_state(self, state_message, _=None):
|
async def async_update_state(self, state_message, _=None):
|
||||||
if state_message.payload and state_message.payload["dps"]:
|
if state_message.payload and state_message.payload["dps"]:
|
||||||
self._dps.update(state_message.payload["dps"])
|
self._dps.update(state_message.payload["dps"])
|
||||||
_LOGGER.info("Received updated state {}: {}".format(self, self._dps))
|
_LOGGER.debug("Received updated state {}: {}".format(self, self._dps))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
|
|
@ -746,16 +802,31 @@ class TuyaDevice:
|
||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state_setter(self, new_values):
|
def state_setter(self, new_values):
|
||||||
asyncio.ensure_future(self.async_set(new_values))
|
asyncio.create_task(self.async_set(new_values))
|
||||||
|
|
||||||
async def _async_handle_message(self):
|
async def _async_handle_message(self):
|
||||||
|
if self._enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_data = await self.reader.readuntil(MAGIC_SUFFIX_BYTES)
|
self._response_task = asyncio.create_task(
|
||||||
|
self.reader.readuntil(MAGIC_SUFFIX_BYTES)
|
||||||
|
)
|
||||||
|
await self._response_task
|
||||||
|
response_data = self._response_task.result()
|
||||||
message = Message.from_bytes(response_data, self.cipher)
|
message = Message.from_bytes(response_data, self.cipher)
|
||||||
|
self._response_task = None
|
||||||
except InvalidMessage as e:
|
except InvalidMessage as e:
|
||||||
_LOGGER.debug("Invalid message from {}: {}".format(self, e))
|
_LOGGER.debug("Invalid message from {}: {}".format(self, e))
|
||||||
except MessageDecodeFailed as e:
|
except MessageDecodeFailed as e:
|
||||||
_LOGGER.debug("Failed to decrypt message from {}".format(self))
|
_LOGGER.debug("Failed to decrypt message from {}".format(self))
|
||||||
|
except asyncio.IncompleteReadError as e:
|
||||||
|
self._response_task = None
|
||||||
|
if self._connected:
|
||||||
|
_LOGGER.debug("Incomplete read")
|
||||||
|
except ConnectionResetError as e:
|
||||||
|
_LOGGER.debug("Connection reset")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_LOGGER.debug("Received message from {}: {}".format(self, message))
|
_LOGGER.debug("Received message from {}: {}".format(self, message))
|
||||||
if message.sequence in self._listeners:
|
if message.sequence in self._listeners:
|
||||||
|
|
@ -764,21 +835,22 @@ class TuyaDevice:
|
||||||
self._listeners[message.sequence] = message
|
self._listeners[message.sequence] = message
|
||||||
sem.release()
|
sem.release()
|
||||||
else:
|
else:
|
||||||
for c in self._handlers.get(message.command, []):
|
handler = self._handlers.get(message.command, None)
|
||||||
asyncio.ensure_future(c(message, self))
|
if handler is not None:
|
||||||
|
asyncio.create_task(handler(message))
|
||||||
|
|
||||||
asyncio.ensure_future(self._async_handle_message())
|
asyncio.create_task(self._async_handle_message())
|
||||||
|
|
||||||
async def _async_send(self, message, retries=4):
|
async def _async_send(self, message, retries=2):
|
||||||
|
_LOGGER.debug("Sending to {}: {}".format(self, message))
|
||||||
try:
|
try:
|
||||||
await self.async_connect()
|
await self.async_connect()
|
||||||
_LOGGER.debug("Sending to {}: {}".format(self, message))
|
|
||||||
self.writer.write(message.bytes())
|
self.writer.write(message.bytes())
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retries == 0:
|
if retries == 0:
|
||||||
if isinstance(e, socket.error):
|
if isinstance(e, socket.error):
|
||||||
asyncio.ensure_future(self.async_disconnect())
|
asyncio.create_task(self.async_disconnect())
|
||||||
raise ConnectionException(
|
raise ConnectionException(
|
||||||
"Connection to {} failed: {}".format(self, e)
|
"Connection to {} failed: {}".format(self, e)
|
||||||
)
|
)
|
||||||
|
|
@ -805,4 +877,22 @@ class TuyaDevice:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Retrying send due to error. Failed to send data to {}".format(self)
|
"Retrying send due to error. Failed to send data to {}".format(self)
|
||||||
)
|
)
|
||||||
|
await asyncio.sleep(0.25)
|
||||||
await self._async_send(message, retries=retries - 1)
|
await self._async_send(message, retries=retries - 1)
|
||||||
|
|
||||||
|
async def async_recieve(self, message):
|
||||||
|
if message.expect_response is True:
|
||||||
|
try:
|
||||||
|
self._recieve_task = asyncio.create_task(
|
||||||
|
asyncio.wait_for(message.listener.acquire(), timeout=self.timeout)
|
||||||
|
)
|
||||||
|
await self._recieve_task
|
||||||
|
return self._listeners.pop(message.sequence)
|
||||||
|
except Exception as e:
|
||||||
|
del self._listeners[message.sequence]
|
||||||
|
|
||||||
|
raise ResponseTimeoutException(
|
||||||
|
"Timed out waiting for response to sequence number {}".format(
|
||||||
|
message.sequence
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -328,15 +328,13 @@ class RoboVacEntity(StateVacuumEntity):
|
||||||
self.update_entity_values()
|
self.update_entity_values()
|
||||||
except TuyaException as e:
|
except TuyaException as e:
|
||||||
self.update_failures += 1
|
self.update_failures += 1
|
||||||
_LOGGER.debug(
|
_LOGGER.warn(
|
||||||
"Update errored. Current failure count: {}. Reason: {}".format(
|
"Update errored. Current update failure count: {}. Reason: {}".format(
|
||||||
self.update_failures, e
|
self.update_failures, e
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.update_failures == UPDATE_RETRIES:
|
if self.update_failures >= UPDATE_RETRIES:
|
||||||
self.update_failures = 0
|
|
||||||
self.error_code = "CONNECTION_FAILED"
|
self.error_code = "CONNECTION_FAILED"
|
||||||
raise e
|
|
||||||
|
|
||||||
async def pushed_update_handler(self):
|
async def pushed_update_handler(self):
|
||||||
self.update_entity_values()
|
self.update_entity_values()
|
||||||
|
|
@ -479,4 +477,4 @@ class RoboVacEntity(StateVacuumEntity):
|
||||||
await self.vacuum.async_set({"124": base64_str})
|
await self.vacuum.async_set({"124": base64_str})
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self):
|
async def async_will_remove_from_hass(self):
|
||||||
await self.vacuum.async_disconnect()
|
await self.vacuum.async_disable()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue