fix: fixed issue with multiple ping loops running

This commit is contained in:
Luke Bonaccorsi 2024-02-26 15:26:23 +00:00
parent 79892aa98f
commit 3e2b923255
1 changed files with 43 additions and 26 deletions

View File

@ -633,6 +633,7 @@ class TuyaDevice:
self.writer = None
self._response_task = None
self._recieve_task = None
self._ping_task = None
self._handlers: dict[int, Callable[[Message], Coroutine]] = {
Message.GRATUITOUS_UPDATE: self.async_gratuitous_update_state,
Message.PING_COMMAND: self._async_pong_received,
@ -718,11 +719,15 @@ class TuyaDevice:
loop.create_connection
self.reader, self.writer = await asyncio.open_connection(sock=sock)
self._connected = True
asyncio.create_task(self.async_ping(self.ping_interval))
if self._ping_task is None:
self.ping_task = asyncio.create_task(self.async_ping(self.ping_interval))
asyncio.create_task(self._async_handle_message())
async def async_disable(self):
self._enabled = False
await self.async_disconnect()
async def async_disconnect(self):
@ -732,11 +737,6 @@ class TuyaDevice:
_LOGGER.debug("Disconnected from {}".format(self))
self._connected = False
self.last_pong = 0
if self._response_task is not None:
self._response_task.cancel()
if self._recieve_task is not None:
self._recieve_task.cancel()
if self.writer is not None:
self.writer.close()
@ -780,7 +780,7 @@ class TuyaDevice:
self._queue.append(message)
await asyncio.sleep(ping_interval)
asyncio.create_task(self.async_ping(self.ping_interval))
self.ping_task = asyncio.create_task(self.async_ping(self.ping_interval))
if self.last_pong < self.last_ping:
await self.async_disconnect()
@ -792,7 +792,11 @@ class TuyaDevice:
await self.update_entity_state_cb()
async def async_update_state(self, state_message, _=None):
if state_message.payload and state_message.payload["dps"]:
if (
state_message is not None
and state_message.payload
and state_message.payload["dps"]
):
self._dps.update(state_message.payload["dps"])
_LOGGER.debug("Received updated state {}: {}".format(self, self._dps))
@ -805,7 +809,7 @@ class TuyaDevice:
asyncio.create_task(self.async_set(new_values))
async def _async_handle_message(self):
if self._enabled is False:
if self._enabled is False or self._connected is False:
return
try:
@ -815,16 +819,15 @@ class TuyaDevice:
await self._response_task
response_data = self._response_task.result()
message = Message.from_bytes(response_data, self.cipher)
self._response_task = None
except InvalidMessage as e:
except Exception as e:
if isinstance(e, InvalidMessage):
_LOGGER.debug("Invalid message from {}: {}".format(self, e))
except MessageDecodeFailed as e:
elif isinstance(e, MessageDecodeFailed):
_LOGGER.debug("Failed to decrypt message from {}".format(self))
except asyncio.IncompleteReadError as e:
self._response_task = None
elif isinstance(e, asyncio.IncompleteReadError):
if self._connected:
_LOGGER.debug("Incomplete read")
except ConnectionResetError as e:
elif isinstance(e, ConnectionResetError):
_LOGGER.debug("Connection reset")
else:
@ -839,6 +842,7 @@ class TuyaDevice:
if handler is not None:
asyncio.create_task(handler(message))
self._response_task = None
asyncio.create_task(self._async_handle_message())
async def _async_send(self, message, retries=2):
@ -850,7 +854,8 @@ class TuyaDevice:
except Exception as e:
if retries == 0:
if isinstance(e, socket.error):
asyncio.create_task(self.async_disconnect())
await self.async_disconnect()
raise ConnectionException(
"Connection to {} failed: {}".format(self, e)
)
@ -881,18 +886,30 @@ class TuyaDevice:
await self._async_send(message, retries=retries - 1)
async def async_recieve(self, message):
if self._connected is False:
return
if message.expect_response is True:
try:
self._recieve_task = asyncio.create_task(
asyncio.wait_for(message.listener.acquire(), timeout=self.timeout)
)
await self._recieve_task
return self._listeners.pop(message.sequence)
response = self._listeners.pop(message.sequence)
if isinstance(response, Exception):
raise response
return response
except Exception as e:
del self._listeners[message.sequence]
await self.async_disconnect()
if isinstance(e, TimeoutError):
raise ResponseTimeoutException(
"Timed out waiting for response to sequence number {}".format(
message.sequence
)
)
raise e