From e103d22bf2d56da70a5a815bf1909641e43bfeac Mon Sep 17 00:00:00 2001 From: nemunaire Date: Mon, 16 May 2016 17:35:24 +0200 Subject: [PATCH] Use fileno instead of name to index existing servers --- nemubot/bot.py | 10 ++--- nemubot/consumer.py | 10 ++--- nemubot/server/IRC.py | 3 +- nemubot/server/abstract.py | 19 +++++--- nemubot/server/message/IRC.py | 2 +- nemubot/server/socket.py | 82 ++++++++++++++++++++--------------- 6 files changed, 73 insertions(+), 53 deletions(-) diff --git a/nemubot/bot.py b/nemubot/bot.py index d92cc35..b7c71b9 100644 --- a/nemubot/bot.py +++ b/nemubot/bot.py @@ -255,9 +255,9 @@ class Bot(threading.Thread): srv = server.server(config) # Add the server in the context if self.add_server(srv, server.autoconnect): - logger.info("Server '%s' successfully added." % srv.id) + logger.info("Server '%s' successfully added." % srv.name) else: - logger.error("Can't add server '%s'." % srv.id) + logger.error("Can't add server '%s'." % srv.name) # Load module and their configuration for mod in config.modules: @@ -306,7 +306,7 @@ class Bot(threading.Thread): if type(eid) is uuid.UUID: evt.id = str(eid) else: - # Ok, this is quite useless... + # Ok, this is quiet useless... try: evt.id = str(uuid.UUID(eid)) except ValueError: @@ -414,8 +414,8 @@ class Bot(threading.Thread): autoconnect -- connect after add? """ - if srv.id not in self.servers: - self.servers[srv.id] = srv + if srv.fileno not in self.servers: + self.servers[srv.fileno] = srv if autoconnect and not hasattr(self, "noautoconnect"): srv.open() return True diff --git a/nemubot/consumer.py b/nemubot/consumer.py index 431db82..0cd4ed5 100644 --- a/nemubot/consumer.py +++ b/nemubot/consumer.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -51,7 +51,7 @@ class MessageConsumer: # Qualify the message if not hasattr(msg, "server") or msg.server is None: - msg.server = self.srv.id + msg.server = self.srv.name if hasattr(msg, "frm_owner"): msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm) @@ -69,8 +69,8 @@ class MessageConsumer: continue elif res.server is None: to_server = self.srv - res.server = self.srv.id - elif isinstance(res.server, str) and res.server in context.servers: + res.server = self.srv.name + elif res.server in context.servers: to_server = context.servers[res.server] else: to_server = res.server @@ -116,7 +116,7 @@ class Consumer(threading.Thread): def __init__(self, context): self.context = context self.stop = False - threading.Thread.__init__(self) + super().__init__(name="Nemubot consumer") def run(self): diff --git a/nemubot/server/IRC.py b/nemubot/server/IRC.py index e09c77e..08e2bc5 100644 --- a/nemubot/server/IRC.py +++ b/nemubot/server/IRC.py @@ -54,8 +54,7 @@ class IRC(SocketServer): self.owner = owner self.realname = realname - self.id = self.username + "@" + host + ":" + str(port) - super().__init__(host=host, port=port, ssl=ssl) + super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port)) self.printer = IRCPrinter self.encoding = encoding diff --git a/nemubot/server/abstract.py b/nemubot/server/abstract.py index 518d7d6..dc2081d 100644 --- a/nemubot/server/abstract.py +++ b/nemubot/server/abstract.py @@ -25,19 +25,18 @@ class AbstractServer(io.IOBase): """An abstract server: handle communication with an IM server""" - def __init__(self, send_callback=None): + def __init__(self, name=None, send_callback=None): """Initialize an abstract server Keyword argument: send_callback -- Callback when developper want to send a message """ + self._name = name + super().__init__() - if not hasattr(self, "id"): - raise Exception("No id defined for this server. Please set one!") - - self.logger = logging.getLogger("nemubot.server." + self.id) + self.logger = logging.getLogger("nemubot.server." + self.name) self._sending_queue = queue.Queue() if send_callback is not None: self._send_callback = send_callback @@ -45,6 +44,14 @@ class AbstractServer(io.IOBase): self._send_callback = self._write_select + @property + def name(self): + if self._name is not None: + return self._name + else: + return self.fileno() + + # Open/close def __enter__(self): @@ -151,4 +158,4 @@ class AbstractServer(io.IOBase): def exception(self): """Exception occurs in fd""" self.logger.warning("Unhandle file descriptor exception on server %s", - self.id) + self.name) diff --git a/nemubot/server/message/IRC.py b/nemubot/server/message/IRC.py index f6d562f..4c9e280 100644 --- a/nemubot/server/message/IRC.py +++ b/nemubot/server/message/IRC.py @@ -146,7 +146,7 @@ class IRC(Abstract): receivers = self.decode(self.params[0]).split(',') common_args = { - "server": srv.id, + "server": srv.name, "date": self.tags["time"], "to": receivers, "to_response": [r if r != srv.nick else self.nick for r in receivers], diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py index 6876d2f..13ac9bd 100644 --- a/nemubot/server/socket.py +++ b/nemubot/server/socket.py @@ -23,18 +23,43 @@ class SocketServer(AbstractServer): """Concrete implementation of a socket connexion (can be wrapped with TLS)""" - def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None): - if id is not None: - self.id = id - super().__init__() - if sock_location is not None: - self.filename = sock_location - elif host is not None: - self.host = host - self.port = int(port) + def __init__(self, sock_location=None, + host=None, port=None, + sock=None, + ssl=False, + name=None): + """Create a server socket + + Keyword arguments: + sock_location -- Path to the UNIX socket + host -- Hostname of the INET socket + port -- Port of the INET socket + sock -- Already connected socket + ssl -- Should TLS connection enabled + name -- Convinience name + """ + + import socket + + assert(sock is None or isinstance(sock, socket.SocketType)) + assert(port is None or isinstance(port, int)) + + super().__init__(name=name) + + if sock is None: + if sock_location is not None: + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.connect_to = sock_location + elif host is not None: + for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): + self.socket = socket.socket(af, socktype, proto) + self.connect_to = sa + break + else: + self.socket = sock + self.ssl = ssl - self.socket = socket self.readbuffer = b'' self.printer = SocketPrinter @@ -46,33 +71,22 @@ class SocketServer(AbstractServer): @property def closed(self): """Indicator of the connection aliveness""" - return self.socket is None + return self.socket._closed # Open/close def open(self): - import socket - if not self.closed: return True try: - if hasattr(self, "filename"): - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.socket.connect(self.filename) - self.logger.info("Connected to %s", self.filename) - else: - self.socket = socket.create_connection((self.host, self.port)) - self.logger.info("Connected to %s:%d", self.host, self.port) + self.socket.connect(self.connect_to) + self.logger.info("Connected to %s", self.connect_to) except: - self.socket = None - if hasattr(self, "filename"): - self.logger.exception("Unable to connect to %s", - self.filename) - else: - self.logger.exception("Unable to connect to %s:%d", - self.host, self.port) + self.socket.close() + self.logger.exception("Unable to connect to %s", + self.connect_to) return False # Wrap the socket for SSL @@ -87,18 +101,19 @@ class SocketServer(AbstractServer): def close(self): import socket + # Flush the sending queue before close from nemubot.server import _lock _lock.release() self._sending_queue.join() _lock.acquire() + if not self.closed: try: self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() except socket.error: pass - self.socket = None + self.socket.close() return super().close() @@ -142,14 +157,13 @@ class SocketServer(AbstractServer): except ValueError: args = line.split(' ') - yield message.Command(cmd=args[0], args=args[1:], server=self.id, to=["you"], frm="you") + yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you") class SocketListener(AbstractServer): - def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None): - self.id = id - super().__init__() + def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None): + super().__init__(name=name) self.new_server_cb = new_server_cb self.sock_location = sock_location self.host = host @@ -210,7 +224,7 @@ class SocketListener(AbstractServer): conn, addr = self.socket.accept() self.nb_son += 1 - ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn) + ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn) self.new_server_cb(ss) return []