From 342bb9acdc88c779e44adf2e9f4a2b4e675d134c Mon Sep 17 00:00:00 2001 From: Pierre-Olivier Mercier Date: Sat, 10 Feb 2018 09:51:51 +0100 Subject: [PATCH] Refactor in treatment analysis --- modules/alias.py | 3 +- nemubot/hooks/abstract.py | 12 ++++++- nemubot/treatment.py | 66 +++++++++++++-------------------------- 3 files changed, 34 insertions(+), 47 deletions(-) diff --git a/modules/alias.py b/modules/alias.py index a246d2c..c432a85 100644 --- a/modules/alias.py +++ b/modules/alias.py @@ -272,7 +272,6 @@ def treat_alias(msg): # Avoid infinite recursion if not isinstance(rpl_msg, Command) or msg.cmd != rpl_msg.cmd: - # Also return origin message, if it can be treated as well - return [msg, rpl_msg] + return rpl_msg return msg diff --git a/nemubot/hooks/abstract.py b/nemubot/hooks/abstract.py index eac4b20..ffe79fb 100644 --- a/nemubot/hooks/abstract.py +++ b/nemubot/hooks/abstract.py @@ -14,6 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import types + def call_game(call, *args, **kargs): """With given args, try to determine the right call to make @@ -119,10 +121,18 @@ class Abstract: try: if self.check(data1): ret = call_game(self.call, data1, self.data, *args) + if isinstance(ret, types.GeneratorType): + for r in ret: + yield r + ret = None except IMException as e: ret = e.fill_response(data1) finally: if self.times == 0: self.call_end(ret) - return ret + if isinstance(ret, list): + for r in ret: + yield ret + elif ret is not None: + yield ret diff --git a/nemubot/treatment.py b/nemubot/treatment.py index 4f629e0..ed7cacb 100644 --- a/nemubot/treatment.py +++ b/nemubot/treatment.py @@ -15,7 +15,6 @@ # along with this program. If not, see . import logging -import types logger = logging.getLogger("nemubot.treatment") @@ -79,19 +78,12 @@ class MessageTreater: for h in self.hm.get_hooks("pre", type(msg).__name__): if h.can_read(msg.to, msg.server) and h.match(msg): - res = h.run(msg) + for res in flatify(h.run(msg)): + if res is not None and res != msg: + yield from self._pre_treat(res) - if isinstance(res, list): - for i in range(len(res)): - # Avoid infinite loop - if res[i] != msg: - yield from self._pre_treat(res[i]) - - elif res is not None and res != msg: - yield from self._pre_treat(res) - - elif res is None or res is False: - break + elif res is None or res is False: + break else: yield msg @@ -113,25 +105,10 @@ class MessageTreater: msg.frm_owner = (not hasattr(msg.server, "owner") or msg.server.owner == msg.frm) while hook is not None: - res = hook.run(msg) - - if isinstance(res, list): - for r in res: - yield r - - elif res is not None: - if isinstance(res, types.GeneratorType): - for r in res: - if not hasattr(r, "server") or r.server is None: - r.server = msg.server - - yield r - - else: - if not hasattr(res, "server") or res.server is None: - res.server = msg.server - - yield res + for res in flatify(hook.run(msg)): + if not hasattr(res, "server") or res.server is None: + res.server = msg.server + yield res hook = next(hook_gen, None) @@ -165,19 +142,20 @@ class MessageTreater: for h in self.hm.get_hooks("post", type(msg).__name__): if h.can_write(msg.to, msg.server) and h.match(msg): - res = h.run(msg) + for res in flatify(h.run(msg)): + if res is not None and res != msg: + yield from self._post_treat(res) - if isinstance(res, list): - for i in range(len(res)): - # Avoid infinite loop - if res[i] != msg: - yield from self._post_treat(res[i]) - - elif res is not None and res != msg: - yield from self._post_treat(res) - - elif res is None or res is False: - break + elif res is None or res is False: + break else: yield msg + + +def flatify(g): + if hasattr(g, "__iter__"): + for i in g: + yield from flatify(i) + else: + yield g