Refactor file/socket management (use poll instead of select)
This commit is contained in:
parent
6d8dca211d
commit
764e6f070b
10 changed files with 324 additions and 335 deletions
|
|
@ -31,7 +31,7 @@ PORTS = list()
|
|||
|
||||
class DCC(server.AbstractServer):
|
||||
def __init__(self, srv, dest, socket=None):
|
||||
super().__init__(self)
|
||||
super().__init__(name="Nemubot DCC server")
|
||||
|
||||
self.error = False # An error has occur, closing the connection?
|
||||
self.messages = list() # Message queued before connexion
|
||||
|
|
|
|||
|
|
@ -20,17 +20,17 @@ import re
|
|||
from nemubot.channel import Channel
|
||||
from nemubot.message.printer.IRC import IRC as IRCPrinter
|
||||
from nemubot.server.message.IRC import IRC as IRCMessage
|
||||
from nemubot.server.socket import SocketServer
|
||||
from nemubot.server.socket import SocketServer, SecureSocketServer
|
||||
|
||||
|
||||
class IRC(SocketServer):
|
||||
class _IRC:
|
||||
|
||||
"""Concrete implementation of a connexion to an IRC server"""
|
||||
|
||||
def __init__(self, host="localhost", port=6667, ssl=False, owner=None,
|
||||
def __init__(self, host="localhost", port=6667, owner=None,
|
||||
nick="nemubot", username=None, password=None,
|
||||
realname="Nemubot", encoding="utf-8", caps=None,
|
||||
channels=list(), on_connect=None):
|
||||
channels=list(), on_connect=None, **kwargs):
|
||||
"""Prepare a connection with an IRC server
|
||||
|
||||
Keyword arguments:
|
||||
|
|
@ -54,7 +54,8 @@ class IRC(SocketServer):
|
|||
self.owner = owner
|
||||
self.realname = realname
|
||||
|
||||
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
|
||||
super().__init__(name=self.username + "@" + host + ":" + str(port),
|
||||
host=host, port=port, **kwargs)
|
||||
self.printer = IRCPrinter
|
||||
|
||||
self.encoding = encoding
|
||||
|
|
@ -231,20 +232,19 @@ class IRC(SocketServer):
|
|||
|
||||
# Open/close
|
||||
|
||||
def open(self):
|
||||
if super().open():
|
||||
if self.password is not None:
|
||||
self.write("PASS :" + self.password)
|
||||
if self.capabilities is not None:
|
||||
self.write("CAP LS")
|
||||
self.write("NICK :" + self.nick)
|
||||
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
|
||||
return True
|
||||
return False
|
||||
def connect(self):
|
||||
super().connect()
|
||||
|
||||
if self.password is not None:
|
||||
self.write("PASS :" + self.password)
|
||||
if self.capabilities is not None:
|
||||
self.write("CAP LS")
|
||||
self.write("NICK :" + self.nick)
|
||||
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
|
||||
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
if not self._closed:
|
||||
self.write("QUIT")
|
||||
return super().close()
|
||||
|
||||
|
|
@ -253,8 +253,8 @@ class IRC(SocketServer):
|
|||
|
||||
# Read
|
||||
|
||||
def read(self):
|
||||
for line in super().read():
|
||||
def async_read(self):
|
||||
for line in super().async_read():
|
||||
# PING should be handled here, so start parsing here :/
|
||||
msg = IRCMessage(line, self.encoding)
|
||||
|
||||
|
|
@ -273,3 +273,10 @@ class IRC(SocketServer):
|
|||
def subparse(self, orig, cnt):
|
||||
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
|
||||
return msg.to_bot_message(self)
|
||||
|
||||
|
||||
class IRC(_IRC, SocketServer):
|
||||
pass
|
||||
|
||||
class IRC_secure(_IRC, SecureSocketServer):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Nemubot is a smart and modulable IM bot.
|
||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
||||
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
|
|
@ -14,34 +14,37 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import threading
|
||||
|
||||
_lock = threading.Lock()
|
||||
|
||||
# Lists for select
|
||||
_rlist = []
|
||||
_wlist = []
|
||||
_xlist = []
|
||||
|
||||
|
||||
def factory(uri, **init_args):
|
||||
def factory(uri, ssl=False, **init_args):
|
||||
from urllib.parse import urlparse, unquote
|
||||
o = urlparse(uri)
|
||||
|
||||
srv = None
|
||||
|
||||
if o.scheme == "irc" or o.scheme == "ircs":
|
||||
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
|
||||
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
|
||||
args = init_args
|
||||
|
||||
modifiers = o.path.split(",")
|
||||
target = unquote(modifiers.pop(0)[1:])
|
||||
|
||||
if o.scheme == "ircs": args["ssl"] = True
|
||||
if o.scheme == "ircs": ssl = True
|
||||
if o.hostname is not None: args["host"] = o.hostname
|
||||
if o.port is not None: args["port"] = o.port
|
||||
if o.username is not None: args["username"] = o.username
|
||||
if o.password is not None: args["password"] = o.password
|
||||
|
||||
if ssl:
|
||||
try:
|
||||
from ssl import create_default_context
|
||||
args["_context"] = create_default_context()
|
||||
except ImportError:
|
||||
# Python 3.3 compat
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1
|
||||
args["_context"] = SSLContext(PROTOCOL_TLSv1)
|
||||
args["server_hostname"] = o.hostname
|
||||
|
||||
modifiers = o.path.split(",")
|
||||
target = unquote(modifiers.pop(0)[1:])
|
||||
|
||||
queries = o.query.split("&")
|
||||
for q in queries:
|
||||
if "=" in q:
|
||||
|
|
@ -64,7 +67,11 @@ def factory(uri, **init_args):
|
|||
if "channels" not in args and "isnick" not in modifiers:
|
||||
args["channels"] = [ target ]
|
||||
|
||||
from nemubot.server.IRC import IRC as IRCServer
|
||||
return IRCServer(**args)
|
||||
else:
|
||||
return None
|
||||
if ssl:
|
||||
from nemubot.server.IRC import IRC_secure as SecureIRCServer
|
||||
srv = SecureIRCServer(**args)
|
||||
else:
|
||||
from nemubot.server.IRC import IRC as IRCServer
|
||||
srv = IRCServer(**args)
|
||||
|
||||
return srv
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Nemubot is a smart and modulable IM bot.
|
||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
||||
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
|
|
@ -14,34 +14,30 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import queue
|
||||
|
||||
from nemubot.server import _lock, _rlist, _wlist, _xlist
|
||||
from nemubot.bot import sync_act
|
||||
|
||||
# Extends from IOBase in order to be compatible with select function
|
||||
class AbstractServer(io.IOBase):
|
||||
|
||||
class AbstractServer:
|
||||
|
||||
"""An abstract server: handle communication with an IM server"""
|
||||
|
||||
def __init__(self, name=None, send_callback=None):
|
||||
def __init__(self, name=None, **kwargs):
|
||||
"""Initialize an abstract server
|
||||
|
||||
Keyword argument:
|
||||
send_callback -- Callback when developper want to send a message
|
||||
name -- Identifier of the socket, for convinience
|
||||
"""
|
||||
|
||||
self._name = name
|
||||
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.logger = logging.getLogger("nemubot.server." + self.name)
|
||||
self.logger = logging.getLogger("nemubot.server." + str(self.name))
|
||||
self._readbuffer = b''
|
||||
self._sending_queue = queue.Queue()
|
||||
if send_callback is not None:
|
||||
self._send_callback = send_callback
|
||||
else:
|
||||
self._send_callback = self._write_select
|
||||
|
||||
|
||||
@property
|
||||
|
|
@ -54,40 +50,28 @@ class AbstractServer(io.IOBase):
|
|||
|
||||
# Open/close
|
||||
|
||||
def __enter__(self):
|
||||
self.open()
|
||||
return self
|
||||
def connect(self, *args, **kwargs):
|
||||
"""Register the server in _poll"""
|
||||
|
||||
self.logger.info("Opening connection")
|
||||
|
||||
super().connect(*args, **kwargs)
|
||||
|
||||
self._on_connect()
|
||||
|
||||
def _on_connect(self):
|
||||
sync_act("sckt", "register", self.fileno())
|
||||
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
def close(self, *args, **kwargs):
|
||||
"""Unregister the server from _poll"""
|
||||
|
||||
self.logger.info("Closing connection")
|
||||
|
||||
def open(self):
|
||||
"""Generic open function that register the server un _rlist in case
|
||||
of successful _open"""
|
||||
self.logger.info("Opening connection to %s", self.id)
|
||||
if not hasattr(self, "_open") or self._open():
|
||||
_rlist.append(self)
|
||||
_xlist.append(self)
|
||||
return True
|
||||
return False
|
||||
if self.fileno() > 0:
|
||||
sync_act("sckt", "unregister", self.fileno())
|
||||
|
||||
|
||||
def close(self):
|
||||
"""Generic close function that register the server un _{r,w,x}list in
|
||||
case of successful _close"""
|
||||
self.logger.info("Closing connection to %s", self.id)
|
||||
with _lock:
|
||||
if not hasattr(self, "_close") or self._close():
|
||||
if self in _rlist:
|
||||
_rlist.remove(self)
|
||||
if self in _wlist:
|
||||
_wlist.remove(self)
|
||||
if self in _xlist:
|
||||
_xlist.remove(self)
|
||||
return True
|
||||
return False
|
||||
super().close(*args, **kwargs)
|
||||
|
||||
|
||||
# Writes
|
||||
|
|
@ -99,13 +83,16 @@ class AbstractServer(io.IOBase):
|
|||
message -- message to send
|
||||
"""
|
||||
|
||||
self._send_callback(message)
|
||||
self._sending_queue.put(self.format(message))
|
||||
self.logger.debug("Message '%s' appended to write queue", message)
|
||||
sync_act("sckt", "write", self.fileno())
|
||||
|
||||
|
||||
def write_select(self):
|
||||
"""Internal function used by the select function"""
|
||||
def async_write(self):
|
||||
"""Internal function used when the file descriptor is writable"""
|
||||
|
||||
try:
|
||||
_wlist.remove(self)
|
||||
sync_act("sckt", "unwrite", self.fileno())
|
||||
while not self._sending_queue.empty():
|
||||
self._write(self._sending_queue.get_nowait())
|
||||
self._sending_queue.task_done()
|
||||
|
|
@ -114,19 +101,6 @@ class AbstractServer(io.IOBase):
|
|||
pass
|
||||
|
||||
|
||||
def _write_select(self, message):
|
||||
"""Send a message to the server safely through select
|
||||
|
||||
Argument:
|
||||
message -- message to send
|
||||
"""
|
||||
|
||||
self._sending_queue.put(self.format(message))
|
||||
self.logger.debug("Message '%s' appended to write queue", message)
|
||||
if self not in _wlist:
|
||||
_wlist.append(self)
|
||||
|
||||
|
||||
def send_response(self, response):
|
||||
"""Send a formated Message class
|
||||
|
||||
|
|
@ -149,13 +123,39 @@ class AbstractServer(io.IOBase):
|
|||
|
||||
# Read
|
||||
|
||||
def async_read(self):
|
||||
"""Internal function used when the file descriptor is readable
|
||||
|
||||
Returns:
|
||||
A list of fully received messages
|
||||
"""
|
||||
|
||||
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
|
||||
|
||||
for r in ret:
|
||||
yield r
|
||||
|
||||
|
||||
def lex(self, buf):
|
||||
"""Assume lexing in default case is per line
|
||||
|
||||
Argument:
|
||||
buf -- buffer to lex
|
||||
"""
|
||||
|
||||
msgs = buf.split(b'\r\n')
|
||||
partial = msgs.pop()
|
||||
|
||||
return msgs, partial
|
||||
|
||||
|
||||
def parse(self, msg):
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
# Exceptions
|
||||
|
||||
def exception(self):
|
||||
"""Exception occurs in fd"""
|
||||
self.logger.warning("Unhandle file descriptor exception on server %s",
|
||||
self.name)
|
||||
def exception(self, flags):
|
||||
"""Exception occurs on fd"""
|
||||
|
||||
self.close()
|
||||
|
|
|
|||
|
|
@ -22,34 +22,30 @@ class TestFactory(unittest.TestCase):
|
|||
|
||||
def test_IRC1(self):
|
||||
from nemubot.server.IRC import IRC as IRCServer
|
||||
from nemubot.server.IRC import IRC_secure as IRCSServer
|
||||
|
||||
# <host>: If omitted, the client must connect to a prespecified default IRC server.
|
||||
server = factory("irc:///")
|
||||
self.assertIsInstance(server, IRCServer)
|
||||
self.assertEqual(server.host, "localhost")
|
||||
self.assertFalse(server.ssl)
|
||||
|
||||
server = factory("ircs:///")
|
||||
self.assertIsInstance(server, IRCServer)
|
||||
self.assertIsInstance(server, IRCSServer)
|
||||
self.assertEqual(server.host, "localhost")
|
||||
self.assertTrue(server.ssl)
|
||||
|
||||
server = factory("irc://host1")
|
||||
self.assertIsInstance(server, IRCServer)
|
||||
self.assertEqual(server.host, "host1")
|
||||
self.assertFalse(server.ssl)
|
||||
|
||||
server = factory("irc://host2:6667")
|
||||
self.assertIsInstance(server, IRCServer)
|
||||
self.assertEqual(server.host, "host2")
|
||||
self.assertEqual(server.port, 6667)
|
||||
self.assertFalse(server.ssl)
|
||||
|
||||
server = factory("ircs://host3:194/")
|
||||
self.assertIsInstance(server, IRCServer)
|
||||
self.assertIsInstance(server, IRCSServer)
|
||||
self.assertEqual(server.host, "host3")
|
||||
self.assertEqual(server.port, 194)
|
||||
self.assertTrue(server.ssl)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
15
nemubot/server/message/__init__.py
Normal file
15
nemubot/server/message/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# Nemubot is a smart and modulable IM bot.
|
||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
# Nemubot is a smart and modulable IM bot.
|
||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
||||
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
|
|
@ -14,117 +14,33 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
import nemubot.message as message
|
||||
from nemubot.message.printer.socket import Socket as SocketPrinter
|
||||
from nemubot.server.abstract import AbstractServer
|
||||
|
||||
|
||||
class SocketServer(AbstractServer):
|
||||
class _Socket(AbstractServer):
|
||||
|
||||
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
|
||||
"""Concrete implementation of a socket connection"""
|
||||
|
||||
def __init__(self, sock_location=None,
|
||||
host=None, port=None,
|
||||
sock=None,
|
||||
ssl=False,
|
||||
name=None):
|
||||
def __init__(self, printer=SocketPrinter, **kwargs):
|
||||
"""Create a server socket
|
||||
|
||||
Keyword arguments:
|
||||
sock_location -- Path to the UNIX socket
|
||||
host -- Hostname of the INET socket
|
||||
port -- Port of the INET socket
|
||||
sock -- Already connected socket
|
||||
ssl -- Should TLS connection enabled
|
||||
name -- Convinience name
|
||||
"""
|
||||
|
||||
import socket
|
||||
|
||||
assert(sock is None or isinstance(sock, socket.SocketType))
|
||||
assert(port is None or isinstance(port, int))
|
||||
|
||||
super().__init__(name=name)
|
||||
|
||||
if sock is None:
|
||||
if sock_location is not None:
|
||||
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self.connect_to = sock_location
|
||||
elif host is not None:
|
||||
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
|
||||
self.socket = socket.socket(af, socktype, proto)
|
||||
self.connect_to = sa
|
||||
break
|
||||
else:
|
||||
self.socket = sock
|
||||
|
||||
self.ssl = ssl
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.readbuffer = b''
|
||||
self.printer = SocketPrinter
|
||||
|
||||
|
||||
def fileno(self):
|
||||
return self.socket.fileno() if self.socket else None
|
||||
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
"""Indicator of the connection aliveness"""
|
||||
return self.socket._closed
|
||||
|
||||
|
||||
# Open/close
|
||||
|
||||
def open(self):
|
||||
if not self.closed:
|
||||
return True
|
||||
|
||||
try:
|
||||
self.socket.connect(self.connect_to)
|
||||
self.logger.info("Connected to %s", self.connect_to)
|
||||
except:
|
||||
self.socket.close()
|
||||
self.logger.exception("Unable to connect to %s",
|
||||
self.connect_to)
|
||||
return False
|
||||
|
||||
# Wrap the socket for SSL
|
||||
if self.ssl:
|
||||
import ssl
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
||||
self.socket = ctx.wrap_socket(self.socket)
|
||||
|
||||
return super().open()
|
||||
|
||||
|
||||
def close(self):
|
||||
import socket
|
||||
|
||||
# Flush the sending queue before close
|
||||
from nemubot.server import _lock
|
||||
_lock.release()
|
||||
self._sending_queue.join()
|
||||
_lock.acquire()
|
||||
|
||||
if not self.closed:
|
||||
try:
|
||||
self.socket.shutdown(socket.SHUT_RDWR)
|
||||
except socket.error:
|
||||
pass
|
||||
|
||||
self.socket.close()
|
||||
|
||||
return super().close()
|
||||
self.printer = printer
|
||||
|
||||
|
||||
# Write
|
||||
|
||||
def _write(self, cnt):
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.socket.sendall(cnt)
|
||||
self.sendall(cnt)
|
||||
|
||||
|
||||
def format(self, txt):
|
||||
|
|
@ -136,19 +52,12 @@ class SocketServer(AbstractServer):
|
|||
|
||||
# Read
|
||||
|
||||
def read(self):
|
||||
if self.closed:
|
||||
return []
|
||||
|
||||
raw = self.socket.recv(1024)
|
||||
temp = (self.readbuffer + raw).split(b'\r\n')
|
||||
self.readbuffer = temp.pop()
|
||||
|
||||
for line in temp:
|
||||
yield line
|
||||
def recv(self, n=1024):
|
||||
return super().recv(n)
|
||||
|
||||
|
||||
def parse(self, line):
|
||||
"""Implement a default behaviour for socket"""
|
||||
import shlex
|
||||
|
||||
line = line.strip().decode()
|
||||
|
|
@ -157,48 +66,97 @@ class SocketServer(AbstractServer):
|
|||
except ValueError:
|
||||
args = line.split(' ')
|
||||
|
||||
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
|
||||
if len(args):
|
||||
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
|
||||
|
||||
|
||||
class SocketListener(AbstractServer):
|
||||
class _SocketServer(_Socket):
|
||||
|
||||
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
|
||||
super().__init__(name=name)
|
||||
self.new_server_cb = new_server_cb
|
||||
self.sock_location = sock_location
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.ssl = ssl
|
||||
self.nb_son = 0
|
||||
def __init__(self, host, port, bind=None, **kwargs):
|
||||
super().__init__(family=socket.AF_INET, **kwargs)
|
||||
|
||||
assert(host is not None)
|
||||
assert(isinstance(port, int))
|
||||
|
||||
def fileno(self):
|
||||
return self.socket.fileno() if self.socket else None
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._bind = bind
|
||||
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
"""Indicator of the connection aliveness"""
|
||||
return self.socket is None
|
||||
def host(self):
|
||||
return self._host
|
||||
|
||||
|
||||
def open(self):
|
||||
import os
|
||||
import socket
|
||||
def connect(self):
|
||||
self.logger.info("Connection to %s:%d", self._host, self._port)
|
||||
super().connect((self._host, self._port))
|
||||
|
||||
if self.sock_location is not None:
|
||||
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
os.remove(self.sock_location)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
self.socket.bind(self.sock_location)
|
||||
elif self.host is not None and self.port is not None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.bind((self.host, self.port))
|
||||
self.socket.listen(5)
|
||||
if self._bind:
|
||||
super().bind(self._bind)
|
||||
|
||||
return super().open()
|
||||
|
||||
class SocketServer(_SocketServer, socket.socket):
|
||||
pass
|
||||
|
||||
|
||||
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
|
||||
pass
|
||||
|
||||
|
||||
class UnixSocket:
|
||||
|
||||
def __init__(self, location, **kwargs):
|
||||
super().__init__(family=socket.AF_UNIX, **kwargs)
|
||||
|
||||
self._socket_path = location
|
||||
|
||||
|
||||
def connect(self):
|
||||
self.logger.info("Connection to unix://%s", self._socket_path)
|
||||
super().connect(self._socket_path)
|
||||
|
||||
|
||||
class _Listener:
|
||||
|
||||
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._instanciate = instanciate
|
||||
self._new_server_cb = new_server_cb
|
||||
|
||||
|
||||
def read(self):
|
||||
conn, addr = self.accept()
|
||||
fileno = conn.fileno()
|
||||
self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
|
||||
|
||||
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
|
||||
ss.connect = ss._on_connect
|
||||
self._new_server_cb(ss, autoconnect=True)
|
||||
|
||||
return b''
|
||||
|
||||
|
||||
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def connect(self):
|
||||
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
|
||||
|
||||
try:
|
||||
os.remove(self._socket_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
self.bind(self._socket_path)
|
||||
self.listen(5)
|
||||
self.logger.info("Socket ready for accepting new connections")
|
||||
|
||||
self._on_connect()
|
||||
|
||||
|
||||
def close(self):
|
||||
|
|
@ -206,25 +164,14 @@ class SocketListener(AbstractServer):
|
|||
import socket
|
||||
|
||||
try:
|
||||
self.socket.shutdown(socket.SHUT_RDWR)
|
||||
self.socket.close()
|
||||
if self.sock_location is not None:
|
||||
os.remove(self.sock_location)
|
||||
self.shutdown(socket.SHUT_RDWR)
|
||||
except socket.error:
|
||||
pass
|
||||
|
||||
return super().close()
|
||||
super().close()
|
||||
|
||||
|
||||
# Read
|
||||
|
||||
def read(self):
|
||||
if self.closed:
|
||||
return []
|
||||
|
||||
conn, addr = self.socket.accept()
|
||||
self.nb_son += 1
|
||||
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
|
||||
self.new_server_cb(ss)
|
||||
|
||||
return []
|
||||
try:
|
||||
if self._socket_path is not None:
|
||||
os.remove(self._socket_path)
|
||||
except:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue