Improve network handling

This commit is contained in:
Luke Bonaccorsi 2023-09-28 10:15:45 +01:00
parent b48d12e05e
commit 40480cc319
2 changed files with 104 additions and 109 deletions

View File

@ -496,36 +496,23 @@ class Message:
__bytes__ = bytes __bytes__ = bytes
class AsyncWrappedCallback: async def async_send(self, device):
def __init__(self, request, callback): device._listeners[self.sequence] = asyncio.Semaphore(0)
self.request = request
self.callback = callback
self.devices = []
def register(self, device):
self.devices.append(device)
device._handlers.setdefault(self.request.command, [])
device._handlers[self.request.command].append(self)
def unregister(self, device):
self.devices.remove(device)
device._handlers[self.request.command].remove(self)
def unregister_all(self):
while self.devices:
device = self.devices.pop()
device._handlers[self.request.command].remove(self)
async def __call__(self, response, device):
if response.sequence == self.request.sequence:
asyncio.ensure_future(self.callback(response, device))
self.unregister(device)
async def async_send(self, device, callback=None):
if callback is not None:
wrapped = self.AsyncWrappedCallback(self, callback)
wrapped.register(device)
await device._async_send(self) await device._async_send(self)
try:
await asyncio.wait_for(
device._listeners[self.sequence].acquire(), timeout=device.timeout
)
except:
_LOGGER.debug(
"Timed out waiting for response to sequence number {}".format(
self.sequence
)
)
del device._listeners[self.sequence]
raise
return device._listeners.pop(self.sequence)
@classmethod @classmethod
def from_bytes(cls, data, cipher=None): def from_bytes(cls, data, cipher=None):
@ -626,6 +613,7 @@ class TuyaDevice:
host, host,
timeout, timeout,
ping_interval, ping_interval,
update_entity_state,
local_key=None, local_key=None,
port=6668, port=6668,
gateway_id=None, gateway_id=None,
@ -642,6 +630,7 @@ class TuyaDevice:
self.timeout = timeout self.timeout = timeout
self.last_pong = 0 self.last_pong = 0
self.ping_interval = ping_interval self.ping_interval = ping_interval
self.update_entity_state_cb = update_entity_state
if len(local_key) != 16: if len(local_key) != 16:
raise InvalidKey("Local key should be a 16-character string") raise InvalidKey("Local key should be a 16-character string")
@ -649,12 +638,12 @@ class TuyaDevice:
self.cipher = TuyaCipher(local_key, self.version) self.cipher = TuyaCipher(local_key, self.version)
self.writer = None self.writer = None
self._handlers = { self._handlers = {
Message.GET_COMMAND: [self.async_update_state], Message.GRATUITOUS_UPDATE: [self.async_gratuitous_update_state],
Message.GRATUITOUS_UPDATE: [self.async_update_state],
Message.PING_COMMAND: [self._async_pong_received], Message.PING_COMMAND: [self._async_pong_received],
} }
self._dps = {} self._dps = {}
self._connected = False self._connected = False
self._listeners = {}
def __repr__(self): def __repr__(self):
return "{}({!r}, {!r}, {!r}, {!r})".format( return "{}({!r}, {!r}, {!r}, {!r})".format(
@ -681,6 +670,8 @@ class TuyaDevice:
raise ConnectionTimeoutException("Connection timed out") from e raise ConnectionTimeoutException("Connection timed out") from e
self.reader, self.writer = await asyncio.open_connection(sock=sock) self.reader, self.writer = await asyncio.open_connection(sock=sock)
self._connected = True self._connected = True
asyncio.ensure_future(self._async_ping(self.ping_interval))
asyncio.ensure_future(self._async_handle_message())
async def async_disconnect(self): async def async_disconnect(self):
_LOGGER.debug("Disconnected from {}".format(self)) _LOGGER.debug("Disconnected from {}".format(self))
@ -689,17 +680,18 @@ class TuyaDevice:
if self.writer is not None: if self.writer is not None:
self.writer.close() self.writer.close()
async def async_get(self, callback=None): async def async_get(self):
payload = {"gwId": self.gateway_id, "devId": self.device_id} payload = {"gwId": self.gateway_id, "devId": self.device_id}
maybe_self = None if self.version < (3, 3) else self maybe_self = None if self.version < (3, 3) else self
message = Message(Message.GET_COMMAND, payload, encrypt_for=maybe_self) message = Message(Message.GET_COMMAND, payload, encrypt_for=maybe_self)
return await message.async_send(self, callback) response = await message.async_send(self)
await self.async_update_state(response)
async def async_set(self, dps, callback=None): async def async_set(self, dps):
t = int(time.time()) t = int(time.time())
payload = {"devId": self.device_id, "uid": "", "t": t, "dps": dps} payload = {"devId": self.device_id, "uid": "", "t": t, "dps": dps}
message = Message(Message.SET_COMMAND, payload, encrypt_for=self) message = Message(Message.SET_COMMAND, payload, encrypt_for=self)
await message.async_send(self, callback) await message.async_send(self)
def set(self, dps): def set(self, dps):
_call_async(self.async_set, dps) _call_async(self.async_set, dps)
@ -721,7 +713,11 @@ class TuyaDevice:
async def _async_pong_received(self, message, device): async def _async_pong_received(self, message, device):
self.last_pong = time.time() self.last_pong = time.time()
async def async_update_state(self, state_message, _): async def async_gratuitous_update_state(self, state_message, _):
await self.async_update_state(state_message)
await self.update_entity_state_cb()
async def async_update_state(self, state_message, _=None):
_LOGGER.info("Received updated state {}: {}".format(self, self._dps)) _LOGGER.info("Received updated state {}: {}".format(self, self._dps))
self._dps.update(state_message.payload["dps"]) self._dps.update(state_message.payload["dps"])
@ -734,9 +730,8 @@ class TuyaDevice:
asyncio.ensure_future(self.async_set(new_values)) asyncio.ensure_future(self.async_set(new_values))
async def _async_handle_message(self): async def _async_handle_message(self):
response_data = await self.reader.readuntil(MAGIC_SUFFIX_BYTES)
try: try:
response_data = await self.reader.readuntil(MAGIC_SUFFIX_BYTES)
message = Message.from_bytes(response_data, self.cipher) message = Message.from_bytes(response_data, self.cipher)
except InvalidMessage as e: except InvalidMessage as e:
_LOGGER.error("Invalid message from {}: {}".format(self, e)) _LOGGER.error("Invalid message from {}: {}".format(self, e))
@ -744,16 +739,23 @@ class TuyaDevice:
_LOGGER.error("Failed to decrypt message from {}".format(self)) _LOGGER.error("Failed to decrypt message from {}".format(self))
else: else:
_LOGGER.debug("Received message from {}: {}".format(self, message)) _LOGGER.debug("Received message from {}: {}".format(self, message))
if message.sequence in self._listeners:
sem = self._listeners[message.sequence]
if isinstance(sem, asyncio.Semaphore):
self._listeners[message.sequence] = message
sem.release()
else:
for c in self._handlers.get(message.command, []): for c in self._handlers.get(message.command, []):
asyncio.ensure_future(c(message, self)) asyncio.ensure_future(c(message, self))
asyncio.ensure_future(self._async_handle_message())
async def _async_send(self, message, retries=4): async def _async_send(self, message, retries=4):
try: try:
await self.async_connect() await self.async_connect()
_LOGGER.debug("Sending to {}: {}".format(self, message)) _LOGGER.debug("Sending to {}: {}".format(self, message))
self.writer.write(message.bytes()) self.writer.write(message.bytes())
await self.writer.drain() await self.writer.drain()
await self._async_handle_message()
except Exception as e: except Exception as e:
if retries == 0: if retries == 0:
if isinstance(e, socket.error): if isinstance(e, socket.error):

View File

@ -274,8 +274,9 @@ class RoboVacEntity(StateVacuumEntity):
host=self.ip_address, host=self.ip_address,
local_key=self.access_token, local_key=self.access_token,
timeout=2, timeout=2,
ping_interval=REFRESH_RATE, ping_interval=REFRESH_RATE / 2,
model_code=self.model_code[0:5], model_code=self.model_code[0:5],
update_entity_state=self.pushed_update_handler,
) )
except ModelNotSupportedException: except ModelNotSupportedException:
self.error_code = "UNSUPPORTED_MODEL" self.error_code = "UNSUPPORTED_MODEL"
@ -311,8 +312,25 @@ class RoboVacEntity(StateVacuumEntity):
try: try:
await self.vacuum.async_get() await self.vacuum.async_get()
self.update_failures = 0 self.update_failures = 0
self.update_entity_values()
except TuyaException as e:
self.update_failures += 1
_LOGGER.debug(
"Update errored. Current failure count: {}. Reason: {}".format(
self.update_failures, e
)
)
if self.update_failures == UPDATE_RETRIES:
self.update_failures = 0
self.error_code = "CONNECTION_FAILED"
raise e
async def pushed_update_handler(self):
self.update_entity_values()
self.async_write_ha_state()
def update_entity_values(self):
self.tuyastatus = self.vacuum._dps self.tuyastatus = self.vacuum._dps
# for 15C # for 15C
@ -342,49 +360,30 @@ class RoboVacEntity(StateVacuumEntity):
and self.tuyastatus.get(CONSUMABLE_CODE) is not None and self.tuyastatus.get(CONSUMABLE_CODE) is not None
): ):
self._attr_consumables = ast.literal_eval( self._attr_consumables = ast.literal_eval(
base64.b64decode( base64.b64decode(self.tuyastatus.get(CONSUMABLE_CODE)).decode(
self.tuyastatus.get(CONSUMABLE_CODE) "ascii"
).decode("ascii") )
)["consumable"]["duration"] )["consumable"]["duration"]
self.async_write_ha_state()
except TuyaException as e:
self.update_failures += 1
_LOGGER.debug(
"Update errored. Current failure count: {}. Reason: {}".format(
self.update_failures, e
)
)
if self.update_failures == UPDATE_RETRIES:
self.update_failures = 0
self.error_code = "CONNECTION_FAILED"
raise e
async def async_locate(self, **kwargs): async def async_locate(self, **kwargs):
"""Locate the vacuum cleaner.""" """Locate the vacuum cleaner."""
_LOGGER.info("Locate Pressed") _LOGGER.info("Locate Pressed")
if self.tuyastatus.get("103"): if self.tuyastatus.get("103"):
await self.vacuum.async_set({"103": False}, None) await self.vacuum.async_set({"103": False})
else: else:
await self.vacuum.async_set({"103": True}, None) await self.vacuum.async_set({"103": True})
async def async_return_to_base(self, **kwargs): async def async_return_to_base(self, **kwargs):
"""Set the vacuum cleaner to return to the dock.""" """Set the vacuum cleaner to return to the dock."""
_LOGGER.info("Return home Pressed") _LOGGER.info("Return home Pressed")
await self.vacuum.async_set({"101": True}, None) await self.vacuum.async_set({"101": True})
await asyncio.sleep(1)
self.async_update
async def async_start(self, **kwargs): async def async_start(self, **kwargs):
self._attr_mode = "auto" self._attr_mode = "auto"
await self.vacuum.async_set({"5": self.mode}, None) await self.vacuum.async_set({"5": self.mode})
await asyncio.sleep(1)
self.async_update
async def async_pause(self, **kwargs): async def async_pause(self, **kwargs):
await self.vacuum.async_set({"2": False}, None) await self.vacuum.async_set({"2": False})
await asyncio.sleep(1)
self.async_update
async def async_stop(self, **kwargs): async def async_stop(self, **kwargs):
await self.async_return_to_base() await self.async_return_to_base()
@ -392,9 +391,7 @@ class RoboVacEntity(StateVacuumEntity):
async def async_clean_spot(self, **kwargs): async def async_clean_spot(self, **kwargs):
"""Perform a spot clean-up.""" """Perform a spot clean-up."""
_LOGGER.info("Spot Clean Pressed") _LOGGER.info("Spot Clean Pressed")
await self.vacuum.async_set({"5": "Spot"}, None) await self.vacuum.async_set({"5": "Spot"})
await asyncio.sleep(1)
self.async_update
async def async_set_fan_speed(self, fan_speed, **kwargs): async def async_set_fan_speed(self, fan_speed, **kwargs):
"""Set fan speed.""" """Set fan speed."""
@ -405,9 +402,7 @@ class RoboVacEntity(StateVacuumEntity):
fan_speed = "Boost_IQ" fan_speed = "Boost_IQ"
elif fan_speed == "Pure": elif fan_speed == "Pure":
fan_speed = "Quiet" fan_speed = "Quiet"
await self.vacuum.async_set({"102": fan_speed}, None) await self.vacuum.async_set({"102": fan_speed})
await asyncio.sleep(1)
self.async_update
async def async_send_command( async def async_send_command(
self, command: str, params: dict | list | None = None, **kwargs self, command: str, params: dict | list | None = None, **kwargs
@ -415,28 +410,28 @@ class RoboVacEntity(StateVacuumEntity):
"""Send a command to a vacuum cleaner.""" """Send a command to a vacuum cleaner."""
_LOGGER.info("Send Command %s Pressed", command) _LOGGER.info("Send Command %s Pressed", command)
if command == "edgeClean": if command == "edgeClean":
await self.vacuum.async_set({"5": "Edge"}, None) await self.vacuum.async_set({"5": "Edge"})
elif command == "smallRoomClean": elif command == "smallRoomClean":
await self.vacuum.async_set({"5": "SmallRoom"}, None) await self.vacuum.async_set({"5": "SmallRoom"})
elif command == "autoClean": elif command == "autoClean":
await self.vacuum.async_set({"5": "auto"}, None) await self.vacuum.async_set({"5": "auto"})
elif command == "autoReturn": elif command == "autoReturn":
if self.auto_return: if self.auto_return:
await self.vacuum.async_set({"135": False}, None) await self.vacuum.async_set({"135": False})
else: else:
await self.vacuum.async_set({"135": True}, None) await self.vacuum.async_set({"135": True})
elif command == "doNotDisturb": elif command == "doNotDisturb":
if self.do_not_disturb: if self.do_not_disturb:
await self.vacuum.async_set({"139": "MEQ4MDAwMDAw"}, None) await self.vacuum.async_set({"139": "MEQ4MDAwMDAw"})
await self.vacuum.async_set({"107": False}, None) await self.vacuum.async_set({"107": False})
else: else:
await self.vacuum.async_set({"139": "MTAwMDAwMDAw"}, None) await self.vacuum.async_set({"139": "MTAwMDAwMDAw"})
await self.vacuum.async_set({"107": True}, None) await self.vacuum.async_set({"107": True})
elif command == "boostIQ": elif command == "boostIQ":
if self.boost_iq: if self.boost_iq:
await self.vacuum.async_set({"118": False}, None) await self.vacuum.async_set({"118": False})
else: else:
await self.vacuum.async_set({"118": True}, None) await self.vacuum.async_set({"118": True})
elif command == "roomClean": elif command == "roomClean":
roomIds = params.get("roomIds", [1]) roomIds = params.get("roomIds", [1])
count = params.get("count", 1) count = params.get("count", 1)
@ -449,9 +444,7 @@ class RoboVacEntity(StateVacuumEntity):
json_str = json.dumps(method_call, separators=(",", ":")) json_str = json.dumps(method_call, separators=(",", ":"))
base64_str = base64.b64encode(json_str.encode("utf8")).decode("utf8") base64_str = base64.b64encode(json_str.encode("utf8")).decode("utf8")
_LOGGER.info("roomClean call %s", json_str) _LOGGER.info("roomClean call %s", json_str)
await self.vacuum.async_set({"124": base64_str}, None) await self.vacuum.async_set({"124": base64_str})
await asyncio.sleep(1)
self.async_update
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
await self.vacuum.async_disconnect() await self.vacuum.async_disconnect()