Rework hook managment and add some tests
This commit is contained in:
parent
926648517f
commit
43c42e1397
@ -242,7 +242,7 @@ def cmd_unalias(msg):
|
||||
|
||||
## Alias replacement
|
||||
|
||||
@hook.add("pre_Command")
|
||||
@hook.add(["pre","Command"])
|
||||
def treat_alias(msg):
|
||||
if msg.cmd in context.data.getNode("aliases").index:
|
||||
txt = context.data.getNode("aliases").index[msg.cmd]["origin"]
|
||||
|
@ -28,12 +28,14 @@ def load(CONF, add_hook):
|
||||
URL_WHOIS = URL_WHOIS % (urllib.parse.quote(CONF.getNode("whoisxmlapi")["username"]), urllib.parse.quote(CONF.getNode("whoisxmlapi")["password"]))
|
||||
|
||||
import nemubot.hooks
|
||||
add_hook("in_Command", nemubot.hooks.Command(cmd_whois, "netwhois",
|
||||
add_hook(nemubot.hooks.Command(cmd_whois, "netwhois",
|
||||
help="Get whois information about given domains",
|
||||
help_usage={"DOMAIN": "Return whois information on the given DOMAIN"}))
|
||||
add_hook("in_Command", nemubot.hooks.Command(cmd_avail, "domain_available",
|
||||
help_usage={"DOMAIN": "Return whois information on the given DOMAIN"}),
|
||||
"in","Command")
|
||||
add_hook(nemubot.hooks.Command(cmd_avail, "domain_available",
|
||||
help="Domain availability check using whoisxmlapi.com",
|
||||
help_usage={"DOMAIN": "Check if the given DOMAIN is available or not"}))
|
||||
help_usage={"DOMAIN": "Check if the given DOMAIN is available or not"}),
|
||||
"in","Command")
|
||||
|
||||
|
||||
# MODULE CORE #########################################################
|
||||
|
@ -30,8 +30,8 @@ def load(context):
|
||||
context.data.getNode("pics").setIndex("login", "pict")
|
||||
|
||||
import nemubot.hooks
|
||||
context.add_hook("in_Command",
|
||||
nemubot.hooks.Command(cmd_whois, "whois"))
|
||||
context.add_hook(nemubot.hooks.Command(cmd_whois, "whois"),
|
||||
"in","Command")
|
||||
|
||||
class Login:
|
||||
|
||||
|
@ -474,7 +474,7 @@ class Bot(threading.Thread):
|
||||
# Register decorated functions
|
||||
import nemubot.hooks
|
||||
for s, h in nemubot.hooks.hook.last_registered:
|
||||
module.__nemubot_context__.add_hook(s, h)
|
||||
module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
|
||||
nemubot.hooks.hook.last_registered = []
|
||||
|
||||
# Launch the module
|
||||
|
@ -35,19 +35,19 @@ class hook:
|
||||
def add(store, *args, **kwargs):
|
||||
return hook._add(store, Abstract, *args, **kwargs)
|
||||
|
||||
def ask(*args, store="in_DirectAsk", **kwargs):
|
||||
def ask(*args, store=["in","DirectAsk"], **kwargs):
|
||||
return hook._add(store, Message, *args, **kwargs)
|
||||
|
||||
def command(*args, store="in_Command", **kwargs):
|
||||
def command(*args, store=["in","Command"], **kwargs):
|
||||
return hook._add(store, Command, *args, **kwargs)
|
||||
|
||||
def message(*args, store="in_Text", **kwargs):
|
||||
def message(*args, store=["in","Text"], **kwargs):
|
||||
return hook._add(store, Message, *args, **kwargs)
|
||||
|
||||
def post(*args, store="post", **kwargs):
|
||||
def post(*args, store=["post"], **kwargs):
|
||||
return hook._add(store, Abstract, *args, **kwargs)
|
||||
|
||||
def pre(*args, store="pre", **kwargs):
|
||||
def pre(*args, store=["pre"], **kwargs):
|
||||
return hook._add(store, Abstract, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -14,15 +14,47 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class HooksManager:
|
||||
|
||||
"""Class to manage hooks"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, name="core"):
|
||||
"""Initialize the manager"""
|
||||
|
||||
self.hooks = dict()
|
||||
self.logger = logging.getLogger("nemubot.hooks.manager." + name)
|
||||
|
||||
|
||||
def _access(self, *triggers):
|
||||
"""Access to the given triggers chain"""
|
||||
|
||||
h = self.hooks
|
||||
for t in triggers:
|
||||
if t not in h:
|
||||
h[t] = dict()
|
||||
h = h[t]
|
||||
|
||||
if "__end__" not in h:
|
||||
h["__end__"] = list()
|
||||
|
||||
return h
|
||||
|
||||
|
||||
def _search(self, hook, *where, start=None):
|
||||
"""Search all occurence of the given hook"""
|
||||
|
||||
if start is None:
|
||||
start = self.hooks
|
||||
|
||||
for k in start:
|
||||
if k == "__end__":
|
||||
if hook in start[k]:
|
||||
yield where
|
||||
else:
|
||||
yield from self._search(hook, *where + (k,), start=start[k])
|
||||
|
||||
|
||||
def add_hook(self, hook, *triggers):
|
||||
@ -33,20 +65,19 @@ class HooksManager:
|
||||
triggers -- string that trigger the hook
|
||||
"""
|
||||
|
||||
trigger = "_".join(triggers)
|
||||
assert hook is not None, hook
|
||||
|
||||
if trigger not in self.hooks:
|
||||
self.hooks[trigger] = list()
|
||||
h = self._access(*triggers)
|
||||
|
||||
self.hooks[trigger].append(hook)
|
||||
h["__end__"].append(hook)
|
||||
|
||||
self.logger.debug("New hook successfully added in %s: %s",
|
||||
"/".join(triggers), hook)
|
||||
|
||||
|
||||
def del_hook(self, hook=None, *triggers):
|
||||
def del_hooks(self, *triggers, hook=None):
|
||||
"""Remove the given hook from the manager
|
||||
|
||||
Return:
|
||||
Boolean value reporting the deletion success
|
||||
|
||||
Argument:
|
||||
triggers -- trigger string to remove
|
||||
|
||||
@ -54,15 +85,20 @@ class HooksManager:
|
||||
hook -- a Hook instance to remove from the trigger string
|
||||
"""
|
||||
|
||||
trigger = "_".join(triggers)
|
||||
assert hook is not None or len(triggers)
|
||||
|
||||
self.logger.debug("Trying to delete hook in %s: %s",
|
||||
"/".join(triggers), hook)
|
||||
|
||||
if hook is not None:
|
||||
for h in self._search(hook, *triggers, start=self._access(*triggers)):
|
||||
self._access(*h)["__end__"].remove(hook)
|
||||
|
||||
if trigger in self.hooks:
|
||||
if hook is None:
|
||||
del self.hooks[trigger]
|
||||
else:
|
||||
self.hooks[trigger].remove(hook)
|
||||
return True
|
||||
return False
|
||||
if len(triggers):
|
||||
del self._access(*triggers[:-1])[triggers[-1]]
|
||||
else:
|
||||
self.hooks = dict()
|
||||
|
||||
|
||||
def get_hooks(self, *triggers):
|
||||
@ -70,35 +106,29 @@ class HooksManager:
|
||||
|
||||
Argument:
|
||||
triggers -- the trigger string
|
||||
|
||||
Keyword argument:
|
||||
data -- Data to pass to the hook as argument
|
||||
"""
|
||||
|
||||
trigger = "_".join(triggers)
|
||||
|
||||
res = list()
|
||||
|
||||
for key in self.hooks:
|
||||
if trigger.find(key) == 0:
|
||||
res += self.hooks[key]
|
||||
|
||||
return res
|
||||
for n in range(len(triggers) + 1):
|
||||
i = self._access(*triggers[:n])
|
||||
for h in i["__end__"]:
|
||||
yield h
|
||||
|
||||
|
||||
def exec_hook(self, *triggers, **data):
|
||||
"""Trigger hooks that match the given trigger string
|
||||
def get_reverse_hooks(self, *triggers, exclude_first=False):
|
||||
"""Returns list of triggered hooks that are bellow or at the same level
|
||||
|
||||
Argument:
|
||||
trigger -- the trigger string
|
||||
triggers -- the trigger string
|
||||
|
||||
Keyword argument:
|
||||
data -- Data to pass to the hook as argument
|
||||
Keyword arguments:
|
||||
exclude_first -- start reporting hook at the next level
|
||||
"""
|
||||
|
||||
trigger = "_".join(triggers)
|
||||
|
||||
for key in self.hooks:
|
||||
if trigger.find(key) == 0:
|
||||
for hook in self.hooks[key]:
|
||||
hook.run(**data)
|
||||
h = self._access(*triggers)
|
||||
for k in h:
|
||||
if k == "__end__":
|
||||
if not exclude_first:
|
||||
for hk in h[k]:
|
||||
yield hk
|
||||
else:
|
||||
yield from self.get_reverse_hooks(*triggers + (k,))
|
||||
|
115
nemubot/hooks/manager_test.py
Executable file
115
nemubot/hooks/manager_test.py
Executable file
@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import unittest
|
||||
|
||||
from nemubot.hooks.manager import HooksManager
|
||||
|
||||
class TestHookManager(unittest.TestCase):
|
||||
|
||||
|
||||
def test_access(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
self.assertIn("__end__", hm._access())
|
||||
self.assertIn("__end__", hm._access("pre"))
|
||||
self.assertIn("__end__", hm._access("pre", "Text"))
|
||||
self.assertIn("__end__", hm._access("post", "Text"))
|
||||
|
||||
self.assertFalse(hm._access("inexistant")["__end__"])
|
||||
self.assertTrue(hm._access()["__end__"])
|
||||
self.assertTrue(hm._access("pre")["__end__"])
|
||||
self.assertTrue(hm._access("pre", "Text")["__end__"])
|
||||
self.assertTrue(hm._access("post", "Text")["__end__"])
|
||||
|
||||
|
||||
def test_search(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
h4 = "HOOK4"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
self.assertTrue([h for h in hm._search(h1)])
|
||||
self.assertFalse([h for h in hm._search(h4)])
|
||||
self.assertEqual(2, len([h for h in hm._search(h2)]))
|
||||
self.assertEqual([("pre", "Text")], [h for h in hm._search(h3)])
|
||||
|
||||
|
||||
def test_delete(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
h4 = "HOOK4"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
hm.del_hooks(hook=h4)
|
||||
|
||||
self.assertTrue(hm._access("pre")["__end__"])
|
||||
self.assertTrue(hm._access("pre", "Text")["__end__"])
|
||||
hm.del_hooks("pre")
|
||||
self.assertFalse(hm._access("pre")["__end__"])
|
||||
|
||||
self.assertTrue(hm._access("post", "Text")["__end__"])
|
||||
hm.del_hooks("post", "Text", hook=h2)
|
||||
self.assertFalse(hm._access("post", "Text")["__end__"])
|
||||
|
||||
self.assertTrue(hm._access()["__end__"])
|
||||
hm.del_hooks(hook=h1)
|
||||
self.assertFalse(hm._access()["__end__"])
|
||||
|
||||
|
||||
def test_get(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
self.assertEqual([h1, h2], [h for h in hm.get_hooks("pre")])
|
||||
self.assertEqual([h1, h2, h3], [h for h in hm.get_hooks("pre", "Text")])
|
||||
|
||||
|
||||
def test_get_rev(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
self.assertEqual([h2, h3], [h for h in hm.get_reverse_hooks("pre")])
|
||||
self.assertEqual([h3], [h for h in hm.get_reverse_hooks("pre", exclude_first=True)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -26,6 +26,8 @@ class ModuleContext:
|
||||
|
||||
if module is not None:
|
||||
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
|
||||
else:
|
||||
module_name = ""
|
||||
|
||||
# Load module configuration if exists
|
||||
if (context is not None and
|
||||
@ -39,26 +41,23 @@ class ModuleContext:
|
||||
self.events = list()
|
||||
self.debug = context.verbosity > 0 if context is not None else False
|
||||
|
||||
from nemubot.hooks import Abstract as AbstractHook
|
||||
|
||||
# Define some callbacks
|
||||
if context is not None:
|
||||
# Load module data
|
||||
self.data = context.datastore.load(module_name)
|
||||
|
||||
def add_hook(store, hook):
|
||||
self.hooks.append((store, hook))
|
||||
return context.treater.hm.add_hook(hook, store)
|
||||
def del_hook(store, hook):
|
||||
self.hooks.remove((store, hook))
|
||||
return context.treater.hm.del_hook(hook, store)
|
||||
def call_hook(store, msg):
|
||||
for h in context.treater.hm.get_hooks(store):
|
||||
if h.match(msg):
|
||||
res = h.run(msg)
|
||||
if isinstance(res, list):
|
||||
for i in res:
|
||||
yield i
|
||||
else:
|
||||
yield res
|
||||
def add_hook(hook, *triggers):
|
||||
assert isinstance(hook, AbstractHook), hook
|
||||
self.hooks.append((triggers, hook))
|
||||
return context.treater.hm.add_hook(hook, *triggers)
|
||||
|
||||
def del_hook(hook, *triggers):
|
||||
assert isinstance(hook, AbstractHook), hook
|
||||
self.hooks.remove((triggers, hook))
|
||||
return context.treater.hm.del_hooks(*triggers, hook=hook)
|
||||
|
||||
def subtreat(msg):
|
||||
yield from context.treater.treat_msg(msg)
|
||||
def add_event(evt, eid=None):
|
||||
@ -80,13 +79,12 @@ class ModuleContext:
|
||||
from nemubot.tools.xmlparser import module_state
|
||||
self.data = module_state.ModuleState("nemubotstate")
|
||||
|
||||
def add_hook(store, hook):
|
||||
self.hooks.append((store, hook))
|
||||
def del_hook(store, hook):
|
||||
self.hooks.remove((store, hook))
|
||||
def call_hook(store, msg):
|
||||
# TODO: what can we do here?
|
||||
return None
|
||||
def add_hook(hook, *triggers):
|
||||
assert isinstance(hook, AbstractHook), hook
|
||||
self.hooks.append((triggers, hook))
|
||||
def del_hook(hook, *triggers):
|
||||
assert isinstance(hook, AbstractHook), hook
|
||||
self.hooks.remove((triggers, hook))
|
||||
def subtreat(msg):
|
||||
return None
|
||||
def add_event(evt, eid=None):
|
||||
@ -106,7 +104,6 @@ class ModuleContext:
|
||||
self.del_event = del_event
|
||||
self.save = save
|
||||
self.send_response = send_response
|
||||
self.call_hook = call_hook
|
||||
self.subtreat = subtreat
|
||||
|
||||
|
||||
@ -115,7 +112,7 @@ class ModuleContext:
|
||||
|
||||
# Remove registered hooks
|
||||
for (s, h) in self.hooks:
|
||||
self.del_hook(s, h)
|
||||
self.del_hook(h, *s)
|
||||
|
||||
# Remove registered events
|
||||
for e in self.events:
|
||||
|
Loading…
Reference in New Issue
Block a user