diff --git a/imodule.py b/imodule.py index 00e7bb9..499846a 100644 --- a/imodule.py +++ b/imodule.py @@ -38,6 +38,9 @@ class ModuleBase: def load(self): return True + def reload(self): + return self.save() and self.load() + def save(self): return True @@ -64,3 +67,6 @@ class ModuleBase: def parselisten(self, msg): return False + + def parseadmin(self, msg): + return False diff --git a/message.py b/message.py index 76cdeb7..bb5de34 100644 --- a/message.py +++ b/message.py @@ -16,7 +16,7 @@ BANLIST = [] CREDITS = {} filename = "" -def load_module(datas_path): +def load(datas_path): global BANLIST, CREDITS, filename CREDITS = dict () @@ -341,21 +341,17 @@ class Message: self.send_snd("No help for command %s" % self.cmd[1]) else: self.send_snd("Pour me demander quelque chose, commencez votre message par mon nom ; je réagis à certain messages commençant par !, consulter l'aide de chaque module :") - for im in mods.keys(): - self.send_snd(" - !help %s: %s" % (im, mods[im].help_tiny ())) + for im in mods: + self.send_snd(" - !help %s: %s" % (im.name, im.help_tiny ())) else: - for im in mods.keys(): - if im == "alias": - continue - if mods[im].parseanswer(self): + for im in mods: + if im.parseanswer(self): return - if mods["alias"].parseanswer(self): - return else: - for im in mods.keys(): - if mods[im].parselisten(self): + for im in mods: + if im.parselisten(self): return diff --git a/nemubot.py b/nemubot.py index fb7d8bf..603a8fb 100755 --- a/nemubot.py +++ b/nemubot.py @@ -5,24 +5,22 @@ import sys import signal import os import re +import imp from datetime import date from datetime import datetime from datetime import timedelta from xml.dom.minidom import parse -imports = ["birthday", "qd", "events", "youtube", "watchWebsite", "soutenance", "whereis", "alias"] -imports_launch = ["watchWebsite", "events"] -mods = {} -import server, message - if len(sys.argv) != 2 and len(sys.argv) != 3: print ("This script takes exactly 1 arg: a XML config file") sys.exit(1) +imports = ["birthday", "qd", "events", "youtube", "watchWebsite", "soutenance", "whereis", "alias"] +imports_launch = ["watchWebsite", "events"] +mods = {} -def onSignal(signum, frame): - print ("\nSIGINT receive, saving states and close") - +def onClose(): + """Call when the bot quit; saving all modules""" for imp in mods.keys(): mods[imp].save_module () @@ -33,36 +31,43 @@ def onSignal(signum, frame): message.save_module () sys.exit (0) + +def onSignal(signum, frame): + print ("\nSIGINT receive, saving states and close") + onClose() signal.signal(signal.SIGINT, onSignal) +#Define working directory if len(sys.argv) == 3: basedir = sys.argv[2] else: basedir = "./" +#Load base modules +server = __import__ ("server") +message = __import__ ("message") +message.load (basedir + "/datas/") + +#Read configuration XML file dom = parse(sys.argv[1]) config = dom.getElementsByTagName('config')[0] servers = dict () -message.load_module (basedir + "/datas/") - +#Load modules for imp in imports: mod = __import__ (imp) mods[imp] = mod - mod.load_module (basedir + "/datas/") for serveur in config.getElementsByTagName('server'): srv = server.Server(serveur, config.getAttribute('nick'), config.getAttribute('owner'), config.getAttribute('realname')) srv.launch(mods, basedir + "/datas/") servers[srv.id] = srv -for imp in imports_launch: - mod = __import__ (imp) - mod.launch (servers) print ("Nemubot ready, my PID is %i!" % (os.getpid())) -prompt="" -while prompt != "quit": - prompt=sys.stdin.readlines () -sys.exit(0) +prompt = __import__ ("prompt") +while prompt.launch(servers): + imp.reload(prompt) + +onClose() diff --git a/prompt.py b/prompt.py new file mode 100644 index 0000000..fb0aeb9 --- /dev/null +++ b/prompt.py @@ -0,0 +1,149 @@ +import sys +import shlex +import traceback +import _thread +from xml.dom.minidom import parse + +import server + +selectedServer = None +MODS = list() + +def parsecmd(msg): + """Parse the command line""" + try: + cmds = shlex.split(msg) + if len(cmds) > 0: + cmds[0] = cmds[0].lower() + return cmds + except: + exc_type, exc_value, exc_traceback = sys.exc_info() + sys.stdout.write (traceback.format_exception_only(exc_type, exc_value)[0]) + return None + +def run(cmds, servers): + """Launch the command""" + if cmds[0] in CAPS: + return CAPS[cmds[0]](cmds, servers) + else: + print ("Unknown command: `%s'" % cmds[0]) + return "" + +def getPS1(): + """Get the PS1 associated to the selected server""" + if selectedServer is None: + return "nemubot" + else: + return selectedServer.id + +def launch(servers): + """Launch the prompt""" + ret = "" + cmds = list() + while ret != "quit" and ret != "reset": + sys.stdout.write("\033[0;33m%s§\033[0m " % getPS1()) + sys.stdout.flush() + try: + cmds = parsecmd(sys.stdin.readline().strip()) + except KeyboardInterrupt: + cmds = parsecmd("quit") + if cmds is not None and len(cmds) > 0: + try: + ret = run(cmds, servers) + except: + exc_type, exc_value, exc_traceback = sys.exc_info() + sys.stdout.write (traceback.format_exception_only(exc_type, exc_value)[0]) + return ret == "reset" + + +########################## +# # +# Permorming functions # +# # +########################## + +def load(cmds, servers): + if len(cmds) > 1: + for f in cmds[1:]: + dom = parse(f) + config = dom.getElementsByTagName('config')[0] + for serveur in config.getElementsByTagName('server'): + srv = server.Server(serveur, config.getAttribute('nick'), config.getAttribute('owner'), config.getAttribute('realname')) + if srv.id not in servers: + servers[srv.id] = srv + print (" Server `%s' successfully added." % srv.id) + else: + print (" Server `%s' already added, skiped." % srv.id) + else: + print ("Not enough arguments. `load' takes an filename.") + return + +def select(cmds, servers): + global selectedServer + if len(cmds) == 2 and cmds[1] != "None" and cmds[1] != "nemubot" and cmds[1] != "none": + if cmds[1] in servers: + selectedServer = servers[cmds[1]] + else: + print ("select: server `%s' not found." % cmds[1]) + else: + selectedServer = None + return + +def liste(cmds, servers): + if len(cmds) > 1: + for l in cmds[1:]: + l = l.lower() + if l == "server" or l == "servers": + for srv in servers.keys(): + print (" - %s ;" % srv) + else: + print (" Unknown list `%s'" % l) + else: + print (" Please give a list to show: servers, ...") + +def connect(cmds, servers): + if len(cmds) > 1: + for s in cmds[1:]: + if s in servers: + servers[s].launch(MODS) + else: + print ("connect: server `%s' not found." % s) + + elif selectedServer is not None: + selectedServer.launch(MODS) + else: + print (" Please SELECT a server or give its name in argument.") + +def disconnect(cmds, servers): + if len(cmds) > 1: + for s in cmds[1:]: + if s in servers: + if not servers[s].disconnect(): + print ("disconnect: server `%s' already disconnected." % s) + else: + print ("disconnect: server `%s' not found." % s) + elif selectedServer is not None: + if not selectedServer.disconnect(): + print ("disconnect: server `%s' already disconnected." % selectedServer.id) + else: + print (" Please SELECT a server or give its name in argument.") + +def end(cmds, servers): + if cmds[0] == "reset": + return "reset" + else: + for srv in servers.keys(): + servers[srv].disconnect() + return "quit" + +#Register build-ins +CAPS = { + 'quit': end, + 'exit': end, + 'reset': end, + 'load': load, + 'select': select, + 'list': liste, + 'connect': connect, + 'disconnect': disconnect, +} diff --git a/server.py b/server.py index c1ede7d..7ef8660 100644 --- a/server.py +++ b/server.py @@ -1,18 +1,16 @@ import sys import traceback import socket -import _thread +import threading import time import message -#class WaitedAnswer: -# def __init__(self, channel): -# self.channel = channel -# self.module - -class Server: +class Server(threading.Thread): def __init__(self, server, nick, owner, realname): + self.stop = False + self.stopping = threading.Event() + self.connected = False self.nick = nick self.owner = owner self.realname = realname @@ -30,8 +28,6 @@ class Server: else: self.password = None - self.waited_answer = list() - self.listen_nick = True self.partner = "nbr23" @@ -39,6 +35,8 @@ class Server: for channel in server.getElementsByTagName('channel'): self.channels.append(channel.getAttribute("name")) + threading.Thread.__init__(self) + @property def id(self): return self.host + ":" + str(self.port) @@ -65,43 +63,35 @@ class Server: self.send_msg (channel, msg, cmd, endl) - def register_answer(self, channel, ): - self.waited_answer.append(channel) - - def launch(self, mods, datas_dir): - self.datas_dir = datas_dir - _thread.start_new_thread(self.connect, (mods,)) - def accepted_channel(self, channel): if self.listen_nick: return self.channels.count(channel) or channel == self.nick else: return self.channels.count(channel) - def read(self, mods): - self.readbuffer = "" #Here we store all the messages from server - while 1: - try: - self.readbuffer = self.readbuffer + self.s.recv(1024).decode() #recieve server messages - except UnicodeDecodeError: - print ("ERREUR de décodage unicode") - continue - temp = self.readbuffer.split("\n") - self.readbuffer = temp.pop( ) + def disconnect(self): + if self.connected: + self.stop = True + self.s.shutdown(socket.SHUT_RDWR) + self.stopping.wait() + return True + else: + return False - for line in temp: - msg = message.Message (self, line) - try: - msg.treat (mods) - except: - print ("Une erreur est survenue lors du traitement du message : %s"%line) - exc_type, exc_value, exc_traceback = sys.exc_info() - traceback.print_exception(exc_type, exc_value, exc_traceback) + def launch(self, mods): + if not self.connected: + self.stop = False + #self.datas_dir = datas_dir #DEPRECATED + self.mods = mods + self.start() + else: + print (" Already connected.") - - def connect(self, mods): + def run(self): self.s = socket.socket( ) #Create the socket self.s.connect((self.host, self.port)) #Connect to server + self.stopping.clear() + self.connected = True if self.password != None: self.s.send(b"PASS " + self.password.encode () + b"\r\n") @@ -112,4 +102,26 @@ class Server: self.s.send(("JOIN %s\r\n" % ' '.join (self.channels)).encode ()) print ("Listen to channels: %s" % ' '.join (self.channels)) - self.read(mods) + readbuffer = "" #Here we store all the messages from server + while not self.stop: + try: + readbuffer = readbuffer + self.s.recv(1024).decode() #recieve server messages + except UnicodeDecodeError: + print ("ERREUR de décodage unicode") + continue + temp = readbuffer.split("\n") + readbuffer = temp.pop( ) + + for line in temp: + msg = message.Message (self, line) + try: + msg.treat (self.mods) + except: + print ("Une erreur est survenue lors du traitement du message : %s"%line) + exc_type, exc_value, exc_traceback = sys.exc_info() + traceback.print_exception(exc_type, exc_value, exc_traceback) + self.connected = False + print ("Server `%s' successfully stopped." % self.id) + self.stopping.set() + #Rearm Thread + threading.Thread.__init__(self)