%PDF- %PDF-
Direktori : /lib/python2.7/site-packages/salt/transport/ |
Current File : //lib/python2.7/site-packages/salt/transport/tcp.py |
# -*- coding: utf-8 -*- ''' TCP transport classes Wire protocol: "len(payload) msgpack({'head': SOMEHEADER, 'body': SOMEBODY})" ''' # Import Python Libs from __future__ import absolute_import, print_function, unicode_literals import errno import logging import socket import os import weakref import time import threading import traceback # Import Salt Libs import salt.crypt import salt.utils.asynchronous import salt.utils.event import salt.utils.files import salt.utils.platform import salt.utils.process import salt.utils.verify import salt.payload import salt.exceptions import salt.transport.frame import salt.transport.ipc import salt.transport.client import salt.transport.server import salt.transport.mixins.auth from salt.ext import six from salt.ext.six.moves import queue # pylint: disable=import-error from salt.exceptions import SaltReqTimeoutError, SaltClientError from salt.transport import iter_transport_opts # Import Tornado Libs import tornado import tornado.tcpserver import tornado.gen import tornado.concurrent import tornado.tcpclient import tornado.netutil import tornado.iostream # pylint: disable=import-error,no-name-in-module if six.PY2: import urlparse else: import urllib.parse as urlparse # pylint: enable=import-error,no-name-in-module # Import third party libs import msgpack try: from M2Crypto import RSA HAS_M2 = True except ImportError: HAS_M2 = False try: from Cryptodome.Cipher import PKCS1_OAEP except ImportError: from Crypto.Cipher import PKCS1_OAEP if six.PY3 and salt.utils.platform.is_windows(): USE_LOAD_BALANCER = True else: USE_LOAD_BALANCER = False if USE_LOAD_BALANCER: import threading import multiprocessing import tornado.util from salt.utils.process import SignalHandlingMultiprocessingProcess log = logging.getLogger(__name__) def _set_tcp_keepalive(sock, opts): ''' Ensure that TCP keepalives are set for the socket. ''' if hasattr(socket, 'SO_KEEPALIVE'): if opts.get('tcp_keepalive', False): sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if hasattr(socket, 'SOL_TCP'): if hasattr(socket, 'TCP_KEEPIDLE'): tcp_keepalive_idle = opts.get('tcp_keepalive_idle', -1) if tcp_keepalive_idle > 0: sock.setsockopt( socket.SOL_TCP, socket.TCP_KEEPIDLE, int(tcp_keepalive_idle)) if hasattr(socket, 'TCP_KEEPCNT'): tcp_keepalive_cnt = opts.get('tcp_keepalive_cnt', -1) if tcp_keepalive_cnt > 0: sock.setsockopt( socket.SOL_TCP, socket.TCP_KEEPCNT, int(tcp_keepalive_cnt)) if hasattr(socket, 'TCP_KEEPINTVL'): tcp_keepalive_intvl = opts.get('tcp_keepalive_intvl', -1) if tcp_keepalive_intvl > 0: sock.setsockopt( socket.SOL_TCP, socket.TCP_KEEPINTVL, int(tcp_keepalive_intvl)) if hasattr(socket, 'SIO_KEEPALIVE_VALS'): # Windows doesn't support TCP_KEEPIDLE, TCP_KEEPCNT, nor # TCP_KEEPINTVL. Instead, it has its own proprietary # SIO_KEEPALIVE_VALS. tcp_keepalive_idle = opts.get('tcp_keepalive_idle', -1) tcp_keepalive_intvl = opts.get('tcp_keepalive_intvl', -1) # Windows doesn't support changing something equivalent to # TCP_KEEPCNT. if tcp_keepalive_idle > 0 or tcp_keepalive_intvl > 0: # Windows defaults may be found by using the link below. # Search for 'KeepAliveTime' and 'KeepAliveInterval'. # https://technet.microsoft.com/en-us/library/bb726981.aspx#EDAA # If one value is set and the other isn't, we still need # to send both values to SIO_KEEPALIVE_VALS and they both # need to be valid. So in that case, use the Windows # default. if tcp_keepalive_idle <= 0: tcp_keepalive_idle = 7200 if tcp_keepalive_intvl <= 0: tcp_keepalive_intvl = 1 # The values expected are in milliseconds, so multiply by # 1000. sock.ioctl(socket.SIO_KEEPALIVE_VALS, ( 1, int(tcp_keepalive_idle * 1000), int(tcp_keepalive_intvl * 1000))) else: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 0) if USE_LOAD_BALANCER: class LoadBalancerServer(SignalHandlingMultiprocessingProcess): ''' Raw TCP server which runs in its own process and will listen for incoming connections. Each incoming connection will be sent via multiprocessing queue to the workers. Since the queue is shared amongst workers, only one worker will handle a given connection. ''' # TODO: opts! # Based on default used in tornado.netutil.bind_sockets() backlog = 128 def __init__(self, opts, socket_queue, **kwargs): super(LoadBalancerServer, self).__init__(**kwargs) self.opts = opts self.socket_queue = socket_queue self._socket = None # __setstate__ and __getstate__ are only used on Windows. # We do this so that __init__ will be invoked on Windows in the child # process so that a register_after_fork() equivalent will work on # Windows. def __setstate__(self, state): self._is_child = True self.__init__( state['opts'], state['socket_queue'], log_queue=state['log_queue'], log_queue_level=state['log_queue_level'] ) def __getstate__(self): return { 'opts': self.opts, 'socket_queue': self.socket_queue, 'log_queue': self.log_queue, 'log_queue_level': self.log_queue_level } def close(self): if self._socket is not None: self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() self._socket = None def __del__(self): self.close() def run(self): ''' Start the load balancer ''' self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _set_tcp_keepalive(self._socket, self.opts) self._socket.setblocking(1) self._socket.bind((self.opts['interface'], int(self.opts['ret_port']))) self._socket.listen(self.backlog) while True: try: # Wait for a connection to occur since the socket is # blocking. connection, address = self._socket.accept() # Wait for a free slot to be available to put # the connection into. # Sockets are picklable on Windows in Python 3. self.socket_queue.put((connection, address), True, None) except socket.error as e: # ECONNABORTED indicates that there was a connection # but it was closed while still in the accept queue. # (observed on FreeBSD). if tornado.util.errno_from_exception(e) == errno.ECONNABORTED: continue raise # TODO: move serial down into message library class AsyncTCPReqChannel(salt.transport.client.ReqChannel): ''' Encapsulate sending routines to tcp. Note: this class returns a singleton ''' # This class is only a singleton per minion/master pair # mapping of io_loop -> {key -> channel} instance_map = weakref.WeakKeyDictionary() def __new__(cls, opts, **kwargs): ''' Only create one instance of channel per __key() ''' # do we have any mapping for this io_loop io_loop = kwargs.get('io_loop') or tornado.ioloop.IOLoop.current() if io_loop not in cls.instance_map: cls.instance_map[io_loop] = weakref.WeakValueDictionary() loop_instance_map = cls.instance_map[io_loop] key = cls.__key(opts, **kwargs) obj = loop_instance_map.get(key) if obj is None: log.debug('Initializing new AsyncTCPReqChannel for %s', key) # we need to make a local variable for this, as we are going to store # it in a WeakValueDictionary-- which will remove the item if no one # references it-- this forces a reference while we return to the caller obj = object.__new__(cls) obj.__singleton_init__(opts, **kwargs) obj._instance_key = key loop_instance_map[key] = obj obj._refcount = 1 obj._refcount_lock = threading.RLock() else: with obj._refcount_lock: obj._refcount += 1 log.debug('Re-using AsyncTCPReqChannel for %s', key) return obj @classmethod def __key(cls, opts, **kwargs): if 'master_uri' in kwargs: opts['master_uri'] = kwargs['master_uri'] return (opts['pki_dir'], # where the keys are stored opts['id'], # minion ID opts['master_uri'], kwargs.get('crypt', 'aes'), # TODO: use the same channel for crypt ) # has to remain empty for singletons, since __init__ will *always* be called def __init__(self, opts, **kwargs): pass # an init for the singleton instance to call def __singleton_init__(self, opts, **kwargs): self.opts = dict(opts) self.serial = salt.payload.Serial(self.opts) # crypt defaults to 'aes' self.crypt = kwargs.get('crypt', 'aes') self.io_loop = kwargs.get('io_loop') or tornado.ioloop.IOLoop.current() if self.crypt != 'clear': self.auth = salt.crypt.AsyncAuth(self.opts, io_loop=self.io_loop) resolver = kwargs.get('resolver') parse = urlparse.urlparse(self.opts['master_uri']) master_host, master_port = parse.netloc.rsplit(':', 1) self.master_addr = (master_host, int(master_port)) self._closing = False self.message_client = SaltMessageClientPool(self.opts, args=(self.opts, master_host, int(master_port),), kwargs={'io_loop': self.io_loop, 'resolver': resolver, 'source_ip': self.opts.get('source_ip'), 'source_port': self.opts.get('source_ret_port')}) def close(self): if self._closing: return if self._refcount > 1: # Decrease refcount with self._refcount_lock: self._refcount -= 1 log.debug( 'This is not the last %s instance. Not closing yet.', self.__class__.__name__ ) return log.debug('Closing %s instance', self.__class__.__name__) self._closing = True self.message_client.close() # Remove the entry from the instance map so that a closed entry may not # be reused. # This forces this operation even if the reference count of the entry # has not yet gone to zero. if self.io_loop in self.__class__.instance_map: loop_instance_map = self.__class__.instance_map[self.io_loop] if self._instance_key in loop_instance_map: del loop_instance_map[self._instance_key] if not loop_instance_map: del self.__class__.instance_map[self.io_loop] def __del__(self): with self._refcount_lock: # Make sure we actually close no matter if something # went wrong with our ref counting self._refcount = 1 try: self.close() except socket.error as exc: if exc.errno != errno.EBADF: # If its not a bad file descriptor error, raise raise def _package_load(self, load): return { 'enc': self.crypt, 'load': load, } @tornado.gen.coroutine def crypted_transfer_decode_dictentry(self, load, dictkey=None, tries=3, timeout=60): if not self.auth.authenticated: yield self.auth.authenticate() ret = yield self.message_client.send(self._package_load(self.auth.crypticle.dumps(load)), timeout=timeout) key = self.auth.get_keys() if HAS_M2: aes = key.private_decrypt(ret['key'], RSA.pkcs1_oaep_padding) else: cipher = PKCS1_OAEP.new(key) aes = cipher.decrypt(ret['key']) pcrypt = salt.crypt.Crypticle(self.opts, aes) data = pcrypt.loads(ret[dictkey]) if six.PY3: data = salt.transport.frame.decode_embedded_strs(data) raise tornado.gen.Return(data) @tornado.gen.coroutine def _crypted_transfer(self, load, tries=3, timeout=60): ''' In case of authentication errors, try to renegotiate authentication and retry the method. Indeed, we can fail too early in case of a master restart during a minion state execution call ''' @tornado.gen.coroutine def _do_transfer(): data = yield self.message_client.send(self._package_load(self.auth.crypticle.dumps(load)), timeout=timeout, ) # we may not have always data # as for example for saltcall ret submission, this is a blind # communication, we do not subscribe to return events, we just # upload the results to the master if data: data = self.auth.crypticle.loads(data) if six.PY3: data = salt.transport.frame.decode_embedded_strs(data) raise tornado.gen.Return(data) if not self.auth.authenticated: yield self.auth.authenticate() try: ret = yield _do_transfer() raise tornado.gen.Return(ret) except salt.crypt.AuthenticationError: yield self.auth.authenticate() ret = yield _do_transfer() raise tornado.gen.Return(ret) @tornado.gen.coroutine def _uncrypted_transfer(self, load, tries=3, timeout=60): ret = yield self.message_client.send(self._package_load(load), timeout=timeout) raise tornado.gen.Return(ret) @tornado.gen.coroutine def send(self, load, tries=3, timeout=60, raw=False): ''' Send a request, return a future which will complete when we send the message ''' try: if self.crypt == 'clear': ret = yield self._uncrypted_transfer(load, tries=tries, timeout=timeout) else: ret = yield self._crypted_transfer(load, tries=tries, timeout=timeout) except tornado.iostream.StreamClosedError: # Convert to 'SaltClientError' so that clients can handle this # exception more appropriately. raise SaltClientError('Connection to master lost') raise tornado.gen.Return(ret) class AsyncTCPPubChannel(salt.transport.mixins.auth.AESPubClientMixin, salt.transport.client.AsyncPubChannel): def __init__(self, opts, **kwargs): self.opts = opts self.serial = salt.payload.Serial(self.opts) self.crypt = kwargs.get('crypt', 'aes') self.io_loop = kwargs.get('io_loop') or tornado.ioloop.IOLoop.current() self.connected = False self._closing = False self._reconnected = False self.event = salt.utils.event.get_event( 'minion', opts=self.opts, listen=False ) def close(self): if self._closing: return self._closing = True if hasattr(self, 'message_client'): self.message_client.close() def __del__(self): self.close() def _package_load(self, load): return { 'enc': self.crypt, 'load': load, } @tornado.gen.coroutine def send_id(self, tok, force_auth): ''' Send the minion id to the master so that the master may better track the connection state of the minion. In case of authentication errors, try to renegotiate authentication and retry the method. ''' load = {'id': self.opts['id'], 'tok': tok} @tornado.gen.coroutine def _do_transfer(): msg = self._package_load(self.auth.crypticle.dumps(load)) package = salt.transport.frame.frame_msg(msg, header=None) yield self.message_client.write_to_stream(package) raise tornado.gen.Return(True) if force_auth or not self.auth.authenticated: count = 0 while count <= self.opts['tcp_authentication_retries'] or self.opts['tcp_authentication_retries'] < 0: try: yield self.auth.authenticate() break except SaltClientError as exc: log.debug(exc) count += 1 try: ret = yield _do_transfer() raise tornado.gen.Return(ret) except salt.crypt.AuthenticationError: yield self.auth.authenticate() ret = yield _do_transfer() raise tornado.gen.Return(ret) @tornado.gen.coroutine def connect_callback(self, result): if self._closing: return # Force re-auth on reconnect since the master # may have been restarted yield self.send_id(self.tok, self._reconnected) self.connected = True self.event.fire_event( {'master': self.opts['master']}, '__master_connected' ) if self._reconnected: # On reconnects, fire a master event to notify that the minion is # available. if self.opts.get('__role') == 'syndic': data = 'Syndic {0} started at {1}'.format( self.opts['id'], time.asctime() ) tag = salt.utils.event.tagify( [self.opts['id'], 'start'], 'syndic' ) else: data = 'Minion {0} started at {1}'.format( self.opts['id'], time.asctime() ) tag = salt.utils.event.tagify( [self.opts['id'], 'start'], 'minion' ) load = {'id': self.opts['id'], 'cmd': '_minion_event', 'pretag': None, 'tok': self.tok, 'data': data, 'tag': tag} req_channel = salt.utils.asynchronous.SyncWrapper( AsyncTCPReqChannel, (self.opts,) ) try: req_channel.send(load, timeout=60) except salt.exceptions.SaltReqTimeoutError: log.info('fire_master failed: master could not be contacted. Request timed out.') except Exception: log.info('fire_master failed: %s', traceback.format_exc()) finally: # SyncWrapper will call either close() or destroy(), whichever is available del req_channel else: self._reconnected = True def disconnect_callback(self): if self._closing: return self.connected = False self.event.fire_event( {'master': self.opts['master']}, '__master_disconnected' ) @tornado.gen.coroutine def connect(self): try: self.auth = salt.crypt.AsyncAuth(self.opts, io_loop=self.io_loop) self.tok = self.auth.gen_token(b'salt') if not self.auth.authenticated: yield self.auth.authenticate() if self.auth.authenticated: self.message_client = SaltMessageClientPool( self.opts, args=(self.opts, self.opts['master_ip'], int(self.auth.creds['publish_port']),), kwargs={'io_loop': self.io_loop, 'connect_callback': self.connect_callback, 'disconnect_callback': self.disconnect_callback, 'source_ip': self.opts.get('source_ip'), 'source_port': self.opts.get('source_publish_port')}) yield self.message_client.connect() # wait for the client to be connected self.connected = True # TODO: better exception handling... except KeyboardInterrupt: raise except Exception as exc: if '-|RETRY|-' not in six.text_type(exc): raise SaltClientError('Unable to sign_in to master: {0}'.format(exc)) # TODO: better error message def on_recv(self, callback): ''' Register an on_recv callback ''' if callback is None: return self.message_client.on_recv(callback) @tornado.gen.coroutine def wrap_callback(body): if not isinstance(body, dict): # TODO: For some reason we need to decode here for things # to work. Fix this. body = msgpack.loads(body) if six.PY3: body = salt.transport.frame.decode_embedded_strs(body) ret = yield self._decode_payload(body) callback(ret) return self.message_client.on_recv(wrap_callback) class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.transport.server.ReqServerChannel): # TODO: opts! backlog = 5 def __init__(self, opts): salt.transport.server.ReqServerChannel.__init__(self, opts) self._socket = None @property def socket(self): return self._socket def close(self): if self._socket is not None: try: self._socket.shutdown(socket.SHUT_RDWR) except socket.error as exc: if exc.errno == errno.ENOTCONN: # We may try to shutdown a socket which is already disconnected. # Ignore this condition and continue. pass else: raise exc self._socket.close() self._socket = None if hasattr(self.req_server, 'shutdown'): try: self.req_server.shutdown() except Exception as exc: log.exception('TCPReqServerChannel close generated an exception: %s', str(exc)) elif hasattr(self.req_server, 'stop'): try: self.req_server.stop() except socket.error as exc: if exc.errno != 9: raise log.exception('TCPReqServerChannel close generated an exception: %s', str(exc)) def __del__(self): self.close() def pre_fork(self, process_manager): ''' Pre-fork we need to create the zmq router device ''' salt.transport.mixins.auth.AESReqServerMixin.pre_fork(self, process_manager) if USE_LOAD_BALANCER: self.socket_queue = multiprocessing.Queue() process_manager.add_process( LoadBalancerServer, args=(self.opts, self.socket_queue) ) elif not salt.utils.platform.is_windows(): self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _set_tcp_keepalive(self._socket, self.opts) self._socket.setblocking(0) self._socket.bind((self.opts['interface'], int(self.opts['ret_port']))) def post_fork(self, payload_handler, io_loop): ''' After forking we need to create all of the local sockets to listen to the router payload_handler: function to call with your payloads ''' self.payload_handler = payload_handler self.io_loop = io_loop self.serial = salt.payload.Serial(self.opts) with salt.utils.asynchronous.current_ioloop(self.io_loop): if USE_LOAD_BALANCER: self.req_server = LoadBalancerWorker(self.socket_queue, self.handle_message, ssl_options=self.opts.get('ssl')) else: if salt.utils.platform.is_windows(): self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _set_tcp_keepalive(self._socket, self.opts) self._socket.setblocking(0) self._socket.bind((self.opts['interface'], int(self.opts['ret_port']))) self.req_server = SaltMessageServer(self.handle_message, ssl_options=self.opts.get('ssl'), io_loop=self.io_loop) self.req_server.add_socket(self._socket) self._socket.listen(self.backlog) salt.transport.mixins.auth.AESReqServerMixin.post_fork(self, payload_handler, io_loop) @tornado.gen.coroutine def handle_message(self, stream, header, payload): ''' Handle incoming messages from underylying tcp streams ''' try: try: payload = self._decode_payload(payload) except Exception: stream.write(salt.transport.frame.frame_msg('bad load', header=header)) raise tornado.gen.Return() # TODO helper functions to normalize payload? if not isinstance(payload, dict) or not isinstance(payload.get('load'), dict): yield stream.write(salt.transport.frame.frame_msg( 'payload and load must be a dict', header=header)) raise tornado.gen.Return() try: id_ = payload['load'].get('id', '') if str('\0') in id_: log.error('Payload contains an id with a null byte: %s', payload) stream.send(self.serial.dumps('bad load: id contains a null byte')) raise tornado.gen.Return() except TypeError: log.error('Payload contains non-string id: %s', payload) stream.send(self.serial.dumps('bad load: id {0} is not a string'.format(id_))) raise tornado.gen.Return() # intercept the "_auth" commands, since the main daemon shouldn't know # anything about our key auth if payload['enc'] == 'clear' and payload.get('load', {}).get('cmd') == '_auth': yield stream.write(salt.transport.frame.frame_msg( self._auth(payload['load']), header=header)) raise tornado.gen.Return() # TODO: test try: ret, req_opts = yield self.payload_handler(payload) except Exception as e: # always attempt to return an error to the minion stream.write('Some exception handling minion payload') log.error('Some exception handling a payload from minion', exc_info=True) stream.close() raise tornado.gen.Return() req_fun = req_opts.get('fun', 'send') if req_fun == 'send_clear': stream.write(salt.transport.frame.frame_msg(ret, header=header)) elif req_fun == 'send': stream.write(salt.transport.frame.frame_msg(self.crypticle.dumps(ret), header=header)) elif req_fun == 'send_private': stream.write(salt.transport.frame.frame_msg(self._encrypt_private(ret, req_opts['key'], req_opts['tgt'], ), header=header)) else: log.error('Unknown req_fun %s', req_fun) # always attempt to return an error to the minion stream.write('Server-side exception handling payload') stream.close() except tornado.gen.Return: raise except tornado.iostream.StreamClosedError: # Stream was closed. This could happen if the remote side # closed the connection on its end (eg in a timeout or shutdown # situation). log.error('Connection was unexpectedly closed', exc_info=True) except Exception as exc: # pylint: disable=broad-except # Absorb any other exceptions log.error('Unexpected exception occurred: %s', exc, exc_info=True) raise tornado.gen.Return() class SaltMessageServer(tornado.tcpserver.TCPServer, object): ''' Raw TCP server which will receive all of the TCP streams and re-assemble messages that are sent through to us ''' def __init__(self, message_handler, *args, **kwargs): io_loop = kwargs.pop('io_loop', None) or tornado.ioloop.IOLoop.current() super(SaltMessageServer, self).__init__(*args, **kwargs) self.io_loop = io_loop self.clients = [] self.message_handler = message_handler self._shutting_down = False @tornado.gen.coroutine def handle_stream(self, stream, address): ''' Handle incoming streams and add messages to the incoming queue ''' log.trace('Req client %s connected', address) self.clients.append((stream, address)) unpacker = msgpack.Unpacker() try: while True: wire_bytes = yield stream.read_bytes(4096, partial=True) unpacker.feed(wire_bytes) for framed_msg in unpacker: if six.PY3: framed_msg = salt.transport.frame.decode_embedded_strs( framed_msg ) header = framed_msg['head'] self.io_loop.spawn_callback(self.message_handler, stream, header, framed_msg['body']) except tornado.iostream.StreamClosedError: log.trace('req client disconnected %s', address) self.remove_client((stream, address)) except Exception as e: log.trace('other master-side exception: %s', e) self.remove_client((stream, address)) stream.close() def remove_client(self, client): try: self.clients.remove(client) except ValueError: log.trace("Message server client was not in list to remove") def shutdown(self): ''' Shutdown the whole server ''' if self._shutting_down: return self._shutting_down = True for item in self.clients: client, address = item client.close() self.remove_client(item) try: self.stop() except socket.error as exc: if exc.errno != 9: raise if USE_LOAD_BALANCER: class LoadBalancerWorker(SaltMessageServer): ''' This will receive TCP connections from 'LoadBalancerServer' via a multiprocessing queue. Since the queue is shared amongst workers, only one worker will handle a given connection. ''' def __init__(self, socket_queue, message_handler, *args, **kwargs): super(LoadBalancerWorker, self).__init__( message_handler, *args, **kwargs) self.socket_queue = socket_queue self._stop = threading.Event() self.thread = threading.Thread(target=self.socket_queue_thread) self.thread.start() def stop(self): self._stop.set() self.thread.join() def socket_queue_thread(self): try: while True: try: client_socket, address = self.socket_queue.get(True, 1) except queue.Empty: if self._stop.is_set(): break continue # 'self.io_loop' initialized in super class # 'tornado.tcpserver.TCPServer'. # 'self._handle_connection' defined in same super class. self.io_loop.spawn_callback( self._handle_connection, client_socket, address) except (KeyboardInterrupt, SystemExit): pass class TCPClientKeepAlive(tornado.tcpclient.TCPClient): ''' Override _create_stream() in TCPClient to enable keep alive support. ''' def __init__(self, opts, resolver=None): self.opts = opts super(TCPClientKeepAlive, self).__init__(resolver=resolver) def _create_stream(self, max_buffer_size, af, addr, **kwargs): # pylint: disable=unused-argument ''' Override _create_stream() in TCPClient. Tornado 4.5 added the kwargs 'source_ip' and 'source_port'. Due to this, use **kwargs to swallow these and any future kwargs to maintain compatibility. ''' # Always connect in plaintext; we'll convert to ssl if necessary # after one connection has completed. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) _set_tcp_keepalive(sock, self.opts) stream = tornado.iostream.IOStream( sock, max_buffer_size=max_buffer_size) if tornado.version_info < (5,): return stream.connect(addr) return stream, stream.connect(addr) class SaltMessageClientPool(salt.transport.MessageClientPool): ''' Wrapper class of SaltMessageClient to avoid blocking waiting while writing data to socket. ''' def __init__(self, opts, args=None, kwargs=None): super(SaltMessageClientPool, self).__init__(SaltMessageClient, opts, args=args, kwargs=kwargs) def __del__(self): self.close() def close(self): for message_client in self.message_clients: message_client.close() self.message_clients = [] @tornado.gen.coroutine def connect(self): futures = [] for message_client in self.message_clients: futures.append(message_client.connect()) for future in futures: yield future raise tornado.gen.Return(None) def on_recv(self, *args, **kwargs): for message_client in self.message_clients: message_client.on_recv(*args, **kwargs) def send(self, *args, **kwargs): message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue)) return message_clients[0].send(*args, **kwargs) def write_to_stream(self, *args, **kwargs): message_clients = sorted(self.message_clients, key=lambda x: len(x.send_queue)) return message_clients[0]._stream.write(*args, **kwargs) # TODO consolidate with IPCClient # TODO: limit in-flight messages. # TODO: singleton? Something to not re-create the tcp connection so much class SaltMessageClient(object): ''' Low-level message sending client ''' def __init__(self, opts, host, port, io_loop=None, resolver=None, connect_callback=None, disconnect_callback=None, source_ip=None, source_port=None): self.opts = opts self.host = host self.port = port self.source_ip = source_ip self.source_port = source_port self.connect_callback = connect_callback self.disconnect_callback = disconnect_callback self.io_loop = io_loop or tornado.ioloop.IOLoop.current() with salt.utils.asynchronous.current_ioloop(self.io_loop): self._tcp_client = TCPClientKeepAlive(opts, resolver=resolver) self._mid = 1 self._max_messages = int((1 << 31) - 2) # number of IDs before we wrap # TODO: max queue size self.send_queue = [] # queue of messages to be sent self.send_future_map = {} # mapping of request_id -> Future self.send_timeout_map = {} # request_id -> timeout_callback self._read_until_future = None self._on_recv = None self._closing = False self._connecting_future = self.connect() self._stream_return_future = tornado.concurrent.Future() self.io_loop.spawn_callback(self._stream_return) def _stop_io_loop(self): if self.io_loop is not None: self.io_loop.stop() # TODO: timeout inflight sessions def close(self): if self._closing: return self._closing = True if hasattr(self, '_stream') and not self._stream.closed(): # If _stream_return() hasn't completed, it means the IO # Loop is stopped (such as when using # 'salt.utils.asynchronous.SyncWrapper'). Ensure that # _stream_return() completes by restarting the IO Loop. # This will prevent potential errors on shutdown. try: orig_loop = tornado.ioloop.IOLoop.current() self.io_loop.make_current() self._stream.close() if self._read_until_future is not None: # This will prevent this message from showing up: # '[ERROR ] Future exception was never retrieved: # StreamClosedError' # This happens because the logic is always waiting to read # the next message and the associated read future is marked # 'StreamClosedError' when the stream is closed. if self._read_until_future.done(): self._read_until_future.exception() if (self.io_loop != tornado.ioloop.IOLoop.current(instance=False) or not self._stream_return_future.done()): self.io_loop.add_future( self._stream_return_future, lambda future: self._stop_io_loop() ) self.io_loop.start() except Exception as e: log.info('Exception caught in SaltMessageClient.close: %s', str(e)) finally: orig_loop.make_current() self._tcp_client.close() self.io_loop = None self._read_until_future = None # Clear callback references to allow the object that they belong to # to be deleted. self.connect_callback = None self.disconnect_callback = None def __del__(self): self.close() def connect(self): ''' Ask for this client to reconnect to the origin ''' if hasattr(self, '_connecting_future') and not self._connecting_future.done(): future = self._connecting_future else: future = tornado.concurrent.Future() self._connecting_future = future self.io_loop.add_callback(self._connect) # Add the callback only when a new future is created if self.connect_callback is not None: def handle_future(future): response = future.result() self.io_loop.add_callback(self.connect_callback, response) future.add_done_callback(handle_future) return future # TODO: tcp backoff opts @tornado.gen.coroutine def _connect(self): ''' Try to connect for the rest of time! ''' while True: if self._closing: break try: kwargs = {} if self.source_ip or self.source_port: if tornado.version_info >= (4, 5): ### source_ip and source_port are supported only in Tornado >= 4.5 # See http://www.tornadoweb.org/en/stable/releases/v4.5.0.html # Otherwise will just ignore these args kwargs = {'source_ip': self.source_ip, 'source_port': self.source_port} else: log.warning('If you need a certain source IP/port, consider upgrading Tornado >= 4.5') with salt.utils.asynchronous.current_ioloop(self.io_loop): self._stream = yield self._tcp_client.connect(self.host, self.port, ssl_options=self.opts.get('ssl'), **kwargs) self._connecting_future.set_result(True) break except Exception as exc: log.warn('TCP Message Client encountered an exception %r', exc) yield tornado.gen.sleep(1) # TODO: backoff #self._connecting_future.set_exception(e) @tornado.gen.coroutine def _stream_return(self): try: while not self._closing and ( not self._connecting_future.done() or self._connecting_future.result() is not True): yield self._connecting_future unpacker = msgpack.Unpacker() while not self._closing: try: self._read_until_future = self._stream.read_bytes(4096, partial=True) wire_bytes = yield self._read_until_future unpacker.feed(wire_bytes) for framed_msg in unpacker: if six.PY3: framed_msg = salt.transport.frame.decode_embedded_strs( framed_msg ) header = framed_msg['head'] body = framed_msg['body'] message_id = header.get('mid') if message_id in self.send_future_map: self.send_future_map.pop(message_id).set_result(body) self.remove_message_timeout(message_id) else: if self._on_recv is not None: self.io_loop.spawn_callback(self._on_recv, header, body) else: log.error('Got response for message_id %s that we are not tracking', message_id) except tornado.iostream.StreamClosedError as e: log.debug('tcp stream to %s:%s closed, unable to recv', self.host, self.port) for future in six.itervalues(self.send_future_map): future.set_exception(e) self.send_future_map = {} if self._closing: return if self.disconnect_callback: self.disconnect_callback() # if the last connect finished, then we need to make a new one if self._connecting_future.done(): self._connecting_future = self.connect() yield self._connecting_future except TypeError: # This is an invalid transport if 'detect_mode' in self.opts: log.info('There was an error trying to use TCP transport; ' 'attempting to fallback to another transport') else: raise SaltClientError except Exception as e: log.error('Exception parsing response', exc_info=True) for future in six.itervalues(self.send_future_map): future.set_exception(e) self.send_future_map = {} if self._closing: return if self.disconnect_callback: self.disconnect_callback() # if the last connect finished, then we need to make a new one if self._connecting_future.done(): self._connecting_future = self.connect() yield self._connecting_future finally: self._stream_return_future.set_result(True) @tornado.gen.coroutine def _stream_send(self): while not self._connecting_future.done() or self._connecting_future.result() is not True: yield self._connecting_future while len(self.send_queue) > 0: message_id, item = self.send_queue[0] try: yield self._stream.write(item) del self.send_queue[0] # if the connection is dead, lets fail this send, and make sure we # attempt to reconnect except tornado.iostream.StreamClosedError as e: if message_id in self.send_future_map: self.send_future_map.pop(message_id).set_exception(e) self.remove_message_timeout(message_id) del self.send_queue[0] if self._closing: return if self.disconnect_callback: self.disconnect_callback() # if the last connect finished, then we need to make a new one if self._connecting_future.done(): self._connecting_future = self.connect() yield self._connecting_future def _message_id(self): wrap = False while self._mid in self.send_future_map: if self._mid >= self._max_messages: if wrap: # this shouldn't ever happen, but just in case raise Exception('Unable to find available messageid') self._mid = 1 wrap = True else: self._mid += 1 return self._mid # TODO: return a message object which takes care of multiplexing? def on_recv(self, callback): ''' Register a callback for received messages (that we didn't initiate) ''' if callback is None: self._on_recv = callback else: def wrap_recv(header, body): callback(body) self._on_recv = wrap_recv def remove_message_timeout(self, message_id): if message_id not in self.send_timeout_map: return timeout = self.send_timeout_map.pop(message_id) self.io_loop.remove_timeout(timeout) def timeout_message(self, message_id): if message_id in self.send_timeout_map: del self.send_timeout_map[message_id] if message_id in self.send_future_map: self.send_future_map.pop(message_id).set_exception( SaltReqTimeoutError('Message timed out') ) def send(self, msg, timeout=None, callback=None, raw=False): ''' Send given message, and return a future ''' message_id = self._message_id() header = {'mid': message_id} future = tornado.concurrent.Future() if callback is not None: def handle_future(future): response = future.result() self.io_loop.add_callback(callback, response) future.add_done_callback(handle_future) # Add this future to the mapping self.send_future_map[message_id] = future if self.opts.get('detect_mode') is True: timeout = 1 if timeout is not None: send_timeout = self.io_loop.call_later(timeout, self.timeout_message, message_id) self.send_timeout_map[message_id] = send_timeout # if we don't have a send queue, we need to spawn the callback to do the sending if len(self.send_queue) == 0: self.io_loop.spawn_callback(self._stream_send) self.send_queue.append((message_id, salt.transport.frame.frame_msg(msg, header=header))) return future class Subscriber(object): ''' Client object for use with the TCP publisher server ''' def __init__(self, stream, address): self.stream = stream self.address = address self._closing = False self._read_until_future = None self.id_ = None def close(self): if self._closing: return self._closing = True if not self.stream.closed(): self.stream.close() if self._read_until_future is not None and self._read_until_future.done(): # This will prevent this message from showing up: # '[ERROR ] Future exception was never retrieved: # StreamClosedError' # This happens because the logic is always waiting to read # the next message and the associated read future is marked # 'StreamClosedError' when the stream is closed. self._read_until_future.exception() def __del__(self): self.close() class PubServer(tornado.tcpserver.TCPServer, object): ''' TCP publisher ''' def __init__(self, opts, io_loop=None): super(PubServer, self).__init__(ssl_options=opts.get('ssl')) self.io_loop = io_loop self.opts = opts self._closing = False self.clients = set() self.aes_funcs = salt.master.AESFuncs(self.opts) self.present = {} self.presence_events = False if self.opts.get('presence_events', False): tcp_only = True for transport, _ in iter_transport_opts(self.opts): if transport != 'tcp': tcp_only = False if tcp_only: # Only when the transport is TCP only, the presence events will # be handled here. Otherwise, it will be handled in the # 'Maintenance' process. self.presence_events = True if self.presence_events: self.event = salt.utils.event.get_event( 'master', opts=self.opts, listen=False ) def close(self): if self._closing: return self._closing = True def __del__(self): self.close() def _add_client_present(self, client): id_ = client.id_ if id_ in self.present: clients = self.present[id_] clients.add(client) else: self.present[id_] = {client} if self.presence_events: data = {'new': [id_], 'lost': []} self.event.fire_event( data, salt.utils.event.tagify('change', 'presence') ) data = {'present': list(self.present.keys())} self.event.fire_event( data, salt.utils.event.tagify('present', 'presence') ) def _remove_client_present(self, client): id_ = client.id_ if id_ is None or id_ not in self.present: # This is possible if _remove_client_present() is invoked # before the minion's id is validated. return clients = self.present[id_] if client not in clients: # Since _remove_client_present() is potentially called from # _stream_read() and/or publish_payload(), it is possible for # it to be called twice, in which case we will get here. # This is not an abnormal case, so no logging is required. return clients.remove(client) if len(clients) == 0: del self.present[id_] if self.presence_events: data = {'new': [], 'lost': [id_]} self.event.fire_event( data, salt.utils.event.tagify('change', 'presence') ) data = {'present': list(self.present.keys())} self.event.fire_event( data, salt.utils.event.tagify('present', 'presence') ) @tornado.gen.coroutine def _stream_read(self, client): unpacker = msgpack.Unpacker() while not self._closing: try: client._read_until_future = client.stream.read_bytes(4096, partial=True) wire_bytes = yield client._read_until_future unpacker.feed(wire_bytes) for framed_msg in unpacker: if six.PY3: framed_msg = salt.transport.frame.decode_embedded_strs( framed_msg ) body = framed_msg['body'] if body['enc'] != 'aes': # We only accept 'aes' encoded messages for 'id' continue crypticle = salt.crypt.Crypticle(self.opts, salt.master.SMaster.secrets['aes']['secret'].value) load = crypticle.loads(body['load']) if six.PY3: load = salt.transport.frame.decode_embedded_strs(load) if not self.aes_funcs.verify_minion(load['id'], load['tok']): continue client.id_ = load['id'] self._add_client_present(client) except tornado.iostream.StreamClosedError as e: log.debug('tcp stream to %s closed, unable to recv', client.address) client.close() self._remove_client_present(client) self.clients.discard(client) break except Exception as e: log.error('Exception parsing response', exc_info=True) continue def handle_stream(self, stream, address): log.trace('Subscriber at %s connected', address) client = Subscriber(stream, address) self.clients.add(client) self.io_loop.spawn_callback(self._stream_read, client) # TODO: ACK the publish through IPC @tornado.gen.coroutine def publish_payload(self, package, _): log.debug('TCP PubServer sending payload: %s', package) payload = salt.transport.frame.frame_msg(package['payload']) to_remove = [] if 'topic_lst' in package: topic_lst = package['topic_lst'] for topic in topic_lst: if topic in self.present: # This will rarely be a list of more than 1 item. It will # be more than 1 item if the minion disconnects from the # master in an unclean manner (eg cable yank), then # restarts and the master is yet to detect the disconnect # via TCP keep-alive. for client in self.present[topic]: try: # Write the packed str f = client.stream.write(payload) self.io_loop.add_future(f, lambda f: True) except tornado.iostream.StreamClosedError: to_remove.append(client) else: log.debug('Publish target %s not connected', topic) else: for client in self.clients: try: # Write the packed str f = client.stream.write(payload) self.io_loop.add_future(f, lambda f: True) except tornado.iostream.StreamClosedError: to_remove.append(client) for client in to_remove: log.debug('Subscriber at %s has disconnected from publisher', client.address) client.close() self._remove_client_present(client) self.clients.discard(client) log.trace('TCP PubServer finished publishing payload') class TCPPubServerChannel(salt.transport.server.PubServerChannel): # TODO: opts! # Based on default used in tornado.netutil.bind_sockets() backlog = 128 def __init__(self, opts): self.opts = opts self.serial = salt.payload.Serial(self.opts) # TODO: in init? self.ckminions = salt.utils.minions.CkMinions(opts) self.io_loop = None def __setstate__(self, state): salt.master.SMaster.secrets = state['secrets'] self.__init__(state['opts']) def __getstate__(self): return {'opts': self.opts, 'secrets': salt.master.SMaster.secrets} def _publish_daemon(self, **kwargs): ''' Bind to the interface specified in the configuration file ''' salt.utils.process.appendproctitle(self.__class__.__name__) log_queue = kwargs.get('log_queue') if log_queue is not None: salt.log.setup.set_multiprocessing_logging_queue(log_queue) log_queue_level = kwargs.get('log_queue_level') if log_queue_level is not None: salt.log.setup.set_multiprocessing_logging_level(log_queue_level) salt.log.setup.setup_multiprocessing_logging(log_queue) # Check if io_loop was set outside if self.io_loop is None: self.io_loop = tornado.ioloop.IOLoop.current() # Spin up the publisher pub_server = PubServer(self.opts, io_loop=self.io_loop) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) _set_tcp_keepalive(sock, self.opts) sock.setblocking(0) sock.bind((self.opts['interface'], int(self.opts['publish_port']))) sock.listen(self.backlog) # pub_server will take ownership of the socket pub_server.add_socket(sock) # Set up Salt IPC server if self.opts.get('ipc_mode', '') == 'tcp': pull_uri = int(self.opts.get('tcp_master_publish_pull', 4514)) else: pull_uri = os.path.join(self.opts['sock_dir'], 'publish_pull.ipc') pull_sock = salt.transport.ipc.IPCMessageServer( pull_uri, io_loop=self.io_loop, payload_handler=pub_server.publish_payload, ) # Securely create socket log.info('Starting the Salt Puller on %s', pull_uri) with salt.utils.files.set_umask(0o177): pull_sock.start() # run forever try: self.io_loop.start() except (KeyboardInterrupt, SystemExit): salt.log.setup.shutdown_multiprocessing_logging() def pre_fork(self, process_manager, kwargs=None): ''' Do anything necessary pre-fork. Since this is on the master side this will primarily be used to create IPC channels and create our daemon process to do the actual publishing ''' process_manager.add_process(self._publish_daemon, kwargs=kwargs) def publish(self, load): ''' Publish "load" to minions ''' payload = {'enc': 'aes'} crypticle = salt.crypt.Crypticle(self.opts, salt.master.SMaster.secrets['aes']['secret'].value) payload['load'] = crypticle.dumps(load) if self.opts['sign_pub_messages']: master_pem_path = os.path.join(self.opts['pki_dir'], 'master.pem') log.debug("Signing data packet") payload['sig'] = salt.crypt.sign_message(master_pem_path, payload['load']) # Use the Salt IPC server if self.opts.get('ipc_mode', '') == 'tcp': pull_uri = int(self.opts.get('tcp_master_publish_pull', 4514)) else: pull_uri = os.path.join(self.opts['sock_dir'], 'publish_pull.ipc') # TODO: switch to the actual asynchronous interface #pub_sock = salt.transport.ipc.IPCMessageClient(self.opts, io_loop=self.io_loop) pub_sock = salt.utils.asynchronous.SyncWrapper( salt.transport.ipc.IPCMessageClient, (pull_uri,) ) pub_sock.connect() int_payload = {'payload': self.serial.dumps(payload)} # add some targeting stuff for lists only (for now) if load['tgt_type'] == 'list': if isinstance(load['tgt'], six.string_types): # Fetch a list of minions that match _res = self.ckminions.check_minions(load['tgt'], tgt_type=load['tgt_type']) match_ids = _res['minions'] log.debug("Publish Side Match: %s", match_ids) # Send list of miions thru so zmq can target them int_payload['topic_lst'] = match_ids else: int_payload['topic_lst'] = load['tgt'] # Send it over IPC! pub_sock.send(int_payload)