611 lines
24 KiB
Python
611 lines
24 KiB
Python
# coding: utf-8
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import queue
|
|
import socket
|
|
import struct
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
from datetime import datetime, timedelta
|
|
|
|
from conf import settings
|
|
from lib.translation import ugettext as _
|
|
from parcel import constants, parcelqueue
|
|
from parcel.base import KnownConnect, LastConnect, Parcel
|
|
from worker.utils import disconnect_handle, replication_after_reconnect
|
|
from worker.audit import Audit
|
|
from xmlrpc_wrapper import ServerProxy
|
|
|
|
try:
|
|
from sendfile import sendfile
|
|
sendfile_available = True
|
|
except ImportError:
|
|
sendfile = None
|
|
sendfile_available = False
|
|
|
|
|
|
log = logging.getLogger('enserver.worker.tcp')
|
|
|
|
|
|
workers = []
|
|
|
|
|
|
class TCPClientWorker(threading.Thread):
|
|
"""
|
|
Maintains uplink connection
|
|
"""
|
|
|
|
def __init__(self, reconnect_timeout=30.0, auth_timeout=60.0):
|
|
super(TCPClientWorker, self).__init__()
|
|
self.connect_address = (settings.host, settings.port)
|
|
self.reconnect_timeout = reconnect_timeout
|
|
self.userid_check = [0, auth_timeout, False]
|
|
self.stop_event = threading.Event()
|
|
self.workers = []
|
|
self.last_message = None
|
|
workers.append(self)
|
|
|
|
def run(self):
|
|
while not self.stop_event.is_set():
|
|
if self.workers:
|
|
if not self.workers_alive() or not self.userid_correct():
|
|
self.stop_workers()
|
|
else:
|
|
self.stop_event.wait(self.reconnect_timeout)
|
|
else:
|
|
self.start_workers()
|
|
self.stop_event.wait(self.reconnect_timeout)
|
|
|
|
def join(self, timeout=None):
|
|
self.stop_workers()
|
|
self.stop_event.set()
|
|
super(TCPClientWorker, self).join()
|
|
|
|
def start_workers(self):
|
|
self.workers = []
|
|
if not self.userid_correct():
|
|
return
|
|
try:
|
|
for state in (TCPWorker.STATE_READER, TCPWorker.STATE_WRITER):
|
|
self.workers.append(self.start_worker(state))
|
|
log.info('Worker type %s connected to %s', state, self.connect_address[0])
|
|
except socket.error:
|
|
log.warning('Unable to connect to %s', self.connect_address[0])
|
|
msg = _('Unable to connect to %s') % self.connect_address[0]
|
|
if self.last_message != msg:
|
|
self.last_message = msg
|
|
Audit.add_message(4001, self.last_message, warning=True)
|
|
self.stop_workers()
|
|
disconnect_handle(is_client=True)
|
|
except Exception as exc:
|
|
log.exception('Unexpected error: %s, cls: %s', exc, self.__class__)
|
|
else:
|
|
self.last_message = ''
|
|
|
|
def stop_workers(self):
|
|
for worker in self.workers:
|
|
worker.join()
|
|
self.workers = []
|
|
|
|
def workers_alive(self):
|
|
for worker in self.workers:
|
|
if not worker.is_alive():
|
|
self.last_message = _('Disconnected from %s') % self.connect_address[0]
|
|
Audit.add_message(4001, self.last_message, warning=True)
|
|
return False
|
|
return True
|
|
|
|
def start_worker(self, state):
|
|
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
client_socket.connect(self.connect_address)
|
|
client_socket.settimeout(settings.tcp_timeout)
|
|
worker = TCPWorker(client_socket, state=state)
|
|
worker.start()
|
|
return worker
|
|
|
|
def userid_correct(self):
|
|
if sum(self.userid_check[:2]) < time.time():
|
|
self.userid_check[0] = time.time()
|
|
try:
|
|
s = ServerProxy(settings.auth_provider)
|
|
response = s.auth_response('', settings.id, False)
|
|
except Exception as exc:
|
|
response = {'error': True, 'msg': _('Authentication server is unavailable')}
|
|
log.warning('Authentication server is unavailable, exception: %s', exc)
|
|
self.userid_check[2] = not response.get('error', True)
|
|
if not self.userid_check[2]:
|
|
log.warning('Authentication is unavailable for user %s', settings.id)
|
|
msg = _('Authentication is unavailable for user %s') % settings.id + '<br>' + response.get('msg')
|
|
if self.last_message != msg:
|
|
self.last_message = msg
|
|
Audit.add_message(4001, self.last_message, warning=True)
|
|
return self.userid_check[2]
|
|
|
|
def connections(self):
|
|
return self.workers
|
|
|
|
|
|
class TCPServerWorker(threading.Thread):
|
|
"""
|
|
Listens server tcp socket, accepts connections from clients
|
|
Every client connection is being served by two threads: reader and writer
|
|
"""
|
|
|
|
def __init__(self, disconnect_timeout=30):
|
|
super(TCPServerWorker, self).__init__()
|
|
self.listen_address = (settings.tcp_host, settings.tcp_port)
|
|
self.static_dir = settings.storage
|
|
self.stop_event = threading.Event()
|
|
workers.append(self)
|
|
self.workers = []
|
|
self.disconnect_timeout = disconnect_timeout
|
|
self.last_timestamp = datetime.now()
|
|
|
|
def run(self):
|
|
try:
|
|
log.info('Serving on %s:%s', *self.listen_address)
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
self.socket.bind(self.listen_address)
|
|
self.socket.listen(10)
|
|
self.socket.settimeout(0.2)
|
|
|
|
while not self.stop_event.isSet():
|
|
try:
|
|
client = self.socket.accept()
|
|
self.process_new_client(*client)
|
|
except socket.error:
|
|
if self.last_timestamp + timedelta(seconds=self.disconnect_timeout) > datetime.now():
|
|
continue
|
|
else:
|
|
self.last_timestamp = datetime.now()
|
|
disconnect_handle(is_client=False, workers=self.workers)
|
|
replication_after_reconnect(self.workers)
|
|
|
|
self.socket.close()
|
|
except Exception as exc:
|
|
log.exception('Unexpected error: %s, cls: %s', exc, self.__class__)
|
|
log.info('Bye!')
|
|
|
|
def join(self, timeout=None):
|
|
self.stop_event.set()
|
|
super(TCPServerWorker, self).join(timeout)
|
|
|
|
def process_new_client(self, client_socket, client_address):
|
|
log.info('New connection from %s', client_address)
|
|
client_socket.settimeout(settings.tcp_timeout)
|
|
worker = TCPWorker(client_socket, is_incoming=True)
|
|
worker.start()
|
|
self.workers.append(worker)
|
|
|
|
def connections(self):
|
|
return self.workers
|
|
|
|
def process_disconnected(self):
|
|
self.workers = filter(lambda worker: worker.is_alive(), self.workers)
|
|
|
|
|
|
class AuthError(Exception):
|
|
def __init__(self, *args, **kwargs):
|
|
self.send_error = kwargs.get('send_error', False)
|
|
super(AuthError, self).__init__(*args)
|
|
|
|
|
|
class TCPWorker(threading.Thread):
|
|
STATE_READER = 1
|
|
STATE_WRITER = -1
|
|
IDLE = 'idle'
|
|
META = 'meta'
|
|
|
|
def __init__(self, socket, keepalive=1.0, state=None, auth_timeout=60.0, is_incoming=False):
|
|
super(TCPWorker, self).__init__()
|
|
assert state in (None, self.STATE_READER, self.STATE_WRITER)
|
|
self.socket = socket
|
|
self.address = socket.getpeername()[0]
|
|
self.in_q = None
|
|
self.out_q = parcelqueue.manager.get(constants.POSTMAN)
|
|
self.keepalive = [0, keepalive]
|
|
self.state = state
|
|
self.is_client = state is not None
|
|
self.parcel = None
|
|
self.server_id = settings.id
|
|
self.connected_id = None
|
|
self.static_dir = settings.storage
|
|
self.bufsize = settings.tcp_bufsize
|
|
self.bufdata = deque()
|
|
self.stop_event = threading.Event()
|
|
self.workflow = None
|
|
self.user_check = [0, auth_timeout]
|
|
self.established = None
|
|
self.status = self.IDLE
|
|
self.data_pos = 0
|
|
self.data_size = 0
|
|
self.total = 0
|
|
self.bitrate = 0
|
|
self.is_incoming = is_incoming
|
|
|
|
def info(self):
|
|
return {'parcel_id': self.parcel.id if self.parcel else '',
|
|
'type': 'up' if self.state == self.STATE_WRITER else 'down',
|
|
'address': self.address,
|
|
'queue': self.in_q.queue_id if self.in_q else '',
|
|
'status': self.status,
|
|
'established': time.mktime(self.established.timetuple()),
|
|
'user': self.connected_id,
|
|
'bitrate': self.bitrate,
|
|
'total': float(self.total),
|
|
'data_pos': float(self.data_pos),
|
|
'data_size': float(self.data_size),
|
|
'direction': 'incoming' if self.is_incoming else 'outgoing'}
|
|
|
|
def run(self):
|
|
# check if pysendfile available
|
|
if sendfile_available:
|
|
self.send_file = self.send_file_syscall
|
|
else:
|
|
self.send_file = self.send_file_native
|
|
|
|
self.name = self.socket.getpeername()[0]
|
|
self.established = datetime.now()
|
|
|
|
try:
|
|
self.handshake()
|
|
except (socket.error, KeyError) as ex:
|
|
log.error('Handshake error: %s', ex)
|
|
event_type = 4001 if self.is_client else 4002
|
|
Audit.add_message(event_type, _('Protocol error when connecting to %s') % self.name, warning=True)
|
|
self.stop_event.set()
|
|
except AuthError as ex:
|
|
log.error('%s - Authentication failed, exception: %s', self.name, ex)
|
|
if ex.send_error:
|
|
self.send_data({'auth_error': True})
|
|
self.stop_event.set()
|
|
|
|
while not self.stop_event.is_set():
|
|
try:
|
|
self.workflow()
|
|
except socket.error:
|
|
self.stop_event.set()
|
|
if self.state == self.STATE_WRITER and self.parcel is not None:
|
|
self.in_q.put(self.parcel)
|
|
except Exception as exc:
|
|
log.exception('Unexpected error: %s', exc)
|
|
self.socket.close()
|
|
log.info('%s - Bye!', self.name)
|
|
|
|
def handshake(self):
|
|
if self.state is not None:
|
|
#
|
|
# client case
|
|
#
|
|
c1 = self.auth_challenge()
|
|
self.send_data({'flow_state': self.state,
|
|
'server_id': self.server_id,
|
|
'challenge': c1})
|
|
|
|
data = self.receive_data()
|
|
if data.get('auth_error', False):
|
|
Audit.add_message(4001, _('Authentication error when connecting to %s') % self.name, warning=True)
|
|
raise AuthError(send_error=False)
|
|
|
|
self.connected_id = data['server_id']
|
|
r1, err_msg = self.auth_response(c1, self.server_id, False)
|
|
if err_msg is not None or data['response'] != r1:
|
|
msg = _('Authentication error when connecting to %s') % self.name
|
|
if err_msg is not None:
|
|
msg = msg + '<br>' + err_msg
|
|
Audit.add_message(4001, msg, warning=True)
|
|
raise AuthError(send_error=True)
|
|
else:
|
|
log.debug('%s - Server authenticated', self.name)
|
|
|
|
r2, err_msg = self.auth_response(data['challenge'], self.server_id, False)
|
|
if err_msg is not None:
|
|
msg = _('Authentication error when connecting to %s') % self.name
|
|
Audit.add_message(4001, msg + '<br>' + err_msg, warning=True)
|
|
raise AuthError(send_error=True)
|
|
|
|
self.send_data({'response': r2})
|
|
|
|
data = self.receive_data()
|
|
if data.get('auth_error', False):
|
|
Audit.add_message(4001, _('Authentication error when connecting to %s') % self.name, warning=True)
|
|
raise AuthError(send_error=False)
|
|
else:
|
|
Audit.add_message(4001, _('Successfully connected to %s') % '%s, %s' % (self.connected_id, self.name))
|
|
log.info('%s - Client authenticated', self.name)
|
|
else:
|
|
#
|
|
# server case
|
|
#
|
|
data = self.receive_data()
|
|
self.connected_id = data['server_id']
|
|
r1, err_msg = self.auth_response(data['challenge'], self.connected_id, True)
|
|
if err_msg is not None:
|
|
msg = _('Error when authenticating client %s') % '%s %s' % (self.name, self.connected_id)
|
|
Audit.add_message(4002, msg + '<br>' + err_msg, warning=True)
|
|
raise AuthError(send_error=True)
|
|
|
|
c1 = self.auth_challenge()
|
|
self.send_data({'server_id': self.server_id,
|
|
'challenge': c1, 'response': r1})
|
|
self.state = data['flow_state'] * -1
|
|
|
|
data = self.receive_data()
|
|
if data.get('auth_error', False):
|
|
msg = _('Error when authenticating client %s') % '%s, %s' % (self.connected_id, self.name)
|
|
Audit.add_message(4002, msg, warning=True)
|
|
raise socket.error('Client %s failed to authenticate server' % self.connected_id)
|
|
|
|
r2, err_msg = self.auth_response(c1, self.connected_id, True)
|
|
if err_msg is not None or data['response'] != r2:
|
|
msg = _('Error when authenticating client %s') % '%s, %s' % (self.connected_id, self.name)
|
|
if err_msg is not None:
|
|
msg = msg + '<br>' + err_msg
|
|
Audit.add_message(4002, msg, warning=True)
|
|
raise AuthError(send_error=True)
|
|
else:
|
|
self.send_data({'auth_error': False})
|
|
Audit.add_message(4002, _('Successfully connected to %s') % '%s, %s' % (self.connected_id, self.name))
|
|
log.info('%s - Client authenticated', self.name)
|
|
self.user_check[0] = time.time()
|
|
|
|
if self.state == self.STATE_READER:
|
|
self.name = '%s reader' % self.connected_id
|
|
self.workflow = self.reader_workflow
|
|
elif self.state == self.STATE_WRITER:
|
|
self.name = '%s writer' % self.connected_id
|
|
self.workflow = self.writer_workflow
|
|
self.in_q = parcelqueue.manager.get('tcp://%s' % self.connected_id)
|
|
known_connect = KnownConnect(self.in_q.queue_id, self.is_client)
|
|
parcelqueue.manager.put_known_connect(known_connect)
|
|
else:
|
|
log.error('Incorrect state value %s', self.state)
|
|
self.stop_event.set()
|
|
|
|
def reader_workflow(self):
|
|
# receive parcel meta
|
|
data = self.receive_data()
|
|
|
|
if not self.is_client and sum(self.user_check) > time.time():
|
|
resp, err_msg = self.auth_response('', self.connected_id, True)
|
|
if err_msg is not None:
|
|
msg = _('Error when authenticating client %s') % '%s' % self.connected_id
|
|
Audit.add_message(4002, msg + '<br>' + err_msg, warning=True)
|
|
raise AuthError(send_error=True)
|
|
self.user_check[1] = time.time()
|
|
|
|
# if it is a ping-request, sending a reply
|
|
if 'ping' in data:
|
|
self.send_pong()
|
|
return
|
|
|
|
self.parcel = Parcel(data['params'], data['files'],
|
|
data['callback'], state=data['state'])
|
|
|
|
if 'delivered' in self.parcel.params:
|
|
self.out_q.put(self.parcel)
|
|
self.send_transfer_notice()
|
|
return
|
|
|
|
if not self.check_incoming_parcel(self.parcel):
|
|
log.warning('Incorrect parcel address')
|
|
self.send_data({'error': True, 'msg': 'Incorrect parcel address'})
|
|
return
|
|
|
|
parcelqueue.manager.put_parcel(self.parcel, constants.TRANSFER)
|
|
log.info('%s - Got a new parcel: %s', self.name, self.parcel)
|
|
|
|
# receive parcel files (attachments)
|
|
for f in self.parcel.files:
|
|
log.info('%s - Receiving file: parcel_id = %s, filename = %s, size = %s',
|
|
self.name, self.parcel.id, f['name'], f['size'])
|
|
self.status = f['name']
|
|
self.receive_file(f, self.parcel.id)
|
|
|
|
log.info('%s - Finished processing parcel: %s', self.name, self.parcel)
|
|
|
|
self.out_q.put(self.parcel)
|
|
self.send_transfer_notice()
|
|
self.status = self.IDLE
|
|
|
|
def writer_workflow(self):
|
|
try:
|
|
self.parcel = self.in_q.get(block=True, timeout=0.1)
|
|
if 'delivered' not in self.parcel.params:
|
|
self.parcel.params['from'] = 'tcp://%s' % self.server_id
|
|
self.keepalive[0] = time.time()
|
|
for f in self.parcel.files:
|
|
file_path = os.path.join(self.static_dir, self.parcel.id, f['name'])
|
|
f['size'] = os.path.getsize(file_path)
|
|
log.info('%s - Sending parcel %s', self.name, self.parcel)
|
|
self.send_data({'params': self.parcel.params,
|
|
'files': self.parcel.files,
|
|
'callback': self.parcel.callback_url,
|
|
'state': self.parcel.state})
|
|
|
|
while True:
|
|
data = self.receive_data()
|
|
if 'error' in data:
|
|
log.warning('Sending parcel failed: %s', data.get('msg'))
|
|
parcelqueue.manager.put_parcel(self.parcel, parcelqueue.ARCHIVE)
|
|
break
|
|
if 'transferred' in data:
|
|
parcelqueue.manager.put_parcel(self.parcel, parcelqueue.ARCHIVE)
|
|
break
|
|
if 'name' in data:
|
|
for f in self.parcel.files:
|
|
if f['name'] == data['name']:
|
|
self.status = f['name']
|
|
filepath = os.path.join(self.static_dir,
|
|
self.parcel.id, f['name'])
|
|
offset = data.get('offset', 0)
|
|
log.info('%s - Sending file %s, offset = %s, '
|
|
'size = %s', self.name, f['name'],
|
|
offset, os.path.getsize(filepath))
|
|
self.send_file(filepath, offset)
|
|
self.status = self.IDLE
|
|
self.parcel = None
|
|
|
|
except queue.Empty:
|
|
curr_time = time.time()
|
|
if curr_time >= sum(self.keepalive):
|
|
self.send_data({'ping': True})
|
|
if 'pong' not in self.receive_data():
|
|
raise socket.error('Ping failed')
|
|
self.keepalive[0] = curr_time
|
|
|
|
def join(self, timeout=None):
|
|
self.stop_event.set()
|
|
super(TCPWorker, self).join(timeout)
|
|
|
|
def auth_challenge(self):
|
|
s = ServerProxy(settings.auth_provider)
|
|
try:
|
|
return s.auth_challenge()
|
|
except:
|
|
raise socket.error('Authentication provider not available, disconnecting')
|
|
|
|
def auth_response(self, challenge, server_id, is_server):
|
|
s = ServerProxy(settings.auth_provider)
|
|
try:
|
|
result = s.auth_response(challenge, server_id, is_server)
|
|
except:
|
|
result = {'error': True, 'msg': _('Authentication server is unavailable')}
|
|
if result.get('error'):
|
|
return None, result.get('msg')
|
|
else:
|
|
return result.get('response'), None
|
|
|
|
def check_incoming_parcel(self, parcel):
|
|
if parcel.params['from'] != 'tcp://%s' % self.connected_id:
|
|
return False
|
|
if parcel.params['to'] != 'tcp://%s' % self.server_id:
|
|
return False
|
|
return True
|
|
|
|
def read_socket(self):
|
|
buf = self.socket.recv(self.bufsize)
|
|
if not buf:
|
|
raise socket.error('Disconnected')
|
|
return buf
|
|
|
|
def receive_data(self):
|
|
if len(self.bufdata):
|
|
data = self.bufdata.popleft()
|
|
log.debug('%s - Received data: %s', self.name, data)
|
|
return json.loads(data)
|
|
data = ''
|
|
self.data_pos = 0
|
|
self.data_size = 0
|
|
while True:
|
|
while len(data) < 8:
|
|
block = self.read_socket()
|
|
data += block
|
|
self.total += len(block)
|
|
self.data_size = data_size = struct.unpack('q', data[:8])[0]
|
|
data = data[8:]
|
|
while len(data) < data_size:
|
|
block = self.read_socket()
|
|
data += block
|
|
self.total += len(block)
|
|
self.data_pos += len(block)
|
|
self.bufdata.append(data[:data_size])
|
|
data = data[data_size:]
|
|
if not data:
|
|
break
|
|
data = self.bufdata.popleft()
|
|
if 'ping' not in data and 'pong' not in data:
|
|
log.debug('%s - Received data: %s', self.name, data)
|
|
return json.loads(data)
|
|
|
|
def receive_file(self, f, parcel_id):
|
|
file_dir = os.path.join(self.static_dir, parcel_id)
|
|
file_path = os.path.join(file_dir, f['name'])
|
|
file_size = f['size']
|
|
if os.path.exists(file_path):
|
|
real_size = os.path.getsize(file_path)
|
|
if real_size < file_size:
|
|
f['offset'] = real_size
|
|
else:
|
|
return
|
|
self.send_file_request(f)
|
|
|
|
if f.get('offset', 0) > 0:
|
|
log.warning('Receiving file with offset: parcel_id = %s, '
|
|
'filename = %s, size = %s, offset = %s',
|
|
parcel_id, f['name'], f['size'], f.get('offset', 0))
|
|
self.data_size = f['size']
|
|
|
|
if not os.path.exists(file_dir):
|
|
os.mkdir(file_dir, 750)
|
|
file_obj = open(file_path, mode='ab')
|
|
bytes_received = f.get('offset', 0)
|
|
self.data_pos = bytes_received
|
|
while bytes_received < file_size:
|
|
chunk = self.socket.recv(self.bufsize)
|
|
if not chunk:
|
|
raise socket.error('Disconnected')
|
|
file_obj.write(chunk)
|
|
bytes_received += len(chunk)
|
|
self.total += len(chunk)
|
|
self.data_pos = bytes_received
|
|
log.info('Received file: parcel_id = %s, filename = %s, '
|
|
'size = %s', parcel_id, f['name'], f['size'])
|
|
del f['size']
|
|
|
|
def send_file_native(self, file_path, offset=0):
|
|
file_obj = open(file_path, 'rb')
|
|
file_obj.seek(offset)
|
|
self.data_pos = offset
|
|
self.data_size = os.path.getsize(file_path)
|
|
while True:
|
|
data = file_obj.read(self.bufsize)
|
|
if not data:
|
|
break
|
|
self.socket.sendall(data)
|
|
self.total += len(data)
|
|
self.data_pos += len(data)
|
|
|
|
def send_file_syscall(self, file_path, offset=0):
|
|
file_obj = open(file_path, 'rb')
|
|
read_offset = offset
|
|
self.data_pos = offset
|
|
self.data_size = os.path.getsize(file_path)
|
|
while True:
|
|
sent = sendfile(self.socket.fileno(), file_obj.fileno(),
|
|
read_offset, self.bufsize)
|
|
if sent == 0:
|
|
break
|
|
read_offset += sent
|
|
self.total += sent
|
|
self.data_pos += sent
|
|
|
|
def send_data(self, data):
|
|
self.data_size = 0
|
|
data = json.dumps(data)
|
|
if 'ping' not in data and 'pong' not in data:
|
|
log.debug('%s - Sending data: %s', self.name, data)
|
|
self.socket.sendall(struct.pack('q', len(data)) + data)
|
|
|
|
def send_pong(self):
|
|
self.send_data({'pong': True})
|
|
|
|
def send_file_request(self, f):
|
|
log.debug('Send file request: filename = %s, size = %s, offset = %s',
|
|
f['name'], f['size'], f.get('offset', 0))
|
|
self.send_data({'name': f['name'],
|
|
'offset': f.get('offset', 0)})
|
|
|
|
def send_transfer_notice(self):
|
|
self.send_data({'transferred': True,
|
|
'parcel_id': self.parcel.id})
|
|
|
|
def send_delivery_notice(self):
|
|
self.send_data({'notice': True,
|
|
'parcel_id': self.parcel.id,
|
|
'parcel_state': self.parcel.state})
|