1
0
Fork 0

Add type for use with mypy

This commit is contained in:
nemunaire 2016-06-17 19:26:29 +02:00
parent a8706d6213
commit 39936e9d39
24 changed files with 482 additions and 446 deletions

View File

@ -17,12 +17,15 @@
__version__ = '4.0.dev3'
__author__ = 'nemunaire'
from typing import Optional
from nemubot.modulecontext import ModuleContext
context = ModuleContext(None, None)
def requires_version(min=None, max=None):
def requires_version(min: Optional[int] = None,
max: Optional[int] = None) -> None:
"""Raise ImportError if the current version is not in the given range
Keyword arguments:
@ -39,7 +42,7 @@ def requires_version(min=None, max=None):
"but this is nemubot v%s." % (str(max), __version__))
def attach(pid, socketfile):
def attach(pid: int, socketfile: str) -> int:
import socket
import sys
@ -98,7 +101,7 @@ def attach(pid, socketfile):
return 0
def daemonize():
def daemonize() -> None:
"""Detach the running process to run as a daemon
"""

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def main():
def main() -> None:
import os
import signal
import sys
@ -152,6 +152,7 @@ def main():
# Signals handling
def sigtermhandler(signum, frame):
"""On SIGTERM and SIGINT, quit nicely"""
sigusr1handler(signum, frame)
context.quit()
signal.signal(signal.SIGINT, sigtermhandler)
signal.signal(signal.SIGTERM, sigtermhandler)
@ -170,17 +171,23 @@ def main():
def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces"""
import traceback
import threading, traceback
for threadId, stack in sys._current_frames().items():
logger.debug("########### Thread %d:\n%s",
threadId,
thName = "#%d" % threadId
for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile:
from nemubot.server.socket import SocketListener
context.add_server(SocketListener(context.add_server, "master_socket",
sock_location=args.socketfile))
from nemubot.server.socket import UnixSocketListener
context.add_server(UnixSocketListener(new_server_cb=context.add_server,
location=args.socketfile,
name="master_socket"))
# context can change when performing an hotswap, always join the latest context
oldcontext = None

View File

@ -15,9 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timezone
import ipaddress
import logging
from multiprocessing import JoinableQueue
import threading
import select
import sys
from typing import Any, Mapping, Optional, Sequence
from nemubot import __version__
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
@ -26,13 +30,23 @@ import nemubot.hooks
logger = logging.getLogger("nemubot")
sync_queue = JoinableQueue()
def sync_act(*args):
if isinstance(act, bytes):
act = act.decode()
sync_queue.put(act)
class Bot(threading.Thread):
"""Class containing the bot context and ensuring key goals"""
def __init__(self, ip="127.0.0.1", modules_paths=list(),
data_store=datastore.Abstract(), verbosity=0):
def __init__(self,
ip: Optional[ipaddress] = None,
modules_paths: Sequence[str] = list(),
data_store: Optional[datastore.Abstract] = None,
verbosity: int = 0):
"""Initialize the bot context
Keyword arguments:
@ -42,7 +56,7 @@ class Bot(threading.Thread):
verbosity -- verbosity level
"""
threading.Thread.__init__(self)
super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__,
@ -52,16 +66,16 @@ class Bot(threading.Thread):
self.stop = None
# External IP for accessing this bot
import ipaddress
self.ip = ipaddress.ip_address(ip)
self.ip = ip if ip is not None else ipaddress.ip_address("127.0.0.1")
# Context paths
self.modules_paths = modules_paths
self.datastore = data_store
self.datastore = data_store if data_store is not None else datastore.Abstract()
self.datastore.open()
# Keep global context: servers and modules
self.servers = dict()
self._poll = select.poll()
self.servers = dict() # types: Mapping[str, AbstractServer]
self.modules = dict()
self.modules_configuration = dict()
@ -138,60 +152,76 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue()
self.cnsr_thrd = list()
self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self):
from select import select
from nemubot.server import _lock, _rlist, _wlist, _xlist
self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
logger.info("Starting main loop")
self.stop = False
while not self.stop:
with _lock:
try:
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
except:
logger.error("Something went wrong in select")
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for fd, flag in self._poll.poll():
print("poll")
# Handle internal socket passing orders
if fd != sync_queue._reader.fileno():
srv = self.servers[fd]
for x in xl:
try:
x.exception()
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
try:
w.write_select()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
self.receive_message(r, i)
srv.exception(flag)
except:
logger.exception("Uncatched exception on server read")
logger.exception("Uncatched exception on server exception")
if srv.fileno() > 0:
if flag & (select.POLLOUT):
try:
srv.async_write()
except:
logger.exception("Uncatched exception on server write")
if flag & (select.POLLIN | select.POLLPRI):
try:
for i in srv.async_read():
self.receive_message(srv, i)
except:
logger.exception("Uncatched exception on server read")
else:
del self.servers[fd]
# Always check the sync queue
while not sync_queue.empty():
import shlex
args = shlex.split(sync_queue.get())
action = args.pop(0)
logger.info("action: %s: %s", action, args)
if action == "sckt" and len(args) >= 2:
try:
if args[0] == "write":
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "unwrite":
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "register":
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "unregister":
self._poll.unregister(int(args[1]))
except:
logger.exception("Unhandled excpetion during action:")
elif action == "exit":
self.quit()
elif action == "loadconf":
for path in action.args:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
sync_queue.task_done()
# Launch new consumer threads if necessary
@ -202,17 +232,6 @@ class Bot(threading.Thread):
c = Consumer(self)
self.cnsr_thrd.append(c)
c.start()
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
logger.info("Ending main loop")
@ -414,10 +433,11 @@ class Bot(threading.Thread):
autoconnect -- connect after add?
"""
if srv.fileno not in self.servers:
self.servers[srv.fileno] = srv
fileno = srv.fileno()
if fileno not in self.servers:
self.servers[fileno] = srv
if autoconnect and not hasattr(self, "noautoconnect"):
srv.open()
srv.connect()
return True
else:
@ -439,10 +459,10 @@ class Bot(threading.Thread):
__import__(name)
def add_module(self, module):
def add_module(self, mdl: Any):
"""Add a module to the context, if already exists, unload the
old one before"""
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
module_name = mdl.__spec__.name if hasattr(mdl, "__spec__") else mdl.__name__
if hasattr(self, "stop") and self.stop:
logger.warn("The bot is stopped, can't register new modules")
@ -454,40 +474,40 @@ class Bot(threading.Thread):
# Overwrite print built-in
def prnt(*args):
if hasattr(module, "logger"):
module.logger.info(" ".join([str(s) for s in args]))
if hasattr(mdl, "logger"):
mdl.logger.info(" ".join([str(s) for s in args]))
else:
logger.info("[%s] %s", module_name, " ".join([str(s) for s in args]))
module.print = prnt
mdl.print = prnt
# Create module context
from nemubot.modulecontext import ModuleContext
module.__nemubot_context__ = ModuleContext(self, module)
mdl.__nemubot_context__ = ModuleContext(self, mdl)
if not hasattr(module, "logger"):
module.logger = logging.getLogger("nemubot.module." + module_name)
if not hasattr(mdl, "logger"):
mdl.logger = logging.getLogger("nemubot.module." + module_name)
# Replace imported context by real one
for attr in module.__dict__:
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext:
module.__dict__[attr] = module.__nemubot_context__
for attr in mdl.__dict__:
if attr != "__nemubot_context__" and type(mdl.__dict__[attr]) == ModuleContext:
mdl.__dict__[attr] = mdl.__nemubot_context__
# Register decorated functions
import nemubot.hooks
for s, h in nemubot.hooks.hook.last_registered:
module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
mdl.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
nemubot.hooks.hook.last_registered = []
# Launch the module
if hasattr(module, "load"):
if hasattr(mdl, "load"):
try:
module.load(module.__nemubot_context__)
mdl.load(mdl.__nemubot_context__)
except:
module.__nemubot_context__.unload()
mdl.__nemubot_context__.unload()
raise
# Save a reference to the module
self.modules[module_name] = module
self.modules[module_name] = mdl
def unload_module(self, name):
@ -530,28 +550,28 @@ class Bot(threading.Thread):
def quit(self):
"""Save and unload modules and disconnect servers"""
self.datastore.close()
if self.event_timer is not None:
logger.info("Stop the event timer...")
self.event_timer.cancel()
logger.info("Save and unload all modules...")
for mod in self.modules.items():
self.unload_module(mod)
logger.info("Close all servers connection...")
for k in self.servers:
self.servers[k].close()
logger.info("Stop consumers")
k = self.cnsr_thrd
for cnsr in k:
cnsr.stop = True
logger.info("Save and unload all modules...")
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
logger.info("Close all servers connection...")
k = list(self.servers.keys())
for srv in k:
self.servers[srv].close()
self.datastore.close()
self.stop = True
sync_queue.put("end")
sync_queue.join()
# Treatment

View File

@ -15,13 +15,18 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
from typing import Optional
from nemubot.message import Abstract as AbstractMessage
class Channel:
"""A chat room"""
def __init__(self, name, password=None, encoding=None):
def __init__(self,
name: str,
password: Optional[str] = None,
encoding: Optional[str] = None):
"""Initialize the channel
Arguments:
@ -37,7 +42,8 @@ class Channel:
self.topic = ""
self.logger = logging.getLogger("nemubot.channel." + name)
def treat(self, cmd, msg):
def treat(self, cmd: str, msg: AbstractMessage) -> None:
"""Treat a incoming IRC command
Arguments:
@ -60,7 +66,8 @@ class Channel:
elif cmd == "TOPIC":
self.topic = self.text
def join(self, nick, level=0):
def join(self, nick: str, level: int = 0) -> None:
"""Someone join the channel
Argument:
@ -71,7 +78,8 @@ class Channel:
self.logger.debug("%s join", nick)
self.people[nick] = level
def chtopic(self, newtopic):
def chtopic(self, newtopic: str) -> None:
"""Send command to change the topic
Arguments:
@ -81,7 +89,8 @@ class Channel:
self.srv.send_msg(self.name, newtopic, "TOPIC")
self.topic = newtopic
def nick(self, oldnick, newnick):
def nick(self, oldnick: str, newnick: str) -> None:
"""Someone change his nick
Arguments:
@ -95,7 +104,8 @@ class Channel:
del self.people[oldnick]
self.people[newnick] = lvl
def part(self, nick):
def part(self, nick: str) -> None:
"""Someone leave the channel
Argument:
@ -106,7 +116,8 @@ class Channel:
self.logger.debug("%s has left", nick)
del self.people[nick]
def mode(self, msg):
def mode(self, msg: AbstractMessage) -> None:
"""Channel or user mode change
Argument:
@ -132,7 +143,8 @@ class Channel:
elif msg.text[0] == "-v":
self.people[msg.nick] &= ~1
def parse332(self, msg):
def parse332(self, msg: AbstractMessage) -> None:
"""Parse RPL_TOPIC message
Argument:
@ -141,7 +153,8 @@ class Channel:
self.topic = msg.text
def parse353(self, msg):
def parse353(self, msg: AbstractMessage) -> None:
"""Parse RPL_ENDOFWHO message
Argument:

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def get_boolean(s):
def get_boolean(s) -> bool:
if isinstance(s, bool):
return s
else:

View File

@ -16,5 +16,5 @@
class Include:
def __init__(self, path):
def __init__(self, path: str):
self.path = path

View File

@ -20,7 +20,11 @@ from nemubot.datastore.nodes.generic import GenericNode
class Module(GenericNode):
def __init__(self, name, autoload=True, **kwargs):
def __init__(self,
name: str,
autoload: bool = True,
**kwargs):
super().__init__(None, **kwargs)
self.name = name
self.autoload = get_boolean(autoload)

View File

@ -14,15 +14,23 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.config.include import Include
from nemubot.config.module import Module
from nemubot.config.server import Server
from typing import Optional, Sequence, Union
import nemubot.config.include
import nemubot.config.module
import nemubot.config.server
class Nemubot:
def __init__(self, nick="nemubot", realname="nemubot", owner=None,
ip=None, ssl=False, caps=None, encoding="utf-8"):
def __init__(self,
nick: str = "nemubot",
realname: str = "nemubot",
owner: Optional[str] = None,
ip: Optional[str] = None,
ssl: bool = False,
caps: Optional[Sequence[str]] = None,
encoding: str = "utf-8"):
self.nick = nick
self.realname = realname
self.owner = owner
@ -34,13 +42,13 @@ class Nemubot:
self.includes = []
def addChild(self, name, child):
if name == "module" and isinstance(child, Module):
def addChild(self, name: str, child: Union[nemubot.config.module.Module, nemubot.config.server.Server, nemubot.config.include.Include]):
if name == "module" and isinstance(child, nemubot.config.module.Module):
self.modules.append(child)
return True
elif name == "server" and isinstance(child, Server):
elif name == "server" and isinstance(child, nemubot.config.server.Server):
self.servers.append(child)
return True
elif name == "include" and isinstance(child, Include):
elif name == "include" and isinstance(child, nemubot.config.include.Include):
self.includes.append(child)
return True

View File

@ -14,12 +14,19 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, Sequence
from nemubot.channel import Channel
import nemubot.config.nemubot
class Server:
def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs):
def __init__(self,
uri: str = "irc://nemubot@localhost/",
autoconnect: bool = True,
caps: Optional[Sequence[str]] = None,
**kwargs):
self.uri = uri
self.autoconnect = autoconnect
self.caps = caps.split(" ") if caps is not None else []
@ -27,7 +34,7 @@ class Server:
self.channels = []
def addChild(self, name, child):
def addChild(self, name: str, child: Channel):
if name == "channel" and isinstance(child, Channel):
self.channels.append(child)
return True

View File

@ -17,6 +17,12 @@
import logging
import queue
import threading
from typing import List
from nemubot.bot import Bot
from nemubot.event import ModuleEvent
from nemubot.message.abstract import Abstract as AbstractMessage
from nemubot.server.abstract import AbstractServer
logger = logging.getLogger("nemubot.consumer")
@ -25,18 +31,15 @@ class MessageConsumer:
"""Store a message before treating"""
def __init__(self, srv, msg):
def __init__(self, srv: AbstractServer, msg: AbstractMessage):
self.srv = srv
self.orig = msg
def run(self, context):
def run(self, context: Bot) -> None:
"""Create, parse and treat the message"""
from nemubot.bot import Bot
assert isinstance(context, Bot)
msgs = []
msgs = [] # type: List[AbstractMessage]
# Parse the message
try:
@ -55,8 +58,6 @@ class MessageConsumer:
if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
from nemubot.server.abstract import AbstractServer
# Treat the message
for msg in msgs:
for res in context.treater.treat_msg(msg):
@ -87,12 +88,12 @@ class EventConsumer:
"""Store a event before treating"""
def __init__(self, evt, timeout=20):
def __init__(self, evt: ModuleEvent, timeout: int = 20):
self.evt = evt
self.timeout = timeout
def run(self, context):
def run(self, context: Bot) -> None:
try:
self.evt.check()
except:
@ -113,13 +114,13 @@ class Consumer(threading.Thread):
"""Dequeue and exec requested action"""
def __init__(self, context):
def __init__(self, context: Bot):
self.context = context
self.stop = False
super().__init__(name="Nemubot consumer")
def run(self):
def run(self) -> None:
try:
while not self.stop:
stm = self.context.cnsr_queue.get(True, 1)

View File

@ -25,11 +25,14 @@ class Abstract:
return None
def open(self):
return
def close(self):
return
def open(self) -> bool:
return True
def close(self) -> bool:
return True
def load(self, module):
"""Load data for the given module
@ -43,6 +46,7 @@ class Abstract:
return self.new()
def save(self, module, data):
"""Load data for the given module
@ -56,9 +60,11 @@ class Abstract:
return True
def __enter__(self):
self.open()
return self
def __exit__(self, type, value, traceback):
self.close()

View File

@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Any, Mapping, Sequence
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
@ -25,24 +28,24 @@ class ListNode(Serializable):
serializetag = "list"
def __init__(self, **kwargs):
self.items = list()
self.items = list() # type: Sequence
def addChild(self, name, child):
def addChild(self, name: str, child) -> bool:
self.items.append(child)
return True
def parsedForm(self):
def parsedForm(self) -> Sequence:
return self.items
def __len__(self):
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, item):
def __getitem__(self, item: int) -> Any:
return self.items[item]
def __setitem__(self, item, v):
def __setitem__(self, item: int, v: Any) -> None:
self.items[item] = v
def __contains__(self, item):
@ -52,8 +55,7 @@ class ListNode(Serializable):
return self.items.__repr__()
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
for i in self.items:
node.children.append(ParsingNode.serialize_node(i))
@ -72,12 +74,12 @@ class DictNode(Serializable):
self._cur = None
def startElement(self, name, attrs):
def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None and "key" in attrs:
self._cur = attrs["key"]
return False
def addChild(self, name, child):
def addChild(self, name: str, child: Any):
if self._cur is None:
return False
@ -85,24 +87,24 @@ class DictNode(Serializable):
self._cur = None
return True
def parsedForm(self):
def parsedForm(self) -> Mapping:
return self.items
def __getitem__(self, item):
def __getitem__(self, item: str) -> Any:
return self.items[item]
def __setitem__(self, item, v):
def __setitem__(self, item: str, v: str) -> None:
self.items[item] = v
def __contains__(self, item):
def __contains__(self, item: str) -> bool:
return item in self.items
def __repr__(self):
def __repr__(self) -> str:
return self.items.__repr__()
def serialize(self):
def serialize(self) -> ParsingNode:
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
for k in self.items:

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Any, Optional, Mapping, Union
from nemubot.datastore.nodes.serializable import Serializable
@ -22,41 +24,44 @@ class ParsingNode:
"""Allow any kind of subtags, just keep parsed ones
"""
def __init__(self, tag=None, **kwargs):
def __init__(self,
tag: Optional[str] = None,
**kwargs):
self.tag = tag
self.attrs = kwargs
self.content = ""
self.children = []
def characters(self, content):
def characters(self, content: str) -> None:
self.content += content
def addChild(self, name, child):
def addChild(self, name: str, child: Any) -> bool:
self.children.append(child)
return True
def hasNode(self, nodename):
def hasNode(self, nodename: str) -> bool:
return self.getNode(nodename) is not None
def getNode(self, nodename):
def getNode(self, nodename: str) -> Optional[Any]:
for c in self.children:
if c is not None and c.tag == nodename:
return c
return None
def __getitem__(self, item):
def __getitem__(self, item: str) -> Any:
return self.attrs[item]
def __contains__(self, item):
def __contains__(self, item: str) -> bool:
return item in self.attrs
def serialize_node(node, **def_kwargs):
def serialize_node(node: Union[Serializable, str, int, float, list, dict],
**def_kwargs):
"""Serialize any node or basic data to a ParsingNode instance"""
if isinstance(node, Serializable):
@ -102,13 +107,16 @@ class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary
"""
def __init__(self, tag, **kwargs):
def __init__(self,
tag: str,
**kwargs):
super().__init__(tag, **kwargs)
self._cur = None
self._deep_cur = 0
def startElement(self, name, attrs):
def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None:
self._cur = GenericNode(name, **attrs)
self._deep_cur = 0
@ -118,14 +126,14 @@ class GenericNode(ParsingNode):
return True
def characters(self, content):
def characters(self, content: str):
if self._cur is None:
super().characters(content)
else:
self._cur.characters(content)
def endElement(self, name):
def endElement(self, name: str):
if name is None:
return

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
@ -27,15 +28,15 @@ class PythonTypeNode(Serializable):
self._cnt = ""
def characters(self, content):
def characters(self, content: str) -> None:
self._cnt += content
def endElement(self, name):
def endElement(self, name: str) -> None:
raise NotImplemented
def __repr__(self):
def __repr__(self) -> str:
return self.value.__repr__()
@ -50,12 +51,11 @@ class IntNode(PythonTypeNode):
serializetag = "int"
def endElement(self, name):
def endElement(self, name: str) -> bool:
self.value = int(self._cnt)
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
@ -65,12 +65,11 @@ class FloatNode(PythonTypeNode):
serializetag = "float"
def endElement(self, name):
def endElement(self, name: str) -> bool:
self.value = float(self._cnt)
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
@ -85,7 +84,6 @@ class StringNode(PythonTypeNode):
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node

View File

@ -17,6 +17,7 @@
import fcntl
import logging
import os
from typing import Any, Mapping
import xml.parsers.expat
from nemubot.datastore.abstract import Abstract
@ -28,7 +29,9 @@ class XML(Abstract):
"""A concrete implementation of a data store that relies on XML files"""
def __init__(self, basedir, rotate=True):
def __init__(self,
basedir: str,
rotate: bool = True):
"""Initialize the datastore
Arguments:
@ -45,7 +48,7 @@ class XML(Abstract):
"enabled" if self.rotate else "disabled")
def open(self):
def open(self) -> bool:
"""Lock the directory"""
if not os.path.isdir(self.basedir):
@ -75,7 +78,7 @@ class XML(Abstract):
return True
def close(self):
def close(self) -> bool:
"""Release a locked path"""
if hasattr(self, "lock_file"):
@ -91,19 +94,19 @@ class XML(Abstract):
return False
def _get_data_file_path(self, module):
def _get_data_file_path(self, module: str) -> str:
"""Get the path to the module data file"""
return os.path.join(self.basedir, module + ".xml")
def _get_lock_file_path(self):
def _get_lock_file_path(self) -> str:
"""Get the path to the datastore lock file"""
return os.path.join(self.basedir, ".used_by_nemubot")
def load(self, module, extendsTags={}):
def load(self, module: str, extendsTags: Mapping[str, Any] = {}) -> Abstract:
"""Load data for the given module
Argument:
@ -116,7 +119,7 @@ class XML(Abstract):
data_file = self._get_data_file_path(module)
def parse(path):
def parse(path: str):
from nemubot.tools.xmlparser import XMLParser
from nemubot.datastore.nodes import basic as basicNodes
from nemubot.datastore.nodes import python as pythonNodes
@ -156,7 +159,7 @@ class XML(Abstract):
return Abstract.load(self, module)
def _rotate(self, path):
def _rotate(self, path: str) -> None:
"""Backup given path
Argument:
@ -173,7 +176,7 @@ class XML(Abstract):
os.rename(src, dst)
def _save_node(self, gen, node):
def _save_node(self, gen, node: Any):
from nemubot.datastore.nodes.generic import ParsingNode
# First, get the serialized form of the node
@ -191,7 +194,7 @@ class XML(Abstract):
gen.endElement(node.tag)
def save(self, module, data):
def save(self, module: str, data: Any) -> bool:
"""Load data for the given module
Argument:

View File

@ -15,15 +15,23 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Optional
class ModuleEvent:
"""Representation of a event initiated by a bot module"""
def __init__(self, call=None, call_data=None, func=None, func_data=None,
cmp=None, cmp_data=None, interval=60, offset=0, times=1):
def __init__(self,
call: Callable = None,
call_data: Callable = None,
func: Callable = None,
func_data: Any = None,
cmp: Any = None,
cmp_data: Any = None,
interval: int = 60,
offset: int = 0,
times: int = 1):
"""Initialize the event
Keyword arguments:
@ -70,8 +78,9 @@ class ModuleEvent:
# How many times do this event?
self.times = times
@property
def current(self):
def current(self) -> Optional[datetime.datetime]:
"""Return the date of the near check"""
if self.times != 0:
if self._end is None:
@ -79,8 +88,9 @@ class ModuleEvent:
return self._end
return None
@property
def next(self):
def next(self) -> Optional[datetime.datetime]:
"""Return the date of the next check"""
if self.times != 0:
if self._end is None:
@ -90,14 +100,16 @@ class ModuleEvent:
return self._end
return None
@property
def time_left(self):
def time_left(self) -> Union[datetime.datetime, int]:
"""Return the time left before/after the near check"""
if self.current is not None:
return self.current - datetime.now(timezone.utc)
return 99999 # TODO: 99999 is not a valid time to return
def check(self):
def check(self) -> None:
"""Run a check and realized the event if this is time"""
# Get initial data

View File

@ -21,7 +21,7 @@ class HooksManager:
"""Class to manage hooks"""
def __init__(self, name="core"):
def __init__(self, name: str = "core"):
"""Initialize the manager"""
self.hooks = dict()

View File

@ -18,16 +18,18 @@ from importlib.abc import Finder
from importlib.machinery import SourceFileLoader
import logging
import os
from typing import Callable
logger = logging.getLogger("nemubot.importer")
class ModuleFinder(Finder):
def __init__(self, modules_paths, add_module):
def __init__(self, modules_paths: str, add_module: Callable[]):
self.modules_paths = modules_paths
self.add_module = add_module
def find_module(self, fullname, path=None):
# Search only for new nemubot modules (packages init)
if path is None:

View File

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None):
super().__init__(self)
super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion

View File

@ -27,10 +27,10 @@ class IRC(SocketServer):
"""Concrete implementation of a connexion to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None,
def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None):
channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server
Keyword arguments:
@ -54,7 +54,8 @@ class IRC(SocketServer):
self.owner = owner
self.realname = realname
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
super().__init__(name=self.username + "@" + host + ":" + str(port),
host=host, port=port, **kwargs)
self.printer = IRCPrinter
self.encoding = encoding
@ -231,20 +232,19 @@ class IRC(SocketServer):
# Open/close
def open(self):
if super().open():
if self.password is not None:
self.write("PASS :" + self.password)
if self.capabilities is not None:
self.write("CAP LS")
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
return True
return False
def connect(self):
super().connect()
if self.password is not None:
self.write("PASS :" + self.password)
if self.capabilities is not None:
self.write("CAP LS")
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
def close(self):
if not self.closed:
if not self._closed:
self.write("QUIT")
return super().close()
@ -253,8 +253,8 @@ class IRC(SocketServer):
# Read
def read(self):
for line in super().read():
def async_read(self):
for line in super().async_read():
# PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding)

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,20 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock()
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
def factory(uri, ssl=False, **init_args):
from urllib.parse import urlparse, unquote
o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
@ -36,7 +29,7 @@ def factory(uri, **init_args):
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True
if o.scheme == "ircs": ssl = True
if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username
@ -65,6 +58,11 @@ def factory(uri, **init_args):
args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer
return IRCServer(**args)
else:
return None
srv = IRCServer(**args)
if ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
return ctx.wrap_socket(srv)
return srv

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,34 +14,34 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging
import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist
from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase):
class AbstractServer:
"""An abstract server: handle communication with an IM server"""
def __init__(self, name=None, send_callback=None):
def __init__(self, name=None, **kwargs):
"""Initialize an abstract server
Keyword argument:
send_callback -- Callback when developper want to send a message
name -- Identifier of the socket, for convinience
"""
self._name = name
super().__init__()
super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + self.name)
self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
else:
self._send_callback = self._write_select
def __del__(self):
print("Server deleted")
@property
@ -54,40 +54,28 @@ class AbstractServer(io.IOBase):
# Open/close
def __enter__(self):
self.open()
return self
def connect(self, *args, **kwargs):
"""Register the server in _poll"""
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._connected()
def _connected(self):
sync_act("sckt register %d" % self.fileno())
def __exit__(self, type, value, traceback):
self.close()
def close(self, *args, **kwargs):
"""Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self):
"""Generic open function that register the server un _rlist in case
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
if self.fileno() > 0:
sync_act("sckt unregister %d" % self.fileno())
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
super().close(*args, **kwargs)
# Writes
@ -99,13 +87,16 @@ class AbstractServer(io.IOBase):
message -- message to send
"""
self._send_callback(message)
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt write %d" % self.fileno())
def write_select(self):
"""Internal function used by the select function"""
def async_write(self):
"""Internal function used when the file descriptor is writable"""
try:
_wlist.remove(self)
sync_act("sckt unwrite %d" % self.fileno())
while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done()
@ -114,19 +105,6 @@ class AbstractServer(io.IOBase):
pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response):
"""Send a formated Message class
@ -149,13 +127,39 @@ class AbstractServer(io.IOBase):
# Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg):
raise NotImplemented
# Exceptions
def exception(self):
"""Exception occurs in fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.name)
def exception(self, flags):
"""Exception occurs on fd"""
self.close()

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,117 +14,35 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer):
class _Socket(AbstractServer, socket.socket):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
"""Concrete implementation of a socket connection"""
def __init__(self, sock_location=None,
host=None, port=None,
sock=None,
ssl=False,
name=None):
def __init__(self, **kwargs):
"""Create a server socket
Keyword arguments:
sock_location -- Path to the UNIX socket
host -- Hostname of the INET socket
port -- Port of the INET socket
sock -- Already connected socket
ssl -- Should TLS connection enabled
name -- Convinience name
"""
import socket
assert(sock is None or isinstance(sock, socket.SocketType))
assert(port is None or isinstance(port, int))
super().__init__(name=name)
if sock is None:
if sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connect_to = sock_location
elif host is not None:
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
self.socket = socket.socket(af, socktype, proto)
self.connect_to = sa
break
else:
self.socket = sock
self.ssl = ssl
super().__init__(**kwargs)
self.readbuffer = b''
self.printer = SocketPrinter
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket._closed
# Open/close
def open(self):
if not self.closed:
return True
try:
self.socket.connect(self.connect_to)
self.logger.info("Connected to %s", self.connect_to)
except:
self.socket.close()
self.logger.exception("Unable to connect to %s",
self.connect_to)
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return super().open()
def close(self):
import socket
# Flush the sending queue before close
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if not self.closed:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
return super().close()
# Write
def _write(self, cnt):
if self.closed:
return
self.socket.sendall(cnt)
self.sendall(cnt)
def format(self, txt):
@ -136,19 +54,12 @@ class SocketServer(AbstractServer):
# Read
def read(self):
if self.closed:
return []
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line
def read(self, n=1024):
return self.recv(n)
def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex
line = line.strip().decode()
@ -157,48 +68,84 @@ class SocketServer(AbstractServer):
except ValueError:
args = line.split(' ')
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
class SocketListener(AbstractServer):
class SocketServer(_Socket):
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
super().__init__(name=name)
self.new_server_cb = new_server_cb
self.sock_location = sock_location
self.host = host
self.port = port
self.ssl = ssl
self.nb_son = 0
def __init__(self, host, port, bind=None, **kwargs):
super().__init__(family=socket.AF_INET, **kwargs)
assert(host is not None)
assert(isinstance(port, int))
self._host = host
self._port = port
self._bind = bind
def fileno(self):
return self.socket.fileno() if self.socket else None
def connect(self):
self.logger.info("Connection to %s:%d", self._host, self._port)
super().connect((self._host, self._port))
if self._bind:
super().bind(self._bind)
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket is None
class UnixSocket(_Socket):
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def open(self):
import os
import socket
def connect(self):
self.logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
if self.sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.remove(self.sock_location)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind((self.host, self.port))
self.socket.listen(5)
return super().open()
class _Listener(_Socket):
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self):
conn, addr = self.accept()
self.logger.info("Accept new connection from %s", addr)
fileno = conn.fileno()
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
ss.connect = ss._connected
self._new_server_cb(ss, autoconnect=True)
return b''
class UnixSocketListener(_Listener, UnixSocket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self):
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._connected()
def close(self):
@ -206,25 +153,14 @@ class SocketListener(AbstractServer):
import socket
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
self.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
return super().close()
super().close()
# Read
def read(self):
if self.closed:
return []
conn, addr = self.socket.accept()
self.nb_son += 1
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []
try:
if self._socket_path is not None:
os.remove(self._socket_path)
except:
pass

View File

@ -57,11 +57,15 @@ def word_distance(str1, str2):
return d[len(str1)][len(str2)]
def guess(pattern, expect):
def guess(pattern, expect, max_depth=0):
if max_depth == 0:
max_depth = 1 + len(pattern) / 4
elif max_depth <= -1:
max_depth = len(pattern) - max_depth
if len(expect):
se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1])
_, m = se[0]
for e, wd in se:
if wd > m or wd > 1 + len(pattern) / 4:
if wd > m or wd > max_depth:
break
yield e