diff --git a/nemubot/__init__.py b/nemubot/__init__.py index c4e1df9..cdc6265 100644 --- a/nemubot/__init__.py +++ b/nemubot/__init__.py @@ -17,12 +17,15 @@ __version__ = '4.0.dev3' __author__ = 'nemunaire' +from typing import Optional + from nemubot.modulecontext import ModuleContext context = ModuleContext(None, None) -def requires_version(min=None, max=None): +def requires_version(min: Optional[int] = None, + max: Optional[int] = None) -> None: """Raise ImportError if the current version is not in the given range Keyword arguments: @@ -39,7 +42,7 @@ def requires_version(min=None, max=None): "but this is nemubot v%s." % (str(max), __version__)) -def attach(pid, socketfile): +def attach(pid: int, socketfile: str) -> int: import socket import sys @@ -98,7 +101,7 @@ def attach(pid, socketfile): return 0 -def daemonize(): +def daemonize() -> None: """Detach the running process to run as a daemon """ diff --git a/nemubot/__main__.py b/nemubot/__main__.py index 5a236f4..cb9fae6 100644 --- a/nemubot/__main__.py +++ b/nemubot/__main__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -def main(): +def main() -> None: import os import signal import sys @@ -152,6 +152,7 @@ def main(): # Signals handling def sigtermhandler(signum, frame): """On SIGTERM and SIGINT, quit nicely""" + sigusr1handler(signum, frame) context.quit() signal.signal(signal.SIGINT, sigtermhandler) signal.signal(signal.SIGTERM, sigtermhandler) @@ -170,17 +171,23 @@ def main(): def sigusr1handler(signum, frame): """On SIGHUSR1, display stacktraces""" - import traceback + import threading, traceback for threadId, stack in sys._current_frames().items(): - logger.debug("########### Thread %d:\n%s", - threadId, + thName = "#%d" % threadId + for th in threading.enumerate(): + if th.ident == threadId: + thName = th.name + break + logger.debug("########### Thread %s:\n%s", + thName, "".join(traceback.format_stack(stack))) signal.signal(signal.SIGUSR1, sigusr1handler) if args.socketfile: - from nemubot.server.socket import SocketListener - context.add_server(SocketListener(context.add_server, "master_socket", - sock_location=args.socketfile)) + from nemubot.server.socket import UnixSocketListener + context.add_server(UnixSocketListener(new_server_cb=context.add_server, + location=args.socketfile, + name="master_socket")) # context can change when performing an hotswap, always join the latest context oldcontext = None diff --git a/nemubot/bot.py b/nemubot/bot.py index b7c71b9..10bcff7 100644 --- a/nemubot/bot.py +++ b/nemubot/bot.py @@ -15,9 +15,13 @@ # along with this program. If not, see . from datetime import datetime, timezone +import ipaddress import logging +from multiprocessing import JoinableQueue import threading +import select import sys +from typing import Any, Mapping, Optional, Sequence from nemubot import __version__ from nemubot.consumer import Consumer, EventConsumer, MessageConsumer @@ -26,13 +30,23 @@ import nemubot.hooks logger = logging.getLogger("nemubot") +sync_queue = JoinableQueue() + +def sync_act(*args): + if isinstance(act, bytes): + act = act.decode() + sync_queue.put(act) + class Bot(threading.Thread): """Class containing the bot context and ensuring key goals""" - def __init__(self, ip="127.0.0.1", modules_paths=list(), - data_store=datastore.Abstract(), verbosity=0): + def __init__(self, + ip: Optional[ipaddress] = None, + modules_paths: Sequence[str] = list(), + data_store: Optional[datastore.Abstract] = None, + verbosity: int = 0): """Initialize the bot context Keyword arguments: @@ -42,7 +56,7 @@ class Bot(threading.Thread): verbosity -- verbosity level """ - threading.Thread.__init__(self) + super().__init__(name="Nemubot main") logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", __version__, @@ -52,16 +66,16 @@ class Bot(threading.Thread): self.stop = None # External IP for accessing this bot - import ipaddress - self.ip = ipaddress.ip_address(ip) + self.ip = ip if ip is not None else ipaddress.ip_address("127.0.0.1") # Context paths self.modules_paths = modules_paths - self.datastore = data_store + self.datastore = data_store if data_store is not None else datastore.Abstract() self.datastore.open() # Keep global context: servers and modules - self.servers = dict() + self._poll = select.poll() + self.servers = dict() # types: Mapping[str, AbstractServer] self.modules = dict() self.modules_configuration = dict() @@ -138,60 +152,76 @@ class Bot(threading.Thread): self.cnsr_queue = Queue() self.cnsr_thrd = list() self.cnsr_thrd_size = -1 - # Synchrone actions to be treated by main thread - self.sync_queue = Queue() def run(self): - from select import select - from nemubot.server import _lock, _rlist, _wlist, _xlist + self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI) logger.info("Starting main loop") self.stop = False while not self.stop: - with _lock: - try: - rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) - except: - logger.error("Something went wrong in select") - fnd_smth = False - # Looking for invalid server - for r in _rlist: - if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0: - _rlist.remove(r) - logger.error("Found invalid object in _rlist: " + str(r)) - fnd_smth = True - for w in _wlist: - if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0: - _wlist.remove(w) - logger.error("Found invalid object in _wlist: " + str(w)) - fnd_smth = True - for x in _xlist: - if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0: - _xlist.remove(x) - logger.error("Found invalid object in _xlist: " + str(x)) - fnd_smth = True - if not fnd_smth: - logger.exception("Can't continue, sorry") - self.quit() - continue + for fd, flag in self._poll.poll(): + print("poll") + # Handle internal socket passing orders + if fd != sync_queue._reader.fileno(): + srv = self.servers[fd] - for x in xl: - try: - x.exception() - except: - logger.exception("Uncatched exception on server exception") - for w in wl: - try: - w.write_select() - except: - logger.exception("Uncatched exception on server write") - for r in rl: - for i in r.read(): + if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL): try: - self.receive_message(r, i) + srv.exception(flag) except: - logger.exception("Uncatched exception on server read") + logger.exception("Uncatched exception on server exception") + + if srv.fileno() > 0: + if flag & (select.POLLOUT): + try: + srv.async_write() + except: + logger.exception("Uncatched exception on server write") + + if flag & (select.POLLIN | select.POLLPRI): + try: + for i in srv.async_read(): + self.receive_message(srv, i) + except: + logger.exception("Uncatched exception on server read") + + else: + del self.servers[fd] + + + # Always check the sync queue + while not sync_queue.empty(): + import shlex + args = shlex.split(sync_queue.get()) + action = args.pop(0) + + logger.info("action: %s: %s", action, args) + + if action == "sckt" and len(args) >= 2: + try: + if args[0] == "write": + self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR) + elif args[0] == "unwrite": + self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR) + + elif args[0] == "register": + self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR) + elif args[0] == "unregister": + self._poll.unregister(int(args[1])) + except: + logger.exception("Unhandled excpetion during action:") + + elif action == "exit": + self.quit() + + elif action == "loadconf": + for path in action.args: + logger.debug("Load configuration from %s", path) + self.load_file(path) + logger.info("Configurations successfully loaded") + + sync_queue.task_done() # Launch new consumer threads if necessary @@ -202,17 +232,6 @@ class Bot(threading.Thread): c = Consumer(self) self.cnsr_thrd.append(c) c.start() - - while self.sync_queue.qsize() > 0: - action = self.sync_queue.get_nowait() - if action[0] == "exit": - self.quit() - elif action[0] == "loadconf": - for path in action[1:]: - logger.debug("Load configuration from %s", path) - self.load_file(path) - logger.info("Configurations successfully loaded") - self.sync_queue.task_done() logger.info("Ending main loop") @@ -414,10 +433,11 @@ class Bot(threading.Thread): autoconnect -- connect after add? """ - if srv.fileno not in self.servers: - self.servers[srv.fileno] = srv + fileno = srv.fileno() + if fileno not in self.servers: + self.servers[fileno] = srv if autoconnect and not hasattr(self, "noautoconnect"): - srv.open() + srv.connect() return True else: @@ -439,10 +459,10 @@ class Bot(threading.Thread): __import__(name) - def add_module(self, module): + def add_module(self, mdl: Any): """Add a module to the context, if already exists, unload the old one before""" - module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ + module_name = mdl.__spec__.name if hasattr(mdl, "__spec__") else mdl.__name__ if hasattr(self, "stop") and self.stop: logger.warn("The bot is stopped, can't register new modules") @@ -454,40 +474,40 @@ class Bot(threading.Thread): # Overwrite print built-in def prnt(*args): - if hasattr(module, "logger"): - module.logger.info(" ".join([str(s) for s in args])) + if hasattr(mdl, "logger"): + mdl.logger.info(" ".join([str(s) for s in args])) else: logger.info("[%s] %s", module_name, " ".join([str(s) for s in args])) - module.print = prnt + mdl.print = prnt # Create module context from nemubot.modulecontext import ModuleContext - module.__nemubot_context__ = ModuleContext(self, module) + mdl.__nemubot_context__ = ModuleContext(self, mdl) - if not hasattr(module, "logger"): - module.logger = logging.getLogger("nemubot.module." + module_name) + if not hasattr(mdl, "logger"): + mdl.logger = logging.getLogger("nemubot.module." + module_name) # Replace imported context by real one - for attr in module.__dict__: - if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext: - module.__dict__[attr] = module.__nemubot_context__ + for attr in mdl.__dict__: + if attr != "__nemubot_context__" and type(mdl.__dict__[attr]) == ModuleContext: + mdl.__dict__[attr] = mdl.__nemubot_context__ # Register decorated functions import nemubot.hooks for s, h in nemubot.hooks.hook.last_registered: - module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s) + mdl.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s) nemubot.hooks.hook.last_registered = [] # Launch the module - if hasattr(module, "load"): + if hasattr(mdl, "load"): try: - module.load(module.__nemubot_context__) + mdl.load(mdl.__nemubot_context__) except: - module.__nemubot_context__.unload() + mdl.__nemubot_context__.unload() raise # Save a reference to the module - self.modules[module_name] = module + self.modules[module_name] = mdl def unload_module(self, name): @@ -530,28 +550,28 @@ class Bot(threading.Thread): def quit(self): """Save and unload modules and disconnect servers""" - self.datastore.close() - if self.event_timer is not None: logger.info("Stop the event timer...") self.event_timer.cancel() + logger.info("Save and unload all modules...") + for mod in self.modules.items(): + self.unload_module(mod) + + logger.info("Close all servers connection...") + for k in self.servers: + self.servers[k].close() + logger.info("Stop consumers") k = self.cnsr_thrd for cnsr in k: cnsr.stop = True - logger.info("Save and unload all modules...") - k = list(self.modules.keys()) - for mod in k: - self.unload_module(mod) - - logger.info("Close all servers connection...") - k = list(self.servers.keys()) - for srv in k: - self.servers[srv].close() + self.datastore.close() self.stop = True + sync_queue.put("end") + sync_queue.join() # Treatment diff --git a/nemubot/channel.py b/nemubot/channel.py index a070131..c01ac90 100644 --- a/nemubot/channel.py +++ b/nemubot/channel.py @@ -15,13 +15,18 @@ # along with this program. If not, see . import logging +from typing import Optional +from nemubot.message import Abstract as AbstractMessage class Channel: """A chat room""" - def __init__(self, name, password=None, encoding=None): + def __init__(self, + name: str, + password: Optional[str] = None, + encoding: Optional[str] = None): """Initialize the channel Arguments: @@ -37,7 +42,8 @@ class Channel: self.topic = "" self.logger = logging.getLogger("nemubot.channel." + name) - def treat(self, cmd, msg): + + def treat(self, cmd: str, msg: AbstractMessage) -> None: """Treat a incoming IRC command Arguments: @@ -60,7 +66,8 @@ class Channel: elif cmd == "TOPIC": self.topic = self.text - def join(self, nick, level=0): + + def join(self, nick: str, level: int = 0) -> None: """Someone join the channel Argument: @@ -71,7 +78,8 @@ class Channel: self.logger.debug("%s join", nick) self.people[nick] = level - def chtopic(self, newtopic): + + def chtopic(self, newtopic: str) -> None: """Send command to change the topic Arguments: @@ -81,7 +89,8 @@ class Channel: self.srv.send_msg(self.name, newtopic, "TOPIC") self.topic = newtopic - def nick(self, oldnick, newnick): + + def nick(self, oldnick: str, newnick: str) -> None: """Someone change his nick Arguments: @@ -95,7 +104,8 @@ class Channel: del self.people[oldnick] self.people[newnick] = lvl - def part(self, nick): + + def part(self, nick: str) -> None: """Someone leave the channel Argument: @@ -106,7 +116,8 @@ class Channel: self.logger.debug("%s has left", nick) del self.people[nick] - def mode(self, msg): + + def mode(self, msg: AbstractMessage) -> None: """Channel or user mode change Argument: @@ -132,7 +143,8 @@ class Channel: elif msg.text[0] == "-v": self.people[msg.nick] &= ~1 - def parse332(self, msg): + + def parse332(self, msg: AbstractMessage) -> None: """Parse RPL_TOPIC message Argument: @@ -141,7 +153,8 @@ class Channel: self.topic = msg.text - def parse353(self, msg): + + def parse353(self, msg: AbstractMessage) -> None: """Parse RPL_ENDOFWHO message Argument: diff --git a/nemubot/config/__init__.py b/nemubot/config/__init__.py index 6bbc1b2..ea6fed4 100644 --- a/nemubot/config/__init__.py +++ b/nemubot/config/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -def get_boolean(s): +def get_boolean(s) -> bool: if isinstance(s, bool): return s else: diff --git a/nemubot/config/include.py b/nemubot/config/include.py index 408c09a..aca6468 100644 --- a/nemubot/config/include.py +++ b/nemubot/config/include.py @@ -16,5 +16,5 @@ class Include: - def __init__(self, path): + def __init__(self, path: str): self.path = path diff --git a/nemubot/config/module.py b/nemubot/config/module.py index 7586697..e67a45b 100644 --- a/nemubot/config/module.py +++ b/nemubot/config/module.py @@ -20,7 +20,11 @@ from nemubot.datastore.nodes.generic import GenericNode class Module(GenericNode): - def __init__(self, name, autoload=True, **kwargs): + def __init__(self, + name: str, + autoload: bool = True, + **kwargs): super().__init__(None, **kwargs) + self.name = name self.autoload = get_boolean(autoload) diff --git a/nemubot/config/nemubot.py b/nemubot/config/nemubot.py index 992cd8e..cc60f86 100644 --- a/nemubot/config/nemubot.py +++ b/nemubot/config/nemubot.py @@ -14,15 +14,23 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from nemubot.config.include import Include -from nemubot.config.module import Module -from nemubot.config.server import Server +from typing import Optional, Sequence, Union + +import nemubot.config.include +import nemubot.config.module +import nemubot.config.server class Nemubot: - def __init__(self, nick="nemubot", realname="nemubot", owner=None, - ip=None, ssl=False, caps=None, encoding="utf-8"): + def __init__(self, + nick: str = "nemubot", + realname: str = "nemubot", + owner: Optional[str] = None, + ip: Optional[str] = None, + ssl: bool = False, + caps: Optional[Sequence[str]] = None, + encoding: str = "utf-8"): self.nick = nick self.realname = realname self.owner = owner @@ -34,13 +42,13 @@ class Nemubot: self.includes = [] - def addChild(self, name, child): - if name == "module" and isinstance(child, Module): + def addChild(self, name: str, child: Union[nemubot.config.module.Module, nemubot.config.server.Server, nemubot.config.include.Include]): + if name == "module" and isinstance(child, nemubot.config.module.Module): self.modules.append(child) return True - elif name == "server" and isinstance(child, Server): + elif name == "server" and isinstance(child, nemubot.config.server.Server): self.servers.append(child) return True - elif name == "include" and isinstance(child, Include): + elif name == "include" and isinstance(child, nemubot.config.include.Include): self.includes.append(child) return True diff --git a/nemubot/config/server.py b/nemubot/config/server.py index 14ca9a8..b8df692 100644 --- a/nemubot/config/server.py +++ b/nemubot/config/server.py @@ -14,12 +14,19 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Optional, Sequence + from nemubot.channel import Channel +import nemubot.config.nemubot class Server: - def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs): + def __init__(self, + uri: str = "irc://nemubot@localhost/", + autoconnect: bool = True, + caps: Optional[Sequence[str]] = None, + **kwargs): self.uri = uri self.autoconnect = autoconnect self.caps = caps.split(" ") if caps is not None else [] @@ -27,7 +34,7 @@ class Server: self.channels = [] - def addChild(self, name, child): + def addChild(self, name: str, child: Channel): if name == "channel" and isinstance(child, Channel): self.channels.append(child) return True diff --git a/nemubot/consumer.py b/nemubot/consumer.py index 0cd4ed5..8ea5a40 100644 --- a/nemubot/consumer.py +++ b/nemubot/consumer.py @@ -17,6 +17,12 @@ import logging import queue import threading +from typing import List + +from nemubot.bot import Bot +from nemubot.event import ModuleEvent +from nemubot.message.abstract import Abstract as AbstractMessage +from nemubot.server.abstract import AbstractServer logger = logging.getLogger("nemubot.consumer") @@ -25,18 +31,15 @@ class MessageConsumer: """Store a message before treating""" - def __init__(self, srv, msg): + def __init__(self, srv: AbstractServer, msg: AbstractMessage): self.srv = srv self.orig = msg - def run(self, context): + def run(self, context: Bot) -> None: """Create, parse and treat the message""" - from nemubot.bot import Bot - assert isinstance(context, Bot) - - msgs = [] + msgs = [] # type: List[AbstractMessage] # Parse the message try: @@ -55,8 +58,6 @@ class MessageConsumer: if hasattr(msg, "frm_owner"): msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm) - from nemubot.server.abstract import AbstractServer - # Treat the message for msg in msgs: for res in context.treater.treat_msg(msg): @@ -87,12 +88,12 @@ class EventConsumer: """Store a event before treating""" - def __init__(self, evt, timeout=20): + def __init__(self, evt: ModuleEvent, timeout: int = 20): self.evt = evt self.timeout = timeout - def run(self, context): + def run(self, context: Bot) -> None: try: self.evt.check() except: @@ -113,13 +114,13 @@ class Consumer(threading.Thread): """Dequeue and exec requested action""" - def __init__(self, context): + def __init__(self, context: Bot): self.context = context self.stop = False super().__init__(name="Nemubot consumer") - def run(self): + def run(self) -> None: try: while not self.stop: stm = self.context.cnsr_queue.get(True, 1) diff --git a/nemubot/datastore/abstract.py b/nemubot/datastore/abstract.py index f54bbcd..856851f 100644 --- a/nemubot/datastore/abstract.py +++ b/nemubot/datastore/abstract.py @@ -25,11 +25,14 @@ class Abstract: return None - def open(self): - return - def close(self): - return + def open(self) -> bool: + return True + + + def close(self) -> bool: + return True + def load(self, module): """Load data for the given module @@ -43,6 +46,7 @@ class Abstract: return self.new() + def save(self, module, data): """Load data for the given module @@ -56,9 +60,11 @@ class Abstract: return True + def __enter__(self): self.open() return self + def __exit__(self, type, value, traceback): self.close() diff --git a/nemubot/datastore/nodes/basic.py b/nemubot/datastore/nodes/basic.py index 6fbd136..a4467b2 100644 --- a/nemubot/datastore/nodes/basic.py +++ b/nemubot/datastore/nodes/basic.py @@ -14,6 +14,9 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Any, Mapping, Sequence + +from nemubot.datastore.nodes.generic import ParsingNode from nemubot.datastore.nodes.serializable import Serializable @@ -25,24 +28,24 @@ class ListNode(Serializable): serializetag = "list" def __init__(self, **kwargs): - self.items = list() + self.items = list() # type: Sequence - def addChild(self, name, child): + def addChild(self, name: str, child) -> bool: self.items.append(child) return True - def parsedForm(self): + def parsedForm(self) -> Sequence: return self.items - def __len__(self): + def __len__(self) -> int: return len(self.items) - def __getitem__(self, item): + def __getitem__(self, item: int) -> Any: return self.items[item] - def __setitem__(self, item, v): + def __setitem__(self, item: int, v: Any) -> None: self.items[item] = v def __contains__(self, item): @@ -52,8 +55,7 @@ class ListNode(Serializable): return self.items.__repr__() - def serialize(self): - from nemubot.datastore.nodes.generic import ParsingNode + def serialize(self) -> ParsingNode: node = ParsingNode(tag=self.serializetag) for i in self.items: node.children.append(ParsingNode.serialize_node(i)) @@ -72,12 +74,12 @@ class DictNode(Serializable): self._cur = None - def startElement(self, name, attrs): + def startElement(self, name: str, attrs: Mapping[str, str]): if self._cur is None and "key" in attrs: self._cur = attrs["key"] return False - def addChild(self, name, child): + def addChild(self, name: str, child: Any): if self._cur is None: return False @@ -85,24 +87,24 @@ class DictNode(Serializable): self._cur = None return True - def parsedForm(self): + def parsedForm(self) -> Mapping: return self.items - def __getitem__(self, item): + def __getitem__(self, item: str) -> Any: return self.items[item] - def __setitem__(self, item, v): + def __setitem__(self, item: str, v: str) -> None: self.items[item] = v - def __contains__(self, item): + def __contains__(self, item: str) -> bool: return item in self.items - def __repr__(self): + def __repr__(self) -> str: return self.items.__repr__() - def serialize(self): + def serialize(self) -> ParsingNode: from nemubot.datastore.nodes.generic import ParsingNode node = ParsingNode(tag=self.serializetag) for k in self.items: diff --git a/nemubot/datastore/nodes/generic.py b/nemubot/datastore/nodes/generic.py index c9840bc..939019c 100644 --- a/nemubot/datastore/nodes/generic.py +++ b/nemubot/datastore/nodes/generic.py @@ -14,6 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import Any, Optional, Mapping, Union + from nemubot.datastore.nodes.serializable import Serializable @@ -22,41 +24,44 @@ class ParsingNode: """Allow any kind of subtags, just keep parsed ones """ - def __init__(self, tag=None, **kwargs): + def __init__(self, + tag: Optional[str] = None, + **kwargs): self.tag = tag self.attrs = kwargs self.content = "" self.children = [] - def characters(self, content): + def characters(self, content: str) -> None: self.content += content - def addChild(self, name, child): + def addChild(self, name: str, child: Any) -> bool: self.children.append(child) return True - def hasNode(self, nodename): + def hasNode(self, nodename: str) -> bool: return self.getNode(nodename) is not None - def getNode(self, nodename): + def getNode(self, nodename: str) -> Optional[Any]: for c in self.children: if c is not None and c.tag == nodename: return c return None - def __getitem__(self, item): + def __getitem__(self, item: str) -> Any: return self.attrs[item] - def __contains__(self, item): + def __contains__(self, item: str) -> bool: return item in self.attrs - def serialize_node(node, **def_kwargs): + def serialize_node(node: Union[Serializable, str, int, float, list, dict], + **def_kwargs): """Serialize any node or basic data to a ParsingNode instance""" if isinstance(node, Serializable): @@ -102,13 +107,16 @@ class GenericNode(ParsingNode): """Consider all subtags as dictionnary """ - def __init__(self, tag, **kwargs): + def __init__(self, + tag: str, + **kwargs): super().__init__(tag, **kwargs) + self._cur = None self._deep_cur = 0 - def startElement(self, name, attrs): + def startElement(self, name: str, attrs: Mapping[str, str]): if self._cur is None: self._cur = GenericNode(name, **attrs) self._deep_cur = 0 @@ -118,14 +126,14 @@ class GenericNode(ParsingNode): return True - def characters(self, content): + def characters(self, content: str): if self._cur is None: super().characters(content) else: self._cur.characters(content) - def endElement(self, name): + def endElement(self, name: str): if name is None: return diff --git a/nemubot/datastore/nodes/python.py b/nemubot/datastore/nodes/python.py index 6e4278b..819bf21 100644 --- a/nemubot/datastore/nodes/python.py +++ b/nemubot/datastore/nodes/python.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from nemubot.datastore.nodes.generic import ParsingNode from nemubot.datastore.nodes.serializable import Serializable @@ -27,15 +28,15 @@ class PythonTypeNode(Serializable): self._cnt = "" - def characters(self, content): + def characters(self, content: str) -> None: self._cnt += content - def endElement(self, name): + def endElement(self, name: str) -> None: raise NotImplemented - def __repr__(self): + def __repr__(self) -> str: return self.value.__repr__() @@ -50,12 +51,11 @@ class IntNode(PythonTypeNode): serializetag = "int" - def endElement(self, name): + def endElement(self, name: str) -> bool: self.value = int(self._cnt) return True - def serialize(self): - from nemubot.datastore.nodes.generic import ParsingNode + def serialize(self) -> ParsingNode: node = ParsingNode(tag=self.serializetag) node.content = str(self.value) return node @@ -65,12 +65,11 @@ class FloatNode(PythonTypeNode): serializetag = "float" - def endElement(self, name): + def endElement(self, name: str) -> bool: self.value = float(self._cnt) return True - def serialize(self): - from nemubot.datastore.nodes.generic import ParsingNode + def serialize(self) -> ParsingNode: node = ParsingNode(tag=self.serializetag) node.content = str(self.value) return node @@ -85,7 +84,6 @@ class StringNode(PythonTypeNode): return True def serialize(self): - from nemubot.datastore.nodes.generic import ParsingNode node = ParsingNode(tag=self.serializetag) node.content = str(self.value) return node diff --git a/nemubot/datastore/xml.py b/nemubot/datastore/xml.py index 266c3ac..abf1492 100644 --- a/nemubot/datastore/xml.py +++ b/nemubot/datastore/xml.py @@ -17,6 +17,7 @@ import fcntl import logging import os +from typing import Any, Mapping import xml.parsers.expat from nemubot.datastore.abstract import Abstract @@ -28,7 +29,9 @@ class XML(Abstract): """A concrete implementation of a data store that relies on XML files""" - def __init__(self, basedir, rotate=True): + def __init__(self, + basedir: str, + rotate: bool = True): """Initialize the datastore Arguments: @@ -45,7 +48,7 @@ class XML(Abstract): "enabled" if self.rotate else "disabled") - def open(self): + def open(self) -> bool: """Lock the directory""" if not os.path.isdir(self.basedir): @@ -75,7 +78,7 @@ class XML(Abstract): return True - def close(self): + def close(self) -> bool: """Release a locked path""" if hasattr(self, "lock_file"): @@ -91,19 +94,19 @@ class XML(Abstract): return False - def _get_data_file_path(self, module): + def _get_data_file_path(self, module: str) -> str: """Get the path to the module data file""" return os.path.join(self.basedir, module + ".xml") - def _get_lock_file_path(self): + def _get_lock_file_path(self) -> str: """Get the path to the datastore lock file""" return os.path.join(self.basedir, ".used_by_nemubot") - def load(self, module, extendsTags={}): + def load(self, module: str, extendsTags: Mapping[str, Any] = {}) -> Abstract: """Load data for the given module Argument: @@ -116,7 +119,7 @@ class XML(Abstract): data_file = self._get_data_file_path(module) - def parse(path): + def parse(path: str): from nemubot.tools.xmlparser import XMLParser from nemubot.datastore.nodes import basic as basicNodes from nemubot.datastore.nodes import python as pythonNodes @@ -156,7 +159,7 @@ class XML(Abstract): return Abstract.load(self, module) - def _rotate(self, path): + def _rotate(self, path: str) -> None: """Backup given path Argument: @@ -173,7 +176,7 @@ class XML(Abstract): os.rename(src, dst) - def _save_node(self, gen, node): + def _save_node(self, gen, node: Any): from nemubot.datastore.nodes.generic import ParsingNode # First, get the serialized form of the node @@ -191,7 +194,7 @@ class XML(Abstract): gen.endElement(node.tag) - def save(self, module, data): + def save(self, module: str, data: Any) -> bool: """Load data for the given module Argument: diff --git a/nemubot/event/__init__.py b/nemubot/event/__init__.py index 7b2adfd..ab96efb 100644 --- a/nemubot/event/__init__.py +++ b/nemubot/event/__init__.py @@ -15,15 +15,23 @@ # along with this program. If not, see . from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Optional class ModuleEvent: """Representation of a event initiated by a bot module""" - def __init__(self, call=None, call_data=None, func=None, func_data=None, - cmp=None, cmp_data=None, interval=60, offset=0, times=1): - + def __init__(self, + call: Callable = None, + call_data: Callable = None, + func: Callable = None, + func_data: Any = None, + cmp: Any = None, + cmp_data: Any = None, + interval: int = 60, + offset: int = 0, + times: int = 1): """Initialize the event Keyword arguments: @@ -70,8 +78,9 @@ class ModuleEvent: # How many times do this event? self.times = times + @property - def current(self): + def current(self) -> Optional[datetime.datetime]: """Return the date of the near check""" if self.times != 0: if self._end is None: @@ -79,8 +88,9 @@ class ModuleEvent: return self._end return None + @property - def next(self): + def next(self) -> Optional[datetime.datetime]: """Return the date of the next check""" if self.times != 0: if self._end is None: @@ -90,14 +100,16 @@ class ModuleEvent: return self._end return None + @property - def time_left(self): + def time_left(self) -> Union[datetime.datetime, int]: """Return the time left before/after the near check""" if self.current is not None: return self.current - datetime.now(timezone.utc) return 99999 # TODO: 99999 is not a valid time to return - def check(self): + + def check(self) -> None: """Run a check and realized the event if this is time""" # Get initial data diff --git a/nemubot/hooks/manager.py b/nemubot/hooks/manager.py index 6a57d2a..9d57483 100644 --- a/nemubot/hooks/manager.py +++ b/nemubot/hooks/manager.py @@ -21,7 +21,7 @@ class HooksManager: """Class to manage hooks""" - def __init__(self, name="core"): + def __init__(self, name: str = "core"): """Initialize the manager""" self.hooks = dict() diff --git a/nemubot/importer.py b/nemubot/importer.py index eaf1535..2827da9 100644 --- a/nemubot/importer.py +++ b/nemubot/importer.py @@ -18,16 +18,18 @@ from importlib.abc import Finder from importlib.machinery import SourceFileLoader import logging import os +from typing import Callable logger = logging.getLogger("nemubot.importer") class ModuleFinder(Finder): - def __init__(self, modules_paths, add_module): + def __init__(self, modules_paths: str, add_module: Callable[]): self.modules_paths = modules_paths self.add_module = add_module + def find_module(self, fullname, path=None): # Search only for new nemubot modules (packages init) if path is None: diff --git a/nemubot/server/DCC.py b/nemubot/server/DCC.py index 644a8cb..c1a6852 100644 --- a/nemubot/server/DCC.py +++ b/nemubot/server/DCC.py @@ -31,7 +31,7 @@ PORTS = list() class DCC(server.AbstractServer): def __init__(self, srv, dest, socket=None): - super().__init__(self) + super().__init__(name="Nemubot DCC server") self.error = False # An error has occur, closing the connection? self.messages = list() # Message queued before connexion diff --git a/nemubot/server/IRC.py b/nemubot/server/IRC.py index 08e2bc5..d09966a 100644 --- a/nemubot/server/IRC.py +++ b/nemubot/server/IRC.py @@ -27,10 +27,10 @@ class IRC(SocketServer): """Concrete implementation of a connexion to an IRC server""" - def __init__(self, host="localhost", port=6667, ssl=False, owner=None, + def __init__(self, host="localhost", port=6667, owner=None, nick="nemubot", username=None, password=None, realname="Nemubot", encoding="utf-8", caps=None, - channels=list(), on_connect=None): + channels=list(), on_connect=None, **kwargs): """Prepare a connection with an IRC server Keyword arguments: @@ -54,7 +54,8 @@ class IRC(SocketServer): self.owner = owner self.realname = realname - super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port)) + super().__init__(name=self.username + "@" + host + ":" + str(port), + host=host, port=port, **kwargs) self.printer = IRCPrinter self.encoding = encoding @@ -231,20 +232,19 @@ class IRC(SocketServer): # Open/close - def open(self): - if super().open(): - if self.password is not None: - self.write("PASS :" + self.password) - if self.capabilities is not None: - self.write("CAP LS") - self.write("NICK :" + self.nick) - self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) - return True - return False + def connect(self): + super().connect() + + if self.password is not None: + self.write("PASS :" + self.password) + if self.capabilities is not None: + self.write("CAP LS") + self.write("NICK :" + self.nick) + self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) def close(self): - if not self.closed: + if not self._closed: self.write("QUIT") return super().close() @@ -253,8 +253,8 @@ class IRC(SocketServer): # Read - def read(self): - for line in super().read(): + def async_read(self): + for line in super().async_read(): # PING should be handled here, so start parsing here :/ msg = IRCMessage(line, self.encoding) diff --git a/nemubot/server/__init__.py b/nemubot/server/__init__.py index 3c88138..464c924 100644 --- a/nemubot/server/__init__.py +++ b/nemubot/server/__init__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,20 +14,13 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import threading -_lock = threading.Lock() - -# Lists for select -_rlist = [] -_wlist = [] -_xlist = [] - - -def factory(uri, **init_args): +def factory(uri, ssl=False, **init_args): from urllib.parse import urlparse, unquote o = urlparse(uri) + srv = None + if o.scheme == "irc" or o.scheme == "ircs": # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html @@ -36,7 +29,7 @@ def factory(uri, **init_args): modifiers = o.path.split(",") target = unquote(modifiers.pop(0)[1:]) - if o.scheme == "ircs": args["ssl"] = True + if o.scheme == "ircs": ssl = True if o.hostname is not None: args["host"] = o.hostname if o.port is not None: args["port"] = o.port if o.username is not None: args["username"] = o.username @@ -65,6 +58,11 @@ def factory(uri, **init_args): args["channels"] = [ target ] from nemubot.server.IRC import IRC as IRCServer - return IRCServer(**args) - else: - return None + srv = IRCServer(**args) + + if ssl: + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + return ctx.wrap_socket(srv) + + return srv diff --git a/nemubot/server/abstract.py b/nemubot/server/abstract.py index dc2081d..7e31cda 100644 --- a/nemubot/server/abstract.py +++ b/nemubot/server/abstract.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,34 +14,34 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import io import logging import queue -from nemubot.server import _lock, _rlist, _wlist, _xlist +from nemubot.bot import sync_act -# Extends from IOBase in order to be compatible with select function -class AbstractServer(io.IOBase): + +class AbstractServer: """An abstract server: handle communication with an IM server""" - def __init__(self, name=None, send_callback=None): + def __init__(self, name=None, **kwargs): """Initialize an abstract server Keyword argument: - send_callback -- Callback when developper want to send a message + name -- Identifier of the socket, for convinience """ self._name = name - super().__init__() + super().__init__(**kwargs) - self.logger = logging.getLogger("nemubot.server." + self.name) + self.logger = logging.getLogger("nemubot.server." + str(self.name)) + self._readbuffer = b'' self._sending_queue = queue.Queue() - if send_callback is not None: - self._send_callback = send_callback - else: - self._send_callback = self._write_select + + + def __del__(self): + print("Server deleted") @property @@ -54,40 +54,28 @@ class AbstractServer(io.IOBase): # Open/close - def __enter__(self): - self.open() - return self + def connect(self, *args, **kwargs): + """Register the server in _poll""" + + self.logger.info("Opening connection") + + super().connect(*args, **kwargs) + + self._connected() + + def _connected(self): + sync_act("sckt register %d" % self.fileno()) - def __exit__(self, type, value, traceback): - self.close() + def close(self, *args, **kwargs): + """Unregister the server from _poll""" + self.logger.info("Closing connection") - def open(self): - """Generic open function that register the server un _rlist in case - of successful _open""" - self.logger.info("Opening connection to %s", self.id) - if not hasattr(self, "_open") or self._open(): - _rlist.append(self) - _xlist.append(self) - return True - return False + if self.fileno() > 0: + sync_act("sckt unregister %d" % self.fileno()) - - def close(self): - """Generic close function that register the server un _{r,w,x}list in - case of successful _close""" - self.logger.info("Closing connection to %s", self.id) - with _lock: - if not hasattr(self, "_close") or self._close(): - if self in _rlist: - _rlist.remove(self) - if self in _wlist: - _wlist.remove(self) - if self in _xlist: - _xlist.remove(self) - return True - return False + super().close(*args, **kwargs) # Writes @@ -99,13 +87,16 @@ class AbstractServer(io.IOBase): message -- message to send """ - self._send_callback(message) + self._sending_queue.put(self.format(message)) + self.logger.debug("Message '%s' appended to write queue", message) + sync_act("sckt write %d" % self.fileno()) - def write_select(self): - """Internal function used by the select function""" + def async_write(self): + """Internal function used when the file descriptor is writable""" + try: - _wlist.remove(self) + sync_act("sckt unwrite %d" % self.fileno()) while not self._sending_queue.empty(): self._write(self._sending_queue.get_nowait()) self._sending_queue.task_done() @@ -114,19 +105,6 @@ class AbstractServer(io.IOBase): pass - def _write_select(self, message): - """Send a message to the server safely through select - - Argument: - message -- message to send - """ - - self._sending_queue.put(self.format(message)) - self.logger.debug("Message '%s' appended to write queue", message) - if self not in _wlist: - _wlist.append(self) - - def send_response(self, response): """Send a formated Message class @@ -149,13 +127,39 @@ class AbstractServer(io.IOBase): # Read + def async_read(self): + """Internal function used when the file descriptor is readable + + Returns: + A list of fully received messages + """ + + ret, self._readbuffer = self.lex(self._readbuffer + self.read()) + + for r in ret: + yield r + + + def lex(self, buf): + """Assume lexing in default case is per line + + Argument: + buf -- buffer to lex + """ + + msgs = buf.split(b'\r\n') + partial = msgs.pop() + + return msgs, partial + + def parse(self, msg): raise NotImplemented # Exceptions - def exception(self): - """Exception occurs in fd""" - self.logger.warning("Unhandle file descriptor exception on server %s", - self.name) + def exception(self, flags): + """Exception occurs on fd""" + + self.close() diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py index 13ac9bd..aeb20e5 100644 --- a/nemubot/server/socket.py +++ b/nemubot/server/socket.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,117 +14,35 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import os +import socket + import nemubot.message as message from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.server.abstract import AbstractServer -class SocketServer(AbstractServer): +class _Socket(AbstractServer, socket.socket): - """Concrete implementation of a socket connexion (can be wrapped with TLS)""" + """Concrete implementation of a socket connection""" - def __init__(self, sock_location=None, - host=None, port=None, - sock=None, - ssl=False, - name=None): + def __init__(self, **kwargs): """Create a server socket Keyword arguments: - sock_location -- Path to the UNIX socket - host -- Hostname of the INET socket - port -- Port of the INET socket - sock -- Already connected socket ssl -- Should TLS connection enabled - name -- Convinience name """ - import socket - - assert(sock is None or isinstance(sock, socket.SocketType)) - assert(port is None or isinstance(port, int)) - - super().__init__(name=name) - - if sock is None: - if sock_location is not None: - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.connect_to = sock_location - elif host is not None: - for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): - self.socket = socket.socket(af, socktype, proto) - self.connect_to = sa - break - else: - self.socket = sock - - self.ssl = ssl + super().__init__(**kwargs) self.readbuffer = b'' self.printer = SocketPrinter - def fileno(self): - return self.socket.fileno() if self.socket else None - - - @property - def closed(self): - """Indicator of the connection aliveness""" - return self.socket._closed - - - # Open/close - - def open(self): - if not self.closed: - return True - - try: - self.socket.connect(self.connect_to) - self.logger.info("Connected to %s", self.connect_to) - except: - self.socket.close() - self.logger.exception("Unable to connect to %s", - self.connect_to) - return False - - # Wrap the socket for SSL - if self.ssl: - import ssl - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - self.socket = ctx.wrap_socket(self.socket) - - return super().open() - - - def close(self): - import socket - - # Flush the sending queue before close - from nemubot.server import _lock - _lock.release() - self._sending_queue.join() - _lock.acquire() - - if not self.closed: - try: - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self.socket.close() - - return super().close() - - # Write def _write(self, cnt): - if self.closed: - return - - self.socket.sendall(cnt) + self.sendall(cnt) def format(self, txt): @@ -136,19 +54,12 @@ class SocketServer(AbstractServer): # Read - def read(self): - if self.closed: - return [] - - raw = self.socket.recv(1024) - temp = (self.readbuffer + raw).split(b'\r\n') - self.readbuffer = temp.pop() - - for line in temp: - yield line + def read(self, n=1024): + return self.recv(n) def parse(self, line): + """Implement a default behaviour for socket""" import shlex line = line.strip().decode() @@ -157,48 +68,84 @@ class SocketServer(AbstractServer): except ValueError: args = line.split(' ') - yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you") + if len(args): + yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you") -class SocketListener(AbstractServer): +class SocketServer(_Socket): - def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None): - super().__init__(name=name) - self.new_server_cb = new_server_cb - self.sock_location = sock_location - self.host = host - self.port = port - self.ssl = ssl - self.nb_son = 0 + def __init__(self, host, port, bind=None, **kwargs): + super().__init__(family=socket.AF_INET, **kwargs) + + assert(host is not None) + assert(isinstance(port, int)) + + self._host = host + self._port = port + self._bind = bind - def fileno(self): - return self.socket.fileno() if self.socket else None + def connect(self): + self.logger.info("Connection to %s:%d", self._host, self._port) + super().connect((self._host, self._port)) + + if self._bind: + super().bind(self._bind) - @property - def closed(self): - """Indicator of the connection aliveness""" - return self.socket is None +class UnixSocket(_Socket): + + def __init__(self, location, **kwargs): + super().__init__(family=socket.AF_UNIX, **kwargs) + + self._socket_path = location - def open(self): - import os - import socket + def connect(self): + self.logger.info("Connection to unix://%s", self._socket_path) + super().connect(self._socket_path) - if self.sock_location is not None: - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - os.remove(self.sock_location) - except FileNotFoundError: - pass - self.socket.bind(self.sock_location) - elif self.host is not None and self.port is not None: - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.bind((self.host, self.port)) - self.socket.listen(5) - return super().open() +class _Listener(_Socket): + + def __init__(self, new_server_cb, instanciate=_Socket, **kwargs): + super().__init__(**kwargs) + + self._instanciate = instanciate + self._new_server_cb = new_server_cb + + + def read(self): + conn, addr = self.accept() + self.logger.info("Accept new connection from %s", addr) + + fileno = conn.fileno() + ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach()) + ss.connect = ss._connected + self._new_server_cb(ss, autoconnect=True) + + return b'' + + +class UnixSocketListener(_Listener, UnixSocket): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + + def connect(self): + self.logger.info("Creating Unix socket at unix://%s", self._socket_path) + + try: + os.remove(self._socket_path) + except FileNotFoundError: + pass + + self.bind(self._socket_path) + self.listen(5) + self.logger.info("Socket ready for accepting new connections") + + self._connected() def close(self): @@ -206,25 +153,14 @@ class SocketListener(AbstractServer): import socket try: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() - if self.sock_location is not None: - os.remove(self.sock_location) + self.shutdown(socket.SHUT_RDWR) except socket.error: pass - return super().close() + super().close() - - # Read - - def read(self): - if self.closed: - return [] - - conn, addr = self.socket.accept() - self.nb_son += 1 - ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn) - self.new_server_cb(ss) - - return [] + try: + if self._socket_path is not None: + os.remove(self._socket_path) + except: + pass diff --git a/nemubot/tools/human.py b/nemubot/tools/human.py index a18cde2..f0e947f 100644 --- a/nemubot/tools/human.py +++ b/nemubot/tools/human.py @@ -57,11 +57,15 @@ def word_distance(str1, str2): return d[len(str1)][len(str2)] -def guess(pattern, expect): +def guess(pattern, expect, max_depth=0): + if max_depth == 0: + max_depth = 1 + len(pattern) / 4 + elif max_depth <= -1: + max_depth = len(pattern) - max_depth if len(expect): se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1]) _, m = se[0] for e, wd in se: - if wd > m or wd > 1 + len(pattern) / 4: + if wd > m or wd > max_depth: break yield e