Use fileno instead of name to index existing servers

This commit is contained in:
nemunaire 2016-05-16 17:35:24 +02:00
parent 1c21231f31
commit 6d8dca211d
6 changed files with 78 additions and 67 deletions

View File

@ -255,9 +255,9 @@ class Bot(threading.Thread):
srv = server.server(config) srv = server.server(config)
# Add the server in the context # Add the server in the context
if self.add_server(srv, server.autoconnect): if self.add_server(srv, server.autoconnect):
logger.info("Server '%s' successfully added." % srv.id) logger.info("Server '%s' successfully added." % srv.name)
else: else:
logger.error("Can't add server '%s'." % srv.id) logger.error("Can't add server '%s'." % srv.name)
# Load module and their configuration # Load module and their configuration
for mod in config.modules: for mod in config.modules:
@ -306,7 +306,7 @@ class Bot(threading.Thread):
if type(eid) is uuid.UUID: if type(eid) is uuid.UUID:
evt.id = str(eid) evt.id = str(eid)
else: else:
# Ok, this is quite useless... # Ok, this is quiet useless...
try: try:
evt.id = str(uuid.UUID(eid)) evt.id = str(uuid.UUID(eid))
except ValueError: except ValueError:
@ -414,8 +414,10 @@ class Bot(threading.Thread):
autoconnect -- connect after add? autoconnect -- connect after add?
""" """
if srv.id not in self.servers: fileno = srv.fileno()
self.servers[srv.id] = srv if fileno not in self.servers:
self.servers[fileno] = srv
self.servers[srv.name] = srv
if autoconnect and not hasattr(self, "noautoconnect"): if autoconnect and not hasattr(self, "noautoconnect"):
srv.open() srv.open()
return True return True

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -38,7 +38,7 @@ class MessageConsumer:
msgs = [] msgs = []
# Parse the message # Parse message
try: try:
for msg in self.srv.parse(self.orig): for msg in self.srv.parse(self.orig):
msgs.append(msg) msgs.append(msg)
@ -46,21 +46,10 @@ class MessageConsumer:
logger.exception("Error occurred during the processing of the %s: " logger.exception("Error occurred during the processing of the %s: "
"%s", type(self.orig).__name__, self.orig) "%s", type(self.orig).__name__, self.orig)
if len(msgs) <= 0: # Treat message
return
# Qualify the message
if not hasattr(msg, "server") or msg.server is None:
msg.server = self.srv.id
if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
from nemubot.server.abstract import AbstractServer
# Treat the message
for msg in msgs: for msg in msgs:
for res in context.treater.treat_msg(msg): for res in context.treater.treat_msg(msg):
# Identify the destination # Identify destination
to_server = None to_server = None
if isinstance(res, str): if isinstance(res, str):
to_server = self.srv to_server = self.srv
@ -69,8 +58,8 @@ class MessageConsumer:
continue continue
elif res.server is None: elif res.server is None:
to_server = self.srv to_server = self.srv
res.server = self.srv.id res.server = self.srv.fileno()
elif isinstance(res.server, str) and res.server in context.servers: elif res.server in context.servers:
to_server = context.servers[res.server] to_server = context.servers[res.server]
else: else:
to_server = res.server to_server = res.server
@ -79,7 +68,7 @@ class MessageConsumer:
logger.error("The server defined in this response doesn't exist: %s", res.server) logger.error("The server defined in this response doesn't exist: %s", res.server)
continue continue
# Sent the message only if treat_post authorize it # Sent message
to_server.send_response(res) to_server.send_response(res)
@ -116,7 +105,7 @@ class Consumer(threading.Thread):
def __init__(self, context): def __init__(self, context):
self.context = context self.context = context
self.stop = False self.stop = False
threading.Thread.__init__(self) super().__init__(name="Nemubot consumer")
def run(self): def run(self):

View File

@ -54,8 +54,7 @@ class IRC(SocketServer):
self.owner = owner self.owner = owner
self.realname = realname self.realname = realname
self.id = self.username + "@" + host + ":" + str(port) super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
super().__init__(host=host, port=port, ssl=ssl)
self.printer = IRCPrinter self.printer = IRCPrinter
self.encoding = encoding self.encoding = encoding

View File

@ -25,19 +25,18 @@ class AbstractServer(io.IOBase):
"""An abstract server: handle communication with an IM server""" """An abstract server: handle communication with an IM server"""
def __init__(self, send_callback=None): def __init__(self, name=None, send_callback=None):
"""Initialize an abstract server """Initialize an abstract server
Keyword argument: Keyword argument:
send_callback -- Callback when developper want to send a message send_callback -- Callback when developper want to send a message
""" """
self._name = name
super().__init__() super().__init__()
if not hasattr(self, "id"): self.logger = logging.getLogger("nemubot.server." + self.name)
raise Exception("No id defined for this server. Please set one!")
self.logger = logging.getLogger("nemubot.server." + self.id)
self._sending_queue = queue.Queue() self._sending_queue = queue.Queue()
if send_callback is not None: if send_callback is not None:
self._send_callback = send_callback self._send_callback = send_callback
@ -45,6 +44,14 @@ class AbstractServer(io.IOBase):
self._send_callback = self._write_select self._send_callback = self._write_select
@property
def name(self):
if self._name is not None:
return self._name
else:
return self.fileno()
# Open/close # Open/close
def __enter__(self): def __enter__(self):
@ -151,4 +158,4 @@ class AbstractServer(io.IOBase):
def exception(self): def exception(self):
"""Exception occurs in fd""" """Exception occurs in fd"""
self.logger.warning("Unhandle file descriptor exception on server %s", self.logger.warning("Unhandle file descriptor exception on server %s",
self.id) self.name)

View File

@ -146,7 +146,7 @@ class IRC(Abstract):
receivers = self.decode(self.params[0]).split(',') receivers = self.decode(self.params[0]).split(',')
common_args = { common_args = {
"server": srv.id, "server": srv.name,
"date": self.tags["time"], "date": self.tags["time"],
"to": receivers, "to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers], "to_response": [r if r != srv.nick else self.nick for r in receivers],

View File

@ -23,18 +23,43 @@ class SocketServer(AbstractServer):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)""" """Concrete implementation of a socket connexion (can be wrapped with TLS)"""
def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None): def __init__(self, sock_location=None,
if id is not None: host=None, port=None,
self.id = id sock=None,
super().__init__() ssl=False,
name=None):
"""Create a server socket
Keyword arguments:
sock_location -- Path to the UNIX socket
host -- Hostname of the INET socket
port -- Port of the INET socket
sock -- Already connected socket
ssl -- Should TLS connection enabled
name -- Convinience name
"""
import socket
assert(sock is None or isinstance(sock, socket.SocketType))
assert(port is None or isinstance(port, int))
super().__init__(name=name)
if sock is None:
if sock_location is not None: if sock_location is not None:
self.filename = sock_location self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connect_to = sock_location
elif host is not None: elif host is not None:
self.host = host for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
self.port = int(port) self.socket = socket.socket(af, socktype, proto)
self.connect_to = sa
break
else:
self.socket = sock
self.ssl = ssl self.ssl = ssl
self.socket = socket
self.readbuffer = b'' self.readbuffer = b''
self.printer = SocketPrinter self.printer = SocketPrinter
@ -46,33 +71,22 @@ class SocketServer(AbstractServer):
@property @property
def closed(self): def closed(self):
"""Indicator of the connection aliveness""" """Indicator of the connection aliveness"""
return self.socket is None return self.socket._closed
# Open/close # Open/close
def open(self): def open(self):
import socket
if not self.closed: if not self.closed:
return True return True
try: try:
if hasattr(self, "filename"): self.socket.connect(self.connect_to)
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.logger.info("Connected to %s", self.connect_to)
self.socket.connect(self.filename)
self.logger.info("Connected to %s", self.filename)
else:
self.socket = socket.create_connection((self.host, self.port))
self.logger.info("Connected to %s:%d", self.host, self.port)
except: except:
self.socket = None self.socket.close()
if hasattr(self, "filename"):
self.logger.exception("Unable to connect to %s", self.logger.exception("Unable to connect to %s",
self.filename) self.connect_to)
else:
self.logger.exception("Unable to connect to %s:%d",
self.host, self.port)
return False return False
# Wrap the socket for SSL # Wrap the socket for SSL
@ -87,18 +101,19 @@ class SocketServer(AbstractServer):
def close(self): def close(self):
import socket import socket
# Flush the sending queue before close
from nemubot.server import _lock from nemubot.server import _lock
_lock.release() _lock.release()
self._sending_queue.join() self._sending_queue.join()
_lock.acquire() _lock.acquire()
if not self.closed: if not self.closed:
try: try:
self.socket.shutdown(socket.SHUT_RDWR) self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
except socket.error: except socket.error:
pass pass
self.socket = None self.socket.close()
return super().close() return super().close()
@ -142,14 +157,13 @@ class SocketServer(AbstractServer):
except ValueError: except ValueError:
args = line.split(' ') args = line.split(' ')
yield message.Command(cmd=args[0], args=args[1:], server=self.id, to=["you"], frm="you") yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
class SocketListener(AbstractServer): class SocketListener(AbstractServer):
def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None): def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
self.id = id super().__init__(name=name)
super().__init__()
self.new_server_cb = new_server_cb self.new_server_cb = new_server_cb
self.sock_location = sock_location self.sock_location = sock_location
self.host = host self.host = host
@ -210,7 +224,7 @@ class SocketListener(AbstractServer):
conn, addr = self.socket.accept() conn, addr = self.socket.accept()
self.nb_son += 1 self.nb_son += 1
ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn) ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss) self.new_server_cb(ss)
return [] return []