Files
enserver/enserver/worker/tcp.py
Ivan Vazhenin 54fa589261 First commit
2023-03-12 16:40:33 +03:00

611 lines
24 KiB
Python

# coding: utf-8
import json
import logging
import os
import queue
import socket
import struct
import threading
import time
from collections import deque
from datetime import datetime, timedelta
from conf import settings
from lib.translation import ugettext as _
from parcel import constants, parcelqueue
from parcel.base import KnownConnect, LastConnect, Parcel
from worker.utils import disconnect_handle, replication_after_reconnect
from worker.audit import Audit
from xmlrpc_wrapper import ServerProxy
try:
from sendfile import sendfile
sendfile_available = True
except ImportError:
sendfile = None
sendfile_available = False
log = logging.getLogger('enserver.worker.tcp')
workers = []
class TCPClientWorker(threading.Thread):
"""
Maintains uplink connection
"""
def __init__(self, reconnect_timeout=30.0, auth_timeout=60.0):
super(TCPClientWorker, self).__init__()
self.connect_address = (settings.host, settings.port)
self.reconnect_timeout = reconnect_timeout
self.userid_check = [0, auth_timeout, False]
self.stop_event = threading.Event()
self.workers = []
self.last_message = None
workers.append(self)
def run(self):
while not self.stop_event.is_set():
if self.workers:
if not self.workers_alive() or not self.userid_correct():
self.stop_workers()
else:
self.stop_event.wait(self.reconnect_timeout)
else:
self.start_workers()
self.stop_event.wait(self.reconnect_timeout)
def join(self, timeout=None):
self.stop_workers()
self.stop_event.set()
super(TCPClientWorker, self).join()
def start_workers(self):
self.workers = []
if not self.userid_correct():
return
try:
for state in (TCPWorker.STATE_READER, TCPWorker.STATE_WRITER):
self.workers.append(self.start_worker(state))
log.info('Worker type %s connected to %s', state, self.connect_address[0])
except socket.error:
log.warning('Unable to connect to %s', self.connect_address[0])
msg = _('Unable to connect to %s') % self.connect_address[0]
if self.last_message != msg:
self.last_message = msg
Audit.add_message(4001, self.last_message, warning=True)
self.stop_workers()
disconnect_handle(is_client=True)
except Exception as exc:
log.exception('Unexpected error: %s, cls: %s', exc, self.__class__)
else:
self.last_message = ''
def stop_workers(self):
for worker in self.workers:
worker.join()
self.workers = []
def workers_alive(self):
for worker in self.workers:
if not worker.is_alive():
self.last_message = _('Disconnected from %s') % self.connect_address[0]
Audit.add_message(4001, self.last_message, warning=True)
return False
return True
def start_worker(self, state):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_socket.connect(self.connect_address)
client_socket.settimeout(settings.tcp_timeout)
worker = TCPWorker(client_socket, state=state)
worker.start()
return worker
def userid_correct(self):
if sum(self.userid_check[:2]) < time.time():
self.userid_check[0] = time.time()
try:
s = ServerProxy(settings.auth_provider)
response = s.auth_response('', settings.id, False)
except Exception as exc:
response = {'error': True, 'msg': _('Authentication server is unavailable')}
log.warning('Authentication server is unavailable, exception: %s', exc)
self.userid_check[2] = not response.get('error', True)
if not self.userid_check[2]:
log.warning('Authentication is unavailable for user %s', settings.id)
msg = _('Authentication is unavailable for user %s') % settings.id + '<br>' + response.get('msg')
if self.last_message != msg:
self.last_message = msg
Audit.add_message(4001, self.last_message, warning=True)
return self.userid_check[2]
def connections(self):
return self.workers
class TCPServerWorker(threading.Thread):
"""
Listens server tcp socket, accepts connections from clients
Every client connection is being served by two threads: reader and writer
"""
def __init__(self, disconnect_timeout=30):
super(TCPServerWorker, self).__init__()
self.listen_address = (settings.tcp_host, settings.tcp_port)
self.static_dir = settings.storage
self.stop_event = threading.Event()
workers.append(self)
self.workers = []
self.disconnect_timeout = disconnect_timeout
self.last_timestamp = datetime.now()
def run(self):
try:
log.info('Serving on %s:%s', *self.listen_address)
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.listen_address)
self.socket.listen(10)
self.socket.settimeout(0.2)
while not self.stop_event.isSet():
try:
client = self.socket.accept()
self.process_new_client(*client)
except socket.error:
if self.last_timestamp + timedelta(seconds=self.disconnect_timeout) > datetime.now():
continue
else:
self.last_timestamp = datetime.now()
disconnect_handle(is_client=False, workers=self.workers)
replication_after_reconnect(self.workers)
self.socket.close()
except Exception as exc:
log.exception('Unexpected error: %s, cls: %s', exc, self.__class__)
log.info('Bye!')
def join(self, timeout=None):
self.stop_event.set()
super(TCPServerWorker, self).join(timeout)
def process_new_client(self, client_socket, client_address):
log.info('New connection from %s', client_address)
client_socket.settimeout(settings.tcp_timeout)
worker = TCPWorker(client_socket, is_incoming=True)
worker.start()
self.workers.append(worker)
def connections(self):
return self.workers
def process_disconnected(self):
self.workers = filter(lambda worker: worker.is_alive(), self.workers)
class AuthError(Exception):
def __init__(self, *args, **kwargs):
self.send_error = kwargs.get('send_error', False)
super(AuthError, self).__init__(*args)
class TCPWorker(threading.Thread):
STATE_READER = 1
STATE_WRITER = -1
IDLE = 'idle'
META = 'meta'
def __init__(self, socket, keepalive=1.0, state=None, auth_timeout=60.0, is_incoming=False):
super(TCPWorker, self).__init__()
assert state in (None, self.STATE_READER, self.STATE_WRITER)
self.socket = socket
self.address = socket.getpeername()[0]
self.in_q = None
self.out_q = parcelqueue.manager.get(constants.POSTMAN)
self.keepalive = [0, keepalive]
self.state = state
self.is_client = state is not None
self.parcel = None
self.server_id = settings.id
self.connected_id = None
self.static_dir = settings.storage
self.bufsize = settings.tcp_bufsize
self.bufdata = deque()
self.stop_event = threading.Event()
self.workflow = None
self.user_check = [0, auth_timeout]
self.established = None
self.status = self.IDLE
self.data_pos = 0
self.data_size = 0
self.total = 0
self.bitrate = 0
self.is_incoming = is_incoming
def info(self):
return {'parcel_id': self.parcel.id if self.parcel else '',
'type': 'up' if self.state == self.STATE_WRITER else 'down',
'address': self.address,
'queue': self.in_q.queue_id if self.in_q else '',
'status': self.status,
'established': time.mktime(self.established.timetuple()),
'user': self.connected_id,
'bitrate': self.bitrate,
'total': float(self.total),
'data_pos': float(self.data_pos),
'data_size': float(self.data_size),
'direction': 'incoming' if self.is_incoming else 'outgoing'}
def run(self):
# check if pysendfile available
if sendfile_available:
self.send_file = self.send_file_syscall
else:
self.send_file = self.send_file_native
self.name = self.socket.getpeername()[0]
self.established = datetime.now()
try:
self.handshake()
except (socket.error, KeyError) as ex:
log.error('Handshake error: %s', ex)
event_type = 4001 if self.is_client else 4002
Audit.add_message(event_type, _('Protocol error when connecting to %s') % self.name, warning=True)
self.stop_event.set()
except AuthError as ex:
log.error('%s - Authentication failed, exception: %s', self.name, ex)
if ex.send_error:
self.send_data({'auth_error': True})
self.stop_event.set()
while not self.stop_event.is_set():
try:
self.workflow()
except socket.error:
self.stop_event.set()
if self.state == self.STATE_WRITER and self.parcel is not None:
self.in_q.put(self.parcel)
except Exception as exc:
log.exception('Unexpected error: %s', exc)
self.socket.close()
log.info('%s - Bye!', self.name)
def handshake(self):
if self.state is not None:
#
# client case
#
c1 = self.auth_challenge()
self.send_data({'flow_state': self.state,
'server_id': self.server_id,
'challenge': c1})
data = self.receive_data()
if data.get('auth_error', False):
Audit.add_message(4001, _('Authentication error when connecting to %s') % self.name, warning=True)
raise AuthError(send_error=False)
self.connected_id = data['server_id']
r1, err_msg = self.auth_response(c1, self.server_id, False)
if err_msg is not None or data['response'] != r1:
msg = _('Authentication error when connecting to %s') % self.name
if err_msg is not None:
msg = msg + '<br>' + err_msg
Audit.add_message(4001, msg, warning=True)
raise AuthError(send_error=True)
else:
log.debug('%s - Server authenticated', self.name)
r2, err_msg = self.auth_response(data['challenge'], self.server_id, False)
if err_msg is not None:
msg = _('Authentication error when connecting to %s') % self.name
Audit.add_message(4001, msg + '<br>' + err_msg, warning=True)
raise AuthError(send_error=True)
self.send_data({'response': r2})
data = self.receive_data()
if data.get('auth_error', False):
Audit.add_message(4001, _('Authentication error when connecting to %s') % self.name, warning=True)
raise AuthError(send_error=False)
else:
Audit.add_message(4001, _('Successfully connected to %s') % '%s, %s' % (self.connected_id, self.name))
log.info('%s - Client authenticated', self.name)
else:
#
# server case
#
data = self.receive_data()
self.connected_id = data['server_id']
r1, err_msg = self.auth_response(data['challenge'], self.connected_id, True)
if err_msg is not None:
msg = _('Error when authenticating client %s') % '%s %s' % (self.name, self.connected_id)
Audit.add_message(4002, msg + '<br>' + err_msg, warning=True)
raise AuthError(send_error=True)
c1 = self.auth_challenge()
self.send_data({'server_id': self.server_id,
'challenge': c1, 'response': r1})
self.state = data['flow_state'] * -1
data = self.receive_data()
if data.get('auth_error', False):
msg = _('Error when authenticating client %s') % '%s, %s' % (self.connected_id, self.name)
Audit.add_message(4002, msg, warning=True)
raise socket.error('Client %s failed to authenticate server' % self.connected_id)
r2, err_msg = self.auth_response(c1, self.connected_id, True)
if err_msg is not None or data['response'] != r2:
msg = _('Error when authenticating client %s') % '%s, %s' % (self.connected_id, self.name)
if err_msg is not None:
msg = msg + '<br>' + err_msg
Audit.add_message(4002, msg, warning=True)
raise AuthError(send_error=True)
else:
self.send_data({'auth_error': False})
Audit.add_message(4002, _('Successfully connected to %s') % '%s, %s' % (self.connected_id, self.name))
log.info('%s - Client authenticated', self.name)
self.user_check[0] = time.time()
if self.state == self.STATE_READER:
self.name = '%s reader' % self.connected_id
self.workflow = self.reader_workflow
elif self.state == self.STATE_WRITER:
self.name = '%s writer' % self.connected_id
self.workflow = self.writer_workflow
self.in_q = parcelqueue.manager.get('tcp://%s' % self.connected_id)
known_connect = KnownConnect(self.in_q.queue_id, self.is_client)
parcelqueue.manager.put_known_connect(known_connect)
else:
log.error('Incorrect state value %s', self.state)
self.stop_event.set()
def reader_workflow(self):
# receive parcel meta
data = self.receive_data()
if not self.is_client and sum(self.user_check) > time.time():
resp, err_msg = self.auth_response('', self.connected_id, True)
if err_msg is not None:
msg = _('Error when authenticating client %s') % '%s' % self.connected_id
Audit.add_message(4002, msg + '<br>' + err_msg, warning=True)
raise AuthError(send_error=True)
self.user_check[1] = time.time()
# if it is a ping-request, sending a reply
if 'ping' in data:
self.send_pong()
return
self.parcel = Parcel(data['params'], data['files'],
data['callback'], state=data['state'])
if 'delivered' in self.parcel.params:
self.out_q.put(self.parcel)
self.send_transfer_notice()
return
if not self.check_incoming_parcel(self.parcel):
log.warning('Incorrect parcel address')
self.send_data({'error': True, 'msg': 'Incorrect parcel address'})
return
parcelqueue.manager.put_parcel(self.parcel, constants.TRANSFER)
log.info('%s - Got a new parcel: %s', self.name, self.parcel)
# receive parcel files (attachments)
for f in self.parcel.files:
log.info('%s - Receiving file: parcel_id = %s, filename = %s, size = %s',
self.name, self.parcel.id, f['name'], f['size'])
self.status = f['name']
self.receive_file(f, self.parcel.id)
log.info('%s - Finished processing parcel: %s', self.name, self.parcel)
self.out_q.put(self.parcel)
self.send_transfer_notice()
self.status = self.IDLE
def writer_workflow(self):
try:
self.parcel = self.in_q.get(block=True, timeout=0.1)
if 'delivered' not in self.parcel.params:
self.parcel.params['from'] = 'tcp://%s' % self.server_id
self.keepalive[0] = time.time()
for f in self.parcel.files:
file_path = os.path.join(self.static_dir, self.parcel.id, f['name'])
f['size'] = os.path.getsize(file_path)
log.info('%s - Sending parcel %s', self.name, self.parcel)
self.send_data({'params': self.parcel.params,
'files': self.parcel.files,
'callback': self.parcel.callback_url,
'state': self.parcel.state})
while True:
data = self.receive_data()
if 'error' in data:
log.warning('Sending parcel failed: %s', data.get('msg'))
parcelqueue.manager.put_parcel(self.parcel, parcelqueue.ARCHIVE)
break
if 'transferred' in data:
parcelqueue.manager.put_parcel(self.parcel, parcelqueue.ARCHIVE)
break
if 'name' in data:
for f in self.parcel.files:
if f['name'] == data['name']:
self.status = f['name']
filepath = os.path.join(self.static_dir,
self.parcel.id, f['name'])
offset = data.get('offset', 0)
log.info('%s - Sending file %s, offset = %s, '
'size = %s', self.name, f['name'],
offset, os.path.getsize(filepath))
self.send_file(filepath, offset)
self.status = self.IDLE
self.parcel = None
except queue.Empty:
curr_time = time.time()
if curr_time >= sum(self.keepalive):
self.send_data({'ping': True})
if 'pong' not in self.receive_data():
raise socket.error('Ping failed')
self.keepalive[0] = curr_time
def join(self, timeout=None):
self.stop_event.set()
super(TCPWorker, self).join(timeout)
def auth_challenge(self):
s = ServerProxy(settings.auth_provider)
try:
return s.auth_challenge()
except:
raise socket.error('Authentication provider not available, disconnecting')
def auth_response(self, challenge, server_id, is_server):
s = ServerProxy(settings.auth_provider)
try:
result = s.auth_response(challenge, server_id, is_server)
except:
result = {'error': True, 'msg': _('Authentication server is unavailable')}
if result.get('error'):
return None, result.get('msg')
else:
return result.get('response'), None
def check_incoming_parcel(self, parcel):
if parcel.params['from'] != 'tcp://%s' % self.connected_id:
return False
if parcel.params['to'] != 'tcp://%s' % self.server_id:
return False
return True
def read_socket(self):
buf = self.socket.recv(self.bufsize)
if not buf:
raise socket.error('Disconnected')
return buf
def receive_data(self):
if len(self.bufdata):
data = self.bufdata.popleft()
log.debug('%s - Received data: %s', self.name, data)
return json.loads(data)
data = ''
self.data_pos = 0
self.data_size = 0
while True:
while len(data) < 8:
block = self.read_socket()
data += block
self.total += len(block)
self.data_size = data_size = struct.unpack('q', data[:8])[0]
data = data[8:]
while len(data) < data_size:
block = self.read_socket()
data += block
self.total += len(block)
self.data_pos += len(block)
self.bufdata.append(data[:data_size])
data = data[data_size:]
if not data:
break
data = self.bufdata.popleft()
if 'ping' not in data and 'pong' not in data:
log.debug('%s - Received data: %s', self.name, data)
return json.loads(data)
def receive_file(self, f, parcel_id):
file_dir = os.path.join(self.static_dir, parcel_id)
file_path = os.path.join(file_dir, f['name'])
file_size = f['size']
if os.path.exists(file_path):
real_size = os.path.getsize(file_path)
if real_size < file_size:
f['offset'] = real_size
else:
return
self.send_file_request(f)
if f.get('offset', 0) > 0:
log.warning('Receiving file with offset: parcel_id = %s, '
'filename = %s, size = %s, offset = %s',
parcel_id, f['name'], f['size'], f.get('offset', 0))
self.data_size = f['size']
if not os.path.exists(file_dir):
os.mkdir(file_dir, 750)
file_obj = open(file_path, mode='ab')
bytes_received = f.get('offset', 0)
self.data_pos = bytes_received
while bytes_received < file_size:
chunk = self.socket.recv(self.bufsize)
if not chunk:
raise socket.error('Disconnected')
file_obj.write(chunk)
bytes_received += len(chunk)
self.total += len(chunk)
self.data_pos = bytes_received
log.info('Received file: parcel_id = %s, filename = %s, '
'size = %s', parcel_id, f['name'], f['size'])
del f['size']
def send_file_native(self, file_path, offset=0):
file_obj = open(file_path, 'rb')
file_obj.seek(offset)
self.data_pos = offset
self.data_size = os.path.getsize(file_path)
while True:
data = file_obj.read(self.bufsize)
if not data:
break
self.socket.sendall(data)
self.total += len(data)
self.data_pos += len(data)
def send_file_syscall(self, file_path, offset=0):
file_obj = open(file_path, 'rb')
read_offset = offset
self.data_pos = offset
self.data_size = os.path.getsize(file_path)
while True:
sent = sendfile(self.socket.fileno(), file_obj.fileno(),
read_offset, self.bufsize)
if sent == 0:
break
read_offset += sent
self.total += sent
self.data_pos += sent
def send_data(self, data):
self.data_size = 0
data = json.dumps(data)
if 'ping' not in data and 'pong' not in data:
log.debug('%s - Sending data: %s', self.name, data)
self.socket.sendall(struct.pack('q', len(data)) + data)
def send_pong(self):
self.send_data({'pong': True})
def send_file_request(self, f):
log.debug('Send file request: filename = %s, size = %s, offset = %s',
f['name'], f['size'], f.get('offset', 0))
self.send_data({'name': f['name'],
'offset': f.get('offset', 0)})
def send_transfer_notice(self):
self.send_data({'transferred': True,
'parcel_id': self.parcel.id})
def send_delivery_notice(self):
self.send_data({'notice': True,
'parcel_id': self.parcel.id,
'parcel_state': self.parcel.state})