Rework hook managment and add some tests
This commit is contained in:
parent
926648517f
commit
43c42e1397
@ -242,7 +242,7 @@ def cmd_unalias(msg):
|
|||||||
|
|
||||||
## Alias replacement
|
## Alias replacement
|
||||||
|
|
||||||
@hook.add("pre_Command")
|
@hook.add(["pre","Command"])
|
||||||
def treat_alias(msg):
|
def treat_alias(msg):
|
||||||
if msg.cmd in context.data.getNode("aliases").index:
|
if msg.cmd in context.data.getNode("aliases").index:
|
||||||
txt = context.data.getNode("aliases").index[msg.cmd]["origin"]
|
txt = context.data.getNode("aliases").index[msg.cmd]["origin"]
|
||||||
|
@ -28,12 +28,14 @@ def load(CONF, add_hook):
|
|||||||
URL_WHOIS = URL_WHOIS % (urllib.parse.quote(CONF.getNode("whoisxmlapi")["username"]), urllib.parse.quote(CONF.getNode("whoisxmlapi")["password"]))
|
URL_WHOIS = URL_WHOIS % (urllib.parse.quote(CONF.getNode("whoisxmlapi")["username"]), urllib.parse.quote(CONF.getNode("whoisxmlapi")["password"]))
|
||||||
|
|
||||||
import nemubot.hooks
|
import nemubot.hooks
|
||||||
add_hook("in_Command", nemubot.hooks.Command(cmd_whois, "netwhois",
|
add_hook(nemubot.hooks.Command(cmd_whois, "netwhois",
|
||||||
help="Get whois information about given domains",
|
help="Get whois information about given domains",
|
||||||
help_usage={"DOMAIN": "Return whois information on the given DOMAIN"}))
|
help_usage={"DOMAIN": "Return whois information on the given DOMAIN"}),
|
||||||
add_hook("in_Command", nemubot.hooks.Command(cmd_avail, "domain_available",
|
"in","Command")
|
||||||
help="Domain availability check using whoisxmlapi.com",
|
add_hook(nemubot.hooks.Command(cmd_avail, "domain_available",
|
||||||
help_usage={"DOMAIN": "Check if the given DOMAIN is available or not"}))
|
help="Domain availability check using whoisxmlapi.com",
|
||||||
|
help_usage={"DOMAIN": "Check if the given DOMAIN is available or not"}),
|
||||||
|
"in","Command")
|
||||||
|
|
||||||
|
|
||||||
# MODULE CORE #########################################################
|
# MODULE CORE #########################################################
|
||||||
|
@ -30,8 +30,8 @@ def load(context):
|
|||||||
context.data.getNode("pics").setIndex("login", "pict")
|
context.data.getNode("pics").setIndex("login", "pict")
|
||||||
|
|
||||||
import nemubot.hooks
|
import nemubot.hooks
|
||||||
context.add_hook("in_Command",
|
context.add_hook(nemubot.hooks.Command(cmd_whois, "whois"),
|
||||||
nemubot.hooks.Command(cmd_whois, "whois"))
|
"in","Command")
|
||||||
|
|
||||||
class Login:
|
class Login:
|
||||||
|
|
||||||
|
@ -474,7 +474,7 @@ class Bot(threading.Thread):
|
|||||||
# Register decorated functions
|
# Register decorated functions
|
||||||
import nemubot.hooks
|
import nemubot.hooks
|
||||||
for s, h in nemubot.hooks.hook.last_registered:
|
for s, h in nemubot.hooks.hook.last_registered:
|
||||||
module.__nemubot_context__.add_hook(s, h)
|
module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
|
||||||
nemubot.hooks.hook.last_registered = []
|
nemubot.hooks.hook.last_registered = []
|
||||||
|
|
||||||
# Launch the module
|
# Launch the module
|
||||||
|
@ -35,19 +35,19 @@ class hook:
|
|||||||
def add(store, *args, **kwargs):
|
def add(store, *args, **kwargs):
|
||||||
return hook._add(store, Abstract, *args, **kwargs)
|
return hook._add(store, Abstract, *args, **kwargs)
|
||||||
|
|
||||||
def ask(*args, store="in_DirectAsk", **kwargs):
|
def ask(*args, store=["in","DirectAsk"], **kwargs):
|
||||||
return hook._add(store, Message, *args, **kwargs)
|
return hook._add(store, Message, *args, **kwargs)
|
||||||
|
|
||||||
def command(*args, store="in_Command", **kwargs):
|
def command(*args, store=["in","Command"], **kwargs):
|
||||||
return hook._add(store, Command, *args, **kwargs)
|
return hook._add(store, Command, *args, **kwargs)
|
||||||
|
|
||||||
def message(*args, store="in_Text", **kwargs):
|
def message(*args, store=["in","Text"], **kwargs):
|
||||||
return hook._add(store, Message, *args, **kwargs)
|
return hook._add(store, Message, *args, **kwargs)
|
||||||
|
|
||||||
def post(*args, store="post", **kwargs):
|
def post(*args, store=["post"], **kwargs):
|
||||||
return hook._add(store, Abstract, *args, **kwargs)
|
return hook._add(store, Abstract, *args, **kwargs)
|
||||||
|
|
||||||
def pre(*args, store="pre", **kwargs):
|
def pre(*args, store=["pre"], **kwargs):
|
||||||
return hook._add(store, Abstract, *args, **kwargs)
|
return hook._add(store, Abstract, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,15 +14,47 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class HooksManager:
|
class HooksManager:
|
||||||
|
|
||||||
"""Class to manage hooks"""
|
"""Class to manage hooks"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, name="core"):
|
||||||
"""Initialize the manager"""
|
"""Initialize the manager"""
|
||||||
|
|
||||||
self.hooks = dict()
|
self.hooks = dict()
|
||||||
|
self.logger = logging.getLogger("nemubot.hooks.manager." + name)
|
||||||
|
|
||||||
|
|
||||||
|
def _access(self, *triggers):
|
||||||
|
"""Access to the given triggers chain"""
|
||||||
|
|
||||||
|
h = self.hooks
|
||||||
|
for t in triggers:
|
||||||
|
if t not in h:
|
||||||
|
h[t] = dict()
|
||||||
|
h = h[t]
|
||||||
|
|
||||||
|
if "__end__" not in h:
|
||||||
|
h["__end__"] = list()
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def _search(self, hook, *where, start=None):
|
||||||
|
"""Search all occurence of the given hook"""
|
||||||
|
|
||||||
|
if start is None:
|
||||||
|
start = self.hooks
|
||||||
|
|
||||||
|
for k in start:
|
||||||
|
if k == "__end__":
|
||||||
|
if hook in start[k]:
|
||||||
|
yield where
|
||||||
|
else:
|
||||||
|
yield from self._search(hook, *where + (k,), start=start[k])
|
||||||
|
|
||||||
|
|
||||||
def add_hook(self, hook, *triggers):
|
def add_hook(self, hook, *triggers):
|
||||||
@ -33,20 +65,19 @@ class HooksManager:
|
|||||||
triggers -- string that trigger the hook
|
triggers -- string that trigger the hook
|
||||||
"""
|
"""
|
||||||
|
|
||||||
trigger = "_".join(triggers)
|
assert hook is not None, hook
|
||||||
|
|
||||||
if trigger not in self.hooks:
|
h = self._access(*triggers)
|
||||||
self.hooks[trigger] = list()
|
|
||||||
|
|
||||||
self.hooks[trigger].append(hook)
|
h["__end__"].append(hook)
|
||||||
|
|
||||||
|
self.logger.debug("New hook successfully added in %s: %s",
|
||||||
|
"/".join(triggers), hook)
|
||||||
|
|
||||||
|
|
||||||
def del_hook(self, hook=None, *triggers):
|
def del_hooks(self, *triggers, hook=None):
|
||||||
"""Remove the given hook from the manager
|
"""Remove the given hook from the manager
|
||||||
|
|
||||||
Return:
|
|
||||||
Boolean value reporting the deletion success
|
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
triggers -- trigger string to remove
|
triggers -- trigger string to remove
|
||||||
|
|
||||||
@ -54,15 +85,20 @@ class HooksManager:
|
|||||||
hook -- a Hook instance to remove from the trigger string
|
hook -- a Hook instance to remove from the trigger string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
trigger = "_".join(triggers)
|
assert hook is not None or len(triggers)
|
||||||
|
|
||||||
if trigger in self.hooks:
|
self.logger.debug("Trying to delete hook in %s: %s",
|
||||||
if hook is None:
|
"/".join(triggers), hook)
|
||||||
del self.hooks[trigger]
|
|
||||||
|
if hook is not None:
|
||||||
|
for h in self._search(hook, *triggers, start=self._access(*triggers)):
|
||||||
|
self._access(*h)["__end__"].remove(hook)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if len(triggers):
|
||||||
|
del self._access(*triggers[:-1])[triggers[-1]]
|
||||||
else:
|
else:
|
||||||
self.hooks[trigger].remove(hook)
|
self.hooks = dict()
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_hooks(self, *triggers):
|
def get_hooks(self, *triggers):
|
||||||
@ -70,35 +106,29 @@ class HooksManager:
|
|||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
triggers -- the trigger string
|
triggers -- the trigger string
|
||||||
|
|
||||||
Keyword argument:
|
|
||||||
data -- Data to pass to the hook as argument
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
trigger = "_".join(triggers)
|
for n in range(len(triggers) + 1):
|
||||||
|
i = self._access(*triggers[:n])
|
||||||
res = list()
|
for h in i["__end__"]:
|
||||||
|
yield h
|
||||||
for key in self.hooks:
|
|
||||||
if trigger.find(key) == 0:
|
|
||||||
res += self.hooks[key]
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def exec_hook(self, *triggers, **data):
|
def get_reverse_hooks(self, *triggers, exclude_first=False):
|
||||||
"""Trigger hooks that match the given trigger string
|
"""Returns list of triggered hooks that are bellow or at the same level
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
trigger -- the trigger string
|
triggers -- the trigger string
|
||||||
|
|
||||||
Keyword argument:
|
Keyword arguments:
|
||||||
data -- Data to pass to the hook as argument
|
exclude_first -- start reporting hook at the next level
|
||||||
"""
|
"""
|
||||||
|
|
||||||
trigger = "_".join(triggers)
|
h = self._access(*triggers)
|
||||||
|
for k in h:
|
||||||
for key in self.hooks:
|
if k == "__end__":
|
||||||
if trigger.find(key) == 0:
|
if not exclude_first:
|
||||||
for hook in self.hooks[key]:
|
for hk in h[k]:
|
||||||
hook.run(**data)
|
yield hk
|
||||||
|
else:
|
||||||
|
yield from self.get_reverse_hooks(*triggers + (k,))
|
||||||
|
115
nemubot/hooks/manager_test.py
Executable file
115
nemubot/hooks/manager_test.py
Executable file
@ -0,0 +1,115 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from nemubot.hooks.manager import HooksManager
|
||||||
|
|
||||||
|
class TestHookManager(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
|
def test_access(self):
|
||||||
|
hm = HooksManager()
|
||||||
|
|
||||||
|
h1 = "HOOK1"
|
||||||
|
h2 = "HOOK2"
|
||||||
|
h3 = "HOOK3"
|
||||||
|
|
||||||
|
hm.add_hook(h1)
|
||||||
|
hm.add_hook(h2, "pre")
|
||||||
|
hm.add_hook(h3, "pre", "Text")
|
||||||
|
hm.add_hook(h2, "post", "Text")
|
||||||
|
|
||||||
|
self.assertIn("__end__", hm._access())
|
||||||
|
self.assertIn("__end__", hm._access("pre"))
|
||||||
|
self.assertIn("__end__", hm._access("pre", "Text"))
|
||||||
|
self.assertIn("__end__", hm._access("post", "Text"))
|
||||||
|
|
||||||
|
self.assertFalse(hm._access("inexistant")["__end__"])
|
||||||
|
self.assertTrue(hm._access()["__end__"])
|
||||||
|
self.assertTrue(hm._access("pre")["__end__"])
|
||||||
|
self.assertTrue(hm._access("pre", "Text")["__end__"])
|
||||||
|
self.assertTrue(hm._access("post", "Text")["__end__"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_search(self):
|
||||||
|
hm = HooksManager()
|
||||||
|
|
||||||
|
h1 = "HOOK1"
|
||||||
|
h2 = "HOOK2"
|
||||||
|
h3 = "HOOK3"
|
||||||
|
h4 = "HOOK4"
|
||||||
|
|
||||||
|
hm.add_hook(h1)
|
||||||
|
hm.add_hook(h2, "pre")
|
||||||
|
hm.add_hook(h3, "pre", "Text")
|
||||||
|
hm.add_hook(h2, "post", "Text")
|
||||||
|
|
||||||
|
self.assertTrue([h for h in hm._search(h1)])
|
||||||
|
self.assertFalse([h for h in hm._search(h4)])
|
||||||
|
self.assertEqual(2, len([h for h in hm._search(h2)]))
|
||||||
|
self.assertEqual([("pre", "Text")], [h for h in hm._search(h3)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(self):
|
||||||
|
hm = HooksManager()
|
||||||
|
|
||||||
|
h1 = "HOOK1"
|
||||||
|
h2 = "HOOK2"
|
||||||
|
h3 = "HOOK3"
|
||||||
|
h4 = "HOOK4"
|
||||||
|
|
||||||
|
hm.add_hook(h1)
|
||||||
|
hm.add_hook(h2, "pre")
|
||||||
|
hm.add_hook(h3, "pre", "Text")
|
||||||
|
hm.add_hook(h2, "post", "Text")
|
||||||
|
|
||||||
|
hm.del_hooks(hook=h4)
|
||||||
|
|
||||||
|
self.assertTrue(hm._access("pre")["__end__"])
|
||||||
|
self.assertTrue(hm._access("pre", "Text")["__end__"])
|
||||||
|
hm.del_hooks("pre")
|
||||||
|
self.assertFalse(hm._access("pre")["__end__"])
|
||||||
|
|
||||||
|
self.assertTrue(hm._access("post", "Text")["__end__"])
|
||||||
|
hm.del_hooks("post", "Text", hook=h2)
|
||||||
|
self.assertFalse(hm._access("post", "Text")["__end__"])
|
||||||
|
|
||||||
|
self.assertTrue(hm._access()["__end__"])
|
||||||
|
hm.del_hooks(hook=h1)
|
||||||
|
self.assertFalse(hm._access()["__end__"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_get(self):
|
||||||
|
hm = HooksManager()
|
||||||
|
|
||||||
|
h1 = "HOOK1"
|
||||||
|
h2 = "HOOK2"
|
||||||
|
h3 = "HOOK3"
|
||||||
|
|
||||||
|
hm.add_hook(h1)
|
||||||
|
hm.add_hook(h2, "pre")
|
||||||
|
hm.add_hook(h3, "pre", "Text")
|
||||||
|
hm.add_hook(h2, "post", "Text")
|
||||||
|
|
||||||
|
self.assertEqual([h1, h2], [h for h in hm.get_hooks("pre")])
|
||||||
|
self.assertEqual([h1, h2, h3], [h for h in hm.get_hooks("pre", "Text")])
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rev(self):
|
||||||
|
hm = HooksManager()
|
||||||
|
|
||||||
|
h1 = "HOOK1"
|
||||||
|
h2 = "HOOK2"
|
||||||
|
h3 = "HOOK3"
|
||||||
|
|
||||||
|
hm.add_hook(h1)
|
||||||
|
hm.add_hook(h2, "pre")
|
||||||
|
hm.add_hook(h3, "pre", "Text")
|
||||||
|
hm.add_hook(h2, "post", "Text")
|
||||||
|
|
||||||
|
self.assertEqual([h2, h3], [h for h in hm.get_reverse_hooks("pre")])
|
||||||
|
self.assertEqual([h3], [h for h in hm.get_reverse_hooks("pre", exclude_first=True)])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -26,6 +26,8 @@ class ModuleContext:
|
|||||||
|
|
||||||
if module is not None:
|
if module is not None:
|
||||||
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
|
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
|
||||||
|
else:
|
||||||
|
module_name = ""
|
||||||
|
|
||||||
# Load module configuration if exists
|
# Load module configuration if exists
|
||||||
if (context is not None and
|
if (context is not None and
|
||||||
@ -39,26 +41,23 @@ class ModuleContext:
|
|||||||
self.events = list()
|
self.events = list()
|
||||||
self.debug = context.verbosity > 0 if context is not None else False
|
self.debug = context.verbosity > 0 if context is not None else False
|
||||||
|
|
||||||
|
from nemubot.hooks import Abstract as AbstractHook
|
||||||
|
|
||||||
# Define some callbacks
|
# Define some callbacks
|
||||||
if context is not None:
|
if context is not None:
|
||||||
# Load module data
|
# Load module data
|
||||||
self.data = context.datastore.load(module_name)
|
self.data = context.datastore.load(module_name)
|
||||||
|
|
||||||
def add_hook(store, hook):
|
def add_hook(hook, *triggers):
|
||||||
self.hooks.append((store, hook))
|
assert isinstance(hook, AbstractHook), hook
|
||||||
return context.treater.hm.add_hook(hook, store)
|
self.hooks.append((triggers, hook))
|
||||||
def del_hook(store, hook):
|
return context.treater.hm.add_hook(hook, *triggers)
|
||||||
self.hooks.remove((store, hook))
|
|
||||||
return context.treater.hm.del_hook(hook, store)
|
def del_hook(hook, *triggers):
|
||||||
def call_hook(store, msg):
|
assert isinstance(hook, AbstractHook), hook
|
||||||
for h in context.treater.hm.get_hooks(store):
|
self.hooks.remove((triggers, hook))
|
||||||
if h.match(msg):
|
return context.treater.hm.del_hooks(*triggers, hook=hook)
|
||||||
res = h.run(msg)
|
|
||||||
if isinstance(res, list):
|
|
||||||
for i in res:
|
|
||||||
yield i
|
|
||||||
else:
|
|
||||||
yield res
|
|
||||||
def subtreat(msg):
|
def subtreat(msg):
|
||||||
yield from context.treater.treat_msg(msg)
|
yield from context.treater.treat_msg(msg)
|
||||||
def add_event(evt, eid=None):
|
def add_event(evt, eid=None):
|
||||||
@ -80,13 +79,12 @@ class ModuleContext:
|
|||||||
from nemubot.tools.xmlparser import module_state
|
from nemubot.tools.xmlparser import module_state
|
||||||
self.data = module_state.ModuleState("nemubotstate")
|
self.data = module_state.ModuleState("nemubotstate")
|
||||||
|
|
||||||
def add_hook(store, hook):
|
def add_hook(hook, *triggers):
|
||||||
self.hooks.append((store, hook))
|
assert isinstance(hook, AbstractHook), hook
|
||||||
def del_hook(store, hook):
|
self.hooks.append((triggers, hook))
|
||||||
self.hooks.remove((store, hook))
|
def del_hook(hook, *triggers):
|
||||||
def call_hook(store, msg):
|
assert isinstance(hook, AbstractHook), hook
|
||||||
# TODO: what can we do here?
|
self.hooks.remove((triggers, hook))
|
||||||
return None
|
|
||||||
def subtreat(msg):
|
def subtreat(msg):
|
||||||
return None
|
return None
|
||||||
def add_event(evt, eid=None):
|
def add_event(evt, eid=None):
|
||||||
@ -106,7 +104,6 @@ class ModuleContext:
|
|||||||
self.del_event = del_event
|
self.del_event = del_event
|
||||||
self.save = save
|
self.save = save
|
||||||
self.send_response = send_response
|
self.send_response = send_response
|
||||||
self.call_hook = call_hook
|
|
||||||
self.subtreat = subtreat
|
self.subtreat = subtreat
|
||||||
|
|
||||||
|
|
||||||
@ -115,7 +112,7 @@ class ModuleContext:
|
|||||||
|
|
||||||
# Remove registered hooks
|
# Remove registered hooks
|
||||||
for (s, h) in self.hooks:
|
for (s, h) in self.hooks:
|
||||||
self.del_hook(s, h)
|
self.del_hook(h, *s)
|
||||||
|
|
||||||
# Remove registered events
|
# Remove registered events
|
||||||
for e in self.events:
|
for e in self.events:
|
||||||
|
Loading…
Reference in New Issue
Block a user