diff --git a/README.md b/README.md index aa3b141..1d40faf 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ Requirements *nemubot* requires at least Python 3.3 to work. +Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629). + Some modules (like `cve`, `nextstop` or `laposte`) require the [BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/), but the core and framework has no dependency. diff --git a/modules/cmd_server.py b/modules/cmd_server.py deleted file mode 100644 index 6580c18..0000000 --- a/modules/cmd_server.py +++ /dev/null @@ -1,202 +0,0 @@ -# Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -import traceback -import sys - -from nemubot.hooks import hook - -nemubotversion = 3.4 -NODATA = True - - -def getserver(toks, context, prompt, mandatory=False, **kwargs): - """Choose the server in toks or prompt. - This function modify the tokens list passed as argument""" - - if len(toks) > 1 and toks[1] in context.servers: - return context.servers[toks.pop(1)] - elif not mandatory or prompt.selectedServer: - return prompt.selectedServer - else: - from nemubot.prompt.error import PromptError - raise PromptError("Please SELECT a server or give its name in argument.") - - -@hook("prompt_cmd", "close") -def close(toks, context, **kwargs): - """Disconnect and forget (remove from the servers list) the server""" - srv = getserver(toks, context=context, mandatory=True, **kwargs) - - if srv.close(): - del context.servers[srv.id] - return 0 - return 1 - - -@hook("prompt_cmd", "connect") -def connect(toks, **kwargs): - """Make the connexion to a server""" - srv = getserver(toks, mandatory=True, **kwargs) - - return not srv.open() - - -@hook("prompt_cmd", "disconnect") -def disconnect(toks, **kwargs): - """Close the connection to a server""" - srv = getserver(toks, mandatory=True, **kwargs) - - return not srv.close() - - -@hook("prompt_cmd", "discover") -def discover(toks, context, **kwargs): - """Discover a new bot on a server""" - srv = getserver(toks, context=context, mandatory=True, **kwargs) - - if len(toks) > 1 and "!" in toks[1]: - bot = context.add_networkbot(srv, name) - return not bot.connect() - else: - print(" %s is not a valid fullname, for example: " - "nemubot!nemubotV3@bot.nemunai.re" % ''.join(toks[1:1])) - return 1 - - -@hook("prompt_cmd", "join") -@hook("prompt_cmd", "leave") -@hook("prompt_cmd", "part") -def join(toks, **kwargs): - """Join or leave a channel""" - srv = getserver(toks, mandatory=True, **kwargs) - - if len(toks) <= 2: - print("%s: not enough arguments." % toks[0]) - return 1 - - if toks[0] == "join": - if len(toks) > 2: - srv.write("JOIN %s %s" % (toks[1], toks[2])) - else: - srv.write("JOIN %s" % toks[1]) - - elif toks[0] == "leave" or toks[0] == "part": - if len(toks) > 2: - srv.write("PART %s :%s" % (toks[1], " ".join(toks[2:]))) - else: - srv.write("PART %s" % toks[1]) - - return 0 - - -@hook("prompt_cmd", "save") -def save_mod(toks, context, **kwargs): - """Force save module data""" - if len(toks) < 2: - print("save: not enough arguments.") - return 1 - - wrn = 0 - for mod in toks[1:]: - if mod in context.modules: - context.modules[mod].save() - print("save: module `%s´ saved successfully" % mod) - else: - wrn += 1 - print("save: no module named `%s´" % mod) - return wrn - - -@hook("prompt_cmd", "send") -def send(toks, **kwargs): - """Send a message on a channel""" - srv = getserver(toks, mandatory=True, **kwargs) - - # Check the server is connected - if not srv.connected: - print ("send: server `%s' not connected." % srv.id) - return 2 - - if len(toks) <= 3: - print ("send: not enough arguments.") - return 1 - - if toks[1] not in srv.channels: - print ("send: channel `%s' not authorized in server `%s'." - % (toks[1], srv.id)) - return 3 - - from nemubot.message import Text - srv.send_response(Text(" ".join(toks[2:]), server=None, - to=[toks[1]])) - return 0 - - -@hook("prompt_cmd", "zap") -def zap(toks, **kwargs): - """Hard change connexion state""" - srv = getserver(toks, mandatory=True, **kwargs) - - srv.connected = not srv.connected - - -@hook("prompt_cmd", "top") -def top(toks, context, **kwargs): - """Display consumers load information""" - print("Queue size: %d, %d thread(s) running (counter: %d)" % - (context.cnsr_queue.qsize(), - len(context.cnsr_thrd), - context.cnsr_thrd_size)) - if len(context.events) > 0: - print("Events registered: %d, next in %d seconds" % - (len(context.events), - context.events[0].time_left.seconds)) - else: - print("No events registered") - - for th in context.cnsr_thrd: - if th.is_alive(): - print(("#" * 15 + " Stack trace for thread %u " + "#" * 15) % - th.ident) - traceback.print_stack(sys._current_frames()[th.ident]) - - -@hook("prompt_cmd", "netstat") -def netstat(toks, context, **kwargs): - """Display sockets in use and many other things""" - if len(context.network) > 0: - print("Distant bots connected: %d:" % len(context.network)) - for name, bot in context.network.items(): - print("# %s:" % name) - print(" * Declared hooks:") - lvl = 0 - for hlvl in bot.hooks: - lvl += 1 - for hook in (hlvl.all_pre + hlvl.all_post + hlvl.cmd_rgxp + - hlvl.cmd_default + hlvl.ask_rgxp + - hlvl.ask_default + hlvl.msg_rgxp + - hlvl.msg_default): - print(" %s- %s" % (' ' * lvl * 2, hook)) - for kind in ["irc_hook", "cmd_hook", "ask_hook", "msg_hook"]: - print(" %s- <%s> %s" % (' ' * lvl * 2, kind, - ", ".join(hlvl.__dict__[kind].keys()))) - print(" * My tag: %d" % bot.my_tag) - print(" * Tags in use (%d):" % bot.inc_tag) - for tag, (cmd, data) in bot.tags.items(): - print(" - %11s: %s « %s »" % (tag, cmd, data)) - else: - print("No distant bot connected") diff --git a/modules/rnd.py b/modules/rnd.py index 32c2adf..5329b06 100644 --- a/modules/rnd.py +++ b/modules/rnd.py @@ -8,7 +8,6 @@ import shlex from nemubot import context from nemubot.exception import IMException from nemubot.hooks import hook -from nemubot.message import Command from more import Response @@ -32,8 +31,24 @@ def cmd_choicecmd(msg): choice = shlex.split(random.choice(msg.args)) - return [x for x in context.subtreat(Command(choice[0][1:], - choice[1:], - to_response=msg.to_response, - frm=msg.frm, - server=msg.server))] + return [x for x in context.subtreat(context.subparse(msg, choice))] + + +@hook.command("choiceres") +def cmd_choiceres(msg): + if not len(msg.args): + raise IMException("indicate some command to pick a message from!") + + rl = [x for x in context.subtreat(context.subparse(msg, " ".join(msg.args)))] + if len(rl) <= 0: + return rl + + r = random.choice(rl) + + if isinstance(r, Response): + for i in range(len(r.messages) - 1, -1, -1): + if isinstance(r.messages[i], list): + r.messages = [ random.choice(random.choice(r.messages)) ] + elif isinstance(r.messages[i], str): + r.messages = [ random.choice(r.messages) ] + return r diff --git a/nemubot/__init__.py b/nemubot/__init__.py index d0a2072..8ec9c1a 100644 --- a/nemubot/__init__.py +++ b/nemubot/__init__.py @@ -18,7 +18,8 @@ __version__ = '4.0.dev3' __author__ = 'nemunaire' from nemubot.modulecontext import ModuleContext -context = ModuleContext(None, None) + +context = ModuleContext(None, None, None) def requires_version(min=None, max=None): @@ -38,62 +39,98 @@ def requires_version(min=None, max=None): "but this is nemubot v%s." % (str(max), __version__)) -def reload(): - """Reload code of all Python modules used by nemubot +def attach(pid, socketfile): + import socket + import sys + + print("nemubot is already launched with PID %d. Attaching to Unix socket at: %s" % (pid, socketfile)) + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(socketfile) + except socket.error as e: + sys.stderr.write(str(e)) + sys.stderr.write("\n") + return 1 + + from select import select + try: + print("Connection established.") + while True: + rl, wl, xl = select([sys.stdin, sock], [], []) + + if sys.stdin in rl: + line = sys.stdin.readline().strip() + if line == "exit" or line == "quit": + return 0 + elif line == "reload": + import os, signal + os.kill(pid, signal.SIGHUP) + print("Reload signal sent. Please wait...") + + elif line == "shutdown": + import os, signal + os.kill(pid, signal.SIGTERM) + print("Shutdown signal sent. Please wait...") + + elif line == "kill": + import os, signal + os.kill(pid, signal.SIGKILL) + print("Signal sent...") + return 0 + + elif line == "stack" or line == "stacks": + import os, signal + os.kill(pid, signal.SIGUSR1) + print("Debug signal sent. Consult logs.") + + else: + sock.send(line.encode() + b'\r\n') + + if sock in rl: + sys.stdout.write(sock.recv(2048).decode()) + except KeyboardInterrupt: + pass + except: + return 1 + finally: + sock.close() + return 0 + + +def daemonize(): + """Detach the running process to run as a daemon """ - import imp + import os + import sys - import nemubot.channel - imp.reload(nemubot.channel) + try: + pid = os.fork() + if pid > 0: + sys.exit(0) + except OSError as err: + sys.stderr.write("Unable to fork: %s\n" % err) + sys.exit(1) - import nemubot.config - imp.reload(nemubot.config) + os.setsid() + os.umask(0) + os.chdir('/') - nemubot.config.reload() + try: + pid = os.fork() + if pid > 0: + sys.exit(0) + except OSError as err: + sys.stderr.write("Unable to fork: %s\n" % err) + sys.exit(1) - import nemubot.consumer - imp.reload(nemubot.consumer) + sys.stdout.flush() + sys.stderr.flush() + si = open(os.devnull, 'r') + so = open(os.devnull, 'a+') + se = open(os.devnull, 'a+') - import nemubot.datastore - imp.reload(nemubot.datastore) - - nemubot.datastore.reload() - - import nemubot.event - imp.reload(nemubot.event) - - import nemubot.exception - imp.reload(nemubot.exception) - - nemubot.exception.reload() - - import nemubot.hooks - imp.reload(nemubot.hooks) - - nemubot.hooks.reload() - - import nemubot.importer - imp.reload(nemubot.importer) - - import nemubot.message - imp.reload(nemubot.message) - - nemubot.message.reload() - - import nemubot.prompt - imp.reload(nemubot.prompt) - - nemubot.prompt.reload() - - import nemubot.server - rl, wl, xl = nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist - imp.reload(nemubot.server) - nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist = rl, wl, xl - - nemubot.server.reload() - - import nemubot.tools - imp.reload(nemubot.tools) - - nemubot.tools.reload() + os.dup2(si.fileno(), sys.stdin.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) + os.dup2(se.fileno(), sys.stderr.fileno()) diff --git a/nemubot/__main__.py b/nemubot/__main__.py index 1809bee..e4acd5c 100644 --- a/nemubot/__main__.py +++ b/nemubot/__main__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,9 @@ # along with this program. If not, see . def main(): + import functools import os + import signal import sys # Parse command line arguments @@ -36,6 +38,15 @@ def main(): default=["./modules/"], help="directory to use as modules store") + parser.add_argument("-d", "--debug", action="store_true", + help="don't deamonize, keep in foreground") + + parser.add_argument("-P", "--pidfile", default="./nemubot.pid", + help="Path to the file where store PID") + + parser.add_argument("-S", "--socketfile", default="./nemubot.sock", + help="path where open the socket for internal communication") + parser.add_argument("-l", "--logfile", default="./nemubot.log", help="Path to store logs") @@ -58,11 +69,35 @@ def main(): # Resolve relatives paths args.data_path = os.path.abspath(os.path.expanduser(args.data_path)) + args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile)) + args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile)) args.logfile = os.path.abspath(os.path.expanduser(args.logfile)) - args.files = [ x for x in map(os.path.abspath, args.files)] - args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)] + args.files = [x for x in map(os.path.abspath, args.files)] + args.modules_path = [x for x in map(os.path.abspath, args.modules_path)] - # Setup loggin interface + # Check if an instance is already launched + if args.pidfile is not None and os.path.isfile(args.pidfile): + with open(args.pidfile, "r") as f: + pid = int(f.readline()) + try: + os.kill(pid, 0) + except OSError: + pass + else: + from nemubot import attach + sys.exit(attach(pid, args.socketfile)) + + # Daemonize + if not args.debug: + from nemubot import daemonize + daemonize() + + # Store PID to pidfile + if args.pidfile is not None: + with open(args.pidfile, "w+") as f: + f.write(str(os.getpid())) + + # Setup logging interface import logging logger = logging.getLogger("nemubot") logger.setLevel(logging.DEBUG) @@ -70,11 +105,12 @@ def main(): formatter = logging.Formatter( '%(asctime)s %(name)s %(levelname)s %(message)s') - ch = logging.StreamHandler() - ch.setFormatter(formatter) - if args.verbose < 2: - ch.setLevel(logging.INFO) - logger.addHandler(ch) + if args.debug: + ch = logging.StreamHandler() + ch.setFormatter(formatter) + if args.verbose < 2: + ch.setLevel(logging.INFO) + logger.addHandler(ch) fh = logging.FileHandler(args.logfile) fh.setFormatter(formatter) @@ -90,64 +126,89 @@ def main(): # Create bot context from nemubot import datastore - from nemubot.bot import Bot - context = Bot(modules_paths=modules_paths, - data_store=datastore.XML(args.data_path), - verbosity=args.verbose) + from nemubot.bot import Bot#, sync_act + context = Bot() if args.no_connect: context.noautoconnect = True - # Load the prompt - import nemubot.prompt - prmpt = nemubot.prompt.Prompt() - # Register the hook for futur import - from nemubot.importer import ModuleFinder - sys.meta_path.append(ModuleFinder(context.modules_paths, context.add_module)) + #from nemubot.importer import ModuleFinder + #module_finder = ModuleFinder(modules_paths, context.add_module) + #sys.meta_path.append(module_finder) # Load requested configuration files - for path in args.files: - if os.path.isfile(path): - context.sync_queue.put_nowait(["loadconf", path]) - else: - logger.error("%s is not a readable file", path) + #for path in args.files: + # if os.path.isfile(path): + # sync_act("loadconf", path) + # else: + # logger.error("%s is not a readable file", path) if args.module: for module in args.module: __import__(module) - print ("Nemubot v%s ready, my PID is %i!" % (nemubot.__version__, - os.getpid())) - while True: - from nemubot.prompt.reset import PromptReset - try: - context.start() - if prmpt.run(context): - break - except PromptReset as e: - if e.type == "quit": - break + # Signals handling ################################################ - try: - import imp - # Reload all other modules - imp.reload(nemubot) - imp.reload(nemubot.prompt) - nemubot.reload() - import nemubot.bot - context = nemubot.bot.hotswap(context) - prmpt = nemubot.prompt.hotswap(prmpt) - print("\033[1;32mContext reloaded\033[0m, now in Nemubot %s" % - nemubot.__version__) - except: - logger.exception("\033[1;31mUnable to reload the prompt due to " - "errors.\033[0m Fix them before trying to reload " - "the prompt.") + # SIGINT and SIGTERM - context.quit() - print("Waiting for other threads shuts down...") + def ask_exit(signame): + """On SIGTERM and SIGINT, quit nicely""" + context.stop() + + for sig in (signal.SIGINT, signal.SIGTERM): + context._loop.add_signal_handler(sig, functools.partial(ask_exit, sig)) + + + # SIGHUP + + def ask_reload(): + """Perform a deep reload""" + nonlocal context + + logger.debug("SIGHUP receive, iniate reload procedure...") + + # Reload configuration file + #for path in args.files: + # if os.path.isfile(path): + # sync_act("loadconf", path) + context._loop.add_signal_handler(signal.SIGHUP, ask_reload) + + + # SIGUSR1 + + def ask_debug_info(): + """Display debug informations and stacktraces""" + import threading, traceback + for threadId, stack in sys._current_frames().items(): + thName = "#%d" % threadId + for th in threading.enumerate(): + if th.ident == threadId: + thName = th.name + break + logger.debug("########### Thread %s:\n%s", + thName, + "".join(traceback.format_stack(stack))) + context._loop.add_signal_handler(signal.SIGUSR1, ask_debug_info) + + #if args.socketfile: + # from nemubot.server.socket import UnixSocketListener + # context.add_server(UnixSocketListener(new_server_cb=context.add_server, + # location=args.socketfile, + # name="master_socket")) + + # context can change when performing an hotswap, always join the latest context + oldcontext = None + while oldcontext != context: + oldcontext = context + context.run() + + # Wait for consumers + logger.info("Waiting for other threads shuts down...") + if args.debug: + ask_debug_info() sys.exit(0) + if __name__ == "__main__": main() diff --git a/nemubot/bot.py b/nemubot/bot.py index e449f35..dc7fa08 100644 --- a/nemubot/bot.py +++ b/nemubot/bot.py @@ -14,568 +14,34 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from datetime import datetime, timezone import logging -import threading import sys -from nemubot import __version__ -from nemubot.consumer import Consumer, EventConsumer, MessageConsumer -from nemubot import datastore -import nemubot.hooks -logger = logging.getLogger("nemubot") - - -class Bot(threading.Thread): +class Bot: """Class containing the bot context and ensuring key goals""" - def __init__(self, ip="127.0.0.1", modules_paths=list(), - data_store=datastore.Abstract(), verbosity=0): - """Initialize the bot context + logger = logging.getLogger("nemubot") - Keyword arguments: - ip -- The external IP of the bot (default: 127.0.0.1) - modules_paths -- Paths to all directories where looking for module - data_store -- An instance of the nemubot datastore for bot's modules + def __init__(self): + """Initialize the bot context """ - threading.Thread.__init__(self) + from nemubot import __version__ + Bot.logger.info("Initiate nemubot v%s, running on Python %s", + __version__, sys.version.split("\n")[0]) - logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", - __version__, - sys.version_info.major, sys.version_info.minor, sys.version_info.micro) - - self.verbosity = verbosity - self.stop = None - - # External IP for accessing this bot - import ipaddress - self.ip = ipaddress.ip_address(ip) - - # Context paths - self.modules_paths = modules_paths - self.datastore = data_store - self.datastore.open() - - # Keep global context: servers and modules - self.servers = dict() - self.modules = dict() - self.modules_configuration = dict() - - # Events - self.events = list() - self.event_timer = None - - # Own hooks - from nemubot.treatment import MessageTreater - self.treater = MessageTreater() - - import re - def in_ping(msg): - return msg.respond("pong") - self.treater.hm.add_hook(nemubot.hooks.Message(in_ping, - match=lambda msg: re.match("^ *(m[' ]?entends?[ -]+tu|h?ear me|do you copy|ping)", - msg.message, re.I)), - "in", "DirectAsk") - - def in_echo(msg): - from nemubot.message import Text - return Text(msg.nick + ": " + " ".join(msg.args), to=msg.to_response) - self.treater.hm.add_hook(nemubot.hooks.Command(in_echo, "echo"), "in", "Command") - - def _help_msg(msg): - """Parse and response to help messages""" - from more import Response - res = Response(channel=msg.to_response) - if len(msg.args) >= 1: - if msg.args[0] in self.modules: - if hasattr(self.modules[msg.args[0]], "help_full"): - hlp = self.modules[msg.args[0]].help_full() - if isinstance(hlp, Response): - return hlp - else: - res.append_message(hlp) - else: - res.append_message([str(h) for s,h in self.modules[msg.args[0]].__nemubot_context__.hooks], title="Available commands for module " + msg.args[0]) - elif msg.args[0][0] == "!": - from nemubot.message.command import Command - for h in self.treater._in_hooks(Command(msg.args[0][1:])): - if h.help_usage: - lp = ["\x03\x02%s%s\x03\x02: %s" % (msg.args[0], (" " + k if k is not None else ""), h.help_usage[k]) for k in h.help_usage] - jp = h.keywords.help() - return res.append_message(lp + ([". Moreover, you can provides some optional parameters: "] + jp if len(jp) else []), title="Usage for command %s" % msg.args[0]) - elif h.help: - return res.append_message("Command %s: %s" % (msg.args[0], h.help)) - else: - return res.append_message("Sorry, there is currently no help for the command %s. Feel free to make a pull request at https://github.com/nemunaire/nemubot/compare" % msg.args[0]) - res.append_message("Sorry, there is no command %s" % msg.args[0]) - else: - res.append_message("Sorry, there is no module named %s" % msg.args[0]) - else: - res.append_message("Pour me demander quelque chose, commencez " - "votre message par mon nom ; je réagis " - "également à certaine commandes commençant par" - " !. Pour plus d'informations, envoyez le " - "message \"!more\".") - res.append_message("Mon code source est libre, publié sous " - "licence AGPL (http://www.gnu.org/licenses/). " - "Vous pouvez le consulter, le dupliquer, " - "envoyer des rapports de bogues ou bien " - "contribuer au projet sur GitHub : " - "http://github.com/nemunaire/nemubot/") - res.append_message(title="Pour plus de détails sur un module, " - "envoyez \"!help nomdumodule\". Voici la liste" - " de tous les modules disponibles localement", - message=["\x03\x02%s\x03\x02 (%s)" % (im, self.modules[im].__doc__) for im in self.modules if self.modules[im].__doc__]) - return res - self.treater.hm.add_hook(nemubot.hooks.Command(_help_msg, "help"), "in", "Command") - - from queue import Queue - # Messages to be treated - self.cnsr_queue = Queue() - self.cnsr_thrd = list() - self.cnsr_thrd_size = -1 - # Synchrone actions to be treated by main thread - self.sync_queue = Queue() + import asyncio + self._loop = asyncio.get_event_loop() def run(self): - from select import select - from nemubot.server import _lock, _rlist, _wlist, _xlist + self._loop.run_forever() + self._loop.close() - self.stop = False - while not self.stop: - with _lock: - try: - rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) - except: - logger.error("Something went wrong in select") - fnd_smth = False - # Looking for invalid server - for r in _rlist: - if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0: - _rlist.remove(r) - logger.error("Found invalid object in _rlist: " + str(r)) - fnd_smth = True - for w in _wlist: - if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0: - _wlist.remove(w) - logger.error("Found invalid object in _wlist: " + str(w)) - fnd_smth = True - for x in _xlist: - if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0: - _xlist.remove(x) - logger.error("Found invalid object in _xlist: " + str(x)) - fnd_smth = True - if not fnd_smth: - logger.exception("Can't continue, sorry") - self.quit() - continue - for x in xl: - try: - x.exception() - except: - logger.exception("Uncatched exception on server exception") - for w in wl: - try: - w.write_select() - except: - logger.exception("Uncatched exception on server write") - for r in rl: - for i in r.read(): - try: - self.receive_message(r, i) - except: - logger.exception("Uncatched exception on server read") - - - # Launch new consumer threads if necessary - while self.cnsr_queue.qsize() > self.cnsr_thrd_size: - # Next launch if two more items in queue - self.cnsr_thrd_size += 2 - - c = Consumer(self) - self.cnsr_thrd.append(c) - c.start() - - while self.sync_queue.qsize() > 0: - action = self.sync_queue.get_nowait() - if action[0] == "exit": - self.quit() - elif action[0] == "loadconf": - for path in action[1:]: - logger.debug("Load configuration from %s", path) - self.load_file(path) - logger.info("Configurations successfully loaded") - self.sync_queue.task_done() - - - - # Config methods - - def load_file(self, filename): - """Load a configuration file - - Arguments: - filename -- the path to the file to load - """ - - import os - - # Unexisting file, assume a name was passed, import the module! - if not os.path.isfile(filename): - return self.import_module(filename) - - from nemubot.channel import Channel - from nemubot import config - from nemubot.tools.xmlparser import XMLParser - - try: - p = XMLParser({ - "nemubotconfig": config.Nemubot, - "server": config.Server, - "channel": Channel, - "module": config.Module, - "include": config.Include, - }) - config = p.parse_file(filename) - except: - logger.exception("Can't load `%s'; this is not a valid nemubot " - "configuration file." % filename) - return False - - # Preset each server in this file - for server in config.servers: - srv = server.server(config) - # Add the server in the context - if self.add_server(srv, server.autoconnect): - logger.info("Server '%s' successfully added." % srv.id) - else: - logger.error("Can't add server '%s'." % srv.id) - - # Load module and their configuration - for mod in config.modules: - self.modules_configuration[mod.name] = mod - if mod.autoload: - try: - __import__(mod.name) - except: - logger.exception("Exception occurs when loading module" - " '%s'", mod.name) - - - # Load files asked by the configuration file - for load in config.includes: - self.load_file(load.path) - - - # Events methods - - def add_event(self, evt, eid=None, module_src=None): - """Register an event and return its identifiant for futur update - - Return: - None if the event is not in the queue (eg. if it has been executed during the call) or - returns the event ID. - - Argument: - evt -- The event object to add - - Keyword arguments: - eid -- The desired event ID (object or string UUID) - module_src -- The module to which the event is attached to - """ - - if hasattr(self, "stop") and self.stop: - logger.warn("The bot is stopped, can't register new events") - return - - import uuid - - # Generate the event id if no given - if eid is None: - eid = uuid.uuid1() - - # Fill the id field of the event - if type(eid) is uuid.UUID: - evt.id = str(eid) - else: - # Ok, this is quite useless... - try: - evt.id = str(uuid.UUID(eid)) - except ValueError: - evt.id = eid - - # TODO: mutex here plz - - # Add the event in its place - t = evt.current - i = 0 # sentinel - for i in range(0, len(self.events)): - if self.events[i].current > t: - break - self.events.insert(i, evt) - - if i == 0: - # First event changed, reset timer - self._update_event_timer() - if len(self.events) <= 0 or self.events[i] != evt: - # Our event has been executed and removed from queue - return None - - # Register the event in the source module - if module_src is not None: - module_src.__nemubot_context__.events.append(evt.id) - evt.module_src = module_src - - logger.info("New event registered in %d position: %s", i, t) - return evt.id - - - def del_event(self, evt, module_src=None): - """Find and remove an event from list - - Return: - True if the event has been found and removed, False else - - Argument: - evt -- The ModuleEvent object to remove or just the event identifier - - Keyword arguments: - module_src -- The module to which the event is attached to (ignored if evt is a ModuleEvent) - """ - - logger.info("Removing event: %s from %s", evt, module_src) - - from nemubot.event import ModuleEvent - if type(evt) is ModuleEvent: - id = evt.id - module_src = evt.module_src - else: - id = evt - - if len(self.events) > 0 and id == self.events[0].id: - self.events.remove(self.events[0]) - self._update_event_timer() - if module_src is not None: - module_src.__nemubot_context__.events.remove(id) - return True - - for evt in self.events: - if evt.id == id: - self.events.remove(evt) - - if module_src is not None: - module_src.__nemubot_context__.events.remove(evt.id) - return True - return False - - - def _update_event_timer(self): - """(Re)launch the timer to end with the closest event""" - - # Reset the timer if this is the first item - if self.event_timer is not None: - self.event_timer.cancel() - - if len(self.events): - remaining = self.events[0].time_left.total_seconds() - logger.debug("Update timer: next event in %d seconds", remaining) - self.event_timer = threading.Timer(remaining if remaining > 0 else 0, self._end_event_timer) - self.event_timer.start() - - else: - logger.debug("Update timer: no timer left") - - - def _end_event_timer(self): - """Function called at the end of the event timer""" - - while len(self.events) > 0 and datetime.now(timezone.utc) >= self.events[0].current: - evt = self.events.pop(0) - self.cnsr_queue.put_nowait(EventConsumer(evt)) - - self._update_event_timer() - - - # Consumers methods - - def add_server(self, srv, autoconnect=True): - """Add a new server to the context - - Arguments: - srv -- a concrete AbstractServer instance - autoconnect -- connect after add? - """ - - if srv.id not in self.servers: - self.servers[srv.id] = srv - if autoconnect and not hasattr(self, "noautoconnect"): - srv.open() - return True - - else: - return False - - - # Modules methods - - def import_module(self, name): - """Load a module - - Argument: - name -- name of the module to load - """ - - if name in self.modules: - self.unload_module(name) - - __import__(name) - - - def add_module(self, module): - """Add a module to the context, if already exists, unload the - old one before""" - module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ - - if hasattr(self, "stop") and self.stop: - logger.warn("The bot is stopped, can't register new modules") - return - - # Check if the module already exists - if module_name in self.modules: - self.unload_module(module_name) - - # Overwrite print built-in - def prnt(*args): - if hasattr(module, "logger"): - module.logger.info(" ".join([str(s) for s in args])) - else: - logger.info("[%s] %s", module_name, " ".join([str(s) for s in args])) - module.print = prnt - - # Create module context - from nemubot.modulecontext import ModuleContext - module.__nemubot_context__ = ModuleContext(self, module) - - if not hasattr(module, "logger"): - module.logger = logging.getLogger("nemubot.module." + module_name) - - # Replace imported context by real one - for attr in module.__dict__: - if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext: - module.__dict__[attr] = module.__nemubot_context__ - - # Register decorated functions - import nemubot.hooks - for s, h in nemubot.hooks.hook.last_registered: - module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s) - nemubot.hooks.hook.last_registered = [] - - # Launch the module - if hasattr(module, "load"): - try: - module.load(module.__nemubot_context__) - except: - module.__nemubot_context__.unload() - raise - - # Save a reference to the module - self.modules[module_name] = module - - - def unload_module(self, name): - """Unload a module""" - if name in self.modules: - self.modules[name].print("Unloading module %s" % name) - - # Call the user defined unload method - if hasattr(self.modules[name], "unload"): - self.modules[name].unload(self) - self.modules[name].__nemubot_context__.unload() - - # Remove from the nemubot dict - del self.modules[name] - - # Remove from the Python dict - del sys.modules[name] - for mod in [i for i in sys.modules]: - if mod[:len(name) + 1] == name + ".": - logger.debug("Module '%s' also removed from system modules list.", mod) - del sys.modules[mod] - - logger.info("Module `%s' successfully unloaded.", name) - - return True - return False - - - def receive_message(self, srv, msg): - """Queued the message for treatment - - Arguments: - srv -- The server where the message comes from - msg -- The message not parsed, as simple as possible - """ - - self.cnsr_queue.put_nowait(MessageConsumer(srv, msg)) - - - def quit(self): + def stop(self): """Save and unload modules and disconnect servers""" - self.datastore.close() - - if self.event_timer is not None: - logger.info("Stop the event timer...") - self.event_timer.cancel() - - logger.info("Stop consumers") - k = self.cnsr_thrd - for cnsr in k: - cnsr.stop = True - - logger.info("Save and unload all modules...") - k = list(self.modules.keys()) - for mod in k: - self.unload_module(mod) - - logger.info("Close all servers connection...") - k = list(self.servers.keys()) - for srv in k: - self.servers[srv].close() - - self.stop = True - - - # Treatment - - def check_rest_times(self, store, hook): - """Remove from store the hook if it has been executed given time""" - if hook.times == 0: - if isinstance(store, dict): - store[hook.name].remove(hook) - if len(store) == 0: - del store[hook.name] - elif isinstance(store, list): - store.remove(hook) - - -def hotswap(bak): - bak.stop = True - if bak.event_timer is not None: - bak.event_timer.cancel() - bak.datastore.close() - - new = Bot(str(bak.ip), bak.modules_paths, bak.datastore) - new.servers = bak.servers - new.modules = bak.modules - new.modules_configuration = bak.modules_configuration - new.events = bak.events - new.hooks = bak.hooks - - new._update_event_timer() - return new + self._loop.stop() diff --git a/nemubot/config/__init__.py b/nemubot/config/__init__.py index 7e0b74a..6bbc1b2 100644 --- a/nemubot/config/__init__.py +++ b/nemubot/config/__init__.py @@ -24,24 +24,3 @@ from nemubot.config.include import Include from nemubot.config.module import Module from nemubot.config.nemubot import Nemubot from nemubot.config.server import Server - -def reload(): - global Include, Module, Nemubot, Server - - import imp - - import nemubot.config.include - imp.reload(nemubot.config.include) - Include = nemubot.config.include.Include - - import nemubot.config.module - imp.reload(nemubot.config.module) - Module = nemubot.config.module.Module - - import nemubot.config.nemubot - imp.reload(nemubot.config.nemubot) - Nemubot = nemubot.config.nemubot.Nemubot - - import nemubot.config.server - imp.reload(nemubot.config.server) - Server = nemubot.config.server.Server diff --git a/nemubot/config/server.py b/nemubot/config/server.py index 14ca9a8..c08b40c 100644 --- a/nemubot/config/server.py +++ b/nemubot/config/server.py @@ -33,13 +33,11 @@ class Server: return True - def server(self, parent): + def server(self, caps=[], **kwargs): from nemubot.server import factory - for a in ["nick", "owner", "realname", "encoding"]: - if a not in self.args: - self.args[a] = getattr(parent, a) + caps += self.caps - self.caps += parent.caps + kwargs.update(self.args) - return factory(self.uri, caps=self.caps, channels=self.channels, **self.args) + return factory(self.uri, caps=caps, channels=self.channels, **kwargs) diff --git a/nemubot/consumer.py b/nemubot/consumer.py index 886c4cf..37234aa 100644 --- a/nemubot/consumer.py +++ b/nemubot/consumer.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -33,47 +33,47 @@ class MessageConsumer: def run(self, context): """Create, parse and treat the message""" + # Dereference weakptr + srv = self.srv() + if srv is None: + return + from nemubot.bot import Bot assert isinstance(context, Bot) msgs = [] - # Parse the message + # Parse message try: - for msg in self.srv.parse(self.orig): + for msg in srv.parse(self.orig): msgs.append(msg) except: logger.exception("Error occurred during the processing of the %s: " - "%s", type(self.msgs[0]).__name__, self.msgs[0]) + "%s", type(self.orig).__name__, self.orig) - if len(msgs) <= 0: - return - - # Qualify the message - if not hasattr(msg, "server") or msg.server is None: - msg.server = self.srv.id - if hasattr(msg, "frm_owner"): - msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm) - - # Treat the message + # Treat message for msg in msgs: for res in context.treater.treat_msg(msg): - # Identify the destination + # Identify destination to_server = None if isinstance(res, str): - to_server = self.srv + to_server = srv + elif not hasattr(res, "server"): + logger.error("No server defined for response of type %s: %s", type(res).__name__, res) + continue elif res.server is None: - to_server = self.srv - res.server = self.srv.id - elif isinstance(res.server, str) and res.server in context.servers: + to_server = srv + res.server = srv.fileno() + elif res.server in context.servers: to_server = context.servers[res.server] + else: + to_server = res.server - if to_server is None: - logger.error("The server defined in this response doesn't " - "exist: %s", res.server) + if to_server is None or not hasattr(to_server, "send_response") or not callable(to_server.send_response): + logger.error("The server defined in this response doesn't exist: %s", res.server) continue - # Sent the message only if treat_post authorize it + # Sent message to_server.send_response(res) @@ -94,12 +94,7 @@ class EventConsumer: # Reappend the event in the queue if it has next iteration if self.evt.next is not None: - context.add_event(self.evt, eid=self.evt.id) - - # Or remove reference of this event - elif (hasattr(self.evt, "module_src") and - self.evt.module_src is not None): - self.evt.module_src.__nemubot_context__.events.remove(self.evt.id) + context.add_event(self.evt) @@ -108,20 +103,23 @@ class Consumer(threading.Thread): """Dequeue and exec requested action""" def __init__(self, context): + super().__init__(name="Nemubot consumer") self.context = context - self.stop = False - threading.Thread.__init__(self) def run(self): try: - while not self.stop: - stm = self.context.cnsr_queue.get(True, 1) - stm.run(self.context) - self.context.cnsr_queue.task_done() + while True: + context = self.context() + if context is None: + break + + stm = context.cnsr_queue.get(True, 1) + stm.run(context) + context.cnsr_queue.task_done() except queue.Empty: pass finally: - self.context.cnsr_thrd_size -= 2 - self.context.cnsr_thrd.remove(self) + if self.context() is not None: + self.context().cnsr_thrd.remove(self) diff --git a/nemubot/datastore/__init__.py b/nemubot/datastore/__init__.py index 411eab1..3e38ad2 100644 --- a/nemubot/datastore/__init__.py +++ b/nemubot/datastore/__init__.py @@ -16,16 +16,3 @@ from nemubot.datastore.abstract import Abstract from nemubot.datastore.xml import XML - - -def reload(): - global Abstract, XML - import imp - - import nemubot.datastore.abstract - imp.reload(nemubot.datastore.abstract) - Abstract = nemubot.datastore.abstract.Abstract - - import nemubot.datastore.xml - imp.reload(nemubot.datastore.xml) - XML = nemubot.datastore.xml.XML diff --git a/nemubot/exception/__init__.py b/nemubot/exception/__init__.py index 1e34923..84464a0 100644 --- a/nemubot/exception/__init__.py +++ b/nemubot/exception/__init__.py @@ -32,10 +32,3 @@ class IMException(Exception): from nemubot.message import Text return Text(*self.args, server=msg.server, to=msg.to_response) - - -def reload(): - import imp - - import nemubot.exception.Keyword - imp.reload(nemubot.exception.printer.IRC) diff --git a/nemubot/hooks/__init__.py b/nemubot/hooks/__init__.py index e9113eb..9024494 100644 --- a/nemubot/hooks/__init__.py +++ b/nemubot/hooks/__init__.py @@ -49,23 +49,3 @@ class hook: def pre(*args, store=["pre"], **kwargs): return hook._add(store, Abstract, *args, **kwargs) - - -def reload(): - import imp - - import nemubot.hooks.abstract - imp.reload(nemubot.hooks.abstract) - - import nemubot.hooks.command - imp.reload(nemubot.hooks.command) - - import nemubot.hooks.message - imp.reload(nemubot.hooks.message) - - import nemubot.hooks.keywords - imp.reload(nemubot.hooks.keywords) - nemubot.hooks.keywords.reload() - - import nemubot.hooks.manager - imp.reload(nemubot.hooks.manager) diff --git a/nemubot/hooks/keywords/__init__.py b/nemubot/hooks/keywords/__init__.py index 4b6419a..598b04f 100644 --- a/nemubot/hooks/keywords/__init__.py +++ b/nemubot/hooks/keywords/__init__.py @@ -26,11 +26,22 @@ class NoKeyword(Abstract): return super().check(mkw) -def reload(): - import imp +class AnyKeyword(Abstract): - import nemubot.hooks.keywords.abstract - imp.reload(nemubot.hooks.keywords.abstract) + def __init__(self, h): + """Class that accepts any passed keywords - import nemubot.hooks.keywords.dict - imp.reload(nemubot.hooks.keywords.dict) + Arguments: + h -- Help string + """ + + super().__init__() + self.h = h + + + def check(self, mkw): + return super().check(mkw) + + + def help(self): + return self.h diff --git a/nemubot/message/__init__.py b/nemubot/message/__init__.py index 31d7313..4d69dbb 100644 --- a/nemubot/message/__init__.py +++ b/nemubot/message/__init__.py @@ -19,27 +19,3 @@ from nemubot.message.text import Text from nemubot.message.directask import DirectAsk from nemubot.message.command import Command from nemubot.message.command import OwnerCommand - - -def reload(): - global Abstract, Text, DirectAsk, Command, OwnerCommand - import imp - - import nemubot.message.abstract - imp.reload(nemubot.message.abstract) - Abstract = nemubot.message.abstract.Abstract - imp.reload(nemubot.message.text) - Text = nemubot.message.text.Text - imp.reload(nemubot.message.directask) - DirectAsk = nemubot.message.directask.DirectAsk - imp.reload(nemubot.message.command) - Command = nemubot.message.command.Command - OwnerCommand = nemubot.message.command.OwnerCommand - - import nemubot.message.visitor - imp.reload(nemubot.message.visitor) - - import nemubot.message.printer - imp.reload(nemubot.message.printer) - - nemubot.message.printer.reload() diff --git a/nemubot/message/command.py b/nemubot/message/command.py index 895d16e..6c208b2 100644 --- a/nemubot/message/command.py +++ b/nemubot/message/command.py @@ -22,7 +22,7 @@ class Command(Abstract): """This class represents a specialized TextMessage""" def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs): - Abstract.__init__(self, *nargs, **kargs) + super().__init__(*nargs, **kargs) self.cmd = cmd self.args = args if args is not None else list() diff --git a/nemubot/message/directask.py b/nemubot/message/directask.py index 03c7902..3b1fabb 100644 --- a/nemubot/message/directask.py +++ b/nemubot/message/directask.py @@ -28,7 +28,7 @@ class DirectAsk(Text): designated -- the user designated by the message """ - Text.__init__(self, *args, **kargs) + super().__init__(*args, **kargs) self.designated = designated diff --git a/nemubot/message/printer/IRC.py b/nemubot/message/printer/IRC.py index 320366c..df9cb9f 100644 --- a/nemubot/message/printer/IRC.py +++ b/nemubot/message/printer/IRC.py @@ -22,4 +22,4 @@ class IRC(SocketPrinter): def visit_Text(self, msg): self.pp += "PRIVMSG %s :" % ",".join(msg.to) - SocketPrinter.visit_Text(self, msg) + super().visit_Text(msg) diff --git a/nemubot/message/printer/__init__.py b/nemubot/message/printer/__init__.py index 060118b..e0fbeef 100644 --- a/nemubot/message/printer/__init__.py +++ b/nemubot/message/printer/__init__.py @@ -13,12 +13,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - -def reload(): - import imp - - import nemubot.message.printer.IRC - imp.reload(nemubot.message.printer.IRC) - - import nemubot.message.printer.socket - imp.reload(nemubot.message.printer.socket) diff --git a/nemubot/message/printer/socket.py b/nemubot/message/printer/socket.py index cb9bc4c..6884c88 100644 --- a/nemubot/message/printer/socket.py +++ b/nemubot/message/printer/socket.py @@ -35,7 +35,7 @@ class Socket(AbstractVisitor): others = [to for to in msg.to if to != msg.designated] # Avoid nick starting message when discussing on user channel - if len(others) != len(msg.to): + if len(others) == 0 or len(others) != len(msg.to): res = Text(msg.message, server=msg.server, date=msg.date, to=msg.to, frm=msg.frm) diff --git a/nemubot/prompt/reset.py b/nemubot/message/response.py similarity index 62% rename from nemubot/prompt/reset.py rename to nemubot/message/response.py index 57da9f8..fba864b 100644 --- a/nemubot/prompt/reset.py +++ b/nemubot/message/response.py @@ -1,5 +1,3 @@ -# coding=utf-8 - # Nemubot is a smart and modulable IM bot. # Copyright (C) 2012-2015 Mercier Pierre-Olivier # @@ -16,8 +14,21 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -class PromptReset(Exception): +from nemubot.message.abstract import Abstract - def __init__(self, type): - super(PromptReset, self).__init__("Prompt reset asked") - self.type = type + +class Response(Abstract): + + def __init__(self, cmd, args=None, *nargs, **kargs): + super().__init__(*nargs, **kargs) + + self.cmd = cmd + self.args = args if args is not None else list() + + def __str__(self): + return self.cmd + " @" + ",@".join(self.args) + + @property + def cmds(self): + # TODO: this is for legacy modules + return [self.cmd] + self.args diff --git a/nemubot/message/text.py b/nemubot/message/text.py index ec90a36..f691a04 100644 --- a/nemubot/message/text.py +++ b/nemubot/message/text.py @@ -28,7 +28,7 @@ class Text(Abstract): message -- the parsed message """ - Abstract.__init__(self, *args, **kargs) + super().__init__(*args, **kargs) self.message = message diff --git a/nemubot/modulecontext.py b/nemubot/modulecontext.py index 1d1b3d0..572b14b 100644 --- a/nemubot/modulecontext.py +++ b/nemubot/modulecontext.py @@ -16,19 +16,15 @@ class ModuleContext: - def __init__(self, context, module): + def __init__(self, context, module_name, logger): """Initialize the module context arguments: context -- the bot context - module -- the module + module_name -- the module name + logger -- a logger """ - if module is not None: - module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ - else: - module_name = "" - # Load module configuration if exists if (context is not None and module_name in context.modules_configuration): @@ -60,10 +56,16 @@ class ModuleContext: def subtreat(msg): yield from context.treater.treat_msg(msg) - def add_event(evt, eid=None): - return context.add_event(evt, eid, module_src=module) + def add_event(call=None, **kwargs): + evt = context.add_event(call, **kwargs) + if evt is not None: + self.events.append(evt) + return evt def del_event(evt): - return context.del_event(evt, module_src=module) + if context.del_event(evt): + self._clean_events() + return True + return False def send_response(server, res): if server in context.servers: @@ -72,7 +74,7 @@ class ModuleContext: else: return context.servers[server].send_response(res) else: - module.logger.error("Try to send a message to the unknown server: %s", server) + logger.error("Try to send a message to the unknown server: %s", server) return False else: # Used when using outside of nemubot @@ -88,13 +90,13 @@ class ModuleContext: self.hooks.remove((triggers, hook)) def subtreat(msg): return None - def add_event(evt, eid=None): - return context.add_event(evt, eid, module_src=module) + def add_event(evt): + return context.add_event(evt) def del_event(evt): return context.del_event(evt, module_src=module) def send_response(server, res): - module.logger.info("Send response: %s", res) + logger.info("Send response: %s", res) def save(): context.datastore.save(module_name, self.data) @@ -121,6 +123,14 @@ class ModuleContext: return self._data + def _clean_events(self): + """Look for None weakref in the events list""" + + for i in range(len(self.events), 0, -1): + if self.events[i-1]() is None: + self.events.remove() + + def unload(self): """Perform actions for unloading the module""" diff --git a/nemubot/prompt/__init__.py b/nemubot/prompt/__init__.py deleted file mode 100644 index 27f7919..0000000 --- a/nemubot/prompt/__init__.py +++ /dev/null @@ -1,142 +0,0 @@ -# Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -import shlex -import sys -import traceback - -from nemubot.prompt import builtins - - -class Prompt: - - def __init__(self): - self.selectedServer = None - self.lastretcode = 0 - - self.HOOKS_CAPS = dict() - self.HOOKS_LIST = dict() - - def add_cap_hook(self, name, call, data=None): - self.HOOKS_CAPS[name] = lambda t, c: call(t, data=data, - context=c, prompt=self) - - def add_list_hook(self, name, call): - self.HOOKS_LIST[name] = call - - def lex_cmd(self, line): - """Return an array of tokens - - Argument: - line -- the line to lex - """ - - try: - cmds = shlex.split(line) - except: - exc_type, exc_value, _ = sys.exc_info() - sys.stderr.write(traceback.format_exception_only(exc_type, - exc_value)[0]) - return - - bgn = 0 - - # Separate commands (command separator: ;) - for i in range(0, len(cmds)): - if cmds[i][-1] == ';': - if i != bgn: - yield cmds[bgn:i] - bgn = i + 1 - - # Return rest of the command (that not end with a ;) - if bgn != len(cmds): - yield cmds[bgn:] - - def exec_cmd(self, toks, context): - """Execute the command - - Arguments: - toks -- lexed tokens to executes - context -- current bot context - """ - - if toks[0] in builtins.CAPS: - self.lastretcode = builtins.CAPS[toks[0]](toks, context, self) - elif toks[0] in self.HOOKS_CAPS: - self.lastretcode = self.HOOKS_CAPS[toks[0]](toks, context) - else: - print("Unknown command: `%s'" % toks[0]) - self.lastretcode = 127 - - def getPS1(self): - """Get the PS1 associated to the selected server""" - if self.selectedServer is None: - return "nemubot" - else: - return self.selectedServer.id - - def run(self, context): - """Launch the prompt - - Argument: - context -- current bot context - """ - - from nemubot.prompt.error import PromptError - from nemubot.prompt.reset import PromptReset - - while True: # Stopped by exception - try: - line = input("\033[0;33m%s\033[0;%dm§\033[0m " % - (self.getPS1(), 31 if self.lastretcode else 32)) - cmds = self.lex_cmd(line.strip()) - for toks in cmds: - try: - self.exec_cmd(toks, context) - except PromptReset: - raise - except PromptError as e: - print(e.message) - self.lastretcode = 128 - except: - exc_type, exc_value, exc_traceback = sys.exc_info() - traceback.print_exception(exc_type, exc_value, - exc_traceback) - except KeyboardInterrupt: - print("") - except EOFError: - print("quit") - return True - - -def hotswap(bak): - p = Prompt() - p.HOOKS_CAPS = bak.HOOKS_CAPS - p.HOOKS_LIST = bak.HOOKS_LIST - return p - - -def reload(): - import imp - - import nemubot.prompt.builtins - imp.reload(nemubot.prompt.builtins) - - import nemubot.prompt.error - imp.reload(nemubot.prompt.error) - - import nemubot.prompt.reset - imp.reload(nemubot.prompt.reset) diff --git a/nemubot/prompt/builtins.py b/nemubot/prompt/builtins.py deleted file mode 100644 index a020fb9..0000000 --- a/nemubot/prompt/builtins.py +++ /dev/null @@ -1,128 +0,0 @@ -# Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -def end(toks, context, prompt): - """Quit the prompt for reload or exit""" - from nemubot.prompt.reset import PromptReset - - if toks[0] == "refresh": - raise PromptReset("refresh") - elif toks[0] == "reset": - raise PromptReset("reset") - raise PromptReset("quit") - - -def liste(toks, context, prompt): - """Show some lists""" - if len(toks) > 1: - for l in toks[1:]: - l = l.lower() - if l == "server" or l == "servers": - for srv in context.servers.keys(): - print (" - %s (state: %s) ;" % (srv, - "connected" if context.servers[srv].connected else "disconnected")) - if len(context.servers) == 0: - print (" > No server loaded") - - elif l == "mod" or l == "mods" or l == "module" or l == "modules": - for mod in context.modules.keys(): - print (" - %s ;" % mod) - if len(context.modules) == 0: - print (" > No module loaded") - - elif l in prompt.HOOKS_LIST: - f, d = prompt.HOOKS_LIST[l] - f(d, context, prompt) - - else: - print (" Unknown list `%s'" % l) - return 2 - return 0 - else: - print (" Please give a list to show: servers, ...") - return 1 - - -def load(toks, context, prompt): - """Load an XML configuration file""" - if len(toks) > 1: - for filename in toks[1:]: - context.load_file(filename) - else: - print ("Not enough arguments. `load' takes a filename.") - return 1 - - -def select(toks, context, prompt): - """Select the current server""" - if (len(toks) == 2 and toks[1] != "None" and - toks[1] != "nemubot" and toks[1] != "none"): - if toks[1] in context.servers: - prompt.selectedServer = context.servers[toks[1]] - else: - print ("select: server `%s' not found." % toks[1]) - return 1 - else: - prompt.selectedServer = None - - -def unload(toks, context, prompt): - """Unload a module""" - if len(toks) == 2 and toks[1] == "all": - for name in context.modules.keys(): - context.unload_module(name) - elif len(toks) > 1: - for name in toks[1:]: - if context.unload_module(name): - print (" Module `%s' successfully unloaded." % name) - else: - print (" No module `%s' loaded, can't unload!" % name) - return 2 - else: - print ("Not enough arguments. `unload' takes a module name.") - return 1 - - -def debug(toks, context, prompt): - """Enable/Disable debug mode on a module""" - if len(toks) > 1: - for name in toks[1:]: - if name in context.modules: - context.modules[name].DEBUG = not context.modules[name].DEBUG - if context.modules[name].DEBUG: - print (" Module `%s' now in DEBUG mode." % name) - else: - print (" Debug for module module `%s' disabled." % name) - else: - print (" No module `%s' loaded, can't debug!" % name) - return 2 - else: - print ("Not enough arguments. `debug' takes a module name.") - return 1 - - -# Register build-ins -CAPS = { - 'quit': end, # Disconnect all server and quit - 'exit': end, # Alias for quit - 'reset': end, # Reload the prompt - 'refresh': end, # Reload the prompt but save modules - 'load': load, # Load a servers or module configuration file - 'unload': unload, # Unload a module and remove it from the list - 'select': select, # Select a server - 'list': liste, # Show lists - 'debug': debug, # Pass a module in debug mode -} diff --git a/nemubot/server/DCC.py b/nemubot/server/DCC.py index 6655d52..c1a6852 100644 --- a/nemubot/server/DCC.py +++ b/nemubot/server/DCC.py @@ -31,7 +31,7 @@ PORTS = list() class DCC(server.AbstractServer): def __init__(self, srv, dest, socket=None): - server.Server.__init__(self) + super().__init__(name="Nemubot DCC server") self.error = False # An error has occur, closing the connection? self.messages = list() # Message queued before connexion diff --git a/nemubot/server/IRC.py b/nemubot/server/IRC.py index e433176..6affe55 100644 --- a/nemubot/server/IRC.py +++ b/nemubot/server/IRC.py @@ -20,17 +20,17 @@ import re from nemubot.channel import Channel from nemubot.message.printer.IRC import IRC as IRCPrinter from nemubot.server.message.IRC import IRC as IRCMessage -from nemubot.server.socket import SocketServer +from nemubot.server.socket import SocketServer, SecureSocketServer -class IRC(SocketServer): +class IRC(): - """Concrete implementation of a connexion to an IRC server""" + """Concrete implementation of a connection to an IRC server""" - def __init__(self, host="localhost", port=6667, ssl=False, owner=None, + def __init__(self, host="localhost", port=6667, owner=None, nick="nemubot", username=None, password=None, realname="Nemubot", encoding="utf-8", caps=None, - channels=list(), on_connect=None): + channels=list(), on_connect=None, **kwargs): """Prepare a connection with an IRC server Keyword arguments: @@ -54,8 +54,8 @@ class IRC(SocketServer): self.owner = owner self.realname = realname - self.id = self.username + "@" + host + ":" + str(port) - SocketServer.__init__(self, host=host, port=port, ssl=ssl) + super().__init__(name=self.username + "@" + host + ":" + str(port), + host=host, port=port, **kwargs) self.printer = IRCPrinter self.encoding = encoding @@ -232,29 +232,29 @@ class IRC(SocketServer): # Open/close - def _open(self): - if SocketServer._open(self): - if self.password is not None: - self.write("PASS :" + self.password) - if self.capabilities is not None: - self.write("CAP LS") - self.write("NICK :" + self.nick) - self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) - return True - return False + def connect(self): + super().connect() + + if self.password is not None: + self.write("PASS :" + self.password) + if self.capabilities is not None: + self.write("CAP LS") + self.write("NICK :" + self.nick) + self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) - def _close(self): - if self.connected: self.write("QUIT") - return SocketServer._close(self) + def close(self): + if not self._closed: + self.write("QUIT") + return super().close() # Writes: as inherited # Read - def read(self): - for line in SocketServer.read(self): + def async_read(self): + for line in super().async_read(): # PING should be handled here, so start parsing here :/ msg = IRCMessage(line, self.encoding) diff --git a/nemubot/server/__init__.py b/nemubot/server/__init__.py index b9a8fe4..a533491 100644 --- a/nemubot/server/__init__.py +++ b/nemubot/server/__init__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,75 +14,65 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import threading -_lock = threading.Lock() - -# Lists for select -_rlist = [] -_wlist = [] -_xlist = [] - - -def factory(uri, **init_args): - from urllib.parse import urlparse, unquote +def factory(uri, ssl=False, **init_args): + from urllib.parse import urlparse, unquote, parse_qs o = urlparse(uri) + srv = None + if o.scheme == "irc" or o.scheme == "ircs": # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html args = init_args - modifiers = o.path.split(",") - target = unquote(modifiers.pop(0)[1:]) - - if o.scheme == "ircs": args["ssl"] = True + if o.scheme == "ircs": ssl = True if o.hostname is not None: args["host"] = o.hostname if o.port is not None: args["port"] = o.port if o.username is not None: args["username"] = o.username if o.password is not None: args["password"] = o.password - queries = o.query.split("&") - for q in queries: - if "=" in q: - key, val = tuple(q.split("=", 1)) - else: - key, val = q, "" - if key == "msg": - if "on_connect" not in args: - args["on_connect"] = [] - args["on_connect"].append("PRIVMSG %s :%s" % (target, unquote(val))) - elif key == "key": - if "channels" not in args: - args["channels"] = [] - args["channels"].append((target, unquote(val))) - elif key == "pass": - args["password"] = unquote(val) - elif key == "charset": - args["encoding"] = unquote(val) + if ssl: + try: + from ssl import create_default_context + args["_context"] = create_default_context() + except ImportError: + # Python 3.3 compat + from ssl import SSLContext, PROTOCOL_TLSv1 + args["_context"] = SSLContext(PROTOCOL_TLSv1) + args["server_hostname"] = o.hostname + modifiers = o.path.split(",") + target = unquote(modifiers.pop(0)[1:]) + + # Read query string + params = parse_qs(o.query) + + if "msg" in params: + if "on_connect" not in args: + args["on_connect"] = [] + args["on_connect"].append("PRIVMSG %s :%s" % (target, params["msg"])) + + if "key" in params: + if "channels" not in args: + args["channels"] = [] + args["channels"].append((target, params["key"])) + + if "pass" in params: + args["password"] = params["pass"] + + if "charset" in params: + args["encoding"] = params["charset"] + + # if "channels" not in args and "isnick" not in modifiers: args["channels"] = [ target ] - from nemubot.server.IRC import IRC as IRCServer - return IRCServer(**args) - else: - return None + if ssl: + from nemubot.server.IRC import IRC_secure as SecureIRCServer + srv = SecureIRCServer(**args) + else: + from nemubot.server.IRC import IRC as IRCServer + srv = IRCServer(**args) - -def reload(): - import imp - - import nemubot.server.abstract - imp.reload(nemubot.server.abstract) - - import nemubot.server.socket - imp.reload(nemubot.server.socket) - - import nemubot.server.IRC - imp.reload(nemubot.server.IRC) - - import nemubot.server.message - imp.reload(nemubot.server.message) - - nemubot.server.message.reload() + return srv diff --git a/nemubot/server/abstract.py b/nemubot/server/abstract.py index ebcb427..48c5104 100644 --- a/nemubot/server/abstract.py +++ b/nemubot/server/abstract.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,71 +14,64 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import io import logging import queue -from nemubot.server import _lock, _rlist, _wlist, _xlist +#from nemubot.bot import sync_act -# Extends from IOBase in order to be compatible with select function -class AbstractServer(io.IOBase): + +class AbstractServer: """An abstract server: handle communication with an IM server""" - def __init__(self, send_callback=None): + def __init__(self, name=None, **kwargs): """Initialize an abstract server Keyword argument: - send_callback -- Callback when developper want to send a message + name -- Identifier of the socket, for convinience """ - if not hasattr(self, "id"): - raise Exception("No id defined for this server. Please set one!") + self._name = name - self.logger = logging.getLogger("nemubot.server." + self.id) + super().__init__(**kwargs) + + self.logger = logging.getLogger("nemubot.server." + str(self.name)) + self._readbuffer = b'' self._sending_queue = queue.Queue() - if send_callback is not None: - self._send_callback = send_callback + + + @property + def name(self): + if self._name is not None: + return self._name else: - self._send_callback = self._write_select + return self.fileno() # Open/close - def __enter__(self): - self.open() - return self + def connect(self, *args, **kwargs): + """Register the server in _poll""" + + self.logger.info("Opening connection") + + super().connect(*args, **kwargs) + + self._on_connect() + + def _on_connect(self): + sync_act("sckt", "register", self.fileno()) - def __exit__(self, type, value, traceback): - self.close() + def close(self, *args, **kwargs): + """Unregister the server from _poll""" + self.logger.info("Closing connection") - def open(self): - """Generic open function that register the server un _rlist in case - of successful _open""" - self.logger.info("Opening connection to %s", self.id) - if not hasattr(self, "_open") or self._open(): - _rlist.append(self) - _xlist.append(self) - return True - return False + if self.fileno() > 0: + sync_act("sckt", "unregister", self.fileno()) - - def close(self): - """Generic close function that register the server un _{r,w,x}list in - case of successful _close""" - self.logger.info("Closing connection to %s", self.id) - with _lock: - if not hasattr(self, "_close") or self._close(): - if self in _rlist: - _rlist.remove(self) - if self in _wlist: - _wlist.remove(self) - if self in _xlist: - _xlist.remove(self) - return True - return False + super().close(*args, **kwargs) # Writes @@ -90,13 +83,16 @@ class AbstractServer(io.IOBase): message -- message to send """ - self._send_callback(message) + self._sending_queue.put(self.format(message)) + self.logger.debug("Message '%s' appended to write queue", message) + sync_act("sckt", "write", self.fileno()) - def write_select(self): - """Internal function used by the select function""" + def async_write(self): + """Internal function used when the file descriptor is writable""" + try: - _wlist.remove(self) + sync_act("sckt", "unwrite", self.fileno()) while not self._sending_queue.empty(): self._write(self._sending_queue.get_nowait()) self._sending_queue.task_done() @@ -105,19 +101,6 @@ class AbstractServer(io.IOBase): pass - def _write_select(self, message): - """Send a message to the server safely through select - - Argument: - message -- message to send - """ - - self._sending_queue.put(self.format(message)) - self.logger.debug("Message '%s' appended to Queue", message) - if self not in _wlist: - _wlist.append(self) - - def send_response(self, response): """Send a formated Message class @@ -140,13 +123,39 @@ class AbstractServer(io.IOBase): # Read + def async_read(self): + """Internal function used when the file descriptor is readable + + Returns: + A list of fully received messages + """ + + ret, self._readbuffer = self.lex(self._readbuffer + self.read()) + + for r in ret: + yield r + + + def lex(self, buf): + """Assume lexing in default case is per line + + Argument: + buf -- buffer to lex + """ + + msgs = buf.split(b'\r\n') + partial = msgs.pop() + + return msgs, partial + + def parse(self, msg): raise NotImplemented # Exceptions - def exception(self): - """Exception occurs in fd""" - self.logger.warning("Unhandle file descriptor exception on server %s", - self.id) + def exception(self, flags): + """Exception occurs on fd""" + + self.close() diff --git a/nemubot/server/factory_test.py b/nemubot/server/factory_test.py index cc7d35b..358591e 100644 --- a/nemubot/server/factory_test.py +++ b/nemubot/server/factory_test.py @@ -22,34 +22,30 @@ class TestFactory(unittest.TestCase): def test_IRC1(self): from nemubot.server.IRC import IRC as IRCServer + from nemubot.server.IRC import IRC_secure as IRCSServer # : If omitted, the client must connect to a prespecified default IRC server. server = factory("irc:///") self.assertIsInstance(server, IRCServer) self.assertEqual(server.host, "localhost") - self.assertFalse(server.ssl) server = factory("ircs:///") - self.assertIsInstance(server, IRCServer) + self.assertIsInstance(server, IRCSServer) self.assertEqual(server.host, "localhost") - self.assertTrue(server.ssl) server = factory("irc://host1") self.assertIsInstance(server, IRCServer) self.assertEqual(server.host, "host1") - self.assertFalse(server.ssl) server = factory("irc://host2:6667") self.assertIsInstance(server, IRCServer) self.assertEqual(server.host, "host2") self.assertEqual(server.port, 6667) - self.assertFalse(server.ssl) server = factory("ircs://host3:194/") - self.assertIsInstance(server, IRCServer) + self.assertIsInstance(server, IRCSServer) self.assertEqual(server.host, "host3") self.assertEqual(server.port, 194) - self.assertTrue(server.ssl) if __name__ == '__main__': diff --git a/nemubot/server/message/IRC.py b/nemubot/server/message/IRC.py index f6d562f..4c9e280 100644 --- a/nemubot/server/message/IRC.py +++ b/nemubot/server/message/IRC.py @@ -146,7 +146,7 @@ class IRC(Abstract): receivers = self.decode(self.params[0]).split(',') common_args = { - "server": srv.id, + "server": srv.name, "date": self.tags["time"], "to": receivers, "to_response": [r if r != srv.nick else self.nick for r in receivers], diff --git a/nemubot/prompt/error.py b/nemubot/server/message/__init__.py similarity index 83% rename from nemubot/prompt/error.py rename to nemubot/server/message/__init__.py index f86b5a1..57f3468 100644 --- a/nemubot/prompt/error.py +++ b/nemubot/server/message/__init__.py @@ -13,9 +13,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - -class PromptError(Exception): - - def __init__(self, message): - super(PromptError, self).__init__(message) - self.message = message diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py index b6c00d4..f5b9c9a 100644 --- a/nemubot/server/socket.py +++ b/nemubot/server/socket.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,99 +14,33 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import os +import socket +import ssl + +import nemubot.message as message from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.server.abstract import AbstractServer -class SocketServer(AbstractServer): +class _Socket(asyncio.Protocol): - """Concrete implementation of a socket connexion (can be wrapped with TLS)""" + """Concrete implementation of a socket connection""" - def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None): - if id is not None: - self.id = id - AbstractServer.__init__(self) - if sock_location is not None: - self.filename = sock_location - elif host is not None: - self.host = host - self.port = int(port) - self.ssl = ssl + def __init__(self, printer=SocketPrinter, **kwargs): + """Create a server socket + """ + + super().__init__(**kwargs) - self.socket = socket self.readbuffer = b'' - self.printer = SocketPrinter - - - def fileno(self): - return self.socket.fileno() if self.socket else None - - - @property - def connected(self): - """Indicator of the connection aliveness""" - return self.socket is not None - - - # Open/close - - def _open(self): - import os - import socket - - if self.connected: - return True - - try: - if hasattr(self, "filename"): - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.socket.connect(self.filename) - self.logger.info("Connected to %s", self.filename) - else: - self.socket = socket.create_connection((self.host, self.port)) - self.logger.info("Connected to %s:%d", self.host, self.port) - except socket.error as e: - self.socket = None - self.logger.critical("Unable to connect to %s:%d: %s", - self.host, self.port, - os.strerror(e.errno)) - return False - - # Wrap the socket for SSL - if self.ssl: - import ssl - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - self.socket = ctx.wrap_socket(self.socket) - - return True - - - def _close(self): - import socket - - from nemubot.server import _lock - _lock.release() - self._sending_queue.join() - _lock.acquire() - if self.connected: - try: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() - except socket.error: - pass - - self.socket = None - - return True + self.printer = printer # Write def _write(self, cnt): - if not self.connected: - return - - self.socket.send(cnt) + self.sendall(cnt) def format(self, txt): @@ -118,80 +52,134 @@ class SocketServer(AbstractServer): # Read - def read(self): - if not self.connected: - return [] - - raw = self.socket.recv(1024) - temp = (self.readbuffer + raw).split(b'\r\n') - self.readbuffer = temp.pop() - - for line in temp: - yield line + def recv(self, n=1024): + return super().recv(n) -class SocketListener(AbstractServer): + def parse(self, line): + """Implement a default behaviour for socket""" + import shlex - def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None): - self.id = id - AbstractServer.__init__(self) - self.new_server_cb = new_server_cb - self.sock_location = sock_location - self.host = host - self.port = port - self.ssl = ssl - self.nb_son = 0 + line = line.strip().decode() + try: + args = shlex.split(line) + except ValueError: + args = line.split(' ') + + if len(args): + yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you") - def fileno(self): - return self.socket.fileno() if self.socket else None + def subparse(self, orig, cnt): + for m in self.parse(cnt): + m.to = orig.to + m.frm = orig.frm + m.date = orig.date + yield m + + +class _SocketServer(_Socket): + + def __init__(self, host, port, bind=None, **kwargs): + super().__init__(family=socket.AF_INET, **kwargs) + + assert(host is not None) + assert(isinstance(port, int)) + + self._host = host + self._port = port + self._bind = bind @property - def connected(self): - """Indicator of the connection aliveness""" - return self.socket is not None + def host(self): + return self._host - def _open(self): - import os - import socket + def connect(self): + self.logger.info("Connection to %s:%d", self._host, self._port) + super().connect((self._host, self._port)) - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - if self.sock_location is not None: - try: - os.remove(self.sock_location) - except FileNotFoundError: - pass - self.socket.bind(self.sock_location) - elif self.host is not None and self.port is not None: - self.socket.bind((self.host, self.port)) - self.socket.listen(5) - - return True + if self._bind: + super().bind(self._bind) - def _close(self): +class SocketServer(_SocketServer, socket.socket): + pass + + +class SecureSocketServer(_SocketServer, ssl.SSLSocket): + pass + + +class UnixSocket: + + def __init__(self, location, **kwargs): + super().__init__(family=socket.AF_UNIX, **kwargs) + + self._socket_path = location + + + def connect(self): + self.logger.info("Connection to unix://%s", self._socket_path) + super().connect(self._socket_path) + + +class _Listener: + + def __init__(self, new_server_cb, instanciate=_Socket, **kwargs): + super().__init__(**kwargs) + + self._instanciate = instanciate + self._new_server_cb = new_server_cb + + + def read(self): + conn, addr = self.accept() + fileno = conn.fileno() + self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno) + + ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach()) + ss.connect = ss._on_connect + self._new_server_cb(ss, autoconnect=True) + + return b'' + + +class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + + def connect(self): + self.logger.info("Creating Unix socket at unix://%s", self._socket_path) + + try: + os.remove(self._socket_path) + except FileNotFoundError: + pass + + self.bind(self._socket_path) + self.listen(5) + self.logger.info("Socket ready for accepting new connections") + + self._on_connect() + + + def close(self): import os import socket try: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() - if self.sock_location is not None: - os.remove(self.sock_location) + self.shutdown(socket.SHUT_RDWR) except socket.error: pass - # Read + super().close() - def read(self): - if not self.connected: - return [] - - conn, addr = self.socket.accept() - self.nb_son += 1 - ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn) - self.new_server_cb(ss) - - return [] + try: + if self._socket_path is not None: + os.remove(self._socket_path) + except: + pass diff --git a/nemubot/tools/__init__.py b/nemubot/tools/__init__.py index 127154c..57f3468 100644 --- a/nemubot/tools/__init__.py +++ b/nemubot/tools/__init__.py @@ -13,29 +13,3 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - -def reload(): - import imp - - import nemubot.tools.config - imp.reload(nemubot.tools.config) - - import nemubot.tools.countdown - imp.reload(nemubot.tools.countdown) - - import nemubot.tools.feed - imp.reload(nemubot.tools.feed) - - import nemubot.tools.date - imp.reload(nemubot.tools.date) - - import nemubot.tools.human - imp.reload(nemubot.tools.human) - - import nemubot.tools.web - imp.reload(nemubot.tools.web) - - import nemubot.tools.xmlparser - imp.reload(nemubot.tools.xmlparser) - import nemubot.tools.xmlparser.node - imp.reload(nemubot.tools.xmlparser.node) diff --git a/setup.py b/setup.py index b39a163..36dddb4 100755 --- a/setup.py +++ b/setup.py @@ -69,7 +69,6 @@ setup( 'nemubot.hooks.keywords', 'nemubot.message', 'nemubot.message.printer', - 'nemubot.prompt', 'nemubot.server', 'nemubot.server.message', 'nemubot.tools',