From 764e6f070b9f4e6a5ff0fc401aa15b8448ed78de Mon Sep 17 00:00:00 2001 From: nemunaire Date: Mon, 30 May 2016 22:06:35 +0200 Subject: [PATCH] Refactor file/socket management (use poll instead of select) --- README.md | 2 + nemubot/__main__.py | 26 +-- nemubot/bot.py | 145 +++++++++-------- nemubot/server/DCC.py | 2 +- nemubot/server/IRC.py | 43 +++-- nemubot/server/__init__.py | 45 ++--- nemubot/server/abstract.py | 126 +++++++------- nemubot/server/factory_test.py | 10 +- nemubot/server/message/__init__.py | 15 ++ nemubot/server/socket.py | 253 ++++++++++++----------------- 10 files changed, 328 insertions(+), 339 deletions(-) create mode 100644 nemubot/server/message/__init__.py diff --git a/README.md b/README.md index aa3b141..1d40faf 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ Requirements *nemubot* requires at least Python 3.3 to work. +Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629). + Some modules (like `cve`, `nextstop` or `laposte`) require the [BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/), but the core and framework has no dependency. diff --git a/nemubot/__main__.py b/nemubot/__main__.py index 5a236f4..c39dd2f 100644 --- a/nemubot/__main__.py +++ b/nemubot/__main__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -125,7 +125,7 @@ def main(): # Create bot context from nemubot import datastore - from nemubot.bot import Bot + from nemubot.bot import Bot, sync_act context = Bot(modules_paths=modules_paths, data_store=datastore.XML(args.data_path), verbosity=args.verbose) @@ -141,7 +141,7 @@ def main(): # Load requested configuration files for path in args.files: if os.path.isfile(path): - context.sync_queue.put_nowait(["loadconf", path]) + sync_act("loadconf", path) else: logger.error("%s is not a readable file", path) @@ -165,22 +165,28 @@ def main(): # Reload configuration file for path in args.files: if os.path.isfile(path): - context.sync_queue.put_nowait(["loadconf", path]) + sync_act("loadconf", path) signal.signal(signal.SIGHUP, sighuphandler) def sigusr1handler(signum, frame): """On SIGHUSR1, display stacktraces""" - import traceback + import threading, traceback for threadId, stack in sys._current_frames().items(): - logger.debug("########### Thread %d:\n%s", - threadId, + thName = "#%d" % threadId + for th in threading.enumerate(): + if th.ident == threadId: + thName = th.name + break + logger.debug("########### Thread %s:\n%s", + thName, "".join(traceback.format_stack(stack))) signal.signal(signal.SIGUSR1, sigusr1handler) if args.socketfile: - from nemubot.server.socket import SocketListener - context.add_server(SocketListener(context.add_server, "master_socket", - sock_location=args.socketfile)) + from nemubot.server.socket import UnixSocketListener + context.add_server(UnixSocketListener(new_server_cb=context.add_server, + location=args.socketfile, + name="master_socket")) # context can change when performing an hotswap, always join the latest context oldcontext = None diff --git a/nemubot/bot.py b/nemubot/bot.py index 2657d52..c8ede40 100644 --- a/nemubot/bot.py +++ b/nemubot/bot.py @@ -16,7 +16,9 @@ from datetime import datetime, timezone import logging +from multiprocessing import JoinableQueue import threading +import select import sys from nemubot import __version__ @@ -26,6 +28,11 @@ import nemubot.hooks logger = logging.getLogger("nemubot") +sync_queue = JoinableQueue() + +def sync_act(*args): + sync_queue.put(list(args)) + class Bot(threading.Thread): @@ -42,7 +49,7 @@ class Bot(threading.Thread): verbosity -- verbosity level """ - threading.Thread.__init__(self) + super().__init__(name="Nemubot main") logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", __version__, @@ -61,6 +68,7 @@ class Bot(threading.Thread): self.datastore.open() # Keep global context: servers and modules + self._poll = select.poll() self.servers = dict() self.modules = dict() self.modules_configuration = dict() @@ -138,60 +146,72 @@ class Bot(threading.Thread): self.cnsr_queue = Queue() self.cnsr_thrd = list() self.cnsr_thrd_size = -1 - # Synchrone actions to be treated by main thread - self.sync_queue = Queue() def run(self): - from select import select - from nemubot.server import _lock, _rlist, _wlist, _xlist + self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI) logger.info("Starting main loop") self.stop = False while not self.stop: - with _lock: - try: - rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) - except: - logger.error("Something went wrong in select") - fnd_smth = False - # Looking for invalid server - for r in _rlist: - if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0: - _rlist.remove(r) - logger.error("Found invalid object in _rlist: " + str(r)) - fnd_smth = True - for w in _wlist: - if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0: - _wlist.remove(w) - logger.error("Found invalid object in _wlist: " + str(w)) - fnd_smth = True - for x in _xlist: - if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0: - _xlist.remove(x) - logger.error("Found invalid object in _xlist: " + str(x)) - fnd_smth = True - if not fnd_smth: - logger.exception("Can't continue, sorry") - self.quit() - continue + for fd, flag in self._poll.poll(): + # Handle internal socket passing orders + if fd != sync_queue._reader.fileno() and fd in self.servers: + srv = self.servers[fd] - for x in xl: - try: - x.exception() - except: - logger.exception("Uncatched exception on server exception") - for w in wl: - try: - w.write_select() - except: - logger.exception("Uncatched exception on server write") - for r in rl: - for i in r.read(): + if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL): try: - self.receive_message(r, i) + srv.exception(flag) except: - logger.exception("Uncatched exception on server read") + logger.exception("Uncatched exception on server exception") + + if srv.fileno() > 0: + if flag & (select.POLLOUT): + try: + srv.async_write() + except: + logger.exception("Uncatched exception on server write") + + if flag & (select.POLLIN | select.POLLPRI): + try: + for i in srv.async_read(): + self.receive_message(srv, i) + except: + logger.exception("Uncatched exception on server read") + + else: + del self.servers[fd] + + + # Always check the sync queue + while not sync_queue.empty(): + args = sync_queue.get() + action = args.pop(0) + + if action == "sckt" and len(args) >= 2: + try: + if args[0] == "write": + self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI) + elif args[0] == "unwrite": + self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI) + + elif args[0] == "register": + self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI) + elif args[0] == "unregister": + self._poll.unregister(int(args[1])) + except: + logger.exception("Unhandled excpetion during action:") + + elif action == "exit": + self.quit() + + elif action == "loadconf": + for path in args: + logger.debug("Load configuration from %s", path) + self.load_file(path) + logger.info("Configurations successfully loaded") + + sync_queue.task_done() # Launch new consumer threads if necessary @@ -202,17 +222,6 @@ class Bot(threading.Thread): c = Consumer(self) self.cnsr_thrd.append(c) c.start() - - while self.sync_queue.qsize() > 0: - action = self.sync_queue.get_nowait() - if action[0] == "exit": - self.quit() - elif action[0] == "loadconf": - for path in action[1:]: - logger.debug("Load configuration from %s", path) - self.load_file(path) - logger.info("Configurations successfully loaded") - self.sync_queue.task_done() logger.info("Ending main loop") @@ -419,7 +428,7 @@ class Bot(threading.Thread): self.servers[fileno] = srv self.servers[srv.name] = srv if autoconnect and not hasattr(self, "noautoconnect"): - srv.open() + srv.connect() return True else: @@ -532,28 +541,28 @@ class Bot(threading.Thread): def quit(self): """Save and unload modules and disconnect servers""" - self.datastore.close() - if self.event_timer is not None: logger.info("Stop the event timer...") self.event_timer.cancel() + logger.info("Save and unload all modules...") + for mod in self.modules.items(): + self.unload_module(mod) + + logger.info("Close all servers connection...") + for srv in [self.servers[k] for k in self.servers]: + srv.close() + logger.info("Stop consumers") k = self.cnsr_thrd for cnsr in k: cnsr.stop = True - logger.info("Save and unload all modules...") - k = list(self.modules.keys()) - for mod in k: - self.unload_module(mod) - - logger.info("Close all servers connection...") - k = list(self.servers.keys()) - for srv in k: - self.servers[srv].close() + self.datastore.close() self.stop = True + sync_act("end") + sync_queue.join() # Treatment diff --git a/nemubot/server/DCC.py b/nemubot/server/DCC.py index 644a8cb..c1a6852 100644 --- a/nemubot/server/DCC.py +++ b/nemubot/server/DCC.py @@ -31,7 +31,7 @@ PORTS = list() class DCC(server.AbstractServer): def __init__(self, srv, dest, socket=None): - super().__init__(self) + super().__init__(name="Nemubot DCC server") self.error = False # An error has occur, closing the connection? self.messages = list() # Message queued before connexion diff --git a/nemubot/server/IRC.py b/nemubot/server/IRC.py index 08e2bc5..89eeab5 100644 --- a/nemubot/server/IRC.py +++ b/nemubot/server/IRC.py @@ -20,17 +20,17 @@ import re from nemubot.channel import Channel from nemubot.message.printer.IRC import IRC as IRCPrinter from nemubot.server.message.IRC import IRC as IRCMessage -from nemubot.server.socket import SocketServer +from nemubot.server.socket import SocketServer, SecureSocketServer -class IRC(SocketServer): +class _IRC: """Concrete implementation of a connexion to an IRC server""" - def __init__(self, host="localhost", port=6667, ssl=False, owner=None, + def __init__(self, host="localhost", port=6667, owner=None, nick="nemubot", username=None, password=None, realname="Nemubot", encoding="utf-8", caps=None, - channels=list(), on_connect=None): + channels=list(), on_connect=None, **kwargs): """Prepare a connection with an IRC server Keyword arguments: @@ -54,7 +54,8 @@ class IRC(SocketServer): self.owner = owner self.realname = realname - super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port)) + super().__init__(name=self.username + "@" + host + ":" + str(port), + host=host, port=port, **kwargs) self.printer = IRCPrinter self.encoding = encoding @@ -231,20 +232,19 @@ class IRC(SocketServer): # Open/close - def open(self): - if super().open(): - if self.password is not None: - self.write("PASS :" + self.password) - if self.capabilities is not None: - self.write("CAP LS") - self.write("NICK :" + self.nick) - self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) - return True - return False + def connect(self): + super().connect() + + if self.password is not None: + self.write("PASS :" + self.password) + if self.capabilities is not None: + self.write("CAP LS") + self.write("NICK :" + self.nick) + self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) def close(self): - if not self.closed: + if not self._closed: self.write("QUIT") return super().close() @@ -253,8 +253,8 @@ class IRC(SocketServer): # Read - def read(self): - for line in super().read(): + def async_read(self): + for line in super().async_read(): # PING should be handled here, so start parsing here :/ msg = IRCMessage(line, self.encoding) @@ -273,3 +273,10 @@ class IRC(SocketServer): def subparse(self, orig, cnt): msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding) return msg.to_bot_message(self) + + +class IRC(_IRC, SocketServer): + pass + +class IRC_secure(_IRC, SecureSocketServer): + pass diff --git a/nemubot/server/__init__.py b/nemubot/server/__init__.py index 3c88138..6b583b7 100644 --- a/nemubot/server/__init__.py +++ b/nemubot/server/__init__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,34 +14,37 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import threading -_lock = threading.Lock() - -# Lists for select -_rlist = [] -_wlist = [] -_xlist = [] - - -def factory(uri, **init_args): +def factory(uri, ssl=False, **init_args): from urllib.parse import urlparse, unquote o = urlparse(uri) + srv = None + if o.scheme == "irc" or o.scheme == "ircs": # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html args = init_args - modifiers = o.path.split(",") - target = unquote(modifiers.pop(0)[1:]) - - if o.scheme == "ircs": args["ssl"] = True + if o.scheme == "ircs": ssl = True if o.hostname is not None: args["host"] = o.hostname if o.port is not None: args["port"] = o.port if o.username is not None: args["username"] = o.username if o.password is not None: args["password"] = o.password + if ssl: + try: + from ssl import create_default_context + args["_context"] = create_default_context() + except ImportError: + # Python 3.3 compat + from ssl import SSLContext, PROTOCOL_TLSv1 + args["_context"] = SSLContext(PROTOCOL_TLSv1) + args["server_hostname"] = o.hostname + + modifiers = o.path.split(",") + target = unquote(modifiers.pop(0)[1:]) + queries = o.query.split("&") for q in queries: if "=" in q: @@ -64,7 +67,11 @@ def factory(uri, **init_args): if "channels" not in args and "isnick" not in modifiers: args["channels"] = [ target ] - from nemubot.server.IRC import IRC as IRCServer - return IRCServer(**args) - else: - return None + if ssl: + from nemubot.server.IRC import IRC_secure as SecureIRCServer + srv = SecureIRCServer(**args) + else: + from nemubot.server.IRC import IRC as IRCServer + srv = IRCServer(**args) + + return srv diff --git a/nemubot/server/abstract.py b/nemubot/server/abstract.py index dc2081d..fd25c2d 100644 --- a/nemubot/server/abstract.py +++ b/nemubot/server/abstract.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,34 +14,30 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import io import logging import queue -from nemubot.server import _lock, _rlist, _wlist, _xlist +from nemubot.bot import sync_act -# Extends from IOBase in order to be compatible with select function -class AbstractServer(io.IOBase): + +class AbstractServer: """An abstract server: handle communication with an IM server""" - def __init__(self, name=None, send_callback=None): + def __init__(self, name=None, **kwargs): """Initialize an abstract server Keyword argument: - send_callback -- Callback when developper want to send a message + name -- Identifier of the socket, for convinience """ self._name = name - super().__init__() + super().__init__(**kwargs) - self.logger = logging.getLogger("nemubot.server." + self.name) + self.logger = logging.getLogger("nemubot.server." + str(self.name)) + self._readbuffer = b'' self._sending_queue = queue.Queue() - if send_callback is not None: - self._send_callback = send_callback - else: - self._send_callback = self._write_select @property @@ -54,40 +50,28 @@ class AbstractServer(io.IOBase): # Open/close - def __enter__(self): - self.open() - return self + def connect(self, *args, **kwargs): + """Register the server in _poll""" + + self.logger.info("Opening connection") + + super().connect(*args, **kwargs) + + self._on_connect() + + def _on_connect(self): + sync_act("sckt", "register", self.fileno()) - def __exit__(self, type, value, traceback): - self.close() + def close(self, *args, **kwargs): + """Unregister the server from _poll""" + self.logger.info("Closing connection") - def open(self): - """Generic open function that register the server un _rlist in case - of successful _open""" - self.logger.info("Opening connection to %s", self.id) - if not hasattr(self, "_open") or self._open(): - _rlist.append(self) - _xlist.append(self) - return True - return False + if self.fileno() > 0: + sync_act("sckt", "unregister", self.fileno()) - - def close(self): - """Generic close function that register the server un _{r,w,x}list in - case of successful _close""" - self.logger.info("Closing connection to %s", self.id) - with _lock: - if not hasattr(self, "_close") or self._close(): - if self in _rlist: - _rlist.remove(self) - if self in _wlist: - _wlist.remove(self) - if self in _xlist: - _xlist.remove(self) - return True - return False + super().close(*args, **kwargs) # Writes @@ -99,13 +83,16 @@ class AbstractServer(io.IOBase): message -- message to send """ - self._send_callback(message) + self._sending_queue.put(self.format(message)) + self.logger.debug("Message '%s' appended to write queue", message) + sync_act("sckt", "write", self.fileno()) - def write_select(self): - """Internal function used by the select function""" + def async_write(self): + """Internal function used when the file descriptor is writable""" + try: - _wlist.remove(self) + sync_act("sckt", "unwrite", self.fileno()) while not self._sending_queue.empty(): self._write(self._sending_queue.get_nowait()) self._sending_queue.task_done() @@ -114,19 +101,6 @@ class AbstractServer(io.IOBase): pass - def _write_select(self, message): - """Send a message to the server safely through select - - Argument: - message -- message to send - """ - - self._sending_queue.put(self.format(message)) - self.logger.debug("Message '%s' appended to write queue", message) - if self not in _wlist: - _wlist.append(self) - - def send_response(self, response): """Send a formated Message class @@ -149,13 +123,39 @@ class AbstractServer(io.IOBase): # Read + def async_read(self): + """Internal function used when the file descriptor is readable + + Returns: + A list of fully received messages + """ + + ret, self._readbuffer = self.lex(self._readbuffer + self.read()) + + for r in ret: + yield r + + + def lex(self, buf): + """Assume lexing in default case is per line + + Argument: + buf -- buffer to lex + """ + + msgs = buf.split(b'\r\n') + partial = msgs.pop() + + return msgs, partial + + def parse(self, msg): raise NotImplemented # Exceptions - def exception(self): - """Exception occurs in fd""" - self.logger.warning("Unhandle file descriptor exception on server %s", - self.name) + def exception(self, flags): + """Exception occurs on fd""" + + self.close() diff --git a/nemubot/server/factory_test.py b/nemubot/server/factory_test.py index cc7d35b..358591e 100644 --- a/nemubot/server/factory_test.py +++ b/nemubot/server/factory_test.py @@ -22,34 +22,30 @@ class TestFactory(unittest.TestCase): def test_IRC1(self): from nemubot.server.IRC import IRC as IRCServer + from nemubot.server.IRC import IRC_secure as IRCSServer # : If omitted, the client must connect to a prespecified default IRC server. server = factory("irc:///") self.assertIsInstance(server, IRCServer) self.assertEqual(server.host, "localhost") - self.assertFalse(server.ssl) server = factory("ircs:///") - self.assertIsInstance(server, IRCServer) + self.assertIsInstance(server, IRCSServer) self.assertEqual(server.host, "localhost") - self.assertTrue(server.ssl) server = factory("irc://host1") self.assertIsInstance(server, IRCServer) self.assertEqual(server.host, "host1") - self.assertFalse(server.ssl) server = factory("irc://host2:6667") self.assertIsInstance(server, IRCServer) self.assertEqual(server.host, "host2") self.assertEqual(server.port, 6667) - self.assertFalse(server.ssl) server = factory("ircs://host3:194/") - self.assertIsInstance(server, IRCServer) + self.assertIsInstance(server, IRCSServer) self.assertEqual(server.host, "host3") self.assertEqual(server.port, 194) - self.assertTrue(server.ssl) if __name__ == '__main__': diff --git a/nemubot/server/message/__init__.py b/nemubot/server/message/__init__.py new file mode 100644 index 0000000..57f3468 --- /dev/null +++ b/nemubot/server/message/__init__.py @@ -0,0 +1,15 @@ +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py index 13ac9bd..1137e36 100644 --- a/nemubot/server/socket.py +++ b/nemubot/server/socket.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,117 +14,33 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import os +import socket +import ssl + import nemubot.message as message from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.server.abstract import AbstractServer -class SocketServer(AbstractServer): +class _Socket(AbstractServer): - """Concrete implementation of a socket connexion (can be wrapped with TLS)""" + """Concrete implementation of a socket connection""" - def __init__(self, sock_location=None, - host=None, port=None, - sock=None, - ssl=False, - name=None): + def __init__(self, printer=SocketPrinter, **kwargs): """Create a server socket - - Keyword arguments: - sock_location -- Path to the UNIX socket - host -- Hostname of the INET socket - port -- Port of the INET socket - sock -- Already connected socket - ssl -- Should TLS connection enabled - name -- Convinience name """ - import socket - - assert(sock is None or isinstance(sock, socket.SocketType)) - assert(port is None or isinstance(port, int)) - - super().__init__(name=name) - - if sock is None: - if sock_location is not None: - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.connect_to = sock_location - elif host is not None: - for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): - self.socket = socket.socket(af, socktype, proto) - self.connect_to = sa - break - else: - self.socket = sock - - self.ssl = ssl + super().__init__(**kwargs) self.readbuffer = b'' - self.printer = SocketPrinter - - - def fileno(self): - return self.socket.fileno() if self.socket else None - - - @property - def closed(self): - """Indicator of the connection aliveness""" - return self.socket._closed - - - # Open/close - - def open(self): - if not self.closed: - return True - - try: - self.socket.connect(self.connect_to) - self.logger.info("Connected to %s", self.connect_to) - except: - self.socket.close() - self.logger.exception("Unable to connect to %s", - self.connect_to) - return False - - # Wrap the socket for SSL - if self.ssl: - import ssl - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - self.socket = ctx.wrap_socket(self.socket) - - return super().open() - - - def close(self): - import socket - - # Flush the sending queue before close - from nemubot.server import _lock - _lock.release() - self._sending_queue.join() - _lock.acquire() - - if not self.closed: - try: - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self.socket.close() - - return super().close() + self.printer = printer # Write def _write(self, cnt): - if self.closed: - return - - self.socket.sendall(cnt) + self.sendall(cnt) def format(self, txt): @@ -136,19 +52,12 @@ class SocketServer(AbstractServer): # Read - def read(self): - if self.closed: - return [] - - raw = self.socket.recv(1024) - temp = (self.readbuffer + raw).split(b'\r\n') - self.readbuffer = temp.pop() - - for line in temp: - yield line + def recv(self, n=1024): + return super().recv(n) def parse(self, line): + """Implement a default behaviour for socket""" import shlex line = line.strip().decode() @@ -157,48 +66,97 @@ class SocketServer(AbstractServer): except ValueError: args = line.split(' ') - yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you") + if len(args): + yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you") -class SocketListener(AbstractServer): +class _SocketServer(_Socket): - def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None): - super().__init__(name=name) - self.new_server_cb = new_server_cb - self.sock_location = sock_location - self.host = host - self.port = port - self.ssl = ssl - self.nb_son = 0 + def __init__(self, host, port, bind=None, **kwargs): + super().__init__(family=socket.AF_INET, **kwargs) + assert(host is not None) + assert(isinstance(port, int)) - def fileno(self): - return self.socket.fileno() if self.socket else None + self._host = host + self._port = port + self._bind = bind @property - def closed(self): - """Indicator of the connection aliveness""" - return self.socket is None + def host(self): + return self._host - def open(self): - import os - import socket + def connect(self): + self.logger.info("Connection to %s:%d", self._host, self._port) + super().connect((self._host, self._port)) - if self.sock_location is not None: - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - os.remove(self.sock_location) - except FileNotFoundError: - pass - self.socket.bind(self.sock_location) - elif self.host is not None and self.port is not None: - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.bind((self.host, self.port)) - self.socket.listen(5) + if self._bind: + super().bind(self._bind) - return super().open() + +class SocketServer(_SocketServer, socket.socket): + pass + + +class SecureSocketServer(_SocketServer, ssl.SSLSocket): + pass + + +class UnixSocket: + + def __init__(self, location, **kwargs): + super().__init__(family=socket.AF_UNIX, **kwargs) + + self._socket_path = location + + + def connect(self): + self.logger.info("Connection to unix://%s", self._socket_path) + super().connect(self._socket_path) + + +class _Listener: + + def __init__(self, new_server_cb, instanciate=_Socket, **kwargs): + super().__init__(**kwargs) + + self._instanciate = instanciate + self._new_server_cb = new_server_cb + + + def read(self): + conn, addr = self.accept() + fileno = conn.fileno() + self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno) + + ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach()) + ss.connect = ss._on_connect + self._new_server_cb(ss, autoconnect=True) + + return b'' + + +class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + + def connect(self): + self.logger.info("Creating Unix socket at unix://%s", self._socket_path) + + try: + os.remove(self._socket_path) + except FileNotFoundError: + pass + + self.bind(self._socket_path) + self.listen(5) + self.logger.info("Socket ready for accepting new connections") + + self._on_connect() def close(self): @@ -206,25 +164,14 @@ class SocketListener(AbstractServer): import socket try: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() - if self.sock_location is not None: - os.remove(self.sock_location) + self.shutdown(socket.SHUT_RDWR) except socket.error: pass - return super().close() + super().close() - - # Read - - def read(self): - if self.closed: - return [] - - conn, addr = self.socket.accept() - self.nb_son += 1 - ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn) - self.new_server_cb(ss) - - return [] + try: + if self._socket_path is not None: + os.remove(self._socket_path) + except: + pass