diff --git a/custom_components/robovac/tuyalocalapi.py b/custom_components/robovac/tuyalocalapi.py index d03926e..55c7abb 100644 --- a/custom_components/robovac/tuyalocalapi.py +++ b/custom_components/robovac/tuyalocalapi.py @@ -633,6 +633,7 @@ class TuyaDevice: self.writer = None self._response_task = None self._recieve_task = None + self._ping_task = None self._handlers: dict[int, Callable[[Message], Coroutine]] = { Message.GRATUITOUS_UPDATE: self.async_gratuitous_update_state, Message.PING_COMMAND: self._async_pong_received, @@ -718,11 +719,15 @@ class TuyaDevice: loop.create_connection self.reader, self.writer = await asyncio.open_connection(sock=sock) self._connected = True - asyncio.create_task(self.async_ping(self.ping_interval)) + + if self._ping_task is None: + self.ping_task = asyncio.create_task(self.async_ping(self.ping_interval)) + asyncio.create_task(self._async_handle_message()) async def async_disable(self): self._enabled = False + await self.async_disconnect() async def async_disconnect(self): @@ -732,11 +737,6 @@ class TuyaDevice: _LOGGER.debug("Disconnected from {}".format(self)) self._connected = False self.last_pong = 0 - if self._response_task is not None: - self._response_task.cancel() - - if self._recieve_task is not None: - self._recieve_task.cancel() if self.writer is not None: self.writer.close() @@ -780,7 +780,7 @@ class TuyaDevice: self._queue.append(message) await asyncio.sleep(ping_interval) - asyncio.create_task(self.async_ping(self.ping_interval)) + self.ping_task = asyncio.create_task(self.async_ping(self.ping_interval)) if self.last_pong < self.last_ping: await self.async_disconnect() @@ -792,7 +792,11 @@ class TuyaDevice: await self.update_entity_state_cb() async def async_update_state(self, state_message, _=None): - if state_message.payload and state_message.payload["dps"]: + if ( + state_message is not None + and state_message.payload + and state_message.payload["dps"] + ): self._dps.update(state_message.payload["dps"]) _LOGGER.debug("Received updated state {}: {}".format(self, self._dps)) @@ -805,7 +809,7 @@ class TuyaDevice: asyncio.create_task(self.async_set(new_values)) async def _async_handle_message(self): - if self._enabled is False: + if self._enabled is False or self._connected is False: return try: @@ -815,17 +819,16 @@ class TuyaDevice: await self._response_task response_data = self._response_task.result() message = Message.from_bytes(response_data, self.cipher) - self._response_task = None - except InvalidMessage as e: - _LOGGER.debug("Invalid message from {}: {}".format(self, e)) - except MessageDecodeFailed as e: - _LOGGER.debug("Failed to decrypt message from {}".format(self)) - except asyncio.IncompleteReadError as e: - self._response_task = None - if self._connected: - _LOGGER.debug("Incomplete read") - except ConnectionResetError as e: - _LOGGER.debug("Connection reset") + except Exception as e: + if isinstance(e, InvalidMessage): + _LOGGER.debug("Invalid message from {}: {}".format(self, e)) + elif isinstance(e, MessageDecodeFailed): + _LOGGER.debug("Failed to decrypt message from {}".format(self)) + elif isinstance(e, asyncio.IncompleteReadError): + if self._connected: + _LOGGER.debug("Incomplete read") + elif isinstance(e, ConnectionResetError): + _LOGGER.debug("Connection reset") else: _LOGGER.debug("Received message from {}: {}".format(self, message)) @@ -839,6 +842,7 @@ class TuyaDevice: if handler is not None: asyncio.create_task(handler(message)) + self._response_task = None asyncio.create_task(self._async_handle_message()) async def _async_send(self, message, retries=2): @@ -850,7 +854,8 @@ class TuyaDevice: except Exception as e: if retries == 0: if isinstance(e, socket.error): - asyncio.create_task(self.async_disconnect()) + await self.async_disconnect() + raise ConnectionException( "Connection to {} failed: {}".format(self, e) ) @@ -881,18 +886,30 @@ class TuyaDevice: await self._async_send(message, retries=retries - 1) async def async_recieve(self, message): + if self._connected is False: + return + if message.expect_response is True: try: self._recieve_task = asyncio.create_task( asyncio.wait_for(message.listener.acquire(), timeout=self.timeout) ) await self._recieve_task - return self._listeners.pop(message.sequence) + response = self._listeners.pop(message.sequence) + + if isinstance(response, Exception): + raise response + + return response except Exception as e: del self._listeners[message.sequence] + await self.async_disconnect() - raise ResponseTimeoutException( - "Timed out waiting for response to sequence number {}".format( - message.sequence + if isinstance(e, TimeoutError): + raise ResponseTimeoutException( + "Timed out waiting for response to sequence number {}".format( + message.sequence + ) ) - ) + + raise e