diff --git a/modules/alias.py b/modules/alias.py index 24c8fa3..f7aeddb 100644 --- a/modules/alias.py +++ b/modules/alias.py @@ -242,7 +242,7 @@ def cmd_unalias(msg): ## Alias replacement -@hook.add("pre_Command") +@hook.add(["pre","Command"]) def treat_alias(msg): if msg.cmd in context.data.getNode("aliases").index: txt = context.data.getNode("aliases").index[msg.cmd]["origin"] diff --git a/modules/networking/whois.py b/modules/networking/whois.py index 2e2970a..d3d30b1 100644 --- a/modules/networking/whois.py +++ b/modules/networking/whois.py @@ -28,12 +28,14 @@ def load(CONF, add_hook): URL_WHOIS = URL_WHOIS % (urllib.parse.quote(CONF.getNode("whoisxmlapi")["username"]), urllib.parse.quote(CONF.getNode("whoisxmlapi")["password"])) import nemubot.hooks - add_hook("in_Command", nemubot.hooks.Command(cmd_whois, "netwhois", - help="Get whois information about given domains", - help_usage={"DOMAIN": "Return whois information on the given DOMAIN"})) - add_hook("in_Command", nemubot.hooks.Command(cmd_avail, "domain_available", - help="Domain availability check using whoisxmlapi.com", - help_usage={"DOMAIN": "Check if the given DOMAIN is available or not"})) + add_hook(nemubot.hooks.Command(cmd_whois, "netwhois", + help="Get whois information about given domains", + help_usage={"DOMAIN": "Return whois information on the given DOMAIN"}), + "in","Command") + add_hook(nemubot.hooks.Command(cmd_avail, "domain_available", + help="Domain availability check using whoisxmlapi.com", + help_usage={"DOMAIN": "Check if the given DOMAIN is available or not"}), + "in","Command") # MODULE CORE ######################################################### diff --git a/modules/whois.py b/modules/whois.py index 4a13e9c..a51b838 100644 --- a/modules/whois.py +++ b/modules/whois.py @@ -30,8 +30,8 @@ def load(context): context.data.getNode("pics").setIndex("login", "pict") import nemubot.hooks - context.add_hook("in_Command", - nemubot.hooks.Command(cmd_whois, "whois")) + context.add_hook(nemubot.hooks.Command(cmd_whois, "whois"), + "in","Command") class Login: diff --git a/nemubot/bot.py b/nemubot/bot.py index f9569b7..8d45f3d 100644 --- a/nemubot/bot.py +++ b/nemubot/bot.py @@ -474,7 +474,7 @@ class Bot(threading.Thread): # Register decorated functions import nemubot.hooks for s, h in nemubot.hooks.hook.last_registered: - module.__nemubot_context__.add_hook(s, h) + module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s) nemubot.hooks.hook.last_registered = [] # Launch the module diff --git a/nemubot/hooks/__init__.py b/nemubot/hooks/__init__.py index 9904119..e9113eb 100644 --- a/nemubot/hooks/__init__.py +++ b/nemubot/hooks/__init__.py @@ -35,19 +35,19 @@ class hook: def add(store, *args, **kwargs): return hook._add(store, Abstract, *args, **kwargs) - def ask(*args, store="in_DirectAsk", **kwargs): + def ask(*args, store=["in","DirectAsk"], **kwargs): return hook._add(store, Message, *args, **kwargs) - def command(*args, store="in_Command", **kwargs): + def command(*args, store=["in","Command"], **kwargs): return hook._add(store, Command, *args, **kwargs) - def message(*args, store="in_Text", **kwargs): + def message(*args, store=["in","Text"], **kwargs): return hook._add(store, Message, *args, **kwargs) - def post(*args, store="post", **kwargs): + def post(*args, store=["post"], **kwargs): return hook._add(store, Abstract, *args, **kwargs) - def pre(*args, store="pre", **kwargs): + def pre(*args, store=["pre"], **kwargs): return hook._add(store, Abstract, *args, **kwargs) diff --git a/nemubot/hooks/manager.py b/nemubot/hooks/manager.py index 8859d19..6a57d2a 100644 --- a/nemubot/hooks/manager.py +++ b/nemubot/hooks/manager.py @@ -14,15 +14,47 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import logging + class HooksManager: """Class to manage hooks""" - def __init__(self): + def __init__(self, name="core"): """Initialize the manager""" self.hooks = dict() + self.logger = logging.getLogger("nemubot.hooks.manager." + name) + + + def _access(self, *triggers): + """Access to the given triggers chain""" + + h = self.hooks + for t in triggers: + if t not in h: + h[t] = dict() + h = h[t] + + if "__end__" not in h: + h["__end__"] = list() + + return h + + + def _search(self, hook, *where, start=None): + """Search all occurence of the given hook""" + + if start is None: + start = self.hooks + + for k in start: + if k == "__end__": + if hook in start[k]: + yield where + else: + yield from self._search(hook, *where + (k,), start=start[k]) def add_hook(self, hook, *triggers): @@ -33,20 +65,19 @@ class HooksManager: triggers -- string that trigger the hook """ - trigger = "_".join(triggers) + assert hook is not None, hook - if trigger not in self.hooks: - self.hooks[trigger] = list() + h = self._access(*triggers) - self.hooks[trigger].append(hook) + h["__end__"].append(hook) + + self.logger.debug("New hook successfully added in %s: %s", + "/".join(triggers), hook) - def del_hook(self, hook=None, *triggers): + def del_hooks(self, *triggers, hook=None): """Remove the given hook from the manager - Return: - Boolean value reporting the deletion success - Argument: triggers -- trigger string to remove @@ -54,15 +85,20 @@ class HooksManager: hook -- a Hook instance to remove from the trigger string """ - trigger = "_".join(triggers) + assert hook is not None or len(triggers) - if trigger in self.hooks: - if hook is None: - del self.hooks[trigger] + self.logger.debug("Trying to delete hook in %s: %s", + "/".join(triggers), hook) + + if hook is not None: + for h in self._search(hook, *triggers, start=self._access(*triggers)): + self._access(*h)["__end__"].remove(hook) + + else: + if len(triggers): + del self._access(*triggers[:-1])[triggers[-1]] else: - self.hooks[trigger].remove(hook) - return True - return False + self.hooks = dict() def get_hooks(self, *triggers): @@ -70,35 +106,29 @@ class HooksManager: Argument: triggers -- the trigger string - - Keyword argument: - data -- Data to pass to the hook as argument """ - trigger = "_".join(triggers) - - res = list() - - for key in self.hooks: - if trigger.find(key) == 0: - res += self.hooks[key] - - return res + for n in range(len(triggers) + 1): + i = self._access(*triggers[:n]) + for h in i["__end__"]: + yield h - def exec_hook(self, *triggers, **data): - """Trigger hooks that match the given trigger string + def get_reverse_hooks(self, *triggers, exclude_first=False): + """Returns list of triggered hooks that are bellow or at the same level Argument: - trigger -- the trigger string + triggers -- the trigger string - Keyword argument: - data -- Data to pass to the hook as argument + Keyword arguments: + exclude_first -- start reporting hook at the next level """ - trigger = "_".join(triggers) - - for key in self.hooks: - if trigger.find(key) == 0: - for hook in self.hooks[key]: - hook.run(**data) + h = self._access(*triggers) + for k in h: + if k == "__end__": + if not exclude_first: + for hk in h[k]: + yield hk + else: + yield from self.get_reverse_hooks(*triggers + (k,)) diff --git a/nemubot/hooks/manager_test.py b/nemubot/hooks/manager_test.py new file mode 100755 index 0000000..a0f38d7 --- /dev/null +++ b/nemubot/hooks/manager_test.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 + +import unittest + +from nemubot.hooks.manager import HooksManager + +class TestHookManager(unittest.TestCase): + + + def test_access(self): + hm = HooksManager() + + h1 = "HOOK1" + h2 = "HOOK2" + h3 = "HOOK3" + + hm.add_hook(h1) + hm.add_hook(h2, "pre") + hm.add_hook(h3, "pre", "Text") + hm.add_hook(h2, "post", "Text") + + self.assertIn("__end__", hm._access()) + self.assertIn("__end__", hm._access("pre")) + self.assertIn("__end__", hm._access("pre", "Text")) + self.assertIn("__end__", hm._access("post", "Text")) + + self.assertFalse(hm._access("inexistant")["__end__"]) + self.assertTrue(hm._access()["__end__"]) + self.assertTrue(hm._access("pre")["__end__"]) + self.assertTrue(hm._access("pre", "Text")["__end__"]) + self.assertTrue(hm._access("post", "Text")["__end__"]) + + + def test_search(self): + hm = HooksManager() + + h1 = "HOOK1" + h2 = "HOOK2" + h3 = "HOOK3" + h4 = "HOOK4" + + hm.add_hook(h1) + hm.add_hook(h2, "pre") + hm.add_hook(h3, "pre", "Text") + hm.add_hook(h2, "post", "Text") + + self.assertTrue([h for h in hm._search(h1)]) + self.assertFalse([h for h in hm._search(h4)]) + self.assertEqual(2, len([h for h in hm._search(h2)])) + self.assertEqual([("pre", "Text")], [h for h in hm._search(h3)]) + + + def test_delete(self): + hm = HooksManager() + + h1 = "HOOK1" + h2 = "HOOK2" + h3 = "HOOK3" + h4 = "HOOK4" + + hm.add_hook(h1) + hm.add_hook(h2, "pre") + hm.add_hook(h3, "pre", "Text") + hm.add_hook(h2, "post", "Text") + + hm.del_hooks(hook=h4) + + self.assertTrue(hm._access("pre")["__end__"]) + self.assertTrue(hm._access("pre", "Text")["__end__"]) + hm.del_hooks("pre") + self.assertFalse(hm._access("pre")["__end__"]) + + self.assertTrue(hm._access("post", "Text")["__end__"]) + hm.del_hooks("post", "Text", hook=h2) + self.assertFalse(hm._access("post", "Text")["__end__"]) + + self.assertTrue(hm._access()["__end__"]) + hm.del_hooks(hook=h1) + self.assertFalse(hm._access()["__end__"]) + + + def test_get(self): + hm = HooksManager() + + h1 = "HOOK1" + h2 = "HOOK2" + h3 = "HOOK3" + + hm.add_hook(h1) + hm.add_hook(h2, "pre") + hm.add_hook(h3, "pre", "Text") + hm.add_hook(h2, "post", "Text") + + self.assertEqual([h1, h2], [h for h in hm.get_hooks("pre")]) + self.assertEqual([h1, h2, h3], [h for h in hm.get_hooks("pre", "Text")]) + + + def test_get_rev(self): + hm = HooksManager() + + h1 = "HOOK1" + h2 = "HOOK2" + h3 = "HOOK3" + + hm.add_hook(h1) + hm.add_hook(h2, "pre") + hm.add_hook(h3, "pre", "Text") + hm.add_hook(h2, "post", "Text") + + self.assertEqual([h2, h3], [h for h in hm.get_reverse_hooks("pre")]) + self.assertEqual([h3], [h for h in hm.get_reverse_hooks("pre", exclude_first=True)]) + + +if __name__ == '__main__': + unittest.main() diff --git a/nemubot/modulecontext.py b/nemubot/modulecontext.py index 9c1f844..d562a98 100644 --- a/nemubot/modulecontext.py +++ b/nemubot/modulecontext.py @@ -26,6 +26,8 @@ class ModuleContext: if module is not None: module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ + else: + module_name = "" # Load module configuration if exists if (context is not None and @@ -39,26 +41,23 @@ class ModuleContext: self.events = list() self.debug = context.verbosity > 0 if context is not None else False + from nemubot.hooks import Abstract as AbstractHook + # Define some callbacks if context is not None: # Load module data self.data = context.datastore.load(module_name) - def add_hook(store, hook): - self.hooks.append((store, hook)) - return context.treater.hm.add_hook(hook, store) - def del_hook(store, hook): - self.hooks.remove((store, hook)) - return context.treater.hm.del_hook(hook, store) - def call_hook(store, msg): - for h in context.treater.hm.get_hooks(store): - if h.match(msg): - res = h.run(msg) - if isinstance(res, list): - for i in res: - yield i - else: - yield res + def add_hook(hook, *triggers): + assert isinstance(hook, AbstractHook), hook + self.hooks.append((triggers, hook)) + return context.treater.hm.add_hook(hook, *triggers) + + def del_hook(hook, *triggers): + assert isinstance(hook, AbstractHook), hook + self.hooks.remove((triggers, hook)) + return context.treater.hm.del_hooks(*triggers, hook=hook) + def subtreat(msg): yield from context.treater.treat_msg(msg) def add_event(evt, eid=None): @@ -80,13 +79,12 @@ class ModuleContext: from nemubot.tools.xmlparser import module_state self.data = module_state.ModuleState("nemubotstate") - def add_hook(store, hook): - self.hooks.append((store, hook)) - def del_hook(store, hook): - self.hooks.remove((store, hook)) - def call_hook(store, msg): - # TODO: what can we do here? - return None + def add_hook(hook, *triggers): + assert isinstance(hook, AbstractHook), hook + self.hooks.append((triggers, hook)) + def del_hook(hook, *triggers): + assert isinstance(hook, AbstractHook), hook + self.hooks.remove((triggers, hook)) def subtreat(msg): return None def add_event(evt, eid=None): @@ -106,7 +104,6 @@ class ModuleContext: self.del_event = del_event self.save = save self.send_response = send_response - self.call_hook = call_hook self.subtreat = subtreat @@ -115,7 +112,7 @@ class ModuleContext: # Remove registered hooks for (s, h) in self.hooks: - self.del_hook(s, h) + self.del_hook(h, *s) # Remove registered events for e in self.events: