Source code for pulsar.apps.ds.server

'''
Pulsar-ds is a python implementation of the popular redis_
data store. It uses pulsar asynchronous framework to create a
single-threaded worker responding to TCP-requests in the same way
as redis does.

To run a stand alone server create a script with the following code::


    from pulsar.apps.data import PulsarDS

    if __name__ == '__main__':
        PulsarDS().start()


More information on the :ref:`pulsar data store example <tutorials-pulsards>`.

Check out these benchmarks_

.. _benchmarks: https://gist.github.com/lsbardel/8068579

Implementation
===========================

Pulsar Data Store Server
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: PulsarDS
   :members:
   :member-order: bysource


Storage
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Storage
   :members:
   :member-order: bysource


.. _redis: http://redis.io/
'''
import os
import re
import time
import math
import pickle
from random import choice
from itertools import islice, chain
from functools import partial, reduce
from collections import namedtuple
from itertools import zip_longest

import pulsar
from pulsar.apps.socket import SocketServer
from pulsar.utils.config import Global
from pulsar.utils.structures import Dict, Zset, Deque

from .parser import redis_parser
from .utils import sort_command, count_bytes, and_op, or_op, xor_op, save_data
from .client import (command, PulsarStoreClient, Blocked,
                     COMMANDS_INFO, check_input, redis_to_py_pattern)


DEFAULT_PULSAR_STORE_ADDRESS = '127.0.0.1:6410'


def pulsards_url(address=None):
    if not address:
        actor = pulsar.get_actor()
        if actor:
            address = actor.cfg.data_store
    address = address or DEFAULT_PULSAR_STORE_ADDRESS
    if not address.startswith('pulsar://'):
        address = 'pulsar://%s' % address
    return address


# Keyspace changes notification classes
STRING_LIMIT = 2**32

nan = float('nan')


class RedisParserSetting(Global):
    name = "redis_py_parser"
    flags = ["--redis-py-parser"]
    action = "store_true"
    default = False
    desc = '''\
    Use the python redis parser rather the C implementation.

    Mainly used for benchmarking purposes.
    '''


class PulsarDsSetting(pulsar.Setting):
    virtual = True
    app = 'pulsards'
    section = "Pulsar data store server"


def validate_list_of_pairs(val):
    new_val = []
    if val:
        if not isinstance(val, (list, tuple)):
            raise TypeError("Not a list: %s" % val)
        for elem in val:
            if not isinstance(elem, (list, tuple)):
                raise TypeError("Not a list: %s" % elem)
            if not len(elem) == 2:
                raise TypeError("Not a pair: %s" % str(elem))
            new_val.append((int(elem[0]), int(elem[1])))
    return new_val


# #############################################################################
# #    CONFIGURATION PARAMETERS
class KeyValueDatabases(PulsarDsSetting):
    name = "key_value_databases"
    flags = ["--key-value-databases"]
    type = int
    default = 16
    desc = 'Number of databases for the key value store.'


class KeyValuePassword(PulsarDsSetting):
    name = "key_value_password"
    flags = ["--key-value-password"]
    default = ''
    desc = 'Optional password for the database.'


class KeyValueSave(PulsarDsSetting):
    name = "key_value_save"
    # default = [(900, 1), (300, 10), (60, 10000)]
    default = []
    validator = validate_list_of_pairs
    desc = '''\
        List of pairs controlling data store persistence.

        Will save the DB if both the given number of seconds and the given
        number of write operations against the DB occurred.

        The default behaviour will be to save:
        after 900 sec (15 min) if at least 1 key changed
        after 300 sec (5 min) if at least 10 keys changed
        after 60 sec if at least 10000 keys changed

        You can disable saving at all by setting an empty list
    '''


class KeyValueFileName(PulsarDsSetting):
    name = "key_value_filename"
    flags = ["--key-value-filename"]
    default = 'pulsards.rdb'
    desc = '''The filename where to dump the DB.'''


class TcpServer(pulsar.TcpServer):

    def __init__(self, cfg, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cfg = cfg
        self._parser_class = redis_parser(cfg.redis_py_parser)
        self._key_value_store = Storage(self, cfg)

    def info(self):
        info = super().info()
        info.update(self._key_value_store._info())
        return info


[docs]class PulsarDS(SocketServer): '''A :class:`.SocketServer` serving a pulsar datastore. ''' name = 'pulsards' cfg = pulsar.Config(bind=DEFAULT_PULSAR_STORE_ADDRESS, keep_alive=0, apps=['socket', 'pulsards']) def server_factory(self, *args, **kw): return TcpServer(self.cfg, *args, **kw) def protocol_factory(self): return partial(PulsarStoreClient, self.cfg) def monitor_start(self, monitor): cfg = self.cfg workers = min(1, cfg.workers) cfg.set('workers', workers) return super().monitor_start(monitor)
# ############################################################################# # # DATA STORE pubsub_patterns = namedtuple('pubsub_patterns', 're clients')
[docs]class Storage: '''Implement redis commands. ''' def __init__(self, server, cfg): self.cfg = cfg self._password = cfg.key_value_password.encode('utf-8') self._filename = cfg.key_value_filename self._writer = None self._server = server self._loop = server._loop self._parser = server._parser_class() self._missed_keys = 0 self._hit_keys = 0 self._expired_keys = 0 self._dirty = 0 self._bpop_blocked_clients = 0 self._last_save = int(time.time()) self._channels = {} self._patterns = {} # The set of clients which are watching keys self._watching = set() # The set of clients which issued the monitor command self._monitors = set() self.logger = server.logger # self.NOTIFY_KEYSPACE = (1 << 0) self.NOTIFY_KEYEVENT = (1 << 1) self.NOTIFY_GENERIC = (1 << 2) self.NOTIFY_STRING = (1 << 3) self.NOTIFY_LIST = (1 << 4) self.NOTIFY_SET = (1 << 5) self.NOTIFY_HASH = (1 << 6) self.NOTIFY_ZSET = (1 << 7) self.NOTIFY_EXPIRED = (1 << 8) self.NOTIFY_EVICTED = (1 << 9) self.NOTIFY_ALL = (self.NOTIFY_GENERIC | self.NOTIFY_STRING | self.NOTIFY_LIST | self.NOTIFY_SET | self.NOTIFY_HASH | self.NOTIFY_ZSET | self.NOTIFY_EXPIRED | self.NOTIFY_EVICTED) self.MONITOR = (1 << 2) self.MULTI = (1 << 3) self.BLOCKED = (1 << 4) self.DIRTY_CAS = (1 << 5) # self._event_handlers = {self.NOTIFY_GENERIC: self._generic_event, self.NOTIFY_STRING: self._string_event, self.NOTIFY_SET: self._set_event, self.NOTIFY_HASH: self._hash_event, self.NOTIFY_LIST: self._list_event, self.NOTIFY_ZSET: self._zset_event} self._set_options = (b'ex', b'px', b'nx', b'xx') self.OK = b'+OK\r\n' self.QUEUED = b'+QUEUED\r\n' self.ZERO = b':0\r\n' self.ONE = b':1\r\n' self.NIL = b'$-1\r\n' self.NULL_ARRAY = b'*-1\r\n' self.INVALID_TIMEOUT = 'invalid expire time' self.PUBSUB_ONLY = ('only (P)SUBSCRIBE / (P)UNSUBSCRIBE / QUIT ' 'allowed in this context') self.INVALID_SCORE = 'Invalid score value' self.NOT_SUPPORTED = 'Command not yet supported' self.OUT_OF_BOUND = 'Out of bound' self.SYNTAX_ERROR = 'Syntax error' self.SUBSCRIBE_COMMANDS = ('psubscribe', 'punsubscribe', 'subscribe', 'unsubscribe', 'quit') self.encoder = pickle self.hash_type = Dict self.list_type = Deque self.zset_type = Zset self.data_types = (bytearray, set, self.hash_type, self.list_type, self.zset_type) self.zset_aggregate = {b'min': min, b'max': max, b'sum': sum} self._type_event_map = {bytearray: self.NOTIFY_STRING, self.hash_type: self.NOTIFY_HASH, self.list_type: self.NOTIFY_LIST, set: self.NOTIFY_SET, self.zset_type: self.NOTIFY_ZSET} self._type_name_map = {bytearray: 'string', self.hash_type: 'hash', self.list_type: 'list', set: 'set', self.zset_type: 'zset'} self.databases = dict(((num, Db(num, self)) for num in range(cfg.key_value_databases))) # Initialise lua self.lua = None self.version = '2.4.10' self._loaddb() self._cron() # ######################################################################### # # KEYS COMMANDS @command('Keys', True, name='del') def delete(self, client, request, N): check_input(request, not N) rem = client.db.rem result = reduce(lambda x, y: x + rem(y), request[1:], 0) client.reply_int(result) @command('Keys') def dump(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_bulk() else: client.reply_bulk(self.encoder.dumps(value)) @command('Keys') def exists(self, client, request, N): check_input(request, N != 1) if client.db.exists(request[1]): client.reply_one() else: client.reply_zero() @command('Keys', True) def expire(self, client, request, N, m=1): check_input(request, N != 2) try: timeout = int(request[2]) except ValueError: client.reply_error(self.INVALID_TIMEOUT) else: if timeout: if timeout < 0: return client.reply_error(self.INVALID_TIMEOUT) if client.db.expire(request[1], m*timeout): return client.reply_one() client.reply_zero() @command('Keys', True) def expireat(self, client, request, N, M=1): check_input(request, N != 2) try: timeout = int(request[2]) except ValueError: client.reply_error(self.INVALID_TIMEOUT) else: if timeout: if timeout < 0: return client.reply_error(self.INVALID_TIMEOUT) timeout = M*timeout - time.time() if client.db.expire(request[1], timeout): return client.reply_one() client.reply_zero() @command('Keys') def keys(self, client, request, N): err = 'ignore' check_input(request, N != 1) pattern = request[1].decode('utf-8', err) allkeys = pattern == '*' gr = None if not allkeys: gr = re.compile(redis_to_py_pattern(pattern)) result = [key for key in client.db if allkeys or gr.search(key.decode('utf-8', err))] client.reply_multi_bulk(result) @command('Keys', supported=False) def migrate(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Keys', True) def move(self, client, request, N): check_input(request, N != 2) key = request[1] try: db2 = self.databases.get(int(request[2])) if db2 is None: raise ValueError except Exception: return client.reply_error('index out of range') db = client.db value = db.get(key) if db2.exists(key) or value is None: return client.reply_zero() assert value db.pop(key) self._signal(self.NOTIFY_GENERIC, db, 'del', key, 1) db2._data[key] = value self._signal(self._type_event_map[type(value)], db2, 'set', key, 1) client.reply_one() @command('Keys', supported=False) def object(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Keys', True) def persist(self, client, request, N): check_input(request, N != 1) if client.db.persist(request[1]): client.reply_one() else: client.reply_zero() @command('Keys', True) def pexpire(self, client, request, N): self.expire(client, request, N, 0.001) @command('Keys', True) def pexpireat(self, client, request, N, M=1): self.expireat(client, request, N, 0.001) @command('Keys') def pttl(self, client, request, N): check_input(request, N != 1) client.reply_int(client.db.ttl(request[1], 1000)) @command('Keys') def randomkey(self, client, request, N): check_input(request, N) keys = list(client.db) if keys: client.reply_bulk(choice(keys)) else: client.reply_bulk() @command('Keys', True) def rename(self, client, request, N, ex=False): check_input(request, N != 2) key1, key2 = request[1], request[2] db = client.db value = db.get(key1) if value is None: client.reply_error('Cannot rename key, not available') elif key1 == key2: client.reply_error('Cannot rename key') else: assert value if ex: if db.exists(key2): return client.reply_zero() result = 1 else: result = 0 if db.pop(key2) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key2) db.pop(key1) event = self._type_event_map[type(value)] dirty = 1 if event == self.NOTIFY_STRING else len(value) db._data[key2] = value self._signal(event, db, request[0], key2, dirty) client.reply_one() if result else client.reply_ok() @command('Keys', True) def renamenx(self, client, request, N): self.rename(client, request, N, True) @command('Keys', True) def restore(self, client, request, N): check_input(request, N != 3) key = request[1] db = client.db try: value = self.encoder.loads(request[3]) except Exception: value = None if not isinstance(value, self.data_types): return client.reply_error('Could not decode value') try: ttl = int(request[2]) except Exception: return client.reply_error(self.INVALID_TIMEOUT) if db.pop(key) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) db._data[key] = value if ttl > 0: db.expire(key, ttl) client.reply_ok() @command('Keys', True) def sort(self, client, request, N): check_input(request, not N) value = client.db.get(request[1]) if value is None: value = self.list_type() elif not isinstance(value, (set, self.list_type, self.zset_type)): return client.reply_wrongtype() sort_command(self, client, request, value) @command('Keys', True) def ttl(self, client, request, N): check_input(request, N != 1) client.reply_int(client.db.ttl(request[1])) @command('Keys', True) def type(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: result = 'none' else: result = self._type_name_map[type(value)] client.reply_status(result) @command('Keys', supported=False) def scan(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) # ######################################################################### # # STRING COMMANDS @command('Strings', True) def append(self, client, request, N): check_input(request, N != 2,) key = request[1] db = client.db value = db.get(key) if value is None: value = bytearray(request[2]) db._data[key] = value elif not isinstance(value, bytearray): return client.reply_wrongtype() else: assert value value.extend(request[2]) self._signal(self.NOTIFY_STRING, db, request[0], key, 1) client.reply_int(len(value)) @command('Strings') def bitcount(self, client, request, N): check_input(request, N < 1 or N > 3) key = request[1] db = client.db value = db.get(key) if value is None: client.reply_int(0) elif not isinstance(value, bytearray): return client.reply_wrongtype() else: assert value if N > 1: start = request[2] end = request[3] if N == 3 else -1 start, end = self._range_values(value, start, end) value = value[start:end] client.reply_int(count_bytes(value)) @command('Strings', True) def bitop(self, client, request, N): check_input(request, N < 3) db = client.db op = request[1].lower() if op == b'and': reduce_op = and_op elif op == b'or': reduce_op = or_op elif op == b'xor': reduce_op = xor_op elif op == b'not': reduce_op = None check_input(request, N != 3) else: return client.reply_error('bad command') empty = bytearray() keys = [] for key in request[3:]: value = db.get(key) if value is None: keys.append(empty) elif isinstance(value, bytearray): keys.append(value) else: return client.reply_wrongtype() result = bytearray() if reduce_op is None: for value in keys[0]: result.append(~value & 255) else: for values in zip_longest(*keys, **{'fillvalue': 0}): result.append(reduce(reduce_op, values)) if result: dest = request[2] if db.pop(dest): self._signal(self.NOTIFY_GENERIC, db, 'del', dest) db._data[dest] = result self._signal(self.NOTIFY_STRING, db, 'set', dest, 1) client.reply_int(len(result)) else: client.reply_zero() @command('Strings', True) def decr(self, client, request, N): check_input(request, N != 1) r = self._incrby(client, request[0], request[1], b'-1', int) client.reply_int(r) @command('Strings', True) def decrby(self, client, request, N): check_input(request, N != 2) try: val = str(-int(request[2])).encode('utf-8') except Exception: val = request[2] r = self._incrby(client, request[0], request[1], val, int) client.reply_int(r) @command('Strings') def get(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_bulk() elif isinstance(value, bytearray): assert value client.reply_bulk(bytes(value)) else: client.reply_wrongtype() @command('Strings') def getbit(self, client, request, N): check_input(request, N != 2) try: bitoffset = int(request[2]) if bitoffset < 0 or bitoffset >= STRING_LIMIT: raise ValueError except Exception: return client.reply_error( "bit offset is not an integer or out of range") string = client.db.get(request[1]) if string is None: client.reply_zero() elif not isinstance(string, bytearray): client.reply_wrongtype() else: assert string byte = bitoffset >> 3 if len(string) > byte: bit = 7 - (bitoffset & 7) v = string[byte] & (1 << bit) client.reply_int(1 if v else 0) else: client.reply_zero() @command('Strings') def getrange(self, client, request, N): check_input(request, N != 3) try: start = int(request[2]) end = int(request[3]) except Exception: return client.reply_error("Wrong offset in '%s' command" % request[0]) string = client.db.get(request[1]) if string is None: client.reply_bulk(b'') elif not isinstance(string, bytearray): client.reply_wrongtype() else: if start < 0: start = len(string) + start if end < 0: end = len(string) + end + 1 else: end += 1 client.reply_bulk(bytes(string[start:end])) @command('Strings', True) def getset(self, client, request, N): check_input(request, N != 2) key = request[1] db = client.db value = db.get(key) if value is None: db._data[key] = bytearray(request[2]) self._signal(self.NOTIFY_STRING, db, 'set', key, 1) client.reply_bulk() elif isinstance(value, bytearray): db.pop(key) db._data[key] = bytearray(request[2]) self._signal(self.NOTIFY_STRING, db, 'set', key, 1) client.reply_bulk(bytes(value)) else: client.reply_wrongtype() @command('Strings', True) def incr(self, client, request, N): check_input(request, N != 1) r = self._incrby(client, request[0], request[1], b'1', int) client.reply_int(r) @command('Strings', True) def incrby(self, client, request, N): check_input(request, N != 2) r = self._incrby(client, request[0], request[1], request[2], int) client.reply_int(r) @command('Strings', True) def incrbyfloat(self, client, request, N): check_input(request, N != 2) r = self._incrby(client, request[0], request[1], request[2], float) client.reply_bulk(str(r).encode('utf-8')) @command('Strings') def mget(self, client, request, N): check_input(request, not N) get = client.db.get values = [] for key in request[1:]: value = get(key) if value is None: values.append(value) elif isinstance(value, bytearray): values.append(bytes(value)) else: return client.reply_wrongtype() client.reply_multi_bulk(values) @command('Strings', True) def mset(self, client, request, N): D = N // 2 check_input(request, N < 2 or D * 2 != N) db = client.db for key, value in zip(request[1::2], request[2::2]): db.pop(key) db._data[key] = bytearray(value) self._signal(self.NOTIFY_STRING, db, 'set', key, 1) client.reply_ok() @command('Strings', True) def msetnx(self, client, request, N): D = N // 2 check_input(request, N < 2 or D * 2 != N) db = client.db keys = request[1::2] exist = False for key in keys: exist = db.exists(key) if exist: break if exist: client.reply_zero() else: for key, value in zip(keys, request[2::2]): db._data[key] = bytearray(value) self._signal(self.NOTIFY_STRING, db, 'set', key, 1) client.reply_one() @command('Strings', True) def psetex(self, client, request, N): check_input(request, N != 3) self._set(client, request[1], request[3], milliseconds=request[2]) client.reply_ok() @command('Strings', True) def set(self, client, request, N): check_input(request, N < 2 or N > 8) it = 2 extra = set(self._set_options) seconds = 0 milliseconds = 0 nx = False xx = False while N > it: it += 1 opt = request[it].lower() if opt in extra: extra.remove(opt) if opt == b'ex': it += 1 seconds = request[it] elif opt == b'px': it += 1 milliseconds = request[it] elif opt == b'nx': nx = True else: xx = True if self._set(client, request[1], request[2], seconds, milliseconds, nx, xx): client.reply_ok() else: client.reply_bulk() @command('Strings', True) def setbit(self, client, request, N): check_input(request, N != 3) key = request[1] try: bitoffset = int(request[2]) if bitoffset < 0 or bitoffset >= STRING_LIMIT: raise ValueError except Exception: return client.reply_error( "bit offset is not an integer or out of range") try: value = int(request[3]) if value not in (0, 1): raise ValueError except Exception: return client.reply_error("bit is not an integer or out of range") db = client.db string = db.get(key) if string is None: string = bytearray() db._data[key] = string elif not isinstance(string, bytearray): return client.reply_wrongtype() else: assert string # grow value to the right if necessary byte = bitoffset >> 3 num_bytes = len(string) if byte >= num_bytes: string.extend((byte + 1 - num_bytes)*b'\x00') # get current value byteval = string[byte] bit = 7 - (bitoffset & 7) bitval = byteval & (1 << bit) # update with new value byteval &= ~(1 << bit) byteval |= ((value & 1) << bit) string[byte] = byteval self._signal(self.NOTIFY_STRING, db, request[0], key, 1) client.reply_one() if bitval else client.reply_zero() @command('Strings', True) def setex(self, client, request, N): check_input(request, N != 3) self._set(client, request[1], request[3], seconds=request[2]) client.reply_ok() @command('Strings', True) def setnx(self, client, request, N): check_input(request, N != 2) if self._set(client, request[1], request[2], nx=True): client.reply_one() else: client.reply_zero() @command('Strings', True) def setrange(self, client, request, N): check_input(request, N != 3) key = request[1] value = request[3] try: offset = int(request[2]) T = offset + len(value) if offset < 0 or T >= STRING_LIMIT: raise ValueError except Exception: return client.reply_error("Wrong offset in '%s' command" % request[0]) db = client.db string = db.get(key) if string is None: string = bytearray(b'') db._data[key] = string elif not isinstance(string, bytearray): return client.reply_wrongtype() N = len(string) if N < T: string.extend((T - N)*b'\x00') string[offset:T] = value self._signal(self.NOTIFY_STRING, db, request[0], key, 1) client.reply_int(len(string)) @command('Strings') def strlen(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_zero() elif isinstance(value, bytearray): client.reply_int(len(value)) else: return client.reply_wrongtype() # ######################################################################### # # HASHES COMMANDS @command('Hashes', True) def hdel(self, client, request, N): check_input(request, N < 2) key = request[1] db = client.db value = db.get(key) if value is None: client.reply_zero() elif isinstance(value, self.hash_type): rem = 0 for field in request[2:]: rem += 0 if value.pop(field, None) is None else 1 self._signal(self.NOTIFY_HASH, db, request[0], key, rem) if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) client.reply_int(rem) else: client.reply_wrongtype() @command('Hashes') def hexists(self, client, request, N): check_input(request, N != 2) value = client.db.get(request[1]) if value is None: client.reply_zero() elif isinstance(value, self.hash_type): client.reply_int(int(request[2] in value)) else: client.reply_wrongtype() @command('Hashes') def hget(self, client, request, N): check_input(request, N != 2) value = client.db.get(request[1]) if value is None: client.reply_bulk() elif isinstance(value, self.hash_type): client.reply_bulk(value.get(request[2])) else: client.reply_wrongtype() @command('Hashes') def hgetall(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_multi_bulk(()) elif isinstance(value, self.hash_type): client.reply_multi_bulk(value.flat()) else: client.reply_wrongtype() @command('Hashes', True) def hincrby(self, client, request, N): result = self._hincrby(client, request, N, int) if result is not None: client.reply_int(result) @command('Hashes', True) def hincrbyfloat(self, client, request, N): result = self._hincrby(client, request, N, float) if result is not None: client.reply_bulk(str(result).encode('utf-8')) @command('Hashes') def hkeys(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_multi_bulk(()) elif isinstance(value, self.hash_type): client.reply_multi_bulk(value) else: client.reply_wrongtype() @command('Hashes') def hlen(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_zero() elif isinstance(value, self.hash_type): client.reply_int(len(value)) else: client.reply_wrongtype() @command('Hashes') def hmget(self, client, request, N): check_input(request, N < 3) value = client.db.get(request[1]) if value is None: client.reply_multi_bulk(()) elif isinstance(value, self.hash_type): result = value.mget(request[2:]) client.reply_multi_bulk(result) else: client.reply_wrongtype() @command('Hashes', True) def hmset(self, client, request, N): D = (N - 1) // 2 check_input(request, N < 3 or D * 2 != N - 1) key = request[1] db = client.db value = db.get(key) if value is None: value = self.hash_type() db._data[key] = value elif not isinstance(value, self.hash_type): return client.reply_wrongtype() it = iter(request[2:]) value.update(zip(it, it)) self._signal(self.NOTIFY_HASH, db, request[0], key, D) client.reply_ok() @command('Hashes', True) def hset(self, client, request, N): check_input(request, N != 3) key, field = request[1], request[2] db = client.db value = db.get(key) if value is None: value = self.hash_type() db._data[key] = value elif not isinstance(value, self.hash_type): return client.reply_wrongtype() avail = (field in value) value[field] = request[3] self._signal(self.NOTIFY_HASH, db, request[0], key, 1) client.reply_zero() if avail else client.reply_one() @command('Hashes', True) def hsetnx(self, client, request, N): check_input(request, N != 3) key, field = request[1], request[2] db = client.db value = db.get(key) if value is None: value = self.hash_type() db._data[key] = value elif not isinstance(value, self.hash_type): return client.reply_wrongtype() if field in value: client.reply_zero() else: value[field] = request[3] self._signal(self.NOTIFY_HASH, db, request[0], key, 1) client.reply_one() @command('Hashes') def hvals(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_multi_bulk(()) elif isinstance(value, self.hash_type): client.reply_multi_bulk(tuple(value.values())) else: client.reply_wrongtype() @command('Hashes', supported=False) def hscan(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) # ######################################################################### # # LIST COMMANDS @command('Lists', True, script=0) def blpop(self, client, request, N): check_input(request, N < 2) try: timeout = max(0, int(request[-1])) except Exception: return client.reply_error(self.SYNTAX_ERROR) keys = request[1:-1] if not self._bpop(client, request, keys): client.blocked = Blocked(client, request[0], keys, timeout) @command('Lists', True, script=0) def brpop(self, client, request, N): return self.blpop(client, request, N) @command('Lists', True, script=0) def brpoplpush(self, client, request, N): check_input(request, N != 3) try: timeout = max(0, int(request[-1])) except Exception: return client.reply_error(self.SYNTAX_ERROR) key, dest = request[1:-1] keys = (key,) if not self._bpop(client, request, keys, dest): client.blocked = Blocked(client, request[0], keys, timeout, dest) @command('Lists') def lindex(self, client, request, N): check_input(request, N != 2) value = client.db.get(request[1]) if value is None: client.reply_bulk() elif isinstance(value, self.list_type): assert value index = int(request[2]) if index >= 0 and index < len(value): client.reply_bulk(value[index]) else: client.reply_bulk() else: client.reply_wrongtype() @command('Lists', True) def linsert(self, client, request, N): # This method is INEFFICIENT, but redis supported so we do # the same here check_input(request, N != 4) db = client.db key = request[1] value = db.get(key) if value is None: client.reply_zero() elif not isinstance(value, self.list_type): client.reply_wrongtype() else: assert value where = request[2].lower() l1 = len(value) if where == b'before': value.insert_before(request[3], request[4]) elif where == b'after': value.insert_after(request[3], request[4]) else: return client.reply_error('cannot insert to list') l2 = len(value) if l2 - l1: self._signal(self.NOTIFY_LIST, db, request[0], key, 1) client.reply_int(l2) else: client.reply_int(-1) @command('Lists') def llen(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_zero() elif isinstance(value, self.list_type): assert value client.reply_int(len(value)) else: client.reply_wrongtype() @command('Lists', True) def lpop(self, client, request, N): check_input(request, N != 1) db = client.db key = request[1] value = db.get(key) if value is None: client.reply_bulk() elif not isinstance(value, self.list_type): client.reply_wrongtype() else: assert value if request[0] == 'lpop': result = value.popleft() else: result = value.pop() self._signal(self.NOTIFY_LIST, db, request[0], key, 1) if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) client.reply_bulk(result) @command('Lists', True) def rpop(self, client, request, N): return self.lpop(client, request, N) @command('Lists', True) def lpush(self, client, request, N): check_input(request, N < 2) key = request[1] db = client.db value = db.get(key) if value is None: value = self.list_type() db._data[key] = value elif not isinstance(value, self.list_type): return client.reply_wrongtype() else: assert value if request[0] == 'lpush': value.extendleft(request[2:]) else: value.extend(request[2:]) client.reply_int(len(value)) self._signal(self.NOTIFY_LIST, db, request[0], key, N - 1) @command('Lists', True) def rpush(self, client, request, N): return self.lpush(client, request, N) @command('Lists', True) def lpushx(self, client, request, N): check_input(request, N != 2) key = request[1] db = client.db value = db.get(key) if value is None: client.reply_zero() elif not isinstance(value, self.list_type): client.reply_wrongtype() else: assert value if request[0] == 'lpushx': value.appendleft(request[2]) else: value.append(request[2]) client.reply_int(len(value)) self._signal(self.NOTIFY_LIST, db, request[0], key, 1) @command('Lists', True) def rpushx(self, client, request, N): return self.lpushx(client, request, N) @command('Lists', True) def lrange(self, client, request, N): check_input(request, N != 3) db = client.db key = request[1] value = db.get(key) try: start, end = self._range_values(value, request[2], request[3]) except Exception: return client.reply_error('invalid range') if value is None: client.reply_multi_bulk(()) elif not isinstance(value, self.list_type): client.reply_wrongtype() else: assert value client.reply_multi_bulk(tuple(islice(value, start, end))) @command('Lists', True) def lrem(self, client, request, N): # This method is INEFFICIENT, but redis supported so we do # the same here check_input(request, N != 3) db = client.db key = request[1] value = db.get(key) if value is None: client.reply_zero() elif not isinstance(value, self.list_type): client.reply_wrongtype() else: assert value try: count = int(request[2]) except Exception: return client.reply_error('cannot remove from list') removed = value.remove(request[3], count) if removed: self._signal(self.NOTIFY_LIST, db, request[0], key, removed) client.reply_int(removed) if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) @command('Lists', True) def lset(self, client, request, N): check_input(request, N != 3) db = client.db key = request[1] value = db.get(key) if value is None: client.reply_error(self.OUT_OF_BOUND) elif not isinstance(value, self.list_type): client.reply_wrongtype() else: assert value try: index = int(request[2]) except Exception: index = -1 if index >= 0 and index < len(value): value[index] = request[3] self._signal(self.NOTIFY_LIST, db, request[0], key, 1) client.reply_ok() else: client.reply_error(self.OUT_OF_BOUND) @command('Lists', True) def ltrim(self, client, request, N): check_input(request, N != 3) db = client.db key = request[1] value = db.get(key) try: start, end = self._range_values(value, request[2], request[3]) except Exception: return client.reply_error('invalid range') if value is None: client.reply_ok() elif not isinstance(value, self.list_type): client.reply_wrongtype() else: assert value value.trim(start, end) self._signal(self.NOTIFY_LIST, db, request[0], key, start-len(value)) client.reply_ok() if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) @command('Lists', True) def rpoplpush(self, client, request, N): check_input(request, N != 2) key1, key2 = request[1], request[2] db = client.db orig = db.get(key1) dest = db.get(key2) if orig is None: client.reply_bulk() elif not isinstance(orig, self.list_type): client.reply_wrongtype() else: assert orig if dest is None: dest = self.list_type() db._data[key2] = dest elif not isinstance(dest, self.list_type): return client.reply_wrongtype() else: assert dest value = orig.pop() self._signal(self.NOTIFY_LIST, db, 'rpop', key1, 1) dest.appendleft(value) self._signal(self.NOTIFY_LIST, db, 'lpush', key2, 1) if db.pop(key1, orig) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key1) client.reply_bulk(value) # ######################################################################### # # SETS COMMANDS @command('Sets', True) def sadd(self, client, request, N): check_input(request, N < 2) key = request[1] db = client.db value = db.get(key) if value is None: value = set() db._data[key] = value elif not isinstance(value, set): return client.reply_wrongtype() n = len(value) value.update(request[2:]) n = len(value) - n self._signal(self.NOTIFY_SET, db, request[0], key, n) client.reply_int(n) @command('Sets') def scard(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_zero() elif not isinstance(value, set): client.reply_wrongtype() else: client.reply_int(len(value)) @command('Sets') def sdiff(self, client, request, N): check_input(request, N < 1) self._setoper(client, 'difference', request[1:]) @command('Sets', True) def sdiffstore(self, client, request, N): check_input(request, N < 2) self._setoper(client, 'difference', request[2:], request[1]) @command('Sets') def sinter(self, client, request, N): check_input(request, N < 1) self._setoper(client, 'intersection', request[1:]) @command('Sets', True) def sinterstore(self, client, request, N): check_input(request, N < 2) self._setoper(client, 'intersection', request[2:], request[1]) @command('Sets') def sismember(self, client, request, N): check_input(request, N != 2) value = client.db.get(request[1]) if value is None: client.reply_zero() elif not isinstance(value, set): client.reply_wrongtype() else: client.reply_int(int(request[2] in value)) @command('Sets') def smembers(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_multi_bulk(()) elif not isinstance(value, set): client.reply_wrongtype() else: client.reply_multi_bulk(value) @command('Sets', True) def smove(self, client, request, N): check_input(request, N != 3) db = client.db key1 = request[1] key2 = request[2] orig = db.get(key1) dest = db.get(key2) if orig is None: client.reply_zero() elif not isinstance(orig, set): client.reply_wrongtype() else: member = request[3] if member in orig: # we my be able to move if dest is None: dest = set() db._data[request[2]] = dest elif not isinstance(dest, set): return client.reply_wrongtype() orig.remove(member) dest.add(member) self._signal(self.NOTIFY_SET, db, 'srem', key1) self._signal(self.NOTIFY_SET, db, 'sadd', key2, 1) if db.pop(key1, orig) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key1) client.reply_one() else: client.reply_zero() @command('Sets', True) def spop(self, client, request, N): check_input(request, N != 1) key = request[1] db = client.db value = db.get(key) if value is None: client.reply_bulk() elif not isinstance(value, set): client.reply_wrongtype() else: result = value.pop() self._signal(self.NOTIFY_SET, db, request[0], key, 1) if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) client.reply_bulk(result) @command('Sets') def srandmember(self, client, request, N): check_input(request, N < 1 or N > 2) value = client.db.get(request[1]) if value is not None and not isinstance(value, set): return client.reply_wrongtype() if N == 2: try: count = int(request[2]) except Exception: return client.reply_error('Invalid count') if count < 0: count = -count if not value: result = (None,) * count else: result = [] for _ in range(count): el = value.pop() result.append(el) value.add(el) elif count > 0: if not value: result = (None,) elif len(value) <= count: result = list(value) result.extend((None,)*(count-len(value))) else: result = [] for _ in range(count): el = value.pop() result.append(el) value.update(result) else: result = [] client.reply_multi_bulk(result) else: if not value: result = None else: result = value.pop() value.add(result) client.reply_bulk(result) @command('Sets', True) def srem(self, client, request, N): check_input(request, N < 2) db = client.db key = request[1] value = db.get(key) if value is None: client.reply_zero() elif not isinstance(value, set): client.reply_wrongtype() else: start = len(value) value.difference_update(request[2:]) removed = start - len(value) self._signal(self.NOTIFY_SET, db, request[0], key, removed) if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) client.reply_int(removed) @command('Sets') def sunion(self, client, request, N): check_input(request, N < 1) self._setoper(client, 'union', request[1:]) @command('Sets', True) def sunionstore(self, client, request, N): check_input(request, N < 2) self._setoper(client, 'union', request[2:], request[1]) @command('Sets', supported=False) def sscan(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) # ######################################################################### # # SORTED SETS COMMANDS @command('Sorted Sets', True) def zadd(self, client, request, N): D = (N - 1) // 2 check_input(request, N < 3 or D * 2 != N - 1) key = request[1] db = client.db value = db.get(key) if value is None: value = self.zset_type() db._data[key] = value elif not isinstance(value, self.zset_type): return client.reply_wrongtype() start = len(value) value.update(zip(map(float, request[2::2]), request[3::2])) result = len(value) - start self._signal(self.NOTIFY_ZSET, db, request[0], key, result) client.reply_int(result) @command('Sorted Sets') def zcard(self, client, request, N): check_input(request, N != 1) value = client.db.get(request[1]) if value is None: client.reply_zero() elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: client.reply_int(len(value)) @command('Sorted Sets') def zcount(self, client, request, N): check_input(request, N != 3) value = client.db.get(request[1]) if value is None: client.reply_zero() elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: min_value, max_value = request[2], request[3] include_min = include_max = True if min_value and min_value[0] == 40: include_min = False min_value = min_value[1:] if max_value and max_value[0] == 40: include_max = False max_value = max_value[1:] try: mmin = float(min_value) mmax = float(max_value) except Exception: client.reply_error(self.INVALID_SCORE) else: client.reply_int(value.count(mmin, mmax, include_min, include_max)) @command('Sorted Sets', True) def zincrby(self, client, request, N): check_input(request, N != 3) key = request[1] db = client.db value = db.get(key) if value is None: db._data[key] = value = self.zset_type() elif not isinstance(value, self.zset_type): return client.reply_wrongtype() try: increment = float(request[2]) except Exception: client.reply_error(self.INVALID_SCORE) else: member = request[3] score = value.score(member, 0) + increment value.add(score, member) self._signal(self.NOTIFY_ZSET, db, request[0], key, 1) client.reply_bulk(str(score).encode('utf-8')) @command('Sorted Sets', True) def zinterstore(self, client, request, N): self._zsetoper(client, request, N) @command('Sorted Sets') def zrange(self, client, request, N): check_input(request, N < 3 or N > 4) value = client.db.get(request[1]) if value is None: client.reply_multi_bulk(()) elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: try: start, end = self._range_values(value, request[2], request[3]) except Exception: return client.reply_error(self.SYNTAX_ERROR) # reverse = (request[0] == b'zrevrange') if N == 4: if request[4].lower() == b'withscores': result = [] [result.extend((v, score)) for score, v in value.range(start, end, scores=True)] else: return client.reply_error(self.SYNTAX_ERROR) else: result = list(value.range(start, end)) client.reply_multi_bulk(result) @command('Sorted Sets') def zrangebyscore(self, client, request, N): check_input(request, N < 3 or N > 7) value = client.db.get(request[1]) if value is None: client.reply_multi_bulk(()) elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: try: minval, include_min, maxval, include_max = self._score_values( request[2], request[3]) except Exception: return client.reply_error(self.SYNTAX_ERROR) request = request[4:] withscores = False offset = 0 count = None while request: if request[0].lower() == b'withscores': withscores = True request = request[1:] elif request[0].lower() == b'limit': try: offset = int(request[1]) count = int(request[2]) except Exception: return client.reply_error(self.SYNTAX_ERROR) request = request[3:] else: return client.reply_error(self.SYNTAX_ERROR) if withscores: result = [] [result.extend((v, score)) for score, v in value.range_by_score(minval, maxval, scores=True, start=offset, num=count, include_min=include_min, include_max=include_max)] else: result = list(value.range_by_score(minval, maxval, start=offset, num=count, include_min=include_min, include_max=include_max)) client.reply_multi_bulk(result) @command('Sorted Sets') def zrank(self, client, request, N): check_input(request, N != 2) value = client.db.get(request[1]) if value is None: client.reply_bulk() elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: rank = value.rank(request[2]) if rank is not None: client.reply_int(rank) else: client.reply_bulk() @command('Sorted Sets', True) def zrem(self, client, request, N): check_input(request, N < 2) key = request[1] db = client.db value = db.get(key) if value is None: client.reply_zero() elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: removed = value.remove_items(request[2:]) if removed: self._signal(self.NOTIFY_ZSET, db, request[0], key, removed) if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) client.reply_int(removed) @command('Sorted Sets', True) def zremrangebyrank(self, client, request, N): check_input(request, N != 3) key = request[1] db = client.db value = db.get(key) if value is None: client.reply_zero() elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: try: start, end = self._range_values(value, request[2], request[3]) except Exception: return client.reply_error(self.SYNTAX_ERROR) removed = value.remove_range(start, end) if removed: self._signal(self.NOTIFY_ZSET, db, request[0], key, removed) if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) client.reply_int(removed) @command('Sorted Sets', True) def zremrangebyscore(self, client, request, N): check_input(request, N != 3) key = request[1] db = client.db value = db.get(key) if value is None: client.reply_zero() elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: try: minval, include_min, maxval, include_max = self._score_values( request[2], request[3]) except Exception: return client.reply_error(self.SYNTAX_ERROR) removed = value.remove_range_by_score(minval, maxval, include_min, include_max) if removed: self._signal(self.NOTIFY_ZSET, db, request[0], key, removed) if db.pop(key, value) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', key) client.reply_int(removed) @command('Sorted Sets', supported=False) def zrevrange(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Sorted Sets', supported=False) def zrevrangebyscore(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Sorted Sets') def zscore(self, client, request, N): check_input(request, N != 2) key = request[1] db = client.db value = db.get(key) if value is None: client.reply_bulk(None) elif not isinstance(value, self.zset_type): client.reply_wrongtype() else: score = value.score(request[2], None) if score is not None: score = str(score).encode('utf-8') client.reply_bulk(score) @command('Sorted Sets', True) def zunionstore(self, client, request, N): self._zsetoper(client, request, N) @command('Sorted Sets', supported=False) def zscan(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) # ######################################################################### # # PUBSUB COMMANDS @command('Pub/Sub', script=0) def psubscribe(self, client, request, N): check_input(request, not N) for pattern in request[1:]: p = self._patterns.get(pattern) if not p: pre = redis_to_py_pattern(pattern.decode('utf-8')) p = pubsub_patterns(re.compile(pre), set()) self._patterns[pattern] = p p.clients.add(client) client.patterns.add(pattern) count = reduce(lambda x, y: x + int(client in y.clients), self._patterns.values()) client.reply_multi_bulk((b'psubscribe', pattern, count)) @command('Pub/Sub') def pubsub(self, client, request, N): check_input(request, not N) subcommand = request[1].decode('utf-8').lower() if subcommand == 'channels': check_input(request, N > 2) if N == 2: pre = re.compile(redis_to_py_pattern( request[2].decode('utf-8'))) channels = [] for channel in self._channels: if pre.match(channel.decode('utf-8', 'ignore')): channels.append(channel) else: channels = list(self._channels) client.reply_multi_bulk(channels) elif subcommand == 'numsub': count = [] for channel in request[2:]: clients = self._channels.get(channel, ()) count.extend((channel, len(clients))) client.reply_multi_bulk(count) elif subcommand == 'numpat': check_input(request, N > 1) count = reduce(lambda x, y: x + len(y.clients), self._patterns.values()) client.reply_int(count) else: client.reply_error("Unknown command 'pubsub %s'" % subcommand) @command('Pub/Sub') def publish(self, client, request, N): check_input(request, N != 2) channel, message = request[1:] ch = channel.decode('utf-8') msg = self._parser.multi_bulk((b'message', channel, message)) count = self._publish_clients(msg, self._channels.get(channel, ())) for pattern in self._patterns.values(): g = pattern.re.match(ch) if g: count += self._publish_clients(msg, pattern.clients) client.reply_int(count) @command('Pub/Sub', script=0) def punsubscribe(self, client, request, N): patterns = request[1:] if N else list(self._patterns) for pattern in patterns: if pattern in self._patterns: p = self._patterns[pattern] if client in p.clients: client.patterns.discard(pattern) p.clients.remove(client) if not p.clients: self._patterns.pop(pattern) client.reply_multi_bulk((b'punsubscribe', pattern)) @command('Pub/Sub', script=0) def subscribe(self, client, request, N): check_input(request, not N) for channel in request[1:]: clients = self._channels.get(channel) if not clients: self._channels[channel] = clients = set() clients.add(client) client.channels.add(channel) client.reply_multi_bulk((b'subscribe', channel, len(clients))) @command('Pub/Sub', script=0) def unsubscribe(self, client, request, N): channels = request[1:] if N else list(self._channels) for channel in channels: if channel in self._channels: clients = self._channels[channel] if client in clients: client.channels.discard(channel) clients.remove(client) if not clients: self._channels.pop(channel) client.reply_multi_bulk((b'unsubscribe', channel)) # ######################################################################### # # TRANSACTION COMMANDS @command('Transactions', script=0) def discard(self, client, request, N): check_input(request, N) if client.transaction is None: client.reply_error("DISCARD without MULTI") else: self._close_transaction(client) client.reply_ok() @command('Transactions', name='exec', script=0) def execute(self, client, request, N): check_input(request, N) if client.transaction is None: client.reply_error("EXEC without MULTI") else: requests = client.transaction if client.flag & self.DIRTY_CAS: self._close_transaction(client) client.reply_multi_bulk(()) else: self._close_transaction(client) client.reply_multi_bulk_len(len(requests)) for handle, request in requests: client._execute_command(handle, request) @command('Transactions', script=0) def multi(self, client, request, N): check_input(request, N) if client.transaction is None: client.reply_ok() client.transaction = [] else: self.error_replay("MULTI calls can not be nested") @command('Transactions', script=0) def watch(self, client, request, N): check_input(request, not N) if client.transaction is not None: client.reply_error("WATCH inside MULTI is not allowed") else: wkeys = client.watched_keys if not wkeys: client.watched_keys = wkeys = set() self._watching.add(client) wkeys.update(request[1:]) client.reply_ok() @command('Transactions', script=0) def unwatch(self, client, request, N): check_input(request, N) transaction = client.transaction self._close_transaction(client) client.transaction = transaction client.reply_ok() # ######################################################################### # # SCRIPTING @command('Scripting', supported=False) def eval(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Scripting', supported=False) def evalsha(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Scripting', supported=False) def script(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) # ######################################################################### # # CONNECTION COMMANDS @command('Connections', script=0) def auth(self, client, request, N): check_input(request, N != 1) client.password = request[1] if client.password != client._producer.password: client.reply_error("wrong password") else: client.reply_ok() @command('Connections') def echo(self, client, request, N): check_input(request, N != 1) client.reply_bulk(request[1]) @command('Connections') def ping(self, client, request, N): check_input(request, N) client.reply_status('PONG') @command('Connections', script=0) def quit(self, client, request, N): check_input(request, N) client.reply_ok() client.close() @command('Connections') def select(self, client, request, N): check_input(request, N != 1) D = len(self.databases) - 1 try: num = int(request[1]) if num < 0 or num > D: raise ValueError except ValueError: client.reply_error(('select requires a database number between ' '%s and %s' % (0, D))) else: client.database = num client.reply_ok() # ######################################################################### # # SERVER COMMANDS @command('Server', supported=False) def bgrewriteaof(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Server') def bgsave(self, client, request, N): check_input(request, N) self._save() client.reply_ok() @command('Server') def client(self, client, request, N): check_input(request, not N) subcommand = request[1].decode('utf-8').lower() if subcommand == 'list': check_input(request, N != 1) value = '\n'.join(self._client_list(client)) client.reply_bulk(value.encode('utf-8')) else: client.reply_error("unknown command 'client %s'" % subcommand) @command('Server') def config(self, client, request, N): check_input(request, not N) subcommand = request[1].decode('utf-8').lower() if subcommand == 'get': if N != 2: client.reply_error("'config get' no argument") else: value = self._get_config(request[2].decode('utf-8')) client.reply_bulk(value) elif subcommand == 'rewrite': client.reply_ok() elif subcommand == 'set': try: if N != 3: raise ValueError("'config set' no argument") self._set_config(request[2].decode('utf-8')) except Exception as e: client.reply_error(str(e)) else: client.reply_ok() elif subcommand == 'resetstat': self._hit_keys = 0 self._missed_keys = 0 self._expired_keys = 0 server = client._producer server._received = 0 server._requests_processed = 0 client.reply_ok() else: client.reply_error("'config %s' not valid" % subcommand) @command('Server') def dbsize(self, client, request, N): check_input(request, N != 0) client.reply_int(len(client.db)) @command('Server', supported=False, subcommands=['object', 'segfault']) def debug(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Server', True) def flushdb(self, client, request, N): check_input(request, N) client.db.flush() client.reply_ok() @command('Server', True) def flushall(self, client, request, N): check_input(request, N) for db in self.databases.values(): db.flush() client.reply_ok() @command('Server') def info(self, client, request, N): check_input(request, N) info = '\n'.join(self._flat_info()) client.reply_bulk(info.encode('utf-8')) @command('Server') def lastsave(self, client, request, N): check_input(request, N) client.reply_int(self._last_save) @command('Server', script=0) def monitor(self, client, request, N): check_input(request, N) client.flag |= self.MONITOR self._monitors.add(client) client.reply_ok() @command('Server', script=0) def save(self, client, request, N): check_input(request, N) self._save(False) client.reply_ok() @command('Server', supported=False) def shutdown(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Server', supported=False) def slaveof(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Server', supported=False) def slowlog(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Server', supported=False) def sync(self, client, request, N): client.reply_error(self.NOT_SUPPORTED) @command('Server') def time(self, client, request, N): check_input(request, N != 0) t = time.time() seconds = math.floor(t) microseconds = int(1000000*(t-seconds)) client.reply_multi_bulk((seconds, microseconds)) # ######################################################################### # # INTERNALS def _cron(self): dirty = self._dirty if dirty: now = time.time() gap = now - self._last_save for interval, changes in self.cfg.key_value_save: if gap >= interval and dirty >= changes: self._save() break self._loop.call_later(1, self._cron) def _set(self, client, key, value, seconds=0, milliseconds=0, nx=False, xx=False): try: seconds = int(seconds) milliseconds = 0.001*int(milliseconds) if seconds < 0 or milliseconds < 0: raise ValueError except Exception: return client.reply_error('invalid expire time') else: timeout = seconds + milliseconds db = client.db exists = db.exists(key) skip = (exists and nx) or (not exists and xx) if not skip: if exists: db.pop(key) if timeout > 0: db._timer(timeout, key, bytearray(value)) self._signal(self.NOTIFY_STRING, db, 'expire', key) else: db._data[key] = bytearray(value) self._signal(self.NOTIFY_STRING, db, 'set', key, 1) return True def _incrby(self, client, name, key, value, type): try: tv = type(value) except Exception: return client.reply_error('invalid increment') db = client.db cur = db.get(key) if cur is None: db._data[key] = bytearray(value) elif isinstance(cur, bytearray): try: tv += type(cur) except Exception: return client.reply_error('invalid increment') db._data[key] = bytearray(str(tv).encode('utf-8')) else: return client.reply_wrongtype() self._signal(self.NOTIFY_STRING, db, name, key, 1) return tv def _bpop(self, client, request, keys, dest=None): list_type = self.list_type db = client.db for key in keys: value = db.get(key) if isinstance(value, list_type): self._block_callback(client, request[0], key, value, dest) return True elif value is not None: client.reply_wrongtype() return True return False def _block_callback(self, client, command, key, value, dest): db = client.db if command[:2] == 'br': if dest is not None: dval = db.get(dest) if dval is None: dval = self.list_type() db._data[dest] = dval elif not isinstance(dval, self.list_type): return client.reply_wrongtype() elem = value.pop() self._signal(self.NOTIFY_LIST, db, 'rpop', key, 1) if dest is not None: dval.appendleft(elem) self._signal(self.NOTIFY_LIST, db, 'lpush', dest, 1) else: elem = value.popleft() self._signal(self.NOTIFY_LIST, db, 'lpop', key, 1) if not value: db.pop(key) self._signal(self.NOTIFY_GENERIC, db, 'del', key, 1) if dest is None: client.reply_multi_bulk((key, elem)) else: client.reply_bulk(elem) def _range_values(self, value, start, end): start = int(start) end = int(end) if value is not None: if start < 0: start = len(value) + start if end < 0: end = len(value) + end + 1 else: end += 1 return start, end def _close_transaction(self, client): client.transaction = None client.watched_keys = None client.flag &= ~self.DIRTY_CAS self._watching.discard(client) def _flat_info(self): info = self._server.info() info['server']['redis_version'] = self.version e = self._encode_info_value for k, values in info.items(): if isinstance(values, dict): yield '#%s' % k for key, value in values.items(): if isinstance(value, (list, tuple)): value = ', '.join((e(v) for v in value)) elif isinstance(value, dict): value = ', '.join(('%s=%s' % (k, e(v)) for k, v in value.items())) else: value = e(value) yield '%s:%s' % (key, value) def _get_config(self, name): return b'' def _set_config(self, name, value): pass def _encode_info_value(self, value): return str(value).replace('=', ' ').replace(',', ' ').replace('\n', ' - ') def _hincrby(self, client, request, N, type): check_input(request, N != 3) key, field = request[1], request[2] try: increment = type(request[3]) except Exception: return client.reply_error( 'value is not an %s or out of range' % type.__name__) db = client.db hash = db.get(key) if hash is None: hash = self.hash_type() db._data[key] = hash elif not isinstance(hash, self.hash_type): return client.reply_wrongtype() if field in hash: try: value = type(hash[field]) except Exception: return client.reply_error( 'hash value is not an %s' % type.__name__) increment += value hash[field] = increment self._signal(self.NOTIFY_HASH, db, request[0], key, 1) return increment def _setoper(self, client, oper, keys, dest=None): db = client.db result = None for key in keys: value = db.get(key) if value is None: value = set() elif not isinstance(value, set): return client.reply_wrongtype() if result is None: result = value else: result = getattr(result, oper)(value) if dest is not None: db.pop(dest) if result: db._data[dest] = result client.reply_int(len(result)) else: client.reply_zero() else: client.reply_multi_bulk(result) def _zsetoper(self, client, request, N): check_input(request, N < 3) db = client.db cmnd = request[0] try: des = request[1] try: numkeys = int(request[2]) except Exception: numkeys = 0 if numkeys <= 0: raise ValueError('at least 1 input key is needed for ' 'ZUNIONSTORE/ZINTERSTORE') sets = [] for key in request[3:3+numkeys]: value = db.get(key) if value is None: value = self.zset_type() elif not isinstance(value, self.zset_type): return client.reply_wrongtype() sets.append(value) if len(sets) != numkeys: raise ValueError('numkeys does not match number of sets') op = set((b'weights', b'aggregate')) request = request[3+numkeys:] weights = None aggregate = sum while request: name = request[0].lower() if name in op: op.discard(name) if name == b'weights': weights = [float(v) for v in request[1:1+numkeys]] request = request[1+numkeys:] elif len(request) > 1: aggregate = self.zset_aggregate.get(request[1]) request = request[2:] else: raise ValueError(self.SYNTAX_ERROR) if not aggregate: raise ValueError(self.SYNTAX_ERRO) if weights is None: weights = [1]*numkeys elif len(weights) != numkeys: raise ValueError(self.SYNTAX_ERROR) except Exception as e: return client.reply_error(str(e)) if cmnd == b'zunionstore': result = self.zset_type.union(sets, weights, aggregate) else: result = self.zset_type.inter(sets, weights, aggregate) if db.pop(des) is not None: self._signal(self.NOTIFY_GENERIC, db, 'del', des, 1) db._data[des] = result self._signal(self.NOTIFY_ZSET, db, cmnd, des, len(result)) client.reply_int(len(result)) def _score_values(self, min_value, max_value): include_min = include_max = True if min_value and min_value[0] == 40: include_min = False min_value = min_value[1:] if max_value and max_value[0] == 40: include_max = False max_value = max_value[1:] return float(min_value), include_min, float(max_value), include_max def _info(self): keyspace = {} stats = {'keyspace_hits': self._hit_keys, 'keyspace_misses': self._missed_keys, 'expired_keys': self._expired_keys, 'keys_changed': self._dirty, 'pubsub_channels': len(self._channels), 'pubsub_patterns': len(self._patterns), 'blocked_clients': self._bpop_blocked_clients} persistance = {'rdb_changes_since_last_save': self._dirty, 'rdb_last_save_time': self._last_save} for db in self.databases.values(): if len(db): keyspace[str(db)] = db.info() return {'keyspace': keyspace, 'stats': stats, 'persistance': persistance} def _client_list(self, client): for client in client._producer._concurrent_connections: yield ' '.join(self._client_info(client)) def _client_info(self, client): yield 'addr=%s:%s' % client._transport.get_extra_info('addr') yield 'fd=%s' % client._transport._sock_fd yield 'age=%s' % int(time.time() - client.started) yield 'db=%s' % client.database yield 'sub=%s' % len(client.channels) yield 'psub=%s' % len(client.patterns) yield 'cmd=%s' % client.last_command def _save(self, async=True): writer = self._writer if writer and writer.is_alive(): self.logger.warning('Cannot save, background saving in progress') else: from multiprocessing import Process data = self._dbs() self._dirty = 0 self._last_save = int(time.time()) if async: self.logger.debug('Saving database in background process') self._writer = Process(target=save_data, args=(self.cfg, self._filename, data)) self._writer.start() else: self.logger.debug('Saving database') save_data(self.cfg, self._filename, data) def _dbs(self): data = [(db._num, db._data) for db in self.databases.values() if len(db._data)] return (1, data) def _loaddb(self): filename = self._filename if self.cfg.key_value_save and os.path.isfile(filename): self.logger.info('loading data from "%s"', filename) with open(filename, 'rb') as file: data = pickle.load(file) version, dbs = data for num, data in dbs: db = self.databases.get(num) if db is not None: db._data = data def _signal(self, type, db, command, key=None, dirty=0): self._dirty += dirty self._event_handlers[type](db, key, COMMANDS_INFO[command]) def _publish_clients(self, msg, clients): remove = set() count = 0 for client in clients: try: client._transport.write(msg) count += 1 except Exception: remove.add(client) if remove: clients.difference_update(remove) return count # EVENT HANDLERS def _modified_key(self, key): for client in self._watching: if key is None or key in client.watched_keys: client.flag |= self.DIRTY_CAS def _generic_event(self, db, key, command): if command.write: self._modified_key(key) _string_event = _generic_event _set_event = _generic_event _hash_event = _generic_event _zset_event = _generic_event def _list_event(self, db, key, command): if command.write: self._modified_key(key) # the key is blocking clients if key in db._blocking_keys: if key in db._data: value = db._data[key] elif key in self._expires: value = db._expires[key] else: value = None for client in db._blocking_keys.pop(key): client.blocked.unblock(client, key, value) def _remove_connection(self, client, _, **kw): # Remove a client from the server self._monitors.discard(client) self._watching.discard(client) for channel, clients in list(self._channels.items()): clients.discard(client) if not clients: self._channels.pop(channel) for pattern, p in list(self._patterns.items()): p.clients.discard(client) if not p.clients: self._patterns.pop(pattern) def _write_to_monitors(self, client, request): # addr = '%s:%s' % self._transport.get_extra_info('addr') cmds = b'" "'.join(request) message = '+%s [0 %s] "'.encode('utf-8') + cmds + b'"\r\n' remove = set() for m in self._monitors: try: m._transport.write(message) except Exception: remove.add(m) if remove: self._monitors.difference_update(remove)
class Db: '''A database. ''' def __init__(self, num, store): self.store = store self._num = num self._loop = store._loop self._data = {} self._expires = {} self._events = {} self._blocking_keys = {} def __repr__(self): return 'db%s' % self._num __str__ = __repr__ def __len__(self): return len(self._data) + len(self._expires) def __iter__(self): return chain(self._data, self._expires) # ######################################################################### # # INTERNALS def flush(self): removed = len(self._data) self._data.clear() [t.handle.cancel() for t in self._expires.values()] self._expires.clear() self.store._signal(self.store.NOTIFY_GENERIC, self, 'flushdb', dirty=removed) def get(self, key, default=None): if key in self._data: self.store._hit_keys += 1 return self._data[key] elif key in self._expires: self.store._hit_keys += 1 return self._expires[key].value else: self.store._missed_keys += 1 return default def exists(self, key): return key in self._data or key in self._expires def expire(self, key, timeout): if key in self._expires: t = self._expires.pop(key) t.handle.cancel() value = t.value elif key in self._data: value = self._data.pop(key) else: return False if timeout > 0: self._timer(timeout, key, value) return True def persist(self, key): if key in self._expires: self.store._hit_keys += 1 t = self._expires.pop(key) t.handle.cancel() self._data[key] = t.value return True elif key in self._data: self.store._hit_keys += 1 else: self.store._missed_keys += 1 return False def ttl(self, key, m=1): if key in self._expires: self.store._hit_keys += 1 t = self._expires[key] return max(0, int(m*(t.when - self._loop.time()))) elif key in self._data: self.store._hit_keys += 1 return -1 else: self.store._missed_keys += 1 return -2 def info(self): return {'Keys': len(self._data), 'expires': len(self._expires)} def pop(self, key, value=None): if not value: if key in self._data: value = self._data.pop(key) return value elif key in self._expires: t = self._expires.pop(key) t.handle.cancel() return t.value def rem(self, key): if key in self._data: self.store._hit_keys += 1 self._data.pop(key) self.store._signal(self.store.NOTIFY_GENERIC, self, 'del', key, 1) return 1 elif key in self._expires: self.store._hit_keys += 1 t = self._expires.pop(key) t.handle.cancel() self.store._signal(self.store.NOTIFY_GENERIC, self, 'del', key, 1) return 1 else: self.store._missed_keys += 1 return 0 def _do_expire(self, key): if key in self._expires: t = self._expires.pop(key) t.handle.cancel() self.store._expired_keys += 1 def _timer(self, timeout, key, value): loop = self._loop when = loop.time() + timeout handle = loop.call_at(when, self._do_expire, key) self._expires[key] = Timer(handle, value, when) class Timer: def __init__(self, handle, value, when): self.handle = handle self.value = value self.when = when