Rework hook managment and add some tests
This commit is contained in:
parent
926648517f
commit
43c42e1397
8 changed files with 222 additions and 78 deletions
|
|
@ -35,19 +35,19 @@ class hook:
|
|||
def add(store, *args, **kwargs):
|
||||
return hook._add(store, Abstract, *args, **kwargs)
|
||||
|
||||
def ask(*args, store="in_DirectAsk", **kwargs):
|
||||
def ask(*args, store=["in","DirectAsk"], **kwargs):
|
||||
return hook._add(store, Message, *args, **kwargs)
|
||||
|
||||
def command(*args, store="in_Command", **kwargs):
|
||||
def command(*args, store=["in","Command"], **kwargs):
|
||||
return hook._add(store, Command, *args, **kwargs)
|
||||
|
||||
def message(*args, store="in_Text", **kwargs):
|
||||
def message(*args, store=["in","Text"], **kwargs):
|
||||
return hook._add(store, Message, *args, **kwargs)
|
||||
|
||||
def post(*args, store="post", **kwargs):
|
||||
def post(*args, store=["post"], **kwargs):
|
||||
return hook._add(store, Abstract, *args, **kwargs)
|
||||
|
||||
def pre(*args, store="pre", **kwargs):
|
||||
def pre(*args, store=["pre"], **kwargs):
|
||||
return hook._add(store, Abstract, *args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,15 +14,47 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class HooksManager:
|
||||
|
||||
"""Class to manage hooks"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, name="core"):
|
||||
"""Initialize the manager"""
|
||||
|
||||
self.hooks = dict()
|
||||
self.logger = logging.getLogger("nemubot.hooks.manager." + name)
|
||||
|
||||
|
||||
def _access(self, *triggers):
|
||||
"""Access to the given triggers chain"""
|
||||
|
||||
h = self.hooks
|
||||
for t in triggers:
|
||||
if t not in h:
|
||||
h[t] = dict()
|
||||
h = h[t]
|
||||
|
||||
if "__end__" not in h:
|
||||
h["__end__"] = list()
|
||||
|
||||
return h
|
||||
|
||||
|
||||
def _search(self, hook, *where, start=None):
|
||||
"""Search all occurence of the given hook"""
|
||||
|
||||
if start is None:
|
||||
start = self.hooks
|
||||
|
||||
for k in start:
|
||||
if k == "__end__":
|
||||
if hook in start[k]:
|
||||
yield where
|
||||
else:
|
||||
yield from self._search(hook, *where + (k,), start=start[k])
|
||||
|
||||
|
||||
def add_hook(self, hook, *triggers):
|
||||
|
|
@ -33,20 +65,19 @@ class HooksManager:
|
|||
triggers -- string that trigger the hook
|
||||
"""
|
||||
|
||||
trigger = "_".join(triggers)
|
||||
assert hook is not None, hook
|
||||
|
||||
if trigger not in self.hooks:
|
||||
self.hooks[trigger] = list()
|
||||
h = self._access(*triggers)
|
||||
|
||||
self.hooks[trigger].append(hook)
|
||||
h["__end__"].append(hook)
|
||||
|
||||
self.logger.debug("New hook successfully added in %s: %s",
|
||||
"/".join(triggers), hook)
|
||||
|
||||
|
||||
def del_hook(self, hook=None, *triggers):
|
||||
def del_hooks(self, *triggers, hook=None):
|
||||
"""Remove the given hook from the manager
|
||||
|
||||
Return:
|
||||
Boolean value reporting the deletion success
|
||||
|
||||
Argument:
|
||||
triggers -- trigger string to remove
|
||||
|
||||
|
|
@ -54,15 +85,20 @@ class HooksManager:
|
|||
hook -- a Hook instance to remove from the trigger string
|
||||
"""
|
||||
|
||||
trigger = "_".join(triggers)
|
||||
assert hook is not None or len(triggers)
|
||||
|
||||
if trigger in self.hooks:
|
||||
if hook is None:
|
||||
del self.hooks[trigger]
|
||||
self.logger.debug("Trying to delete hook in %s: %s",
|
||||
"/".join(triggers), hook)
|
||||
|
||||
if hook is not None:
|
||||
for h in self._search(hook, *triggers, start=self._access(*triggers)):
|
||||
self._access(*h)["__end__"].remove(hook)
|
||||
|
||||
else:
|
||||
if len(triggers):
|
||||
del self._access(*triggers[:-1])[triggers[-1]]
|
||||
else:
|
||||
self.hooks[trigger].remove(hook)
|
||||
return True
|
||||
return False
|
||||
self.hooks = dict()
|
||||
|
||||
|
||||
def get_hooks(self, *triggers):
|
||||
|
|
@ -70,35 +106,29 @@ class HooksManager:
|
|||
|
||||
Argument:
|
||||
triggers -- the trigger string
|
||||
|
||||
Keyword argument:
|
||||
data -- Data to pass to the hook as argument
|
||||
"""
|
||||
|
||||
trigger = "_".join(triggers)
|
||||
|
||||
res = list()
|
||||
|
||||
for key in self.hooks:
|
||||
if trigger.find(key) == 0:
|
||||
res += self.hooks[key]
|
||||
|
||||
return res
|
||||
for n in range(len(triggers) + 1):
|
||||
i = self._access(*triggers[:n])
|
||||
for h in i["__end__"]:
|
||||
yield h
|
||||
|
||||
|
||||
def exec_hook(self, *triggers, **data):
|
||||
"""Trigger hooks that match the given trigger string
|
||||
def get_reverse_hooks(self, *triggers, exclude_first=False):
|
||||
"""Returns list of triggered hooks that are bellow or at the same level
|
||||
|
||||
Argument:
|
||||
trigger -- the trigger string
|
||||
triggers -- the trigger string
|
||||
|
||||
Keyword argument:
|
||||
data -- Data to pass to the hook as argument
|
||||
Keyword arguments:
|
||||
exclude_first -- start reporting hook at the next level
|
||||
"""
|
||||
|
||||
trigger = "_".join(triggers)
|
||||
|
||||
for key in self.hooks:
|
||||
if trigger.find(key) == 0:
|
||||
for hook in self.hooks[key]:
|
||||
hook.run(**data)
|
||||
h = self._access(*triggers)
|
||||
for k in h:
|
||||
if k == "__end__":
|
||||
if not exclude_first:
|
||||
for hk in h[k]:
|
||||
yield hk
|
||||
else:
|
||||
yield from self.get_reverse_hooks(*triggers + (k,))
|
||||
|
|
|
|||
115
nemubot/hooks/manager_test.py
Executable file
115
nemubot/hooks/manager_test.py
Executable file
|
|
@ -0,0 +1,115 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import unittest
|
||||
|
||||
from nemubot.hooks.manager import HooksManager
|
||||
|
||||
class TestHookManager(unittest.TestCase):
|
||||
|
||||
|
||||
def test_access(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
self.assertIn("__end__", hm._access())
|
||||
self.assertIn("__end__", hm._access("pre"))
|
||||
self.assertIn("__end__", hm._access("pre", "Text"))
|
||||
self.assertIn("__end__", hm._access("post", "Text"))
|
||||
|
||||
self.assertFalse(hm._access("inexistant")["__end__"])
|
||||
self.assertTrue(hm._access()["__end__"])
|
||||
self.assertTrue(hm._access("pre")["__end__"])
|
||||
self.assertTrue(hm._access("pre", "Text")["__end__"])
|
||||
self.assertTrue(hm._access("post", "Text")["__end__"])
|
||||
|
||||
|
||||
def test_search(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
h4 = "HOOK4"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
self.assertTrue([h for h in hm._search(h1)])
|
||||
self.assertFalse([h for h in hm._search(h4)])
|
||||
self.assertEqual(2, len([h for h in hm._search(h2)]))
|
||||
self.assertEqual([("pre", "Text")], [h for h in hm._search(h3)])
|
||||
|
||||
|
||||
def test_delete(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
h4 = "HOOK4"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
hm.del_hooks(hook=h4)
|
||||
|
||||
self.assertTrue(hm._access("pre")["__end__"])
|
||||
self.assertTrue(hm._access("pre", "Text")["__end__"])
|
||||
hm.del_hooks("pre")
|
||||
self.assertFalse(hm._access("pre")["__end__"])
|
||||
|
||||
self.assertTrue(hm._access("post", "Text")["__end__"])
|
||||
hm.del_hooks("post", "Text", hook=h2)
|
||||
self.assertFalse(hm._access("post", "Text")["__end__"])
|
||||
|
||||
self.assertTrue(hm._access()["__end__"])
|
||||
hm.del_hooks(hook=h1)
|
||||
self.assertFalse(hm._access()["__end__"])
|
||||
|
||||
|
||||
def test_get(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
self.assertEqual([h1, h2], [h for h in hm.get_hooks("pre")])
|
||||
self.assertEqual([h1, h2, h3], [h for h in hm.get_hooks("pre", "Text")])
|
||||
|
||||
|
||||
def test_get_rev(self):
|
||||
hm = HooksManager()
|
||||
|
||||
h1 = "HOOK1"
|
||||
h2 = "HOOK2"
|
||||
h3 = "HOOK3"
|
||||
|
||||
hm.add_hook(h1)
|
||||
hm.add_hook(h2, "pre")
|
||||
hm.add_hook(h3, "pre", "Text")
|
||||
hm.add_hook(h2, "post", "Text")
|
||||
|
||||
self.assertEqual([h2, h3], [h for h in hm.get_reverse_hooks("pre")])
|
||||
self.assertEqual([h3], [h for h in hm.get_reverse_hooks("pre", exclude_first=True)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue