servers: use proxy design pattern instead of inheritance, because Python ssl patch has benn refused

This commit is contained in:
nemunaire 2017-09-16 19:32:58 +02:00 committed by Pierre-Olivier Mercier
parent 7a4b27510c
commit 12ddf40ef4
5 changed files with 57 additions and 73 deletions

View File

@ -9,8 +9,6 @@ Requirements
*nemubot* requires at least Python 3.3 to work. *nemubot* requires at least Python 3.3 to work.
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
Some modules (like `cve`, `nextstop` or `laposte`) require the Some modules (like `cve`, `nextstop` or `laposte`) require the
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/), [BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
but the core and framework has no dependency. but the core and framework has no dependency.

View File

@ -21,10 +21,10 @@ import socket
from nemubot.channel import Channel from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter from nemubot.message.printer.IRC import IRC as IRCPrinter
from nemubot.server.message.IRC import IRC as IRCMessage from nemubot.server.message.IRC import IRC as IRCMessage
from nemubot.server.socket import SocketServer, SecureSocketServer from nemubot.server.socket import SocketServer
class _IRC: class IRC(SocketServer):
"""Concrete implementation of a connexion to an IRC server""" """Concrete implementation of a connexion to an IRC server"""
@ -245,7 +245,7 @@ class _IRC:
def close(self): def close(self):
if not self._closed: if not self._fd._closed:
self.write("QUIT") self.write("QUIT")
return super().close() return super().close()
@ -274,10 +274,3 @@ class _IRC:
def subparse(self, orig, cnt): def subparse(self, orig, cnt):
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding) msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
return msg.to_bot_message(self) return msg.to_bot_message(self)
class IRC(_IRC, SocketServer):
pass
class IRC_secure(_IRC, SecureSocketServer):
pass

View File

@ -32,16 +32,6 @@ def factory(uri, ssl=False, **init_args):
if o.username is not None: args["username"] = o.username if o.username is not None: args["username"] = o.username
if o.password is not None: args["password"] = o.password if o.password is not None: args["password"] = o.password
if ssl:
try:
from ssl import create_default_context
args["_context"] = create_default_context()
except ImportError:
# Python 3.3 compat
from ssl import SSLContext, PROTOCOL_TLSv1
args["_context"] = SSLContext(PROTOCOL_TLSv1)
args["server_hostname"] = o.hostname
modifiers = o.path.split(",") modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:]) target = unquote(modifiers.pop(0)[1:])
@ -68,11 +58,19 @@ def factory(uri, ssl=False, **init_args):
if "channels" not in args and "isnick" not in modifiers: if "channels" not in args and "isnick" not in modifiers:
args["channels"] = [ target ] args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer
srv = IRCServer(**args)
if ssl: if ssl:
from nemubot.server.IRC import IRC_secure as SecureIRCServer try:
srv = SecureIRCServer(**args) from ssl import create_default_context
else: context = create_default_context()
from nemubot.server.IRC import IRC as IRCServer except ImportError:
srv = IRCServer(**args) # Python 3.3 compat
from ssl import SSLContext, PROTOCOL_TLSv1
context = SSLContext(PROTOCOL_TLSv1)
from ssl import wrap_socket
srv._fd = context.wrap_socket(srv._fd, server_hostname=o.hostname)
return srv return srv

View File

@ -24,17 +24,16 @@ class AbstractServer:
"""An abstract server: handle communication with an IM server""" """An abstract server: handle communication with an IM server"""
def __init__(self, name=None, **kwargs): def __init__(self, name, fdClass, **kwargs):
"""Initialize an abstract server """Initialize an abstract server
Keyword argument: Keyword argument:
name -- Identifier of the socket, for convinience name -- Identifier of the socket, for convinience
fdClass -- Class to instantiate as support file
""" """
self._name = name self._name = name
self._socket = socket self._fd = fdClass(**kwargs)
super().__init__(**kwargs)
self._logger = logging.getLogger("nemubot.server." + str(self.name)) self._logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b'' self._readbuffer = b''
@ -46,7 +45,7 @@ class AbstractServer:
if self._name is not None: if self._name is not None:
return self._name return self._name
else: else:
return self.fileno() return self._fd.fileno()
# Open/close # Open/close
@ -56,12 +55,12 @@ class AbstractServer:
self._logger.info("Opening connection") self._logger.info("Opening connection")
super().connect(*args, **kwargs) self._fd.connect(*args, **kwargs)
self._on_connect() self._on_connect()
def _on_connect(self): def _on_connect(self):
sync_act("sckt", "register", self.fileno()) sync_act("sckt", "register", self._fd.fileno())
def close(self, *args, **kwargs): def close(self, *args, **kwargs):
@ -69,10 +68,10 @@ class AbstractServer:
self._logger.info("Closing connection") self._logger.info("Closing connection")
if self.fileno() > 0: if self._fd.fileno() > 0:
sync_act("sckt", "unregister", self.fileno()) sync_act("sckt", "unregister", self._fd.fileno())
super().close(*args, **kwargs) self._fd.close(*args, **kwargs)
# Writes # Writes
@ -86,14 +85,14 @@ class AbstractServer:
self._sending_queue.put(self.format(message)) self._sending_queue.put(self.format(message))
self._logger.debug("Message '%s' appended to write queue", message) self._logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt", "write", self.fileno()) sync_act("sckt", "write", self._fd.fileno())
def async_write(self): def async_write(self):
"""Internal function used when the file descriptor is writable""" """Internal function used when the file descriptor is writable"""
try: try:
sync_act("sckt", "unwrite", self.fileno()) sync_act("sckt", "unwrite", self._fd.fileno())
while not self._sending_queue.empty(): while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait()) self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done() self._sending_queue.task_done()
@ -131,7 +130,7 @@ class AbstractServer:
A list of fully received messages A list of fully received messages
""" """
ret, self._readbuffer = self.lex(self._readbuffer + self.read()) ret, self._readbuffer = self.lex(self._readbuffer + self._fd.read())
for r in ret: for r in ret:
yield r yield r
@ -159,4 +158,9 @@ class AbstractServer:
def exception(self, flags): def exception(self, flags):
"""Exception occurs on fd""" """Exception occurs on fd"""
self.close() self._fd.close()
# Proxy
def fileno(self):
return self._fd.fileno()

View File

@ -16,7 +16,6 @@
import os import os
import socket import socket
import ssl
import nemubot.message as message import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.message.printer.socket import Socket as SocketPrinter
@ -40,7 +39,7 @@ class _Socket(AbstractServer):
# Write # Write
def _write(self, cnt): def _write(self, cnt):
self.sendall(cnt) self._fd.sendall(cnt)
def format(self, txt): def format(self, txt):
@ -52,8 +51,8 @@ class _Socket(AbstractServer):
# Read # Read
def recv(self, n=1024): def recv(self, *args, **kwargs):
return super().recv(n) return self._fd.recv(*args, **kwargs)
def parse(self, line): def parse(self, line):
@ -67,7 +66,7 @@ class _Socket(AbstractServer):
args = line.split(' ') args = line.split(' ')
if len(args): if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you") yield message.Command(cmd=args[0], args=args[1:], server=self._fd.fileno(), to=["you"], frm="you")
def subparse(self, orig, cnt): def subparse(self, orig, cnt):
@ -78,50 +77,46 @@ class _Socket(AbstractServer):
yield m yield m
class _SocketServer(_Socket): class SocketServer(_Socket):
def __init__(self, host, port, bind=None, **kwargs): def __init__(self, host, port, bind=None, **kwargs):
(family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0] (family, type, proto, canonname, self._sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
super().__init__(family=family, type=type, proto=proto, **kwargs) super().__init__(fdClass=socket.socket, family=family, type=type, proto=proto, **kwargs)
self._sockaddr = sockaddr
self._bind = bind self._bind = bind
def connect(self): def connect(self):
self._logger.info("Connection to %s:%d", *self._sockaddr[:2]) self._logger.info("Connecting to %s:%d", *self._sockaddr[:2])
super().connect(self._sockaddr) super().connect(self._sockaddr)
self._logger.info("Connected to %s:%d", *self._sockaddr[:2])
if self._bind: if self._bind:
super().bind(self._bind) self._fd.bind(self._bind)
class SocketServer(_SocketServer, socket.socket):
pass
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
pass
class UnixSocket: class UnixSocket:
def __init__(self, location, **kwargs): def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs) super().__init__(fdClass=socket.socket, family=socket.AF_UNIX, **kwargs)
self._socket_path = location self._socket_path = location
def connect(self): def connect(self):
self._logger.info("Connection to unix://%s", self._socket_path) self._logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path) self.connect(self._socket_path)
class SocketClient(_Socket, socket.socket): class SocketClient(_Socket):
def __init__(self, **kwargs):
super().__init__(fdClass=socket.socket, **kwargs)
def read(self): def read(self):
return self.recv() return self._fd.recv()
class _Listener: class _Listener:
@ -134,7 +129,7 @@ class _Listener:
def read(self): def read(self):
conn, addr = self.accept() conn, addr = self._fd.accept()
fileno = conn.fileno() fileno = conn.fileno()
self._logger.info("Accept new connection from %s (fd=%d)", addr, fileno) self._logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
@ -145,11 +140,7 @@ class _Listener:
return b'' return b''
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket): class UnixSocketListener(_Listener, UnixSocket, _Socket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self): def connect(self):
self._logger.info("Creating Unix socket at unix://%s", self._socket_path) self._logger.info("Creating Unix socket at unix://%s", self._socket_path)
@ -159,8 +150,8 @@ class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
except FileNotFoundError: except FileNotFoundError:
pass pass
self.bind(self._socket_path) self._fd.bind(self._socket_path)
self.listen(5) self._fd.listen(5)
self._logger.info("Socket ready for accepting new connections") self._logger.info("Socket ready for accepting new connections")
self._on_connect() self._on_connect()
@ -171,7 +162,7 @@ class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
import socket import socket
try: try:
self.shutdown(socket.SHUT_RDWR) self._fd.shutdown(socket.SHUT_RDWR)
except socket.error: except socket.error:
pass pass