Use fileno instead of name to index existing servers
This commit is contained in:
parent
1c21231f31
commit
6d8dca211d
@ -255,9 +255,9 @@ class Bot(threading.Thread):
|
|||||||
srv = server.server(config)
|
srv = server.server(config)
|
||||||
# Add the server in the context
|
# Add the server in the context
|
||||||
if self.add_server(srv, server.autoconnect):
|
if self.add_server(srv, server.autoconnect):
|
||||||
logger.info("Server '%s' successfully added." % srv.id)
|
logger.info("Server '%s' successfully added." % srv.name)
|
||||||
else:
|
else:
|
||||||
logger.error("Can't add server '%s'." % srv.id)
|
logger.error("Can't add server '%s'." % srv.name)
|
||||||
|
|
||||||
# Load module and their configuration
|
# Load module and their configuration
|
||||||
for mod in config.modules:
|
for mod in config.modules:
|
||||||
@ -306,7 +306,7 @@ class Bot(threading.Thread):
|
|||||||
if type(eid) is uuid.UUID:
|
if type(eid) is uuid.UUID:
|
||||||
evt.id = str(eid)
|
evt.id = str(eid)
|
||||||
else:
|
else:
|
||||||
# Ok, this is quite useless...
|
# Ok, this is quiet useless...
|
||||||
try:
|
try:
|
||||||
evt.id = str(uuid.UUID(eid))
|
evt.id = str(uuid.UUID(eid))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -414,8 +414,10 @@ class Bot(threading.Thread):
|
|||||||
autoconnect -- connect after add?
|
autoconnect -- connect after add?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if srv.id not in self.servers:
|
fileno = srv.fileno()
|
||||||
self.servers[srv.id] = srv
|
if fileno not in self.servers:
|
||||||
|
self.servers[fileno] = srv
|
||||||
|
self.servers[srv.name] = srv
|
||||||
if autoconnect and not hasattr(self, "noautoconnect"):
|
if autoconnect and not hasattr(self, "noautoconnect"):
|
||||||
srv.open()
|
srv.open()
|
||||||
return True
|
return True
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Nemubot is a smart and modulable IM bot.
|
# Nemubot is a smart and modulable IM bot.
|
||||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -38,7 +38,7 @@ class MessageConsumer:
|
|||||||
|
|
||||||
msgs = []
|
msgs = []
|
||||||
|
|
||||||
# Parse the message
|
# Parse message
|
||||||
try:
|
try:
|
||||||
for msg in self.srv.parse(self.orig):
|
for msg in self.srv.parse(self.orig):
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
@ -46,21 +46,10 @@ class MessageConsumer:
|
|||||||
logger.exception("Error occurred during the processing of the %s: "
|
logger.exception("Error occurred during the processing of the %s: "
|
||||||
"%s", type(self.orig).__name__, self.orig)
|
"%s", type(self.orig).__name__, self.orig)
|
||||||
|
|
||||||
if len(msgs) <= 0:
|
# Treat message
|
||||||
return
|
|
||||||
|
|
||||||
# Qualify the message
|
|
||||||
if not hasattr(msg, "server") or msg.server is None:
|
|
||||||
msg.server = self.srv.id
|
|
||||||
if hasattr(msg, "frm_owner"):
|
|
||||||
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
|
|
||||||
|
|
||||||
from nemubot.server.abstract import AbstractServer
|
|
||||||
|
|
||||||
# Treat the message
|
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
for res in context.treater.treat_msg(msg):
|
for res in context.treater.treat_msg(msg):
|
||||||
# Identify the destination
|
# Identify destination
|
||||||
to_server = None
|
to_server = None
|
||||||
if isinstance(res, str):
|
if isinstance(res, str):
|
||||||
to_server = self.srv
|
to_server = self.srv
|
||||||
@ -69,8 +58,8 @@ class MessageConsumer:
|
|||||||
continue
|
continue
|
||||||
elif res.server is None:
|
elif res.server is None:
|
||||||
to_server = self.srv
|
to_server = self.srv
|
||||||
res.server = self.srv.id
|
res.server = self.srv.fileno()
|
||||||
elif isinstance(res.server, str) and res.server in context.servers:
|
elif res.server in context.servers:
|
||||||
to_server = context.servers[res.server]
|
to_server = context.servers[res.server]
|
||||||
else:
|
else:
|
||||||
to_server = res.server
|
to_server = res.server
|
||||||
@ -79,7 +68,7 @@ class MessageConsumer:
|
|||||||
logger.error("The server defined in this response doesn't exist: %s", res.server)
|
logger.error("The server defined in this response doesn't exist: %s", res.server)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Sent the message only if treat_post authorize it
|
# Sent message
|
||||||
to_server.send_response(res)
|
to_server.send_response(res)
|
||||||
|
|
||||||
|
|
||||||
@ -116,7 +105,7 @@ class Consumer(threading.Thread):
|
|||||||
def __init__(self, context):
|
def __init__(self, context):
|
||||||
self.context = context
|
self.context = context
|
||||||
self.stop = False
|
self.stop = False
|
||||||
threading.Thread.__init__(self)
|
super().__init__(name="Nemubot consumer")
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
@ -54,8 +54,7 @@ class IRC(SocketServer):
|
|||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.realname = realname
|
self.realname = realname
|
||||||
|
|
||||||
self.id = self.username + "@" + host + ":" + str(port)
|
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
|
||||||
super().__init__(host=host, port=port, ssl=ssl)
|
|
||||||
self.printer = IRCPrinter
|
self.printer = IRCPrinter
|
||||||
|
|
||||||
self.encoding = encoding
|
self.encoding = encoding
|
||||||
|
@ -25,19 +25,18 @@ class AbstractServer(io.IOBase):
|
|||||||
|
|
||||||
"""An abstract server: handle communication with an IM server"""
|
"""An abstract server: handle communication with an IM server"""
|
||||||
|
|
||||||
def __init__(self, send_callback=None):
|
def __init__(self, name=None, send_callback=None):
|
||||||
"""Initialize an abstract server
|
"""Initialize an abstract server
|
||||||
|
|
||||||
Keyword argument:
|
Keyword argument:
|
||||||
send_callback -- Callback when developper want to send a message
|
send_callback -- Callback when developper want to send a message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
self._name = name
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not hasattr(self, "id"):
|
self.logger = logging.getLogger("nemubot.server." + self.name)
|
||||||
raise Exception("No id defined for this server. Please set one!")
|
|
||||||
|
|
||||||
self.logger = logging.getLogger("nemubot.server." + self.id)
|
|
||||||
self._sending_queue = queue.Queue()
|
self._sending_queue = queue.Queue()
|
||||||
if send_callback is not None:
|
if send_callback is not None:
|
||||||
self._send_callback = send_callback
|
self._send_callback = send_callback
|
||||||
@ -45,6 +44,14 @@ class AbstractServer(io.IOBase):
|
|||||||
self._send_callback = self._write_select
|
self._send_callback = self._write_select
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
if self._name is not None:
|
||||||
|
return self._name
|
||||||
|
else:
|
||||||
|
return self.fileno()
|
||||||
|
|
||||||
|
|
||||||
# Open/close
|
# Open/close
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -151,4 +158,4 @@ class AbstractServer(io.IOBase):
|
|||||||
def exception(self):
|
def exception(self):
|
||||||
"""Exception occurs in fd"""
|
"""Exception occurs in fd"""
|
||||||
self.logger.warning("Unhandle file descriptor exception on server %s",
|
self.logger.warning("Unhandle file descriptor exception on server %s",
|
||||||
self.id)
|
self.name)
|
||||||
|
@ -146,7 +146,7 @@ class IRC(Abstract):
|
|||||||
receivers = self.decode(self.params[0]).split(',')
|
receivers = self.decode(self.params[0]).split(',')
|
||||||
|
|
||||||
common_args = {
|
common_args = {
|
||||||
"server": srv.id,
|
"server": srv.name,
|
||||||
"date": self.tags["time"],
|
"date": self.tags["time"],
|
||||||
"to": receivers,
|
"to": receivers,
|
||||||
"to_response": [r if r != srv.nick else self.nick for r in receivers],
|
"to_response": [r if r != srv.nick else self.nick for r in receivers],
|
||||||
|
@ -23,18 +23,43 @@ class SocketServer(AbstractServer):
|
|||||||
|
|
||||||
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
|
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
|
||||||
|
|
||||||
def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None):
|
def __init__(self, sock_location=None,
|
||||||
if id is not None:
|
host=None, port=None,
|
||||||
self.id = id
|
sock=None,
|
||||||
super().__init__()
|
ssl=False,
|
||||||
|
name=None):
|
||||||
|
"""Create a server socket
|
||||||
|
|
||||||
|
Keyword arguments:
|
||||||
|
sock_location -- Path to the UNIX socket
|
||||||
|
host -- Hostname of the INET socket
|
||||||
|
port -- Port of the INET socket
|
||||||
|
sock -- Already connected socket
|
||||||
|
ssl -- Should TLS connection enabled
|
||||||
|
name -- Convinience name
|
||||||
|
"""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
|
||||||
|
assert(sock is None or isinstance(sock, socket.SocketType))
|
||||||
|
assert(port is None or isinstance(port, int))
|
||||||
|
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
if sock is None:
|
||||||
if sock_location is not None:
|
if sock_location is not None:
|
||||||
self.filename = sock_location
|
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
|
self.connect_to = sock_location
|
||||||
elif host is not None:
|
elif host is not None:
|
||||||
self.host = host
|
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
|
||||||
self.port = int(port)
|
self.socket = socket.socket(af, socktype, proto)
|
||||||
|
self.connect_to = sa
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.socket = sock
|
||||||
|
|
||||||
self.ssl = ssl
|
self.ssl = ssl
|
||||||
|
|
||||||
self.socket = socket
|
|
||||||
self.readbuffer = b''
|
self.readbuffer = b''
|
||||||
self.printer = SocketPrinter
|
self.printer = SocketPrinter
|
||||||
|
|
||||||
@ -46,33 +71,22 @@ class SocketServer(AbstractServer):
|
|||||||
@property
|
@property
|
||||||
def closed(self):
|
def closed(self):
|
||||||
"""Indicator of the connection aliveness"""
|
"""Indicator of the connection aliveness"""
|
||||||
return self.socket is None
|
return self.socket._closed
|
||||||
|
|
||||||
|
|
||||||
# Open/close
|
# Open/close
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
import socket
|
|
||||||
|
|
||||||
if not self.closed:
|
if not self.closed:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self, "filename"):
|
self.socket.connect(self.connect_to)
|
||||||
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
self.logger.info("Connected to %s", self.connect_to)
|
||||||
self.socket.connect(self.filename)
|
|
||||||
self.logger.info("Connected to %s", self.filename)
|
|
||||||
else:
|
|
||||||
self.socket = socket.create_connection((self.host, self.port))
|
|
||||||
self.logger.info("Connected to %s:%d", self.host, self.port)
|
|
||||||
except:
|
except:
|
||||||
self.socket = None
|
self.socket.close()
|
||||||
if hasattr(self, "filename"):
|
|
||||||
self.logger.exception("Unable to connect to %s",
|
self.logger.exception("Unable to connect to %s",
|
||||||
self.filename)
|
self.connect_to)
|
||||||
else:
|
|
||||||
self.logger.exception("Unable to connect to %s:%d",
|
|
||||||
self.host, self.port)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Wrap the socket for SSL
|
# Wrap the socket for SSL
|
||||||
@ -87,18 +101,19 @@ class SocketServer(AbstractServer):
|
|||||||
def close(self):
|
def close(self):
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
|
# Flush the sending queue before close
|
||||||
from nemubot.server import _lock
|
from nemubot.server import _lock
|
||||||
_lock.release()
|
_lock.release()
|
||||||
self._sending_queue.join()
|
self._sending_queue.join()
|
||||||
_lock.acquire()
|
_lock.acquire()
|
||||||
|
|
||||||
if not self.closed:
|
if not self.closed:
|
||||||
try:
|
try:
|
||||||
self.socket.shutdown(socket.SHUT_RDWR)
|
self.socket.shutdown(socket.SHUT_RDWR)
|
||||||
self.socket.close()
|
|
||||||
except socket.error:
|
except socket.error:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.socket = None
|
self.socket.close()
|
||||||
|
|
||||||
return super().close()
|
return super().close()
|
||||||
|
|
||||||
@ -142,14 +157,13 @@ class SocketServer(AbstractServer):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
args = line.split(' ')
|
args = line.split(' ')
|
||||||
|
|
||||||
yield message.Command(cmd=args[0], args=args[1:], server=self.id, to=["you"], frm="you")
|
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
|
||||||
|
|
||||||
|
|
||||||
class SocketListener(AbstractServer):
|
class SocketListener(AbstractServer):
|
||||||
|
|
||||||
def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None):
|
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
|
||||||
self.id = id
|
super().__init__(name=name)
|
||||||
super().__init__()
|
|
||||||
self.new_server_cb = new_server_cb
|
self.new_server_cb = new_server_cb
|
||||||
self.sock_location = sock_location
|
self.sock_location = sock_location
|
||||||
self.host = host
|
self.host = host
|
||||||
@ -210,7 +224,7 @@ class SocketListener(AbstractServer):
|
|||||||
|
|
||||||
conn, addr = self.socket.accept()
|
conn, addr = self.socket.accept()
|
||||||
self.nb_son += 1
|
self.nb_son += 1
|
||||||
ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn)
|
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
|
||||||
self.new_server_cb(ss)
|
self.new_server_cb(ss)
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
Loading…
Reference in New Issue
Block a user