Source code for pulsar.async.protocols

import asyncio

from pulsar.utils.internet import nice_address, format_address

from .futures import task, Future, ensure_future
from .events import EventHandler, AbortEvent
from .mixins import FlowControl, Timeout


__all__ = ['ProtocolConsumer',
           'Protocol',
           'DatagramProtocol',
           'Connection',
           'Producer',
           'TcpServer',
           'DatagramServer',
           'AbortRequest']


CLOSE_TIMEOUT = 3


class AbortRequest(AbortEvent):
    pass


[docs]class ProtocolConsumer(EventHandler): """The consumer of data for a server or client :class:`.Connection`. It is responsible for receiving incoming data from an end point via the :meth:`Connection.data_received` method, decoding (parsing) and, possibly, writing back to the client or server via the :attr:`transport` attribute. .. note:: For server consumers, :meth:`data_received` is the only method to implement. For client consumers, :meth:`start_request` should also be implemented. A :class:`ProtocolConsumer` is a subclass of :class:`.EventHandler` and it has two default :ref:`one time events <one-time-event>`: * ``pre_request`` fired when the request is received (for servers) or just before is sent (for clients). This occurs just before the :meth:`start_request` method. * ``post_request`` fired when the request is done. The :attr:`on_finished` attribute is a shortcut for the ``post_request`` :class:`.OneTime` event and therefore can be used to wait for the request to have received a full response (clients). In addition, it has two :ref:`many times events <many-times-event>`: * ``data_received`` fired when new data is received from the transport but not yet processed (before the :meth:`data_received` method is invoked) * ``data_processed`` fired just after data has been consumed (after the :meth:`data_received` method) .. note:: A useful example on how to use the ``data_received`` event is the :ref:`wsgi proxy server <tutorials-proxy-server>`. """ _connection = None _data_received_count = 0 ONE_TIME_EVENTS = ('pre_request', 'post_request') MANY_TIMES_EVENTS = ('data_received', 'data_processed') @property def connection(self): """The :class:`Connection` of this consumer. """ return self._connection @property def request(self): """The request. Used for clients only and available only after the :meth:`start` method is invoked. """ return getattr(self, '_request', None) @property def transport(self): """The :class:`Transport` of this consumer """ if self._connection: return self._connection.transport @property def address(self): if self._connection: return self._connection.address @property def producer(self): """The :class:`Producer` of this consumer. """ if self._connection: return self._connection.producer @property def on_finished(self): """Event fired once a full response to a request is received. It is the ``post_request`` one time event. """ return self.event('post_request')
[docs] def connection_made(self, connection): """Called by a :class:`Connection` when it starts using this consumer. By default it does nothing. """
[docs] def data_received(self, data): """Called when some data is received. **This method must be implemented by subclasses** for both server and client consumers. The argument is a bytes object. """
[docs] def start_request(self): """Starts a new request. Invoked by the :meth:`start` method to kick start the request with remote server. For server :class:`ProtocolConsumer` this method is not invoked at all. **For clients this method should be implemented** and it is critical method where errors caused by stale socket connections can arise. **This method should not be called directly.** Use :meth:`start` instead. Typically one writes some data from the :attr:`request` into the transport. Something like this:: self.transport.write(self.request.encode()) """ raise NotImplementedError
[docs] def start(self, request=None): """Starts processing the request for this protocol consumer. There is no need to override this method, implement :meth:`start_request` instead. If either :attr:`connection` or :attr:`transport` are missing, a :class:`RuntimeError` occurs. For server side consumer, this method simply fires the ``pre_request`` event. """ if hasattr(self, '_request'): raise RuntimeError('%s already requested %s' % (self, self._request)) conn = self._connection if not conn: raise RuntimeError('Cannot start new request. No connection.') if not conn.transport: raise RuntimeError('%s has no transport.' % conn) conn._processed += 1 if conn._producer: p = getattr(conn._producer, '_requests_processed', 0) conn._producer._requests_processed = p + 1 self.bind_event('post_request', self._finished) self._request = request return ensure_future(self._start(), loop=self._loop)
[docs] def abort_request(self): """Abort the request. This method can be called during the pre-request stage """ future = self.events['pre_request'] if future.done(): raise RuntimeError('Request already sent') future.add_done_callback(self._abort_request) raise AbortRequest
[docs] def connection_lost(self, exc): """Called by the :attr:`connection` when the transport is closed. By default it calls the :meth:`finished` method. It can be overwritten to handle the potential exception ``exc``. """ # TODO: decide how to handle connection_lost when no exception occurs # Set the first positional parameter to None so that if the # connection was dropped without exception it returns None # rather than the protocol consumer return self.finished(None)
[docs] def finished(self, *arg, **kw): """Fire the ``post_request`` event if it wasn't already fired. """ if not self.done(): return self.fire_event('post_request', *arg, **kw)
def done(self): return self.event('post_request').fired()
[docs] def write(self, data): """Delegate writing to the underlying :class:`.Connection` Return an empty tuple or a :class:`~asyncio.Future` """ c = self._connection if c: return c.write(data) else: raise RuntimeError('No connection')
async def _start(self): try: await self.fire_event('pre_request') except AbortEvent: self.logger.debug('Abort request %s', self.request) else: if self._request is not None: try: self.start_request() except Exception as exc: self.finished(exc=exc) def _data_received(self, data): # Called by Connection, it updates the counters and invoke # the high level data_received method which must be implemented # by subclasses if not hasattr(self, '_request'): self.start() self._data_received_count += 1 self.fire_event('data_received', data=data) result = self.data_received(data) self.fire_event('data_processed', data=data) return result def _finished(self, _, exc=None): c = self._connection if c and c._current_consumer is self: c._current_consumer = None @task async def _abort_request(self, fut): exc = fut.exception() for event in self.ONE_TIME_EVENTS: if not self.event(event).fired(): try: await self.fire_event(event, exc=exc) except AbortRequest: pass
[docs]class PulsarProtocol(EventHandler, FlowControl): """A mixin class for both :class:`.Protocol` and :class:`.DatagramProtocol`. A :class:`PulsarProtocol` is an :class:`.EventHandler` which has two :ref:`one time events <one-time-event>`: * ``connection_made`` * ``connection_lost`` """ ONE_TIME_EVENTS = ('connection_made', 'connection_lost') MANY_TIMES_EVENTS = ('data_received', 'data_processed', 'before_write', 'after_write') _transport = None _address = None _closed = None def __init__(self, loop, session=1, producer=None, logger=None, **kw): super().__init__(loop) FlowControl.__init__(self, **kw) self._logger = logger self._session = session self._producer = producer def __repr__(self): address = self._address if address: return '%s session %s' % (nice_address(address), self._session) else: return '<pending> session %s' % self._session __str__ = __repr__ @property def session(self): """Connection session number. Passed during initialisation by the :attr:`producer`. Usually an integer representing the number of separate connections the producer has processed at the time it created this :class:`Protocol`. """ return self._session @property def transport(self): """The :ref:`transport <asyncio-transport>` for this protocol. Available once the :meth:`connection_made` is called. """ return self._transport @property def sock(self): """The socket of :attr:`transport`. """ if self._transport: return self._transport.get_extra_info('socket') @property def address(self): """The address of the :attr:`transport`. """ return self._address @property def producer(self): """The producer of this :class:`Protocol`. """ return self._producer @property def closed(self): """``True`` if the :attr:`transport` is closed. """ if self._transport: if hasattr(self._transport, 'is_closing'): return self._transport.is_closing() return False return True
[docs] def close(self): """Close by closing the :attr:`transport` Return the ``connection_lost`` event which can be used to wait for complete transport closure. """ if not self._closed: if self._transport: if self.debug: self.logger.debug('Closing connection %s', self) if self._transport.can_write_eof(): try: self._transport.write_eof() except Exception: pass try: self._transport.close() except Exception: pass self._closed = ensure_future(self._close()) return self.event('connection_lost')
[docs] def abort(self): """Abort by aborting the :attr:`transport` """ if self._transport: self._transport.abort()
[docs] def connection_made(self, transport): """Sets the :attr:`transport`, fire the ``connection_made`` event and adds a :attr:`timeout` for idle connections. """ self._transport = transport addr = self._transport.get_extra_info('peername') if not addr: addr = self._transport.get_extra_info('sockname') self._address = addr # let everyone know we have a connection with endpoint self.fire_event('connection_made')
[docs] def connection_lost(self, exc=None): """Fires the ``connection_lost`` event. """ if self.debug and not exc: self.logger.debug('Lost connection %s', self) self.fire_event('connection_lost')
[docs] def eof_received(self): """The socket was closed from the remote end """
def info(self): info = {'connection': {'session': self._session}} if self._producer: info.update(self._producer.info()) return info async def _close(self): try: await asyncio.wait_for( self.event('connection_lost'), CLOSE_TIMEOUT ) except asyncio.TimeoutError: self.logger.warning('Abort connection %s', self) self.abort()
[docs]class Protocol(PulsarProtocol, asyncio.Protocol): """An :class:`asyncio.Protocol` with :ref:`events <event-handling>` """ _data_received_count = 0
[docs] def write(self, data): """Write ``data`` into the wire. Returns an empty tuple or a :class:`~asyncio.Future` if this protocol has paused writing. """ t = self._transport if t: if self._paused: # # Uses private variable once again! # This occurs when the protocol is paused from writing # but another data ready callback is fired in the same # event-loop frame self.logger.debug('protocol cannot write, add data to the ' 'transport buffer') t._buffer.extend(data) else: self.fire_event('before_write') t.write(data) self.fire_event('after_write') return self._write_waiter or () else: raise ConnectionResetError('No Transport')
[docs]class DatagramProtocol(PulsarProtocol, asyncio.DatagramProtocol): """An ``asyncio.DatagramProtocol`` with events` """
[docs]class Connection(Protocol, Timeout): """A :class:`.FlowControl` to handle multiple TCP requests/responses. It is a class which acts as bridge between a :ref:`transport <asyncio-transport>` and a :class:`.ProtocolConsumer`. It routes data arriving from the transport to the :meth:`current_consumer`. .. attribute:: _consumer_factory A factory of :class:`.ProtocolConsumer`. .. attribute:: _processed number of separate requests processed. """ def __init__(self, consumer_factory=None, timeout=None, **kw): super().__init__(**kw) self.bind_event('connection_lost', self._connection_lost) self._processed = 0 self._current_consumer = None self._consumer_factory = consumer_factory self.timeout = timeout @property def requests_processed(self): return self._processed
[docs] def current_consumer(self): """The :class:`ProtocolConsumer` currently handling incoming data. This instance will receive data when this connection get data from the :attr:`~PulsarProtocol.transport` via the :meth:`data_received` method. If no consumer is available, build a new one and return it. """ if self._current_consumer is None: self._build_consumer(None) return self._current_consumer
[docs] def data_received(self, data): """Delegates handling of data to the :meth:`current_consumer`. Once done set a timeout for idle connections when a :attr:`~Protocol.timeout` is a positive number (of seconds). """ self._data_received_count = self._data_received_count + 1 self.fire_event('data_received', data=data) toprocess = data while toprocess: consumer = self.current_consumer() toprocess = consumer._data_received(toprocess) if isinstance(toprocess, Future): break self.fire_event('data_processed', data=data)
[docs] def upgrade(self, consumer_factory): """Upgrade the :func:`_consumer_factory` callable. This method can be used when the protocol specification changes during a response (an example is a WebSocket request/response, or HTTP tunneling). This method adds a ``post_request`` callback to the :meth:`current_consumer` to build a new consumer with the new :func:`_consumer_factory`. :param consumer_factory: the new consumer factory (a callable accepting no parameters) :return: ``None``. """ self._consumer_factory = consumer_factory consumer = self._current_consumer if consumer: consumer.bind_event('post_request', self._build_consumer) else: self._build_consumer(None)
def info(self): info = super().info() c = info['connection'] c['request_processed'] = self._processed c['data_processed_count'] = self._data_received_count c['timeout'] = self.timeout return info def _build_consumer(self, _, exc=None): if not exc or isinstance(exc, AbortEvent): consumer = self._producer.build_consumer(self._consumer_factory) assert self._current_consumer is None, 'Consumer is not None' self._current_consumer = consumer consumer._connection = self consumer.connection_made(self) def _connection_lost(self, _, exc=None): """It performs these actions in the following order: * Fires the ``connection_lost`` :ref:`one time event <one-time-event>` if not fired before, with ``exc`` as event data. * Cancel the idle timeout if set. * Invokes the :meth:`ProtocolConsumer.connection_lost` method in the :meth:`current_consumer`. """ if self._current_consumer: self._current_consumer.connection_lost(exc)
[docs]class Producer(EventHandler): """An Abstract :class:`.EventHandler` class for all producers of socket (client and servers) """ protocol_factory = None """A callable producing protocols. The signature of the protocol factory callable must be:: protocol_factory(session, producer, **params) """ def __init__(self, loop=None, protocol_factory=None, name=None, max_requests=None, logger=None): super().__init__(loop or asyncio.get_event_loop()) self.protocol_factory = protocol_factory or self.protocol_factory self._name = name or self.__class__.__name__ self._requests_processed = 0 self._sessions = 0 self._max_requests = max_requests self._logger = logger @property def sessions(self): """Total number of protocols created by the :class:`Producer`. """ return self._sessions @property def requests_processed(self): """Total number of requests processed. """ return self._requests_processed
[docs] def create_protocol(self, **kw): """Create a new protocol via the :meth:`protocol_factory` This method increase the count of :attr:`sessions` and build the protocol passing ``self`` as the producer. """ self._sessions = self._sessions + 1 kw['session'] = self.sessions kw['producer'] = self kw['loop'] = self._loop kw['logger'] = self._logger return self.protocol_factory(**kw)
[docs] def build_consumer(self, consumer_factory): """Build a consumer for a protocol. This method can be used by protocols which handle several requests, for example the :class:`Connection` class. :param consumer_factory: consumer factory to use. """ consumer = consumer_factory(loop=self._loop) consumer._logger = self._logger consumer.copy_many_times_events(self) return consumer
[docs]class TcpServer(Producer): """A :class:`.Producer` of server :class:`Connection` for TCP servers. .. attribute:: _server A :class:`.Server` managed by this Tcp wrapper. Available once the :meth:`start_serving` method has returned. """ ONE_TIME_EVENTS = ('start', 'stop') MANY_TIMES_EVENTS = ('connection_made', 'pre_request', 'post_request', 'connection_lost') _server = None _started = None def __init__(self, protocol_factory, loop, address=None, name=None, sockets=None, max_requests=None, keep_alive=None, logger=None): super().__init__(loop, protocol_factory, name=name, max_requests=max_requests, logger=logger) self._params = {'address': address, 'sockets': sockets} self._keep_alive = max(keep_alive or 0, 0) self._concurrent_connections = set() def __repr__(self): address = self.address if address: return '%s %s' % (self.__class__.__name__, address) else: return self.__class__.__name__ __str_ = __repr__ @property def address(self): """Socket address of this server. It is obtained from the first socket ``getsockname`` method. """ if self._server is not None: return self._server.sockets[0].getsockname()
[docs] async def start_serving(self, backlog=100, sslcontext=None): """Start serving. :param backlog: Number of maximum connections :param sslcontext: optional SSLContext object. :return: a :class:`.Future` called back when the server is serving the socket. """ assert not self._server if hasattr(self, '_params'): address = self._params['address'] sockets = self._params['sockets'] del self._params create_server = self._loop.create_server if sockets: server = None for sock in sockets: srv = await create_server(self.create_protocol, sock=sock, backlog=backlog, ssl=sslcontext) if server: server.sockets.extend(srv.sockets) else: server = srv else: if isinstance(address, tuple): server = await create_server(self.create_protocol, host=address[0], port=address[1], backlog=backlog, ssl=sslcontext) else: raise NotImplementedError self._server = server self._started = self._loop.time() for sock in server.sockets: address = sock.getsockname() self.logger.info('%s serving on %s', self._name, format_address(address)) self._loop.call_soon(self.fire_event, 'start')
[docs] async def close(self): """Stop serving the :attr:`.Server.sockets`. """ if self._server: self._server.close() self._server = None coro = self._close_connections() if coro: await coro self.fire_event('stop')
def info(self): sockets = [] up = int(self._loop.time() - self._started) if self._started else 0 server = {'uptime_in_seconds': up, 'sockets': sockets, 'max_requests': self._max_requests, 'keep_alive': self._keep_alive} clients = {'processed_clients': self._sessions, 'connected_clients': len(self._concurrent_connections), 'requests_processed': self._requests_processed} if self._server: for sock in self._server.sockets: sockets.append({ 'address': format_address(sock.getsockname())}) return {'server': server, 'clients': clients}
[docs] def create_protocol(self): """Override :meth:`Producer.create_protocol`. """ protocol = super().create_protocol(timeout=self._keep_alive) protocol.bind_event('connection_made', self._connection_made) protocol.bind_event('connection_lost', self._connection_lost) protocol.copy_many_times_events(self) if (self._server and self._max_requests and self._sessions >= self._max_requests): self.logger.info('Reached maximum number of connections %s. ' 'Stop serving.' % self._max_requests) self.close() return protocol
# INTERNALS def _connection_made(self, connection, exc=None): if not exc: self._concurrent_connections.add(connection) def _connection_lost(self, connection, exc=None): self._concurrent_connections.discard(connection) def _close_connections(self, connection=None, timeout=5): """Close ``connection`` if specified, otherwise close all connections. Return a list of :class:`.Future` called back once the connection/s are closed. """ all = [] if connection: all.append(connection.event('connection_lost')) connection.close() else: connections = list(self._concurrent_connections) self._concurrent_connections = set() for connection in connections: all.append(connection.event('connection_lost')) connection.close() if all: self.logger.info('%s closing %d connections', self, len(all)) return asyncio.wait(all, timeout=timeout, loop=self._loop)
[docs]class DatagramServer(Producer): """An :class:`.Producer` for serving UDP sockets. .. attribute:: _transports A list of :class:`.DatagramTransport`. Available once the :meth:`create_endpoint` method has returned. """ _transports = None _started = None ONE_TIME_EVENTS = ('start', 'stop') MANY_TIMES_EVENTS = ('pre_request', 'post_request') def __init__(self, protocol_factory, loop=None, address=None, name=None, sockets=None, max_requests=None, logger=None): super().__init__(loop, protocol_factory, name=name, max_requests=max_requests, logger=logger) self._params = {'address': address, 'sockets': sockets}
[docs] async def create_endpoint(self, **kw): """create the server endpoint. :return: a :class:`~asyncio.Future` called back when the server is serving the socket. """ if hasattr(self, '_params'): address = self._params['address'] sockets = self._params['sockets'] del self._params try: transports = [] loop = self._loop if sockets: for sock in sockets: transport, _ = await loop.create_datagram_endpoint( self.create_protocol, sock=sock) transports.append(transport) else: transport, _ = await loop.create_datagram_endpoint( self.create_protocol, local_addr=address) transports.append(transport) self._transports = transports self._started = loop.time() for transport in self._transports: address = transport.get_extra_info('sockname') self.logger.info('%s serving on %s', self._name, format_address(address)) self.fire_event('start') except Exception as exc: self.logger.exception('Error while starting UDP server') self.fire_event('start', exc=exc) self.fire_event('stop')
[docs] async def close(self): """Stop serving the :attr:`.Server.sockets` and close all concurrent connections. """ if not self.fired_event('stop'): transports, self._transports = self._transports, None if transports: for transport in transports: transport.close() self.fire_event('stop')
def info(self): sockets = [] up = int(self._loop.time() - self._started) if self._started else 0 server = {'uptime_in_seconds': up, 'sockets': sockets, 'max_requests': self._max_requests} clients = {'requests_processed': self._requests_processed} if self._transports: for transport in self._transports: sock = transport.get_extra_info('socket') if sock: sockets.append({ 'address': format_address(sock.getsockname()) }) return {'server': server, 'clients': clients}