Use fileno instead of name to index existing servers

This commit is contained in:
nemunaire 2016-05-16 17:35:24 +02:00
parent 1c21231f31
commit 6d8dca211d
6 changed files with 78 additions and 67 deletions

View File

@ -255,9 +255,9 @@ class Bot(threading.Thread):
srv = server.server(config)
# Add the server in the context
if self.add_server(srv, server.autoconnect):
logger.info("Server '%s' successfully added." % srv.id)
logger.info("Server '%s' successfully added." % srv.name)
else:
logger.error("Can't add server '%s'." % srv.id)
logger.error("Can't add server '%s'." % srv.name)
# Load module and their configuration
for mod in config.modules:
@ -306,7 +306,7 @@ class Bot(threading.Thread):
if type(eid) is uuid.UUID:
evt.id = str(eid)
else:
# Ok, this is quite useless...
# Ok, this is quiet useless...
try:
evt.id = str(uuid.UUID(eid))
except ValueError:
@ -414,8 +414,10 @@ class Bot(threading.Thread):
autoconnect -- connect after add?
"""
if srv.id not in self.servers:
self.servers[srv.id] = srv
fileno = srv.fileno()
if fileno not in self.servers:
self.servers[fileno] = srv
self.servers[srv.name] = srv
if autoconnect and not hasattr(self, "noautoconnect"):
srv.open()
return True

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -38,7 +38,7 @@ class MessageConsumer:
msgs = []
# Parse the message
# Parse message
try:
for msg in self.srv.parse(self.orig):
msgs.append(msg)
@ -46,21 +46,10 @@ class MessageConsumer:
logger.exception("Error occurred during the processing of the %s: "
"%s", type(self.orig).__name__, self.orig)
if len(msgs) <= 0:
return
# Qualify the message
if not hasattr(msg, "server") or msg.server is None:
msg.server = self.srv.id
if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
from nemubot.server.abstract import AbstractServer
# Treat the message
# Treat message
for msg in msgs:
for res in context.treater.treat_msg(msg):
# Identify the destination
# Identify destination
to_server = None
if isinstance(res, str):
to_server = self.srv
@ -69,8 +58,8 @@ class MessageConsumer:
continue
elif res.server is None:
to_server = self.srv
res.server = self.srv.id
elif isinstance(res.server, str) and res.server in context.servers:
res.server = self.srv.fileno()
elif res.server in context.servers:
to_server = context.servers[res.server]
else:
to_server = res.server
@ -79,7 +68,7 @@ class MessageConsumer:
logger.error("The server defined in this response doesn't exist: %s", res.server)
continue
# Sent the message only if treat_post authorize it
# Sent message
to_server.send_response(res)
@ -116,7 +105,7 @@ class Consumer(threading.Thread):
def __init__(self, context):
self.context = context
self.stop = False
threading.Thread.__init__(self)
super().__init__(name="Nemubot consumer")
def run(self):

View File

@ -54,8 +54,7 @@ class IRC(SocketServer):
self.owner = owner
self.realname = realname
self.id = self.username + "@" + host + ":" + str(port)
super().__init__(host=host, port=port, ssl=ssl)
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
self.printer = IRCPrinter
self.encoding = encoding

View File

@ -25,19 +25,18 @@ class AbstractServer(io.IOBase):
"""An abstract server: handle communication with an IM server"""
def __init__(self, send_callback=None):
def __init__(self, name=None, send_callback=None):
"""Initialize an abstract server
Keyword argument:
send_callback -- Callback when developper want to send a message
"""
self._name = name
super().__init__()
if not hasattr(self, "id"):
raise Exception("No id defined for this server. Please set one!")
self.logger = logging.getLogger("nemubot.server." + self.id)
self.logger = logging.getLogger("nemubot.server." + self.name)
self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
@ -45,6 +44,14 @@ class AbstractServer(io.IOBase):
self._send_callback = self._write_select
@property
def name(self):
if self._name is not None:
return self._name
else:
return self.fileno()
# Open/close
def __enter__(self):
@ -151,4 +158,4 @@ class AbstractServer(io.IOBase):
def exception(self):
"""Exception occurs in fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.id)
self.name)

View File

@ -146,7 +146,7 @@ class IRC(Abstract):
receivers = self.decode(self.params[0]).split(',')
common_args = {
"server": srv.id,
"server": srv.name,
"date": self.tags["time"],
"to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers],

View File

@ -23,18 +23,43 @@ class SocketServer(AbstractServer):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None):
if id is not None:
self.id = id
super().__init__()
if sock_location is not None:
self.filename = sock_location
elif host is not None:
self.host = host
self.port = int(port)
def __init__(self, sock_location=None,
host=None, port=None,
sock=None,
ssl=False,
name=None):
"""Create a server socket
Keyword arguments:
sock_location -- Path to the UNIX socket
host -- Hostname of the INET socket
port -- Port of the INET socket
sock -- Already connected socket
ssl -- Should TLS connection enabled
name -- Convinience name
"""
import socket
assert(sock is None or isinstance(sock, socket.SocketType))
assert(port is None or isinstance(port, int))
super().__init__(name=name)
if sock is None:
if sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connect_to = sock_location
elif host is not None:
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
self.socket = socket.socket(af, socktype, proto)
self.connect_to = sa
break
else:
self.socket = sock
self.ssl = ssl
self.socket = socket
self.readbuffer = b''
self.printer = SocketPrinter
@ -46,33 +71,22 @@ class SocketServer(AbstractServer):
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket is None
return self.socket._closed
# Open/close
def open(self):
import socket
if not self.closed:
return True
try:
if hasattr(self, "filename"):
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.connect(self.filename)
self.logger.info("Connected to %s", self.filename)
else:
self.socket = socket.create_connection((self.host, self.port))
self.logger.info("Connected to %s:%d", self.host, self.port)
self.socket.connect(self.connect_to)
self.logger.info("Connected to %s", self.connect_to)
except:
self.socket = None
if hasattr(self, "filename"):
self.logger.exception("Unable to connect to %s",
self.filename)
else:
self.logger.exception("Unable to connect to %s:%d",
self.host, self.port)
self.socket.close()
self.logger.exception("Unable to connect to %s",
self.connect_to)
return False
# Wrap the socket for SSL
@ -87,18 +101,19 @@ class SocketServer(AbstractServer):
def close(self):
import socket
# Flush the sending queue before close
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if not self.closed:
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
except socket.error:
pass
self.socket = None
self.socket.close()
return super().close()
@ -142,14 +157,13 @@ class SocketServer(AbstractServer):
except ValueError:
args = line.split(' ')
yield message.Command(cmd=args[0], args=args[1:], server=self.id, to=["you"], frm="you")
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
class SocketListener(AbstractServer):
def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None):
self.id = id
super().__init__()
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
super().__init__(name=name)
self.new_server_cb = new_server_cb
self.sock_location = sock_location
self.host = host
@ -210,7 +224,7 @@ class SocketListener(AbstractServer):
conn, addr = self.socket.accept()
self.nb_son += 1
ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn)
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []