From 81593a493b1b3e98f3932e2fc584c8b5e456354f Mon Sep 17 00:00:00 2001 From: nemunaire Date: Sat, 30 Aug 2014 19:15:14 +0200 Subject: [PATCH] (wip) use select instead of a thread by server. Currently read and write seems to work properly, we lost some code (like more, next, ...) that never should be in server part --- bot.py | 33 +++++- consumer.py | 33 +++--- nemubot.py | 1 + prompt/builtins.py | 3 +- response.py | 7 ++ server/IRC.py | 282 ++++----------------------------------------- server/__init__.py | 188 ++++++++---------------------- server/socket.py | 77 +++++++++++++ 8 files changed, 205 insertions(+), 419 deletions(-) create mode 100644 server/socket.py diff --git a/bot.py b/bot.py index 2ac55b1..980f1bb 100644 --- a/bot.py +++ b/bot.py @@ -21,6 +21,7 @@ from datetime import timedelta import logging from queue import Queue import re +from select import select import threading import time import uuid @@ -38,7 +39,7 @@ import response logger = logging.getLogger("nemubot.bot") -class Bot: +class Bot(threading.Thread): """Class containing the bot context and ensuring key goals""" @@ -51,6 +52,8 @@ class Bot: data_path -- Path to directory where store bot context data """ + threading.Thread.__init__(self) + logger.info("Initiate nemubot v%s", __version__) # External IP for accessing this bot @@ -90,6 +93,22 @@ class Bot: self) + def run(self): + from server import _rlist, _wlist, _xlist + + self.stop = False + while not self.stop: + rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) + + for x in xl: + x.exception() + for w in wl: + w.write_select() + for r in rl: + for i in r.read(): + self.receive_message(r, i) + + def init_ctcp_capabilities(self): """Reset existing CTCP capabilities to default one""" @@ -277,17 +296,16 @@ class Bot: c.start() - def add_server(self, node, nick, owner, realname, ssl=False): + def add_server(self, node, nick, owner, realname): """Add a new server to the context""" - srv = IRCServer(node, nick, owner, realname, ssl) + srv = IRCServer(node, nick, owner, realname) srv.add_hook = lambda h: self.hooks.add_hook("irc_hook", h, self) srv.add_networkbot = self.add_networkbot srv.send_bot = lambda d: self.send_networkbot(srv, d) - srv.register_hooks() + #srv.register_hooks() if srv.id not in self.servers: self.servers[srv.id] = srv - if srv.autoconnect: - srv.launch(self.receive_message) + srv.open() return True else: return False @@ -378,6 +396,9 @@ class Bot: for srv in k: self.servers[srv].disconnect() + self.stop = True + + # Hooks cache def create_cache(self, name): diff --git a/consumer.py b/consumer.py index 9ad8304..5d41891 100644 --- a/consumer.py +++ b/consumer.py @@ -46,9 +46,8 @@ class MessageConsumer: def treat_in(self, context, msg): """Treat the input message""" if msg.cmd == "PING": - self.srv.send_pong(msg.params[0]) + self.srv.write("%s :%s" % ("PONG", msg.params[0])) elif hasattr(msg, "receivers"): - msg.receivers = [ receiver for receiver in msg.receivers if self.srv.accepted_channel(receiver) ] if msg.receivers: # All messages context.treat_pre(msg, self.srv) @@ -62,21 +61,29 @@ class MessageConsumer: if r is not None: self.treat_out(context, r) elif isinstance(res, response.Response): - # Define the destination server - if (res.server is not None and - isinstance(res.server, str) and res.server in context.servers): - res.server = context.servers[res.server] - if (res.server is not None and - not isinstance(res.server, server.Server)): + # Define the destination server + to_server = None + if res.server is None: + to_server = self.srv + res.server = self.srv.id + elif isinstance(res.server, str) and res.server in context.servers: + to_server = context.servers[res.server] + + if to_server is None: logger.error("The server defined in this response doesn't " "exist: %s", res.server) - res.server = None - if res.server is None: - res.server = self.srv + return False # Sent the message only if treat_post authorize it if context.treat_post(res): - res.server.send_response(res, self.data) + if type(res.channel) != list: + res.channel = [ res.channel ] + for channel in res.channel: + if channel != to_server.nick: + to_server.write("%s %s :%s" % ("PRIVMSG", channel, res.get_message())) + else: + channel = res.sender + to_server.write("%s %s :%s" % ("NOTICE" if res.is_ctcp else "PRIVMSG", channel, res.get_message())) elif res is not None: logger.error("Unrecognized response type: %s", res) @@ -98,7 +105,7 @@ class MessageConsumer: self.treat_out(context, res) # Inform that the message has been treated - self.srv.msg_treated(self.data) + #self.srv.msg_treated(self.data) diff --git a/nemubot.py b/nemubot.py index 9d17ef8..f04899c 100755 --- a/nemubot.py +++ b/nemubot.py @@ -68,6 +68,7 @@ if __name__ == "__main__": print ("Nemubot v%s ready, my PID is %i!" % (bot.__version__, os.getpid())) + context.start() while prmpt.run(context): try: # Reload context diff --git a/prompt/builtins.py b/prompt/builtins.py index 7633e46..f23400d 100644 --- a/prompt/builtins.py +++ b/prompt/builtins.py @@ -69,8 +69,7 @@ def load_file(filename, context): nick = server["nick"] if server.hasAttribute("nick") else config["nick"] owner = server["owner"] if server.hasAttribute("owner") else config["owner"] realname = server["realname"] if server.hasAttribute("realname") else config["realname"] - if context.add_server(server, nick, owner, realname, - server.hasAttribute("ssl")): + if context.add_server(server, nick, owner, realname): print("Server `%s:%s' successfully added." % (server["server"], server["port"])) else: diff --git a/response.py b/response.py index 4533c86..781c4d8 100644 --- a/response.py +++ b/response.py @@ -42,6 +42,13 @@ class Response: self.set_sender(sender) self.count = count + @property + def receivers(self): + if type(self.channel) is list: + return self.channel + else: + return [ self.channel ] + @property def content(self): #FIXME: error when messages in self.messages are list! diff --git a/server/IRC.py b/server/IRC.py index 16bcabd..f05c2c8 100644 --- a/server/IRC.py +++ b/server/IRC.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -# Nemubot is a modulable IRC bot, built around XML configuration files. -# Copyright (C) 2012 Mercier Pierre-Olivier +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2014 nemunaire # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -16,272 +16,32 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import errno -import os -import ssl -import socket -import threading -import traceback - -from channel import Channel -from server.DCC import DCC -from hooks import Hook -import message import server -import xmlparser +from server.socket import SocketServer -class IRCServer(server.Server): - """Class to interact with an IRC server""" +class IRCServer(SocketServer): - def __init__(self, node, nick, owner, realname, ssl=False): - """Initialize an IRC server - - Arguments: - node -- server node from XML configuration - nick -- nick used by the bot on this server - owner -- nick used by the bot owner on this server - realname -- string used as realname on this server - ssl -- require SSL? - """ - self.node = node - - server.Server.__init__(self) + def __init__(self, node, nick, owner, realname): + SocketServer.__init__(self, + node["host"], + node["port"], + node["password"], + node.hasAttribute("ssl") and node["ssl"].lower() == "true") self.nick = nick self.owner = owner self.realname = realname - self.ssl = ssl + self.id = "TODO" - # Listen private messages? - self.listen_nick = True - - self.dcc_clients = dict() - - self.channels = dict() - for chn in self.node.getNodes("channel"): - chan = Channel(chn["name"], chn["password"]) - self.channels[chan.name] = chan - - - @property - def host(self): - """Return the server hostname""" - if self.node is not None and self.node.hasAttribute("server"): - return self.node["server"] - else: - return "localhost" - - @property - def port(self): - """Return the connection port used on this server""" - if self.node is not None and self.node.hasAttribute("port"): - return self.node.getInt("port") - else: - return "6667" - - @property - def password(self): - """Return the password used to connect to this server""" - if self.node is not None and self.node.hasAttribute("password"): - return self.node["password"] - else: - return None - - @property - def allow_all(self): - """If True, treat message from all channels, not only listed one""" - return (self.node is not None and self.node.hasAttribute("allowall") - and self.node["allowall"] == "true") - - @property - def autoconnect(self): - """Autoconnect the server when added""" - if self.node is not None and self.node.hasAttribute("autoconnect"): - value = self.node["autoconnect"].lower() - return value != "no" and value != "off" and value != "false" - else: - return False - - @property - def id(self): - """Gives the server identifiant""" - return self.host + ":" + str(self.port) - - def register_hooks(self): - self.add_hook(Hook(self.evt_channel, "JOIN")) - self.add_hook(Hook(self.evt_channel, "PART")) - self.add_hook(Hook(self.evt_server, "NICK")) - self.add_hook(Hook(self.evt_server, "QUIT")) - self.add_hook(Hook(self.evt_channel, "332")) - self.add_hook(Hook(self.evt_channel, "353")) - - def evt_server(self, msg, srv): - for chan in self.channels: - self.channels[chan].treat(msg.cmd, msg) - - def evt_channel(self, msg, srv): - if msg.receivers is not None: - for receiver in msg.receivers: - if receiver in self.channels: - self.channels[receiver].treat(msg.cmd, msg) - - def accepted_channel(self, chan, sender=None): - """Return True if the channel (or the user) is authorized""" - return (self.allow_all or - (chan in self.channels and (sender is None or sender in self.channels[chan].people)) or - (self.listen_nick and chan == self.nick)) - - def join(self, chan, password=None, force=False): - """Join a channel""" - if force or (chan is not None and - self.connected and chan not in self.channels): - self.channels[chan] = Channel(chan, password) - if password is not None: - self.s.send(("JOIN %s %s\r\n" % (chan, password)).encode()) - else: - self.s.send(("JOIN %s\r\n" % chan).encode()) + def _open(self): + if SocketServer._open(self): + if self.password is not None: + self.write("PASS :" + self.password) + self.write("NICK :" + self.nick) + self.write("USER %s %s bla :%s" % (self.nick, self.host, self.realname)) return True - else: - return False + return False - def leave(self, chan): - """Leave a channel""" - if chan is not None and self.connected and chan in self.channels: - if isinstance(chan, list): - for c in chan: - self.leave(c) - else: - self.s.send(("PART %s\r\n" % self.channels[chan].name).encode()) - del self.channels[chan] - return True - else: - return False - -# Main loop - def run(self): - if not self.connected: - self.s = socket.socket() #Create the socket - if self.ssl: - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - self.s = ctx.wrap_socket(self.s) - try: - self.s.connect((self.host, self.port)) #Connect to server - except socket.error as e: - self.s = None - self.logger.critical("Unable to connect to %s:%d: %s", - self.host, self.port, os.strerror(e.errno)) - return - self.stopping.clear() - - if self.password != None: - self.s.send(b"PASS " + self.password.encode () + b"\r\n") - self.s.send(("NICK %s\r\n" % self.nick).encode ()) - self.s.send(("USER %s %s bla :%s\r\n" % (self.nick, self.host, - self.realname)).encode()) - raw = self.s.recv(1024) - if not raw: - self.logger.critical("Unable to connect to %s:%d", self.host, self.port) - return - self.connected = True - self.logger.info("Connection to %s:%d completed", self.host, self.port) - - if len(self.channels) > 0: - for chn in self.channels.keys(): - self.join(self.channels[chn].name, - self.channels[chn].password, force=True) - - - readbuffer = b'' #Here we store all the messages from server - try: - while not self.stop: - readbuffer = readbuffer + raw - temp = readbuffer.split(b'\n') - readbuffer = temp.pop() - - for line in temp: - self.treat_msg(line) - raw = self.s.recv(1024) #recieve server messages - except socket.error: - pass - - if self.connected: - self.s.close() - self.connected = False - if self.closing_event is not None: - self.closing_event() - self.logger.info("Server `%s' successfully stopped.", self.id) - self.stopping.set() - # Rearm Thread - threading.Thread.__init__(self) - - -# Overwritted methods - - def disconnect(self): - """Close the socket with the server and all DCC client connections""" - #Close all DCC connection - clts = [c for c in self.dcc_clients] - for clt in clts: - self.dcc_clients[clt].disconnect() - return server.Server.disconnect(self) - - - -# Abstract methods - - def send_pong(self, cnt): - """Send a PONG command to the server with argument cnt""" - self.s.send(("PONG %s\r\n" % cnt).encode()) - - def msg_treated(self, origin): - """Do nothing; here for implement abstract class""" - pass - - def send_dcc(self, msg, to): - """Send a message through DCC connection""" - if msg is not None and to is not None: - realname = to.split("!")[1] - if realname not in self.dcc_clients.keys(): - d = DCC(self, to) - self.dcc_clients[realname] = d - self.dcc_clients[realname].send_dcc(msg) - - def send_msg_final(self, channel, line, cmd="PRIVMSG", endl="\r\n"): - """Send a message without checks or format""" - #TODO: add something for post message treatment here - if channel == self.nick: - self.logger.warn("Nemubot talks to himself: %s", line, stack_info=True) - if line is not None and channel is not None: - if self.s is None: - self.logger.warn("Attempt to send message on a non connected server: %s: %s", self.id, line, stack_info=True) - elif len(line) < 442: - self.s.send(("%s %s :%s%s" % (cmd, channel, line, endl)).encode ()) - else: - self.logger.warn("Message truncated due to size (%d ; max : 442) : %s", len(line), line, stack_info=True) - self.s.send (("%s %s :%s%s" % (cmd, channel, line[0:442]+"<…>", endl)).encode ()) - - def send_msg_usr(self, user, msg): - """Send a message to a user instead of a channel""" - if user is not None and user[0] != "#": - realname = user.split("!")[1] - if realname in self.dcc_clients or user in self.dcc_clients: - self.send_dcc(msg, user) - else: - for line in msg.split("\n"): - if line != "": - self.send_msg_final(user.split('!')[0], msg) - - def send_msg(self, channel, msg, cmd="PRIVMSG", endl="\r\n"): - """Send a message to a channel""" - if self.accepted_channel(channel): - server.Server.send_msg(self, channel, msg, cmd, endl) - - def send_msg_verified(self, sender, channel, msg, cmd = "PRIVMSG", endl = "\r\n"): - """Send a message to a channel, only if the source user is on this channel too""" - if self.accepted_channel(channel, sender): - self.send_msg_final(channel, msg, cmd, endl) - - def send_global(self, msg, cmd="PRIVMSG", endl="\r\n"): - """Send a message to all channels on this server""" - for channel in self.channels.keys(): - self.send_msg(channel, msg, cmd, endl) + def _close(self): + self.write("QUIT") + SocketServer._close(self) diff --git a/server/__init__.py b/server/__init__.py index 8fffe1d..9f423c2 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -# Nemubot is a modulable IRC bot, built around XML configuration files. -# Copyright (C) 2012 Mercier Pierre-Olivier +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2014 nemunaire # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -16,156 +16,70 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import io import logging import socket -import threading +import queue -class Server(threading.Thread): - def __init__(self, socket = None): - self.stop = False - self.stopping = threading.Event() - self.s = socket - self.connected = self.s is not None - self.closing_event = None +# Lists for select +_rlist = [] +_wlist = [] +_xlist = [] - self.moremessages = dict() +# Extends from IOBase in order to be compatible with select function +class AbstractServer(io.IOBase): - self.logger = logging.getLogger("nemubot.server." + self.id) + """An abstract server: handle communication with an IM server""" - threading.Thread.__init__(self) + def __init__(self, send_callback=None): + """Initialize an abstract server - def isDCC(self, to=None): - return to is not None and to in self.dcc_clients + Keyword argument: + send_callback -- Callback when developper want to send a message + """ - @property - def ip(self): - """Convert common IP representation to little-endian integer representation""" - sum = 0 - if self.node.hasAttribute("ip"): - ip = self.node["ip"] + self.logger = logging.getLogger("nemubot.server.TODO") + self._sending_queue = queue.Queue() + if send_callback is not None: + self._send_callback = send_callback else: - #TODO: find the external IP - ip = "0.0.0.0" - for b in ip.split("."): - sum = 256 * sum + int(b) - return sum + self._send_callback = self._write_select - def toIP(self, input): - """Convert little-endian int to IPv4 adress""" - ip = "" - for i in range(0,4): - mod = input % 256 - ip = "%d.%s" % (mod, ip) - input = (input - mod) / 256 - return ip[:len(ip) - 1] - @property - def id(self): - """Gives the server identifiant""" - raise NotImplemented() + def open(self): + """Generic open function that register the server un _rlist in case of successful _open""" + if self._open(): + _rlist.append(self) - def accepted_channel(self, msg, sender=None): - return True - def msg_treated(self, origin): - """Action done on server when a message was treated""" - raise NotImplemented() + def close(self): + """Generic close function that register the server un _{r,w,x}list in case of successful _close""" + if self._close(): + if self in _rlist: + _rlist.remove(self) + if self in _wlist: + _wlist.remove(self) + if self in _xlist: + _xlist.remove(self) - def send_response(self, res, origin): - """Analyse a Response and send it""" - if type(res.channel) != list: - res.channel = [ res.channel ] - for channel in res.channel: - if channel != self.nick: - self.send_msg(channel, res.get_message()) - else: - channel = res.sender - self.send_msg_usr(channel, res.get_message(), "NOTICE" if res.is_ctcp else "PRIVMSG") + def write(self, message): + """Send a message to the server using send_callback""" + self._send_callback(message) - if not res.alone: - if hasattr(self, "send_bot"): - self.send_bot("NOMORE %s" % res.channel) - self.moremessages[channel] = res + def write_select(self): + """Internal function used by the select function""" + try: + while not self._sending_queue.empty(): + self._write(self._sending_queue.get_nowait()) + _wlist.remove(self) - def send_ctcp(self, to, msg, cmd="NOTICE", endl="\r\n"): - """Send a message as CTCP response""" - if msg is not None and to is not None: - for line in msg.split("\n"): - if line != "": - self.send_msg_final(to.split("!")[0], "\x01" + line + "\x01", cmd, endl) + except queue.Empty: + pass - def send_dcc(self, msg, to): - """Send a message through DCC connection""" - raise NotImplemented() - - def send_msg_final(self, channel, msg, cmd="PRIVMSG", endl="\r\n"): - """Send a message without checks or format""" - raise NotImplemented() - - def send_msg_usr(self, user, msg): - """Send a message to a user instead of a channel""" - raise NotImplemented() - - def send_msg(self, channel, msg, cmd="PRIVMSG", endl="\r\n"): - """Send a message to a channel""" - if msg is not None: - for line in msg.split("\n"): - if line != "": - self.send_msg_final(channel, line, cmd, endl) - - def send_msg_verified(self, sender, channel, msg, cmd="PRIVMSG", endl="\r\n"): - """A more secure way to send messages""" - raise NotImplemented() - - def send_global(self, msg, cmd="PRIVMSG", endl="\r\n"): - """Send a message to all channels on this server""" - raise NotImplemented() - - def disconnect(self): - """Close the socket with the server""" - if self.connected: - self.stop = True - try: - self.s.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self.stopping.wait() - return True - else: - return False - - def kill(self): - """Just stop the main loop, don't close the socket directly""" - if self.connected: - self.stop = True - self.connected = False - #Send a message in order to close the socket - try: - self.s.send(("Bye!\r\n").encode ()) - except: - pass - self.stopping.wait() - return True - else: - return False - - def launch(self, receive_action, verb=True): - """Connect to the server if it is no yet connected""" - self._receive_action = receive_action - if not self.connected: - self.stop = False - self.logger.info("Entering main loop for server") - try: - self.start() - except RuntimeError: - pass - elif verb: - print (" Already connected.") - - def treat_msg(self, line, private=False): - self._receive_action(self, line, private) - - def run(self): - raise NotImplemented() + def _write_select(self, message): + """Send a message to the server safely through select""" + self._sending_queue.put(self.format(message)) + self.logger.debug("Message '%s' appended to Queue", message) + if self not in _wlist: + _wlist.append(self) diff --git a/server/socket.py b/server/socket.py new file mode 100644 index 0000000..21d3d0c --- /dev/null +++ b/server/socket.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2014 nemunaire +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import ssl +import socket + +from server import AbstractServer + +class SocketServer(AbstractServer): + + def __init__(self, host, port=6667, password=None, ssl=False): + AbstractServer.__init__(self) + self.host = host + self.port = int(port) + self.password = password + self.ssl = ssl + + self.socket = None + self.readbuffer = b'' + + def fileno(self): + return self.socket.fileno() if self.socket else None + + def _open(self): + # Create the socket + self.socket = socket.socket() + + # Wrap the socket for SSL + if self.ssl: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.socket = ctx.wrap_socket(self.socket) + + try: + self.socket.connect((self.host, self.port)) #Connect to server + self.logger.info("Connected to %s:%d", self.host, self.port) + except socket.error as e: + self.socket = None + self.logger.critical("Unable to connect to %s:%d: %s", + self.host, self.port, os.strerror(e.errno)) + return False + + return True + + def _close(self): + if self.socket is not None: + self.socket.shutdown(SHUT_RDWR) + self.socket.close() + self.socket = None + + def _write(self, cnt): + self.socket.send(cnt) + + def format(self, txt): + return txt.encode() + b'\r\n' + + def read(self): + raw = self.socket.recv(1024) + temp = (self.readbuffer + raw).split(b'\r\n') + self.readbuffer = temp.pop() + + for line in temp: + yield line