Source code for pulsar.apps.data.channels

import re
import logging
import json
from functools import wraps
from enum import Enum
from asyncio import gather
from collections import namedtuple, OrderedDict

from pulsar import ProtocolError
from pulsar.apps.ds import redis_to_py_pattern

from .store import PubSubClient


event_callbacks = namedtuple('event_callbacks', 'name pattern regex callbacks')


LOGGER = logging.getLogger('pulsar.channels')


class StatusType(Enum):
    initialised = 1
    connecting = 2
    connected = 3
    disconnected = 4
    closed = 5


can_connect = frozenset((StatusType.initialised, StatusType.disconnected))


RECONNECT_LAG = 2
DEFAULT_NAMESPACE = ''
DEFAULT_CHANNEL = 'server'


def backoff(value):
    return min(value + 0.25, 16)


class CallbackError(Exception):
    """Exception which allow for a clean callback removal
    """


class Json:

    def encode(self, msg):
        return json.dumps(msg)

    def decode(self, msg):
        if isinstance(msg, bytes):
            msg = msg.decode('utf-8')
        try:
            return json.loads(msg)
        except Exception:
            raise ProtocolError('Invalid JSON') from None


class Connector:
    namespace_delimiter = '_'

    def __init__(self, store, namespace=None):
        self.connection_error = False
        self.namespace = (
            namespace or
            store.urlparams.get('namespace') or
            DEFAULT_NAMESPACE
        ).lower()
        if (self.namespace and not
                self.namespace.endswith(self.namespace_delimiter)):
            self.namespace = '%s%s' % (
                self.namespace, self.namespace_delimiter
            )
        self.dns = store.buildurl(namespace=self.namespace)

    def __repr__(self):
        return self.dns

    def __str__(self):
        return self.__repr__()

    def prefixed(self, name):
        if self.namespace and not name.startswith(self.namespace):
            name = '%s%s' % (self.namespace, name)
        return name

    def connection_ok(self):
        if self.connection_error:
            self.logger.warning(
                'connection with %s established - all good',
                self
            )
            self.connection_error = False
        else:
            return True


[docs]class Channels(Connector, PubSubClient): """Manage channels for publish/subscribe """ statusType = StatusType def __init__(self, store, namespace=None, status_channel=None, logger=None): super().__init__(store, namespace=namespace) self.store = store self.channels = OrderedDict() self.logger = logger or LOGGER self.status_channel = self.channel(status_channel or DEFAULT_CHANNEL) self.status = self.statusType.initialised @property def _loop(self): return self.store._loop def __repr__(self): return self.dns def __len__(self): return len(self.channels) def __contains__(self, name): return name in self.channels def __iter__(self): return iter(self.channels.values()) def __call__(self, channel_name, message): if channel_name.startswith(self.namespace): name = channel_name[len(self.namespace):] channel = self.channels.get(name) if channel: channel(message)
[docs] async def register(self, channel, event, callback): """Register a callback to ``channel_name`` and ``event``. A prefix will be added to the channel name if not already available or the prefix is an empty string :param channel: channel name :param event: event name :param callback: callback to execute when event on channel occurs :return: a coroutine which results in the channel where the callback was registered """ channel = self.channel(channel) event = channel.register(event, callback) await channel.connect(event.name) return channel
[docs] async def unregister(self, channel, event, callback): """Safely unregister a callback from the list of ``event`` callbacks for ``channel_name``. :param channel: channel name :param event: event name :param callback: callback to execute when event on channel occurs :return: a coroutine which results in the channel object where the ``callback`` was removed (if found) """ channel = self.channel(channel, create=False) if channel: channel.unregister(event, callback) if not channel: await channel.disconnect() self.channels.pop(channel.name) return channel
[docs] async def connect(self, next_time=None): """Connect with store :return: a coroutine and therefore it must be awaited """ if self.status in can_connect: loop = self._loop if loop.is_running(): self.status = StatusType.connecting await self._connect(next_time)
[docs] async def publish(self, channel, event, data=None): """Publish a new ``event`` on a ``channel`` :param channel: channel name :param event: event name :param data: optional payload to include in the event :return: a coroutine and therefore it must be awaited """ raise NotImplementedError
async def _subscribe(self, channel, event=None): """Subscribe to the remote server """ raise NotImplementedError async def _unsubscribe(self, channel): raise NotImplementedError
[docs] async def close(self): """Close channels and underlying store handler :return: a coroutine and therefore it must be awaited """ self.status = self.statusType.closed
def channel(self, name, create=True): name = name.lower() channel = self.channels.get(name) if channel is None and create: channel = Channel(self, name) self.channels[channel.name] = channel return channel
[docs] def event_pattern(self, event): """Channel pattern for an event name """ return redis_to_py_pattern(event)
def _connection_lost(self, *args): self.status = self.statusType.disconnected self._loop.create_task(self.connect()) async def _connect(self, next_time): try: # register self.status = StatusType.connecting await self._subscribe(self.status_channel) self.status = StatusType.connected self.logger.warning( '%s ready and listening for events on %s - all good', self, self.status_channel.name ) except ConnectionError: self.status = StatusType.disconnected next_time = backoff(next_time) if next_time else RECONNECT_LAG self.logger.critical( '%s cannot subscribe - connection error - ' 'try again in %s seconds', self, next_time ) self._loop.call_later(next_time, self._loop.create_task, self.connect(next_time)) else: await gather(*[c.connect() for c in self.channels.values() if c.name != self.status_channel.name])
def safe_execution(method): @wraps(method) async def _(self, *args, **kwargs): try: await method(self, *args, **kwargs) except ConnectionError: self.channels.status = StatusType.disconnected await self.channels.connect() return _ class Channel: """Channel .. attribute:: channels the channels container .. attribute:: name channel name .. attribute:: callbacks dictionary mapping events to callbacks """ def __init__(self, channels, name): self.channels = channels self.name = name self.callbacks = OrderedDict() @property def events(self): """List of event names this channel is registered with """ return tuple((e.name for e in self.callbacks.values())) def __repr__(self): return repr(self.callbacks) def __len__(self): return len(self.callbacks) def __contains__(self, pattern): return pattern in self.callbacks def __iter__(self): return iter(self.channels.values()) def __call__(self, message): event = message.pop('event', '') data = message.get('data') self.fire(event, data) def fire(self, event, data=None): for entry in tuple(self.callbacks.values()): match = entry.regex.match(event) if match: match = match.group() for callback in tuple(entry.callbacks): try: callback(self, match, data) except CallbackError: self._remove_callback(entry, callback) except Exception: self._remove_callback(entry, callback) self.channels.logger.exception( 'callback exception: channel "%s" event "%s"', self.name, event) @safe_execution async def connect(self, event=None): channels = self.channels if channels.status == StatusType.connected: await self.channels._subscribe(self, event) @safe_execution async def disconnect(self): channels = self.channels if channels.status == StatusType.connected: await self.channels._unsubscribe(self) def register(self, event, callback): """Register a ``callback`` for ``event`` """ pattern = self.channels.event_pattern(event) entry = self.callbacks.get(pattern) if not entry: entry = event_callbacks(event, pattern, re.compile(pattern), []) self.callbacks[entry.pattern] = entry if callback not in entry.callbacks: entry.callbacks.append(callback) return entry def unregister(self, event, callback): pattern = self.channels.event_pattern(event) entry = self.callbacks.get(pattern) if entry: return self._remove_callback(entry, callback) def _remove_callback(self, entry, callback): if callback in entry.callbacks: entry.callbacks.remove(callback) if not entry.callbacks: self.callbacks.pop(entry.pattern) return entry