servers: use proxy design pattern instead of inheritance, because Python ssl patch has benn refused

This commit is contained in:
nemunaire 2017-09-16 19:32:58 +02:00 committed by Pierre-Olivier Mercier
parent 7a4b27510c
commit 12ddf40ef4
5 changed files with 57 additions and 73 deletions

View File

@ -9,8 +9,6 @@ Requirements
*nemubot* requires at least Python 3.3 to work.
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
Some modules (like `cve`, `nextstop` or `laposte`) require the
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
but the core and framework has no dependency.

View File

@ -21,10 +21,10 @@ import socket
from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter
from nemubot.server.message.IRC import IRC as IRCMessage
from nemubot.server.socket import SocketServer, SecureSocketServer
from nemubot.server.socket import SocketServer
class _IRC:
class IRC(SocketServer):
"""Concrete implementation of a connexion to an IRC server"""
@ -245,7 +245,7 @@ class _IRC:
def close(self):
if not self._closed:
if not self._fd._closed:
self.write("QUIT")
return super().close()
@ -274,10 +274,3 @@ class _IRC:
def subparse(self, orig, cnt):
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
return msg.to_bot_message(self)
class IRC(_IRC, SocketServer):
pass
class IRC_secure(_IRC, SecureSocketServer):
pass

View File

@ -32,16 +32,6 @@ def factory(uri, ssl=False, **init_args):
if o.username is not None: args["username"] = o.username
if o.password is not None: args["password"] = o.password
if ssl:
try:
from ssl import create_default_context
args["_context"] = create_default_context()
except ImportError:
# Python 3.3 compat
from ssl import SSLContext, PROTOCOL_TLSv1
args["_context"] = SSLContext(PROTOCOL_TLSv1)
args["server_hostname"] = o.hostname
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
@ -68,11 +58,19 @@ def factory(uri, ssl=False, **init_args):
if "channels" not in args and "isnick" not in modifiers:
args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer
srv = IRCServer(**args)
if ssl:
from nemubot.server.IRC import IRC_secure as SecureIRCServer
srv = SecureIRCServer(**args)
else:
from nemubot.server.IRC import IRC as IRCServer
srv = IRCServer(**args)
try:
from ssl import create_default_context
context = create_default_context()
except ImportError:
# Python 3.3 compat
from ssl import SSLContext, PROTOCOL_TLSv1
context = SSLContext(PROTOCOL_TLSv1)
from ssl import wrap_socket
srv._fd = context.wrap_socket(srv._fd, server_hostname=o.hostname)
return srv

View File

@ -24,17 +24,16 @@ class AbstractServer:
"""An abstract server: handle communication with an IM server"""
def __init__(self, name=None, **kwargs):
def __init__(self, name, fdClass, **kwargs):
"""Initialize an abstract server
Keyword argument:
name -- Identifier of the socket, for convinience
fdClass -- Class to instantiate as support file
"""
self._name = name
self._socket = socket
super().__init__(**kwargs)
self._fd = fdClass(**kwargs)
self._logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
@ -46,7 +45,7 @@ class AbstractServer:
if self._name is not None:
return self._name
else:
return self.fileno()
return self._fd.fileno()
# Open/close
@ -56,12 +55,12 @@ class AbstractServer:
self._logger.info("Opening connection")
super().connect(*args, **kwargs)
self._fd.connect(*args, **kwargs)
self._on_connect()
def _on_connect(self):
sync_act("sckt", "register", self.fileno())
sync_act("sckt", "register", self._fd.fileno())
def close(self, *args, **kwargs):
@ -69,10 +68,10 @@ class AbstractServer:
self._logger.info("Closing connection")
if self.fileno() > 0:
sync_act("sckt", "unregister", self.fileno())
if self._fd.fileno() > 0:
sync_act("sckt", "unregister", self._fd.fileno())
super().close(*args, **kwargs)
self._fd.close(*args, **kwargs)
# Writes
@ -86,14 +85,14 @@ class AbstractServer:
self._sending_queue.put(self.format(message))
self._logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt", "write", self.fileno())
sync_act("sckt", "write", self._fd.fileno())
def async_write(self):
"""Internal function used when the file descriptor is writable"""
try:
sync_act("sckt", "unwrite", self.fileno())
sync_act("sckt", "unwrite", self._fd.fileno())
while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done()
@ -131,7 +130,7 @@ class AbstractServer:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
ret, self._readbuffer = self.lex(self._readbuffer + self._fd.read())
for r in ret:
yield r
@ -159,4 +158,9 @@ class AbstractServer:
def exception(self, flags):
"""Exception occurs on fd"""
self.close()
self._fd.close()
# Proxy
def fileno(self):
return self._fd.fileno()

View File

@ -16,7 +16,6 @@
import os
import socket
import ssl
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter
@ -40,7 +39,7 @@ class _Socket(AbstractServer):
# Write
def _write(self, cnt):
self.sendall(cnt)
self._fd.sendall(cnt)
def format(self, txt):
@ -52,8 +51,8 @@ class _Socket(AbstractServer):
# Read
def recv(self, n=1024):
return super().recv(n)
def recv(self, *args, **kwargs):
return self._fd.recv(*args, **kwargs)
def parse(self, line):
@ -67,7 +66,7 @@ class _Socket(AbstractServer):
args = line.split(' ')
if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
yield message.Command(cmd=args[0], args=args[1:], server=self._fd.fileno(), to=["you"], frm="you")
def subparse(self, orig, cnt):
@ -78,50 +77,46 @@ class _Socket(AbstractServer):
yield m
class _SocketServer(_Socket):
class SocketServer(_Socket):
def __init__(self, host, port, bind=None, **kwargs):
(family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
(family, type, proto, canonname, self._sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
super().__init__(family=family, type=type, proto=proto, **kwargs)
super().__init__(fdClass=socket.socket, family=family, type=type, proto=proto, **kwargs)
self._sockaddr = sockaddr
self._bind = bind
def connect(self):
self._logger.info("Connection to %s:%d", *self._sockaddr[:2])
self._logger.info("Connecting to %s:%d", *self._sockaddr[:2])
super().connect(self._sockaddr)
self._logger.info("Connected to %s:%d", *self._sockaddr[:2])
if self._bind:
super().bind(self._bind)
class SocketServer(_SocketServer, socket.socket):
pass
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
pass
self._fd.bind(self._bind)
class UnixSocket:
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
super().__init__(fdClass=socket.socket, family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def connect(self):
self._logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
self.connect(self._socket_path)
class SocketClient(_Socket, socket.socket):
class SocketClient(_Socket):
def __init__(self, **kwargs):
super().__init__(fdClass=socket.socket, **kwargs)
def read(self):
return self.recv()
return self._fd.recv()
class _Listener:
@ -134,7 +129,7 @@ class _Listener:
def read(self):
conn, addr = self.accept()
conn, addr = self._fd.accept()
fileno = conn.fileno()
self._logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
@ -145,11 +140,7 @@ class _Listener:
return b''
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class UnixSocketListener(_Listener, UnixSocket, _Socket):
def connect(self):
self._logger.info("Creating Unix socket at unix://%s", self._socket_path)
@ -159,8 +150,8 @@ class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self._fd.bind(self._socket_path)
self._fd.listen(5)
self._logger.info("Socket ready for accepting new connections")
self._on_connect()
@ -171,7 +162,7 @@ class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
import socket
try:
self.shutdown(socket.SHUT_RDWR)
self._fd.shutdown(socket.SHUT_RDWR)
except socket.error:
pass