fix: fixed issue with multiple ping loops running

This commit is contained in:
Luke Bonaccorsi 2024-02-26 15:26:23 +00:00
parent 79892aa98f
commit 3e2b923255
1 changed files with 43 additions and 26 deletions

View File

@ -633,6 +633,7 @@ class TuyaDevice:
self.writer = None self.writer = None
self._response_task = None self._response_task = None
self._recieve_task = None self._recieve_task = None
self._ping_task = None
self._handlers: dict[int, Callable[[Message], Coroutine]] = { self._handlers: dict[int, Callable[[Message], Coroutine]] = {
Message.GRATUITOUS_UPDATE: self.async_gratuitous_update_state, Message.GRATUITOUS_UPDATE: self.async_gratuitous_update_state,
Message.PING_COMMAND: self._async_pong_received, Message.PING_COMMAND: self._async_pong_received,
@ -718,11 +719,15 @@ class TuyaDevice:
loop.create_connection loop.create_connection
self.reader, self.writer = await asyncio.open_connection(sock=sock) self.reader, self.writer = await asyncio.open_connection(sock=sock)
self._connected = True self._connected = True
asyncio.create_task(self.async_ping(self.ping_interval))
if self._ping_task is None:
self.ping_task = asyncio.create_task(self.async_ping(self.ping_interval))
asyncio.create_task(self._async_handle_message()) asyncio.create_task(self._async_handle_message())
async def async_disable(self): async def async_disable(self):
self._enabled = False self._enabled = False
await self.async_disconnect() await self.async_disconnect()
async def async_disconnect(self): async def async_disconnect(self):
@ -732,11 +737,6 @@ class TuyaDevice:
_LOGGER.debug("Disconnected from {}".format(self)) _LOGGER.debug("Disconnected from {}".format(self))
self._connected = False self._connected = False
self.last_pong = 0 self.last_pong = 0
if self._response_task is not None:
self._response_task.cancel()
if self._recieve_task is not None:
self._recieve_task.cancel()
if self.writer is not None: if self.writer is not None:
self.writer.close() self.writer.close()
@ -780,7 +780,7 @@ class TuyaDevice:
self._queue.append(message) self._queue.append(message)
await asyncio.sleep(ping_interval) await asyncio.sleep(ping_interval)
asyncio.create_task(self.async_ping(self.ping_interval)) self.ping_task = asyncio.create_task(self.async_ping(self.ping_interval))
if self.last_pong < self.last_ping: if self.last_pong < self.last_ping:
await self.async_disconnect() await self.async_disconnect()
@ -792,7 +792,11 @@ class TuyaDevice:
await self.update_entity_state_cb() await self.update_entity_state_cb()
async def async_update_state(self, state_message, _=None): async def async_update_state(self, state_message, _=None):
if state_message.payload and state_message.payload["dps"]: if (
state_message is not None
and state_message.payload
and state_message.payload["dps"]
):
self._dps.update(state_message.payload["dps"]) self._dps.update(state_message.payload["dps"])
_LOGGER.debug("Received updated state {}: {}".format(self, self._dps)) _LOGGER.debug("Received updated state {}: {}".format(self, self._dps))
@ -805,7 +809,7 @@ class TuyaDevice:
asyncio.create_task(self.async_set(new_values)) asyncio.create_task(self.async_set(new_values))
async def _async_handle_message(self): async def _async_handle_message(self):
if self._enabled is False: if self._enabled is False or self._connected is False:
return return
try: try:
@ -815,17 +819,16 @@ class TuyaDevice:
await self._response_task await self._response_task
response_data = self._response_task.result() response_data = self._response_task.result()
message = Message.from_bytes(response_data, self.cipher) message = Message.from_bytes(response_data, self.cipher)
self._response_task = None except Exception as e:
except InvalidMessage as e: if isinstance(e, InvalidMessage):
_LOGGER.debug("Invalid message from {}: {}".format(self, e)) _LOGGER.debug("Invalid message from {}: {}".format(self, e))
except MessageDecodeFailed as e: elif isinstance(e, MessageDecodeFailed):
_LOGGER.debug("Failed to decrypt message from {}".format(self)) _LOGGER.debug("Failed to decrypt message from {}".format(self))
except asyncio.IncompleteReadError as e: elif isinstance(e, asyncio.IncompleteReadError):
self._response_task = None if self._connected:
if self._connected: _LOGGER.debug("Incomplete read")
_LOGGER.debug("Incomplete read") elif isinstance(e, ConnectionResetError):
except ConnectionResetError as e: _LOGGER.debug("Connection reset")
_LOGGER.debug("Connection reset")
else: else:
_LOGGER.debug("Received message from {}: {}".format(self, message)) _LOGGER.debug("Received message from {}: {}".format(self, message))
@ -839,6 +842,7 @@ class TuyaDevice:
if handler is not None: if handler is not None:
asyncio.create_task(handler(message)) asyncio.create_task(handler(message))
self._response_task = None
asyncio.create_task(self._async_handle_message()) asyncio.create_task(self._async_handle_message())
async def _async_send(self, message, retries=2): async def _async_send(self, message, retries=2):
@ -850,7 +854,8 @@ class TuyaDevice:
except Exception as e: except Exception as e:
if retries == 0: if retries == 0:
if isinstance(e, socket.error): if isinstance(e, socket.error):
asyncio.create_task(self.async_disconnect()) await self.async_disconnect()
raise ConnectionException( raise ConnectionException(
"Connection to {} failed: {}".format(self, e) "Connection to {} failed: {}".format(self, e)
) )
@ -881,18 +886,30 @@ class TuyaDevice:
await self._async_send(message, retries=retries - 1) await self._async_send(message, retries=retries - 1)
async def async_recieve(self, message): async def async_recieve(self, message):
if self._connected is False:
return
if message.expect_response is True: if message.expect_response is True:
try: try:
self._recieve_task = asyncio.create_task( self._recieve_task = asyncio.create_task(
asyncio.wait_for(message.listener.acquire(), timeout=self.timeout) asyncio.wait_for(message.listener.acquire(), timeout=self.timeout)
) )
await self._recieve_task await self._recieve_task
return self._listeners.pop(message.sequence) response = self._listeners.pop(message.sequence)
if isinstance(response, Exception):
raise response
return response
except Exception as e: except Exception as e:
del self._listeners[message.sequence] del self._listeners[message.sequence]
await self.async_disconnect()
raise ResponseTimeoutException( if isinstance(e, TimeoutError):
"Timed out waiting for response to sequence number {}".format( raise ResponseTimeoutException(
message.sequence "Timed out waiting for response to sequence number {}".format(
message.sequence
)
) )
)
raise e