diff --git a/nemubot/__init__.py b/nemubot/__init__.py
index c4e1df9..cdc6265 100644
--- a/nemubot/__init__.py
+++ b/nemubot/__init__.py
@@ -17,12 +17,15 @@
__version__ = '4.0.dev3'
__author__ = 'nemunaire'
+from typing import Optional
+
from nemubot.modulecontext import ModuleContext
context = ModuleContext(None, None)
-def requires_version(min=None, max=None):
+def requires_version(min: Optional[int] = None,
+ max: Optional[int] = None) -> None:
"""Raise ImportError if the current version is not in the given range
Keyword arguments:
@@ -39,7 +42,7 @@ def requires_version(min=None, max=None):
"but this is nemubot v%s." % (str(max), __version__))
-def attach(pid, socketfile):
+def attach(pid: int, socketfile: str) -> int:
import socket
import sys
@@ -98,7 +101,7 @@ def attach(pid, socketfile):
return 0
-def daemonize():
+def daemonize() -> None:
"""Detach the running process to run as a daemon
"""
diff --git a/nemubot/__main__.py b/nemubot/__main__.py
index 5a236f4..cb9fae6 100644
--- a/nemubot/__main__.py
+++ b/nemubot/__main__.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-def main():
+def main() -> None:
import os
import signal
import sys
@@ -152,6 +152,7 @@ def main():
# Signals handling
def sigtermhandler(signum, frame):
"""On SIGTERM and SIGINT, quit nicely"""
+ sigusr1handler(signum, frame)
context.quit()
signal.signal(signal.SIGINT, sigtermhandler)
signal.signal(signal.SIGTERM, sigtermhandler)
@@ -170,17 +171,23 @@ def main():
def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces"""
- import traceback
+ import threading, traceback
for threadId, stack in sys._current_frames().items():
- logger.debug("########### Thread %d:\n%s",
- threadId,
+ thName = "#%d" % threadId
+ for th in threading.enumerate():
+ if th.ident == threadId:
+ thName = th.name
+ break
+ logger.debug("########### Thread %s:\n%s",
+ thName,
"".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile:
- from nemubot.server.socket import SocketListener
- context.add_server(SocketListener(context.add_server, "master_socket",
- sock_location=args.socketfile))
+ from nemubot.server.socket import UnixSocketListener
+ context.add_server(UnixSocketListener(new_server_cb=context.add_server,
+ location=args.socketfile,
+ name="master_socket"))
# context can change when performing an hotswap, always join the latest context
oldcontext = None
diff --git a/nemubot/bot.py b/nemubot/bot.py
index b7c71b9..10bcff7 100644
--- a/nemubot/bot.py
+++ b/nemubot/bot.py
@@ -15,9 +15,13 @@
# along with this program. If not, see .
from datetime import datetime, timezone
+import ipaddress
import logging
+from multiprocessing import JoinableQueue
import threading
+import select
import sys
+from typing import Any, Mapping, Optional, Sequence
from nemubot import __version__
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
@@ -26,13 +30,23 @@ import nemubot.hooks
logger = logging.getLogger("nemubot")
+sync_queue = JoinableQueue()
+
+def sync_act(*args):
+ if isinstance(act, bytes):
+ act = act.decode()
+ sync_queue.put(act)
+
class Bot(threading.Thread):
"""Class containing the bot context and ensuring key goals"""
- def __init__(self, ip="127.0.0.1", modules_paths=list(),
- data_store=datastore.Abstract(), verbosity=0):
+ def __init__(self,
+ ip: Optional[ipaddress] = None,
+ modules_paths: Sequence[str] = list(),
+ data_store: Optional[datastore.Abstract] = None,
+ verbosity: int = 0):
"""Initialize the bot context
Keyword arguments:
@@ -42,7 +56,7 @@ class Bot(threading.Thread):
verbosity -- verbosity level
"""
- threading.Thread.__init__(self)
+ super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__,
@@ -52,16 +66,16 @@ class Bot(threading.Thread):
self.stop = None
# External IP for accessing this bot
- import ipaddress
- self.ip = ipaddress.ip_address(ip)
+ self.ip = ip if ip is not None else ipaddress.ip_address("127.0.0.1")
# Context paths
self.modules_paths = modules_paths
- self.datastore = data_store
+ self.datastore = data_store if data_store is not None else datastore.Abstract()
self.datastore.open()
# Keep global context: servers and modules
- self.servers = dict()
+ self._poll = select.poll()
+ self.servers = dict() # types: Mapping[str, AbstractServer]
self.modules = dict()
self.modules_configuration = dict()
@@ -138,60 +152,76 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue()
self.cnsr_thrd = list()
self.cnsr_thrd_size = -1
- # Synchrone actions to be treated by main thread
- self.sync_queue = Queue()
def run(self):
- from select import select
- from nemubot.server import _lock, _rlist, _wlist, _xlist
+ self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
logger.info("Starting main loop")
self.stop = False
while not self.stop:
- with _lock:
- try:
- rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
- except:
- logger.error("Something went wrong in select")
- fnd_smth = False
- # Looking for invalid server
- for r in _rlist:
- if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
- _rlist.remove(r)
- logger.error("Found invalid object in _rlist: " + str(r))
- fnd_smth = True
- for w in _wlist:
- if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
- _wlist.remove(w)
- logger.error("Found invalid object in _wlist: " + str(w))
- fnd_smth = True
- for x in _xlist:
- if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
- _xlist.remove(x)
- logger.error("Found invalid object in _xlist: " + str(x))
- fnd_smth = True
- if not fnd_smth:
- logger.exception("Can't continue, sorry")
- self.quit()
- continue
+ for fd, flag in self._poll.poll():
+ print("poll")
+ # Handle internal socket passing orders
+ if fd != sync_queue._reader.fileno():
+ srv = self.servers[fd]
- for x in xl:
- try:
- x.exception()
- except:
- logger.exception("Uncatched exception on server exception")
- for w in wl:
- try:
- w.write_select()
- except:
- logger.exception("Uncatched exception on server write")
- for r in rl:
- for i in r.read():
+ if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
- self.receive_message(r, i)
+ srv.exception(flag)
except:
- logger.exception("Uncatched exception on server read")
+ logger.exception("Uncatched exception on server exception")
+
+ if srv.fileno() > 0:
+ if flag & (select.POLLOUT):
+ try:
+ srv.async_write()
+ except:
+ logger.exception("Uncatched exception on server write")
+
+ if flag & (select.POLLIN | select.POLLPRI):
+ try:
+ for i in srv.async_read():
+ self.receive_message(srv, i)
+ except:
+ logger.exception("Uncatched exception on server read")
+
+ else:
+ del self.servers[fd]
+
+
+ # Always check the sync queue
+ while not sync_queue.empty():
+ import shlex
+ args = shlex.split(sync_queue.get())
+ action = args.pop(0)
+
+ logger.info("action: %s: %s", action, args)
+
+ if action == "sckt" and len(args) >= 2:
+ try:
+ if args[0] == "write":
+ self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
+ elif args[0] == "unwrite":
+ self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
+
+ elif args[0] == "register":
+ self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
+ elif args[0] == "unregister":
+ self._poll.unregister(int(args[1]))
+ except:
+ logger.exception("Unhandled excpetion during action:")
+
+ elif action == "exit":
+ self.quit()
+
+ elif action == "loadconf":
+ for path in action.args:
+ logger.debug("Load configuration from %s", path)
+ self.load_file(path)
+ logger.info("Configurations successfully loaded")
+
+ sync_queue.task_done()
# Launch new consumer threads if necessary
@@ -202,17 +232,6 @@ class Bot(threading.Thread):
c = Consumer(self)
self.cnsr_thrd.append(c)
c.start()
-
- while self.sync_queue.qsize() > 0:
- action = self.sync_queue.get_nowait()
- if action[0] == "exit":
- self.quit()
- elif action[0] == "loadconf":
- for path in action[1:]:
- logger.debug("Load configuration from %s", path)
- self.load_file(path)
- logger.info("Configurations successfully loaded")
- self.sync_queue.task_done()
logger.info("Ending main loop")
@@ -414,10 +433,11 @@ class Bot(threading.Thread):
autoconnect -- connect after add?
"""
- if srv.fileno not in self.servers:
- self.servers[srv.fileno] = srv
+ fileno = srv.fileno()
+ if fileno not in self.servers:
+ self.servers[fileno] = srv
if autoconnect and not hasattr(self, "noautoconnect"):
- srv.open()
+ srv.connect()
return True
else:
@@ -439,10 +459,10 @@ class Bot(threading.Thread):
__import__(name)
- def add_module(self, module):
+ def add_module(self, mdl: Any):
"""Add a module to the context, if already exists, unload the
old one before"""
- module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
+ module_name = mdl.__spec__.name if hasattr(mdl, "__spec__") else mdl.__name__
if hasattr(self, "stop") and self.stop:
logger.warn("The bot is stopped, can't register new modules")
@@ -454,40 +474,40 @@ class Bot(threading.Thread):
# Overwrite print built-in
def prnt(*args):
- if hasattr(module, "logger"):
- module.logger.info(" ".join([str(s) for s in args]))
+ if hasattr(mdl, "logger"):
+ mdl.logger.info(" ".join([str(s) for s in args]))
else:
logger.info("[%s] %s", module_name, " ".join([str(s) for s in args]))
- module.print = prnt
+ mdl.print = prnt
# Create module context
from nemubot.modulecontext import ModuleContext
- module.__nemubot_context__ = ModuleContext(self, module)
+ mdl.__nemubot_context__ = ModuleContext(self, mdl)
- if not hasattr(module, "logger"):
- module.logger = logging.getLogger("nemubot.module." + module_name)
+ if not hasattr(mdl, "logger"):
+ mdl.logger = logging.getLogger("nemubot.module." + module_name)
# Replace imported context by real one
- for attr in module.__dict__:
- if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext:
- module.__dict__[attr] = module.__nemubot_context__
+ for attr in mdl.__dict__:
+ if attr != "__nemubot_context__" and type(mdl.__dict__[attr]) == ModuleContext:
+ mdl.__dict__[attr] = mdl.__nemubot_context__
# Register decorated functions
import nemubot.hooks
for s, h in nemubot.hooks.hook.last_registered:
- module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
+ mdl.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
nemubot.hooks.hook.last_registered = []
# Launch the module
- if hasattr(module, "load"):
+ if hasattr(mdl, "load"):
try:
- module.load(module.__nemubot_context__)
+ mdl.load(mdl.__nemubot_context__)
except:
- module.__nemubot_context__.unload()
+ mdl.__nemubot_context__.unload()
raise
# Save a reference to the module
- self.modules[module_name] = module
+ self.modules[module_name] = mdl
def unload_module(self, name):
@@ -530,28 +550,28 @@ class Bot(threading.Thread):
def quit(self):
"""Save and unload modules and disconnect servers"""
- self.datastore.close()
-
if self.event_timer is not None:
logger.info("Stop the event timer...")
self.event_timer.cancel()
+ logger.info("Save and unload all modules...")
+ for mod in self.modules.items():
+ self.unload_module(mod)
+
+ logger.info("Close all servers connection...")
+ for k in self.servers:
+ self.servers[k].close()
+
logger.info("Stop consumers")
k = self.cnsr_thrd
for cnsr in k:
cnsr.stop = True
- logger.info("Save and unload all modules...")
- k = list(self.modules.keys())
- for mod in k:
- self.unload_module(mod)
-
- logger.info("Close all servers connection...")
- k = list(self.servers.keys())
- for srv in k:
- self.servers[srv].close()
+ self.datastore.close()
self.stop = True
+ sync_queue.put("end")
+ sync_queue.join()
# Treatment
diff --git a/nemubot/channel.py b/nemubot/channel.py
index a070131..c01ac90 100644
--- a/nemubot/channel.py
+++ b/nemubot/channel.py
@@ -15,13 +15,18 @@
# along with this program. If not, see .
import logging
+from typing import Optional
+from nemubot.message import Abstract as AbstractMessage
class Channel:
"""A chat room"""
- def __init__(self, name, password=None, encoding=None):
+ def __init__(self,
+ name: str,
+ password: Optional[str] = None,
+ encoding: Optional[str] = None):
"""Initialize the channel
Arguments:
@@ -37,7 +42,8 @@ class Channel:
self.topic = ""
self.logger = logging.getLogger("nemubot.channel." + name)
- def treat(self, cmd, msg):
+
+ def treat(self, cmd: str, msg: AbstractMessage) -> None:
"""Treat a incoming IRC command
Arguments:
@@ -60,7 +66,8 @@ class Channel:
elif cmd == "TOPIC":
self.topic = self.text
- def join(self, nick, level=0):
+
+ def join(self, nick: str, level: int = 0) -> None:
"""Someone join the channel
Argument:
@@ -71,7 +78,8 @@ class Channel:
self.logger.debug("%s join", nick)
self.people[nick] = level
- def chtopic(self, newtopic):
+
+ def chtopic(self, newtopic: str) -> None:
"""Send command to change the topic
Arguments:
@@ -81,7 +89,8 @@ class Channel:
self.srv.send_msg(self.name, newtopic, "TOPIC")
self.topic = newtopic
- def nick(self, oldnick, newnick):
+
+ def nick(self, oldnick: str, newnick: str) -> None:
"""Someone change his nick
Arguments:
@@ -95,7 +104,8 @@ class Channel:
del self.people[oldnick]
self.people[newnick] = lvl
- def part(self, nick):
+
+ def part(self, nick: str) -> None:
"""Someone leave the channel
Argument:
@@ -106,7 +116,8 @@ class Channel:
self.logger.debug("%s has left", nick)
del self.people[nick]
- def mode(self, msg):
+
+ def mode(self, msg: AbstractMessage) -> None:
"""Channel or user mode change
Argument:
@@ -132,7 +143,8 @@ class Channel:
elif msg.text[0] == "-v":
self.people[msg.nick] &= ~1
- def parse332(self, msg):
+
+ def parse332(self, msg: AbstractMessage) -> None:
"""Parse RPL_TOPIC message
Argument:
@@ -141,7 +153,8 @@ class Channel:
self.topic = msg.text
- def parse353(self, msg):
+
+ def parse353(self, msg: AbstractMessage) -> None:
"""Parse RPL_ENDOFWHO message
Argument:
diff --git a/nemubot/config/__init__.py b/nemubot/config/__init__.py
index 6bbc1b2..ea6fed4 100644
--- a/nemubot/config/__init__.py
+++ b/nemubot/config/__init__.py
@@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-def get_boolean(s):
+def get_boolean(s) -> bool:
if isinstance(s, bool):
return s
else:
diff --git a/nemubot/config/include.py b/nemubot/config/include.py
index 408c09a..aca6468 100644
--- a/nemubot/config/include.py
+++ b/nemubot/config/include.py
@@ -16,5 +16,5 @@
class Include:
- def __init__(self, path):
+ def __init__(self, path: str):
self.path = path
diff --git a/nemubot/config/module.py b/nemubot/config/module.py
index 7586697..e67a45b 100644
--- a/nemubot/config/module.py
+++ b/nemubot/config/module.py
@@ -20,7 +20,11 @@ from nemubot.datastore.nodes.generic import GenericNode
class Module(GenericNode):
- def __init__(self, name, autoload=True, **kwargs):
+ def __init__(self,
+ name: str,
+ autoload: bool = True,
+ **kwargs):
super().__init__(None, **kwargs)
+
self.name = name
self.autoload = get_boolean(autoload)
diff --git a/nemubot/config/nemubot.py b/nemubot/config/nemubot.py
index 992cd8e..cc60f86 100644
--- a/nemubot/config/nemubot.py
+++ b/nemubot/config/nemubot.py
@@ -14,15 +14,23 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from nemubot.config.include import Include
-from nemubot.config.module import Module
-from nemubot.config.server import Server
+from typing import Optional, Sequence, Union
+
+import nemubot.config.include
+import nemubot.config.module
+import nemubot.config.server
class Nemubot:
- def __init__(self, nick="nemubot", realname="nemubot", owner=None,
- ip=None, ssl=False, caps=None, encoding="utf-8"):
+ def __init__(self,
+ nick: str = "nemubot",
+ realname: str = "nemubot",
+ owner: Optional[str] = None,
+ ip: Optional[str] = None,
+ ssl: bool = False,
+ caps: Optional[Sequence[str]] = None,
+ encoding: str = "utf-8"):
self.nick = nick
self.realname = realname
self.owner = owner
@@ -34,13 +42,13 @@ class Nemubot:
self.includes = []
- def addChild(self, name, child):
- if name == "module" and isinstance(child, Module):
+ def addChild(self, name: str, child: Union[nemubot.config.module.Module, nemubot.config.server.Server, nemubot.config.include.Include]):
+ if name == "module" and isinstance(child, nemubot.config.module.Module):
self.modules.append(child)
return True
- elif name == "server" and isinstance(child, Server):
+ elif name == "server" and isinstance(child, nemubot.config.server.Server):
self.servers.append(child)
return True
- elif name == "include" and isinstance(child, Include):
+ elif name == "include" and isinstance(child, nemubot.config.include.Include):
self.includes.append(child)
return True
diff --git a/nemubot/config/server.py b/nemubot/config/server.py
index 14ca9a8..b8df692 100644
--- a/nemubot/config/server.py
+++ b/nemubot/config/server.py
@@ -14,12 +14,19 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Optional, Sequence
+
from nemubot.channel import Channel
+import nemubot.config.nemubot
class Server:
- def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs):
+ def __init__(self,
+ uri: str = "irc://nemubot@localhost/",
+ autoconnect: bool = True,
+ caps: Optional[Sequence[str]] = None,
+ **kwargs):
self.uri = uri
self.autoconnect = autoconnect
self.caps = caps.split(" ") if caps is not None else []
@@ -27,7 +34,7 @@ class Server:
self.channels = []
- def addChild(self, name, child):
+ def addChild(self, name: str, child: Channel):
if name == "channel" and isinstance(child, Channel):
self.channels.append(child)
return True
diff --git a/nemubot/consumer.py b/nemubot/consumer.py
index 0cd4ed5..8ea5a40 100644
--- a/nemubot/consumer.py
+++ b/nemubot/consumer.py
@@ -17,6 +17,12 @@
import logging
import queue
import threading
+from typing import List
+
+from nemubot.bot import Bot
+from nemubot.event import ModuleEvent
+from nemubot.message.abstract import Abstract as AbstractMessage
+from nemubot.server.abstract import AbstractServer
logger = logging.getLogger("nemubot.consumer")
@@ -25,18 +31,15 @@ class MessageConsumer:
"""Store a message before treating"""
- def __init__(self, srv, msg):
+ def __init__(self, srv: AbstractServer, msg: AbstractMessage):
self.srv = srv
self.orig = msg
- def run(self, context):
+ def run(self, context: Bot) -> None:
"""Create, parse and treat the message"""
- from nemubot.bot import Bot
- assert isinstance(context, Bot)
-
- msgs = []
+ msgs = [] # type: List[AbstractMessage]
# Parse the message
try:
@@ -55,8 +58,6 @@ class MessageConsumer:
if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
- from nemubot.server.abstract import AbstractServer
-
# Treat the message
for msg in msgs:
for res in context.treater.treat_msg(msg):
@@ -87,12 +88,12 @@ class EventConsumer:
"""Store a event before treating"""
- def __init__(self, evt, timeout=20):
+ def __init__(self, evt: ModuleEvent, timeout: int = 20):
self.evt = evt
self.timeout = timeout
- def run(self, context):
+ def run(self, context: Bot) -> None:
try:
self.evt.check()
except:
@@ -113,13 +114,13 @@ class Consumer(threading.Thread):
"""Dequeue and exec requested action"""
- def __init__(self, context):
+ def __init__(self, context: Bot):
self.context = context
self.stop = False
super().__init__(name="Nemubot consumer")
- def run(self):
+ def run(self) -> None:
try:
while not self.stop:
stm = self.context.cnsr_queue.get(True, 1)
diff --git a/nemubot/datastore/abstract.py b/nemubot/datastore/abstract.py
index f54bbcd..856851f 100644
--- a/nemubot/datastore/abstract.py
+++ b/nemubot/datastore/abstract.py
@@ -25,11 +25,14 @@ class Abstract:
return None
- def open(self):
- return
- def close(self):
- return
+ def open(self) -> bool:
+ return True
+
+
+ def close(self) -> bool:
+ return True
+
def load(self, module):
"""Load data for the given module
@@ -43,6 +46,7 @@ class Abstract:
return self.new()
+
def save(self, module, data):
"""Load data for the given module
@@ -56,9 +60,11 @@ class Abstract:
return True
+
def __enter__(self):
self.open()
return self
+
def __exit__(self, type, value, traceback):
self.close()
diff --git a/nemubot/datastore/nodes/basic.py b/nemubot/datastore/nodes/basic.py
index 6fbd136..a4467b2 100644
--- a/nemubot/datastore/nodes/basic.py
+++ b/nemubot/datastore/nodes/basic.py
@@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Any, Mapping, Sequence
+
+from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
@@ -25,24 +28,24 @@ class ListNode(Serializable):
serializetag = "list"
def __init__(self, **kwargs):
- self.items = list()
+ self.items = list() # type: Sequence
- def addChild(self, name, child):
+ def addChild(self, name: str, child) -> bool:
self.items.append(child)
return True
- def parsedForm(self):
+ def parsedForm(self) -> Sequence:
return self.items
- def __len__(self):
+ def __len__(self) -> int:
return len(self.items)
- def __getitem__(self, item):
+ def __getitem__(self, item: int) -> Any:
return self.items[item]
- def __setitem__(self, item, v):
+ def __setitem__(self, item: int, v: Any) -> None:
self.items[item] = v
def __contains__(self, item):
@@ -52,8 +55,7 @@ class ListNode(Serializable):
return self.items.__repr__()
- def serialize(self):
- from nemubot.datastore.nodes.generic import ParsingNode
+ def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
for i in self.items:
node.children.append(ParsingNode.serialize_node(i))
@@ -72,12 +74,12 @@ class DictNode(Serializable):
self._cur = None
- def startElement(self, name, attrs):
+ def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None and "key" in attrs:
self._cur = attrs["key"]
return False
- def addChild(self, name, child):
+ def addChild(self, name: str, child: Any):
if self._cur is None:
return False
@@ -85,24 +87,24 @@ class DictNode(Serializable):
self._cur = None
return True
- def parsedForm(self):
+ def parsedForm(self) -> Mapping:
return self.items
- def __getitem__(self, item):
+ def __getitem__(self, item: str) -> Any:
return self.items[item]
- def __setitem__(self, item, v):
+ def __setitem__(self, item: str, v: str) -> None:
self.items[item] = v
- def __contains__(self, item):
+ def __contains__(self, item: str) -> bool:
return item in self.items
- def __repr__(self):
+ def __repr__(self) -> str:
return self.items.__repr__()
- def serialize(self):
+ def serialize(self) -> ParsingNode:
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
for k in self.items:
diff --git a/nemubot/datastore/nodes/generic.py b/nemubot/datastore/nodes/generic.py
index c9840bc..939019c 100644
--- a/nemubot/datastore/nodes/generic.py
+++ b/nemubot/datastore/nodes/generic.py
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from typing import Any, Optional, Mapping, Union
+
from nemubot.datastore.nodes.serializable import Serializable
@@ -22,41 +24,44 @@ class ParsingNode:
"""Allow any kind of subtags, just keep parsed ones
"""
- def __init__(self, tag=None, **kwargs):
+ def __init__(self,
+ tag: Optional[str] = None,
+ **kwargs):
self.tag = tag
self.attrs = kwargs
self.content = ""
self.children = []
- def characters(self, content):
+ def characters(self, content: str) -> None:
self.content += content
- def addChild(self, name, child):
+ def addChild(self, name: str, child: Any) -> bool:
self.children.append(child)
return True
- def hasNode(self, nodename):
+ def hasNode(self, nodename: str) -> bool:
return self.getNode(nodename) is not None
- def getNode(self, nodename):
+ def getNode(self, nodename: str) -> Optional[Any]:
for c in self.children:
if c is not None and c.tag == nodename:
return c
return None
- def __getitem__(self, item):
+ def __getitem__(self, item: str) -> Any:
return self.attrs[item]
- def __contains__(self, item):
+ def __contains__(self, item: str) -> bool:
return item in self.attrs
- def serialize_node(node, **def_kwargs):
+ def serialize_node(node: Union[Serializable, str, int, float, list, dict],
+ **def_kwargs):
"""Serialize any node or basic data to a ParsingNode instance"""
if isinstance(node, Serializable):
@@ -102,13 +107,16 @@ class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary
"""
- def __init__(self, tag, **kwargs):
+ def __init__(self,
+ tag: str,
+ **kwargs):
super().__init__(tag, **kwargs)
+
self._cur = None
self._deep_cur = 0
- def startElement(self, name, attrs):
+ def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None:
self._cur = GenericNode(name, **attrs)
self._deep_cur = 0
@@ -118,14 +126,14 @@ class GenericNode(ParsingNode):
return True
- def characters(self, content):
+ def characters(self, content: str):
if self._cur is None:
super().characters(content)
else:
self._cur.characters(content)
- def endElement(self, name):
+ def endElement(self, name: str):
if name is None:
return
diff --git a/nemubot/datastore/nodes/python.py b/nemubot/datastore/nodes/python.py
index 6e4278b..819bf21 100644
--- a/nemubot/datastore/nodes/python.py
+++ b/nemubot/datastore/nodes/python.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
@@ -27,15 +28,15 @@ class PythonTypeNode(Serializable):
self._cnt = ""
- def characters(self, content):
+ def characters(self, content: str) -> None:
self._cnt += content
- def endElement(self, name):
+ def endElement(self, name: str) -> None:
raise NotImplemented
- def __repr__(self):
+ def __repr__(self) -> str:
return self.value.__repr__()
@@ -50,12 +51,11 @@ class IntNode(PythonTypeNode):
serializetag = "int"
- def endElement(self, name):
+ def endElement(self, name: str) -> bool:
self.value = int(self._cnt)
return True
- def serialize(self):
- from nemubot.datastore.nodes.generic import ParsingNode
+ def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
@@ -65,12 +65,11 @@ class FloatNode(PythonTypeNode):
serializetag = "float"
- def endElement(self, name):
+ def endElement(self, name: str) -> bool:
self.value = float(self._cnt)
return True
- def serialize(self):
- from nemubot.datastore.nodes.generic import ParsingNode
+ def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
@@ -85,7 +84,6 @@ class StringNode(PythonTypeNode):
return True
def serialize(self):
- from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
diff --git a/nemubot/datastore/xml.py b/nemubot/datastore/xml.py
index 266c3ac..abf1492 100644
--- a/nemubot/datastore/xml.py
+++ b/nemubot/datastore/xml.py
@@ -17,6 +17,7 @@
import fcntl
import logging
import os
+from typing import Any, Mapping
import xml.parsers.expat
from nemubot.datastore.abstract import Abstract
@@ -28,7 +29,9 @@ class XML(Abstract):
"""A concrete implementation of a data store that relies on XML files"""
- def __init__(self, basedir, rotate=True):
+ def __init__(self,
+ basedir: str,
+ rotate: bool = True):
"""Initialize the datastore
Arguments:
@@ -45,7 +48,7 @@ class XML(Abstract):
"enabled" if self.rotate else "disabled")
- def open(self):
+ def open(self) -> bool:
"""Lock the directory"""
if not os.path.isdir(self.basedir):
@@ -75,7 +78,7 @@ class XML(Abstract):
return True
- def close(self):
+ def close(self) -> bool:
"""Release a locked path"""
if hasattr(self, "lock_file"):
@@ -91,19 +94,19 @@ class XML(Abstract):
return False
- def _get_data_file_path(self, module):
+ def _get_data_file_path(self, module: str) -> str:
"""Get the path to the module data file"""
return os.path.join(self.basedir, module + ".xml")
- def _get_lock_file_path(self):
+ def _get_lock_file_path(self) -> str:
"""Get the path to the datastore lock file"""
return os.path.join(self.basedir, ".used_by_nemubot")
- def load(self, module, extendsTags={}):
+ def load(self, module: str, extendsTags: Mapping[str, Any] = {}) -> Abstract:
"""Load data for the given module
Argument:
@@ -116,7 +119,7 @@ class XML(Abstract):
data_file = self._get_data_file_path(module)
- def parse(path):
+ def parse(path: str):
from nemubot.tools.xmlparser import XMLParser
from nemubot.datastore.nodes import basic as basicNodes
from nemubot.datastore.nodes import python as pythonNodes
@@ -156,7 +159,7 @@ class XML(Abstract):
return Abstract.load(self, module)
- def _rotate(self, path):
+ def _rotate(self, path: str) -> None:
"""Backup given path
Argument:
@@ -173,7 +176,7 @@ class XML(Abstract):
os.rename(src, dst)
- def _save_node(self, gen, node):
+ def _save_node(self, gen, node: Any):
from nemubot.datastore.nodes.generic import ParsingNode
# First, get the serialized form of the node
@@ -191,7 +194,7 @@ class XML(Abstract):
gen.endElement(node.tag)
- def save(self, module, data):
+ def save(self, module: str, data: Any) -> bool:
"""Load data for the given module
Argument:
diff --git a/nemubot/event/__init__.py b/nemubot/event/__init__.py
index 7b2adfd..ab96efb 100644
--- a/nemubot/event/__init__.py
+++ b/nemubot/event/__init__.py
@@ -15,15 +15,23 @@
# along with this program. If not, see .
from datetime import datetime, timedelta, timezone
+from typing import Any, Callable, Optional
class ModuleEvent:
"""Representation of a event initiated by a bot module"""
- def __init__(self, call=None, call_data=None, func=None, func_data=None,
- cmp=None, cmp_data=None, interval=60, offset=0, times=1):
-
+ def __init__(self,
+ call: Callable = None,
+ call_data: Callable = None,
+ func: Callable = None,
+ func_data: Any = None,
+ cmp: Any = None,
+ cmp_data: Any = None,
+ interval: int = 60,
+ offset: int = 0,
+ times: int = 1):
"""Initialize the event
Keyword arguments:
@@ -70,8 +78,9 @@ class ModuleEvent:
# How many times do this event?
self.times = times
+
@property
- def current(self):
+ def current(self) -> Optional[datetime.datetime]:
"""Return the date of the near check"""
if self.times != 0:
if self._end is None:
@@ -79,8 +88,9 @@ class ModuleEvent:
return self._end
return None
+
@property
- def next(self):
+ def next(self) -> Optional[datetime.datetime]:
"""Return the date of the next check"""
if self.times != 0:
if self._end is None:
@@ -90,14 +100,16 @@ class ModuleEvent:
return self._end
return None
+
@property
- def time_left(self):
+ def time_left(self) -> Union[datetime.datetime, int]:
"""Return the time left before/after the near check"""
if self.current is not None:
return self.current - datetime.now(timezone.utc)
return 99999 # TODO: 99999 is not a valid time to return
- def check(self):
+
+ def check(self) -> None:
"""Run a check and realized the event if this is time"""
# Get initial data
diff --git a/nemubot/hooks/manager.py b/nemubot/hooks/manager.py
index 6a57d2a..9d57483 100644
--- a/nemubot/hooks/manager.py
+++ b/nemubot/hooks/manager.py
@@ -21,7 +21,7 @@ class HooksManager:
"""Class to manage hooks"""
- def __init__(self, name="core"):
+ def __init__(self, name: str = "core"):
"""Initialize the manager"""
self.hooks = dict()
diff --git a/nemubot/importer.py b/nemubot/importer.py
index eaf1535..2827da9 100644
--- a/nemubot/importer.py
+++ b/nemubot/importer.py
@@ -18,16 +18,18 @@ from importlib.abc import Finder
from importlib.machinery import SourceFileLoader
import logging
import os
+from typing import Callable
logger = logging.getLogger("nemubot.importer")
class ModuleFinder(Finder):
- def __init__(self, modules_paths, add_module):
+ def __init__(self, modules_paths: str, add_module: Callable[]):
self.modules_paths = modules_paths
self.add_module = add_module
+
def find_module(self, fullname, path=None):
# Search only for new nemubot modules (packages init)
if path is None:
diff --git a/nemubot/server/DCC.py b/nemubot/server/DCC.py
index 644a8cb..c1a6852 100644
--- a/nemubot/server/DCC.py
+++ b/nemubot/server/DCC.py
@@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None):
- super().__init__(self)
+ super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion
diff --git a/nemubot/server/IRC.py b/nemubot/server/IRC.py
index 08e2bc5..d09966a 100644
--- a/nemubot/server/IRC.py
+++ b/nemubot/server/IRC.py
@@ -27,10 +27,10 @@ class IRC(SocketServer):
"""Concrete implementation of a connexion to an IRC server"""
- def __init__(self, host="localhost", port=6667, ssl=False, owner=None,
+ def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None,
- channels=list(), on_connect=None):
+ channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server
Keyword arguments:
@@ -54,7 +54,8 @@ class IRC(SocketServer):
self.owner = owner
self.realname = realname
- super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
+ super().__init__(name=self.username + "@" + host + ":" + str(port),
+ host=host, port=port, **kwargs)
self.printer = IRCPrinter
self.encoding = encoding
@@ -231,20 +232,19 @@ class IRC(SocketServer):
# Open/close
- def open(self):
- if super().open():
- if self.password is not None:
- self.write("PASS :" + self.password)
- if self.capabilities is not None:
- self.write("CAP LS")
- self.write("NICK :" + self.nick)
- self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
- return True
- return False
+ def connect(self):
+ super().connect()
+
+ if self.password is not None:
+ self.write("PASS :" + self.password)
+ if self.capabilities is not None:
+ self.write("CAP LS")
+ self.write("NICK :" + self.nick)
+ self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
def close(self):
- if not self.closed:
+ if not self._closed:
self.write("QUIT")
return super().close()
@@ -253,8 +253,8 @@ class IRC(SocketServer):
# Read
- def read(self):
- for line in super().read():
+ def async_read(self):
+ for line in super().async_read():
# PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding)
diff --git a/nemubot/server/__init__.py b/nemubot/server/__init__.py
index 3c88138..464c924 100644
--- a/nemubot/server/__init__.py
+++ b/nemubot/server/__init__.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -14,20 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import threading
-_lock = threading.Lock()
-
-# Lists for select
-_rlist = []
-_wlist = []
-_xlist = []
-
-
-def factory(uri, **init_args):
+def factory(uri, ssl=False, **init_args):
from urllib.parse import urlparse, unquote
o = urlparse(uri)
+ srv = None
+
if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
@@ -36,7 +29,7 @@ def factory(uri, **init_args):
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
- if o.scheme == "ircs": args["ssl"] = True
+ if o.scheme == "ircs": ssl = True
if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username
@@ -65,6 +58,11 @@ def factory(uri, **init_args):
args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer
- return IRCServer(**args)
- else:
- return None
+ srv = IRCServer(**args)
+
+ if ssl:
+ import ssl
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ return ctx.wrap_socket(srv)
+
+ return srv
diff --git a/nemubot/server/abstract.py b/nemubot/server/abstract.py
index dc2081d..7e31cda 100644
--- a/nemubot/server/abstract.py
+++ b/nemubot/server/abstract.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -14,34 +14,34 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import io
import logging
import queue
-from nemubot.server import _lock, _rlist, _wlist, _xlist
+from nemubot.bot import sync_act
-# Extends from IOBase in order to be compatible with select function
-class AbstractServer(io.IOBase):
+
+class AbstractServer:
"""An abstract server: handle communication with an IM server"""
- def __init__(self, name=None, send_callback=None):
+ def __init__(self, name=None, **kwargs):
"""Initialize an abstract server
Keyword argument:
- send_callback -- Callback when developper want to send a message
+ name -- Identifier of the socket, for convinience
"""
self._name = name
- super().__init__()
+ super().__init__(**kwargs)
- self.logger = logging.getLogger("nemubot.server." + self.name)
+ self.logger = logging.getLogger("nemubot.server." + str(self.name))
+ self._readbuffer = b''
self._sending_queue = queue.Queue()
- if send_callback is not None:
- self._send_callback = send_callback
- else:
- self._send_callback = self._write_select
+
+
+ def __del__(self):
+ print("Server deleted")
@property
@@ -54,40 +54,28 @@ class AbstractServer(io.IOBase):
# Open/close
- def __enter__(self):
- self.open()
- return self
+ def connect(self, *args, **kwargs):
+ """Register the server in _poll"""
+
+ self.logger.info("Opening connection")
+
+ super().connect(*args, **kwargs)
+
+ self._connected()
+
+ def _connected(self):
+ sync_act("sckt register %d" % self.fileno())
- def __exit__(self, type, value, traceback):
- self.close()
+ def close(self, *args, **kwargs):
+ """Unregister the server from _poll"""
+ self.logger.info("Closing connection")
- def open(self):
- """Generic open function that register the server un _rlist in case
- of successful _open"""
- self.logger.info("Opening connection to %s", self.id)
- if not hasattr(self, "_open") or self._open():
- _rlist.append(self)
- _xlist.append(self)
- return True
- return False
+ if self.fileno() > 0:
+ sync_act("sckt unregister %d" % self.fileno())
-
- def close(self):
- """Generic close function that register the server un _{r,w,x}list in
- case of successful _close"""
- self.logger.info("Closing connection to %s", self.id)
- with _lock:
- if not hasattr(self, "_close") or self._close():
- if self in _rlist:
- _rlist.remove(self)
- if self in _wlist:
- _wlist.remove(self)
- if self in _xlist:
- _xlist.remove(self)
- return True
- return False
+ super().close(*args, **kwargs)
# Writes
@@ -99,13 +87,16 @@ class AbstractServer(io.IOBase):
message -- message to send
"""
- self._send_callback(message)
+ self._sending_queue.put(self.format(message))
+ self.logger.debug("Message '%s' appended to write queue", message)
+ sync_act("sckt write %d" % self.fileno())
- def write_select(self):
- """Internal function used by the select function"""
+ def async_write(self):
+ """Internal function used when the file descriptor is writable"""
+
try:
- _wlist.remove(self)
+ sync_act("sckt unwrite %d" % self.fileno())
while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done()
@@ -114,19 +105,6 @@ class AbstractServer(io.IOBase):
pass
- def _write_select(self, message):
- """Send a message to the server safely through select
-
- Argument:
- message -- message to send
- """
-
- self._sending_queue.put(self.format(message))
- self.logger.debug("Message '%s' appended to write queue", message)
- if self not in _wlist:
- _wlist.append(self)
-
-
def send_response(self, response):
"""Send a formated Message class
@@ -149,13 +127,39 @@ class AbstractServer(io.IOBase):
# Read
+ def async_read(self):
+ """Internal function used when the file descriptor is readable
+
+ Returns:
+ A list of fully received messages
+ """
+
+ ret, self._readbuffer = self.lex(self._readbuffer + self.read())
+
+ for r in ret:
+ yield r
+
+
+ def lex(self, buf):
+ """Assume lexing in default case is per line
+
+ Argument:
+ buf -- buffer to lex
+ """
+
+ msgs = buf.split(b'\r\n')
+ partial = msgs.pop()
+
+ return msgs, partial
+
+
def parse(self, msg):
raise NotImplemented
# Exceptions
- def exception(self):
- """Exception occurs in fd"""
- self.logger.warning("Unhandle file descriptor exception on server %s",
- self.name)
+ def exception(self, flags):
+ """Exception occurs on fd"""
+
+ self.close()
diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py
index 13ac9bd..aeb20e5 100644
--- a/nemubot/server/socket.py
+++ b/nemubot/server/socket.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -14,117 +14,35 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import os
+import socket
+
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer
-class SocketServer(AbstractServer):
+class _Socket(AbstractServer, socket.socket):
- """Concrete implementation of a socket connexion (can be wrapped with TLS)"""
+ """Concrete implementation of a socket connection"""
- def __init__(self, sock_location=None,
- host=None, port=None,
- sock=None,
- ssl=False,
- name=None):
+ def __init__(self, **kwargs):
"""Create a server socket
Keyword arguments:
- sock_location -- Path to the UNIX socket
- host -- Hostname of the INET socket
- port -- Port of the INET socket
- sock -- Already connected socket
ssl -- Should TLS connection enabled
- name -- Convinience name
"""
- import socket
-
- assert(sock is None or isinstance(sock, socket.SocketType))
- assert(port is None or isinstance(port, int))
-
- super().__init__(name=name)
-
- if sock is None:
- if sock_location is not None:
- self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- self.connect_to = sock_location
- elif host is not None:
- for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
- self.socket = socket.socket(af, socktype, proto)
- self.connect_to = sa
- break
- else:
- self.socket = sock
-
- self.ssl = ssl
+ super().__init__(**kwargs)
self.readbuffer = b''
self.printer = SocketPrinter
- def fileno(self):
- return self.socket.fileno() if self.socket else None
-
-
- @property
- def closed(self):
- """Indicator of the connection aliveness"""
- return self.socket._closed
-
-
- # Open/close
-
- def open(self):
- if not self.closed:
- return True
-
- try:
- self.socket.connect(self.connect_to)
- self.logger.info("Connected to %s", self.connect_to)
- except:
- self.socket.close()
- self.logger.exception("Unable to connect to %s",
- self.connect_to)
- return False
-
- # Wrap the socket for SSL
- if self.ssl:
- import ssl
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- self.socket = ctx.wrap_socket(self.socket)
-
- return super().open()
-
-
- def close(self):
- import socket
-
- # Flush the sending queue before close
- from nemubot.server import _lock
- _lock.release()
- self._sending_queue.join()
- _lock.acquire()
-
- if not self.closed:
- try:
- self.socket.shutdown(socket.SHUT_RDWR)
- except socket.error:
- pass
-
- self.socket.close()
-
- return super().close()
-
-
# Write
def _write(self, cnt):
- if self.closed:
- return
-
- self.socket.sendall(cnt)
+ self.sendall(cnt)
def format(self, txt):
@@ -136,19 +54,12 @@ class SocketServer(AbstractServer):
# Read
- def read(self):
- if self.closed:
- return []
-
- raw = self.socket.recv(1024)
- temp = (self.readbuffer + raw).split(b'\r\n')
- self.readbuffer = temp.pop()
-
- for line in temp:
- yield line
+ def read(self, n=1024):
+ return self.recv(n)
def parse(self, line):
+ """Implement a default behaviour for socket"""
import shlex
line = line.strip().decode()
@@ -157,48 +68,84 @@ class SocketServer(AbstractServer):
except ValueError:
args = line.split(' ')
- yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
+ if len(args):
+ yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
-class SocketListener(AbstractServer):
+class SocketServer(_Socket):
- def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
- super().__init__(name=name)
- self.new_server_cb = new_server_cb
- self.sock_location = sock_location
- self.host = host
- self.port = port
- self.ssl = ssl
- self.nb_son = 0
+ def __init__(self, host, port, bind=None, **kwargs):
+ super().__init__(family=socket.AF_INET, **kwargs)
+
+ assert(host is not None)
+ assert(isinstance(port, int))
+
+ self._host = host
+ self._port = port
+ self._bind = bind
- def fileno(self):
- return self.socket.fileno() if self.socket else None
+ def connect(self):
+ self.logger.info("Connection to %s:%d", self._host, self._port)
+ super().connect((self._host, self._port))
+
+ if self._bind:
+ super().bind(self._bind)
- @property
- def closed(self):
- """Indicator of the connection aliveness"""
- return self.socket is None
+class UnixSocket(_Socket):
+
+ def __init__(self, location, **kwargs):
+ super().__init__(family=socket.AF_UNIX, **kwargs)
+
+ self._socket_path = location
- def open(self):
- import os
- import socket
+ def connect(self):
+ self.logger.info("Connection to unix://%s", self._socket_path)
+ super().connect(self._socket_path)
- if self.sock_location is not None:
- self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- try:
- os.remove(self.sock_location)
- except FileNotFoundError:
- pass
- self.socket.bind(self.sock_location)
- elif self.host is not None and self.port is not None:
- self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.socket.bind((self.host, self.port))
- self.socket.listen(5)
- return super().open()
+class _Listener(_Socket):
+
+ def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
+ super().__init__(**kwargs)
+
+ self._instanciate = instanciate
+ self._new_server_cb = new_server_cb
+
+
+ def read(self):
+ conn, addr = self.accept()
+ self.logger.info("Accept new connection from %s", addr)
+
+ fileno = conn.fileno()
+ ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
+ ss.connect = ss._connected
+ self._new_server_cb(ss, autoconnect=True)
+
+ return b''
+
+
+class UnixSocketListener(_Listener, UnixSocket):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+
+ def connect(self):
+ self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
+
+ try:
+ os.remove(self._socket_path)
+ except FileNotFoundError:
+ pass
+
+ self.bind(self._socket_path)
+ self.listen(5)
+ self.logger.info("Socket ready for accepting new connections")
+
+ self._connected()
def close(self):
@@ -206,25 +153,14 @@ class SocketListener(AbstractServer):
import socket
try:
- self.socket.shutdown(socket.SHUT_RDWR)
- self.socket.close()
- if self.sock_location is not None:
- os.remove(self.sock_location)
+ self.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
- return super().close()
+ super().close()
-
- # Read
-
- def read(self):
- if self.closed:
- return []
-
- conn, addr = self.socket.accept()
- self.nb_son += 1
- ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
- self.new_server_cb(ss)
-
- return []
+ try:
+ if self._socket_path is not None:
+ os.remove(self._socket_path)
+ except:
+ pass
diff --git a/nemubot/tools/human.py b/nemubot/tools/human.py
index a18cde2..f0e947f 100644
--- a/nemubot/tools/human.py
+++ b/nemubot/tools/human.py
@@ -57,11 +57,15 @@ def word_distance(str1, str2):
return d[len(str1)][len(str2)]
-def guess(pattern, expect):
+def guess(pattern, expect, max_depth=0):
+ if max_depth == 0:
+ max_depth = 1 + len(pattern) / 4
+ elif max_depth <= -1:
+ max_depth = len(pattern) - max_depth
if len(expect):
se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1])
_, m = se[0]
for e, wd in se:
- if wd > m or wd > 1 + len(pattern) / 4:
+ if wd > m or wd > max_depth:
break
yield e