Source code for tq42.channel

from __future__ import annotations

import asyncio
import logging
from typing import List, Callable, Optional, Awaitable

import grpc.aio
from google.protobuf import empty_pb2

from tq42.utils.exception_handling import handle_generic_sdk_errors
from com.terraquantum.channel.v1alpha1.create_channel_pb2 import CreateChannelResponse

# important for re-export
from com.terraquantum.channel.v1alpha1.channel_message_pb2 import (
    ChannelMessage,
    Ask,
    Parameter,  # pylint:disable=unused-import # noqa: F401
    Tell,
    DataAcknowledge,
)

from typing import TYPE_CHECKING

# only import the stuff for type hints -> avoid circular imports
if TYPE_CHECKING:
    from tq42.client import TQ42Client


from tq42.utils.timers import AsyncTimedIterable

_RETRYABLE_ERROR_CODES = (
    grpc.StatusCode.DATA_LOSS,
    grpc.StatusCode.UNAVAILABLE,
    grpc.StatusCode.INTERNAL,
)
_MAX_RECONNECT_RETRIES = 5


[docs]class Channel: """ Reference an existing channel with its `id`. A channel can be created with :py:func:`create`. :param client: a client instance :param id: existing channel id """ id: str """ID of the channel""" _client: TQ42Client _sequential_message_id: int = 0 def __init__(self, client: TQ42Client, id: str): self._client = client self.id = id def __repr__(self) -> str: return f"<Channel Id={self.id}>" def __str__(self) -> str: return self.id
[docs] @staticmethod @handle_generic_sdk_errors async def create(client: TQ42Client) -> Channel: """ Create a new channel :param client: a client instance :returns: a new channel """ empty = empty_pb2.Empty() res: CreateChannelResponse = await client.channel_client.CreateChannel( request=empty, metadata=client.metadata ) return Channel(client=client, id=res.channel_id)
[docs] async def connect( self, callback: Callable[[Ask], Awaitable[Tell]], finish_callback: Callable, max_duration_in_sec: Optional[int] = None, message_timeout_in_sec: Optional[int] = None, ) -> None: """ Connects to the stream and handles every message with the provided callback to create an answer. ASK gets into the callback and then we expect a TELL answer :param callback: Async callback that handles an ASK message and returns a TELL message :param finish_callback: Callback that is called when channel is completed :param int max_duration_in_sec: Timeout for whole connection in seconds. `None` -> no timeout for overall flow :param int message_timeout_in_sec: Timeout between messages in seconds. `None` -> no timeout between messages """ call = await self._establish_connection() async def _acknowledge_message(msg: ChannelMessage) -> None: nonlocal call ack_msg = ChannelMessage( acknowledge_data=DataAcknowledge(id=msg.sequential_message_id) ) await call.write(ack_msg) logging.debug(f"User Sent ack {msg.sequential_message_id=}") async def _handle(): nonlocal call try: timed_stream = AsyncTimedIterable(call, timeout=message_timeout_in_sec) incoming: ChannelMessage async for incoming in timed_stream: logging.debug(f"User received {incoming=}") data_field_name = incoming.WhichOneof("data") if data_field_name == "completion_data": await _acknowledge_message(msg=incoming) logging.debug( "Message indicated channel completion. Closing channel connection" ) break elif data_field_name == "ask_data": if ( self._sequential_message_id >= incoming.sequential_message_id ): logging.debug( "Message id is not sequential. Ignoring message" ) continue self._sequential_message_id = incoming.sequential_message_id await _acknowledge_message(msg=incoming) tell = await callback(incoming.ask_data) tell_msg = ChannelMessage( sequential_message_id=(incoming.sequential_message_id + 1), tell_data=tell, ) await call.write(tell_msg) except asyncio.TimeoutError: logging.debug("Stream finished because of the provided timeouts") raise TimeoutError( f"Channel was closed due to exceeding {message_timeout_in_sec}s timeout between messages" ) from None except grpc.aio.AioRpcError as e: if e.code() in _RETRYABLE_ERROR_CODES: call = await self._reestablish_connection() if call: return await _handle() raise e await asyncio.wait_for(_handle(), timeout=max_duration_in_sec) await call.done_writing() finish_callback()
async def _establish_connection(self): metadata: tuple = ( *self._client.metadata, ("channel-id", self.id), ) call = self._client.channel_client.ConnectChannelCustomer(metadata=metadata) await call.write(ChannelMessage()) return call async def _reestablish_connection(self): for i in range(1, _MAX_RECONNECT_RETRIES): logging.warning( f"Lost connection to channel, retrying (attempt {i}, max attempts: {_MAX_RECONNECT_RETRIES})" ) try: await asyncio.wait_for( self._client.channels_channel.channel_ready(), i * 2 ) call = await self._establish_connection() logging.info("Reconnected to channel.") return call except asyncio.TimeoutError: logging.error("Failed to reconnect.") return None
@handle_generic_sdk_errors def list_all(client: TQ42Client) -> List[Channel]: """ List all channels for a given user. :meta private: # FIXME remove once implemented """ raise NotImplementedError("This is a functionality still to come")