Source code for fastiot.core.broker_connection

import asyncio
import concurrent.futures
import threading
from abc import ABC, abstractmethod
from asyncio import get_running_loop
from inspect import signature
from typing import Any, Callable, Coroutine, Optional

from nats.aio.client import Client as BrokerClient, Callback as BrokerCallback, ErrorCallback as BrokerErrorCallback
from nats.aio.msg import Msg as NatsBrokerMsg
from nats.aio.subscription import Subscription as BrokerSubscription

from fastiot.core.data_models import Msg, MsgPub, MsgReq, MsgResp, Subject, ReplySubject
from fastiot.core.serialization import serialize_from_bin, serialize_to_bin
from fastiot.env import env_broker

# passes MsgCls or str (subject_name) and MsgCls into callback
SubscriptionCallback = Callable[..., Coroutine[None, None, None]]
SubscriptionReplyCallback = Callable[..., Coroutine[None, None, Msg]]


[docs]class Subscription(ABC):
[docs] @abstractmethod async def unsubscribe(self): """ Cancels the subscription """
[docs]class NatsBrokerSubscription(Subscription):
[docs] def __init__(self, subscription_error_cb: Optional[BrokerErrorCallback] = None, **kwargs): super().__init__(**kwargs) self._subscription = None self._subscription_error_cb = subscription_error_cb
def _set_subscription(self, subscription: BrokerSubscription): self._subscription = subscription
[docs] async def unsubscribe(self): if self._subscription is None: raise RuntimeError("Expected a subscription object") await self._subscription.unsubscribe()
[docs]class NatsBrokerSubscriptionSubject(NatsBrokerSubscription):
[docs] def __init__(self, subject: Subject, cb: SubscriptionCallback, **kwargs): super().__init__(**kwargs) self._subject = subject self._cb = cb self._num_cb_params = len(signature(cb).parameters)
[docs] async def received_nats_msg_cb(self, nats_msg: NatsBrokerMsg): err = None try: msg = serialize_from_bin(self._subject.msg_cls, nats_msg.data) if self._num_cb_params == 1: result = await self._cb(msg) elif self._num_cb_params == 2: result = await self._cb(nats_msg.subject, msg) else: raise NotImplementedError("Callbacks with more then two params are not intended yet.") if result is not None: raise TypeError( f"Callbacks for subscriptions must return None. " f"Got object of type {type(result)} instead. " f"Maybe you need to use request pattern, e.g. @reply instead of @subscribe?" ) except Exception as e: err = e if err and self._subscription_error_cb: await self._subscription_error_cb(err)
[docs]class NatsBrokerSubscriptionReplySubject(NatsBrokerSubscription):
[docs] def __init__(self, subject: ReplySubject, cb: SubscriptionReplyCallback, send_reply_fn: Callable[[Subject, Msg], Coroutine[None, None, None]], **kwargs): super().__init__(**kwargs) self._subject = subject self._cb = cb self._num_cb_params = len(signature(cb).parameters) self._send_reply_fn = send_reply_fn self._subscription = None cb_signature = signature(self._cb) self._cb_with_subject = len(cb_signature.parameters) == 2
[docs] async def received_nats_msg_cb(self, nats_msg: NatsBrokerMsg): err = None try: msg = serialize_from_bin(self._subject.msg_cls, nats_msg.data) if self._num_cb_params == 1: reply_msg = await self._cb(msg) elif self._num_cb_params == 2: reply_msg = await self._cb(nats_msg.subject, msg) else: raise NotImplementedError("Callbacks with more then two params are not intended yet.") if not isinstance(reply_msg, self._subject.reply_cls): raise TypeError(f"Callback has not returned correct type: Expected type {self._subject.reply_cls}, " f"got {type(msg)}") reply_subject = Subject( name=nats_msg.reply, msg_cls=self._subject.reply_cls ) await self._send_reply_fn(reply_subject, reply_msg) except Exception as e: err = e if err and self._subscription_error_cb: await self._subscription_error_cb(err)
[docs]class BrokerConnection(ABC):
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self._loop_mutex = threading.RLock() # Try to set loop for synchronous functionality try: self._loop = asyncio.get_event_loop() except RuntimeError: self._loop = None
def _set_loop(self, loop: asyncio.AbstractEventLoop): with self._loop_mutex: self._loop = loop
[docs] @abstractmethod async def subscribe(self, subject: Subject, cb: SubscriptionCallback) -> Subscription: """ Subscribe to a subject. :param subject: The subject to subscribe to :param cb: The callback which is called when a message is received """
[docs] async def subscribe_msg_queue(self, subject: Subject, msg_queue: asyncio.Queue[Msg]) -> Subscription: """ Subscribe to a subject using a message queue instead of a callback. Use this if you prefer querying a msg_queue. :param subject: The subject to subscribe to :param msg_queue: The message queue where received messages are enqueued. """ async def cb(msg): nonlocal msg_queue await msg_queue.put(msg) return await self.subscribe(subject=subject, cb=cb)
[docs] @abstractmethod async def subscribe_reply_cb(self, subject: ReplySubject, cb: SubscriptionReplyCallback) -> Subscription: """ Subscribe to a reply subject. It is expected that the message will be answered. :param subject: The reply subject to subscribe to :param cb: The callback which is called when a request is received """
@abstractmethod async def _send(self, subject: Subject, msg: Msg, reply: Optional[Subject] = None): """ Low level method to send msg to broker """
[docs] async def publish(self, subject: Subject, msg: MsgPub): """ Publishes a message for a subject. :param subject: The subject info to publish to. :param msg: The message. """ await self._send(subject=subject, msg=msg)
[docs] async def request(self, subject: ReplySubject, msg: MsgReq, timeout: float = env_broker.default_timeout) -> MsgResp: """ Send a request on a subject. :param subject: The subject used for sending the request. :param msg: The request :param timeout: The time in seconds to wait for an answer. Raises ErrTimeout if no answer is received in time. :return: The response """ inbox = subject.make_generic_reply_inbox() msg_queue = asyncio.Queue() sub = await self.subscribe_msg_queue(subject=inbox, msg_queue=msg_queue) try: await self._send(subject=subject, msg=msg, reply=inbox) result = await asyncio.wait_for(msg_queue.get(), timeout=timeout) finally: await sub.unsubscribe() return result
@property @abstractmethod def is_connected(self) -> bool: """ Return the connection status e.g. for health checks """
[docs] def run_threadsafe_nowait(self, coro: Coroutine) -> concurrent.futures.Future: """ Runs a coroutine on brokers event loop. This method is thread-safe. It can be useful if you want to interact with the broker from another thread. :param coro: The coroutine to run thread-safe on brokers event loop, for example 'broker_client.publish(...)' """ with self._loop_mutex: if not self._loop: raise RuntimeError( "No event loop has been detected. Please make sure the " "connection is created in an asyncio context or the loop " "has been set manually." ) return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._loop)
[docs] def run_threadsafe(self, coro: Coroutine, timeout: float = 0.0) -> Any: """ Runs a coroutine on brokers event loop. This method is thread-safe. It can be useful if you want to interact with the broker from another thread. :param coro: The coroutine to run thread-safe on brokers event loop, for example 'broker_client.publish(...)' :param timeout: The number of seconds to wait for the result to be done. Raises concurrent.futures.TimeoutError if timeout exceeds. A value of zero means wait forever. :return: Returns the result of the coroutine. If the coroutine raised an exception, it is reraised. """ future = self.run_threadsafe_nowait(coro=coro) return future.result(timeout=timeout if timeout else None)
[docs] def publish_sync(self, subject: Subject, msg: MsgPub, timeout: float = 0.0): """ Publishes a message for a subject. This method is thread-safe. Under the hood, it uses run_threadsafe. :param subject: The subject info to publish to. :param msg: The message. :param timeout: The timeout. """ return self.run_threadsafe( coro=self.publish(subject=subject, msg=msg), timeout=timeout )
[docs] def publish_sync_nowait(self, subject: Subject, msg: MsgPub) -> concurrent.futures.Future: """ Publishes a message for a subject. This method is thread-safe. Under the hood, it uses run_threadsafe_nowait. :param subject: The subject info to publish to. :param msg: The message. """ return self.run_threadsafe_nowait( coro=self.publish(subject=subject, msg=msg) )
[docs] def request_sync(self, subject: ReplySubject, msg: MsgReq, timeout: float = env_broker.default_timeout) -> MsgResp: """ Performs a request on the subject. This method is thread-safe. Under the hood, it uses run_threadsafe. Please note, that it will only timeout if the request times out. For purposes of simplicity it will wait forever, if the executing thread is occupied too much and the request cannot be scheduled. :param subject: The subject info to publish the request. :param msg: The request message. :param timeout: The timeout for the broker call. :return: The requested message. """ return self.run_threadsafe( coro=self.request(subject=subject, msg=msg, timeout=timeout) )
[docs]class NatsBrokerConnection(BrokerConnection):
[docs] @classmethod async def connect(cls, closed_cb: Optional[BrokerCallback] = None, subscription_error_cb: Optional[BrokerErrorCallback] = None ) -> "NatsBrokerConnection": """ Connects a nats instance and returns a nats broker connection. """ client = BrokerClient() await client.connect( f"nats://{env_broker.host}:{env_broker.port}", closed_cb=closed_cb ) return cls( client=client, subscription_error_cb=subscription_error_cb )
[docs] def __init__(self, client: BrokerClient, subscription_error_cb: Optional[BrokerErrorCallback] = None): super().__init__() self._client = client self._subscription_error_cp = subscription_error_cb self._loop = get_running_loop()
[docs] async def close(self): await self._client.close()
[docs] async def subscribe(self, subject: Subject, cb: SubscriptionCallback, ) -> Subscription: result = NatsBrokerSubscriptionSubject( subject=subject, cb=cb, subscription_error_cb=self._subscription_error_cp ) subscription = await self._client.subscribe( subject=subject.name, cb=result.received_nats_msg_cb ) result._set_subscription(subscription=subscription) return result
[docs] async def subscribe_reply_cb(self, subject: ReplySubject, cb: SubscriptionReplyCallback) -> Subscription: result = NatsBrokerSubscriptionReplySubject( subject=subject, cb=cb, send_reply_fn=self._send ) subscription = await self._client.subscribe( subject=subject.name, cb=result.received_nats_msg_cb ) result._set_subscription(subscription=subscription) return result
async def _send(self, subject: Subject, msg: Msg, reply: Optional[Subject] = None): payload = serialize_to_bin(subject.msg_cls, msg) reply_str = '' if reply is None else reply.name await self._client.publish( subject=subject.name, payload=payload, reply=reply_str ) @property def is_connected(self): return self._client.is_connected
[docs]class SubscriptionDummy(Subscription):
[docs] async def unsubscribe(self): pass
[docs] def check_pending_error(self): pass
[docs] async def raise_pending_error(self): await asyncio.Event().wait()
[docs]class BrokerConnectionDummy(BrokerConnection): """ A dummy broker implementation to mock dependencies. """ @property def is_connected(self) -> bool: return True
[docs] async def subscribe(self, subject: Subject, cb: SubscriptionCallback) -> Subscription: return SubscriptionDummy()
[docs] async def subscribe_reply_cb(self, subject: ReplySubject, cb: SubscriptionReplyCallback) -> Subscription: return SubscriptionDummy()
async def _send(self, subject: Subject, msg: Msg, reply: Optional[Subject] = None): pass