diff --git a/modules/cmd_server.py b/modules/cmd_server.py deleted file mode 100644 index 6580c18..0000000 --- a/modules/cmd_server.py +++ /dev/null @@ -1,202 +0,0 @@ -# Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -import traceback -import sys - -from nemubot.hooks import hook - -nemubotversion = 3.4 -NODATA = True - - -def getserver(toks, context, prompt, mandatory=False, **kwargs): - """Choose the server in toks or prompt. - This function modify the tokens list passed as argument""" - - if len(toks) > 1 and toks[1] in context.servers: - return context.servers[toks.pop(1)] - elif not mandatory or prompt.selectedServer: - return prompt.selectedServer - else: - from nemubot.prompt.error import PromptError - raise PromptError("Please SELECT a server or give its name in argument.") - - -@hook("prompt_cmd", "close") -def close(toks, context, **kwargs): - """Disconnect and forget (remove from the servers list) the server""" - srv = getserver(toks, context=context, mandatory=True, **kwargs) - - if srv.close(): - del context.servers[srv.id] - return 0 - return 1 - - -@hook("prompt_cmd", "connect") -def connect(toks, **kwargs): - """Make the connexion to a server""" - srv = getserver(toks, mandatory=True, **kwargs) - - return not srv.open() - - -@hook("prompt_cmd", "disconnect") -def disconnect(toks, **kwargs): - """Close the connection to a server""" - srv = getserver(toks, mandatory=True, **kwargs) - - return not srv.close() - - -@hook("prompt_cmd", "discover") -def discover(toks, context, **kwargs): - """Discover a new bot on a server""" - srv = getserver(toks, context=context, mandatory=True, **kwargs) - - if len(toks) > 1 and "!" in toks[1]: - bot = context.add_networkbot(srv, name) - return not bot.connect() - else: - print(" %s is not a valid fullname, for example: " - "nemubot!nemubotV3@bot.nemunai.re" % ''.join(toks[1:1])) - return 1 - - -@hook("prompt_cmd", "join") -@hook("prompt_cmd", "leave") -@hook("prompt_cmd", "part") -def join(toks, **kwargs): - """Join or leave a channel""" - srv = getserver(toks, mandatory=True, **kwargs) - - if len(toks) <= 2: - print("%s: not enough arguments." % toks[0]) - return 1 - - if toks[0] == "join": - if len(toks) > 2: - srv.write("JOIN %s %s" % (toks[1], toks[2])) - else: - srv.write("JOIN %s" % toks[1]) - - elif toks[0] == "leave" or toks[0] == "part": - if len(toks) > 2: - srv.write("PART %s :%s" % (toks[1], " ".join(toks[2:]))) - else: - srv.write("PART %s" % toks[1]) - - return 0 - - -@hook("prompt_cmd", "save") -def save_mod(toks, context, **kwargs): - """Force save module data""" - if len(toks) < 2: - print("save: not enough arguments.") - return 1 - - wrn = 0 - for mod in toks[1:]: - if mod in context.modules: - context.modules[mod].save() - print("save: module `%s´ saved successfully" % mod) - else: - wrn += 1 - print("save: no module named `%s´" % mod) - return wrn - - -@hook("prompt_cmd", "send") -def send(toks, **kwargs): - """Send a message on a channel""" - srv = getserver(toks, mandatory=True, **kwargs) - - # Check the server is connected - if not srv.connected: - print ("send: server `%s' not connected." % srv.id) - return 2 - - if len(toks) <= 3: - print ("send: not enough arguments.") - return 1 - - if toks[1] not in srv.channels: - print ("send: channel `%s' not authorized in server `%s'." - % (toks[1], srv.id)) - return 3 - - from nemubot.message import Text - srv.send_response(Text(" ".join(toks[2:]), server=None, - to=[toks[1]])) - return 0 - - -@hook("prompt_cmd", "zap") -def zap(toks, **kwargs): - """Hard change connexion state""" - srv = getserver(toks, mandatory=True, **kwargs) - - srv.connected = not srv.connected - - -@hook("prompt_cmd", "top") -def top(toks, context, **kwargs): - """Display consumers load information""" - print("Queue size: %d, %d thread(s) running (counter: %d)" % - (context.cnsr_queue.qsize(), - len(context.cnsr_thrd), - context.cnsr_thrd_size)) - if len(context.events) > 0: - print("Events registered: %d, next in %d seconds" % - (len(context.events), - context.events[0].time_left.seconds)) - else: - print("No events registered") - - for th in context.cnsr_thrd: - if th.is_alive(): - print(("#" * 15 + " Stack trace for thread %u " + "#" * 15) % - th.ident) - traceback.print_stack(sys._current_frames()[th.ident]) - - -@hook("prompt_cmd", "netstat") -def netstat(toks, context, **kwargs): - """Display sockets in use and many other things""" - if len(context.network) > 0: - print("Distant bots connected: %d:" % len(context.network)) - for name, bot in context.network.items(): - print("# %s:" % name) - print(" * Declared hooks:") - lvl = 0 - for hlvl in bot.hooks: - lvl += 1 - for hook in (hlvl.all_pre + hlvl.all_post + hlvl.cmd_rgxp + - hlvl.cmd_default + hlvl.ask_rgxp + - hlvl.ask_default + hlvl.msg_rgxp + - hlvl.msg_default): - print(" %s- %s" % (' ' * lvl * 2, hook)) - for kind in ["irc_hook", "cmd_hook", "ask_hook", "msg_hook"]: - print(" %s- <%s> %s" % (' ' * lvl * 2, kind, - ", ".join(hlvl.__dict__[kind].keys()))) - print(" * My tag: %d" % bot.my_tag) - print(" * Tags in use (%d):" % bot.inc_tag) - for tag, (cmd, data) in bot.tags.items(): - print(" - %11s: %s « %s »" % (tag, cmd, data)) - else: - print("No distant bot connected") diff --git a/modules/rnd.py b/modules/rnd.py index 32c2adf..5329b06 100644 --- a/modules/rnd.py +++ b/modules/rnd.py @@ -8,7 +8,6 @@ import shlex from nemubot import context from nemubot.exception import IMException from nemubot.hooks import hook -from nemubot.message import Command from more import Response @@ -32,8 +31,24 @@ def cmd_choicecmd(msg): choice = shlex.split(random.choice(msg.args)) - return [x for x in context.subtreat(Command(choice[0][1:], - choice[1:], - to_response=msg.to_response, - frm=msg.frm, - server=msg.server))] + return [x for x in context.subtreat(context.subparse(msg, choice))] + + +@hook.command("choiceres") +def cmd_choiceres(msg): + if not len(msg.args): + raise IMException("indicate some command to pick a message from!") + + rl = [x for x in context.subtreat(context.subparse(msg, " ".join(msg.args)))] + if len(rl) <= 0: + return rl + + r = random.choice(rl) + + if isinstance(r, Response): + for i in range(len(r.messages) - 1, -1, -1): + if isinstance(r.messages[i], list): + r.messages = [ random.choice(random.choice(r.messages)) ] + elif isinstance(r.messages[i], str): + r.messages = [ random.choice(r.messages) ] + return r diff --git a/nemubot/__init__.py b/nemubot/__init__.py index d0a2072..cdc6265 100644 --- a/nemubot/__init__.py +++ b/nemubot/__init__.py @@ -17,11 +17,15 @@ __version__ = '4.0.dev3' __author__ = 'nemunaire' +from typing import Optional + from nemubot.modulecontext import ModuleContext + context = ModuleContext(None, None) -def requires_version(min=None, max=None): +def requires_version(min: Optional[int] = None, + max: Optional[int] = None) -> None: """Raise ImportError if the current version is not in the given range Keyword arguments: @@ -38,62 +42,98 @@ def requires_version(min=None, max=None): "but this is nemubot v%s." % (str(max), __version__)) -def reload(): - """Reload code of all Python modules used by nemubot +def attach(pid: int, socketfile: str) -> int: + import socket + import sys + + print("nemubot is already launched with PID %d. Attaching to Unix socket at: %s" % (pid, socketfile)) + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(socketfile) + except socket.error as e: + sys.stderr.write(str(e)) + sys.stderr.write("\n") + return 1 + + from select import select + try: + print("Connection established.") + while True: + rl, wl, xl = select([sys.stdin, sock], [], []) + + if sys.stdin in rl: + line = sys.stdin.readline().strip() + if line == "exit" or line == "quit": + return 0 + elif line == "reload": + import os, signal + os.kill(pid, signal.SIGHUP) + print("Reload signal sent. Please wait...") + + elif line == "shutdown": + import os, signal + os.kill(pid, signal.SIGTERM) + print("Shutdown signal sent. Please wait...") + + elif line == "kill": + import os, signal + os.kill(pid, signal.SIGKILL) + print("Signal sent...") + return 0 + + elif line == "stack" or line == "stacks": + import os, signal + os.kill(pid, signal.SIGUSR1) + print("Debug signal sent. Consult logs.") + + else: + sock.send(line.encode() + b'\r\n') + + if sock in rl: + sys.stdout.write(sock.recv(2048).decode()) + except KeyboardInterrupt: + pass + except: + return 1 + finally: + sock.close() + return 0 + + +def daemonize() -> None: + """Detach the running process to run as a daemon """ - import imp + import os + import sys - import nemubot.channel - imp.reload(nemubot.channel) + try: + pid = os.fork() + if pid > 0: + sys.exit(0) + except OSError as err: + sys.stderr.write("Unable to fork: %s\n" % err) + sys.exit(1) - import nemubot.config - imp.reload(nemubot.config) + os.setsid() + os.umask(0) + os.chdir('/') - nemubot.config.reload() + try: + pid = os.fork() + if pid > 0: + sys.exit(0) + except OSError as err: + sys.stderr.write("Unable to fork: %s\n" % err) + sys.exit(1) - import nemubot.consumer - imp.reload(nemubot.consumer) + sys.stdout.flush() + sys.stderr.flush() + si = open(os.devnull, 'r') + so = open(os.devnull, 'a+') + se = open(os.devnull, 'a+') - import nemubot.datastore - imp.reload(nemubot.datastore) - - nemubot.datastore.reload() - - import nemubot.event - imp.reload(nemubot.event) - - import nemubot.exception - imp.reload(nemubot.exception) - - nemubot.exception.reload() - - import nemubot.hooks - imp.reload(nemubot.hooks) - - nemubot.hooks.reload() - - import nemubot.importer - imp.reload(nemubot.importer) - - import nemubot.message - imp.reload(nemubot.message) - - nemubot.message.reload() - - import nemubot.prompt - imp.reload(nemubot.prompt) - - nemubot.prompt.reload() - - import nemubot.server - rl, wl, xl = nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist - imp.reload(nemubot.server) - nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist = rl, wl, xl - - nemubot.server.reload() - - import nemubot.tools - imp.reload(nemubot.tools) - - nemubot.tools.reload() + os.dup2(si.fileno(), sys.stdin.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) + os.dup2(se.fileno(), sys.stderr.fileno()) diff --git a/nemubot/__main__.py b/nemubot/__main__.py index 1809bee..cb9fae6 100644 --- a/nemubot/__main__.py +++ b/nemubot/__main__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,8 +14,9 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -def main(): +def main() -> None: import os + import signal import sys # Parse command line arguments @@ -36,6 +37,15 @@ def main(): default=["./modules/"], help="directory to use as modules store") + parser.add_argument("-d", "--debug", action="store_true", + help="don't deamonize, keep in foreground") + + parser.add_argument("-P", "--pidfile", default="./nemubot.pid", + help="Path to the file where store PID") + + parser.add_argument("-S", "--socketfile", default="./nemubot.sock", + help="path where open the socket for internal communication") + parser.add_argument("-l", "--logfile", default="./nemubot.log", help="Path to store logs") @@ -58,10 +68,34 @@ def main(): # Resolve relatives paths args.data_path = os.path.abspath(os.path.expanduser(args.data_path)) + args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile)) + args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile)) args.logfile = os.path.abspath(os.path.expanduser(args.logfile)) args.files = [ x for x in map(os.path.abspath, args.files)] args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)] + # Check if an instance is already launched + if args.pidfile is not None and os.path.isfile(args.pidfile): + with open(args.pidfile, "r") as f: + pid = int(f.readline()) + try: + os.kill(pid, 0) + except OSError: + pass + else: + from nemubot import attach + sys.exit(attach(pid, args.socketfile)) + + # Daemonize + if not args.debug: + from nemubot import daemonize + daemonize() + + # Store PID to pidfile + if args.pidfile is not None: + with open(args.pidfile, "w+") as f: + f.write(str(os.getpid())) + # Setup loggin interface import logging logger = logging.getLogger("nemubot") @@ -70,11 +104,12 @@ def main(): formatter = logging.Formatter( '%(asctime)s %(name)s %(levelname)s %(message)s') - ch = logging.StreamHandler() - ch.setFormatter(formatter) - if args.verbose < 2: - ch.setLevel(logging.INFO) - logger.addHandler(ch) + if args.debug: + ch = logging.StreamHandler() + ch.setFormatter(formatter) + if args.verbose < 2: + ch.setLevel(logging.INFO) + logger.addHandler(ch) fh = logging.FileHandler(args.logfile) fh.setFormatter(formatter) @@ -98,13 +133,10 @@ def main(): if args.no_connect: context.noautoconnect = True - # Load the prompt - import nemubot.prompt - prmpt = nemubot.prompt.Prompt() - # Register the hook for futur import from nemubot.importer import ModuleFinder - sys.meta_path.append(ModuleFinder(context.modules_paths, context.add_module)) + module_finder = ModuleFinder(context.modules_paths, context.add_module) + sys.meta_path.append(module_finder) # Load requested configuration files for path in args.files: @@ -117,36 +149,57 @@ def main(): for module in args.module: __import__(module) - print ("Nemubot v%s ready, my PID is %i!" % (nemubot.__version__, - os.getpid())) - while True: - from nemubot.prompt.reset import PromptReset - try: - context.start() - if prmpt.run(context): - break - except PromptReset as e: - if e.type == "quit": - break + # Signals handling + def sigtermhandler(signum, frame): + """On SIGTERM and SIGINT, quit nicely""" + sigusr1handler(signum, frame) + context.quit() + signal.signal(signal.SIGINT, sigtermhandler) + signal.signal(signal.SIGTERM, sigtermhandler) - try: - import imp - # Reload all other modules - imp.reload(nemubot) - imp.reload(nemubot.prompt) - nemubot.reload() - import nemubot.bot - context = nemubot.bot.hotswap(context) - prmpt = nemubot.prompt.hotswap(prmpt) - print("\033[1;32mContext reloaded\033[0m, now in Nemubot %s" % - nemubot.__version__) - except: - logger.exception("\033[1;31mUnable to reload the prompt due to " - "errors.\033[0m Fix them before trying to reload " - "the prompt.") + def sighuphandler(signum, frame): + """On SIGHUP, perform a deep reload""" + nonlocal context - context.quit() - print("Waiting for other threads shuts down...") + logger.debug("SIGHUP receive, iniate reload procedure...") + + # Reload configuration file + for path in args.files: + if os.path.isfile(path): + context.sync_queue.put_nowait(["loadconf", path]) + signal.signal(signal.SIGHUP, sighuphandler) + + def sigusr1handler(signum, frame): + """On SIGHUSR1, display stacktraces""" + import threading, traceback + for threadId, stack in sys._current_frames().items(): + thName = "#%d" % threadId + for th in threading.enumerate(): + if th.ident == threadId: + thName = th.name + break + logger.debug("########### Thread %s:\n%s", + thName, + "".join(traceback.format_stack(stack))) + signal.signal(signal.SIGUSR1, sigusr1handler) + + if args.socketfile: + from nemubot.server.socket import UnixSocketListener + context.add_server(UnixSocketListener(new_server_cb=context.add_server, + location=args.socketfile, + name="master_socket")) + + # context can change when performing an hotswap, always join the latest context + oldcontext = None + while oldcontext != context: + oldcontext = context + context.start() + context.join() + + # Wait for consumers + logger.info("Waiting for other threads shuts down...") + if args.debug: + sigusr1handler(0, None) sys.exit(0) if __name__ == "__main__": diff --git a/nemubot/bot.py b/nemubot/bot.py index 0adb587..10bcff7 100644 --- a/nemubot/bot.py +++ b/nemubot/bot.py @@ -15,9 +15,13 @@ # along with this program. If not, see . from datetime import datetime, timezone +import ipaddress import logging +from multiprocessing import JoinableQueue import threading +import select import sys +from typing import Any, Mapping, Optional, Sequence from nemubot import __version__ from nemubot.consumer import Consumer, EventConsumer, MessageConsumer @@ -26,22 +30,33 @@ import nemubot.hooks logger = logging.getLogger("nemubot") +sync_queue = JoinableQueue() + +def sync_act(*args): + if isinstance(act, bytes): + act = act.decode() + sync_queue.put(act) + class Bot(threading.Thread): """Class containing the bot context and ensuring key goals""" - def __init__(self, ip="127.0.0.1", modules_paths=list(), - data_store=datastore.Abstract(), verbosity=0): + def __init__(self, + ip: Optional[ipaddress] = None, + modules_paths: Sequence[str] = list(), + data_store: Optional[datastore.Abstract] = None, + verbosity: int = 0): """Initialize the bot context Keyword arguments: ip -- The external IP of the bot (default: 127.0.0.1) - modules_paths -- Paths to all directories where looking for module + modules_paths -- Paths to all directories where looking for modules data_store -- An instance of the nemubot datastore for bot's modules + verbosity -- verbosity level """ - threading.Thread.__init__(self) + super().__init__(name="Nemubot main") logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", __version__, @@ -51,16 +66,16 @@ class Bot(threading.Thread): self.stop = None # External IP for accessing this bot - import ipaddress - self.ip = ipaddress.ip_address(ip) + self.ip = ip if ip is not None else ipaddress.ip_address("127.0.0.1") # Context paths self.modules_paths = modules_paths - self.datastore = data_store + self.datastore = data_store if data_store is not None else datastore.Abstract() self.datastore.open() # Keep global context: servers and modules - self.servers = dict() + self._poll = select.poll() + self.servers = dict() # types: Mapping[str, AbstractServer] self.modules = dict() self.modules_configuration = dict() @@ -137,59 +152,76 @@ class Bot(threading.Thread): self.cnsr_queue = Queue() self.cnsr_thrd = list() self.cnsr_thrd_size = -1 - # Synchrone actions to be treated by main thread - self.sync_queue = Queue() def run(self): - from select import select - from nemubot.server import _lock, _rlist, _wlist, _xlist + self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI) + logger.info("Starting main loop") self.stop = False while not self.stop: - with _lock: - try: - rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) - except: - logger.error("Something went wrong in select") - fnd_smth = False - # Looking for invalid server - for r in _rlist: - if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0: - _rlist.remove(r) - logger.error("Found invalid object in _rlist: " + str(r)) - fnd_smth = True - for w in _wlist: - if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0: - _wlist.remove(w) - logger.error("Found invalid object in _wlist: " + str(w)) - fnd_smth = True - for x in _xlist: - if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0: - _xlist.remove(x) - logger.error("Found invalid object in _xlist: " + str(x)) - fnd_smth = True - if not fnd_smth: - logger.exception("Can't continue, sorry") - self.quit() - continue + for fd, flag in self._poll.poll(): + print("poll") + # Handle internal socket passing orders + if fd != sync_queue._reader.fileno(): + srv = self.servers[fd] - for x in xl: - try: - x.exception() - except: - logger.exception("Uncatched exception on server exception") - for w in wl: - try: - w.write_select() - except: - logger.exception("Uncatched exception on server write") - for r in rl: - for i in r.read(): + if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL): try: - self.receive_message(r, i) + srv.exception(flag) except: - logger.exception("Uncatched exception on server read") + logger.exception("Uncatched exception on server exception") + + if srv.fileno() > 0: + if flag & (select.POLLOUT): + try: + srv.async_write() + except: + logger.exception("Uncatched exception on server write") + + if flag & (select.POLLIN | select.POLLPRI): + try: + for i in srv.async_read(): + self.receive_message(srv, i) + except: + logger.exception("Uncatched exception on server read") + + else: + del self.servers[fd] + + + # Always check the sync queue + while not sync_queue.empty(): + import shlex + args = shlex.split(sync_queue.get()) + action = args.pop(0) + + logger.info("action: %s: %s", action, args) + + if action == "sckt" and len(args) >= 2: + try: + if args[0] == "write": + self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR) + elif args[0] == "unwrite": + self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR) + + elif args[0] == "register": + self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR) + elif args[0] == "unregister": + self._poll.unregister(int(args[1])) + except: + logger.exception("Unhandled excpetion during action:") + + elif action == "exit": + self.quit() + + elif action == "loadconf": + for path in action.args: + logger.debug("Load configuration from %s", path) + self.load_file(path) + logger.info("Configurations successfully loaded") + + sync_queue.task_done() # Launch new consumer threads if necessary @@ -200,17 +232,7 @@ class Bot(threading.Thread): c = Consumer(self) self.cnsr_thrd.append(c) c.start() - - while self.sync_queue.qsize() > 0: - action = self.sync_queue.get_nowait() - if action[0] == "exit": - self.quit() - elif action[0] == "loadconf": - for path in action[1:]: - logger.debug("Load configuration from %s", path) - self.load_file(path) - logger.info("Configurations successfully loaded") - self.sync_queue.task_done() + logger.info("Ending main loop") @@ -252,9 +274,9 @@ class Bot(threading.Thread): srv = server.server(config) # Add the server in the context if self.add_server(srv, server.autoconnect): - logger.info("Server '%s' successfully added." % srv.id) + logger.info("Server '%s' successfully added." % srv.name) else: - logger.error("Can't add server '%s'." % srv.id) + logger.error("Can't add server '%s'." % srv.name) # Load module and their configuration for mod in config.modules: @@ -303,7 +325,7 @@ class Bot(threading.Thread): if type(eid) is uuid.UUID: evt.id = str(eid) else: - # Ok, this is quite useless... + # Ok, this is quiet useless... try: evt.id = str(uuid.UUID(eid)) except ValueError: @@ -411,10 +433,11 @@ class Bot(threading.Thread): autoconnect -- connect after add? """ - if srv.id not in self.servers: - self.servers[srv.id] = srv + fileno = srv.fileno() + if fileno not in self.servers: + self.servers[fileno] = srv if autoconnect and not hasattr(self, "noautoconnect"): - srv.open() + srv.connect() return True else: @@ -436,10 +459,10 @@ class Bot(threading.Thread): __import__(name) - def add_module(self, module): + def add_module(self, mdl: Any): """Add a module to the context, if already exists, unload the old one before""" - module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ + module_name = mdl.__spec__.name if hasattr(mdl, "__spec__") else mdl.__name__ if hasattr(self, "stop") and self.stop: logger.warn("The bot is stopped, can't register new modules") @@ -451,40 +474,40 @@ class Bot(threading.Thread): # Overwrite print built-in def prnt(*args): - if hasattr(module, "logger"): - module.logger.info(" ".join([str(s) for s in args])) + if hasattr(mdl, "logger"): + mdl.logger.info(" ".join([str(s) for s in args])) else: logger.info("[%s] %s", module_name, " ".join([str(s) for s in args])) - module.print = prnt + mdl.print = prnt # Create module context from nemubot.modulecontext import ModuleContext - module.__nemubot_context__ = ModuleContext(self, module) + mdl.__nemubot_context__ = ModuleContext(self, mdl) - if not hasattr(module, "logger"): - module.logger = logging.getLogger("nemubot.module." + module_name) + if not hasattr(mdl, "logger"): + mdl.logger = logging.getLogger("nemubot.module." + module_name) # Replace imported context by real one - for attr in module.__dict__: - if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext: - module.__dict__[attr] = module.__nemubot_context__ + for attr in mdl.__dict__: + if attr != "__nemubot_context__" and type(mdl.__dict__[attr]) == ModuleContext: + mdl.__dict__[attr] = mdl.__nemubot_context__ # Register decorated functions import nemubot.hooks for s, h in nemubot.hooks.hook.last_registered: - module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s) + mdl.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s) nemubot.hooks.hook.last_registered = [] # Launch the module - if hasattr(module, "load"): + if hasattr(mdl, "load"): try: - module.load(module.__nemubot_context__) + mdl.load(mdl.__nemubot_context__) except: - module.__nemubot_context__.unload() + mdl.__nemubot_context__.unload() raise # Save a reference to the module - self.modules[module_name] = module + self.modules[module_name] = mdl def unload_module(self, name): @@ -527,28 +550,28 @@ class Bot(threading.Thread): def quit(self): """Save and unload modules and disconnect servers""" - self.datastore.close() - if self.event_timer is not None: logger.info("Stop the event timer...") self.event_timer.cancel() + logger.info("Save and unload all modules...") + for mod in self.modules.items(): + self.unload_module(mod) + + logger.info("Close all servers connection...") + for k in self.servers: + self.servers[k].close() + logger.info("Stop consumers") k = self.cnsr_thrd for cnsr in k: cnsr.stop = True - logger.info("Save and unload all modules...") - k = list(self.modules.keys()) - for mod in k: - self.unload_module(mod) - - logger.info("Close all servers connection...") - k = list(self.servers.keys()) - for srv in k: - self.servers[srv].close() + self.datastore.close() self.stop = True + sync_queue.put("end") + sync_queue.join() # Treatment @@ -562,20 +585,3 @@ class Bot(threading.Thread): del store[hook.name] elif isinstance(store, list): store.remove(hook) - - -def hotswap(bak): - bak.stop = True - if bak.event_timer is not None: - bak.event_timer.cancel() - bak.datastore.close() - - new = Bot(str(bak.ip), bak.modules_paths, bak.datastore) - new.servers = bak.servers - new.modules = bak.modules - new.modules_configuration = bak.modules_configuration - new.events = bak.events - new.hooks = bak.hooks - - new._update_event_timer() - return new diff --git a/nemubot/channel.py b/nemubot/channel.py index a070131..c01ac90 100644 --- a/nemubot/channel.py +++ b/nemubot/channel.py @@ -15,13 +15,18 @@ # along with this program. If not, see . import logging +from typing import Optional +from nemubot.message import Abstract as AbstractMessage class Channel: """A chat room""" - def __init__(self, name, password=None, encoding=None): + def __init__(self, + name: str, + password: Optional[str] = None, + encoding: Optional[str] = None): """Initialize the channel Arguments: @@ -37,7 +42,8 @@ class Channel: self.topic = "" self.logger = logging.getLogger("nemubot.channel." + name) - def treat(self, cmd, msg): + + def treat(self, cmd: str, msg: AbstractMessage) -> None: """Treat a incoming IRC command Arguments: @@ -60,7 +66,8 @@ class Channel: elif cmd == "TOPIC": self.topic = self.text - def join(self, nick, level=0): + + def join(self, nick: str, level: int = 0) -> None: """Someone join the channel Argument: @@ -71,7 +78,8 @@ class Channel: self.logger.debug("%s join", nick) self.people[nick] = level - def chtopic(self, newtopic): + + def chtopic(self, newtopic: str) -> None: """Send command to change the topic Arguments: @@ -81,7 +89,8 @@ class Channel: self.srv.send_msg(self.name, newtopic, "TOPIC") self.topic = newtopic - def nick(self, oldnick, newnick): + + def nick(self, oldnick: str, newnick: str) -> None: """Someone change his nick Arguments: @@ -95,7 +104,8 @@ class Channel: del self.people[oldnick] self.people[newnick] = lvl - def part(self, nick): + + def part(self, nick: str) -> None: """Someone leave the channel Argument: @@ -106,7 +116,8 @@ class Channel: self.logger.debug("%s has left", nick) del self.people[nick] - def mode(self, msg): + + def mode(self, msg: AbstractMessage) -> None: """Channel or user mode change Argument: @@ -132,7 +143,8 @@ class Channel: elif msg.text[0] == "-v": self.people[msg.nick] &= ~1 - def parse332(self, msg): + + def parse332(self, msg: AbstractMessage) -> None: """Parse RPL_TOPIC message Argument: @@ -141,7 +153,8 @@ class Channel: self.topic = msg.text - def parse353(self, msg): + + def parse353(self, msg: AbstractMessage) -> None: """Parse RPL_ENDOFWHO message Argument: diff --git a/nemubot/config/__init__.py b/nemubot/config/__init__.py index 7e0b74a..ea6fed4 100644 --- a/nemubot/config/__init__.py +++ b/nemubot/config/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -def get_boolean(s): +def get_boolean(s) -> bool: if isinstance(s, bool): return s else: @@ -24,24 +24,3 @@ from nemubot.config.include import Include from nemubot.config.module import Module from nemubot.config.nemubot import Nemubot from nemubot.config.server import Server - -def reload(): - global Include, Module, Nemubot, Server - - import imp - - import nemubot.config.include - imp.reload(nemubot.config.include) - Include = nemubot.config.include.Include - - import nemubot.config.module - imp.reload(nemubot.config.module) - Module = nemubot.config.module.Module - - import nemubot.config.nemubot - imp.reload(nemubot.config.nemubot) - Nemubot = nemubot.config.nemubot.Nemubot - - import nemubot.config.server - imp.reload(nemubot.config.server) - Server = nemubot.config.server.Server diff --git a/nemubot/config/include.py b/nemubot/config/include.py index 408c09a..aca6468 100644 --- a/nemubot/config/include.py +++ b/nemubot/config/include.py @@ -16,5 +16,5 @@ class Include: - def __init__(self, path): + def __init__(self, path: str): self.path = path diff --git a/nemubot/config/module.py b/nemubot/config/module.py index ab51971..e67a45b 100644 --- a/nemubot/config/module.py +++ b/nemubot/config/module.py @@ -15,12 +15,16 @@ # along with this program. If not, see . from nemubot.config import get_boolean -from nemubot.tools.xmlparser.genericnode import GenericNode +from nemubot.datastore.nodes.generic import GenericNode class Module(GenericNode): - def __init__(self, name, autoload=True, **kwargs): + def __init__(self, + name: str, + autoload: bool = True, + **kwargs): super().__init__(None, **kwargs) + self.name = name self.autoload = get_boolean(autoload) diff --git a/nemubot/config/nemubot.py b/nemubot/config/nemubot.py index 992cd8e..cc60f86 100644 --- a/nemubot/config/nemubot.py +++ b/nemubot/config/nemubot.py @@ -14,15 +14,23 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from nemubot.config.include import Include -from nemubot.config.module import Module -from nemubot.config.server import Server +from typing import Optional, Sequence, Union + +import nemubot.config.include +import nemubot.config.module +import nemubot.config.server class Nemubot: - def __init__(self, nick="nemubot", realname="nemubot", owner=None, - ip=None, ssl=False, caps=None, encoding="utf-8"): + def __init__(self, + nick: str = "nemubot", + realname: str = "nemubot", + owner: Optional[str] = None, + ip: Optional[str] = None, + ssl: bool = False, + caps: Optional[Sequence[str]] = None, + encoding: str = "utf-8"): self.nick = nick self.realname = realname self.owner = owner @@ -34,13 +42,13 @@ class Nemubot: self.includes = [] - def addChild(self, name, child): - if name == "module" and isinstance(child, Module): + def addChild(self, name: str, child: Union[nemubot.config.module.Module, nemubot.config.server.Server, nemubot.config.include.Include]): + if name == "module" and isinstance(child, nemubot.config.module.Module): self.modules.append(child) return True - elif name == "server" and isinstance(child, Server): + elif name == "server" and isinstance(child, nemubot.config.server.Server): self.servers.append(child) return True - elif name == "include" and isinstance(child, Include): + elif name == "include" and isinstance(child, nemubot.config.include.Include): self.includes.append(child) return True diff --git a/nemubot/config/server.py b/nemubot/config/server.py index 14ca9a8..b8df692 100644 --- a/nemubot/config/server.py +++ b/nemubot/config/server.py @@ -14,12 +14,19 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Optional, Sequence + from nemubot.channel import Channel +import nemubot.config.nemubot class Server: - def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs): + def __init__(self, + uri: str = "irc://nemubot@localhost/", + autoconnect: bool = True, + caps: Optional[Sequence[str]] = None, + **kwargs): self.uri = uri self.autoconnect = autoconnect self.caps = caps.split(" ") if caps is not None else [] @@ -27,7 +34,7 @@ class Server: self.channels = [] - def addChild(self, name, child): + def addChild(self, name: str, child: Channel): if name == "channel" and isinstance(child, Channel): self.channels.append(child) return True diff --git a/nemubot/consumer.py b/nemubot/consumer.py index 886c4cf..8ea5a40 100644 --- a/nemubot/consumer.py +++ b/nemubot/consumer.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -17,6 +17,12 @@ import logging import queue import threading +from typing import List + +from nemubot.bot import Bot +from nemubot.event import ModuleEvent +from nemubot.message.abstract import Abstract as AbstractMessage +from nemubot.server.abstract import AbstractServer logger = logging.getLogger("nemubot.consumer") @@ -25,18 +31,15 @@ class MessageConsumer: """Store a message before treating""" - def __init__(self, srv, msg): + def __init__(self, srv: AbstractServer, msg: AbstractMessage): self.srv = srv self.orig = msg - def run(self, context): + def run(self, context: Bot) -> None: """Create, parse and treat the message""" - from nemubot.bot import Bot - assert isinstance(context, Bot) - - msgs = [] + msgs = [] # type: List[AbstractMessage] # Parse the message try: @@ -44,14 +47,14 @@ class MessageConsumer: msgs.append(msg) except: logger.exception("Error occurred during the processing of the %s: " - "%s", type(self.msgs[0]).__name__, self.msgs[0]) + "%s", type(self.orig).__name__, self.orig) if len(msgs) <= 0: return # Qualify the message if not hasattr(msg, "server") or msg.server is None: - msg.server = self.srv.id + msg.server = self.srv.name if hasattr(msg, "frm_owner"): msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm) @@ -62,15 +65,19 @@ class MessageConsumer: to_server = None if isinstance(res, str): to_server = self.srv + elif not hasattr(res, "server"): + logger.error("No server defined for response of type %s: %s", type(res).__name__, res) + continue elif res.server is None: to_server = self.srv - res.server = self.srv.id - elif isinstance(res.server, str) and res.server in context.servers: + res.server = self.srv.name + elif res.server in context.servers: to_server = context.servers[res.server] + else: + to_server = res.server - if to_server is None: - logger.error("The server defined in this response doesn't " - "exist: %s", res.server) + if to_server is None or not hasattr(to_server, "send_response") or not callable(to_server.send_response): + logger.error("The server defined in this response doesn't exist: %s", res.server) continue # Sent the message only if treat_post authorize it @@ -81,12 +88,12 @@ class EventConsumer: """Store a event before treating""" - def __init__(self, evt, timeout=20): + def __init__(self, evt: ModuleEvent, timeout: int = 20): self.evt = evt self.timeout = timeout - def run(self, context): + def run(self, context: Bot) -> None: try: self.evt.check() except: @@ -107,13 +114,13 @@ class Consumer(threading.Thread): """Dequeue and exec requested action""" - def __init__(self, context): + def __init__(self, context: Bot): self.context = context self.stop = False - threading.Thread.__init__(self) + super().__init__(name="Nemubot consumer") - def run(self): + def run(self) -> None: try: while not self.stop: stm = self.context.cnsr_queue.get(True, 1) diff --git a/nemubot/datastore/__init__.py b/nemubot/datastore/__init__.py index 411eab1..3e38ad2 100644 --- a/nemubot/datastore/__init__.py +++ b/nemubot/datastore/__init__.py @@ -16,16 +16,3 @@ from nemubot.datastore.abstract import Abstract from nemubot.datastore.xml import XML - - -def reload(): - global Abstract, XML - import imp - - import nemubot.datastore.abstract - imp.reload(nemubot.datastore.abstract) - Abstract = nemubot.datastore.abstract.Abstract - - import nemubot.datastore.xml - imp.reload(nemubot.datastore.xml) - XML = nemubot.datastore.xml.XML diff --git a/nemubot/datastore/abstract.py b/nemubot/datastore/abstract.py index 96e2c0d..856851f 100644 --- a/nemubot/datastore/abstract.py +++ b/nemubot/datastore/abstract.py @@ -23,14 +23,16 @@ class Abstract: """Initialize a new empty storage tree """ - from nemubot.tools.xmlparser import module_state - return module_state.ModuleState("nemubotstate") + return None - def open(self): - return - def close(self): - return + def open(self) -> bool: + return True + + + def close(self) -> bool: + return True + def load(self, module): """Load data for the given module @@ -44,6 +46,7 @@ class Abstract: return self.new() + def save(self, module, data): """Load data for the given module @@ -57,9 +60,11 @@ class Abstract: return True + def __enter__(self): self.open() return self + def __exit__(self, type, value, traceback): self.close() diff --git a/nemubot/prompt/error.py b/nemubot/datastore/nodes/__init__.py similarity index 78% rename from nemubot/prompt/error.py rename to nemubot/datastore/nodes/__init__.py index f86b5a1..e4b2788 100644 --- a/nemubot/prompt/error.py +++ b/nemubot/datastore/nodes/__init__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,8 +14,5 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -class PromptError(Exception): - - def __init__(self, message): - super(PromptError, self).__init__(message) - self.message = message +from nemubot.datastore.nodes.generic import ParsingNode +from nemubot.datastore.nodes.serializable import Serializable diff --git a/nemubot/tools/xmlparser/basic.py b/nemubot/datastore/nodes/basic.py similarity index 51% rename from nemubot/tools/xmlparser/basic.py rename to nemubot/datastore/nodes/basic.py index 8456629..a4467b2 100644 --- a/nemubot/tools/xmlparser/basic.py +++ b/nemubot/datastore/nodes/basic.py @@ -14,27 +14,38 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -class ListNode: +from typing import Any, Mapping, Sequence + +from nemubot.datastore.nodes.generic import ParsingNode +from nemubot.datastore.nodes.serializable import Serializable + + +class ListNode(Serializable): """XML node representing a Python dictionnnary """ + serializetag = "list" + def __init__(self, **kwargs): - self.items = list() + self.items = list() # type: Sequence - def addChild(self, name, child): + def addChild(self, name: str, child) -> bool: self.items.append(child) return True + def parsedForm(self) -> Sequence: + return self.items - def __len__(self): + + def __len__(self) -> int: return len(self.items) - def __getitem__(self, item): + def __getitem__(self, item: int) -> Any: return self.items[item] - def __setitem__(self, item, v): + def __setitem__(self, item: int, v: Any) -> None: self.items[item] = v def __contains__(self, item): @@ -44,65 +55,60 @@ class ListNode: return self.items.__repr__() -class DictNode: + def serialize(self) -> ParsingNode: + node = ParsingNode(tag=self.serializetag) + for i in self.items: + node.children.append(ParsingNode.serialize_node(i)) + return node + + +class DictNode(Serializable): """XML node representing a Python dictionnnary """ + serializetag = "dict" + def __init__(self, **kwargs): self.items = dict() self._cur = None - def startElement(self, name, attrs): + def startElement(self, name: str, attrs: Mapping[str, str]): if self._cur is None and "key" in attrs: - self._cur = (attrs["key"], "") - return True + self._cur = attrs["key"] return False - - def characters(self, content): - if self._cur is not None: - key, cnt = self._cur - if isinstance(cnt, str): - cnt += content - self._cur = key, cnt - - - def endElement(self, name): - if name is None or self._cur is None: - return - - key, cnt = self._cur - if isinstance(cnt, list) and len(cnt) == 1: - self.items[key] = cnt - else: - self.items[key] = cnt - - self._cur = None - return True - - - def addChild(self, name, child): + def addChild(self, name: str, child: Any): if self._cur is None: return False - key, cnt = self._cur - if not isinstance(cnt, list): - cnt = [] - cnt.append(child) - self._cur = key, cnt + self.items[self._cur] = child + self._cur = None return True + def parsedForm(self) -> Mapping: + return self.items - def __getitem__(self, item): + + def __getitem__(self, item: str) -> Any: return self.items[item] - def __setitem__(self, item, v): + def __setitem__(self, item: str, v: str) -> None: self.items[item] = v - def __contains__(self, item): + def __contains__(self, item: str) -> bool: return item in self.items - def __repr__(self): + def __repr__(self) -> str: return self.items.__repr__() + + + def serialize(self) -> ParsingNode: + from nemubot.datastore.nodes.generic import ParsingNode + node = ParsingNode(tag=self.serializetag) + for k in self.items: + chld = ParsingNode.serialize_node(self.items[k]) + chld.attrs["key"] = k + node.children.append(chld) + return node diff --git a/nemubot/tools/xmlparser/genericnode.py b/nemubot/datastore/nodes/generic.py similarity index 50% rename from nemubot/tools/xmlparser/genericnode.py rename to nemubot/datastore/nodes/generic.py index 9c29a23..939019c 100644 --- a/nemubot/tools/xmlparser/genericnode.py +++ b/nemubot/datastore/nodes/generic.py @@ -14,57 +14,109 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Any, Optional, Mapping, Union + +from nemubot.datastore.nodes.serializable import Serializable + + class ParsingNode: """Allow any kind of subtags, just keep parsed ones """ - def __init__(self, tag=None, **kwargs): + def __init__(self, + tag: Optional[str] = None, + **kwargs): self.tag = tag self.attrs = kwargs self.content = "" self.children = [] - def characters(self, content): + def characters(self, content: str) -> None: self.content += content - def addChild(self, name, child): + def addChild(self, name: str, child: Any) -> bool: self.children.append(child) return True - def hasNode(self, nodename): + def hasNode(self, nodename: str) -> bool: return self.getNode(nodename) is not None - def getNode(self, nodename): + def getNode(self, nodename: str) -> Optional[Any]: for c in self.children: if c is not None and c.tag == nodename: return c return None - def __getitem__(self, item): + def __getitem__(self, item: str) -> Any: return self.attrs[item] - def __contains__(self, item): + def __contains__(self, item: str) -> bool: return item in self.attrs + def serialize_node(node: Union[Serializable, str, int, float, list, dict], + **def_kwargs): + """Serialize any node or basic data to a ParsingNode instance""" + + if isinstance(node, Serializable): + node = node.serialize() + + if isinstance(node, str): + from nemubot.datastore.nodes.python import StringNode + pn = StringNode(**def_kwargs) + pn.value = node + return pn + + elif isinstance(node, int): + from nemubot.datastore.nodes.python import IntNode + pn = IntNode(**def_kwargs) + pn.value = node + return pn + + elif isinstance(node, float): + from nemubot.datastore.nodes.python import FloatNode + pn = FloatNode(**def_kwargs) + pn.value = node + return pn + + elif isinstance(node, list): + from nemubot.datastore.nodes.basic import ListNode + pn = ListNode(**def_kwargs) + pn.items = node + return pn.serialize() + + elif isinstance(node, dict): + from nemubot.datastore.nodes.basic import DictNode + pn = DictNode(**def_kwargs) + pn.items = node + return pn.serialize() + + else: + assert isinstance(node, ParsingNode) + return node + + class GenericNode(ParsingNode): """Consider all subtags as dictionnary """ - def __init__(self, tag, **kwargs): + def __init__(self, + tag: str, + **kwargs): super().__init__(tag, **kwargs) + self._cur = None self._deep_cur = 0 - def startElement(self, name, attrs): + def startElement(self, name: str, attrs: Mapping[str, str]): if self._cur is None: self._cur = GenericNode(name, **attrs) self._deep_cur = 0 @@ -74,14 +126,14 @@ class GenericNode(ParsingNode): return True - def characters(self, content): + def characters(self, content: str): if self._cur is None: super().characters(content) else: self._cur.characters(content) - def endElement(self, name): + def endElement(self, name: str): if name is None: return diff --git a/nemubot/datastore/nodes/python.py b/nemubot/datastore/nodes/python.py new file mode 100644 index 0000000..819bf21 --- /dev/null +++ b/nemubot/datastore/nodes/python.py @@ -0,0 +1,89 @@ +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2016 Mercier Pierre-Olivier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from nemubot.datastore.nodes.generic import ParsingNode +from nemubot.datastore.nodes.serializable import Serializable + + +class PythonTypeNode(Serializable): + + """XML node representing a Python simple type + """ + + def __init__(self, **kwargs): + self.value = None + self._cnt = "" + + + def characters(self, content: str) -> None: + self._cnt += content + + + def endElement(self, name: str) -> None: + raise NotImplemented + + + def __repr__(self) -> str: + return self.value.__repr__() + + + def parsedForm(self): + return self.value + + def serialize(self): + raise NotImplemented + + +class IntNode(PythonTypeNode): + + serializetag = "int" + + def endElement(self, name: str) -> bool: + self.value = int(self._cnt) + return True + + def serialize(self) -> ParsingNode: + node = ParsingNode(tag=self.serializetag) + node.content = str(self.value) + return node + + +class FloatNode(PythonTypeNode): + + serializetag = "float" + + def endElement(self, name: str) -> bool: + self.value = float(self._cnt) + return True + + def serialize(self) -> ParsingNode: + node = ParsingNode(tag=self.serializetag) + node.content = str(self.value) + return node + + +class StringNode(PythonTypeNode): + + serializetag = "str" + + def endElement(self, name): + self.value = str(self._cnt) + return True + + def serialize(self): + node = ParsingNode(tag=self.serializetag) + node.content = str(self.value) + return node diff --git a/nemubot/prompt/reset.py b/nemubot/datastore/nodes/serializable.py similarity index 76% rename from nemubot/prompt/reset.py rename to nemubot/datastore/nodes/serializable.py index 57da9f8..e543699 100644 --- a/nemubot/prompt/reset.py +++ b/nemubot/datastore/nodes/serializable.py @@ -1,7 +1,5 @@ -# coding=utf-8 - # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -16,8 +14,9 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -class PromptReset(Exception): - def __init__(self, type): - super(PromptReset, self).__init__("Prompt reset asked") - self.type = type +class Serializable: + + def serialize(self): + # Implementations of this function should return ParsingNode items + return NotImplemented diff --git a/nemubot/datastore/xml.py b/nemubot/datastore/xml.py index 46dca70..abf1492 100644 --- a/nemubot/datastore/xml.py +++ b/nemubot/datastore/xml.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -17,6 +17,7 @@ import fcntl import logging import os +from typing import Any, Mapping import xml.parsers.expat from nemubot.datastore.abstract import Abstract @@ -28,7 +29,9 @@ class XML(Abstract): """A concrete implementation of a data store that relies on XML files""" - def __init__(self, basedir, rotate=True): + def __init__(self, + basedir: str, + rotate: bool = True): """Initialize the datastore Arguments: @@ -36,17 +39,24 @@ class XML(Abstract): rotate -- auto-backup files? """ - self.basedir = basedir + self.basedir = os.path.abspath(basedir) self.rotate = rotate self.nb_save = 0 - def open(self): + logger.info("Initiate XML datastore at %s, rotation %s", + self.basedir, + "enabled" if self.rotate else "disabled") + + + def open(self) -> bool: """Lock the directory""" if not os.path.isdir(self.basedir): + logger.debug("Datastore directory not found, creating: %s", self.basedir) os.mkdir(self.basedir) - lock_path = os.path.join(self.basedir, ".used_by_nemubot") + lock_path = self._get_lock_file_path() + logger.debug("Locking datastore directory via %s", lock_path) self.lock_file = open(lock_path, 'a+') ok = True @@ -64,57 +74,92 @@ class XML(Abstract): self.lock_file.write(str(os.getpid())) self.lock_file.flush() + logger.info("Datastore successfuly opened at %s", self.basedir) return True - def close(self): + + def close(self) -> bool: """Release a locked path""" if hasattr(self, "lock_file"): self.lock_file.close() - lock_path = os.path.join(self.basedir, ".used_by_nemubot") + lock_path = self._get_lock_file_path() if os.path.isdir(self.basedir) and os.path.exists(lock_path): os.unlink(lock_path) del self.lock_file + logger.info("Datastore successfully closed at %s", self.basedir) return True + else: + logger.warn("Datastore not open/locked or lock file not found") return False - def _get_data_file_path(self, module): + + def _get_data_file_path(self, module: str) -> str: """Get the path to the module data file""" return os.path.join(self.basedir, module + ".xml") - def load(self, module): + + def _get_lock_file_path(self) -> str: + """Get the path to the datastore lock file""" + + return os.path.join(self.basedir, ".used_by_nemubot") + + + def load(self, module: str, extendsTags: Mapping[str, Any] = {}) -> Abstract: """Load data for the given module Argument: module -- the module name of data to load """ + logger.debug("Trying to load data for %s%s", + module, + (" with tags: " + ", ".join(extendsTags.keys())) if len(extendsTags) else "") + data_file = self._get_data_file_path(module) + def parse(path: str): + from nemubot.tools.xmlparser import XMLParser + from nemubot.datastore.nodes import basic as basicNodes + from nemubot.datastore.nodes import python as pythonNodes + + d = { + basicNodes.ListNode.serializetag: basicNodes.ListNode, + basicNodes.DictNode.serializetag: basicNodes.DictNode, + pythonNodes.IntNode.serializetag: pythonNodes.IntNode, + pythonNodes.FloatNode.serializetag: pythonNodes.FloatNode, + pythonNodes.StringNode.serializetag: pythonNodes.StringNode, + } + d.update(extendsTags) + + p = XMLParser(d) + return p.parse_file(path) + # Try to load original file if os.path.isfile(data_file): - from nemubot.tools.xmlparser import parse_file try: - return parse_file(data_file) + return parse(data_file) except xml.parsers.expat.ExpatError: # Try to load from backup for i in range(10): path = data_file + "." + str(i) if os.path.isfile(path): try: - cnt = parse_file(path) + cnt = parse(path) - logger.warn("Restoring from backup: %s", path) + logger.warn("Restoring data from backup: %s", path) return cnt except xml.parsers.expat.ExpatError: continue # Default case: initialize a new empty datastore + logger.warn("No data found in store for %s, creating new set", module) return Abstract.load(self, module) - def _rotate(self, path): + + def _rotate(self, path: str) -> None: """Backup given path Argument: @@ -130,7 +175,26 @@ class XML(Abstract): if os.path.isfile(src): os.rename(src, dst) - def save(self, module, data): + + def _save_node(self, gen, node: Any): + from nemubot.datastore.nodes.generic import ParsingNode + + # First, get the serialized form of the node + node = ParsingNode.serialize_node(node) + + assert node.tag is not None, "Undefined tag name" + + gen.startElement(node.tag, {k: str(node.attrs[k]) for k in node.attrs}) + + gen.characters(node.content) + + for child in node.children: + self._save_node(gen, child) + + gen.endElement(node.tag) + + + def save(self, module: str, data: Any) -> bool: """Load data for the given module Argument: @@ -139,8 +203,22 @@ class XML(Abstract): """ path = self._get_data_file_path(module) + logger.debug("Trying to save data for module %s in %s", module, path) if self.rotate: self._rotate(path) - return data.save(path) + import tempfile + _, tmpath = tempfile.mkstemp() + with open(tmpath, "w") as f: + import xml.sax.saxutils + gen = xml.sax.saxutils.XMLGenerator(f, "utf-8") + gen.startDocument() + self._save_node(gen, data) + gen.endDocument() + + # Atomic save + import shutil + shutil.move(tmpath, path) + + return True diff --git a/nemubot/event/__init__.py b/nemubot/event/__init__.py index 7b2adfd..ab96efb 100644 --- a/nemubot/event/__init__.py +++ b/nemubot/event/__init__.py @@ -15,15 +15,23 @@ # along with this program. If not, see . from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Optional class ModuleEvent: """Representation of a event initiated by a bot module""" - def __init__(self, call=None, call_data=None, func=None, func_data=None, - cmp=None, cmp_data=None, interval=60, offset=0, times=1): - + def __init__(self, + call: Callable = None, + call_data: Callable = None, + func: Callable = None, + func_data: Any = None, + cmp: Any = None, + cmp_data: Any = None, + interval: int = 60, + offset: int = 0, + times: int = 1): """Initialize the event Keyword arguments: @@ -70,8 +78,9 @@ class ModuleEvent: # How many times do this event? self.times = times + @property - def current(self): + def current(self) -> Optional[datetime.datetime]: """Return the date of the near check""" if self.times != 0: if self._end is None: @@ -79,8 +88,9 @@ class ModuleEvent: return self._end return None + @property - def next(self): + def next(self) -> Optional[datetime.datetime]: """Return the date of the next check""" if self.times != 0: if self._end is None: @@ -90,14 +100,16 @@ class ModuleEvent: return self._end return None + @property - def time_left(self): + def time_left(self) -> Union[datetime.datetime, int]: """Return the time left before/after the near check""" if self.current is not None: return self.current - datetime.now(timezone.utc) return 99999 # TODO: 99999 is not a valid time to return - def check(self): + + def check(self) -> None: """Run a check and realized the event if this is time""" # Get initial data diff --git a/nemubot/exception/__init__.py b/nemubot/exception/__init__.py index 1e34923..84464a0 100644 --- a/nemubot/exception/__init__.py +++ b/nemubot/exception/__init__.py @@ -32,10 +32,3 @@ class IMException(Exception): from nemubot.message import Text return Text(*self.args, server=msg.server, to=msg.to_response) - - -def reload(): - import imp - - import nemubot.exception.Keyword - imp.reload(nemubot.exception.printer.IRC) diff --git a/nemubot/hooks/__init__.py b/nemubot/hooks/__init__.py index e9113eb..9024494 100644 --- a/nemubot/hooks/__init__.py +++ b/nemubot/hooks/__init__.py @@ -49,23 +49,3 @@ class hook: def pre(*args, store=["pre"], **kwargs): return hook._add(store, Abstract, *args, **kwargs) - - -def reload(): - import imp - - import nemubot.hooks.abstract - imp.reload(nemubot.hooks.abstract) - - import nemubot.hooks.command - imp.reload(nemubot.hooks.command) - - import nemubot.hooks.message - imp.reload(nemubot.hooks.message) - - import nemubot.hooks.keywords - imp.reload(nemubot.hooks.keywords) - nemubot.hooks.keywords.reload() - - import nemubot.hooks.manager - imp.reload(nemubot.hooks.manager) diff --git a/nemubot/hooks/keywords/__init__.py b/nemubot/hooks/keywords/__init__.py index 4b6419a..598b04f 100644 --- a/nemubot/hooks/keywords/__init__.py +++ b/nemubot/hooks/keywords/__init__.py @@ -26,11 +26,22 @@ class NoKeyword(Abstract): return super().check(mkw) -def reload(): - import imp +class AnyKeyword(Abstract): - import nemubot.hooks.keywords.abstract - imp.reload(nemubot.hooks.keywords.abstract) + def __init__(self, h): + """Class that accepts any passed keywords - import nemubot.hooks.keywords.dict - imp.reload(nemubot.hooks.keywords.dict) + Arguments: + h -- Help string + """ + + super().__init__() + self.h = h + + + def check(self, mkw): + return super().check(mkw) + + + def help(self): + return self.h diff --git a/nemubot/hooks/manager.py b/nemubot/hooks/manager.py index 6a57d2a..9d57483 100644 --- a/nemubot/hooks/manager.py +++ b/nemubot/hooks/manager.py @@ -21,7 +21,7 @@ class HooksManager: """Class to manage hooks""" - def __init__(self, name="core"): + def __init__(self, name: str = "core"): """Initialize the manager""" self.hooks = dict() diff --git a/nemubot/importer.py b/nemubot/importer.py index eaf1535..2827da9 100644 --- a/nemubot/importer.py +++ b/nemubot/importer.py @@ -18,16 +18,18 @@ from importlib.abc import Finder from importlib.machinery import SourceFileLoader import logging import os +from typing import Callable logger = logging.getLogger("nemubot.importer") class ModuleFinder(Finder): - def __init__(self, modules_paths, add_module): + def __init__(self, modules_paths: str, add_module: Callable[]): self.modules_paths = modules_paths self.add_module = add_module + def find_module(self, fullname, path=None): # Search only for new nemubot modules (packages init) if path is None: diff --git a/nemubot/message/__init__.py b/nemubot/message/__init__.py index 31d7313..4d69dbb 100644 --- a/nemubot/message/__init__.py +++ b/nemubot/message/__init__.py @@ -19,27 +19,3 @@ from nemubot.message.text import Text from nemubot.message.directask import DirectAsk from nemubot.message.command import Command from nemubot.message.command import OwnerCommand - - -def reload(): - global Abstract, Text, DirectAsk, Command, OwnerCommand - import imp - - import nemubot.message.abstract - imp.reload(nemubot.message.abstract) - Abstract = nemubot.message.abstract.Abstract - imp.reload(nemubot.message.text) - Text = nemubot.message.text.Text - imp.reload(nemubot.message.directask) - DirectAsk = nemubot.message.directask.DirectAsk - imp.reload(nemubot.message.command) - Command = nemubot.message.command.Command - OwnerCommand = nemubot.message.command.OwnerCommand - - import nemubot.message.visitor - imp.reload(nemubot.message.visitor) - - import nemubot.message.printer - imp.reload(nemubot.message.printer) - - nemubot.message.printer.reload() diff --git a/nemubot/message/command.py b/nemubot/message/command.py index 895d16e..6c208b2 100644 --- a/nemubot/message/command.py +++ b/nemubot/message/command.py @@ -22,7 +22,7 @@ class Command(Abstract): """This class represents a specialized TextMessage""" def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs): - Abstract.__init__(self, *nargs, **kargs) + super().__init__(*nargs, **kargs) self.cmd = cmd self.args = args if args is not None else list() diff --git a/nemubot/message/directask.py b/nemubot/message/directask.py index 03c7902..3b1fabb 100644 --- a/nemubot/message/directask.py +++ b/nemubot/message/directask.py @@ -28,7 +28,7 @@ class DirectAsk(Text): designated -- the user designated by the message """ - Text.__init__(self, *args, **kargs) + super().__init__(*args, **kargs) self.designated = designated diff --git a/nemubot/message/printer/IRC.py b/nemubot/message/printer/IRC.py index 320366c..df9cb9f 100644 --- a/nemubot/message/printer/IRC.py +++ b/nemubot/message/printer/IRC.py @@ -22,4 +22,4 @@ class IRC(SocketPrinter): def visit_Text(self, msg): self.pp += "PRIVMSG %s :" % ",".join(msg.to) - SocketPrinter.visit_Text(self, msg) + super().visit_Text(msg) diff --git a/nemubot/message/printer/__init__.py b/nemubot/message/printer/__init__.py index 060118b..e0fbeef 100644 --- a/nemubot/message/printer/__init__.py +++ b/nemubot/message/printer/__init__.py @@ -13,12 +13,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - -def reload(): - import imp - - import nemubot.message.printer.IRC - imp.reload(nemubot.message.printer.IRC) - - import nemubot.message.printer.socket - imp.reload(nemubot.message.printer.socket) diff --git a/nemubot/message/printer/socket.py b/nemubot/message/printer/socket.py index cb9bc4c..6884c88 100644 --- a/nemubot/message/printer/socket.py +++ b/nemubot/message/printer/socket.py @@ -35,7 +35,7 @@ class Socket(AbstractVisitor): others = [to for to in msg.to if to != msg.designated] # Avoid nick starting message when discussing on user channel - if len(others) != len(msg.to): + if len(others) == 0 or len(others) != len(msg.to): res = Text(msg.message, server=msg.server, date=msg.date, to=msg.to, frm=msg.frm) diff --git a/nemubot/message/response.py b/nemubot/message/response.py new file mode 100644 index 0000000..fba864b --- /dev/null +++ b/nemubot/message/response.py @@ -0,0 +1,34 @@ +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from nemubot.message.abstract import Abstract + + +class Response(Abstract): + + def __init__(self, cmd, args=None, *nargs, **kargs): + super().__init__(*nargs, **kargs) + + self.cmd = cmd + self.args = args if args is not None else list() + + def __str__(self): + return self.cmd + " @" + ",@".join(self.args) + + @property + def cmds(self): + # TODO: this is for legacy modules + return [self.cmd] + self.args diff --git a/nemubot/message/text.py b/nemubot/message/text.py index ec90a36..f691a04 100644 --- a/nemubot/message/text.py +++ b/nemubot/message/text.py @@ -28,7 +28,7 @@ class Text(Abstract): message -- the parsed message """ - Abstract.__init__(self, *args, **kargs) + super().__init__(*args, **kargs) self.message = message diff --git a/nemubot/modulecontext.py b/nemubot/modulecontext.py index 1d1b3d0..9a4b7d4 100644 --- a/nemubot/modulecontext.py +++ b/nemubot/modulecontext.py @@ -39,6 +39,7 @@ class ModuleContext: self.hooks = list() self.events = list() + self.extendtags = dict() self.debug = context.verbosity > 0 if context is not None else False from nemubot.hooks import Abstract as AbstractHook @@ -46,7 +47,7 @@ class ModuleContext: # Define some callbacks if context is not None: def load_data(): - return context.datastore.load(module_name) + return context.datastore.load(module_name, extendsTags=self.extendtags) def add_hook(hook, *triggers): assert isinstance(hook, AbstractHook), hook @@ -77,8 +78,7 @@ class ModuleContext: else: # Used when using outside of nemubot def load_data(): - from nemubot.tools.xmlparser import module_state - return module_state.ModuleState("nemubotstate") + return None def add_hook(hook, *triggers): assert isinstance(hook, AbstractHook), hook @@ -97,7 +97,9 @@ class ModuleContext: module.logger.info("Send response: %s", res) def save(): - context.datastore.save(module_name, self.data) + # Don't save if no data has been access + if hasattr(self, "_data"): + context.datastore.save(module_name, self.data) def subparse(orig, cnt): if orig.server in context.servers: @@ -120,6 +122,21 @@ class ModuleContext: self._data = self.load_data() return self._data + @data.setter + def data(self, value): + assert value is not None + + self._data = value + + + def register_tags(self, **tags): + self.extendtags.update(tags) + + + def unregister_tags(self, *tags): + for t in tags: + del self.extendtags[t] + def unload(self): """Perform actions for unloading the module""" diff --git a/nemubot/prompt/__init__.py b/nemubot/prompt/__init__.py deleted file mode 100644 index 27f7919..0000000 --- a/nemubot/prompt/__init__.py +++ /dev/null @@ -1,142 +0,0 @@ -# Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -import shlex -import sys -import traceback - -from nemubot.prompt import builtins - - -class Prompt: - - def __init__(self): - self.selectedServer = None - self.lastretcode = 0 - - self.HOOKS_CAPS = dict() - self.HOOKS_LIST = dict() - - def add_cap_hook(self, name, call, data=None): - self.HOOKS_CAPS[name] = lambda t, c: call(t, data=data, - context=c, prompt=self) - - def add_list_hook(self, name, call): - self.HOOKS_LIST[name] = call - - def lex_cmd(self, line): - """Return an array of tokens - - Argument: - line -- the line to lex - """ - - try: - cmds = shlex.split(line) - except: - exc_type, exc_value, _ = sys.exc_info() - sys.stderr.write(traceback.format_exception_only(exc_type, - exc_value)[0]) - return - - bgn = 0 - - # Separate commands (command separator: ;) - for i in range(0, len(cmds)): - if cmds[i][-1] == ';': - if i != bgn: - yield cmds[bgn:i] - bgn = i + 1 - - # Return rest of the command (that not end with a ;) - if bgn != len(cmds): - yield cmds[bgn:] - - def exec_cmd(self, toks, context): - """Execute the command - - Arguments: - toks -- lexed tokens to executes - context -- current bot context - """ - - if toks[0] in builtins.CAPS: - self.lastretcode = builtins.CAPS[toks[0]](toks, context, self) - elif toks[0] in self.HOOKS_CAPS: - self.lastretcode = self.HOOKS_CAPS[toks[0]](toks, context) - else: - print("Unknown command: `%s'" % toks[0]) - self.lastretcode = 127 - - def getPS1(self): - """Get the PS1 associated to the selected server""" - if self.selectedServer is None: - return "nemubot" - else: - return self.selectedServer.id - - def run(self, context): - """Launch the prompt - - Argument: - context -- current bot context - """ - - from nemubot.prompt.error import PromptError - from nemubot.prompt.reset import PromptReset - - while True: # Stopped by exception - try: - line = input("\033[0;33m%s\033[0;%dm§\033[0m " % - (self.getPS1(), 31 if self.lastretcode else 32)) - cmds = self.lex_cmd(line.strip()) - for toks in cmds: - try: - self.exec_cmd(toks, context) - except PromptReset: - raise - except PromptError as e: - print(e.message) - self.lastretcode = 128 - except: - exc_type, exc_value, exc_traceback = sys.exc_info() - traceback.print_exception(exc_type, exc_value, - exc_traceback) - except KeyboardInterrupt: - print("") - except EOFError: - print("quit") - return True - - -def hotswap(bak): - p = Prompt() - p.HOOKS_CAPS = bak.HOOKS_CAPS - p.HOOKS_LIST = bak.HOOKS_LIST - return p - - -def reload(): - import imp - - import nemubot.prompt.builtins - imp.reload(nemubot.prompt.builtins) - - import nemubot.prompt.error - imp.reload(nemubot.prompt.error) - - import nemubot.prompt.reset - imp.reload(nemubot.prompt.reset) diff --git a/nemubot/prompt/builtins.py b/nemubot/prompt/builtins.py deleted file mode 100644 index a020fb9..0000000 --- a/nemubot/prompt/builtins.py +++ /dev/null @@ -1,128 +0,0 @@ -# Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -def end(toks, context, prompt): - """Quit the prompt for reload or exit""" - from nemubot.prompt.reset import PromptReset - - if toks[0] == "refresh": - raise PromptReset("refresh") - elif toks[0] == "reset": - raise PromptReset("reset") - raise PromptReset("quit") - - -def liste(toks, context, prompt): - """Show some lists""" - if len(toks) > 1: - for l in toks[1:]: - l = l.lower() - if l == "server" or l == "servers": - for srv in context.servers.keys(): - print (" - %s (state: %s) ;" % (srv, - "connected" if context.servers[srv].connected else "disconnected")) - if len(context.servers) == 0: - print (" > No server loaded") - - elif l == "mod" or l == "mods" or l == "module" or l == "modules": - for mod in context.modules.keys(): - print (" - %s ;" % mod) - if len(context.modules) == 0: - print (" > No module loaded") - - elif l in prompt.HOOKS_LIST: - f, d = prompt.HOOKS_LIST[l] - f(d, context, prompt) - - else: - print (" Unknown list `%s'" % l) - return 2 - return 0 - else: - print (" Please give a list to show: servers, ...") - return 1 - - -def load(toks, context, prompt): - """Load an XML configuration file""" - if len(toks) > 1: - for filename in toks[1:]: - context.load_file(filename) - else: - print ("Not enough arguments. `load' takes a filename.") - return 1 - - -def select(toks, context, prompt): - """Select the current server""" - if (len(toks) == 2 and toks[1] != "None" and - toks[1] != "nemubot" and toks[1] != "none"): - if toks[1] in context.servers: - prompt.selectedServer = context.servers[toks[1]] - else: - print ("select: server `%s' not found." % toks[1]) - return 1 - else: - prompt.selectedServer = None - - -def unload(toks, context, prompt): - """Unload a module""" - if len(toks) == 2 and toks[1] == "all": - for name in context.modules.keys(): - context.unload_module(name) - elif len(toks) > 1: - for name in toks[1:]: - if context.unload_module(name): - print (" Module `%s' successfully unloaded." % name) - else: - print (" No module `%s' loaded, can't unload!" % name) - return 2 - else: - print ("Not enough arguments. `unload' takes a module name.") - return 1 - - -def debug(toks, context, prompt): - """Enable/Disable debug mode on a module""" - if len(toks) > 1: - for name in toks[1:]: - if name in context.modules: - context.modules[name].DEBUG = not context.modules[name].DEBUG - if context.modules[name].DEBUG: - print (" Module `%s' now in DEBUG mode." % name) - else: - print (" Debug for module module `%s' disabled." % name) - else: - print (" No module `%s' loaded, can't debug!" % name) - return 2 - else: - print ("Not enough arguments. `debug' takes a module name.") - return 1 - - -# Register build-ins -CAPS = { - 'quit': end, # Disconnect all server and quit - 'exit': end, # Alias for quit - 'reset': end, # Reload the prompt - 'refresh': end, # Reload the prompt but save modules - 'load': load, # Load a servers or module configuration file - 'unload': unload, # Unload a module and remove it from the list - 'select': select, # Select a server - 'list': liste, # Show lists - 'debug': debug, # Pass a module in debug mode -} diff --git a/nemubot/server/DCC.py b/nemubot/server/DCC.py index 6655d52..c1a6852 100644 --- a/nemubot/server/DCC.py +++ b/nemubot/server/DCC.py @@ -31,7 +31,7 @@ PORTS = list() class DCC(server.AbstractServer): def __init__(self, srv, dest, socket=None): - server.Server.__init__(self) + super().__init__(name="Nemubot DCC server") self.error = False # An error has occur, closing the connection? self.messages = list() # Message queued before connexion diff --git a/nemubot/server/IRC.py b/nemubot/server/IRC.py index e433176..d09966a 100644 --- a/nemubot/server/IRC.py +++ b/nemubot/server/IRC.py @@ -27,10 +27,10 @@ class IRC(SocketServer): """Concrete implementation of a connexion to an IRC server""" - def __init__(self, host="localhost", port=6667, ssl=False, owner=None, + def __init__(self, host="localhost", port=6667, owner=None, nick="nemubot", username=None, password=None, realname="Nemubot", encoding="utf-8", caps=None, - channels=list(), on_connect=None): + channels=list(), on_connect=None, **kwargs): """Prepare a connection with an IRC server Keyword arguments: @@ -54,8 +54,8 @@ class IRC(SocketServer): self.owner = owner self.realname = realname - self.id = self.username + "@" + host + ":" + str(port) - SocketServer.__init__(self, host=host, port=port, ssl=ssl) + super().__init__(name=self.username + "@" + host + ":" + str(port), + host=host, port=port, **kwargs) self.printer = IRCPrinter self.encoding = encoding @@ -232,29 +232,29 @@ class IRC(SocketServer): # Open/close - def _open(self): - if SocketServer._open(self): - if self.password is not None: - self.write("PASS :" + self.password) - if self.capabilities is not None: - self.write("CAP LS") - self.write("NICK :" + self.nick) - self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) - return True - return False + def connect(self): + super().connect() + + if self.password is not None: + self.write("PASS :" + self.password) + if self.capabilities is not None: + self.write("CAP LS") + self.write("NICK :" + self.nick) + self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) - def _close(self): - if self.connected: self.write("QUIT") - return SocketServer._close(self) + def close(self): + if not self._closed: + self.write("QUIT") + return super().close() # Writes: as inherited # Read - def read(self): - for line in SocketServer.read(self): + def async_read(self): + for line in super().async_read(): # PING should be handled here, so start parsing here :/ msg = IRCMessage(line, self.encoding) diff --git a/nemubot/server/__init__.py b/nemubot/server/__init__.py index b9a8fe4..464c924 100644 --- a/nemubot/server/__init__.py +++ b/nemubot/server/__init__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,20 +14,13 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import threading -_lock = threading.Lock() - -# Lists for select -_rlist = [] -_wlist = [] -_xlist = [] - - -def factory(uri, **init_args): +def factory(uri, ssl=False, **init_args): from urllib.parse import urlparse, unquote o = urlparse(uri) + srv = None + if o.scheme == "irc" or o.scheme == "ircs": # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html @@ -36,7 +29,7 @@ def factory(uri, **init_args): modifiers = o.path.split(",") target = unquote(modifiers.pop(0)[1:]) - if o.scheme == "ircs": args["ssl"] = True + if o.scheme == "ircs": ssl = True if o.hostname is not None: args["host"] = o.hostname if o.port is not None: args["port"] = o.port if o.username is not None: args["username"] = o.username @@ -65,24 +58,11 @@ def factory(uri, **init_args): args["channels"] = [ target ] from nemubot.server.IRC import IRC as IRCServer - return IRCServer(**args) - else: - return None + srv = IRCServer(**args) + if ssl: + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + return ctx.wrap_socket(srv) -def reload(): - import imp - - import nemubot.server.abstract - imp.reload(nemubot.server.abstract) - - import nemubot.server.socket - imp.reload(nemubot.server.socket) - - import nemubot.server.IRC - imp.reload(nemubot.server.IRC) - - import nemubot.server.message - imp.reload(nemubot.server.message) - - nemubot.server.message.reload() + return srv diff --git a/nemubot/server/abstract.py b/nemubot/server/abstract.py index ebcb427..7e31cda 100644 --- a/nemubot/server/abstract.py +++ b/nemubot/server/abstract.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,71 +14,68 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import io import logging import queue -from nemubot.server import _lock, _rlist, _wlist, _xlist +from nemubot.bot import sync_act -# Extends from IOBase in order to be compatible with select function -class AbstractServer(io.IOBase): + +class AbstractServer: """An abstract server: handle communication with an IM server""" - def __init__(self, send_callback=None): + def __init__(self, name=None, **kwargs): """Initialize an abstract server Keyword argument: - send_callback -- Callback when developper want to send a message + name -- Identifier of the socket, for convinience """ - if not hasattr(self, "id"): - raise Exception("No id defined for this server. Please set one!") + self._name = name - self.logger = logging.getLogger("nemubot.server." + self.id) + super().__init__(**kwargs) + + self.logger = logging.getLogger("nemubot.server." + str(self.name)) + self._readbuffer = b'' self._sending_queue = queue.Queue() - if send_callback is not None: - self._send_callback = send_callback + + + def __del__(self): + print("Server deleted") + + + @property + def name(self): + if self._name is not None: + return self._name else: - self._send_callback = self._write_select + return self.fileno() # Open/close - def __enter__(self): - self.open() - return self + def connect(self, *args, **kwargs): + """Register the server in _poll""" + + self.logger.info("Opening connection") + + super().connect(*args, **kwargs) + + self._connected() + + def _connected(self): + sync_act("sckt register %d" % self.fileno()) - def __exit__(self, type, value, traceback): - self.close() + def close(self, *args, **kwargs): + """Unregister the server from _poll""" + self.logger.info("Closing connection") - def open(self): - """Generic open function that register the server un _rlist in case - of successful _open""" - self.logger.info("Opening connection to %s", self.id) - if not hasattr(self, "_open") or self._open(): - _rlist.append(self) - _xlist.append(self) - return True - return False + if self.fileno() > 0: + sync_act("sckt unregister %d" % self.fileno()) - - def close(self): - """Generic close function that register the server un _{r,w,x}list in - case of successful _close""" - self.logger.info("Closing connection to %s", self.id) - with _lock: - if not hasattr(self, "_close") or self._close(): - if self in _rlist: - _rlist.remove(self) - if self in _wlist: - _wlist.remove(self) - if self in _xlist: - _xlist.remove(self) - return True - return False + super().close(*args, **kwargs) # Writes @@ -90,13 +87,16 @@ class AbstractServer(io.IOBase): message -- message to send """ - self._send_callback(message) + self._sending_queue.put(self.format(message)) + self.logger.debug("Message '%s' appended to write queue", message) + sync_act("sckt write %d" % self.fileno()) - def write_select(self): - """Internal function used by the select function""" + def async_write(self): + """Internal function used when the file descriptor is writable""" + try: - _wlist.remove(self) + sync_act("sckt unwrite %d" % self.fileno()) while not self._sending_queue.empty(): self._write(self._sending_queue.get_nowait()) self._sending_queue.task_done() @@ -105,19 +105,6 @@ class AbstractServer(io.IOBase): pass - def _write_select(self, message): - """Send a message to the server safely through select - - Argument: - message -- message to send - """ - - self._sending_queue.put(self.format(message)) - self.logger.debug("Message '%s' appended to Queue", message) - if self not in _wlist: - _wlist.append(self) - - def send_response(self, response): """Send a formated Message class @@ -140,13 +127,39 @@ class AbstractServer(io.IOBase): # Read + def async_read(self): + """Internal function used when the file descriptor is readable + + Returns: + A list of fully received messages + """ + + ret, self._readbuffer = self.lex(self._readbuffer + self.read()) + + for r in ret: + yield r + + + def lex(self, buf): + """Assume lexing in default case is per line + + Argument: + buf -- buffer to lex + """ + + msgs = buf.split(b'\r\n') + partial = msgs.pop() + + return msgs, partial + + def parse(self, msg): raise NotImplemented # Exceptions - def exception(self): - """Exception occurs in fd""" - self.logger.warning("Unhandle file descriptor exception on server %s", - self.id) + def exception(self, flags): + """Exception occurs on fd""" + + self.close() diff --git a/nemubot/server/message/IRC.py b/nemubot/server/message/IRC.py index f6d562f..4c9e280 100644 --- a/nemubot/server/message/IRC.py +++ b/nemubot/server/message/IRC.py @@ -146,7 +146,7 @@ class IRC(Abstract): receivers = self.decode(self.params[0]).split(',') common_args = { - "server": srv.id, + "server": srv.name, "date": self.tags["time"], "to": receivers, "to_response": [r if r != srv.nick else self.nick for r in receivers], diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py index b6c00d4..aeb20e5 100644 --- a/nemubot/server/socket.py +++ b/nemubot/server/socket.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,99 +14,35 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import os +import socket + +import nemubot.message as message from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.server.abstract import AbstractServer -class SocketServer(AbstractServer): +class _Socket(AbstractServer, socket.socket): - """Concrete implementation of a socket connexion (can be wrapped with TLS)""" + """Concrete implementation of a socket connection""" - def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None): - if id is not None: - self.id = id - AbstractServer.__init__(self) - if sock_location is not None: - self.filename = sock_location - elif host is not None: - self.host = host - self.port = int(port) - self.ssl = ssl + def __init__(self, **kwargs): + """Create a server socket + + Keyword arguments: + ssl -- Should TLS connection enabled + """ + + super().__init__(**kwargs) - self.socket = socket self.readbuffer = b'' self.printer = SocketPrinter - def fileno(self): - return self.socket.fileno() if self.socket else None - - - @property - def connected(self): - """Indicator of the connection aliveness""" - return self.socket is not None - - - # Open/close - - def _open(self): - import os - import socket - - if self.connected: - return True - - try: - if hasattr(self, "filename"): - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.socket.connect(self.filename) - self.logger.info("Connected to %s", self.filename) - else: - self.socket = socket.create_connection((self.host, self.port)) - self.logger.info("Connected to %s:%d", self.host, self.port) - except socket.error as e: - self.socket = None - self.logger.critical("Unable to connect to %s:%d: %s", - self.host, self.port, - os.strerror(e.errno)) - return False - - # Wrap the socket for SSL - if self.ssl: - import ssl - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - self.socket = ctx.wrap_socket(self.socket) - - return True - - - def _close(self): - import socket - - from nemubot.server import _lock - _lock.release() - self._sending_queue.join() - _lock.acquire() - if self.connected: - try: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() - except socket.error: - pass - - self.socket = None - - return True - - # Write def _write(self, cnt): - if not self.connected: - return - - self.socket.send(cnt) + self.sendall(cnt) def format(self, txt): @@ -118,80 +54,113 @@ class SocketServer(AbstractServer): # Read + def read(self, n=1024): + return self.recv(n) + + + def parse(self, line): + """Implement a default behaviour for socket""" + import shlex + + line = line.strip().decode() + try: + args = shlex.split(line) + except ValueError: + args = line.split(' ') + + if len(args): + yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you") + + +class SocketServer(_Socket): + + def __init__(self, host, port, bind=None, **kwargs): + super().__init__(family=socket.AF_INET, **kwargs) + + assert(host is not None) + assert(isinstance(port, int)) + + self._host = host + self._port = port + self._bind = bind + + + def connect(self): + self.logger.info("Connection to %s:%d", self._host, self._port) + super().connect((self._host, self._port)) + + if self._bind: + super().bind(self._bind) + + +class UnixSocket(_Socket): + + def __init__(self, location, **kwargs): + super().__init__(family=socket.AF_UNIX, **kwargs) + + self._socket_path = location + + + def connect(self): + self.logger.info("Connection to unix://%s", self._socket_path) + super().connect(self._socket_path) + + +class _Listener(_Socket): + + def __init__(self, new_server_cb, instanciate=_Socket, **kwargs): + super().__init__(**kwargs) + + self._instanciate = instanciate + self._new_server_cb = new_server_cb + + def read(self): - if not self.connected: - return [] + conn, addr = self.accept() + self.logger.info("Accept new connection from %s", addr) - raw = self.socket.recv(1024) - temp = (self.readbuffer + raw).split(b'\r\n') - self.readbuffer = temp.pop() + fileno = conn.fileno() + ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach()) + ss.connect = ss._connected + self._new_server_cb(ss, autoconnect=True) - for line in temp: - yield line + return b'' -class SocketListener(AbstractServer): +class UnixSocketListener(_Listener, UnixSocket): - def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None): - self.id = id - AbstractServer.__init__(self) - self.new_server_cb = new_server_cb - self.sock_location = sock_location - self.host = host - self.port = port - self.ssl = ssl - self.nb_son = 0 + def __init__(self, **kwargs): + super().__init__(**kwargs) - def fileno(self): - return self.socket.fileno() if self.socket else None + def connect(self): + self.logger.info("Creating Unix socket at unix://%s", self._socket_path) + + try: + os.remove(self._socket_path) + except FileNotFoundError: + pass + + self.bind(self._socket_path) + self.listen(5) + self.logger.info("Socket ready for accepting new connections") + + self._connected() - @property - def connected(self): - """Indicator of the connection aliveness""" - return self.socket is not None - - - def _open(self): - import os - import socket - - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - if self.sock_location is not None: - try: - os.remove(self.sock_location) - except FileNotFoundError: - pass - self.socket.bind(self.sock_location) - elif self.host is not None and self.port is not None: - self.socket.bind((self.host, self.port)) - self.socket.listen(5) - - return True - - - def _close(self): + def close(self): import os import socket try: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() - if self.sock_location is not None: - os.remove(self.sock_location) + self.shutdown(socket.SHUT_RDWR) except socket.error: pass - # Read + super().close() - def read(self): - if not self.connected: - return [] - - conn, addr = self.socket.accept() - self.nb_son += 1 - ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn) - self.new_server_cb(ss) - - return [] + try: + if self._socket_path is not None: + os.remove(self._socket_path) + except: + pass diff --git a/nemubot/tools/__init__.py b/nemubot/tools/__init__.py index 127154c..57f3468 100644 --- a/nemubot/tools/__init__.py +++ b/nemubot/tools/__init__.py @@ -13,29 +13,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - -def reload(): - import imp - - import nemubot.tools.config - imp.reload(nemubot.tools.config) - - import nemubot.tools.countdown - imp.reload(nemubot.tools.countdown) - - import nemubot.tools.feed - imp.reload(nemubot.tools.feed) - - import nemubot.tools.date - imp.reload(nemubot.tools.date) - - import nemubot.tools.human - imp.reload(nemubot.tools.human) - - import nemubot.tools.web - imp.reload(nemubot.tools.web) - - import nemubot.tools.xmlparser - imp.reload(nemubot.tools.xmlparser) - import nemubot.tools.xmlparser.node - imp.reload(nemubot.tools.xmlparser.node) diff --git a/nemubot/tools/human.py b/nemubot/tools/human.py index a18cde2..f0e947f 100644 --- a/nemubot/tools/human.py +++ b/nemubot/tools/human.py @@ -57,11 +57,15 @@ def word_distance(str1, str2): return d[len(str1)][len(str2)] -def guess(pattern, expect): +def guess(pattern, expect, max_depth=0): + if max_depth == 0: + max_depth = 1 + len(pattern) / 4 + elif max_depth <= -1: + max_depth = len(pattern) - max_depth if len(expect): se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1]) _, m = se[0] for e, wd in se: - if wd > m or wd > 1 + len(pattern) / 4: + if wd > m or wd > max_depth: break yield e diff --git a/nemubot/tools/xmlparser/__init__.py b/nemubot/tools/xmlparser/__init__.py index abc5bb9..687bf63 100644 --- a/nemubot/tools/xmlparser/__init__.py +++ b/nemubot/tools/xmlparser/__init__.py @@ -51,11 +51,13 @@ class XMLParser: def __init__(self, knodes): self.knodes = knodes + def _reset(self): self.stack = list() self.child = 0 def parse_file(self, path): + self._reset() p = xml.parsers.expat.ParserCreate() p.StartElementHandler = self.startElement @@ -69,6 +71,7 @@ class XMLParser: def parse_string(self, s): + self._reset() p = xml.parsers.expat.ParserCreate() p.StartElementHandler = self.startElement @@ -126,10 +129,13 @@ class XMLParser: if hasattr(self.current, "endElement"): self.current.endElement(None) + if hasattr(self.current, "parsedForm") and callable(self.current.parsedForm): + self.stack[-1] = self.current.parsedForm() + # Don't remove root if len(self.stack) > 1: last = self.stack.pop() - if hasattr(self.current, "addChild"): + if hasattr(self.current, "addChild") and callable(self.current.addChild): if self.current.addChild(name, last): return raise TypeError(name + " tag not expected in " + self.display_stack()) diff --git a/setup.py b/setup.py index b39a163..a400c3c 100755 --- a/setup.py +++ b/setup.py @@ -63,13 +63,13 @@ setup( 'nemubot', 'nemubot.config', 'nemubot.datastore', + 'nemubot.datastore.nodes', 'nemubot.event', 'nemubot.exception', 'nemubot.hooks', 'nemubot.hooks.keywords', 'nemubot.message', 'nemubot.message.printer', - 'nemubot.prompt', 'nemubot.server', 'nemubot.server.message', 'nemubot.tools',