1
0
Fork 0

Refactor file/socket management (use poll instead of select)

This commit is contained in:
nemunaire 2016-05-30 22:06:35 +02:00
parent 6d8dca211d
commit 764e6f070b
10 changed files with 328 additions and 339 deletions

View File

@ -9,6 +9,8 @@ Requirements
*nemubot* requires at least Python 3.3 to work.
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
Some modules (like `cve`, `nextstop` or `laposte`) require the
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
but the core and framework has no dependency.

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -125,7 +125,7 @@ def main():
# Create bot context
from nemubot import datastore
from nemubot.bot import Bot
from nemubot.bot import Bot, sync_act
context = Bot(modules_paths=modules_paths,
data_store=datastore.XML(args.data_path),
verbosity=args.verbose)
@ -141,7 +141,7 @@ def main():
# Load requested configuration files
for path in args.files:
if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path])
sync_act("loadconf", path)
else:
logger.error("%s is not a readable file", path)
@ -165,22 +165,28 @@ def main():
# Reload configuration file
for path in args.files:
if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path])
sync_act("loadconf", path)
signal.signal(signal.SIGHUP, sighuphandler)
def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces"""
import traceback
import threading, traceback
for threadId, stack in sys._current_frames().items():
logger.debug("########### Thread %d:\n%s",
threadId,
thName = "#%d" % threadId
for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile:
from nemubot.server.socket import SocketListener
context.add_server(SocketListener(context.add_server, "master_socket",
sock_location=args.socketfile))
from nemubot.server.socket import UnixSocketListener
context.add_server(UnixSocketListener(new_server_cb=context.add_server,
location=args.socketfile,
name="master_socket"))
# context can change when performing an hotswap, always join the latest context
oldcontext = None

View File

@ -16,7 +16,9 @@
from datetime import datetime, timezone
import logging
from multiprocessing import JoinableQueue
import threading
import select
import sys
from nemubot import __version__
@ -26,6 +28,11 @@ import nemubot.hooks
logger = logging.getLogger("nemubot")
sync_queue = JoinableQueue()
def sync_act(*args):
sync_queue.put(list(args))
class Bot(threading.Thread):
@ -42,7 +49,7 @@ class Bot(threading.Thread):
verbosity -- verbosity level
"""
threading.Thread.__init__(self)
super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__,
@ -61,6 +68,7 @@ class Bot(threading.Thread):
self.datastore.open()
# Keep global context: servers and modules
self._poll = select.poll()
self.servers = dict()
self.modules = dict()
self.modules_configuration = dict()
@ -138,60 +146,72 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue()
self.cnsr_thrd = list()
self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self):
from select import select
from nemubot.server import _lock, _rlist, _wlist, _xlist
self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
logger.info("Starting main loop")
self.stop = False
while not self.stop:
with _lock:
try:
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
except:
logger.error("Something went wrong in select")
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for fd, flag in self._poll.poll():
# Handle internal socket passing orders
if fd != sync_queue._reader.fileno() and fd in self.servers:
srv = self.servers[fd]
for x in xl:
try:
x.exception()
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
try:
w.write_select()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
self.receive_message(r, i)
srv.exception(flag)
except:
logger.exception("Uncatched exception on server read")
logger.exception("Uncatched exception on server exception")
if srv.fileno() > 0:
if flag & (select.POLLOUT):
try:
srv.async_write()
except:
logger.exception("Uncatched exception on server write")
if flag & (select.POLLIN | select.POLLPRI):
try:
for i in srv.async_read():
self.receive_message(srv, i)
except:
logger.exception("Uncatched exception on server read")
else:
del self.servers[fd]
# Always check the sync queue
while not sync_queue.empty():
args = sync_queue.get()
action = args.pop(0)
if action == "sckt" and len(args) >= 2:
try:
if args[0] == "write":
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI)
elif args[0] == "unwrite":
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI)
elif args[0] == "register":
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI)
elif args[0] == "unregister":
self._poll.unregister(int(args[1]))
except:
logger.exception("Unhandled excpetion during action:")
elif action == "exit":
self.quit()
elif action == "loadconf":
for path in args:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
sync_queue.task_done()
# Launch new consumer threads if necessary
@ -202,17 +222,6 @@ class Bot(threading.Thread):
c = Consumer(self)
self.cnsr_thrd.append(c)
c.start()
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
logger.info("Ending main loop")
@ -419,7 +428,7 @@ class Bot(threading.Thread):
self.servers[fileno] = srv
self.servers[srv.name] = srv
if autoconnect and not hasattr(self, "noautoconnect"):
srv.open()
srv.connect()
return True
else:
@ -532,28 +541,28 @@ class Bot(threading.Thread):
def quit(self):
"""Save and unload modules and disconnect servers"""
self.datastore.close()
if self.event_timer is not None:
logger.info("Stop the event timer...")
self.event_timer.cancel()
logger.info("Save and unload all modules...")
for mod in self.modules.items():
self.unload_module(mod)
logger.info("Close all servers connection...")
for srv in [self.servers[k] for k in self.servers]:
srv.close()
logger.info("Stop consumers")
k = self.cnsr_thrd
for cnsr in k:
cnsr.stop = True
logger.info("Save and unload all modules...")
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
logger.info("Close all servers connection...")
k = list(self.servers.keys())
for srv in k:
self.servers[srv].close()
self.datastore.close()
self.stop = True
sync_act("end")
sync_queue.join()
# Treatment

View File

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None):
super().__init__(self)
super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion

View File

@ -20,17 +20,17 @@ import re
from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter
from nemubot.server.message.IRC import IRC as IRCMessage
from nemubot.server.socket import SocketServer
from nemubot.server.socket import SocketServer, SecureSocketServer
class IRC(SocketServer):
class _IRC:
"""Concrete implementation of a connexion to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None,
def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None):
channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server
Keyword arguments:
@ -54,7 +54,8 @@ class IRC(SocketServer):
self.owner = owner
self.realname = realname
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
super().__init__(name=self.username + "@" + host + ":" + str(port),
host=host, port=port, **kwargs)
self.printer = IRCPrinter
self.encoding = encoding
@ -231,20 +232,19 @@ class IRC(SocketServer):
# Open/close
def open(self):
if super().open():
if self.password is not None:
self.write("PASS :" + self.password)
if self.capabilities is not None:
self.write("CAP LS")
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
return True
return False
def connect(self):
super().connect()
if self.password is not None:
self.write("PASS :" + self.password)
if self.capabilities is not None:
self.write("CAP LS")
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
def close(self):
if not self.closed:
if not self._closed:
self.write("QUIT")
return super().close()
@ -253,8 +253,8 @@ class IRC(SocketServer):
# Read
def read(self):
for line in super().read():
def async_read(self):
for line in super().async_read():
# PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding)
@ -273,3 +273,10 @@ class IRC(SocketServer):
def subparse(self, orig, cnt):
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
return msg.to_bot_message(self)
class IRC(_IRC, SocketServer):
pass
class IRC_secure(_IRC, SecureSocketServer):
pass

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,34 +14,37 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock()
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
def factory(uri, ssl=False, **init_args):
from urllib.parse import urlparse, unquote
o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
args = init_args
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True
if o.scheme == "ircs": ssl = True
if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username
if o.password is not None: args["password"] = o.password
if ssl:
try:
from ssl import create_default_context
args["_context"] = create_default_context()
except ImportError:
# Python 3.3 compat
from ssl import SSLContext, PROTOCOL_TLSv1
args["_context"] = SSLContext(PROTOCOL_TLSv1)
args["server_hostname"] = o.hostname
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
queries = o.query.split("&")
for q in queries:
if "=" in q:
@ -64,7 +67,11 @@ def factory(uri, **init_args):
if "channels" not in args and "isnick" not in modifiers:
args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer
return IRCServer(**args)
else:
return None
if ssl:
from nemubot.server.IRC import IRC_secure as SecureIRCServer
srv = SecureIRCServer(**args)
else:
from nemubot.server.IRC import IRC as IRCServer
srv = IRCServer(**args)
return srv

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,34 +14,30 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging
import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist
from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase):
class AbstractServer:
"""An abstract server: handle communication with an IM server"""
def __init__(self, name=None, send_callback=None):
def __init__(self, name=None, **kwargs):
"""Initialize an abstract server
Keyword argument:
send_callback -- Callback when developper want to send a message
name -- Identifier of the socket, for convinience
"""
self._name = name
super().__init__()
super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + self.name)
self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
else:
self._send_callback = self._write_select
@property
@ -54,40 +50,28 @@ class AbstractServer(io.IOBase):
# Open/close
def __enter__(self):
self.open()
return self
def connect(self, *args, **kwargs):
"""Register the server in _poll"""
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._on_connect()
def _on_connect(self):
sync_act("sckt", "register", self.fileno())
def __exit__(self, type, value, traceback):
self.close()
def close(self, *args, **kwargs):
"""Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self):
"""Generic open function that register the server un _rlist in case
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
if self.fileno() > 0:
sync_act("sckt", "unregister", self.fileno())
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
super().close(*args, **kwargs)
# Writes
@ -99,13 +83,16 @@ class AbstractServer(io.IOBase):
message -- message to send
"""
self._send_callback(message)
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt", "write", self.fileno())
def write_select(self):
"""Internal function used by the select function"""
def async_write(self):
"""Internal function used when the file descriptor is writable"""
try:
_wlist.remove(self)
sync_act("sckt", "unwrite", self.fileno())
while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done()
@ -114,19 +101,6 @@ class AbstractServer(io.IOBase):
pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response):
"""Send a formated Message class
@ -149,13 +123,39 @@ class AbstractServer(io.IOBase):
# Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg):
raise NotImplemented
# Exceptions
def exception(self):
"""Exception occurs in fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.name)
def exception(self, flags):
"""Exception occurs on fd"""
self.close()

View File

@ -22,34 +22,30 @@ class TestFactory(unittest.TestCase):
def test_IRC1(self):
from nemubot.server.IRC import IRC as IRCServer
from nemubot.server.IRC import IRC_secure as IRCSServer
# <host>: If omitted, the client must connect to a prespecified default IRC server.
server = factory("irc:///")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "localhost")
self.assertFalse(server.ssl)
server = factory("ircs:///")
self.assertIsInstance(server, IRCServer)
self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "localhost")
self.assertTrue(server.ssl)
server = factory("irc://host1")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host1")
self.assertFalse(server.ssl)
server = factory("irc://host2:6667")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host2")
self.assertEqual(server.port, 6667)
self.assertFalse(server.ssl)
server = factory("ircs://host3:194/")
self.assertIsInstance(server, IRCServer)
self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "host3")
self.assertEqual(server.port, 194)
self.assertTrue(server.ssl)
if __name__ == '__main__':

View File

@ -0,0 +1,15 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,117 +14,33 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import ssl
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer):
class _Socket(AbstractServer):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
"""Concrete implementation of a socket connection"""
def __init__(self, sock_location=None,
host=None, port=None,
sock=None,
ssl=False,
name=None):
def __init__(self, printer=SocketPrinter, **kwargs):
"""Create a server socket
Keyword arguments:
sock_location -- Path to the UNIX socket
host -- Hostname of the INET socket
port -- Port of the INET socket
sock -- Already connected socket
ssl -- Should TLS connection enabled
name -- Convinience name
"""
import socket
assert(sock is None or isinstance(sock, socket.SocketType))
assert(port is None or isinstance(port, int))
super().__init__(name=name)
if sock is None:
if sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connect_to = sock_location
elif host is not None:
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
self.socket = socket.socket(af, socktype, proto)
self.connect_to = sa
break
else:
self.socket = sock
self.ssl = ssl
super().__init__(**kwargs)
self.readbuffer = b''
self.printer = SocketPrinter
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket._closed
# Open/close
def open(self):
if not self.closed:
return True
try:
self.socket.connect(self.connect_to)
self.logger.info("Connected to %s", self.connect_to)
except:
self.socket.close()
self.logger.exception("Unable to connect to %s",
self.connect_to)
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return super().open()
def close(self):
import socket
# Flush the sending queue before close
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if not self.closed:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
return super().close()
self.printer = printer
# Write
def _write(self, cnt):
if self.closed:
return
self.socket.sendall(cnt)
self.sendall(cnt)
def format(self, txt):
@ -136,19 +52,12 @@ class SocketServer(AbstractServer):
# Read
def read(self):
if self.closed:
return []
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line
def recv(self, n=1024):
return super().recv(n)
def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex
line = line.strip().decode()
@ -157,48 +66,97 @@ class SocketServer(AbstractServer):
except ValueError:
args = line.split(' ')
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
class SocketListener(AbstractServer):
class _SocketServer(_Socket):
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
super().__init__(name=name)
self.new_server_cb = new_server_cb
self.sock_location = sock_location
self.host = host
self.port = port
self.ssl = ssl
self.nb_son = 0
def __init__(self, host, port, bind=None, **kwargs):
super().__init__(family=socket.AF_INET, **kwargs)
assert(host is not None)
assert(isinstance(port, int))
def fileno(self):
return self.socket.fileno() if self.socket else None
self._host = host
self._port = port
self._bind = bind
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket is None
def host(self):
return self._host
def open(self):
import os
import socket
def connect(self):
self.logger.info("Connection to %s:%d", self._host, self._port)
super().connect((self._host, self._port))
if self.sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.remove(self.sock_location)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind((self.host, self.port))
self.socket.listen(5)
if self._bind:
super().bind(self._bind)
return super().open()
class SocketServer(_SocketServer, socket.socket):
pass
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
pass
class UnixSocket:
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def connect(self):
self.logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
class _Listener:
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self):
conn, addr = self.accept()
fileno = conn.fileno()
self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
ss.connect = ss._on_connect
self._new_server_cb(ss, autoconnect=True)
return b''
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self):
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._on_connect()
def close(self):
@ -206,25 +164,14 @@ class SocketListener(AbstractServer):
import socket
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
self.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
return super().close()
super().close()
# Read
def read(self):
if self.closed:
return []
conn, addr = self.socket.accept()
self.nb_son += 1
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []
try:
if self._socket_path is not None:
os.remove(self._socket_path)
except:
pass