Source code for redis.connection

from distutils.version import LooseVersion
from itertools import chain
from time import time
from queue import LifoQueue, Empty, Full
from urllib.parse import parse_qs, unquote, urlparse
import copy
import errno
import io
import os
import socket
import threading
import weakref

from redis.exceptions import (
    AuthenticationError,
    AuthenticationWrongNumberOfArgsError,
    BusyLoadingError,
    ChildDeadlockedError,
    ConnectionError,
    DataError,
    ExecAbortError,
    InvalidResponse,
    NoPermissionError,
    NoScriptError,
    ReadOnlyError,
    RedisError,
    ResponseError,
    TimeoutError,
    ModuleError,
)
from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
from redis.backoff import NoBackoff
from redis.retry import Retry

try:
    import ssl
    ssl_available = True
except ImportError:
    ssl_available = False

NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
    BlockingIOError: errno.EWOULDBLOCK,
}

if ssl_available:
    if hasattr(ssl, 'SSLWantReadError'):
        NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2
        NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2
    else:
        NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2

NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())

if HIREDIS_AVAILABLE:
    import hiredis

    hiredis_version = LooseVersion(hiredis.__version__)
    HIREDIS_SUPPORTS_CALLABLE_ERRORS = \
        hiredis_version >= LooseVersion('0.1.3')
    HIREDIS_SUPPORTS_BYTE_BUFFER = \
        hiredis_version >= LooseVersion('0.1.4')
    HIREDIS_SUPPORTS_ENCODING_ERRORS = \
        hiredis_version >= LooseVersion('1.0.0')

    HIREDIS_USE_BYTE_BUFFER = True
    # only use byte buffer if hiredis supports it
    if not HIREDIS_SUPPORTS_BYTE_BUFFER:
        HIREDIS_USE_BYTE_BUFFER = False

SYM_STAR = b'*'
SYM_DOLLAR = b'$'
SYM_CRLF = b'\r\n'
SYM_EMPTY = b''

SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."

SENTINEL = object()
MODULE_LOAD_ERROR = 'Error loading the extension. ' \
                    'Please check the server logs.'
NO_SUCH_MODULE_ERROR = 'Error unloading module: no such module with that name'
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = 'Error unloading module: operation not ' \
                                   'possible.'
MODULE_EXPORTS_DATA_TYPES_ERROR = "Error unloading module: the module " \
                                  "exports one or more module-side data " \
                                  "types, can't unload"


[docs]class Encoder: "Encode strings to bytes-like and decode bytes-like to strings" def __init__(self, encoding, encoding_errors, decode_responses): self.encoding = encoding self.encoding_errors = encoding_errors self.decode_responses = decode_responses
[docs] def encode(self, value): "Return a bytestring or bytes-like representation of the value" if isinstance(value, (bytes, memoryview)): return value elif isinstance(value, bool): # special case bool since it is a subclass of int raise DataError("Invalid input of type: 'bool'. Convert to a " "bytes, string, int or float first.") elif isinstance(value, (int, float)): value = repr(value).encode() elif not isinstance(value, str): # a value we don't know how to deal with. throw an error typename = type(value).__name__ raise DataError("Invalid input of type: '%s'. Convert to a " "bytes, string, int or float first." % typename) if isinstance(value, str): value = value.encode(self.encoding, self.encoding_errors) return value
[docs] def decode(self, value, force=False): "Return a unicode string from the bytes-like representation" if self.decode_responses or force: if isinstance(value, memoryview): value = value.tobytes() if isinstance(value, bytes): value = value.decode(self.encoding, self.encoding_errors) return value
class BaseParser: EXCEPTION_CLASSES = { 'ERR': { 'max number of clients reached': ConnectionError, 'Client sent AUTH, but no password is set': AuthenticationError, 'invalid password': AuthenticationError, # some Redis server versions report invalid command syntax # in lowercase 'wrong number of arguments for \'auth\' command': AuthenticationWrongNumberOfArgsError, # some Redis server versions report invalid command syntax # in uppercase 'wrong number of arguments for \'AUTH\' command': AuthenticationWrongNumberOfArgsError, MODULE_LOAD_ERROR: ModuleError, MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, NO_SUCH_MODULE_ERROR: ModuleError, MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, }, 'EXECABORT': ExecAbortError, 'LOADING': BusyLoadingError, 'NOSCRIPT': NoScriptError, 'READONLY': ReadOnlyError, 'NOAUTH': AuthenticationError, 'NOPERM': NoPermissionError, } def parse_error(self, response): "Parse an error response" error_code = response.split(' ')[0] if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1:] exception_class = self.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) return exception_class(response) return ResponseError(response) class SocketBuffer: def __init__(self, socket, socket_read_size, socket_timeout): self._sock = socket self.socket_read_size = socket_read_size self.socket_timeout = socket_timeout self._buffer = io.BytesIO() # number of bytes written to the buffer from the socket self.bytes_written = 0 # number of bytes read from the buffer self.bytes_read = 0 @property def length(self): return self.bytes_written - self.bytes_read def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True): sock = self._sock socket_read_size = self.socket_read_size buf = self._buffer buf.seek(self.bytes_written) marker = 0 custom_timeout = timeout is not SENTINEL try: if custom_timeout: sock.settimeout(timeout) while True: data = self._sock.recv(socket_read_size) # an empty string indicates the server shutdown the socket if isinstance(data, bytes) and len(data) == 0: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) buf.write(data) data_length = len(data) self.bytes_written += data_length marker += data_length if length is not None and length > marker: continue return True except socket.timeout: if raise_on_timeout: raise TimeoutError("Timeout reading from socket") return False except NONBLOCKING_EXCEPTIONS as ex: # if we're in nonblocking mode and the recv raises a # blocking error, simply return False indicating that # there's no data to be read. otherwise raise the # original exception. allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) if not raise_on_timeout and ex.errno == allowed: return False raise ConnectionError("Error while reading from socket: %s" % (ex.args,)) finally: if custom_timeout: sock.settimeout(self.socket_timeout) def can_read(self, timeout): return bool(self.length) or \ self._read_from_socket(timeout=timeout, raise_on_timeout=False) def read(self, length): length = length + 2 # make sure to read the \r\n terminator # make sure we've read enough data from the socket if length > self.length: self._read_from_socket(length - self.length) self._buffer.seek(self.bytes_read) data = self._buffer.read(length) self.bytes_read += len(data) # purge the buffer when we've consumed it all so it doesn't # grow forever if self.bytes_read == self.bytes_written: self.purge() return data[:-2] def readline(self): buf = self._buffer buf.seek(self.bytes_read) data = buf.readline() while not data.endswith(SYM_CRLF): # there's more data in the socket that we need self._read_from_socket() buf.seek(self.bytes_read) data = buf.readline() self.bytes_read += len(data) # purge the buffer when we've consumed it all so it doesn't # grow forever if self.bytes_read == self.bytes_written: self.purge() return data[:-2] def purge(self): self._buffer.seek(0) self._buffer.truncate() self.bytes_written = 0 self.bytes_read = 0 def close(self): try: self.purge() self._buffer.close() except Exception: # issue #633 suggests the purge/close somehow raised a # BadFileDescriptor error. Perhaps the client ran out of # memory or something else? It's probably OK to ignore # any error being raised from purge/close since we're # removing the reference to the instance below. pass self._buffer = None self._sock = None
[docs]class PythonParser(BaseParser): "Plain Python parsing class" def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self.encoder = None self._sock = None self._buffer = None def __del__(self): try: self.on_disconnect() except Exception: pass
[docs] def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock self._buffer = SocketBuffer(self._sock, self.socket_read_size, connection.socket_timeout) self.encoder = connection.encoder
[docs] def on_disconnect(self): "Called when the socket disconnects" self._sock = None if self._buffer is not None: self._buffer.close() self._buffer = None self.encoder = None
def can_read(self, timeout): return self._buffer and self._buffer.can_read(timeout) def read_response(self): raw = self._buffer.readline() if not raw: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) byte, response = raw[:1], raw[1:] if byte not in (b'-', b'+', b':', b'$', b'*'): raise InvalidResponse("Protocol Error: %r" % raw) # server returned an error if byte == b'-': response = response.decode('utf-8', errors='replace') error = self.parse_error(response) # if the error is a ConnectionError, raise immediately so the user # is notified if isinstance(error, ConnectionError): raise error # otherwise, we're dealing with a ResponseError that might belong # inside a pipeline response. the connection's read_response() # and/or the pipeline's execute() will raise this error if # necessary, so just return the exception instance here. return error # single value elif byte == b'+': pass # int value elif byte == b':': response = int(response) # bulk response elif byte == b'$': length = int(response) if length == -1: return None response = self._buffer.read(length) # multi-bulk response elif byte == b'*': length = int(response) if length == -1: return None response = [self.read_response() for i in range(length)] if isinstance(response, bytes): response = self.encoder.decode(response) return response
[docs]class HiredisParser(BaseParser): "Parser class for connections using Hiredis" def __init__(self, socket_read_size): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not installed") self.socket_read_size = socket_read_size if HIREDIS_USE_BYTE_BUFFER: self._buffer = bytearray(socket_read_size) def __del__(self): try: self.on_disconnect() except Exception: pass def on_connect(self, connection): self._sock = connection._sock self._socket_timeout = connection.socket_timeout kwargs = { 'protocolError': InvalidResponse, 'replyError': self.parse_error, } # hiredis < 0.1.3 doesn't support functions that create exceptions if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: kwargs['replyError'] = ResponseError if connection.encoder.decode_responses: kwargs['encoding'] = connection.encoder.encoding if HIREDIS_SUPPORTS_ENCODING_ERRORS: kwargs['errors'] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) self._next_response = False def on_disconnect(self): self._sock = None self._reader = None self._next_response = False def can_read(self, timeout): if not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._next_response is False: self._next_response = self._reader.gets() if self._next_response is False: return self.read_from_socket(timeout=timeout, raise_on_timeout=False) return True def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): sock = self._sock custom_timeout = timeout is not SENTINEL try: if custom_timeout: sock.settimeout(timeout) if HIREDIS_USE_BYTE_BUFFER: bufflen = self._sock.recv_into(self._buffer) if bufflen == 0: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) self._reader.feed(self._buffer, 0, bufflen) else: buffer = self._sock.recv(self.socket_read_size) # an empty string indicates the server shutdown the socket if not isinstance(buffer, bytes) or len(buffer) == 0: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) self._reader.feed(buffer) # data was read from the socket and added to the buffer. # return True to indicate that data was read. return True except socket.timeout: if raise_on_timeout: raise TimeoutError("Timeout reading from socket") return False except NONBLOCKING_EXCEPTIONS as ex: # if we're in nonblocking mode and the recv raises a # blocking error, simply return False indicating that # there's no data to be read. otherwise raise the # original exception. allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) if not raise_on_timeout and ex.errno == allowed: return False raise ConnectionError("Error while reading from socket: %s" % (ex.args,)) finally: if custom_timeout: sock.settimeout(self._socket_timeout) def read_response(self): if not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response self._next_response = False return response response = self._reader.gets() while response is False: self.read_from_socket() response = self._reader.gets() # if an older version of hiredis is installed, we need to attempt # to convert ResponseErrors to their appropriate types. if not HIREDIS_SUPPORTS_CALLABLE_ERRORS: if isinstance(response, ResponseError): response = self.parse_error(response.args[0]) elif isinstance(response, list) and response and \ isinstance(response[0], ResponseError): response[0] = self.parse_error(response[0].args[0]) # if the response is a ConnectionError or the response is a list and # the first item is a ConnectionError, raise it as something bad # happened if isinstance(response, ConnectionError): raise response elif isinstance(response, list) and response and \ isinstance(response[0], ConnectionError): raise response[0] return response
if HIREDIS_AVAILABLE: DefaultParser = HiredisParser else: DefaultParser = PythonParser
[docs]class Connection: "Manages TCP communication to and from a Redis server" def __init__(self, host='localhost', port=6379, db=0, password=None, socket_timeout=None, socket_connect_timeout=None, socket_keepalive=False, socket_keepalive_options=None, socket_type=0, retry_on_timeout=False, encoding='utf-8', encoding_errors='strict', decode_responses=False, parser_class=DefaultParser, socket_read_size=65536, health_check_interval=0, client_name=None, username=None, retry=None): """ Initialize a new Connection. To specify a retry policy, first set `retry_on_timeout` to `True` then set `retry` to a valid `Retry` object """ self.pid = os.getpid() self.host = host self.port = int(port) self.db = db self.username = username self.client_name = client_name self.password = password self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout if retry_on_timeout: if retry is None: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable self.retry = copy.deepcopy(retry) else: self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._connect_callbacks = [] self._buffer_cutoff = 6000 def __repr__(self): repr_args = ','.join(['%s=%s' % (k, v) for k, v in self.repr_pieces()]) return '%s<%s>' % (self.__class__.__name__, repr_args) def repr_pieces(self): pieces = [ ('host', self.host), ('port', self.port), ('db', self.db) ] if self.client_name: pieces.append(('client_name', self.client_name)) return pieces def __del__(self): try: self.disconnect() except Exception: pass def register_connect_callback(self, callback): self._connect_callbacks.append(weakref.WeakMethod(callback)) def clear_connect_callbacks(self): self._connect_callbacks = []
[docs] def connect(self): "Connects to the Redis server if not already connected" if self._sock: return try: sock = self._connect() except socket.timeout: raise TimeoutError("Timeout connecting to server") except socket.error as e: raise ConnectionError(self._error_message(e)) self._sock = sock try: self.on_connect() except RedisError: # clean up after any error in on_connect self.disconnect() raise # run any user callbacks. right now the only internal callback # is for pubsub channel/pattern resubscription for ref in self._connect_callbacks: callback = ref() if callback: callback(self)
def _connect(self): "Create a TCP socket connection" # we want to mimic what socket.create_connection does to support # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None for res in socket.getaddrinfo(self.host, self.port, self.socket_type, socket.SOCK_STREAM): family, socktype, proto, canonname, socket_address = res sock = None try: sock = socket.socket(family, socktype, proto) # TCP_NODELAY sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # TCP_KEEPALIVE if self.socket_keepalive: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) for k, v in self.socket_keepalive_options.items(): sock.setsockopt(socket.IPPROTO_TCP, k, v) # set the socket_connect_timeout before we connect sock.settimeout(self.socket_connect_timeout) # connect sock.connect(socket_address) # set the socket_timeout now that we're connected sock.settimeout(self.socket_timeout) return sock except OSError as _: err = _ if sock is not None: sock.close() if err is not None: raise err raise OSError("socket.getaddrinfo returned an empty list") def _error_message(self, exception): # args for socket.error can either be (errno, "message") # or just "message" if len(exception.args) == 1: return "Error connecting to %s:%s. %s." % \ (self.host, self.port, exception.args[0]) else: return "Error %s connecting to %s:%s. %s." % \ (exception.args[0], self.host, self.port, exception.args[1])
[docs] def on_connect(self): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) # if username and/or password are set, authenticate if self.username or self.password: if self.username: auth_args = (self.username, self.password or '') else: auth_args = (self.password,) # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command('AUTH', *auth_args, check_health=False) try: auth_response = self.read_response() except AuthenticationWrongNumberOfArgsError: # a username and password were specified but the Redis # server seems to be < 6.0.0 which expects a single password # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 self.send_command('AUTH', self.password, check_health=False) auth_response = self.read_response() if str_if_bytes(auth_response) != 'OK': raise AuthenticationError('Invalid Username or Password') # if a client_name is given, set it if self.client_name: self.send_command('CLIENT', 'SETNAME', self.client_name) if str_if_bytes(self.read_response()) != 'OK': raise ConnectionError('Error setting client name') # if a database is specified, switch to it if self.db: self.send_command('SELECT', self.db) if str_if_bytes(self.read_response()) != 'OK': raise ConnectionError('Invalid Database')
[docs] def disconnect(self): "Disconnects from the Redis server" self._parser.on_disconnect() if self._sock is None: return try: if os.getpid() == self.pid: self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() except OSError: pass self._sock = None
def _send_ping(self): """Send PING, expect PONG in return""" self.send_command('PING', check_health=False) if str_if_bytes(self.read_response()) != 'PONG': raise ConnectionError('Bad response from PING health check') def _ping_failed(self, error): """Function to call when PING fails""" self.disconnect()
[docs] def check_health(self): """Check the health of the connection with a PING/PONG""" if self.health_check_interval and time() > self.next_health_check: self.retry.call_with_retry(self._send_ping, self._ping_failed)
[docs] def send_packed_command(self, command, check_health=True): """Send an already packed command to the Redis server""" if not self._sock: self.connect() # guard against health check recursion if check_health: self.check_health() try: if isinstance(command, str): command = [command] for item in command: self._sock.sendall(item) except socket.timeout: self.disconnect() raise TimeoutError("Timeout writing to socket") except socket.error as e: self.disconnect() if len(e.args) == 1: errno, errmsg = 'UNKNOWN', e.args[0] else: errno = e.args[0] errmsg = e.args[1] raise ConnectionError("Error %s while writing to socket. %s." % (errno, errmsg)) except BaseException: self.disconnect() raise
[docs] def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" self.send_packed_command(self.pack_command(*args), check_health=kwargs.get('check_health', True))
[docs] def can_read(self, timeout=0): """Poll the socket to see if there's data that can be read.""" sock = self._sock if not sock: self.connect() return self._parser.can_read(timeout)
[docs] def read_response(self): """Read the response from a previously sent command""" try: response = self._parser.read_response() except socket.timeout: self.disconnect() raise TimeoutError("Timeout reading from %s:%s" % (self.host, self.port)) except socket.error as e: self.disconnect() raise ConnectionError("Error while reading from %s:%s : %s" % (self.host, self.port, e.args)) except BaseException: self.disconnect() raise if self.health_check_interval: self.next_health_check = time() + self.health_check_interval if isinstance(response, ResponseError): raise response return response
[docs] def pack_command(self, *args): """Pack a series of arguments into the Redis protocol""" output = [] # the client might have included 1 or more literal arguments in # the command name, e.g., 'CONFIG GET'. The Redis server expects these # arguments to be sent separately, so split the first argument # manually. These arguments should be bytestrings so that they are # not encoded. if isinstance(args[0], str): args = tuple(args[0].encode().split()) + args[1:] elif b' ' in args[0]: args = tuple(args[0].split()) + args[1:] buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) buffer_cutoff = self._buffer_cutoff for arg in map(self.encoder.encode, args): # to avoid large string mallocs, chunk the command into the # output list if we're sending large values or memoryviews arg_length = len(arg) if (len(buff) > buffer_cutoff or arg_length > buffer_cutoff or isinstance(arg, memoryview)): buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)) output.append(buff) output.append(arg) buff = SYM_CRLF else: buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF, arg, SYM_CRLF)) output.append(buff) return output
[docs] def pack_commands(self, commands): """Pack multiple commands into the Redis protocol""" output = [] pieces = [] buffer_length = 0 buffer_cutoff = self._buffer_cutoff for cmd in commands: for chunk in self.pack_command(*cmd): chunklen = len(chunk) if (buffer_length > buffer_cutoff or chunklen > buffer_cutoff or isinstance(chunk, memoryview)): output.append(SYM_EMPTY.join(pieces)) buffer_length = 0 pieces = [] if chunklen > buffer_cutoff or isinstance(chunk, memoryview): output.append(chunk) else: pieces.append(chunk) buffer_length += chunklen if pieces: output.append(SYM_EMPTY.join(pieces)) return output
[docs]class SSLConnection(Connection): def __init__(self, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, ssl_check_hostname=False, **kwargs): if not ssl_available: raise RedisError("Python wasn't built with SSL support") super().__init__(**kwargs) self.keyfile = ssl_keyfile self.certfile = ssl_certfile if ssl_cert_reqs is None: ssl_cert_reqs = ssl.CERT_NONE elif isinstance(ssl_cert_reqs, str): CERT_REQS = { 'none': ssl.CERT_NONE, 'optional': ssl.CERT_OPTIONAL, 'required': ssl.CERT_REQUIRED } if ssl_cert_reqs not in CERT_REQS: raise RedisError( "Invalid SSL Certificate Requirements Flag: %s" % ssl_cert_reqs) ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] self.cert_reqs = ssl_cert_reqs self.ca_certs = ssl_ca_certs self.check_hostname = ssl_check_hostname def _connect(self): "Wrap the socket with SSL support" sock = super()._connect() context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs if self.certfile and self.keyfile: context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs: context.load_verify_locations(self.ca_certs) return context.wrap_socket(sock, server_hostname=self.host)
[docs]class UnixDomainSocketConnection(Connection): def __init__(self, path='', db=0, username=None, password=None, socket_timeout=None, encoding='utf-8', encoding_errors='strict', decode_responses=False, retry_on_timeout=False, parser_class=DefaultParser, socket_read_size=65536, health_check_interval=0, client_name=None, retry=None): """ Initialize a new UnixDomainSocketConnection. To specify a retry policy, first set `retry_on_timeout` to `True` then set `retry` to a valid `Retry` object """ self.pid = os.getpid() self.path = path self.db = db self.username = username self.client_name = client_name self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout if retry_on_timeout: if retry is None: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable self.retry = copy.deepcopy(retry) else: self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) self._connect_callbacks = [] self._buffer_cutoff = 6000 def repr_pieces(self): pieces = [ ('path', self.path), ('db', self.db), ] if self.client_name: pieces.append(('client_name', self.client_name)) return pieces def _connect(self): "Create a Unix domain socket connection" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(self.socket_timeout) sock.connect(self.path) return sock def _error_message(self, exception): # args for socket.error can either be (errno, "message") # or just "message" if len(exception.args) == 1: return "Error connecting to unix socket: %s. %s." % \ (self.path, exception.args[0]) else: return "Error %s connecting to unix socket: %s. %s." % \ (exception.args[0], self.path, exception.args[1])
FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO') def to_bool(value): if value is None or value == '': return None if isinstance(value, str) and value.upper() in FALSE_STRINGS: return False return bool(value) URL_QUERY_ARGUMENT_PARSERS = { 'db': int, 'socket_timeout': float, 'socket_connect_timeout': float, 'socket_keepalive': to_bool, 'retry_on_timeout': to_bool, 'max_connections': int, 'health_check_interval': int, 'ssl_check_hostname': to_bool, } def parse_url(url): url = urlparse(url) kwargs = {} for name, value in parse_qs(url.query).items(): if value and len(value) > 0: value = unquote(value[0]) parser = URL_QUERY_ARGUMENT_PARSERS.get(name) if parser: try: kwargs[name] = parser(value) except (TypeError, ValueError): raise ValueError( "Invalid value for `%s` in connection URL." % name ) else: kwargs[name] = value if url.username: kwargs['username'] = unquote(url.username) if url.password: kwargs['password'] = unquote(url.password) # We only support redis://, rediss:// and unix:// schemes. if url.scheme == 'unix': if url.path: kwargs['path'] = unquote(url.path) kwargs['connection_class'] = UnixDomainSocketConnection elif url.scheme in ('redis', 'rediss'): if url.hostname: kwargs['host'] = unquote(url.hostname) if url.port: kwargs['port'] = int(url.port) # If there's a path argument, use it as the db argument if a # querystring value wasn't specified if url.path and 'db' not in kwargs: try: kwargs['db'] = int(unquote(url.path).replace('/', '')) except (AttributeError, ValueError): pass if url.scheme == 'rediss': kwargs['connection_class'] = SSLConnection else: valid_schemes = 'redis://, rediss://, unix://' raise ValueError('Redis URL must specify one of the following ' 'schemes (%s)' % valid_schemes) return kwargs
[docs]class ConnectionPool: """ Create a connection pool. ``If max_connections`` is set, then this object raises :py:class:`~redis.ConnectionError` when the pool's limit is reached. By default, TCP connections are created unless ``connection_class`` is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for unix sockets. Any additional keyword arguments are passed to the constructor of ``connection_class``. """
[docs] @classmethod def from_url(cls, url, **kwargs): """ Return a connection pool configured from the given URL. For example:: redis://[[username]:[password]]@localhost:6379/0 rediss://[[username]:[password]]@localhost:6379/0 unix://[[username]:[password]]@/path/to/socket.sock?db=0 Three URL schemes are supported: - `redis://` creates a TCP socket connection. See more at: <https://www.iana.org/assignments/uri-schemes/prov/redis> - `rediss://` creates a SSL wrapped TCP socket connection. See more at: <https://www.iana.org/assignments/uri-schemes/prov/rediss> - ``unix://``: creates a Unix Domain Socket connection. The username, password, hostname, path and all querystring values are passed through urllib.parse.unquote in order to replace any percent-encoded values with their corresponding characters. There are several ways to specify a database number. The first value found will be used: 1. A ``db`` querystring option, e.g. redis://localhost?db=0 2. If using the redis:// or rediss:// schemes, the path argument of the url, e.g. redis://localhost/0 3. A ``db`` keyword argument to this function. If none of these options are specified, the default db=0 is used. All querystring options are cast to their appropriate Python types. Boolean arguments can be specified with string values "True"/"False" or "Yes"/"No". Values that cannot be properly cast cause a ``ValueError`` to be raised. Once parsed, the querystring arguments and keyword arguments are passed to the ``ConnectionPool``'s class initializer. In the case of conflicting arguments, querystring arguments always win. """ url_options = parse_url(url) kwargs.update(url_options) return cls(**kwargs)
def __init__(self, connection_class=Connection, max_connections=None, **connection_kwargs): max_connections = max_connections or 2 ** 31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.max_connections = max_connections # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as # after a fork. during this time, multiple threads in the child # process could attempt to acquire this lock. the first thread # to acquire the lock will reset the data structures and lock # object of this pool. subsequent threads acquiring this lock # will notice the first thread already did the work and simply # release the lock. self._fork_lock = threading.Lock() self.reset() def __repr__(self): return "%s<%s>" % ( type(self).__name__, repr(self.connection_class(**self.connection_kwargs)), ) def reset(self): self._lock = threading.Lock() self._created_connections = 0 self._available_connections = [] self._in_use_connections = set() # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process # can call _checkpid() which compares self.pid and os.getpid() without # holding any lock (for performance reasons). keeping this assignment # as the last operation ensures that those other threads will also # notice a pid difference and block waiting for the first thread to # release _fork_lock. when each of these threads eventually acquire # _fork_lock, they will notice that another thread already called # reset() and they will immediately release _fork_lock and continue on. self.pid = os.getpid() def _checkpid(self): # _checkpid() attempts to keep ConnectionPool fork-safe on modern # systems. this is called by all ConnectionPool methods that # manipulate the pool's state such as get_connection() and release(). # # _checkpid() determines whether the process has forked by comparing # the current process id to the process id saved on the ConnectionPool # instance. if these values are the same, _checkpid() simply returns. # # when the process ids differ, _checkpid() assumes that the process # has forked and that we're now running in the child process. the child # process cannot use the parent's file descriptors (e.g., sockets). # therefore, when _checkpid() sees the process id change, it calls # reset() in order to reinitialize the child's ConnectionPool. this # will cause the child to make all new connection objects. # # _checkpid() is protected by self._fork_lock to ensure that multiple # threads in the child process do not call reset() multiple times. # # there is an extremely small chance this could fail in the following # scenario: # 1. process A calls _checkpid() for the first time and acquires # self._fork_lock. # 2. while holding self._fork_lock, process A forks (the fork() # could happen in a different thread owned by process A) # 3. process B (the forked child process) inherits the # ConnectionPool's state from the parent. that state includes # a locked _fork_lock. process B will not be notified when # process A releases the _fork_lock and will thus never be # able to acquire the _fork_lock. # # to mitigate this possible deadlock, _checkpid() will only wait 5 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in # that time it is assumed that the child is deadlocked and a # redis.ChildDeadlockedError error is raised. if self.pid != os.getpid(): acquired = self._fork_lock.acquire(timeout=5) if not acquired: raise ChildDeadlockedError # reset() the instance for the new process if another thread # hasn't already done so try: if self.pid != os.getpid(): self.reset() finally: self._fork_lock.release()
[docs] def get_connection(self, command_name, *keys, **options): "Get a connection from the pool" self._checkpid() with self._lock: try: connection = self._available_connections.pop() except IndexError: connection = self.make_connection() self._in_use_connections.add(connection) try: # ensure this connection is connected to Redis connection.connect() # connections that the pool provides should be ready to send # a command. if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: if connection.can_read(): raise ConnectionError('Connection has data') except ConnectionError: connection.disconnect() connection.connect() if connection.can_read(): raise ConnectionError('Connection not ready') except BaseException: # release the connection back to the pool so that we don't # leak it self.release(connection) raise return connection
[docs] def get_encoder(self): "Return an encoder based on encoding settings" kwargs = self.connection_kwargs return Encoder( encoding=kwargs.get('encoding', 'utf-8'), encoding_errors=kwargs.get('encoding_errors', 'strict'), decode_responses=kwargs.get('decode_responses', False) )
[docs] def make_connection(self): "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 return self.connection_class(**self.connection_kwargs)
[docs] def release(self, connection): "Releases the connection back to the pool" self._checkpid() with self._lock: try: self._in_use_connections.remove(connection) except KeyError: # Gracefully fail when a connection is returned to this pool # that the pool doesn't actually own pass if self.owns_connection(connection): self._available_connections.append(connection) else: # pool doesn't own this connection. do not add it back # to the pool and decrement the count so that another # connection can take its place if needed self._created_connections -= 1 connection.disconnect() return
def owns_connection(self, connection): return connection.pid == self.pid
[docs] def disconnect(self, inuse_connections=True): """ Disconnects connections in the pool If ``inuse_connections`` is True, disconnect connections that are current in use, potentially by other threads. Otherwise only disconnect connections that are idle in the pool. """ self._checkpid() with self._lock: if inuse_connections: connections = chain(self._available_connections, self._in_use_connections) else: connections = self._available_connections for connection in connections: connection.disconnect()
[docs]class BlockingConnectionPool(ConnectionPool): """ Thread-safe blocking connection pool:: >>> from redis.client import Redis >>> client = Redis(connection_pool=BlockingConnectionPool()) It performs the same function as the default :py:class:`~redis.ConnectionPool` implementation, in that, it maintains a pool of reusable connections that can be shared by multiple redis clients (safely across threads if required). The difference is that, in the event that a client tries to get a connection from the pool when all of connections are in use, rather than raising a :py:class:`~redis.ConnectionError` (as the default :py:class:`~redis.ConnectionPool` implementation does), it makes the client wait ("blocks") for a specified number of seconds until a connection becomes available. Use ``max_connections`` to increase / decrease the pool size:: >>> pool = BlockingConnectionPool(max_connections=10) Use ``timeout`` to tell it either how many seconds to wait for a connection to become available, or to block forever: >>> # Block forever. >>> pool = BlockingConnectionPool(timeout=None) >>> # Raise a ``ConnectionError`` after five seconds if a connection is >>> # not available. >>> pool = BlockingConnectionPool(timeout=5) """ def __init__(self, max_connections=50, timeout=20, connection_class=Connection, queue_class=LifoQueue, **connection_kwargs): self.queue_class = queue_class self.timeout = timeout super().__init__( connection_class=connection_class, max_connections=max_connections, **connection_kwargs) def reset(self): # Create and fill up a thread safe queue with ``None`` values. self.pool = self.queue_class(self.max_connections) while True: try: self.pool.put_nowait(None) except Full: break # Keep a list of actual connection instances so that we can # disconnect them later. self._connections = [] # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process # can call _checkpid() which compares self.pid and os.getpid() without # holding any lock (for performance reasons). keeping this assignment # as the last operation ensures that those other threads will also # notice a pid difference and block waiting for the first thread to # release _fork_lock. when each of these threads eventually acquire # _fork_lock, they will notice that another thread already called # reset() and they will immediately release _fork_lock and continue on. self.pid = os.getpid()
[docs] def make_connection(self): "Make a fresh connection." connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection
[docs] def get_connection(self, command_name, *keys, **options): """ Get a connection, blocking for ``self.timeout`` until a connection is available from the pool. If the connection returned is ``None`` then creates a new connection. Because we use a last-in first-out queue, the existing connections (having been returned to the pool after the initial ``None`` values were added) will be returned before ``None`` values. This means we only create new connections when we need to, i.e.: the actual number of connections will only increase in response to demand. """ # Make sure we haven't changed process. self._checkpid() # Try and get a connection from the pool. If one isn't available within # self.timeout then raise a ``ConnectionError``. connection = None try: connection = self.pool.get(block=True, timeout=self.timeout) except Empty: # Note that this is not caught by the redis client and will be # raised unless handled by application code. If you want never to raise ConnectionError("No connection available.") # If the ``connection`` is actually ``None`` then that's a cue to make # a new connection to add to the pool. if connection is None: connection = self.make_connection() try: # ensure this connection is connected to Redis connection.connect() # connections that the pool provides should be ready to send # a command. if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: if connection.can_read(): raise ConnectionError('Connection has data') except ConnectionError: connection.disconnect() connection.connect() if connection.can_read(): raise ConnectionError('Connection not ready') except BaseException: # release the connection back to the pool so that we don't leak it self.release(connection) raise return connection
[docs] def release(self, connection): "Releases the connection back to the pool." # Make sure we haven't changed process. self._checkpid() if not self.owns_connection(connection): # pool doesn't own this connection. do not add it back # to the pool. instead add a None value which is a placeholder # that will cause the pool to recreate the connection if # its needed. connection.disconnect() self.pool.put_nowait(None) return # Put the connection back into the pool. try: self.pool.put_nowait(connection) except Full: # perhaps the pool has been reset() after a fork? regardless, # we don't want this connection pass
[docs] def disconnect(self): "Disconnects all connections in the pool." self._checkpid() for connection in self._connections: connection.disconnect()