servers: use proxy design pattern instead of inheritance, because Python ssl patch has benn refused
This commit is contained in:
parent
7a4b27510c
commit
12ddf40ef4
@ -9,8 +9,6 @@ Requirements
|
|||||||
|
|
||||||
*nemubot* requires at least Python 3.3 to work.
|
*nemubot* requires at least Python 3.3 to work.
|
||||||
|
|
||||||
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
|
|
||||||
|
|
||||||
Some modules (like `cve`, `nextstop` or `laposte`) require the
|
Some modules (like `cve`, `nextstop` or `laposte`) require the
|
||||||
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
|
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
|
||||||
but the core and framework has no dependency.
|
but the core and framework has no dependency.
|
||||||
|
@ -21,10 +21,10 @@ import socket
|
|||||||
from nemubot.channel import Channel
|
from nemubot.channel import Channel
|
||||||
from nemubot.message.printer.IRC import IRC as IRCPrinter
|
from nemubot.message.printer.IRC import IRC as IRCPrinter
|
||||||
from nemubot.server.message.IRC import IRC as IRCMessage
|
from nemubot.server.message.IRC import IRC as IRCMessage
|
||||||
from nemubot.server.socket import SocketServer, SecureSocketServer
|
from nemubot.server.socket import SocketServer
|
||||||
|
|
||||||
|
|
||||||
class _IRC:
|
class IRC(SocketServer):
|
||||||
|
|
||||||
"""Concrete implementation of a connexion to an IRC server"""
|
"""Concrete implementation of a connexion to an IRC server"""
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ class _IRC:
|
|||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if not self._closed:
|
if not self._fd._closed:
|
||||||
self.write("QUIT")
|
self.write("QUIT")
|
||||||
return super().close()
|
return super().close()
|
||||||
|
|
||||||
@ -274,10 +274,3 @@ class _IRC:
|
|||||||
def subparse(self, orig, cnt):
|
def subparse(self, orig, cnt):
|
||||||
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
|
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
|
||||||
return msg.to_bot_message(self)
|
return msg.to_bot_message(self)
|
||||||
|
|
||||||
|
|
||||||
class IRC(_IRC, SocketServer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class IRC_secure(_IRC, SecureSocketServer):
|
|
||||||
pass
|
|
||||||
|
@ -32,16 +32,6 @@ def factory(uri, ssl=False, **init_args):
|
|||||||
if o.username is not None: args["username"] = o.username
|
if o.username is not None: args["username"] = o.username
|
||||||
if o.password is not None: args["password"] = o.password
|
if o.password is not None: args["password"] = o.password
|
||||||
|
|
||||||
if ssl:
|
|
||||||
try:
|
|
||||||
from ssl import create_default_context
|
|
||||||
args["_context"] = create_default_context()
|
|
||||||
except ImportError:
|
|
||||||
# Python 3.3 compat
|
|
||||||
from ssl import SSLContext, PROTOCOL_TLSv1
|
|
||||||
args["_context"] = SSLContext(PROTOCOL_TLSv1)
|
|
||||||
args["server_hostname"] = o.hostname
|
|
||||||
|
|
||||||
modifiers = o.path.split(",")
|
modifiers = o.path.split(",")
|
||||||
target = unquote(modifiers.pop(0)[1:])
|
target = unquote(modifiers.pop(0)[1:])
|
||||||
|
|
||||||
@ -68,11 +58,19 @@ def factory(uri, ssl=False, **init_args):
|
|||||||
if "channels" not in args and "isnick" not in modifiers:
|
if "channels" not in args and "isnick" not in modifiers:
|
||||||
args["channels"] = [ target ]
|
args["channels"] = [ target ]
|
||||||
|
|
||||||
if ssl:
|
|
||||||
from nemubot.server.IRC import IRC_secure as SecureIRCServer
|
|
||||||
srv = SecureIRCServer(**args)
|
|
||||||
else:
|
|
||||||
from nemubot.server.IRC import IRC as IRCServer
|
from nemubot.server.IRC import IRC as IRCServer
|
||||||
srv = IRCServer(**args)
|
srv = IRCServer(**args)
|
||||||
|
|
||||||
|
if ssl:
|
||||||
|
try:
|
||||||
|
from ssl import create_default_context
|
||||||
|
context = create_default_context()
|
||||||
|
except ImportError:
|
||||||
|
# Python 3.3 compat
|
||||||
|
from ssl import SSLContext, PROTOCOL_TLSv1
|
||||||
|
context = SSLContext(PROTOCOL_TLSv1)
|
||||||
|
from ssl import wrap_socket
|
||||||
|
srv._fd = context.wrap_socket(srv._fd, server_hostname=o.hostname)
|
||||||
|
|
||||||
|
|
||||||
return srv
|
return srv
|
||||||
|
@ -24,17 +24,16 @@ class AbstractServer:
|
|||||||
|
|
||||||
"""An abstract server: handle communication with an IM server"""
|
"""An abstract server: handle communication with an IM server"""
|
||||||
|
|
||||||
def __init__(self, name=None, **kwargs):
|
def __init__(self, name, fdClass, **kwargs):
|
||||||
"""Initialize an abstract server
|
"""Initialize an abstract server
|
||||||
|
|
||||||
Keyword argument:
|
Keyword argument:
|
||||||
name -- Identifier of the socket, for convinience
|
name -- Identifier of the socket, for convinience
|
||||||
|
fdClass -- Class to instantiate as support file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._name = name
|
self._name = name
|
||||||
self._socket = socket
|
self._fd = fdClass(**kwargs)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self._logger = logging.getLogger("nemubot.server." + str(self.name))
|
self._logger = logging.getLogger("nemubot.server." + str(self.name))
|
||||||
self._readbuffer = b''
|
self._readbuffer = b''
|
||||||
@ -46,7 +45,7 @@ class AbstractServer:
|
|||||||
if self._name is not None:
|
if self._name is not None:
|
||||||
return self._name
|
return self._name
|
||||||
else:
|
else:
|
||||||
return self.fileno()
|
return self._fd.fileno()
|
||||||
|
|
||||||
|
|
||||||
# Open/close
|
# Open/close
|
||||||
@ -56,12 +55,12 @@ class AbstractServer:
|
|||||||
|
|
||||||
self._logger.info("Opening connection")
|
self._logger.info("Opening connection")
|
||||||
|
|
||||||
super().connect(*args, **kwargs)
|
self._fd.connect(*args, **kwargs)
|
||||||
|
|
||||||
self._on_connect()
|
self._on_connect()
|
||||||
|
|
||||||
def _on_connect(self):
|
def _on_connect(self):
|
||||||
sync_act("sckt", "register", self.fileno())
|
sync_act("sckt", "register", self._fd.fileno())
|
||||||
|
|
||||||
|
|
||||||
def close(self, *args, **kwargs):
|
def close(self, *args, **kwargs):
|
||||||
@ -69,10 +68,10 @@ class AbstractServer:
|
|||||||
|
|
||||||
self._logger.info("Closing connection")
|
self._logger.info("Closing connection")
|
||||||
|
|
||||||
if self.fileno() > 0:
|
if self._fd.fileno() > 0:
|
||||||
sync_act("sckt", "unregister", self.fileno())
|
sync_act("sckt", "unregister", self._fd.fileno())
|
||||||
|
|
||||||
super().close(*args, **kwargs)
|
self._fd.close(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Writes
|
# Writes
|
||||||
@ -86,14 +85,14 @@ class AbstractServer:
|
|||||||
|
|
||||||
self._sending_queue.put(self.format(message))
|
self._sending_queue.put(self.format(message))
|
||||||
self._logger.debug("Message '%s' appended to write queue", message)
|
self._logger.debug("Message '%s' appended to write queue", message)
|
||||||
sync_act("sckt", "write", self.fileno())
|
sync_act("sckt", "write", self._fd.fileno())
|
||||||
|
|
||||||
|
|
||||||
def async_write(self):
|
def async_write(self):
|
||||||
"""Internal function used when the file descriptor is writable"""
|
"""Internal function used when the file descriptor is writable"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sync_act("sckt", "unwrite", self.fileno())
|
sync_act("sckt", "unwrite", self._fd.fileno())
|
||||||
while not self._sending_queue.empty():
|
while not self._sending_queue.empty():
|
||||||
self._write(self._sending_queue.get_nowait())
|
self._write(self._sending_queue.get_nowait())
|
||||||
self._sending_queue.task_done()
|
self._sending_queue.task_done()
|
||||||
@ -131,7 +130,7 @@ class AbstractServer:
|
|||||||
A list of fully received messages
|
A list of fully received messages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
|
ret, self._readbuffer = self.lex(self._readbuffer + self._fd.read())
|
||||||
|
|
||||||
for r in ret:
|
for r in ret:
|
||||||
yield r
|
yield r
|
||||||
@ -159,4 +158,9 @@ class AbstractServer:
|
|||||||
def exception(self, flags):
|
def exception(self, flags):
|
||||||
"""Exception occurs on fd"""
|
"""Exception occurs on fd"""
|
||||||
|
|
||||||
self.close()
|
self._fd.close()
|
||||||
|
|
||||||
|
# Proxy
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self._fd.fileno()
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
|
||||||
|
|
||||||
import nemubot.message as message
|
import nemubot.message as message
|
||||||
from nemubot.message.printer.socket import Socket as SocketPrinter
|
from nemubot.message.printer.socket import Socket as SocketPrinter
|
||||||
@ -40,7 +39,7 @@ class _Socket(AbstractServer):
|
|||||||
# Write
|
# Write
|
||||||
|
|
||||||
def _write(self, cnt):
|
def _write(self, cnt):
|
||||||
self.sendall(cnt)
|
self._fd.sendall(cnt)
|
||||||
|
|
||||||
|
|
||||||
def format(self, txt):
|
def format(self, txt):
|
||||||
@ -52,8 +51,8 @@ class _Socket(AbstractServer):
|
|||||||
|
|
||||||
# Read
|
# Read
|
||||||
|
|
||||||
def recv(self, n=1024):
|
def recv(self, *args, **kwargs):
|
||||||
return super().recv(n)
|
return self._fd.recv(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def parse(self, line):
|
def parse(self, line):
|
||||||
@ -67,7 +66,7 @@ class _Socket(AbstractServer):
|
|||||||
args = line.split(' ')
|
args = line.split(' ')
|
||||||
|
|
||||||
if len(args):
|
if len(args):
|
||||||
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
|
yield message.Command(cmd=args[0], args=args[1:], server=self._fd.fileno(), to=["you"], frm="you")
|
||||||
|
|
||||||
|
|
||||||
def subparse(self, orig, cnt):
|
def subparse(self, orig, cnt):
|
||||||
@ -78,50 +77,46 @@ class _Socket(AbstractServer):
|
|||||||
yield m
|
yield m
|
||||||
|
|
||||||
|
|
||||||
class _SocketServer(_Socket):
|
class SocketServer(_Socket):
|
||||||
|
|
||||||
def __init__(self, host, port, bind=None, **kwargs):
|
def __init__(self, host, port, bind=None, **kwargs):
|
||||||
(family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
|
(family, type, proto, canonname, self._sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
|
||||||
|
|
||||||
super().__init__(family=family, type=type, proto=proto, **kwargs)
|
super().__init__(fdClass=socket.socket, family=family, type=type, proto=proto, **kwargs)
|
||||||
|
|
||||||
self._sockaddr = sockaddr
|
|
||||||
self._bind = bind
|
self._bind = bind
|
||||||
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self._logger.info("Connection to %s:%d", *self._sockaddr[:2])
|
self._logger.info("Connecting to %s:%d", *self._sockaddr[:2])
|
||||||
super().connect(self._sockaddr)
|
super().connect(self._sockaddr)
|
||||||
|
self._logger.info("Connected to %s:%d", *self._sockaddr[:2])
|
||||||
|
|
||||||
if self._bind:
|
if self._bind:
|
||||||
super().bind(self._bind)
|
self._fd.bind(self._bind)
|
||||||
|
|
||||||
|
|
||||||
class SocketServer(_SocketServer, socket.socket):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UnixSocket:
|
class UnixSocket:
|
||||||
|
|
||||||
def __init__(self, location, **kwargs):
|
def __init__(self, location, **kwargs):
|
||||||
super().__init__(family=socket.AF_UNIX, **kwargs)
|
super().__init__(fdClass=socket.socket, family=socket.AF_UNIX, **kwargs)
|
||||||
|
|
||||||
self._socket_path = location
|
self._socket_path = location
|
||||||
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self._logger.info("Connection to unix://%s", self._socket_path)
|
self._logger.info("Connection to unix://%s", self._socket_path)
|
||||||
super().connect(self._socket_path)
|
self.connect(self._socket_path)
|
||||||
|
|
||||||
|
|
||||||
class SocketClient(_Socket, socket.socket):
|
class SocketClient(_Socket):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(fdClass=socket.socket, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
return self.recv()
|
return self._fd.recv()
|
||||||
|
|
||||||
|
|
||||||
class _Listener:
|
class _Listener:
|
||||||
@ -134,7 +129,7 @@ class _Listener:
|
|||||||
|
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
conn, addr = self.accept()
|
conn, addr = self._fd.accept()
|
||||||
fileno = conn.fileno()
|
fileno = conn.fileno()
|
||||||
self._logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
|
self._logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
|
||||||
|
|
||||||
@ -145,11 +140,7 @@ class _Listener:
|
|||||||
return b''
|
return b''
|
||||||
|
|
||||||
|
|
||||||
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
|
class UnixSocketListener(_Listener, UnixSocket, _Socket):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self._logger.info("Creating Unix socket at unix://%s", self._socket_path)
|
self._logger.info("Creating Unix socket at unix://%s", self._socket_path)
|
||||||
@ -159,8 +150,8 @@ class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.bind(self._socket_path)
|
self._fd.bind(self._socket_path)
|
||||||
self.listen(5)
|
self._fd.listen(5)
|
||||||
self._logger.info("Socket ready for accepting new connections")
|
self._logger.info("Socket ready for accepting new connections")
|
||||||
|
|
||||||
self._on_connect()
|
self._on_connect()
|
||||||
@ -171,7 +162,7 @@ class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
|
|||||||
import socket
|
import socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.shutdown(socket.SHUT_RDWR)
|
self._fd.shutdown(socket.SHUT_RDWR)
|
||||||
except socket.error:
|
except socket.error:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user