Rework hook managment and add some tests

This commit is contained in:
nemunaire 2015-11-16 07:19:09 +01:00
parent 926648517f
commit 43c42e1397
8 changed files with 222 additions and 78 deletions

View File

@ -242,7 +242,7 @@ def cmd_unalias(msg):
## Alias replacement ## Alias replacement
@hook.add("pre_Command") @hook.add(["pre","Command"])
def treat_alias(msg): def treat_alias(msg):
if msg.cmd in context.data.getNode("aliases").index: if msg.cmd in context.data.getNode("aliases").index:
txt = context.data.getNode("aliases").index[msg.cmd]["origin"] txt = context.data.getNode("aliases").index[msg.cmd]["origin"]

View File

@ -28,12 +28,14 @@ def load(CONF, add_hook):
URL_WHOIS = URL_WHOIS % (urllib.parse.quote(CONF.getNode("whoisxmlapi")["username"]), urllib.parse.quote(CONF.getNode("whoisxmlapi")["password"])) URL_WHOIS = URL_WHOIS % (urllib.parse.quote(CONF.getNode("whoisxmlapi")["username"]), urllib.parse.quote(CONF.getNode("whoisxmlapi")["password"]))
import nemubot.hooks import nemubot.hooks
add_hook("in_Command", nemubot.hooks.Command(cmd_whois, "netwhois", add_hook(nemubot.hooks.Command(cmd_whois, "netwhois",
help="Get whois information about given domains", help="Get whois information about given domains",
help_usage={"DOMAIN": "Return whois information on the given DOMAIN"})) help_usage={"DOMAIN": "Return whois information on the given DOMAIN"}),
add_hook("in_Command", nemubot.hooks.Command(cmd_avail, "domain_available", "in","Command")
add_hook(nemubot.hooks.Command(cmd_avail, "domain_available",
help="Domain availability check using whoisxmlapi.com", help="Domain availability check using whoisxmlapi.com",
help_usage={"DOMAIN": "Check if the given DOMAIN is available or not"})) help_usage={"DOMAIN": "Check if the given DOMAIN is available or not"}),
"in","Command")
# MODULE CORE ######################################################### # MODULE CORE #########################################################

View File

@ -30,8 +30,8 @@ def load(context):
context.data.getNode("pics").setIndex("login", "pict") context.data.getNode("pics").setIndex("login", "pict")
import nemubot.hooks import nemubot.hooks
context.add_hook("in_Command", context.add_hook(nemubot.hooks.Command(cmd_whois, "whois"),
nemubot.hooks.Command(cmd_whois, "whois")) "in","Command")
class Login: class Login:

View File

@ -474,7 +474,7 @@ class Bot(threading.Thread):
# Register decorated functions # Register decorated functions
import nemubot.hooks import nemubot.hooks
for s, h in nemubot.hooks.hook.last_registered: for s, h in nemubot.hooks.hook.last_registered:
module.__nemubot_context__.add_hook(s, h) module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
nemubot.hooks.hook.last_registered = [] nemubot.hooks.hook.last_registered = []
# Launch the module # Launch the module

View File

@ -35,19 +35,19 @@ class hook:
def add(store, *args, **kwargs): def add(store, *args, **kwargs):
return hook._add(store, Abstract, *args, **kwargs) return hook._add(store, Abstract, *args, **kwargs)
def ask(*args, store="in_DirectAsk", **kwargs): def ask(*args, store=["in","DirectAsk"], **kwargs):
return hook._add(store, Message, *args, **kwargs) return hook._add(store, Message, *args, **kwargs)
def command(*args, store="in_Command", **kwargs): def command(*args, store=["in","Command"], **kwargs):
return hook._add(store, Command, *args, **kwargs) return hook._add(store, Command, *args, **kwargs)
def message(*args, store="in_Text", **kwargs): def message(*args, store=["in","Text"], **kwargs):
return hook._add(store, Message, *args, **kwargs) return hook._add(store, Message, *args, **kwargs)
def post(*args, store="post", **kwargs): def post(*args, store=["post"], **kwargs):
return hook._add(store, Abstract, *args, **kwargs) return hook._add(store, Abstract, *args, **kwargs)
def pre(*args, store="pre", **kwargs): def pre(*args, store=["pre"], **kwargs):
return hook._add(store, Abstract, *args, **kwargs) return hook._add(store, Abstract, *args, **kwargs)

View File

@ -14,15 +14,47 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
class HooksManager: class HooksManager:
"""Class to manage hooks""" """Class to manage hooks"""
def __init__(self): def __init__(self, name="core"):
"""Initialize the manager""" """Initialize the manager"""
self.hooks = dict() self.hooks = dict()
self.logger = logging.getLogger("nemubot.hooks.manager." + name)
def _access(self, *triggers):
"""Access to the given triggers chain"""
h = self.hooks
for t in triggers:
if t not in h:
h[t] = dict()
h = h[t]
if "__end__" not in h:
h["__end__"] = list()
return h
def _search(self, hook, *where, start=None):
"""Search all occurence of the given hook"""
if start is None:
start = self.hooks
for k in start:
if k == "__end__":
if hook in start[k]:
yield where
else:
yield from self._search(hook, *where + (k,), start=start[k])
def add_hook(self, hook, *triggers): def add_hook(self, hook, *triggers):
@ -33,20 +65,19 @@ class HooksManager:
triggers -- string that trigger the hook triggers -- string that trigger the hook
""" """
trigger = "_".join(triggers) assert hook is not None, hook
if trigger not in self.hooks: h = self._access(*triggers)
self.hooks[trigger] = list()
self.hooks[trigger].append(hook) h["__end__"].append(hook)
self.logger.debug("New hook successfully added in %s: %s",
"/".join(triggers), hook)
def del_hook(self, hook=None, *triggers): def del_hooks(self, *triggers, hook=None):
"""Remove the given hook from the manager """Remove the given hook from the manager
Return:
Boolean value reporting the deletion success
Argument: Argument:
triggers -- trigger string to remove triggers -- trigger string to remove
@ -54,15 +85,20 @@ class HooksManager:
hook -- a Hook instance to remove from the trigger string hook -- a Hook instance to remove from the trigger string
""" """
trigger = "_".join(triggers) assert hook is not None or len(triggers)
self.logger.debug("Trying to delete hook in %s: %s",
"/".join(triggers), hook)
if hook is not None:
for h in self._search(hook, *triggers, start=self._access(*triggers)):
self._access(*h)["__end__"].remove(hook)
if trigger in self.hooks:
if hook is None:
del self.hooks[trigger]
else: else:
self.hooks[trigger].remove(hook) if len(triggers):
return True del self._access(*triggers[:-1])[triggers[-1]]
return False else:
self.hooks = dict()
def get_hooks(self, *triggers): def get_hooks(self, *triggers):
@ -70,35 +106,29 @@ class HooksManager:
Argument: Argument:
triggers -- the trigger string triggers -- the trigger string
Keyword argument:
data -- Data to pass to the hook as argument
""" """
trigger = "_".join(triggers) for n in range(len(triggers) + 1):
i = self._access(*triggers[:n])
res = list() for h in i["__end__"]:
yield h
for key in self.hooks:
if trigger.find(key) == 0:
res += self.hooks[key]
return res
def exec_hook(self, *triggers, **data): def get_reverse_hooks(self, *triggers, exclude_first=False):
"""Trigger hooks that match the given trigger string """Returns list of triggered hooks that are bellow or at the same level
Argument: Argument:
trigger -- the trigger string triggers -- the trigger string
Keyword argument: Keyword arguments:
data -- Data to pass to the hook as argument exclude_first -- start reporting hook at the next level
""" """
trigger = "_".join(triggers) h = self._access(*triggers)
for k in h:
for key in self.hooks: if k == "__end__":
if trigger.find(key) == 0: if not exclude_first:
for hook in self.hooks[key]: for hk in h[k]:
hook.run(**data) yield hk
else:
yield from self.get_reverse_hooks(*triggers + (k,))

115
nemubot/hooks/manager_test.py Executable file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
import unittest
from nemubot.hooks.manager import HooksManager
class TestHookManager(unittest.TestCase):
def test_access(self):
hm = HooksManager()
h1 = "HOOK1"
h2 = "HOOK2"
h3 = "HOOK3"
hm.add_hook(h1)
hm.add_hook(h2, "pre")
hm.add_hook(h3, "pre", "Text")
hm.add_hook(h2, "post", "Text")
self.assertIn("__end__", hm._access())
self.assertIn("__end__", hm._access("pre"))
self.assertIn("__end__", hm._access("pre", "Text"))
self.assertIn("__end__", hm._access("post", "Text"))
self.assertFalse(hm._access("inexistant")["__end__"])
self.assertTrue(hm._access()["__end__"])
self.assertTrue(hm._access("pre")["__end__"])
self.assertTrue(hm._access("pre", "Text")["__end__"])
self.assertTrue(hm._access("post", "Text")["__end__"])
def test_search(self):
hm = HooksManager()
h1 = "HOOK1"
h2 = "HOOK2"
h3 = "HOOK3"
h4 = "HOOK4"
hm.add_hook(h1)
hm.add_hook(h2, "pre")
hm.add_hook(h3, "pre", "Text")
hm.add_hook(h2, "post", "Text")
self.assertTrue([h for h in hm._search(h1)])
self.assertFalse([h for h in hm._search(h4)])
self.assertEqual(2, len([h for h in hm._search(h2)]))
self.assertEqual([("pre", "Text")], [h for h in hm._search(h3)])
def test_delete(self):
hm = HooksManager()
h1 = "HOOK1"
h2 = "HOOK2"
h3 = "HOOK3"
h4 = "HOOK4"
hm.add_hook(h1)
hm.add_hook(h2, "pre")
hm.add_hook(h3, "pre", "Text")
hm.add_hook(h2, "post", "Text")
hm.del_hooks(hook=h4)
self.assertTrue(hm._access("pre")["__end__"])
self.assertTrue(hm._access("pre", "Text")["__end__"])
hm.del_hooks("pre")
self.assertFalse(hm._access("pre")["__end__"])
self.assertTrue(hm._access("post", "Text")["__end__"])
hm.del_hooks("post", "Text", hook=h2)
self.assertFalse(hm._access("post", "Text")["__end__"])
self.assertTrue(hm._access()["__end__"])
hm.del_hooks(hook=h1)
self.assertFalse(hm._access()["__end__"])
def test_get(self):
hm = HooksManager()
h1 = "HOOK1"
h2 = "HOOK2"
h3 = "HOOK3"
hm.add_hook(h1)
hm.add_hook(h2, "pre")
hm.add_hook(h3, "pre", "Text")
hm.add_hook(h2, "post", "Text")
self.assertEqual([h1, h2], [h for h in hm.get_hooks("pre")])
self.assertEqual([h1, h2, h3], [h for h in hm.get_hooks("pre", "Text")])
def test_get_rev(self):
hm = HooksManager()
h1 = "HOOK1"
h2 = "HOOK2"
h3 = "HOOK3"
hm.add_hook(h1)
hm.add_hook(h2, "pre")
hm.add_hook(h3, "pre", "Text")
hm.add_hook(h2, "post", "Text")
self.assertEqual([h2, h3], [h for h in hm.get_reverse_hooks("pre")])
self.assertEqual([h3], [h for h in hm.get_reverse_hooks("pre", exclude_first=True)])
if __name__ == '__main__':
unittest.main()

View File

@ -26,6 +26,8 @@ class ModuleContext:
if module is not None: if module is not None:
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
else:
module_name = ""
# Load module configuration if exists # Load module configuration if exists
if (context is not None and if (context is not None and
@ -39,26 +41,23 @@ class ModuleContext:
self.events = list() self.events = list()
self.debug = context.verbosity > 0 if context is not None else False self.debug = context.verbosity > 0 if context is not None else False
from nemubot.hooks import Abstract as AbstractHook
# Define some callbacks # Define some callbacks
if context is not None: if context is not None:
# Load module data # Load module data
self.data = context.datastore.load(module_name) self.data = context.datastore.load(module_name)
def add_hook(store, hook): def add_hook(hook, *triggers):
self.hooks.append((store, hook)) assert isinstance(hook, AbstractHook), hook
return context.treater.hm.add_hook(hook, store) self.hooks.append((triggers, hook))
def del_hook(store, hook): return context.treater.hm.add_hook(hook, *triggers)
self.hooks.remove((store, hook))
return context.treater.hm.del_hook(hook, store) def del_hook(hook, *triggers):
def call_hook(store, msg): assert isinstance(hook, AbstractHook), hook
for h in context.treater.hm.get_hooks(store): self.hooks.remove((triggers, hook))
if h.match(msg): return context.treater.hm.del_hooks(*triggers, hook=hook)
res = h.run(msg)
if isinstance(res, list):
for i in res:
yield i
else:
yield res
def subtreat(msg): def subtreat(msg):
yield from context.treater.treat_msg(msg) yield from context.treater.treat_msg(msg)
def add_event(evt, eid=None): def add_event(evt, eid=None):
@ -80,13 +79,12 @@ class ModuleContext:
from nemubot.tools.xmlparser import module_state from nemubot.tools.xmlparser import module_state
self.data = module_state.ModuleState("nemubotstate") self.data = module_state.ModuleState("nemubotstate")
def add_hook(store, hook): def add_hook(hook, *triggers):
self.hooks.append((store, hook)) assert isinstance(hook, AbstractHook), hook
def del_hook(store, hook): self.hooks.append((triggers, hook))
self.hooks.remove((store, hook)) def del_hook(hook, *triggers):
def call_hook(store, msg): assert isinstance(hook, AbstractHook), hook
# TODO: what can we do here? self.hooks.remove((triggers, hook))
return None
def subtreat(msg): def subtreat(msg):
return None return None
def add_event(evt, eid=None): def add_event(evt, eid=None):
@ -106,7 +104,6 @@ class ModuleContext:
self.del_event = del_event self.del_event = del_event
self.save = save self.save = save
self.send_response = send_response self.send_response = send_response
self.call_hook = call_hook
self.subtreat = subtreat self.subtreat = subtreat
@ -115,7 +112,7 @@ class ModuleContext:
# Remove registered hooks # Remove registered hooks
for (s, h) in self.hooks: for (s, h) in self.hooks:
self.del_hook(s, h) self.del_hook(h, *s)
# Remove registered events # Remove registered events
for e in self.events: for e in self.events: