1
0
Fork 0

Add type for use with mypy

This commit is contained in:
nemunaire 2016-06-17 19:26:29 +02:00
parent a8706d6213
commit 39936e9d39
24 changed files with 482 additions and 446 deletions

View File

@ -17,12 +17,15 @@
__version__ = '4.0.dev3' __version__ = '4.0.dev3'
__author__ = 'nemunaire' __author__ = 'nemunaire'
from typing import Optional
from nemubot.modulecontext import ModuleContext from nemubot.modulecontext import ModuleContext
context = ModuleContext(None, None) context = ModuleContext(None, None)
def requires_version(min=None, max=None): def requires_version(min: Optional[int] = None,
max: Optional[int] = None) -> None:
"""Raise ImportError if the current version is not in the given range """Raise ImportError if the current version is not in the given range
Keyword arguments: Keyword arguments:
@ -39,7 +42,7 @@ def requires_version(min=None, max=None):
"but this is nemubot v%s." % (str(max), __version__)) "but this is nemubot v%s." % (str(max), __version__))
def attach(pid, socketfile): def attach(pid: int, socketfile: str) -> int:
import socket import socket
import sys import sys
@ -98,7 +101,7 @@ def attach(pid, socketfile):
return 0 return 0
def daemonize(): def daemonize() -> None:
"""Detach the running process to run as a daemon """Detach the running process to run as a daemon
""" """

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def main(): def main() -> None:
import os import os
import signal import signal
import sys import sys
@ -152,6 +152,7 @@ def main():
# Signals handling # Signals handling
def sigtermhandler(signum, frame): def sigtermhandler(signum, frame):
"""On SIGTERM and SIGINT, quit nicely""" """On SIGTERM and SIGINT, quit nicely"""
sigusr1handler(signum, frame)
context.quit() context.quit()
signal.signal(signal.SIGINT, sigtermhandler) signal.signal(signal.SIGINT, sigtermhandler)
signal.signal(signal.SIGTERM, sigtermhandler) signal.signal(signal.SIGTERM, sigtermhandler)
@ -170,17 +171,23 @@ def main():
def sigusr1handler(signum, frame): def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces""" """On SIGHUSR1, display stacktraces"""
import traceback import threading, traceback
for threadId, stack in sys._current_frames().items(): for threadId, stack in sys._current_frames().items():
logger.debug("########### Thread %d:\n%s", thName = "#%d" % threadId
threadId, for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack))) "".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler) signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile: if args.socketfile:
from nemubot.server.socket import SocketListener from nemubot.server.socket import UnixSocketListener
context.add_server(SocketListener(context.add_server, "master_socket", context.add_server(UnixSocketListener(new_server_cb=context.add_server,
sock_location=args.socketfile)) location=args.socketfile,
name="master_socket"))
# context can change when performing an hotswap, always join the latest context # context can change when performing an hotswap, always join the latest context
oldcontext = None oldcontext = None

View File

@ -15,9 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timezone from datetime import datetime, timezone
import ipaddress
import logging import logging
from multiprocessing import JoinableQueue
import threading import threading
import select
import sys import sys
from typing import Any, Mapping, Optional, Sequence
from nemubot import __version__ from nemubot import __version__
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
@ -26,13 +30,23 @@ import nemubot.hooks
logger = logging.getLogger("nemubot") logger = logging.getLogger("nemubot")
sync_queue = JoinableQueue()
def sync_act(*args):
if isinstance(act, bytes):
act = act.decode()
sync_queue.put(act)
class Bot(threading.Thread): class Bot(threading.Thread):
"""Class containing the bot context and ensuring key goals""" """Class containing the bot context and ensuring key goals"""
def __init__(self, ip="127.0.0.1", modules_paths=list(), def __init__(self,
data_store=datastore.Abstract(), verbosity=0): ip: Optional[ipaddress] = None,
modules_paths: Sequence[str] = list(),
data_store: Optional[datastore.Abstract] = None,
verbosity: int = 0):
"""Initialize the bot context """Initialize the bot context
Keyword arguments: Keyword arguments:
@ -42,7 +56,7 @@ class Bot(threading.Thread):
verbosity -- verbosity level verbosity -- verbosity level
""" """
threading.Thread.__init__(self) super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__, __version__,
@ -52,16 +66,16 @@ class Bot(threading.Thread):
self.stop = None self.stop = None
# External IP for accessing this bot # External IP for accessing this bot
import ipaddress self.ip = ip if ip is not None else ipaddress.ip_address("127.0.0.1")
self.ip = ipaddress.ip_address(ip)
# Context paths # Context paths
self.modules_paths = modules_paths self.modules_paths = modules_paths
self.datastore = data_store self.datastore = data_store if data_store is not None else datastore.Abstract()
self.datastore.open() self.datastore.open()
# Keep global context: servers and modules # Keep global context: servers and modules
self.servers = dict() self._poll = select.poll()
self.servers = dict() # types: Mapping[str, AbstractServer]
self.modules = dict() self.modules = dict()
self.modules_configuration = dict() self.modules_configuration = dict()
@ -138,60 +152,76 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue() self.cnsr_queue = Queue()
self.cnsr_thrd = list() self.cnsr_thrd = list()
self.cnsr_thrd_size = -1 self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self): def run(self):
from select import select self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
from nemubot.server import _lock, _rlist, _wlist, _xlist
logger.info("Starting main loop") logger.info("Starting main loop")
self.stop = False self.stop = False
while not self.stop: while not self.stop:
with _lock: for fd, flag in self._poll.poll():
try: print("poll")
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) # Handle internal socket passing orders
except: if fd != sync_queue._reader.fileno():
logger.error("Something went wrong in select") srv = self.servers[fd]
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for x in xl: if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
x.exception()
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
try:
w.write_select()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
try: try:
self.receive_message(r, i) srv.exception(flag)
except: except:
logger.exception("Uncatched exception on server read") logger.exception("Uncatched exception on server exception")
if srv.fileno() > 0:
if flag & (select.POLLOUT):
try:
srv.async_write()
except:
logger.exception("Uncatched exception on server write")
if flag & (select.POLLIN | select.POLLPRI):
try:
for i in srv.async_read():
self.receive_message(srv, i)
except:
logger.exception("Uncatched exception on server read")
else:
del self.servers[fd]
# Always check the sync queue
while not sync_queue.empty():
import shlex
args = shlex.split(sync_queue.get())
action = args.pop(0)
logger.info("action: %s: %s", action, args)
if action == "sckt" and len(args) >= 2:
try:
if args[0] == "write":
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "unwrite":
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "register":
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "unregister":
self._poll.unregister(int(args[1]))
except:
logger.exception("Unhandled excpetion during action:")
elif action == "exit":
self.quit()
elif action == "loadconf":
for path in action.args:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
sync_queue.task_done()
# Launch new consumer threads if necessary # Launch new consumer threads if necessary
@ -202,17 +232,6 @@ class Bot(threading.Thread):
c = Consumer(self) c = Consumer(self)
self.cnsr_thrd.append(c) self.cnsr_thrd.append(c)
c.start() c.start()
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
logger.info("Ending main loop") logger.info("Ending main loop")
@ -414,10 +433,11 @@ class Bot(threading.Thread):
autoconnect -- connect after add? autoconnect -- connect after add?
""" """
if srv.fileno not in self.servers: fileno = srv.fileno()
self.servers[srv.fileno] = srv if fileno not in self.servers:
self.servers[fileno] = srv
if autoconnect and not hasattr(self, "noautoconnect"): if autoconnect and not hasattr(self, "noautoconnect"):
srv.open() srv.connect()
return True return True
else: else:
@ -439,10 +459,10 @@ class Bot(threading.Thread):
__import__(name) __import__(name)
def add_module(self, module): def add_module(self, mdl: Any):
"""Add a module to the context, if already exists, unload the """Add a module to the context, if already exists, unload the
old one before""" old one before"""
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ module_name = mdl.__spec__.name if hasattr(mdl, "__spec__") else mdl.__name__
if hasattr(self, "stop") and self.stop: if hasattr(self, "stop") and self.stop:
logger.warn("The bot is stopped, can't register new modules") logger.warn("The bot is stopped, can't register new modules")
@ -454,40 +474,40 @@ class Bot(threading.Thread):
# Overwrite print built-in # Overwrite print built-in
def prnt(*args): def prnt(*args):
if hasattr(module, "logger"): if hasattr(mdl, "logger"):
module.logger.info(" ".join([str(s) for s in args])) mdl.logger.info(" ".join([str(s) for s in args]))
else: else:
logger.info("[%s] %s", module_name, " ".join([str(s) for s in args])) logger.info("[%s] %s", module_name, " ".join([str(s) for s in args]))
module.print = prnt mdl.print = prnt
# Create module context # Create module context
from nemubot.modulecontext import ModuleContext from nemubot.modulecontext import ModuleContext
module.__nemubot_context__ = ModuleContext(self, module) mdl.__nemubot_context__ = ModuleContext(self, mdl)
if not hasattr(module, "logger"): if not hasattr(mdl, "logger"):
module.logger = logging.getLogger("nemubot.module." + module_name) mdl.logger = logging.getLogger("nemubot.module." + module_name)
# Replace imported context by real one # Replace imported context by real one
for attr in module.__dict__: for attr in mdl.__dict__:
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext: if attr != "__nemubot_context__" and type(mdl.__dict__[attr]) == ModuleContext:
module.__dict__[attr] = module.__nemubot_context__ mdl.__dict__[attr] = mdl.__nemubot_context__
# Register decorated functions # Register decorated functions
import nemubot.hooks import nemubot.hooks
for s, h in nemubot.hooks.hook.last_registered: for s, h in nemubot.hooks.hook.last_registered:
module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s) mdl.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
nemubot.hooks.hook.last_registered = [] nemubot.hooks.hook.last_registered = []
# Launch the module # Launch the module
if hasattr(module, "load"): if hasattr(mdl, "load"):
try: try:
module.load(module.__nemubot_context__) mdl.load(mdl.__nemubot_context__)
except: except:
module.__nemubot_context__.unload() mdl.__nemubot_context__.unload()
raise raise
# Save a reference to the module # Save a reference to the module
self.modules[module_name] = module self.modules[module_name] = mdl
def unload_module(self, name): def unload_module(self, name):
@ -530,28 +550,28 @@ class Bot(threading.Thread):
def quit(self): def quit(self):
"""Save and unload modules and disconnect servers""" """Save and unload modules and disconnect servers"""
self.datastore.close()
if self.event_timer is not None: if self.event_timer is not None:
logger.info("Stop the event timer...") logger.info("Stop the event timer...")
self.event_timer.cancel() self.event_timer.cancel()
logger.info("Save and unload all modules...")
for mod in self.modules.items():
self.unload_module(mod)
logger.info("Close all servers connection...")
for k in self.servers:
self.servers[k].close()
logger.info("Stop consumers") logger.info("Stop consumers")
k = self.cnsr_thrd k = self.cnsr_thrd
for cnsr in k: for cnsr in k:
cnsr.stop = True cnsr.stop = True
logger.info("Save and unload all modules...") self.datastore.close()
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
logger.info("Close all servers connection...")
k = list(self.servers.keys())
for srv in k:
self.servers[srv].close()
self.stop = True self.stop = True
sync_queue.put("end")
sync_queue.join()
# Treatment # Treatment

View File

@ -15,13 +15,18 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
from typing import Optional
from nemubot.message import Abstract as AbstractMessage
class Channel: class Channel:
"""A chat room""" """A chat room"""
def __init__(self, name, password=None, encoding=None): def __init__(self,
name: str,
password: Optional[str] = None,
encoding: Optional[str] = None):
"""Initialize the channel """Initialize the channel
Arguments: Arguments:
@ -37,7 +42,8 @@ class Channel:
self.topic = "" self.topic = ""
self.logger = logging.getLogger("nemubot.channel." + name) self.logger = logging.getLogger("nemubot.channel." + name)
def treat(self, cmd, msg):
def treat(self, cmd: str, msg: AbstractMessage) -> None:
"""Treat a incoming IRC command """Treat a incoming IRC command
Arguments: Arguments:
@ -60,7 +66,8 @@ class Channel:
elif cmd == "TOPIC": elif cmd == "TOPIC":
self.topic = self.text self.topic = self.text
def join(self, nick, level=0):
def join(self, nick: str, level: int = 0) -> None:
"""Someone join the channel """Someone join the channel
Argument: Argument:
@ -71,7 +78,8 @@ class Channel:
self.logger.debug("%s join", nick) self.logger.debug("%s join", nick)
self.people[nick] = level self.people[nick] = level
def chtopic(self, newtopic):
def chtopic(self, newtopic: str) -> None:
"""Send command to change the topic """Send command to change the topic
Arguments: Arguments:
@ -81,7 +89,8 @@ class Channel:
self.srv.send_msg(self.name, newtopic, "TOPIC") self.srv.send_msg(self.name, newtopic, "TOPIC")
self.topic = newtopic self.topic = newtopic
def nick(self, oldnick, newnick):
def nick(self, oldnick: str, newnick: str) -> None:
"""Someone change his nick """Someone change his nick
Arguments: Arguments:
@ -95,7 +104,8 @@ class Channel:
del self.people[oldnick] del self.people[oldnick]
self.people[newnick] = lvl self.people[newnick] = lvl
def part(self, nick):
def part(self, nick: str) -> None:
"""Someone leave the channel """Someone leave the channel
Argument: Argument:
@ -106,7 +116,8 @@ class Channel:
self.logger.debug("%s has left", nick) self.logger.debug("%s has left", nick)
del self.people[nick] del self.people[nick]
def mode(self, msg):
def mode(self, msg: AbstractMessage) -> None:
"""Channel or user mode change """Channel or user mode change
Argument: Argument:
@ -132,7 +143,8 @@ class Channel:
elif msg.text[0] == "-v": elif msg.text[0] == "-v":
self.people[msg.nick] &= ~1 self.people[msg.nick] &= ~1
def parse332(self, msg):
def parse332(self, msg: AbstractMessage) -> None:
"""Parse RPL_TOPIC message """Parse RPL_TOPIC message
Argument: Argument:
@ -141,7 +153,8 @@ class Channel:
self.topic = msg.text self.topic = msg.text
def parse353(self, msg):
def parse353(self, msg: AbstractMessage) -> None:
"""Parse RPL_ENDOFWHO message """Parse RPL_ENDOFWHO message
Argument: Argument:

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def get_boolean(s): def get_boolean(s) -> bool:
if isinstance(s, bool): if isinstance(s, bool):
return s return s
else: else:

View File

@ -16,5 +16,5 @@
class Include: class Include:
def __init__(self, path): def __init__(self, path: str):
self.path = path self.path = path

View File

@ -20,7 +20,11 @@ from nemubot.datastore.nodes.generic import GenericNode
class Module(GenericNode): class Module(GenericNode):
def __init__(self, name, autoload=True, **kwargs): def __init__(self,
name: str,
autoload: bool = True,
**kwargs):
super().__init__(None, **kwargs) super().__init__(None, **kwargs)
self.name = name self.name = name
self.autoload = get_boolean(autoload) self.autoload = get_boolean(autoload)

View File

@ -14,15 +14,23 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.config.include import Include from typing import Optional, Sequence, Union
from nemubot.config.module import Module
from nemubot.config.server import Server import nemubot.config.include
import nemubot.config.module
import nemubot.config.server
class Nemubot: class Nemubot:
def __init__(self, nick="nemubot", realname="nemubot", owner=None, def __init__(self,
ip=None, ssl=False, caps=None, encoding="utf-8"): nick: str = "nemubot",
realname: str = "nemubot",
owner: Optional[str] = None,
ip: Optional[str] = None,
ssl: bool = False,
caps: Optional[Sequence[str]] = None,
encoding: str = "utf-8"):
self.nick = nick self.nick = nick
self.realname = realname self.realname = realname
self.owner = owner self.owner = owner
@ -34,13 +42,13 @@ class Nemubot:
self.includes = [] self.includes = []
def addChild(self, name, child): def addChild(self, name: str, child: Union[nemubot.config.module.Module, nemubot.config.server.Server, nemubot.config.include.Include]):
if name == "module" and isinstance(child, Module): if name == "module" and isinstance(child, nemubot.config.module.Module):
self.modules.append(child) self.modules.append(child)
return True return True
elif name == "server" and isinstance(child, Server): elif name == "server" and isinstance(child, nemubot.config.server.Server):
self.servers.append(child) self.servers.append(child)
return True return True
elif name == "include" and isinstance(child, Include): elif name == "include" and isinstance(child, nemubot.config.include.Include):
self.includes.append(child) self.includes.append(child)
return True return True

View File

@ -14,12 +14,19 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, Sequence
from nemubot.channel import Channel from nemubot.channel import Channel
import nemubot.config.nemubot
class Server: class Server:
def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs): def __init__(self,
uri: str = "irc://nemubot@localhost/",
autoconnect: bool = True,
caps: Optional[Sequence[str]] = None,
**kwargs):
self.uri = uri self.uri = uri
self.autoconnect = autoconnect self.autoconnect = autoconnect
self.caps = caps.split(" ") if caps is not None else [] self.caps = caps.split(" ") if caps is not None else []
@ -27,7 +34,7 @@ class Server:
self.channels = [] self.channels = []
def addChild(self, name, child): def addChild(self, name: str, child: Channel):
if name == "channel" and isinstance(child, Channel): if name == "channel" and isinstance(child, Channel):
self.channels.append(child) self.channels.append(child)
return True return True

View File

@ -17,6 +17,12 @@
import logging import logging
import queue import queue
import threading import threading
from typing import List
from nemubot.bot import Bot
from nemubot.event import ModuleEvent
from nemubot.message.abstract import Abstract as AbstractMessage
from nemubot.server.abstract import AbstractServer
logger = logging.getLogger("nemubot.consumer") logger = logging.getLogger("nemubot.consumer")
@ -25,18 +31,15 @@ class MessageConsumer:
"""Store a message before treating""" """Store a message before treating"""
def __init__(self, srv, msg): def __init__(self, srv: AbstractServer, msg: AbstractMessage):
self.srv = srv self.srv = srv
self.orig = msg self.orig = msg
def run(self, context): def run(self, context: Bot) -> None:
"""Create, parse and treat the message""" """Create, parse and treat the message"""
from nemubot.bot import Bot msgs = [] # type: List[AbstractMessage]
assert isinstance(context, Bot)
msgs = []
# Parse the message # Parse the message
try: try:
@ -55,8 +58,6 @@ class MessageConsumer:
if hasattr(msg, "frm_owner"): if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm) msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
from nemubot.server.abstract import AbstractServer
# Treat the message # Treat the message
for msg in msgs: for msg in msgs:
for res in context.treater.treat_msg(msg): for res in context.treater.treat_msg(msg):
@ -87,12 +88,12 @@ class EventConsumer:
"""Store a event before treating""" """Store a event before treating"""
def __init__(self, evt, timeout=20): def __init__(self, evt: ModuleEvent, timeout: int = 20):
self.evt = evt self.evt = evt
self.timeout = timeout self.timeout = timeout
def run(self, context): def run(self, context: Bot) -> None:
try: try:
self.evt.check() self.evt.check()
except: except:
@ -113,13 +114,13 @@ class Consumer(threading.Thread):
"""Dequeue and exec requested action""" """Dequeue and exec requested action"""
def __init__(self, context): def __init__(self, context: Bot):
self.context = context self.context = context
self.stop = False self.stop = False
super().__init__(name="Nemubot consumer") super().__init__(name="Nemubot consumer")
def run(self): def run(self) -> None:
try: try:
while not self.stop: while not self.stop:
stm = self.context.cnsr_queue.get(True, 1) stm = self.context.cnsr_queue.get(True, 1)

View File

@ -25,11 +25,14 @@ class Abstract:
return None return None
def open(self):
return
def close(self): def open(self) -> bool:
return return True
def close(self) -> bool:
return True
def load(self, module): def load(self, module):
"""Load data for the given module """Load data for the given module
@ -43,6 +46,7 @@ class Abstract:
return self.new() return self.new()
def save(self, module, data): def save(self, module, data):
"""Load data for the given module """Load data for the given module
@ -56,9 +60,11 @@ class Abstract:
return True return True
def __enter__(self): def __enter__(self):
self.open() self.open()
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self.close() self.close()

View File

@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Any, Mapping, Sequence
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable from nemubot.datastore.nodes.serializable import Serializable
@ -25,24 +28,24 @@ class ListNode(Serializable):
serializetag = "list" serializetag = "list"
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.items = list() self.items = list() # type: Sequence
def addChild(self, name, child): def addChild(self, name: str, child) -> bool:
self.items.append(child) self.items.append(child)
return True return True
def parsedForm(self): def parsedForm(self) -> Sequence:
return self.items return self.items
def __len__(self): def __len__(self) -> int:
return len(self.items) return len(self.items)
def __getitem__(self, item): def __getitem__(self, item: int) -> Any:
return self.items[item] return self.items[item]
def __setitem__(self, item, v): def __setitem__(self, item: int, v: Any) -> None:
self.items[item] = v self.items[item] = v
def __contains__(self, item): def __contains__(self, item):
@ -52,8 +55,7 @@ class ListNode(Serializable):
return self.items.__repr__() return self.items.__repr__()
def serialize(self): def serialize(self) -> ParsingNode:
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag) node = ParsingNode(tag=self.serializetag)
for i in self.items: for i in self.items:
node.children.append(ParsingNode.serialize_node(i)) node.children.append(ParsingNode.serialize_node(i))
@ -72,12 +74,12 @@ class DictNode(Serializable):
self._cur = None self._cur = None
def startElement(self, name, attrs): def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None and "key" in attrs: if self._cur is None and "key" in attrs:
self._cur = attrs["key"] self._cur = attrs["key"]
return False return False
def addChild(self, name, child): def addChild(self, name: str, child: Any):
if self._cur is None: if self._cur is None:
return False return False
@ -85,24 +87,24 @@ class DictNode(Serializable):
self._cur = None self._cur = None
return True return True
def parsedForm(self): def parsedForm(self) -> Mapping:
return self.items return self.items
def __getitem__(self, item): def __getitem__(self, item: str) -> Any:
return self.items[item] return self.items[item]
def __setitem__(self, item, v): def __setitem__(self, item: str, v: str) -> None:
self.items[item] = v self.items[item] = v
def __contains__(self, item): def __contains__(self, item: str) -> bool:
return item in self.items return item in self.items
def __repr__(self): def __repr__(self) -> str:
return self.items.__repr__() return self.items.__repr__()
def serialize(self): def serialize(self) -> ParsingNode:
from nemubot.datastore.nodes.generic import ParsingNode from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag) node = ParsingNode(tag=self.serializetag)
for k in self.items: for k in self.items:

View File

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Any, Optional, Mapping, Union
from nemubot.datastore.nodes.serializable import Serializable from nemubot.datastore.nodes.serializable import Serializable
@ -22,41 +24,44 @@ class ParsingNode:
"""Allow any kind of subtags, just keep parsed ones """Allow any kind of subtags, just keep parsed ones
""" """
def __init__(self, tag=None, **kwargs): def __init__(self,
tag: Optional[str] = None,
**kwargs):
self.tag = tag self.tag = tag
self.attrs = kwargs self.attrs = kwargs
self.content = "" self.content = ""
self.children = [] self.children = []
def characters(self, content): def characters(self, content: str) -> None:
self.content += content self.content += content
def addChild(self, name, child): def addChild(self, name: str, child: Any) -> bool:
self.children.append(child) self.children.append(child)
return True return True
def hasNode(self, nodename): def hasNode(self, nodename: str) -> bool:
return self.getNode(nodename) is not None return self.getNode(nodename) is not None
def getNode(self, nodename): def getNode(self, nodename: str) -> Optional[Any]:
for c in self.children: for c in self.children:
if c is not None and c.tag == nodename: if c is not None and c.tag == nodename:
return c return c
return None return None
def __getitem__(self, item): def __getitem__(self, item: str) -> Any:
return self.attrs[item] return self.attrs[item]
def __contains__(self, item): def __contains__(self, item: str) -> bool:
return item in self.attrs return item in self.attrs
def serialize_node(node, **def_kwargs): def serialize_node(node: Union[Serializable, str, int, float, list, dict],
**def_kwargs):
"""Serialize any node or basic data to a ParsingNode instance""" """Serialize any node or basic data to a ParsingNode instance"""
if isinstance(node, Serializable): if isinstance(node, Serializable):
@ -102,13 +107,16 @@ class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary """Consider all subtags as dictionnary
""" """
def __init__(self, tag, **kwargs): def __init__(self,
tag: str,
**kwargs):
super().__init__(tag, **kwargs) super().__init__(tag, **kwargs)
self._cur = None self._cur = None
self._deep_cur = 0 self._deep_cur = 0
def startElement(self, name, attrs): def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None: if self._cur is None:
self._cur = GenericNode(name, **attrs) self._cur = GenericNode(name, **attrs)
self._deep_cur = 0 self._deep_cur = 0
@ -118,14 +126,14 @@ class GenericNode(ParsingNode):
return True return True
def characters(self, content): def characters(self, content: str):
if self._cur is None: if self._cur is None:
super().characters(content) super().characters(content)
else: else:
self._cur.characters(content) self._cur.characters(content)
def endElement(self, name): def endElement(self, name: str):
if name is None: if name is None:
return return

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable from nemubot.datastore.nodes.serializable import Serializable
@ -27,15 +28,15 @@ class PythonTypeNode(Serializable):
self._cnt = "" self._cnt = ""
def characters(self, content): def characters(self, content: str) -> None:
self._cnt += content self._cnt += content
def endElement(self, name): def endElement(self, name: str) -> None:
raise NotImplemented raise NotImplemented
def __repr__(self): def __repr__(self) -> str:
return self.value.__repr__() return self.value.__repr__()
@ -50,12 +51,11 @@ class IntNode(PythonTypeNode):
serializetag = "int" serializetag = "int"
def endElement(self, name): def endElement(self, name: str) -> bool:
self.value = int(self._cnt) self.value = int(self._cnt)
return True return True
def serialize(self): def serialize(self) -> ParsingNode:
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag) node = ParsingNode(tag=self.serializetag)
node.content = str(self.value) node.content = str(self.value)
return node return node
@ -65,12 +65,11 @@ class FloatNode(PythonTypeNode):
serializetag = "float" serializetag = "float"
def endElement(self, name): def endElement(self, name: str) -> bool:
self.value = float(self._cnt) self.value = float(self._cnt)
return True return True
def serialize(self): def serialize(self) -> ParsingNode:
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag) node = ParsingNode(tag=self.serializetag)
node.content = str(self.value) node.content = str(self.value)
return node return node
@ -85,7 +84,6 @@ class StringNode(PythonTypeNode):
return True return True
def serialize(self): def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag) node = ParsingNode(tag=self.serializetag)
node.content = str(self.value) node.content = str(self.value)
return node return node

View File

@ -17,6 +17,7 @@
import fcntl import fcntl
import logging import logging
import os import os
from typing import Any, Mapping
import xml.parsers.expat import xml.parsers.expat
from nemubot.datastore.abstract import Abstract from nemubot.datastore.abstract import Abstract
@ -28,7 +29,9 @@ class XML(Abstract):
"""A concrete implementation of a data store that relies on XML files""" """A concrete implementation of a data store that relies on XML files"""
def __init__(self, basedir, rotate=True): def __init__(self,
basedir: str,
rotate: bool = True):
"""Initialize the datastore """Initialize the datastore
Arguments: Arguments:
@ -45,7 +48,7 @@ class XML(Abstract):
"enabled" if self.rotate else "disabled") "enabled" if self.rotate else "disabled")
def open(self): def open(self) -> bool:
"""Lock the directory""" """Lock the directory"""
if not os.path.isdir(self.basedir): if not os.path.isdir(self.basedir):
@ -75,7 +78,7 @@ class XML(Abstract):
return True return True
def close(self): def close(self) -> bool:
"""Release a locked path""" """Release a locked path"""
if hasattr(self, "lock_file"): if hasattr(self, "lock_file"):
@ -91,19 +94,19 @@ class XML(Abstract):
return False return False
def _get_data_file_path(self, module): def _get_data_file_path(self, module: str) -> str:
"""Get the path to the module data file""" """Get the path to the module data file"""
return os.path.join(self.basedir, module + ".xml") return os.path.join(self.basedir, module + ".xml")
def _get_lock_file_path(self): def _get_lock_file_path(self) -> str:
"""Get the path to the datastore lock file""" """Get the path to the datastore lock file"""
return os.path.join(self.basedir, ".used_by_nemubot") return os.path.join(self.basedir, ".used_by_nemubot")
def load(self, module, extendsTags={}): def load(self, module: str, extendsTags: Mapping[str, Any] = {}) -> Abstract:
"""Load data for the given module """Load data for the given module
Argument: Argument:
@ -116,7 +119,7 @@ class XML(Abstract):
data_file = self._get_data_file_path(module) data_file = self._get_data_file_path(module)
def parse(path): def parse(path: str):
from nemubot.tools.xmlparser import XMLParser from nemubot.tools.xmlparser import XMLParser
from nemubot.datastore.nodes import basic as basicNodes from nemubot.datastore.nodes import basic as basicNodes
from nemubot.datastore.nodes import python as pythonNodes from nemubot.datastore.nodes import python as pythonNodes
@ -156,7 +159,7 @@ class XML(Abstract):
return Abstract.load(self, module) return Abstract.load(self, module)
def _rotate(self, path): def _rotate(self, path: str) -> None:
"""Backup given path """Backup given path
Argument: Argument:
@ -173,7 +176,7 @@ class XML(Abstract):
os.rename(src, dst) os.rename(src, dst)
def _save_node(self, gen, node): def _save_node(self, gen, node: Any):
from nemubot.datastore.nodes.generic import ParsingNode from nemubot.datastore.nodes.generic import ParsingNode
# First, get the serialized form of the node # First, get the serialized form of the node
@ -191,7 +194,7 @@ class XML(Abstract):
gen.endElement(node.tag) gen.endElement(node.tag)
def save(self, module, data): def save(self, module: str, data: Any) -> bool:
"""Load data for the given module """Load data for the given module
Argument: Argument:

View File

@ -15,15 +15,23 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Optional
class ModuleEvent: class ModuleEvent:
"""Representation of a event initiated by a bot module""" """Representation of a event initiated by a bot module"""
def __init__(self, call=None, call_data=None, func=None, func_data=None, def __init__(self,
cmp=None, cmp_data=None, interval=60, offset=0, times=1): call: Callable = None,
call_data: Callable = None,
func: Callable = None,
func_data: Any = None,
cmp: Any = None,
cmp_data: Any = None,
interval: int = 60,
offset: int = 0,
times: int = 1):
"""Initialize the event """Initialize the event
Keyword arguments: Keyword arguments:
@ -70,8 +78,9 @@ class ModuleEvent:
# How many times do this event? # How many times do this event?
self.times = times self.times = times
@property @property
def current(self): def current(self) -> Optional[datetime.datetime]:
"""Return the date of the near check""" """Return the date of the near check"""
if self.times != 0: if self.times != 0:
if self._end is None: if self._end is None:
@ -79,8 +88,9 @@ class ModuleEvent:
return self._end return self._end
return None return None
@property @property
def next(self): def next(self) -> Optional[datetime.datetime]:
"""Return the date of the next check""" """Return the date of the next check"""
if self.times != 0: if self.times != 0:
if self._end is None: if self._end is None:
@ -90,14 +100,16 @@ class ModuleEvent:
return self._end return self._end
return None return None
@property @property
def time_left(self): def time_left(self) -> Union[datetime.datetime, int]:
"""Return the time left before/after the near check""" """Return the time left before/after the near check"""
if self.current is not None: if self.current is not None:
return self.current - datetime.now(timezone.utc) return self.current - datetime.now(timezone.utc)
return 99999 # TODO: 99999 is not a valid time to return return 99999 # TODO: 99999 is not a valid time to return
def check(self):
def check(self) -> None:
"""Run a check and realized the event if this is time""" """Run a check and realized the event if this is time"""
# Get initial data # Get initial data

View File

@ -21,7 +21,7 @@ class HooksManager:
"""Class to manage hooks""" """Class to manage hooks"""
def __init__(self, name="core"): def __init__(self, name: str = "core"):
"""Initialize the manager""" """Initialize the manager"""
self.hooks = dict() self.hooks = dict()

View File

@ -18,16 +18,18 @@ from importlib.abc import Finder
from importlib.machinery import SourceFileLoader from importlib.machinery import SourceFileLoader
import logging import logging
import os import os
from typing import Callable
logger = logging.getLogger("nemubot.importer") logger = logging.getLogger("nemubot.importer")
class ModuleFinder(Finder): class ModuleFinder(Finder):
def __init__(self, modules_paths, add_module): def __init__(self, modules_paths: str, add_module: Callable[]):
self.modules_paths = modules_paths self.modules_paths = modules_paths
self.add_module = add_module self.add_module = add_module
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
# Search only for new nemubot modules (packages init) # Search only for new nemubot modules (packages init)
if path is None: if path is None:

View File

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer): class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None): def __init__(self, srv, dest, socket=None):
super().__init__(self) super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection? self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion self.messages = list() # Message queued before connexion

View File

@ -27,10 +27,10 @@ class IRC(SocketServer):
"""Concrete implementation of a connexion to an IRC server""" """Concrete implementation of a connexion to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None, def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None, nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None, realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None): channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server """Prepare a connection with an IRC server
Keyword arguments: Keyword arguments:
@ -54,7 +54,8 @@ class IRC(SocketServer):
self.owner = owner self.owner = owner
self.realname = realname self.realname = realname
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port)) super().__init__(name=self.username + "@" + host + ":" + str(port),
host=host, port=port, **kwargs)
self.printer = IRCPrinter self.printer = IRCPrinter
self.encoding = encoding self.encoding = encoding
@ -231,20 +232,19 @@ class IRC(SocketServer):
# Open/close # Open/close
def open(self): def connect(self):
if super().open(): super().connect()
if self.password is not None:
self.write("PASS :" + self.password) if self.password is not None:
if self.capabilities is not None: self.write("PASS :" + self.password)
self.write("CAP LS") if self.capabilities is not None:
self.write("NICK :" + self.nick) self.write("CAP LS")
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) self.write("NICK :" + self.nick)
return True self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
return False
def close(self): def close(self):
if not self.closed: if not self._closed:
self.write("QUIT") self.write("QUIT")
return super().close() return super().close()
@ -253,8 +253,8 @@ class IRC(SocketServer):
# Read # Read
def read(self): def async_read(self):
for line in super().read(): for line in super().async_read():
# PING should be handled here, so start parsing here :/ # PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding) msg = IRCMessage(line, self.encoding)

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,20 +14,13 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock() def factory(uri, ssl=False, **init_args):
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
from urllib.parse import urlparse, unquote from urllib.parse import urlparse, unquote
o = urlparse(uri) o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs": if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
@ -36,7 +29,7 @@ def factory(uri, **init_args):
modifiers = o.path.split(",") modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:]) target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True if o.scheme == "ircs": ssl = True
if o.hostname is not None: args["host"] = o.hostname if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username if o.username is not None: args["username"] = o.username
@ -65,6 +58,11 @@ def factory(uri, **init_args):
args["channels"] = [ target ] args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer from nemubot.server.IRC import IRC as IRCServer
return IRCServer(**args) srv = IRCServer(**args)
else:
return None if ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
return ctx.wrap_socket(srv)
return srv

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,34 +14,34 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging import logging
import queue import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase): class AbstractServer:
"""An abstract server: handle communication with an IM server""" """An abstract server: handle communication with an IM server"""
def __init__(self, name=None, send_callback=None): def __init__(self, name=None, **kwargs):
"""Initialize an abstract server """Initialize an abstract server
Keyword argument: Keyword argument:
send_callback -- Callback when developper want to send a message name -- Identifier of the socket, for convinience
""" """
self._name = name self._name = name
super().__init__() super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + self.name) self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue() self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
else: def __del__(self):
self._send_callback = self._write_select print("Server deleted")
@property @property
@ -54,40 +54,28 @@ class AbstractServer(io.IOBase):
# Open/close # Open/close
def __enter__(self): def connect(self, *args, **kwargs):
self.open() """Register the server in _poll"""
return self
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._connected()
def _connected(self):
sync_act("sckt register %d" % self.fileno())
def __exit__(self, type, value, traceback): def close(self, *args, **kwargs):
self.close() """Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self): if self.fileno() > 0:
"""Generic open function that register the server un _rlist in case sync_act("sckt unregister %d" % self.fileno())
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
super().close(*args, **kwargs)
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
# Writes # Writes
@ -99,13 +87,16 @@ class AbstractServer(io.IOBase):
message -- message to send message -- message to send
""" """
self._send_callback(message) self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt write %d" % self.fileno())
def write_select(self): def async_write(self):
"""Internal function used by the select function""" """Internal function used when the file descriptor is writable"""
try: try:
_wlist.remove(self) sync_act("sckt unwrite %d" % self.fileno())
while not self._sending_queue.empty(): while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait()) self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done() self._sending_queue.task_done()
@ -114,19 +105,6 @@ class AbstractServer(io.IOBase):
pass pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response): def send_response(self, response):
"""Send a formated Message class """Send a formated Message class
@ -149,13 +127,39 @@ class AbstractServer(io.IOBase):
# Read # Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg): def parse(self, msg):
raise NotImplemented raise NotImplemented
# Exceptions # Exceptions
def exception(self): def exception(self, flags):
"""Exception occurs in fd""" """Exception occurs on fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.name) self.close()

View File

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,117 +14,35 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import nemubot.message as message import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer): class _Socket(AbstractServer, socket.socket):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)""" """Concrete implementation of a socket connection"""
def __init__(self, sock_location=None, def __init__(self, **kwargs):
host=None, port=None,
sock=None,
ssl=False,
name=None):
"""Create a server socket """Create a server socket
Keyword arguments: Keyword arguments:
sock_location -- Path to the UNIX socket
host -- Hostname of the INET socket
port -- Port of the INET socket
sock -- Already connected socket
ssl -- Should TLS connection enabled ssl -- Should TLS connection enabled
name -- Convinience name
""" """
import socket super().__init__(**kwargs)
assert(sock is None or isinstance(sock, socket.SocketType))
assert(port is None or isinstance(port, int))
super().__init__(name=name)
if sock is None:
if sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connect_to = sock_location
elif host is not None:
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
self.socket = socket.socket(af, socktype, proto)
self.connect_to = sa
break
else:
self.socket = sock
self.ssl = ssl
self.readbuffer = b'' self.readbuffer = b''
self.printer = SocketPrinter self.printer = SocketPrinter
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket._closed
# Open/close
def open(self):
if not self.closed:
return True
try:
self.socket.connect(self.connect_to)
self.logger.info("Connected to %s", self.connect_to)
except:
self.socket.close()
self.logger.exception("Unable to connect to %s",
self.connect_to)
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return super().open()
def close(self):
import socket
# Flush the sending queue before close
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if not self.closed:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
return super().close()
# Write # Write
def _write(self, cnt): def _write(self, cnt):
if self.closed: self.sendall(cnt)
return
self.socket.sendall(cnt)
def format(self, txt): def format(self, txt):
@ -136,19 +54,12 @@ class SocketServer(AbstractServer):
# Read # Read
def read(self): def read(self, n=1024):
if self.closed: return self.recv(n)
return []
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line
def parse(self, line): def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex import shlex
line = line.strip().decode() line = line.strip().decode()
@ -157,48 +68,84 @@ class SocketServer(AbstractServer):
except ValueError: except ValueError:
args = line.split(' ') args = line.split(' ')
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you") if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
class SocketListener(AbstractServer): class SocketServer(_Socket):
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None): def __init__(self, host, port, bind=None, **kwargs):
super().__init__(name=name) super().__init__(family=socket.AF_INET, **kwargs)
self.new_server_cb = new_server_cb
self.sock_location = sock_location assert(host is not None)
self.host = host assert(isinstance(port, int))
self.port = port
self.ssl = ssl self._host = host
self.nb_son = 0 self._port = port
self._bind = bind
def fileno(self): def connect(self):
return self.socket.fileno() if self.socket else None self.logger.info("Connection to %s:%d", self._host, self._port)
super().connect((self._host, self._port))
if self._bind:
super().bind(self._bind)
@property class UnixSocket(_Socket):
def closed(self):
"""Indicator of the connection aliveness""" def __init__(self, location, **kwargs):
return self.socket is None super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def open(self): def connect(self):
import os self.logger.info("Connection to unix://%s", self._socket_path)
import socket super().connect(self._socket_path)
if self.sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.remove(self.sock_location)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind((self.host, self.port))
self.socket.listen(5)
return super().open() class _Listener(_Socket):
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self):
conn, addr = self.accept()
self.logger.info("Accept new connection from %s", addr)
fileno = conn.fileno()
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
ss.connect = ss._connected
self._new_server_cb(ss, autoconnect=True)
return b''
class UnixSocketListener(_Listener, UnixSocket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self):
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._connected()
def close(self): def close(self):
@ -206,25 +153,14 @@ class SocketListener(AbstractServer):
import socket import socket
try: try:
self.socket.shutdown(socket.SHUT_RDWR) self.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
except socket.error: except socket.error:
pass pass
return super().close() super().close()
try:
# Read if self._socket_path is not None:
os.remove(self._socket_path)
def read(self): except:
if self.closed: pass
return []
conn, addr = self.socket.accept()
self.nb_son += 1
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []

View File

@ -57,11 +57,15 @@ def word_distance(str1, str2):
return d[len(str1)][len(str2)] return d[len(str1)][len(str2)]
def guess(pattern, expect): def guess(pattern, expect, max_depth=0):
if max_depth == 0:
max_depth = 1 + len(pattern) / 4
elif max_depth <= -1:
max_depth = len(pattern) - max_depth
if len(expect): if len(expect):
se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1]) se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1])
_, m = se[0] _, m = se[0]
for e, wd in se: for e, wd in se:
if wd > m or wd > 1 + len(pattern) / 4: if wd > m or wd > max_depth:
break break
yield e yield e