Add type for use with mypy
This commit is contained in:
parent
a8706d6213
commit
39936e9d39
@ -17,12 +17,15 @@
|
|||||||
__version__ = '4.0.dev3'
|
__version__ = '4.0.dev3'
|
||||||
__author__ = 'nemunaire'
|
__author__ = 'nemunaire'
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from nemubot.modulecontext import ModuleContext
|
from nemubot.modulecontext import ModuleContext
|
||||||
|
|
||||||
context = ModuleContext(None, None)
|
context = ModuleContext(None, None)
|
||||||
|
|
||||||
|
|
||||||
def requires_version(min=None, max=None):
|
def requires_version(min: Optional[int] = None,
|
||||||
|
max: Optional[int] = None) -> None:
|
||||||
"""Raise ImportError if the current version is not in the given range
|
"""Raise ImportError if the current version is not in the given range
|
||||||
|
|
||||||
Keyword arguments:
|
Keyword arguments:
|
||||||
@ -39,7 +42,7 @@ def requires_version(min=None, max=None):
|
|||||||
"but this is nemubot v%s." % (str(max), __version__))
|
"but this is nemubot v%s." % (str(max), __version__))
|
||||||
|
|
||||||
|
|
||||||
def attach(pid, socketfile):
|
def attach(pid: int, socketfile: str) -> int:
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -98,7 +101,7 @@ def attach(pid, socketfile):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def daemonize():
|
def daemonize() -> None:
|
||||||
"""Detach the running process to run as a daemon
|
"""Detach the running process to run as a daemon
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Nemubot is a smart and modulable IM bot.
|
# Nemubot is a smart and modulable IM bot.
|
||||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -14,7 +14,7 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@ -152,6 +152,7 @@ def main():
|
|||||||
# Signals handling
|
# Signals handling
|
||||||
def sigtermhandler(signum, frame):
|
def sigtermhandler(signum, frame):
|
||||||
"""On SIGTERM and SIGINT, quit nicely"""
|
"""On SIGTERM and SIGINT, quit nicely"""
|
||||||
|
sigusr1handler(signum, frame)
|
||||||
context.quit()
|
context.quit()
|
||||||
signal.signal(signal.SIGINT, sigtermhandler)
|
signal.signal(signal.SIGINT, sigtermhandler)
|
||||||
signal.signal(signal.SIGTERM, sigtermhandler)
|
signal.signal(signal.SIGTERM, sigtermhandler)
|
||||||
@ -170,17 +171,23 @@ def main():
|
|||||||
|
|
||||||
def sigusr1handler(signum, frame):
|
def sigusr1handler(signum, frame):
|
||||||
"""On SIGHUSR1, display stacktraces"""
|
"""On SIGHUSR1, display stacktraces"""
|
||||||
import traceback
|
import threading, traceback
|
||||||
for threadId, stack in sys._current_frames().items():
|
for threadId, stack in sys._current_frames().items():
|
||||||
logger.debug("########### Thread %d:\n%s",
|
thName = "#%d" % threadId
|
||||||
threadId,
|
for th in threading.enumerate():
|
||||||
|
if th.ident == threadId:
|
||||||
|
thName = th.name
|
||||||
|
break
|
||||||
|
logger.debug("########### Thread %s:\n%s",
|
||||||
|
thName,
|
||||||
"".join(traceback.format_stack(stack)))
|
"".join(traceback.format_stack(stack)))
|
||||||
signal.signal(signal.SIGUSR1, sigusr1handler)
|
signal.signal(signal.SIGUSR1, sigusr1handler)
|
||||||
|
|
||||||
if args.socketfile:
|
if args.socketfile:
|
||||||
from nemubot.server.socket import SocketListener
|
from nemubot.server.socket import UnixSocketListener
|
||||||
context.add_server(SocketListener(context.add_server, "master_socket",
|
context.add_server(UnixSocketListener(new_server_cb=context.add_server,
|
||||||
sock_location=args.socketfile))
|
location=args.socketfile,
|
||||||
|
name="master_socket"))
|
||||||
|
|
||||||
# context can change when performing an hotswap, always join the latest context
|
# context can change when performing an hotswap, always join the latest context
|
||||||
oldcontext = None
|
oldcontext = None
|
||||||
|
190
nemubot/bot.py
190
nemubot/bot.py
@ -15,9 +15,13 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
|
from multiprocessing import JoinableQueue
|
||||||
import threading
|
import threading
|
||||||
|
import select
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any, Mapping, Optional, Sequence
|
||||||
|
|
||||||
from nemubot import __version__
|
from nemubot import __version__
|
||||||
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
|
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
|
||||||
@ -26,13 +30,23 @@ import nemubot.hooks
|
|||||||
|
|
||||||
logger = logging.getLogger("nemubot")
|
logger = logging.getLogger("nemubot")
|
||||||
|
|
||||||
|
sync_queue = JoinableQueue()
|
||||||
|
|
||||||
|
def sync_act(*args):
|
||||||
|
if isinstance(act, bytes):
|
||||||
|
act = act.decode()
|
||||||
|
sync_queue.put(act)
|
||||||
|
|
||||||
|
|
||||||
class Bot(threading.Thread):
|
class Bot(threading.Thread):
|
||||||
|
|
||||||
"""Class containing the bot context and ensuring key goals"""
|
"""Class containing the bot context and ensuring key goals"""
|
||||||
|
|
||||||
def __init__(self, ip="127.0.0.1", modules_paths=list(),
|
def __init__(self,
|
||||||
data_store=datastore.Abstract(), verbosity=0):
|
ip: Optional[ipaddress] = None,
|
||||||
|
modules_paths: Sequence[str] = list(),
|
||||||
|
data_store: Optional[datastore.Abstract] = None,
|
||||||
|
verbosity: int = 0):
|
||||||
"""Initialize the bot context
|
"""Initialize the bot context
|
||||||
|
|
||||||
Keyword arguments:
|
Keyword arguments:
|
||||||
@ -42,7 +56,7 @@ class Bot(threading.Thread):
|
|||||||
verbosity -- verbosity level
|
verbosity -- verbosity level
|
||||||
"""
|
"""
|
||||||
|
|
||||||
threading.Thread.__init__(self)
|
super().__init__(name="Nemubot main")
|
||||||
|
|
||||||
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
|
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
|
||||||
__version__,
|
__version__,
|
||||||
@ -52,16 +66,16 @@ class Bot(threading.Thread):
|
|||||||
self.stop = None
|
self.stop = None
|
||||||
|
|
||||||
# External IP for accessing this bot
|
# External IP for accessing this bot
|
||||||
import ipaddress
|
self.ip = ip if ip is not None else ipaddress.ip_address("127.0.0.1")
|
||||||
self.ip = ipaddress.ip_address(ip)
|
|
||||||
|
|
||||||
# Context paths
|
# Context paths
|
||||||
self.modules_paths = modules_paths
|
self.modules_paths = modules_paths
|
||||||
self.datastore = data_store
|
self.datastore = data_store if data_store is not None else datastore.Abstract()
|
||||||
self.datastore.open()
|
self.datastore.open()
|
||||||
|
|
||||||
# Keep global context: servers and modules
|
# Keep global context: servers and modules
|
||||||
self.servers = dict()
|
self._poll = select.poll()
|
||||||
|
self.servers = dict() # types: Mapping[str, AbstractServer]
|
||||||
self.modules = dict()
|
self.modules = dict()
|
||||||
self.modules_configuration = dict()
|
self.modules_configuration = dict()
|
||||||
|
|
||||||
@ -138,61 +152,77 @@ class Bot(threading.Thread):
|
|||||||
self.cnsr_queue = Queue()
|
self.cnsr_queue = Queue()
|
||||||
self.cnsr_thrd = list()
|
self.cnsr_thrd = list()
|
||||||
self.cnsr_thrd_size = -1
|
self.cnsr_thrd_size = -1
|
||||||
# Synchrone actions to be treated by main thread
|
|
||||||
self.sync_queue = Queue()
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
from select import select
|
self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
|
||||||
from nemubot.server import _lock, _rlist, _wlist, _xlist
|
|
||||||
|
|
||||||
logger.info("Starting main loop")
|
logger.info("Starting main loop")
|
||||||
self.stop = False
|
self.stop = False
|
||||||
while not self.stop:
|
while not self.stop:
|
||||||
with _lock:
|
for fd, flag in self._poll.poll():
|
||||||
try:
|
print("poll")
|
||||||
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
|
# Handle internal socket passing orders
|
||||||
except:
|
if fd != sync_queue._reader.fileno():
|
||||||
logger.error("Something went wrong in select")
|
srv = self.servers[fd]
|
||||||
fnd_smth = False
|
|
||||||
# Looking for invalid server
|
|
||||||
for r in _rlist:
|
|
||||||
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
|
|
||||||
_rlist.remove(r)
|
|
||||||
logger.error("Found invalid object in _rlist: " + str(r))
|
|
||||||
fnd_smth = True
|
|
||||||
for w in _wlist:
|
|
||||||
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
|
|
||||||
_wlist.remove(w)
|
|
||||||
logger.error("Found invalid object in _wlist: " + str(w))
|
|
||||||
fnd_smth = True
|
|
||||||
for x in _xlist:
|
|
||||||
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
|
|
||||||
_xlist.remove(x)
|
|
||||||
logger.error("Found invalid object in _xlist: " + str(x))
|
|
||||||
fnd_smth = True
|
|
||||||
if not fnd_smth:
|
|
||||||
logger.exception("Can't continue, sorry")
|
|
||||||
self.quit()
|
|
||||||
continue
|
|
||||||
|
|
||||||
for x in xl:
|
if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
|
||||||
try:
|
try:
|
||||||
x.exception()
|
srv.exception(flag)
|
||||||
except:
|
except:
|
||||||
logger.exception("Uncatched exception on server exception")
|
logger.exception("Uncatched exception on server exception")
|
||||||
for w in wl:
|
|
||||||
|
if srv.fileno() > 0:
|
||||||
|
if flag & (select.POLLOUT):
|
||||||
try:
|
try:
|
||||||
w.write_select()
|
srv.async_write()
|
||||||
except:
|
except:
|
||||||
logger.exception("Uncatched exception on server write")
|
logger.exception("Uncatched exception on server write")
|
||||||
for r in rl:
|
|
||||||
for i in r.read():
|
if flag & (select.POLLIN | select.POLLPRI):
|
||||||
try:
|
try:
|
||||||
self.receive_message(r, i)
|
for i in srv.async_read():
|
||||||
|
self.receive_message(srv, i)
|
||||||
except:
|
except:
|
||||||
logger.exception("Uncatched exception on server read")
|
logger.exception("Uncatched exception on server read")
|
||||||
|
|
||||||
|
else:
|
||||||
|
del self.servers[fd]
|
||||||
|
|
||||||
|
|
||||||
|
# Always check the sync queue
|
||||||
|
while not sync_queue.empty():
|
||||||
|
import shlex
|
||||||
|
args = shlex.split(sync_queue.get())
|
||||||
|
action = args.pop(0)
|
||||||
|
|
||||||
|
logger.info("action: %s: %s", action, args)
|
||||||
|
|
||||||
|
if action == "sckt" and len(args) >= 2:
|
||||||
|
try:
|
||||||
|
if args[0] == "write":
|
||||||
|
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
|
||||||
|
elif args[0] == "unwrite":
|
||||||
|
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
|
||||||
|
|
||||||
|
elif args[0] == "register":
|
||||||
|
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
|
||||||
|
elif args[0] == "unregister":
|
||||||
|
self._poll.unregister(int(args[1]))
|
||||||
|
except:
|
||||||
|
logger.exception("Unhandled excpetion during action:")
|
||||||
|
|
||||||
|
elif action == "exit":
|
||||||
|
self.quit()
|
||||||
|
|
||||||
|
elif action == "loadconf":
|
||||||
|
for path in action.args:
|
||||||
|
logger.debug("Load configuration from %s", path)
|
||||||
|
self.load_file(path)
|
||||||
|
logger.info("Configurations successfully loaded")
|
||||||
|
|
||||||
|
sync_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
# Launch new consumer threads if necessary
|
# Launch new consumer threads if necessary
|
||||||
while self.cnsr_queue.qsize() > self.cnsr_thrd_size:
|
while self.cnsr_queue.qsize() > self.cnsr_thrd_size:
|
||||||
@ -202,17 +232,6 @@ class Bot(threading.Thread):
|
|||||||
c = Consumer(self)
|
c = Consumer(self)
|
||||||
self.cnsr_thrd.append(c)
|
self.cnsr_thrd.append(c)
|
||||||
c.start()
|
c.start()
|
||||||
|
|
||||||
while self.sync_queue.qsize() > 0:
|
|
||||||
action = self.sync_queue.get_nowait()
|
|
||||||
if action[0] == "exit":
|
|
||||||
self.quit()
|
|
||||||
elif action[0] == "loadconf":
|
|
||||||
for path in action[1:]:
|
|
||||||
logger.debug("Load configuration from %s", path)
|
|
||||||
self.load_file(path)
|
|
||||||
logger.info("Configurations successfully loaded")
|
|
||||||
self.sync_queue.task_done()
|
|
||||||
logger.info("Ending main loop")
|
logger.info("Ending main loop")
|
||||||
|
|
||||||
|
|
||||||
@ -414,10 +433,11 @@ class Bot(threading.Thread):
|
|||||||
autoconnect -- connect after add?
|
autoconnect -- connect after add?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if srv.fileno not in self.servers:
|
fileno = srv.fileno()
|
||||||
self.servers[srv.fileno] = srv
|
if fileno not in self.servers:
|
||||||
|
self.servers[fileno] = srv
|
||||||
if autoconnect and not hasattr(self, "noautoconnect"):
|
if autoconnect and not hasattr(self, "noautoconnect"):
|
||||||
srv.open()
|
srv.connect()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -439,10 +459,10 @@ class Bot(threading.Thread):
|
|||||||
__import__(name)
|
__import__(name)
|
||||||
|
|
||||||
|
|
||||||
def add_module(self, module):
|
def add_module(self, mdl: Any):
|
||||||
"""Add a module to the context, if already exists, unload the
|
"""Add a module to the context, if already exists, unload the
|
||||||
old one before"""
|
old one before"""
|
||||||
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
|
module_name = mdl.__spec__.name if hasattr(mdl, "__spec__") else mdl.__name__
|
||||||
|
|
||||||
if hasattr(self, "stop") and self.stop:
|
if hasattr(self, "stop") and self.stop:
|
||||||
logger.warn("The bot is stopped, can't register new modules")
|
logger.warn("The bot is stopped, can't register new modules")
|
||||||
@ -454,40 +474,40 @@ class Bot(threading.Thread):
|
|||||||
|
|
||||||
# Overwrite print built-in
|
# Overwrite print built-in
|
||||||
def prnt(*args):
|
def prnt(*args):
|
||||||
if hasattr(module, "logger"):
|
if hasattr(mdl, "logger"):
|
||||||
module.logger.info(" ".join([str(s) for s in args]))
|
mdl.logger.info(" ".join([str(s) for s in args]))
|
||||||
else:
|
else:
|
||||||
logger.info("[%s] %s", module_name, " ".join([str(s) for s in args]))
|
logger.info("[%s] %s", module_name, " ".join([str(s) for s in args]))
|
||||||
module.print = prnt
|
mdl.print = prnt
|
||||||
|
|
||||||
# Create module context
|
# Create module context
|
||||||
from nemubot.modulecontext import ModuleContext
|
from nemubot.modulecontext import ModuleContext
|
||||||
module.__nemubot_context__ = ModuleContext(self, module)
|
mdl.__nemubot_context__ = ModuleContext(self, mdl)
|
||||||
|
|
||||||
if not hasattr(module, "logger"):
|
if not hasattr(mdl, "logger"):
|
||||||
module.logger = logging.getLogger("nemubot.module." + module_name)
|
mdl.logger = logging.getLogger("nemubot.module." + module_name)
|
||||||
|
|
||||||
# Replace imported context by real one
|
# Replace imported context by real one
|
||||||
for attr in module.__dict__:
|
for attr in mdl.__dict__:
|
||||||
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext:
|
if attr != "__nemubot_context__" and type(mdl.__dict__[attr]) == ModuleContext:
|
||||||
module.__dict__[attr] = module.__nemubot_context__
|
mdl.__dict__[attr] = mdl.__nemubot_context__
|
||||||
|
|
||||||
# Register decorated functions
|
# Register decorated functions
|
||||||
import nemubot.hooks
|
import nemubot.hooks
|
||||||
for s, h in nemubot.hooks.hook.last_registered:
|
for s, h in nemubot.hooks.hook.last_registered:
|
||||||
module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
|
mdl.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
|
||||||
nemubot.hooks.hook.last_registered = []
|
nemubot.hooks.hook.last_registered = []
|
||||||
|
|
||||||
# Launch the module
|
# Launch the module
|
||||||
if hasattr(module, "load"):
|
if hasattr(mdl, "load"):
|
||||||
try:
|
try:
|
||||||
module.load(module.__nemubot_context__)
|
mdl.load(mdl.__nemubot_context__)
|
||||||
except:
|
except:
|
||||||
module.__nemubot_context__.unload()
|
mdl.__nemubot_context__.unload()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Save a reference to the module
|
# Save a reference to the module
|
||||||
self.modules[module_name] = module
|
self.modules[module_name] = mdl
|
||||||
|
|
||||||
|
|
||||||
def unload_module(self, name):
|
def unload_module(self, name):
|
||||||
@ -530,28 +550,28 @@ class Bot(threading.Thread):
|
|||||||
def quit(self):
|
def quit(self):
|
||||||
"""Save and unload modules and disconnect servers"""
|
"""Save and unload modules and disconnect servers"""
|
||||||
|
|
||||||
self.datastore.close()
|
|
||||||
|
|
||||||
if self.event_timer is not None:
|
if self.event_timer is not None:
|
||||||
logger.info("Stop the event timer...")
|
logger.info("Stop the event timer...")
|
||||||
self.event_timer.cancel()
|
self.event_timer.cancel()
|
||||||
|
|
||||||
|
logger.info("Save and unload all modules...")
|
||||||
|
for mod in self.modules.items():
|
||||||
|
self.unload_module(mod)
|
||||||
|
|
||||||
|
logger.info("Close all servers connection...")
|
||||||
|
for k in self.servers:
|
||||||
|
self.servers[k].close()
|
||||||
|
|
||||||
logger.info("Stop consumers")
|
logger.info("Stop consumers")
|
||||||
k = self.cnsr_thrd
|
k = self.cnsr_thrd
|
||||||
for cnsr in k:
|
for cnsr in k:
|
||||||
cnsr.stop = True
|
cnsr.stop = True
|
||||||
|
|
||||||
logger.info("Save and unload all modules...")
|
self.datastore.close()
|
||||||
k = list(self.modules.keys())
|
|
||||||
for mod in k:
|
|
||||||
self.unload_module(mod)
|
|
||||||
|
|
||||||
logger.info("Close all servers connection...")
|
|
||||||
k = list(self.servers.keys())
|
|
||||||
for srv in k:
|
|
||||||
self.servers[srv].close()
|
|
||||||
|
|
||||||
self.stop = True
|
self.stop = True
|
||||||
|
sync_queue.put("end")
|
||||||
|
sync_queue.join()
|
||||||
|
|
||||||
|
|
||||||
# Treatment
|
# Treatment
|
||||||
|
@ -15,13 +15,18 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nemubot.message import Abstract as AbstractMessage
|
||||||
|
|
||||||
class Channel:
|
class Channel:
|
||||||
|
|
||||||
"""A chat room"""
|
"""A chat room"""
|
||||||
|
|
||||||
def __init__(self, name, password=None, encoding=None):
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
encoding: Optional[str] = None):
|
||||||
"""Initialize the channel
|
"""Initialize the channel
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -37,7 +42,8 @@ class Channel:
|
|||||||
self.topic = ""
|
self.topic = ""
|
||||||
self.logger = logging.getLogger("nemubot.channel." + name)
|
self.logger = logging.getLogger("nemubot.channel." + name)
|
||||||
|
|
||||||
def treat(self, cmd, msg):
|
|
||||||
|
def treat(self, cmd: str, msg: AbstractMessage) -> None:
|
||||||
"""Treat a incoming IRC command
|
"""Treat a incoming IRC command
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -60,7 +66,8 @@ class Channel:
|
|||||||
elif cmd == "TOPIC":
|
elif cmd == "TOPIC":
|
||||||
self.topic = self.text
|
self.topic = self.text
|
||||||
|
|
||||||
def join(self, nick, level=0):
|
|
||||||
|
def join(self, nick: str, level: int = 0) -> None:
|
||||||
"""Someone join the channel
|
"""Someone join the channel
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
@ -71,7 +78,8 @@ class Channel:
|
|||||||
self.logger.debug("%s join", nick)
|
self.logger.debug("%s join", nick)
|
||||||
self.people[nick] = level
|
self.people[nick] = level
|
||||||
|
|
||||||
def chtopic(self, newtopic):
|
|
||||||
|
def chtopic(self, newtopic: str) -> None:
|
||||||
"""Send command to change the topic
|
"""Send command to change the topic
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -81,7 +89,8 @@ class Channel:
|
|||||||
self.srv.send_msg(self.name, newtopic, "TOPIC")
|
self.srv.send_msg(self.name, newtopic, "TOPIC")
|
||||||
self.topic = newtopic
|
self.topic = newtopic
|
||||||
|
|
||||||
def nick(self, oldnick, newnick):
|
|
||||||
|
def nick(self, oldnick: str, newnick: str) -> None:
|
||||||
"""Someone change his nick
|
"""Someone change his nick
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -95,7 +104,8 @@ class Channel:
|
|||||||
del self.people[oldnick]
|
del self.people[oldnick]
|
||||||
self.people[newnick] = lvl
|
self.people[newnick] = lvl
|
||||||
|
|
||||||
def part(self, nick):
|
|
||||||
|
def part(self, nick: str) -> None:
|
||||||
"""Someone leave the channel
|
"""Someone leave the channel
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
@ -106,7 +116,8 @@ class Channel:
|
|||||||
self.logger.debug("%s has left", nick)
|
self.logger.debug("%s has left", nick)
|
||||||
del self.people[nick]
|
del self.people[nick]
|
||||||
|
|
||||||
def mode(self, msg):
|
|
||||||
|
def mode(self, msg: AbstractMessage) -> None:
|
||||||
"""Channel or user mode change
|
"""Channel or user mode change
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
@ -132,7 +143,8 @@ class Channel:
|
|||||||
elif msg.text[0] == "-v":
|
elif msg.text[0] == "-v":
|
||||||
self.people[msg.nick] &= ~1
|
self.people[msg.nick] &= ~1
|
||||||
|
|
||||||
def parse332(self, msg):
|
|
||||||
|
def parse332(self, msg: AbstractMessage) -> None:
|
||||||
"""Parse RPL_TOPIC message
|
"""Parse RPL_TOPIC message
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
@ -141,7 +153,8 @@ class Channel:
|
|||||||
|
|
||||||
self.topic = msg.text
|
self.topic = msg.text
|
||||||
|
|
||||||
def parse353(self, msg):
|
|
||||||
|
def parse353(self, msg: AbstractMessage) -> None:
|
||||||
"""Parse RPL_ENDOFWHO message
|
"""Parse RPL_ENDOFWHO message
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
def get_boolean(s):
|
def get_boolean(s) -> bool:
|
||||||
if isinstance(s, bool):
|
if isinstance(s, bool):
|
||||||
return s
|
return s
|
||||||
else:
|
else:
|
||||||
|
@ -16,5 +16,5 @@
|
|||||||
|
|
||||||
class Include:
|
class Include:
|
||||||
|
|
||||||
def __init__(self, path):
|
def __init__(self, path: str):
|
||||||
self.path = path
|
self.path = path
|
||||||
|
@ -20,7 +20,11 @@ from nemubot.datastore.nodes.generic import GenericNode
|
|||||||
|
|
||||||
class Module(GenericNode):
|
class Module(GenericNode):
|
||||||
|
|
||||||
def __init__(self, name, autoload=True, **kwargs):
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
autoload: bool = True,
|
||||||
|
**kwargs):
|
||||||
super().__init__(None, **kwargs)
|
super().__init__(None, **kwargs)
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.autoload = get_boolean(autoload)
|
self.autoload = get_boolean(autoload)
|
||||||
|
@ -14,15 +14,23 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from nemubot.config.include import Include
|
from typing import Optional, Sequence, Union
|
||||||
from nemubot.config.module import Module
|
|
||||||
from nemubot.config.server import Server
|
import nemubot.config.include
|
||||||
|
import nemubot.config.module
|
||||||
|
import nemubot.config.server
|
||||||
|
|
||||||
|
|
||||||
class Nemubot:
|
class Nemubot:
|
||||||
|
|
||||||
def __init__(self, nick="nemubot", realname="nemubot", owner=None,
|
def __init__(self,
|
||||||
ip=None, ssl=False, caps=None, encoding="utf-8"):
|
nick: str = "nemubot",
|
||||||
|
realname: str = "nemubot",
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
ip: Optional[str] = None,
|
||||||
|
ssl: bool = False,
|
||||||
|
caps: Optional[Sequence[str]] = None,
|
||||||
|
encoding: str = "utf-8"):
|
||||||
self.nick = nick
|
self.nick = nick
|
||||||
self.realname = realname
|
self.realname = realname
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
@ -34,13 +42,13 @@ class Nemubot:
|
|||||||
self.includes = []
|
self.includes = []
|
||||||
|
|
||||||
|
|
||||||
def addChild(self, name, child):
|
def addChild(self, name: str, child: Union[nemubot.config.module.Module, nemubot.config.server.Server, nemubot.config.include.Include]):
|
||||||
if name == "module" and isinstance(child, Module):
|
if name == "module" and isinstance(child, nemubot.config.module.Module):
|
||||||
self.modules.append(child)
|
self.modules.append(child)
|
||||||
return True
|
return True
|
||||||
elif name == "server" and isinstance(child, Server):
|
elif name == "server" and isinstance(child, nemubot.config.server.Server):
|
||||||
self.servers.append(child)
|
self.servers.append(child)
|
||||||
return True
|
return True
|
||||||
elif name == "include" and isinstance(child, Include):
|
elif name == "include" and isinstance(child, nemubot.config.include.Include):
|
||||||
self.includes.append(child)
|
self.includes.append(child)
|
||||||
return True
|
return True
|
||||||
|
@ -14,12 +14,19 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
from nemubot.channel import Channel
|
from nemubot.channel import Channel
|
||||||
|
import nemubot.config.nemubot
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
|
|
||||||
def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs):
|
def __init__(self,
|
||||||
|
uri: str = "irc://nemubot@localhost/",
|
||||||
|
autoconnect: bool = True,
|
||||||
|
caps: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs):
|
||||||
self.uri = uri
|
self.uri = uri
|
||||||
self.autoconnect = autoconnect
|
self.autoconnect = autoconnect
|
||||||
self.caps = caps.split(" ") if caps is not None else []
|
self.caps = caps.split(" ") if caps is not None else []
|
||||||
@ -27,7 +34,7 @@ class Server:
|
|||||||
self.channels = []
|
self.channels = []
|
||||||
|
|
||||||
|
|
||||||
def addChild(self, name, child):
|
def addChild(self, name: str, child: Channel):
|
||||||
if name == "channel" and isinstance(child, Channel):
|
if name == "channel" and isinstance(child, Channel):
|
||||||
self.channels.append(child)
|
self.channels.append(child)
|
||||||
return True
|
return True
|
||||||
|
@ -17,6 +17,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from nemubot.bot import Bot
|
||||||
|
from nemubot.event import ModuleEvent
|
||||||
|
from nemubot.message.abstract import Abstract as AbstractMessage
|
||||||
|
from nemubot.server.abstract import AbstractServer
|
||||||
|
|
||||||
logger = logging.getLogger("nemubot.consumer")
|
logger = logging.getLogger("nemubot.consumer")
|
||||||
|
|
||||||
@ -25,18 +31,15 @@ class MessageConsumer:
|
|||||||
|
|
||||||
"""Store a message before treating"""
|
"""Store a message before treating"""
|
||||||
|
|
||||||
def __init__(self, srv, msg):
|
def __init__(self, srv: AbstractServer, msg: AbstractMessage):
|
||||||
self.srv = srv
|
self.srv = srv
|
||||||
self.orig = msg
|
self.orig = msg
|
||||||
|
|
||||||
|
|
||||||
def run(self, context):
|
def run(self, context: Bot) -> None:
|
||||||
"""Create, parse and treat the message"""
|
"""Create, parse and treat the message"""
|
||||||
|
|
||||||
from nemubot.bot import Bot
|
msgs = [] # type: List[AbstractMessage]
|
||||||
assert isinstance(context, Bot)
|
|
||||||
|
|
||||||
msgs = []
|
|
||||||
|
|
||||||
# Parse the message
|
# Parse the message
|
||||||
try:
|
try:
|
||||||
@ -55,8 +58,6 @@ class MessageConsumer:
|
|||||||
if hasattr(msg, "frm_owner"):
|
if hasattr(msg, "frm_owner"):
|
||||||
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
|
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
|
||||||
|
|
||||||
from nemubot.server.abstract import AbstractServer
|
|
||||||
|
|
||||||
# Treat the message
|
# Treat the message
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
for res in context.treater.treat_msg(msg):
|
for res in context.treater.treat_msg(msg):
|
||||||
@ -87,12 +88,12 @@ class EventConsumer:
|
|||||||
|
|
||||||
"""Store a event before treating"""
|
"""Store a event before treating"""
|
||||||
|
|
||||||
def __init__(self, evt, timeout=20):
|
def __init__(self, evt: ModuleEvent, timeout: int = 20):
|
||||||
self.evt = evt
|
self.evt = evt
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
|
||||||
def run(self, context):
|
def run(self, context: Bot) -> None:
|
||||||
try:
|
try:
|
||||||
self.evt.check()
|
self.evt.check()
|
||||||
except:
|
except:
|
||||||
@ -113,13 +114,13 @@ class Consumer(threading.Thread):
|
|||||||
|
|
||||||
"""Dequeue and exec requested action"""
|
"""Dequeue and exec requested action"""
|
||||||
|
|
||||||
def __init__(self, context):
|
def __init__(self, context: Bot):
|
||||||
self.context = context
|
self.context = context
|
||||||
self.stop = False
|
self.stop = False
|
||||||
super().__init__(name="Nemubot consumer")
|
super().__init__(name="Nemubot consumer")
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> None:
|
||||||
try:
|
try:
|
||||||
while not self.stop:
|
while not self.stop:
|
||||||
stm = self.context.cnsr_queue.get(True, 1)
|
stm = self.context.cnsr_queue.get(True, 1)
|
||||||
|
@ -25,11 +25,14 @@ class Abstract:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def open(self):
|
|
||||||
return
|
|
||||||
|
|
||||||
def close(self):
|
def open(self) -> bool:
|
||||||
return
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def close(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load(self, module):
|
def load(self, module):
|
||||||
"""Load data for the given module
|
"""Load data for the given module
|
||||||
@ -43,6 +46,7 @@ class Abstract:
|
|||||||
|
|
||||||
return self.new()
|
return self.new()
|
||||||
|
|
||||||
|
|
||||||
def save(self, module, data):
|
def save(self, module, data):
|
||||||
"""Load data for the given module
|
"""Load data for the given module
|
||||||
|
|
||||||
@ -56,9 +60,11 @@ class Abstract:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.open()
|
self.open()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
self.close()
|
self.close()
|
||||||
|
@ -14,6 +14,9 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from typing import Any, Mapping, Sequence
|
||||||
|
|
||||||
|
from nemubot.datastore.nodes.generic import ParsingNode
|
||||||
from nemubot.datastore.nodes.serializable import Serializable
|
from nemubot.datastore.nodes.serializable import Serializable
|
||||||
|
|
||||||
|
|
||||||
@ -25,24 +28,24 @@ class ListNode(Serializable):
|
|||||||
serializetag = "list"
|
serializetag = "list"
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.items = list()
|
self.items = list() # type: Sequence
|
||||||
|
|
||||||
|
|
||||||
def addChild(self, name, child):
|
def addChild(self, name: str, child) -> bool:
|
||||||
self.items.append(child)
|
self.items.append(child)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def parsedForm(self):
|
def parsedForm(self) -> Sequence:
|
||||||
return self.items
|
return self.items
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.items)
|
return len(self.items)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item: int) -> Any:
|
||||||
return self.items[item]
|
return self.items[item]
|
||||||
|
|
||||||
def __setitem__(self, item, v):
|
def __setitem__(self, item: int, v: Any) -> None:
|
||||||
self.items[item] = v
|
self.items[item] = v
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
@ -52,8 +55,7 @@ class ListNode(Serializable):
|
|||||||
return self.items.__repr__()
|
return self.items.__repr__()
|
||||||
|
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self) -> ParsingNode:
|
||||||
from nemubot.datastore.nodes.generic import ParsingNode
|
|
||||||
node = ParsingNode(tag=self.serializetag)
|
node = ParsingNode(tag=self.serializetag)
|
||||||
for i in self.items:
|
for i in self.items:
|
||||||
node.children.append(ParsingNode.serialize_node(i))
|
node.children.append(ParsingNode.serialize_node(i))
|
||||||
@ -72,12 +74,12 @@ class DictNode(Serializable):
|
|||||||
self._cur = None
|
self._cur = None
|
||||||
|
|
||||||
|
|
||||||
def startElement(self, name, attrs):
|
def startElement(self, name: str, attrs: Mapping[str, str]):
|
||||||
if self._cur is None and "key" in attrs:
|
if self._cur is None and "key" in attrs:
|
||||||
self._cur = attrs["key"]
|
self._cur = attrs["key"]
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def addChild(self, name, child):
|
def addChild(self, name: str, child: Any):
|
||||||
if self._cur is None:
|
if self._cur is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -85,24 +87,24 @@ class DictNode(Serializable):
|
|||||||
self._cur = None
|
self._cur = None
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def parsedForm(self):
|
def parsedForm(self) -> Mapping:
|
||||||
return self.items
|
return self.items
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item: str) -> Any:
|
||||||
return self.items[item]
|
return self.items[item]
|
||||||
|
|
||||||
def __setitem__(self, item, v):
|
def __setitem__(self, item: str, v: str) -> None:
|
||||||
self.items[item] = v
|
self.items[item] = v
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item: str) -> bool:
|
||||||
return item in self.items
|
return item in self.items
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return self.items.__repr__()
|
return self.items.__repr__()
|
||||||
|
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self) -> ParsingNode:
|
||||||
from nemubot.datastore.nodes.generic import ParsingNode
|
from nemubot.datastore.nodes.generic import ParsingNode
|
||||||
node = ParsingNode(tag=self.serializetag)
|
node = ParsingNode(tag=self.serializetag)
|
||||||
for k in self.items:
|
for k in self.items:
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from typing import Any, Optional, Mapping, Union
|
||||||
|
|
||||||
from nemubot.datastore.nodes.serializable import Serializable
|
from nemubot.datastore.nodes.serializable import Serializable
|
||||||
|
|
||||||
|
|
||||||
@ -22,41 +24,44 @@ class ParsingNode:
|
|||||||
"""Allow any kind of subtags, just keep parsed ones
|
"""Allow any kind of subtags, just keep parsed ones
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tag=None, **kwargs):
|
def __init__(self,
|
||||||
|
tag: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
self.tag = tag
|
self.tag = tag
|
||||||
self.attrs = kwargs
|
self.attrs = kwargs
|
||||||
self.content = ""
|
self.content = ""
|
||||||
self.children = []
|
self.children = []
|
||||||
|
|
||||||
|
|
||||||
def characters(self, content):
|
def characters(self, content: str) -> None:
|
||||||
self.content += content
|
self.content += content
|
||||||
|
|
||||||
|
|
||||||
def addChild(self, name, child):
|
def addChild(self, name: str, child: Any) -> bool:
|
||||||
self.children.append(child)
|
self.children.append(child)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def hasNode(self, nodename):
|
def hasNode(self, nodename: str) -> bool:
|
||||||
return self.getNode(nodename) is not None
|
return self.getNode(nodename) is not None
|
||||||
|
|
||||||
|
|
||||||
def getNode(self, nodename):
|
def getNode(self, nodename: str) -> Optional[Any]:
|
||||||
for c in self.children:
|
for c in self.children:
|
||||||
if c is not None and c.tag == nodename:
|
if c is not None and c.tag == nodename:
|
||||||
return c
|
return c
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item: str) -> Any:
|
||||||
return self.attrs[item]
|
return self.attrs[item]
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item: str) -> bool:
|
||||||
return item in self.attrs
|
return item in self.attrs
|
||||||
|
|
||||||
|
|
||||||
def serialize_node(node, **def_kwargs):
|
def serialize_node(node: Union[Serializable, str, int, float, list, dict],
|
||||||
|
**def_kwargs):
|
||||||
"""Serialize any node or basic data to a ParsingNode instance"""
|
"""Serialize any node or basic data to a ParsingNode instance"""
|
||||||
|
|
||||||
if isinstance(node, Serializable):
|
if isinstance(node, Serializable):
|
||||||
@ -102,13 +107,16 @@ class GenericNode(ParsingNode):
|
|||||||
"""Consider all subtags as dictionnary
|
"""Consider all subtags as dictionnary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tag, **kwargs):
|
def __init__(self,
|
||||||
|
tag: str,
|
||||||
|
**kwargs):
|
||||||
super().__init__(tag, **kwargs)
|
super().__init__(tag, **kwargs)
|
||||||
|
|
||||||
self._cur = None
|
self._cur = None
|
||||||
self._deep_cur = 0
|
self._deep_cur = 0
|
||||||
|
|
||||||
|
|
||||||
def startElement(self, name, attrs):
|
def startElement(self, name: str, attrs: Mapping[str, str]):
|
||||||
if self._cur is None:
|
if self._cur is None:
|
||||||
self._cur = GenericNode(name, **attrs)
|
self._cur = GenericNode(name, **attrs)
|
||||||
self._deep_cur = 0
|
self._deep_cur = 0
|
||||||
@ -118,14 +126,14 @@ class GenericNode(ParsingNode):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def characters(self, content):
|
def characters(self, content: str):
|
||||||
if self._cur is None:
|
if self._cur is None:
|
||||||
super().characters(content)
|
super().characters(content)
|
||||||
else:
|
else:
|
||||||
self._cur.characters(content)
|
self._cur.characters(content)
|
||||||
|
|
||||||
|
|
||||||
def endElement(self, name):
|
def endElement(self, name: str):
|
||||||
if name is None:
|
if name is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from nemubot.datastore.nodes.generic import ParsingNode
|
||||||
from nemubot.datastore.nodes.serializable import Serializable
|
from nemubot.datastore.nodes.serializable import Serializable
|
||||||
|
|
||||||
|
|
||||||
@ -27,15 +28,15 @@ class PythonTypeNode(Serializable):
|
|||||||
self._cnt = ""
|
self._cnt = ""
|
||||||
|
|
||||||
|
|
||||||
def characters(self, content):
|
def characters(self, content: str) -> None:
|
||||||
self._cnt += content
|
self._cnt += content
|
||||||
|
|
||||||
|
|
||||||
def endElement(self, name):
|
def endElement(self, name: str) -> None:
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return self.value.__repr__()
|
return self.value.__repr__()
|
||||||
|
|
||||||
|
|
||||||
@ -50,12 +51,11 @@ class IntNode(PythonTypeNode):
|
|||||||
|
|
||||||
serializetag = "int"
|
serializetag = "int"
|
||||||
|
|
||||||
def endElement(self, name):
|
def endElement(self, name: str) -> bool:
|
||||||
self.value = int(self._cnt)
|
self.value = int(self._cnt)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self) -> ParsingNode:
|
||||||
from nemubot.datastore.nodes.generic import ParsingNode
|
|
||||||
node = ParsingNode(tag=self.serializetag)
|
node = ParsingNode(tag=self.serializetag)
|
||||||
node.content = str(self.value)
|
node.content = str(self.value)
|
||||||
return node
|
return node
|
||||||
@ -65,12 +65,11 @@ class FloatNode(PythonTypeNode):
|
|||||||
|
|
||||||
serializetag = "float"
|
serializetag = "float"
|
||||||
|
|
||||||
def endElement(self, name):
|
def endElement(self, name: str) -> bool:
|
||||||
self.value = float(self._cnt)
|
self.value = float(self._cnt)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self) -> ParsingNode:
|
||||||
from nemubot.datastore.nodes.generic import ParsingNode
|
|
||||||
node = ParsingNode(tag=self.serializetag)
|
node = ParsingNode(tag=self.serializetag)
|
||||||
node.content = str(self.value)
|
node.content = str(self.value)
|
||||||
return node
|
return node
|
||||||
@ -85,7 +84,6 @@ class StringNode(PythonTypeNode):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self):
|
||||||
from nemubot.datastore.nodes.generic import ParsingNode
|
|
||||||
node = ParsingNode(tag=self.serializetag)
|
node = ParsingNode(tag=self.serializetag)
|
||||||
node.content = str(self.value)
|
node.content = str(self.value)
|
||||||
return node
|
return node
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
import fcntl
|
import fcntl
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Mapping
|
||||||
import xml.parsers.expat
|
import xml.parsers.expat
|
||||||
|
|
||||||
from nemubot.datastore.abstract import Abstract
|
from nemubot.datastore.abstract import Abstract
|
||||||
@ -28,7 +29,9 @@ class XML(Abstract):
|
|||||||
|
|
||||||
"""A concrete implementation of a data store that relies on XML files"""
|
"""A concrete implementation of a data store that relies on XML files"""
|
||||||
|
|
||||||
def __init__(self, basedir, rotate=True):
|
def __init__(self,
|
||||||
|
basedir: str,
|
||||||
|
rotate: bool = True):
|
||||||
"""Initialize the datastore
|
"""Initialize the datastore
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -45,7 +48,7 @@ class XML(Abstract):
|
|||||||
"enabled" if self.rotate else "disabled")
|
"enabled" if self.rotate else "disabled")
|
||||||
|
|
||||||
|
|
||||||
def open(self):
|
def open(self) -> bool:
|
||||||
"""Lock the directory"""
|
"""Lock the directory"""
|
||||||
|
|
||||||
if not os.path.isdir(self.basedir):
|
if not os.path.isdir(self.basedir):
|
||||||
@ -75,7 +78,7 @@ class XML(Abstract):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> bool:
|
||||||
"""Release a locked path"""
|
"""Release a locked path"""
|
||||||
|
|
||||||
if hasattr(self, "lock_file"):
|
if hasattr(self, "lock_file"):
|
||||||
@ -91,19 +94,19 @@ class XML(Abstract):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _get_data_file_path(self, module):
|
def _get_data_file_path(self, module: str) -> str:
|
||||||
"""Get the path to the module data file"""
|
"""Get the path to the module data file"""
|
||||||
|
|
||||||
return os.path.join(self.basedir, module + ".xml")
|
return os.path.join(self.basedir, module + ".xml")
|
||||||
|
|
||||||
|
|
||||||
def _get_lock_file_path(self):
|
def _get_lock_file_path(self) -> str:
|
||||||
"""Get the path to the datastore lock file"""
|
"""Get the path to the datastore lock file"""
|
||||||
|
|
||||||
return os.path.join(self.basedir, ".used_by_nemubot")
|
return os.path.join(self.basedir, ".used_by_nemubot")
|
||||||
|
|
||||||
|
|
||||||
def load(self, module, extendsTags={}):
|
def load(self, module: str, extendsTags: Mapping[str, Any] = {}) -> Abstract:
|
||||||
"""Load data for the given module
|
"""Load data for the given module
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
@ -116,7 +119,7 @@ class XML(Abstract):
|
|||||||
|
|
||||||
data_file = self._get_data_file_path(module)
|
data_file = self._get_data_file_path(module)
|
||||||
|
|
||||||
def parse(path):
|
def parse(path: str):
|
||||||
from nemubot.tools.xmlparser import XMLParser
|
from nemubot.tools.xmlparser import XMLParser
|
||||||
from nemubot.datastore.nodes import basic as basicNodes
|
from nemubot.datastore.nodes import basic as basicNodes
|
||||||
from nemubot.datastore.nodes import python as pythonNodes
|
from nemubot.datastore.nodes import python as pythonNodes
|
||||||
@ -156,7 +159,7 @@ class XML(Abstract):
|
|||||||
return Abstract.load(self, module)
|
return Abstract.load(self, module)
|
||||||
|
|
||||||
|
|
||||||
def _rotate(self, path):
|
def _rotate(self, path: str) -> None:
|
||||||
"""Backup given path
|
"""Backup given path
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
@ -173,7 +176,7 @@ class XML(Abstract):
|
|||||||
os.rename(src, dst)
|
os.rename(src, dst)
|
||||||
|
|
||||||
|
|
||||||
def _save_node(self, gen, node):
|
def _save_node(self, gen, node: Any):
|
||||||
from nemubot.datastore.nodes.generic import ParsingNode
|
from nemubot.datastore.nodes.generic import ParsingNode
|
||||||
|
|
||||||
# First, get the serialized form of the node
|
# First, get the serialized form of the node
|
||||||
@ -191,7 +194,7 @@ class XML(Abstract):
|
|||||||
gen.endElement(node.tag)
|
gen.endElement(node.tag)
|
||||||
|
|
||||||
|
|
||||||
def save(self, module, data):
|
def save(self, module: str, data: Any) -> bool:
|
||||||
"""Load data for the given module
|
"""Load data for the given module
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
|
@ -15,15 +15,23 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
class ModuleEvent:
|
class ModuleEvent:
|
||||||
|
|
||||||
"""Representation of a event initiated by a bot module"""
|
"""Representation of a event initiated by a bot module"""
|
||||||
|
|
||||||
def __init__(self, call=None, call_data=None, func=None, func_data=None,
|
def __init__(self,
|
||||||
cmp=None, cmp_data=None, interval=60, offset=0, times=1):
|
call: Callable = None,
|
||||||
|
call_data: Callable = None,
|
||||||
|
func: Callable = None,
|
||||||
|
func_data: Any = None,
|
||||||
|
cmp: Any = None,
|
||||||
|
cmp_data: Any = None,
|
||||||
|
interval: int = 60,
|
||||||
|
offset: int = 0,
|
||||||
|
times: int = 1):
|
||||||
"""Initialize the event
|
"""Initialize the event
|
||||||
|
|
||||||
Keyword arguments:
|
Keyword arguments:
|
||||||
@ -70,8 +78,9 @@ class ModuleEvent:
|
|||||||
# How many times do this event?
|
# How many times do this event?
|
||||||
self.times = times
|
self.times = times
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current(self):
|
def current(self) -> Optional[datetime.datetime]:
|
||||||
"""Return the date of the near check"""
|
"""Return the date of the near check"""
|
||||||
if self.times != 0:
|
if self.times != 0:
|
||||||
if self._end is None:
|
if self._end is None:
|
||||||
@ -79,8 +88,9 @@ class ModuleEvent:
|
|||||||
return self._end
|
return self._end
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next(self):
|
def next(self) -> Optional[datetime.datetime]:
|
||||||
"""Return the date of the next check"""
|
"""Return the date of the next check"""
|
||||||
if self.times != 0:
|
if self.times != 0:
|
||||||
if self._end is None:
|
if self._end is None:
|
||||||
@ -90,14 +100,16 @@ class ModuleEvent:
|
|||||||
return self._end
|
return self._end
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def time_left(self):
|
def time_left(self) -> Union[datetime.datetime, int]:
|
||||||
"""Return the time left before/after the near check"""
|
"""Return the time left before/after the near check"""
|
||||||
if self.current is not None:
|
if self.current is not None:
|
||||||
return self.current - datetime.now(timezone.utc)
|
return self.current - datetime.now(timezone.utc)
|
||||||
return 99999 # TODO: 99999 is not a valid time to return
|
return 99999 # TODO: 99999 is not a valid time to return
|
||||||
|
|
||||||
def check(self):
|
|
||||||
|
def check(self) -> None:
|
||||||
"""Run a check and realized the event if this is time"""
|
"""Run a check and realized the event if this is time"""
|
||||||
|
|
||||||
# Get initial data
|
# Get initial data
|
||||||
|
@ -21,7 +21,7 @@ class HooksManager:
|
|||||||
|
|
||||||
"""Class to manage hooks"""
|
"""Class to manage hooks"""
|
||||||
|
|
||||||
def __init__(self, name="core"):
|
def __init__(self, name: str = "core"):
|
||||||
"""Initialize the manager"""
|
"""Initialize the manager"""
|
||||||
|
|
||||||
self.hooks = dict()
|
self.hooks = dict()
|
||||||
|
@ -18,16 +18,18 @@ from importlib.abc import Finder
|
|||||||
from importlib.machinery import SourceFileLoader
|
from importlib.machinery import SourceFileLoader
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
logger = logging.getLogger("nemubot.importer")
|
logger = logging.getLogger("nemubot.importer")
|
||||||
|
|
||||||
|
|
||||||
class ModuleFinder(Finder):
|
class ModuleFinder(Finder):
|
||||||
|
|
||||||
def __init__(self, modules_paths, add_module):
|
def __init__(self, modules_paths: str, add_module: Callable[]):
|
||||||
self.modules_paths = modules_paths
|
self.modules_paths = modules_paths
|
||||||
self.add_module = add_module
|
self.add_module = add_module
|
||||||
|
|
||||||
|
|
||||||
def find_module(self, fullname, path=None):
|
def find_module(self, fullname, path=None):
|
||||||
# Search only for new nemubot modules (packages init)
|
# Search only for new nemubot modules (packages init)
|
||||||
if path is None:
|
if path is None:
|
||||||
|
@ -31,7 +31,7 @@ PORTS = list()
|
|||||||
|
|
||||||
class DCC(server.AbstractServer):
|
class DCC(server.AbstractServer):
|
||||||
def __init__(self, srv, dest, socket=None):
|
def __init__(self, srv, dest, socket=None):
|
||||||
super().__init__(self)
|
super().__init__(name="Nemubot DCC server")
|
||||||
|
|
||||||
self.error = False # An error has occur, closing the connection?
|
self.error = False # An error has occur, closing the connection?
|
||||||
self.messages = list() # Message queued before connexion
|
self.messages = list() # Message queued before connexion
|
||||||
|
@ -27,10 +27,10 @@ class IRC(SocketServer):
|
|||||||
|
|
||||||
"""Concrete implementation of a connexion to an IRC server"""
|
"""Concrete implementation of a connexion to an IRC server"""
|
||||||
|
|
||||||
def __init__(self, host="localhost", port=6667, ssl=False, owner=None,
|
def __init__(self, host="localhost", port=6667, owner=None,
|
||||||
nick="nemubot", username=None, password=None,
|
nick="nemubot", username=None, password=None,
|
||||||
realname="Nemubot", encoding="utf-8", caps=None,
|
realname="Nemubot", encoding="utf-8", caps=None,
|
||||||
channels=list(), on_connect=None):
|
channels=list(), on_connect=None, **kwargs):
|
||||||
"""Prepare a connection with an IRC server
|
"""Prepare a connection with an IRC server
|
||||||
|
|
||||||
Keyword arguments:
|
Keyword arguments:
|
||||||
@ -54,7 +54,8 @@ class IRC(SocketServer):
|
|||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.realname = realname
|
self.realname = realname
|
||||||
|
|
||||||
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
|
super().__init__(name=self.username + "@" + host + ":" + str(port),
|
||||||
|
host=host, port=port, **kwargs)
|
||||||
self.printer = IRCPrinter
|
self.printer = IRCPrinter
|
||||||
|
|
||||||
self.encoding = encoding
|
self.encoding = encoding
|
||||||
@ -231,20 +232,19 @@ class IRC(SocketServer):
|
|||||||
|
|
||||||
# Open/close
|
# Open/close
|
||||||
|
|
||||||
def open(self):
|
def connect(self):
|
||||||
if super().open():
|
super().connect()
|
||||||
|
|
||||||
if self.password is not None:
|
if self.password is not None:
|
||||||
self.write("PASS :" + self.password)
|
self.write("PASS :" + self.password)
|
||||||
if self.capabilities is not None:
|
if self.capabilities is not None:
|
||||||
self.write("CAP LS")
|
self.write("CAP LS")
|
||||||
self.write("NICK :" + self.nick)
|
self.write("NICK :" + self.nick)
|
||||||
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
|
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if not self.closed:
|
if not self._closed:
|
||||||
self.write("QUIT")
|
self.write("QUIT")
|
||||||
return super().close()
|
return super().close()
|
||||||
|
|
||||||
@ -253,8 +253,8 @@ class IRC(SocketServer):
|
|||||||
|
|
||||||
# Read
|
# Read
|
||||||
|
|
||||||
def read(self):
|
def async_read(self):
|
||||||
for line in super().read():
|
for line in super().async_read():
|
||||||
# PING should be handled here, so start parsing here :/
|
# PING should be handled here, so start parsing here :/
|
||||||
msg = IRCMessage(line, self.encoding)
|
msg = IRCMessage(line, self.encoding)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Nemubot is a smart and modulable IM bot.
|
# Nemubot is a smart and modulable IM bot.
|
||||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -14,20 +14,13 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import threading
|
|
||||||
|
|
||||||
_lock = threading.Lock()
|
def factory(uri, ssl=False, **init_args):
|
||||||
|
|
||||||
# Lists for select
|
|
||||||
_rlist = []
|
|
||||||
_wlist = []
|
|
||||||
_xlist = []
|
|
||||||
|
|
||||||
|
|
||||||
def factory(uri, **init_args):
|
|
||||||
from urllib.parse import urlparse, unquote
|
from urllib.parse import urlparse, unquote
|
||||||
o = urlparse(uri)
|
o = urlparse(uri)
|
||||||
|
|
||||||
|
srv = None
|
||||||
|
|
||||||
if o.scheme == "irc" or o.scheme == "ircs":
|
if o.scheme == "irc" or o.scheme == "ircs":
|
||||||
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
|
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
|
||||||
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
|
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
|
||||||
@ -36,7 +29,7 @@ def factory(uri, **init_args):
|
|||||||
modifiers = o.path.split(",")
|
modifiers = o.path.split(",")
|
||||||
target = unquote(modifiers.pop(0)[1:])
|
target = unquote(modifiers.pop(0)[1:])
|
||||||
|
|
||||||
if o.scheme == "ircs": args["ssl"] = True
|
if o.scheme == "ircs": ssl = True
|
||||||
if o.hostname is not None: args["host"] = o.hostname
|
if o.hostname is not None: args["host"] = o.hostname
|
||||||
if o.port is not None: args["port"] = o.port
|
if o.port is not None: args["port"] = o.port
|
||||||
if o.username is not None: args["username"] = o.username
|
if o.username is not None: args["username"] = o.username
|
||||||
@ -65,6 +58,11 @@ def factory(uri, **init_args):
|
|||||||
args["channels"] = [ target ]
|
args["channels"] = [ target ]
|
||||||
|
|
||||||
from nemubot.server.IRC import IRC as IRCServer
|
from nemubot.server.IRC import IRC as IRCServer
|
||||||
return IRCServer(**args)
|
srv = IRCServer(**args)
|
||||||
else:
|
|
||||||
return None
|
if ssl:
|
||||||
|
import ssl
|
||||||
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
||||||
|
return ctx.wrap_socket(srv)
|
||||||
|
|
||||||
|
return srv
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Nemubot is a smart and modulable IM bot.
|
# Nemubot is a smart and modulable IM bot.
|
||||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -14,34 +14,34 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
|
||||||
from nemubot.server import _lock, _rlist, _wlist, _xlist
|
from nemubot.bot import sync_act
|
||||||
|
|
||||||
# Extends from IOBase in order to be compatible with select function
|
|
||||||
class AbstractServer(io.IOBase):
|
class AbstractServer:
|
||||||
|
|
||||||
"""An abstract server: handle communication with an IM server"""
|
"""An abstract server: handle communication with an IM server"""
|
||||||
|
|
||||||
def __init__(self, name=None, send_callback=None):
|
def __init__(self, name=None, **kwargs):
|
||||||
"""Initialize an abstract server
|
"""Initialize an abstract server
|
||||||
|
|
||||||
Keyword argument:
|
Keyword argument:
|
||||||
send_callback -- Callback when developper want to send a message
|
name -- Identifier of the socket, for convinience
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._name = name
|
self._name = name
|
||||||
|
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.logger = logging.getLogger("nemubot.server." + self.name)
|
self.logger = logging.getLogger("nemubot.server." + str(self.name))
|
||||||
|
self._readbuffer = b''
|
||||||
self._sending_queue = queue.Queue()
|
self._sending_queue = queue.Queue()
|
||||||
if send_callback is not None:
|
|
||||||
self._send_callback = send_callback
|
|
||||||
else:
|
def __del__(self):
|
||||||
self._send_callback = self._write_select
|
print("Server deleted")
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -54,40 +54,28 @@ class AbstractServer(io.IOBase):
|
|||||||
|
|
||||||
# Open/close
|
# Open/close
|
||||||
|
|
||||||
def __enter__(self):
|
def connect(self, *args, **kwargs):
|
||||||
self.open()
|
"""Register the server in _poll"""
|
||||||
return self
|
|
||||||
|
self.logger.info("Opening connection")
|
||||||
|
|
||||||
|
super().connect(*args, **kwargs)
|
||||||
|
|
||||||
|
self._connected()
|
||||||
|
|
||||||
|
def _connected(self):
|
||||||
|
sync_act("sckt register %d" % self.fileno())
|
||||||
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def close(self, *args, **kwargs):
|
||||||
self.close()
|
"""Unregister the server from _poll"""
|
||||||
|
|
||||||
|
self.logger.info("Closing connection")
|
||||||
|
|
||||||
def open(self):
|
if self.fileno() > 0:
|
||||||
"""Generic open function that register the server un _rlist in case
|
sync_act("sckt unregister %d" % self.fileno())
|
||||||
of successful _open"""
|
|
||||||
self.logger.info("Opening connection to %s", self.id)
|
|
||||||
if not hasattr(self, "_open") or self._open():
|
|
||||||
_rlist.append(self)
|
|
||||||
_xlist.append(self)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
super().close(*args, **kwargs)
|
||||||
def close(self):
|
|
||||||
"""Generic close function that register the server un _{r,w,x}list in
|
|
||||||
case of successful _close"""
|
|
||||||
self.logger.info("Closing connection to %s", self.id)
|
|
||||||
with _lock:
|
|
||||||
if not hasattr(self, "_close") or self._close():
|
|
||||||
if self in _rlist:
|
|
||||||
_rlist.remove(self)
|
|
||||||
if self in _wlist:
|
|
||||||
_wlist.remove(self)
|
|
||||||
if self in _xlist:
|
|
||||||
_xlist.remove(self)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Writes
|
# Writes
|
||||||
@ -99,13 +87,16 @@ class AbstractServer(io.IOBase):
|
|||||||
message -- message to send
|
message -- message to send
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._send_callback(message)
|
self._sending_queue.put(self.format(message))
|
||||||
|
self.logger.debug("Message '%s' appended to write queue", message)
|
||||||
|
sync_act("sckt write %d" % self.fileno())
|
||||||
|
|
||||||
|
|
||||||
def write_select(self):
|
def async_write(self):
|
||||||
"""Internal function used by the select function"""
|
"""Internal function used when the file descriptor is writable"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_wlist.remove(self)
|
sync_act("sckt unwrite %d" % self.fileno())
|
||||||
while not self._sending_queue.empty():
|
while not self._sending_queue.empty():
|
||||||
self._write(self._sending_queue.get_nowait())
|
self._write(self._sending_queue.get_nowait())
|
||||||
self._sending_queue.task_done()
|
self._sending_queue.task_done()
|
||||||
@ -114,19 +105,6 @@ class AbstractServer(io.IOBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _write_select(self, message):
|
|
||||||
"""Send a message to the server safely through select
|
|
||||||
|
|
||||||
Argument:
|
|
||||||
message -- message to send
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._sending_queue.put(self.format(message))
|
|
||||||
self.logger.debug("Message '%s' appended to write queue", message)
|
|
||||||
if self not in _wlist:
|
|
||||||
_wlist.append(self)
|
|
||||||
|
|
||||||
|
|
||||||
def send_response(self, response):
|
def send_response(self, response):
|
||||||
"""Send a formated Message class
|
"""Send a formated Message class
|
||||||
|
|
||||||
@ -149,13 +127,39 @@ class AbstractServer(io.IOBase):
|
|||||||
|
|
||||||
# Read
|
# Read
|
||||||
|
|
||||||
|
def async_read(self):
|
||||||
|
"""Internal function used when the file descriptor is readable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of fully received messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
|
||||||
|
|
||||||
|
for r in ret:
|
||||||
|
yield r
|
||||||
|
|
||||||
|
|
||||||
|
def lex(self, buf):
|
||||||
|
"""Assume lexing in default case is per line
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
buf -- buffer to lex
|
||||||
|
"""
|
||||||
|
|
||||||
|
msgs = buf.split(b'\r\n')
|
||||||
|
partial = msgs.pop()
|
||||||
|
|
||||||
|
return msgs, partial
|
||||||
|
|
||||||
|
|
||||||
def parse(self, msg):
|
def parse(self, msg):
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
|
||||||
|
|
||||||
# Exceptions
|
# Exceptions
|
||||||
|
|
||||||
def exception(self):
|
def exception(self, flags):
|
||||||
"""Exception occurs in fd"""
|
"""Exception occurs on fd"""
|
||||||
self.logger.warning("Unhandle file descriptor exception on server %s",
|
|
||||||
self.name)
|
self.close()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Nemubot is a smart and modulable IM bot.
|
# Nemubot is a smart and modulable IM bot.
|
||||||
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
|
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -14,117 +14,35 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
|
||||||
import nemubot.message as message
|
import nemubot.message as message
|
||||||
from nemubot.message.printer.socket import Socket as SocketPrinter
|
from nemubot.message.printer.socket import Socket as SocketPrinter
|
||||||
from nemubot.server.abstract import AbstractServer
|
from nemubot.server.abstract import AbstractServer
|
||||||
|
|
||||||
|
|
||||||
class SocketServer(AbstractServer):
|
class _Socket(AbstractServer, socket.socket):
|
||||||
|
|
||||||
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
|
"""Concrete implementation of a socket connection"""
|
||||||
|
|
||||||
def __init__(self, sock_location=None,
|
def __init__(self, **kwargs):
|
||||||
host=None, port=None,
|
|
||||||
sock=None,
|
|
||||||
ssl=False,
|
|
||||||
name=None):
|
|
||||||
"""Create a server socket
|
"""Create a server socket
|
||||||
|
|
||||||
Keyword arguments:
|
Keyword arguments:
|
||||||
sock_location -- Path to the UNIX socket
|
|
||||||
host -- Hostname of the INET socket
|
|
||||||
port -- Port of the INET socket
|
|
||||||
sock -- Already connected socket
|
|
||||||
ssl -- Should TLS connection enabled
|
ssl -- Should TLS connection enabled
|
||||||
name -- Convinience name
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import socket
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
assert(sock is None or isinstance(sock, socket.SocketType))
|
|
||||||
assert(port is None or isinstance(port, int))
|
|
||||||
|
|
||||||
super().__init__(name=name)
|
|
||||||
|
|
||||||
if sock is None:
|
|
||||||
if sock_location is not None:
|
|
||||||
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
||||||
self.connect_to = sock_location
|
|
||||||
elif host is not None:
|
|
||||||
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
|
|
||||||
self.socket = socket.socket(af, socktype, proto)
|
|
||||||
self.connect_to = sa
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.socket = sock
|
|
||||||
|
|
||||||
self.ssl = ssl
|
|
||||||
|
|
||||||
self.readbuffer = b''
|
self.readbuffer = b''
|
||||||
self.printer = SocketPrinter
|
self.printer = SocketPrinter
|
||||||
|
|
||||||
|
|
||||||
def fileno(self):
|
|
||||||
return self.socket.fileno() if self.socket else None
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def closed(self):
|
|
||||||
"""Indicator of the connection aliveness"""
|
|
||||||
return self.socket._closed
|
|
||||||
|
|
||||||
|
|
||||||
# Open/close
|
|
||||||
|
|
||||||
def open(self):
|
|
||||||
if not self.closed:
|
|
||||||
return True
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.socket.connect(self.connect_to)
|
|
||||||
self.logger.info("Connected to %s", self.connect_to)
|
|
||||||
except:
|
|
||||||
self.socket.close()
|
|
||||||
self.logger.exception("Unable to connect to %s",
|
|
||||||
self.connect_to)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Wrap the socket for SSL
|
|
||||||
if self.ssl:
|
|
||||||
import ssl
|
|
||||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
|
||||||
self.socket = ctx.wrap_socket(self.socket)
|
|
||||||
|
|
||||||
return super().open()
|
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
import socket
|
|
||||||
|
|
||||||
# Flush the sending queue before close
|
|
||||||
from nemubot.server import _lock
|
|
||||||
_lock.release()
|
|
||||||
self._sending_queue.join()
|
|
||||||
_lock.acquire()
|
|
||||||
|
|
||||||
if not self.closed:
|
|
||||||
try:
|
|
||||||
self.socket.shutdown(socket.SHUT_RDWR)
|
|
||||||
except socket.error:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.socket.close()
|
|
||||||
|
|
||||||
return super().close()
|
|
||||||
|
|
||||||
|
|
||||||
# Write
|
# Write
|
||||||
|
|
||||||
def _write(self, cnt):
|
def _write(self, cnt):
|
||||||
if self.closed:
|
self.sendall(cnt)
|
||||||
return
|
|
||||||
|
|
||||||
self.socket.sendall(cnt)
|
|
||||||
|
|
||||||
|
|
||||||
def format(self, txt):
|
def format(self, txt):
|
||||||
@ -136,19 +54,12 @@ class SocketServer(AbstractServer):
|
|||||||
|
|
||||||
# Read
|
# Read
|
||||||
|
|
||||||
def read(self):
|
def read(self, n=1024):
|
||||||
if self.closed:
|
return self.recv(n)
|
||||||
return []
|
|
||||||
|
|
||||||
raw = self.socket.recv(1024)
|
|
||||||
temp = (self.readbuffer + raw).split(b'\r\n')
|
|
||||||
self.readbuffer = temp.pop()
|
|
||||||
|
|
||||||
for line in temp:
|
|
||||||
yield line
|
|
||||||
|
|
||||||
|
|
||||||
def parse(self, line):
|
def parse(self, line):
|
||||||
|
"""Implement a default behaviour for socket"""
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
line = line.strip().decode()
|
line = line.strip().decode()
|
||||||
@ -157,48 +68,84 @@ class SocketServer(AbstractServer):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
args = line.split(' ')
|
args = line.split(' ')
|
||||||
|
|
||||||
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
|
if len(args):
|
||||||
|
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
|
||||||
|
|
||||||
|
|
||||||
class SocketListener(AbstractServer):
|
class SocketServer(_Socket):
|
||||||
|
|
||||||
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
|
def __init__(self, host, port, bind=None, **kwargs):
|
||||||
super().__init__(name=name)
|
super().__init__(family=socket.AF_INET, **kwargs)
|
||||||
self.new_server_cb = new_server_cb
|
|
||||||
self.sock_location = sock_location
|
assert(host is not None)
|
||||||
self.host = host
|
assert(isinstance(port, int))
|
||||||
self.port = port
|
|
||||||
self.ssl = ssl
|
self._host = host
|
||||||
self.nb_son = 0
|
self._port = port
|
||||||
|
self._bind = bind
|
||||||
|
|
||||||
|
|
||||||
def fileno(self):
|
def connect(self):
|
||||||
return self.socket.fileno() if self.socket else None
|
self.logger.info("Connection to %s:%d", self._host, self._port)
|
||||||
|
super().connect((self._host, self._port))
|
||||||
|
|
||||||
|
if self._bind:
|
||||||
|
super().bind(self._bind)
|
||||||
|
|
||||||
|
|
||||||
@property
|
class UnixSocket(_Socket):
|
||||||
def closed(self):
|
|
||||||
"""Indicator of the connection aliveness"""
|
def __init__(self, location, **kwargs):
|
||||||
return self.socket is None
|
super().__init__(family=socket.AF_UNIX, **kwargs)
|
||||||
|
|
||||||
|
self._socket_path = location
|
||||||
|
|
||||||
|
|
||||||
def open(self):
|
def connect(self):
|
||||||
import os
|
self.logger.info("Connection to unix://%s", self._socket_path)
|
||||||
import socket
|
super().connect(self._socket_path)
|
||||||
|
|
||||||
|
|
||||||
|
class _Listener(_Socket):
|
||||||
|
|
||||||
|
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self._instanciate = instanciate
|
||||||
|
self._new_server_cb = new_server_cb
|
||||||
|
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
conn, addr = self.accept()
|
||||||
|
self.logger.info("Accept new connection from %s", addr)
|
||||||
|
|
||||||
|
fileno = conn.fileno()
|
||||||
|
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
|
||||||
|
ss.connect = ss._connected
|
||||||
|
self._new_server_cb(ss, autoconnect=True)
|
||||||
|
|
||||||
|
return b''
|
||||||
|
|
||||||
|
|
||||||
|
class UnixSocketListener(_Listener, UnixSocket):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
|
||||||
|
|
||||||
if self.sock_location is not None:
|
|
||||||
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
||||||
try:
|
try:
|
||||||
os.remove(self.sock_location)
|
os.remove(self._socket_path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
self.socket.bind(self.sock_location)
|
|
||||||
elif self.host is not None and self.port is not None:
|
|
||||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
self.socket.bind((self.host, self.port))
|
|
||||||
self.socket.listen(5)
|
|
||||||
|
|
||||||
return super().open()
|
self.bind(self._socket_path)
|
||||||
|
self.listen(5)
|
||||||
|
self.logger.info("Socket ready for accepting new connections")
|
||||||
|
|
||||||
|
self._connected()
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@ -206,25 +153,14 @@ class SocketListener(AbstractServer):
|
|||||||
import socket
|
import socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.socket.shutdown(socket.SHUT_RDWR)
|
self.shutdown(socket.SHUT_RDWR)
|
||||||
self.socket.close()
|
|
||||||
if self.sock_location is not None:
|
|
||||||
os.remove(self.sock_location)
|
|
||||||
except socket.error:
|
except socket.error:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return super().close()
|
super().close()
|
||||||
|
|
||||||
|
try:
|
||||||
# Read
|
if self._socket_path is not None:
|
||||||
|
os.remove(self._socket_path)
|
||||||
def read(self):
|
except:
|
||||||
if self.closed:
|
pass
|
||||||
return []
|
|
||||||
|
|
||||||
conn, addr = self.socket.accept()
|
|
||||||
self.nb_son += 1
|
|
||||||
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
|
|
||||||
self.new_server_cb(ss)
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
@ -57,11 +57,15 @@ def word_distance(str1, str2):
|
|||||||
return d[len(str1)][len(str2)]
|
return d[len(str1)][len(str2)]
|
||||||
|
|
||||||
|
|
||||||
def guess(pattern, expect):
|
def guess(pattern, expect, max_depth=0):
|
||||||
|
if max_depth == 0:
|
||||||
|
max_depth = 1 + len(pattern) / 4
|
||||||
|
elif max_depth <= -1:
|
||||||
|
max_depth = len(pattern) - max_depth
|
||||||
if len(expect):
|
if len(expect):
|
||||||
se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1])
|
se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1])
|
||||||
_, m = se[0]
|
_, m = se[0]
|
||||||
for e, wd in se:
|
for e, wd in se:
|
||||||
if wd > m or wd > 1 + len(pattern) / 4:
|
if wd > m or wd > max_depth:
|
||||||
break
|
break
|
||||||
yield e
|
yield e
|
||||||
|
Loading…
x
Reference in New Issue
Block a user