Refactor file/socket management (use poll instead of select)

This commit is contained in:
nemunaire 2016-05-30 22:06:35 +02:00
parent 6d8dca211d
commit 449cb684f9
10 changed files with 332 additions and 339 deletions

View File

@ -9,6 +9,8 @@ Requirements
*nemubot* requires at least Python 3.3 to work. *nemubot* requires at least Python 3.3 to work.
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
Some modules (like `cve`, `nextstop` or `laposte`) require the Some modules (like `cve`, `nextstop` or `laposte`) require the
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/), [BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
but the core and framework has no dependency. but the core and framework has no dependency.

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -125,7 +125,7 @@ def main():
# Create bot context # Create bot context
from nemubot import datastore from nemubot import datastore
from nemubot.bot import Bot from nemubot.bot import Bot, sync_act
context = Bot(modules_paths=modules_paths, context = Bot(modules_paths=modules_paths,
data_store=datastore.XML(args.data_path), data_store=datastore.XML(args.data_path),
verbosity=args.verbose) verbosity=args.verbose)
@ -141,7 +141,7 @@ def main():
# Load requested configuration files # Load requested configuration files
for path in args.files: for path in args.files:
if os.path.isfile(path): if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path]) sync_act("loadconf", path)
else: else:
logger.error("%s is not a readable file", path) logger.error("%s is not a readable file", path)
@ -165,22 +165,28 @@ def main():
# Reload configuration file # Reload configuration file
for path in args.files: for path in args.files:
if os.path.isfile(path): if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path]) sync_act("loadconf", path)
signal.signal(signal.SIGHUP, sighuphandler) signal.signal(signal.SIGHUP, sighuphandler)
def sigusr1handler(signum, frame): def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces""" """On SIGHUSR1, display stacktraces"""
import traceback import threading, traceback
for threadId, stack in sys._current_frames().items(): for threadId, stack in sys._current_frames().items():
logger.debug("########### Thread %d:\n%s", thName = "#%d" % threadId
threadId, for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack))) "".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler) signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile: if args.socketfile:
from nemubot.server.socket import SocketListener from nemubot.server.socket import UnixSocketListener
context.add_server(SocketListener(context.add_server, "master_socket", context.add_server(UnixSocketListener(new_server_cb=context.add_server,
sock_location=args.socketfile)) location=args.socketfile,
name="master_socket"))
# context can change when performing an hotswap, always join the latest context # context can change when performing an hotswap, always join the latest context
oldcontext = None oldcontext = None

View File

@ -16,7 +16,9 @@
from datetime import datetime, timezone from datetime import datetime, timezone
import logging import logging
from multiprocessing import JoinableQueue
import threading import threading
import select
import sys import sys
from nemubot import __version__ from nemubot import __version__
@ -26,6 +28,11 @@ import nemubot.hooks
logger = logging.getLogger("nemubot") logger = logging.getLogger("nemubot")
sync_queue = JoinableQueue()
def sync_act(*args):
sync_queue.put(list(args))
class Bot(threading.Thread): class Bot(threading.Thread):
@ -42,7 +49,7 @@ class Bot(threading.Thread):
verbosity -- verbosity level verbosity -- verbosity level
""" """
threading.Thread.__init__(self) super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__, __version__,
@ -61,6 +68,7 @@ class Bot(threading.Thread):
self.datastore.open() self.datastore.open()
# Keep global context: servers and modules # Keep global context: servers and modules
self._poll = select.poll()
self.servers = dict() self.servers = dict()
self.modules = dict() self.modules = dict()
self.modules_configuration = dict() self.modules_configuration = dict()
@ -138,60 +146,72 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue() self.cnsr_queue = Queue()
self.cnsr_thrd = list() self.cnsr_thrd = list()
self.cnsr_thrd_size = -1 self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self): def run(self):
from select import select self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
from nemubot.server import _lock, _rlist, _wlist, _xlist
logger.info("Starting main loop") logger.info("Starting main loop")
self.stop = False self.stop = False
while not self.stop: while not self.stop:
with _lock: for fd, flag in self._poll.poll():
try: # Handle internal socket passing orders
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) if fd != sync_queue._reader.fileno() and fd in self.servers:
except: srv = self.servers[fd]
logger.error("Something went wrong in select")
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for x in xl: if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
x.exception()
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
try:
w.write_select()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
try: try:
self.receive_message(r, i) srv.exception(flag)
except: except:
logger.exception("Uncatched exception on server read") logger.exception("Uncatched exception on server exception")
if srv.fileno() > 0:
if flag & (select.POLLOUT):
try:
srv.async_write()
except:
logger.exception("Uncatched exception on server write")
if flag & (select.POLLIN | select.POLLPRI):
try:
for i in srv.async_read():
self.receive_message(srv, i)
except:
logger.exception("Uncatched exception on server read")
else:
del self.servers[fd]
# Always check the sync queue
while not sync_queue.empty():
args = sync_queue.get()
action = args.pop(0)
if action == "sckt" and len(args) >= 2:
try:
if args[0] == "write":
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI)
elif args[0] == "unwrite":
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI)
elif args[0] == "register":
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI)
elif args[0] == "unregister":
self._poll.unregister(int(args[1]))
except:
logger.exception("Unhandled excpetion during action:")
elif action == "exit":
self.quit()
elif action == "loadconf":
for path in args:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
sync_queue.task_done()
# Launch new consumer threads if necessary # Launch new consumer threads if necessary
@ -202,17 +222,6 @@ class Bot(threading.Thread):
c = Consumer(self) c = Consumer(self)
self.cnsr_thrd.append(c) self.cnsr_thrd.append(c)
c.start() c.start()
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
logger.info("Ending main loop") logger.info("Ending main loop")
@ -419,7 +428,7 @@ class Bot(threading.Thread):
self.servers[fileno] = srv self.servers[fileno] = srv
self.servers[srv.name] = srv self.servers[srv.name] = srv
if autoconnect and not hasattr(self, "noautoconnect"): if autoconnect and not hasattr(self, "noautoconnect"):
srv.open() srv.connect()
return True return True
else: else:
@ -532,28 +541,28 @@ class Bot(threading.Thread):
def quit(self): def quit(self):
"""Save and unload modules and disconnect servers""" """Save and unload modules and disconnect servers"""
self.datastore.close()
if self.event_timer is not None: if self.event_timer is not None:
logger.info("Stop the event timer...") logger.info("Stop the event timer...")
self.event_timer.cancel() self.event_timer.cancel()
logger.info("Save and unload all modules...")
for mod in self.modules.items():
self.unload_module(mod)
logger.info("Close all servers connection...")
for srv in [self.servers[k] for k in self.servers]:
srv.close()
logger.info("Stop consumers") logger.info("Stop consumers")
k = self.cnsr_thrd k = self.cnsr_thrd
for cnsr in k: for cnsr in k:
cnsr.stop = True cnsr.stop = True
logger.info("Save and unload all modules...") self.datastore.close()
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
logger.info("Close all servers connection...")
k = list(self.servers.keys())
for srv in k:
self.servers[srv].close()
self.stop = True self.stop = True
sync_act("end")
sync_queue.join()
# Treatment # Treatment

View File

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer): class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None): def __init__(self, srv, dest, socket=None):
super().__init__(self) super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection? self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion self.messages = list() # Message queued before connexion

View File

@ -20,17 +20,17 @@ import re
from nemubot.channel import Channel from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter from nemubot.message.printer.IRC import IRC as IRCPrinter
from nemubot.server.message.IRC import IRC as IRCMessage from nemubot.server.message.IRC import IRC as IRCMessage
from nemubot.server.socket import SocketServer from nemubot.server.socket import SocketServer, SecureSocketServer
class IRC(SocketServer): class _IRC:
"""Concrete implementation of a connexion to an IRC server""" """Concrete implementation of a connexion to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None, def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None, nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None, realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None): channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server """Prepare a connection with an IRC server
Keyword arguments: Keyword arguments:
@ -54,7 +54,8 @@ class IRC(SocketServer):
self.owner = owner self.owner = owner
self.realname = realname self.realname = realname
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port)) super().__init__(name=self.username + "@" + host + ":" + str(port),
host=host, port=port, **kwargs)
self.printer = IRCPrinter self.printer = IRCPrinter
self.encoding = encoding self.encoding = encoding
@ -231,20 +232,19 @@ class IRC(SocketServer):
# Open/close # Open/close
def open(self): def connect(self):
if super().open(): super().connect()
if self.password is not None:
self.write("PASS :" + self.password) if self.password is not None:
if self.capabilities is not None: self.write("PASS :" + self.password)
self.write("CAP LS") if self.capabilities is not None:
self.write("NICK :" + self.nick) self.write("CAP LS")
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) self.write("NICK :" + self.nick)
return True self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
return False
def close(self): def close(self):
if not self.closed: if not self._closed:
self.write("QUIT") self.write("QUIT")
return super().close() return super().close()
@ -253,8 +253,8 @@ class IRC(SocketServer):
# Read # Read
def read(self): def async_read(self):
for line in super().read(): for line in super().async_read():
# PING should be handled here, so start parsing here :/ # PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding) msg = IRCMessage(line, self.encoding)
@ -273,3 +273,10 @@ class IRC(SocketServer):
def subparse(self, orig, cnt): def subparse(self, orig, cnt):
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding) msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
return msg.to_bot_message(self) return msg.to_bot_message(self)
class IRC(_IRC, SocketServer):
pass
class IRC_secure(_IRC, SecureSocketServer):
pass

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,34 +14,36 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock() def factory(uri, ssl=False, **init_args):
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
from urllib.parse import urlparse, unquote from urllib.parse import urlparse, unquote
o = urlparse(uri) o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs": if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
args = init_args args = init_args
modifiers = o.path.split(",") if o.scheme == "ircs": ssl = True
target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True
if o.hostname is not None: args["host"] = o.hostname if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username if o.username is not None: args["username"] = o.username
if o.password is not None: args["password"] = o.password if o.password is not None: args["password"] = o.password
if ssl:
try:
from ssl import create_default_context
args["_context"] = create_default_context()
except ImportError:
# Python 3.3 compat
from ssl import SSLContext, PROTOCOL_TLSv1
args["_context"] = SSLContext(PROTOCOL_TLSv1)
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
queries = o.query.split("&") queries = o.query.split("&")
for q in queries: for q in queries:
if "=" in q: if "=" in q:
@ -64,7 +66,11 @@ def factory(uri, **init_args):
if "channels" not in args and "isnick" not in modifiers: if "channels" not in args and "isnick" not in modifiers:
args["channels"] = [ target ] args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer if ssl:
return IRCServer(**args) from nemubot.server.IRC import IRC_secure as SecureIRCServer
else: srv = SecureIRCServer(**args)
return None else:
from nemubot.server.IRC import IRC as IRCServer
srv = IRCServer(**args)
return srv

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,34 +14,30 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging import logging
import queue import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase): class AbstractServer:
"""An abstract server: handle communication with an IM server""" """An abstract server: handle communication with an IM server"""
def __init__(self, name=None, send_callback=None): def __init__(self, name=None, **kwargs):
"""Initialize an abstract server """Initialize an abstract server
Keyword argument: Keyword argument:
send_callback -- Callback when developper want to send a message name -- Identifier of the socket, for convinience
""" """
self._name = name self._name = name
super().__init__() super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + self.name) self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue() self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
else:
self._send_callback = self._write_select
@property @property
@ -54,40 +50,28 @@ class AbstractServer(io.IOBase):
# Open/close # Open/close
def __enter__(self): def connect(self, *args, **kwargs):
self.open() """Register the server in _poll"""
return self
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._on_connect()
def _on_connect(self):
sync_act("sckt", "register", self.fileno())
def __exit__(self, type, value, traceback): def close(self, *args, **kwargs):
self.close() """Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self): if self.fileno() > 0:
"""Generic open function that register the server un _rlist in case sync_act("sckt", "unregister", self.fileno())
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
super().close(*args, **kwargs)
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
# Writes # Writes
@ -99,13 +83,16 @@ class AbstractServer(io.IOBase):
message -- message to send message -- message to send
""" """
self._send_callback(message) self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt", "write", self.fileno())
def write_select(self): def async_write(self):
"""Internal function used by the select function""" """Internal function used when the file descriptor is writable"""
try: try:
_wlist.remove(self) sync_act("sckt", "unwrite", self.fileno())
while not self._sending_queue.empty(): while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait()) self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done() self._sending_queue.task_done()
@ -114,19 +101,6 @@ class AbstractServer(io.IOBase):
pass pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response): def send_response(self, response):
"""Send a formated Message class """Send a formated Message class
@ -149,13 +123,39 @@ class AbstractServer(io.IOBase):
# Read # Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg): def parse(self, msg):
raise NotImplemented raise NotImplemented
# Exceptions # Exceptions
def exception(self): def exception(self, flags):
"""Exception occurs in fd""" """Exception occurs on fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.name) self.close()

View File

@ -22,34 +22,30 @@ class TestFactory(unittest.TestCase):
def test_IRC1(self): def test_IRC1(self):
from nemubot.server.IRC import IRC as IRCServer from nemubot.server.IRC import IRC as IRCServer
from nemubot.server.IRC import IRC_secure as IRCSServer
# <host>: If omitted, the client must connect to a prespecified default IRC server. # <host>: If omitted, the client must connect to a prespecified default IRC server.
server = factory("irc:///") server = factory("irc:///")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "localhost") self.assertEqual(server.host, "localhost")
self.assertFalse(server.ssl)
server = factory("ircs:///") server = factory("ircs:///")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "localhost") self.assertEqual(server.host, "localhost")
self.assertTrue(server.ssl)
server = factory("irc://host1") server = factory("irc://host1")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host1") self.assertEqual(server.host, "host1")
self.assertFalse(server.ssl)
server = factory("irc://host2:6667") server = factory("irc://host2:6667")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host2") self.assertEqual(server.host, "host2")
self.assertEqual(server.port, 6667) self.assertEqual(server.port, 6667)
self.assertFalse(server.ssl)
server = factory("ircs://host3:194/") server = factory("ircs://host3:194/")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "host3") self.assertEqual(server.host, "host3")
self.assertEqual(server.port, 194) self.assertEqual(server.port, 194)
self.assertTrue(server.ssl)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -0,0 +1,15 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,117 +14,33 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import ssl
import nemubot.message as message import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer): class _Socket(AbstractServer):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)""" """Concrete implementation of a socket connection"""
def __init__(self, sock_location=None, def __init__(self, printer=SocketPrinter, **kwargs):
host=None, port=None,
sock=None,
ssl=False,
name=None):
"""Create a server socket """Create a server socket
Keyword arguments:
sock_location -- Path to the UNIX socket
host -- Hostname of the INET socket
port -- Port of the INET socket
sock -- Already connected socket
ssl -- Should TLS connection enabled
name -- Convinience name
""" """
import socket super().__init__(**kwargs)
assert(sock is None or isinstance(sock, socket.SocketType))
assert(port is None or isinstance(port, int))
super().__init__(name=name)
if sock is None:
if sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connect_to = sock_location
elif host is not None:
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
self.socket = socket.socket(af, socktype, proto)
self.connect_to = sa
break
else:
self.socket = sock
self.ssl = ssl
self.readbuffer = b'' self.readbuffer = b''
self.printer = SocketPrinter self.printer = printer
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket._closed
# Open/close
def open(self):
if not self.closed:
return True
try:
self.socket.connect(self.connect_to)
self.logger.info("Connected to %s", self.connect_to)
except:
self.socket.close()
self.logger.exception("Unable to connect to %s",
self.connect_to)
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return super().open()
def close(self):
import socket
# Flush the sending queue before close
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if not self.closed:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
return super().close()
# Write # Write
def _write(self, cnt): def _write(self, cnt):
if self.closed: self.sendall(cnt)
return
self.socket.sendall(cnt)
def format(self, txt): def format(self, txt):
@ -136,19 +52,12 @@ class SocketServer(AbstractServer):
# Read # Read
def read(self): def recv(self, n=1024):
if self.closed: return super().recv(n)
return []
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line
def parse(self, line): def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex import shlex
line = line.strip().decode() line = line.strip().decode()
@ -157,48 +66,102 @@ class SocketServer(AbstractServer):
except ValueError: except ValueError:
args = line.split(' ') args = line.split(' ')
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you") if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
class SocketListener(AbstractServer): class _SocketServer(_Socket):
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None): def __init__(self, host, port, bind=None, **kwargs):
super().__init__(name=name) super().__init__(family=socket.AF_INET, **kwargs)
self.new_server_cb = new_server_cb
self.sock_location = sock_location
self.host = host
self.port = port
self.ssl = ssl
self.nb_son = 0
assert(host is not None)
assert(isinstance(port, int))
def fileno(self): if isinstance(self, ssl.SSLSocket) and "server_hostname" not in kwargs:
return self.socket.fileno() if self.socket else None kwargs["server_hostname"] = host
super().__init__(family=family, type=type, proto=proto, **kwargs)
self._host = host
self._port = port
self._bind = bind
@property @property
def closed(self): def host(self):
"""Indicator of the connection aliveness""" return self._host
return self.socket is None
def open(self): def connect(self):
import os self.logger.info("Connection to %s:%d", self._host, self._port)
import socket super().connect((self._host, self._port))
if self.sock_location is not None: if self._bind:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) super().bind(self._bind)
try:
os.remove(self.sock_location)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind((self.host, self.port))
self.socket.listen(5)
return super().open()
class SocketServer(_SocketServer, socket.socket):
pass
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
pass
class UnixSocket:
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def connect(self):
self.logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
class _Listener:
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self):
conn, addr = self.accept()
fileno = conn.fileno()
self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
ss.connect = ss._on_connect
self._new_server_cb(ss, autoconnect=True)
return b''
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self):
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._on_connect()
def close(self): def close(self):
@ -206,25 +169,14 @@ class SocketListener(AbstractServer):
import socket import socket
try: try:
self.socket.shutdown(socket.SHUT_RDWR) self.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
except socket.error: except socket.error:
pass pass
return super().close() super().close()
try:
# Read if self._socket_path is not None:
os.remove(self._socket_path)
def read(self): except:
if self.closed: pass
return []
conn, addr = self.socket.accept()
self.nb_son += 1
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []