servers: use proxy design pattern instead of inheritance, because Python ssl patch has benn refused
This commit is contained in:
parent
7a4b27510c
commit
12ddf40ef4
@ -9,8 +9,6 @@ Requirements
|
||||
|
||||
*nemubot* requires at least Python 3.3 to work.
|
||||
|
||||
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
|
||||
|
||||
Some modules (like `cve`, `nextstop` or `laposte`) require the
|
||||
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
|
||||
but the core and framework has no dependency.
|
||||
|
@ -21,10 +21,10 @@ import socket
|
||||
from nemubot.channel import Channel
|
||||
from nemubot.message.printer.IRC import IRC as IRCPrinter
|
||||
from nemubot.server.message.IRC import IRC as IRCMessage
|
||||
from nemubot.server.socket import SocketServer, SecureSocketServer
|
||||
from nemubot.server.socket import SocketServer
|
||||
|
||||
|
||||
class _IRC:
|
||||
class IRC(SocketServer):
|
||||
|
||||
"""Concrete implementation of a connexion to an IRC server"""
|
||||
|
||||
@ -245,7 +245,7 @@ class _IRC:
|
||||
|
||||
|
||||
def close(self):
|
||||
if not self._closed:
|
||||
if not self._fd._closed:
|
||||
self.write("QUIT")
|
||||
return super().close()
|
||||
|
||||
@ -274,10 +274,3 @@ class _IRC:
|
||||
def subparse(self, orig, cnt):
|
||||
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
|
||||
return msg.to_bot_message(self)
|
||||
|
||||
|
||||
class IRC(_IRC, SocketServer):
|
||||
pass
|
||||
|
||||
class IRC_secure(_IRC, SecureSocketServer):
|
||||
pass
|
||||
|
@ -32,16 +32,6 @@ def factory(uri, ssl=False, **init_args):
|
||||
if o.username is not None: args["username"] = o.username
|
||||
if o.password is not None: args["password"] = o.password
|
||||
|
||||
if ssl:
|
||||
try:
|
||||
from ssl import create_default_context
|
||||
args["_context"] = create_default_context()
|
||||
except ImportError:
|
||||
# Python 3.3 compat
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1
|
||||
args["_context"] = SSLContext(PROTOCOL_TLSv1)
|
||||
args["server_hostname"] = o.hostname
|
||||
|
||||
modifiers = o.path.split(",")
|
||||
target = unquote(modifiers.pop(0)[1:])
|
||||
|
||||
@ -68,11 +58,19 @@ def factory(uri, ssl=False, **init_args):
|
||||
if "channels" not in args and "isnick" not in modifiers:
|
||||
args["channels"] = [ target ]
|
||||
|
||||
if ssl:
|
||||
from nemubot.server.IRC import IRC_secure as SecureIRCServer
|
||||
srv = SecureIRCServer(**args)
|
||||
else:
|
||||
from nemubot.server.IRC import IRC as IRCServer
|
||||
srv = IRCServer(**args)
|
||||
|
||||
if ssl:
|
||||
try:
|
||||
from ssl import create_default_context
|
||||
context = create_default_context()
|
||||
except ImportError:
|
||||
# Python 3.3 compat
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1
|
||||
context = SSLContext(PROTOCOL_TLSv1)
|
||||
from ssl import wrap_socket
|
||||
srv._fd = context.wrap_socket(srv._fd, server_hostname=o.hostname)
|
||||
|
||||
|
||||
return srv
|
||||
|
@ -24,17 +24,16 @@ class AbstractServer:
|
||||
|
||||
"""An abstract server: handle communication with an IM server"""
|
||||
|
||||
def __init__(self, name=None, **kwargs):
|
||||
def __init__(self, name, fdClass, **kwargs):
|
||||
"""Initialize an abstract server
|
||||
|
||||
Keyword argument:
|
||||
name -- Identifier of the socket, for convinience
|
||||
fdClass -- Class to instantiate as support file
|
||||
"""
|
||||
|
||||
self._name = name
|
||||
self._socket = socket
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._fd = fdClass(**kwargs)
|
||||
|
||||
self._logger = logging.getLogger("nemubot.server." + str(self.name))
|
||||
self._readbuffer = b''
|
||||
@ -46,7 +45,7 @@ class AbstractServer:
|
||||
if self._name is not None:
|
||||
return self._name
|
||||
else:
|
||||
return self.fileno()
|
||||
return self._fd.fileno()
|
||||
|
||||
|
||||
# Open/close
|
||||
@ -56,12 +55,12 @@ class AbstractServer:
|
||||
|
||||
self._logger.info("Opening connection")
|
||||
|
||||
super().connect(*args, **kwargs)
|
||||
self._fd.connect(*args, **kwargs)
|
||||
|
||||
self._on_connect()
|
||||
|
||||
def _on_connect(self):
|
||||
sync_act("sckt", "register", self.fileno())
|
||||
sync_act("sckt", "register", self._fd.fileno())
|
||||
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
@ -69,10 +68,10 @@ class AbstractServer:
|
||||
|
||||
self._logger.info("Closing connection")
|
||||
|
||||
if self.fileno() > 0:
|
||||
sync_act("sckt", "unregister", self.fileno())
|
||||
if self._fd.fileno() > 0:
|
||||
sync_act("sckt", "unregister", self._fd.fileno())
|
||||
|
||||
super().close(*args, **kwargs)
|
||||
self._fd.close(*args, **kwargs)
|
||||
|
||||
|
||||
# Writes
|
||||
@ -86,14 +85,14 @@ class AbstractServer:
|
||||
|
||||
self._sending_queue.put(self.format(message))
|
||||
self._logger.debug("Message '%s' appended to write queue", message)
|
||||
sync_act("sckt", "write", self.fileno())
|
||||
sync_act("sckt", "write", self._fd.fileno())
|
||||
|
||||
|
||||
def async_write(self):
|
||||
"""Internal function used when the file descriptor is writable"""
|
||||
|
||||
try:
|
||||
sync_act("sckt", "unwrite", self.fileno())
|
||||
sync_act("sckt", "unwrite", self._fd.fileno())
|
||||
while not self._sending_queue.empty():
|
||||
self._write(self._sending_queue.get_nowait())
|
||||
self._sending_queue.task_done()
|
||||
@ -131,7 +130,7 @@ class AbstractServer:
|
||||
A list of fully received messages
|
||||
"""
|
||||
|
||||
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
|
||||
ret, self._readbuffer = self.lex(self._readbuffer + self._fd.read())
|
||||
|
||||
for r in ret:
|
||||
yield r
|
||||
@ -159,4 +158,9 @@ class AbstractServer:
|
||||
def exception(self, flags):
|
||||
"""Exception occurs on fd"""
|
||||
|
||||
self.close()
|
||||
self._fd.close()
|
||||
|
||||
# Proxy
|
||||
|
||||
def fileno(self):
|
||||
return self._fd.fileno()
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
import nemubot.message as message
|
||||
from nemubot.message.printer.socket import Socket as SocketPrinter
|
||||
@ -40,7 +39,7 @@ class _Socket(AbstractServer):
|
||||
# Write
|
||||
|
||||
def _write(self, cnt):
|
||||
self.sendall(cnt)
|
||||
self._fd.sendall(cnt)
|
||||
|
||||
|
||||
def format(self, txt):
|
||||
@ -52,8 +51,8 @@ class _Socket(AbstractServer):
|
||||
|
||||
# Read
|
||||
|
||||
def recv(self, n=1024):
|
||||
return super().recv(n)
|
||||
def recv(self, *args, **kwargs):
|
||||
return self._fd.recv(*args, **kwargs)
|
||||
|
||||
|
||||
def parse(self, line):
|
||||
@ -67,7 +66,7 @@ class _Socket(AbstractServer):
|
||||
args = line.split(' ')
|
||||
|
||||
if len(args):
|
||||
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
|
||||
yield message.Command(cmd=args[0], args=args[1:], server=self._fd.fileno(), to=["you"], frm="you")
|
||||
|
||||
|
||||
def subparse(self, orig, cnt):
|
||||
@ -78,50 +77,46 @@ class _Socket(AbstractServer):
|
||||
yield m
|
||||
|
||||
|
||||
class _SocketServer(_Socket):
|
||||
class SocketServer(_Socket):
|
||||
|
||||
def __init__(self, host, port, bind=None, **kwargs):
|
||||
(family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
|
||||
(family, type, proto, canonname, self._sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
|
||||
|
||||
super().__init__(family=family, type=type, proto=proto, **kwargs)
|
||||
super().__init__(fdClass=socket.socket, family=family, type=type, proto=proto, **kwargs)
|
||||
|
||||
self._sockaddr = sockaddr
|
||||
self._bind = bind
|
||||
|
||||
|
||||
def connect(self):
|
||||
self._logger.info("Connection to %s:%d", *self._sockaddr[:2])
|
||||
self._logger.info("Connecting to %s:%d", *self._sockaddr[:2])
|
||||
super().connect(self._sockaddr)
|
||||
self._logger.info("Connected to %s:%d", *self._sockaddr[:2])
|
||||
|
||||
if self._bind:
|
||||
super().bind(self._bind)
|
||||
|
||||
|
||||
class SocketServer(_SocketServer, socket.socket):
|
||||
pass
|
||||
|
||||
|
||||
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
|
||||
pass
|
||||
self._fd.bind(self._bind)
|
||||
|
||||
|
||||
class UnixSocket:
|
||||
|
||||
def __init__(self, location, **kwargs):
|
||||
super().__init__(family=socket.AF_UNIX, **kwargs)
|
||||
super().__init__(fdClass=socket.socket, family=socket.AF_UNIX, **kwargs)
|
||||
|
||||
self._socket_path = location
|
||||
|
||||
|
||||
def connect(self):
|
||||
self._logger.info("Connection to unix://%s", self._socket_path)
|
||||
super().connect(self._socket_path)
|
||||
self.connect(self._socket_path)
|
||||
|
||||
|
||||
class SocketClient(_Socket, socket.socket):
|
||||
class SocketClient(_Socket):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(fdClass=socket.socket, **kwargs)
|
||||
|
||||
|
||||
def read(self):
|
||||
return self.recv()
|
||||
return self._fd.recv()
|
||||
|
||||
|
||||
class _Listener:
|
||||
@ -134,7 +129,7 @@ class _Listener:
|
||||
|
||||
|
||||
def read(self):
|
||||
conn, addr = self.accept()
|
||||
conn, addr = self._fd.accept()
|
||||
fileno = conn.fileno()
|
||||
self._logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
|
||||
|
||||
@ -145,11 +140,7 @@ class _Listener:
|
||||
return b''
|
||||
|
||||
|
||||
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
class UnixSocketListener(_Listener, UnixSocket, _Socket):
|
||||
|
||||
def connect(self):
|
||||
self._logger.info("Creating Unix socket at unix://%s", self._socket_path)
|
||||
@ -159,8 +150,8 @@ class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
self.bind(self._socket_path)
|
||||
self.listen(5)
|
||||
self._fd.bind(self._socket_path)
|
||||
self._fd.listen(5)
|
||||
self._logger.info("Socket ready for accepting new connections")
|
||||
|
||||
self._on_connect()
|
||||
@ -171,7 +162,7 @@ class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
|
||||
import socket
|
||||
|
||||
try:
|
||||
self.shutdown(socket.SHUT_RDWR)
|
||||
self._fd.shutdown(socket.SHUT_RDWR)
|
||||
except socket.error:
|
||||
pass
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user