Compare commits

...

22 commits

Author SHA1 Message Date
39936e9d39 Add type for use with mypy 2016-07-10 00:17:37 +02:00
a8706d6213 New printer and parser for bot data, XML-based 2016-07-10 00:17:37 +02:00
e103d22bf2 Use fileno instead of name to index existing servers 2016-07-10 00:15:14 +02:00
1e8cb3a12a Use super() instead of parent class name 2016-07-10 00:15:14 +02:00
fab747fcfd Documentation 2016-07-10 00:15:14 +02:00
c0e489f6b6 In debug mode, display running thread at exit 2016-07-08 22:42:44 +02:00
509446a0f4 Handle case where frm and to have not been filled 2016-07-08 22:27:57 +02:00
4473d9547e Review consumer errors 2016-07-08 22:27:57 +02:00
95fc044783 Remove reload feature
As reload shoudl be done in a particular order, to keep valid types, and because maintaining such system is too complex (currently, it doesn't work for a while), now, a reload is just reload configuration file (and possibly modules)
2016-07-08 22:27:57 +02:00
1dd06f1621 New keywords class that accepts any keywords 2016-07-08 22:27:57 +02:00
f9837abba8 [rnd] Add new function choiceres which pick a random response returned by a given subcommand 2016-07-08 22:27:57 +02:00
670a1319a2 Can attach to the main process 2016-07-08 22:27:57 +02:00
2b1469c03f Remove legacy prompt 2016-07-08 22:27:57 +02:00
8983b9b67c Fix and improve reload process 2016-07-08 22:27:57 +02:00
c0fed51fde New argument: --socketfile that create a socket for internal communication 2016-07-08 22:27:57 +02:00
0b14207c88 New CLI argument: --pidfile, path to store the daemon PID 2016-07-08 22:27:57 +02:00
e17368cf26 Catch SIGUSR1: log threads stack traces 2016-07-08 22:27:57 +02:00
220bc7356e Extract deamonize to a dedicated function that can be called from anywhere 2016-07-08 22:27:57 +02:00
29913bd943 Catch SIGHUP: deep reload 2016-07-08 22:27:57 +02:00
7440bd4222 Do a proper close on SIGINT and SIGTERM 2016-07-08 22:27:57 +02:00
179397e96a Remove prompt at launch 2016-07-08 22:27:57 +02:00
992c847e27 Introducing daemon mode 2016-07-08 22:27:57 +02:00
47 changed files with 1051 additions and 1216 deletions

View file

@ -1,202 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import traceback
import sys
from nemubot.hooks import hook
nemubotversion = 3.4
NODATA = True
def getserver(toks, context, prompt, mandatory=False, **kwargs):
"""Choose the server in toks or prompt.
This function modify the tokens list passed as argument"""
if len(toks) > 1 and toks[1] in context.servers:
return context.servers[toks.pop(1)]
elif not mandatory or prompt.selectedServer:
return prompt.selectedServer
else:
from nemubot.prompt.error import PromptError
raise PromptError("Please SELECT a server or give its name in argument.")
@hook("prompt_cmd", "close")
def close(toks, context, **kwargs):
"""Disconnect and forget (remove from the servers list) the server"""
srv = getserver(toks, context=context, mandatory=True, **kwargs)
if srv.close():
del context.servers[srv.id]
return 0
return 1
@hook("prompt_cmd", "connect")
def connect(toks, **kwargs):
"""Make the connexion to a server"""
srv = getserver(toks, mandatory=True, **kwargs)
return not srv.open()
@hook("prompt_cmd", "disconnect")
def disconnect(toks, **kwargs):
"""Close the connection to a server"""
srv = getserver(toks, mandatory=True, **kwargs)
return not srv.close()
@hook("prompt_cmd", "discover")
def discover(toks, context, **kwargs):
"""Discover a new bot on a server"""
srv = getserver(toks, context=context, mandatory=True, **kwargs)
if len(toks) > 1 and "!" in toks[1]:
bot = context.add_networkbot(srv, name)
return not bot.connect()
else:
print(" %s is not a valid fullname, for example: "
"nemubot!nemubotV3@bot.nemunai.re" % ''.join(toks[1:1]))
return 1
@hook("prompt_cmd", "join")
@hook("prompt_cmd", "leave")
@hook("prompt_cmd", "part")
def join(toks, **kwargs):
"""Join or leave a channel"""
srv = getserver(toks, mandatory=True, **kwargs)
if len(toks) <= 2:
print("%s: not enough arguments." % toks[0])
return 1
if toks[0] == "join":
if len(toks) > 2:
srv.write("JOIN %s %s" % (toks[1], toks[2]))
else:
srv.write("JOIN %s" % toks[1])
elif toks[0] == "leave" or toks[0] == "part":
if len(toks) > 2:
srv.write("PART %s :%s" % (toks[1], " ".join(toks[2:])))
else:
srv.write("PART %s" % toks[1])
return 0
@hook("prompt_cmd", "save")
def save_mod(toks, context, **kwargs):
"""Force save module data"""
if len(toks) < 2:
print("save: not enough arguments.")
return 1
wrn = 0
for mod in toks[1:]:
if mod in context.modules:
context.modules[mod].save()
print("save: module `%s´ saved successfully" % mod)
else:
wrn += 1
print("save: no module named `%s´" % mod)
return wrn
@hook("prompt_cmd", "send")
def send(toks, **kwargs):
"""Send a message on a channel"""
srv = getserver(toks, mandatory=True, **kwargs)
# Check the server is connected
if not srv.connected:
print ("send: server `%s' not connected." % srv.id)
return 2
if len(toks) <= 3:
print ("send: not enough arguments.")
return 1
if toks[1] not in srv.channels:
print ("send: channel `%s' not authorized in server `%s'."
% (toks[1], srv.id))
return 3
from nemubot.message import Text
srv.send_response(Text(" ".join(toks[2:]), server=None,
to=[toks[1]]))
return 0
@hook("prompt_cmd", "zap")
def zap(toks, **kwargs):
"""Hard change connexion state"""
srv = getserver(toks, mandatory=True, **kwargs)
srv.connected = not srv.connected
@hook("prompt_cmd", "top")
def top(toks, context, **kwargs):
"""Display consumers load information"""
print("Queue size: %d, %d thread(s) running (counter: %d)" %
(context.cnsr_queue.qsize(),
len(context.cnsr_thrd),
context.cnsr_thrd_size))
if len(context.events) > 0:
print("Events registered: %d, next in %d seconds" %
(len(context.events),
context.events[0].time_left.seconds))
else:
print("No events registered")
for th in context.cnsr_thrd:
if th.is_alive():
print(("#" * 15 + " Stack trace for thread %u " + "#" * 15) %
th.ident)
traceback.print_stack(sys._current_frames()[th.ident])
@hook("prompt_cmd", "netstat")
def netstat(toks, context, **kwargs):
"""Display sockets in use and many other things"""
if len(context.network) > 0:
print("Distant bots connected: %d:" % len(context.network))
for name, bot in context.network.items():
print("# %s:" % name)
print(" * Declared hooks:")
lvl = 0
for hlvl in bot.hooks:
lvl += 1
for hook in (hlvl.all_pre + hlvl.all_post + hlvl.cmd_rgxp +
hlvl.cmd_default + hlvl.ask_rgxp +
hlvl.ask_default + hlvl.msg_rgxp +
hlvl.msg_default):
print(" %s- %s" % (' ' * lvl * 2, hook))
for kind in ["irc_hook", "cmd_hook", "ask_hook", "msg_hook"]:
print(" %s- <%s> %s" % (' ' * lvl * 2, kind,
", ".join(hlvl.__dict__[kind].keys())))
print(" * My tag: %d" % bot.my_tag)
print(" * Tags in use (%d):" % bot.inc_tag)
for tag, (cmd, data) in bot.tags.items():
print(" - %11s: %s « %s »" % (tag, cmd, data))
else:
print("No distant bot connected")

View file

@ -8,7 +8,6 @@ import shlex
from nemubot import context from nemubot import context
from nemubot.exception import IMException from nemubot.exception import IMException
from nemubot.hooks import hook from nemubot.hooks import hook
from nemubot.message import Command
from more import Response from more import Response
@ -32,8 +31,24 @@ def cmd_choicecmd(msg):
choice = shlex.split(random.choice(msg.args)) choice = shlex.split(random.choice(msg.args))
return [x for x in context.subtreat(Command(choice[0][1:], return [x for x in context.subtreat(context.subparse(msg, choice))]
choice[1:],
to_response=msg.to_response,
frm=msg.frm, @hook.command("choiceres")
server=msg.server))] def cmd_choiceres(msg):
if not len(msg.args):
raise IMException("indicate some command to pick a message from!")
rl = [x for x in context.subtreat(context.subparse(msg, " ".join(msg.args)))]
if len(rl) <= 0:
return rl
r = random.choice(rl)
if isinstance(r, Response):
for i in range(len(r.messages) - 1, -1, -1):
if isinstance(r.messages[i], list):
r.messages = [ random.choice(random.choice(r.messages)) ]
elif isinstance(r.messages[i], str):
r.messages = [ random.choice(r.messages) ]
return r

View file

@ -17,11 +17,15 @@
__version__ = '4.0.dev3' __version__ = '4.0.dev3'
__author__ = 'nemunaire' __author__ = 'nemunaire'
from typing import Optional
from nemubot.modulecontext import ModuleContext from nemubot.modulecontext import ModuleContext
context = ModuleContext(None, None) context = ModuleContext(None, None)
def requires_version(min=None, max=None): def requires_version(min: Optional[int] = None,
max: Optional[int] = None) -> None:
"""Raise ImportError if the current version is not in the given range """Raise ImportError if the current version is not in the given range
Keyword arguments: Keyword arguments:
@ -38,62 +42,98 @@ def requires_version(min=None, max=None):
"but this is nemubot v%s." % (str(max), __version__)) "but this is nemubot v%s." % (str(max), __version__))
def reload(): def attach(pid: int, socketfile: str) -> int:
"""Reload code of all Python modules used by nemubot import socket
import sys
print("nemubot is already launched with PID %d. Attaching to Unix socket at: %s" % (pid, socketfile))
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(socketfile)
except socket.error as e:
sys.stderr.write(str(e))
sys.stderr.write("\n")
return 1
from select import select
try:
print("Connection established.")
while True:
rl, wl, xl = select([sys.stdin, sock], [], [])
if sys.stdin in rl:
line = sys.stdin.readline().strip()
if line == "exit" or line == "quit":
return 0
elif line == "reload":
import os, signal
os.kill(pid, signal.SIGHUP)
print("Reload signal sent. Please wait...")
elif line == "shutdown":
import os, signal
os.kill(pid, signal.SIGTERM)
print("Shutdown signal sent. Please wait...")
elif line == "kill":
import os, signal
os.kill(pid, signal.SIGKILL)
print("Signal sent...")
return 0
elif line == "stack" or line == "stacks":
import os, signal
os.kill(pid, signal.SIGUSR1)
print("Debug signal sent. Consult logs.")
else:
sock.send(line.encode() + b'\r\n')
if sock in rl:
sys.stdout.write(sock.recv(2048).decode())
except KeyboardInterrupt:
pass
except:
return 1
finally:
sock.close()
return 0
def daemonize() -> None:
"""Detach the running process to run as a daemon
""" """
import imp import os
import sys
import nemubot.channel try:
imp.reload(nemubot.channel) pid = os.fork()
if pid > 0:
sys.exit(0)
except OSError as err:
sys.stderr.write("Unable to fork: %s\n" % err)
sys.exit(1)
import nemubot.config os.setsid()
imp.reload(nemubot.config) os.umask(0)
os.chdir('/')
nemubot.config.reload() try:
pid = os.fork()
if pid > 0:
sys.exit(0)
except OSError as err:
sys.stderr.write("Unable to fork: %s\n" % err)
sys.exit(1)
import nemubot.consumer sys.stdout.flush()
imp.reload(nemubot.consumer) sys.stderr.flush()
si = open(os.devnull, 'r')
so = open(os.devnull, 'a+')
se = open(os.devnull, 'a+')
import nemubot.datastore os.dup2(si.fileno(), sys.stdin.fileno())
imp.reload(nemubot.datastore) os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno())
nemubot.datastore.reload()
import nemubot.event
imp.reload(nemubot.event)
import nemubot.exception
imp.reload(nemubot.exception)
nemubot.exception.reload()
import nemubot.hooks
imp.reload(nemubot.hooks)
nemubot.hooks.reload()
import nemubot.importer
imp.reload(nemubot.importer)
import nemubot.message
imp.reload(nemubot.message)
nemubot.message.reload()
import nemubot.prompt
imp.reload(nemubot.prompt)
nemubot.prompt.reload()
import nemubot.server
rl, wl, xl = nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist
imp.reload(nemubot.server)
nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist = rl, wl, xl
nemubot.server.reload()
import nemubot.tools
imp.reload(nemubot.tools)
nemubot.tools.reload()

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,8 +14,9 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def main(): def main() -> None:
import os import os
import signal
import sys import sys
# Parse command line arguments # Parse command line arguments
@ -36,6 +37,15 @@ def main():
default=["./modules/"], default=["./modules/"],
help="directory to use as modules store") help="directory to use as modules store")
parser.add_argument("-d", "--debug", action="store_true",
help="don't deamonize, keep in foreground")
parser.add_argument("-P", "--pidfile", default="./nemubot.pid",
help="Path to the file where store PID")
parser.add_argument("-S", "--socketfile", default="./nemubot.sock",
help="path where open the socket for internal communication")
parser.add_argument("-l", "--logfile", default="./nemubot.log", parser.add_argument("-l", "--logfile", default="./nemubot.log",
help="Path to store logs") help="Path to store logs")
@ -58,10 +68,34 @@ def main():
# Resolve relatives paths # Resolve relatives paths
args.data_path = os.path.abspath(os.path.expanduser(args.data_path)) args.data_path = os.path.abspath(os.path.expanduser(args.data_path))
args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile))
args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile))
args.logfile = os.path.abspath(os.path.expanduser(args.logfile)) args.logfile = os.path.abspath(os.path.expanduser(args.logfile))
args.files = [ x for x in map(os.path.abspath, args.files)] args.files = [ x for x in map(os.path.abspath, args.files)]
args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)] args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)]
# Check if an instance is already launched
if args.pidfile is not None and os.path.isfile(args.pidfile):
with open(args.pidfile, "r") as f:
pid = int(f.readline())
try:
os.kill(pid, 0)
except OSError:
pass
else:
from nemubot import attach
sys.exit(attach(pid, args.socketfile))
# Daemonize
if not args.debug:
from nemubot import daemonize
daemonize()
# Store PID to pidfile
if args.pidfile is not None:
with open(args.pidfile, "w+") as f:
f.write(str(os.getpid()))
# Setup loggin interface # Setup loggin interface
import logging import logging
logger = logging.getLogger("nemubot") logger = logging.getLogger("nemubot")
@ -70,11 +104,12 @@ def main():
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s %(name)s %(levelname)s %(message)s') '%(asctime)s %(name)s %(levelname)s %(message)s')
ch = logging.StreamHandler() if args.debug:
ch.setFormatter(formatter) ch = logging.StreamHandler()
if args.verbose < 2: ch.setFormatter(formatter)
ch.setLevel(logging.INFO) if args.verbose < 2:
logger.addHandler(ch) ch.setLevel(logging.INFO)
logger.addHandler(ch)
fh = logging.FileHandler(args.logfile) fh = logging.FileHandler(args.logfile)
fh.setFormatter(formatter) fh.setFormatter(formatter)
@ -98,13 +133,10 @@ def main():
if args.no_connect: if args.no_connect:
context.noautoconnect = True context.noautoconnect = True
# Load the prompt
import nemubot.prompt
prmpt = nemubot.prompt.Prompt()
# Register the hook for futur import # Register the hook for futur import
from nemubot.importer import ModuleFinder from nemubot.importer import ModuleFinder
sys.meta_path.append(ModuleFinder(context.modules_paths, context.add_module)) module_finder = ModuleFinder(context.modules_paths, context.add_module)
sys.meta_path.append(module_finder)
# Load requested configuration files # Load requested configuration files
for path in args.files: for path in args.files:
@ -117,36 +149,57 @@ def main():
for module in args.module: for module in args.module:
__import__(module) __import__(module)
print ("Nemubot v%s ready, my PID is %i!" % (nemubot.__version__, # Signals handling
os.getpid())) def sigtermhandler(signum, frame):
while True: """On SIGTERM and SIGINT, quit nicely"""
from nemubot.prompt.reset import PromptReset sigusr1handler(signum, frame)
try: context.quit()
context.start() signal.signal(signal.SIGINT, sigtermhandler)
if prmpt.run(context): signal.signal(signal.SIGTERM, sigtermhandler)
break
except PromptReset as e:
if e.type == "quit":
break
try: def sighuphandler(signum, frame):
import imp """On SIGHUP, perform a deep reload"""
# Reload all other modules nonlocal context
imp.reload(nemubot)
imp.reload(nemubot.prompt)
nemubot.reload()
import nemubot.bot
context = nemubot.bot.hotswap(context)
prmpt = nemubot.prompt.hotswap(prmpt)
print("\033[1;32mContext reloaded\033[0m, now in Nemubot %s" %
nemubot.__version__)
except:
logger.exception("\033[1;31mUnable to reload the prompt due to "
"errors.\033[0m Fix them before trying to reload "
"the prompt.")
context.quit() logger.debug("SIGHUP receive, iniate reload procedure...")
print("Waiting for other threads shuts down...")
# Reload configuration file
for path in args.files:
if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path])
signal.signal(signal.SIGHUP, sighuphandler)
def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces"""
import threading, traceback
for threadId, stack in sys._current_frames().items():
thName = "#%d" % threadId
for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile:
from nemubot.server.socket import UnixSocketListener
context.add_server(UnixSocketListener(new_server_cb=context.add_server,
location=args.socketfile,
name="master_socket"))
# context can change when performing an hotswap, always join the latest context
oldcontext = None
while oldcontext != context:
oldcontext = context
context.start()
context.join()
# Wait for consumers
logger.info("Waiting for other threads shuts down...")
if args.debug:
sigusr1handler(0, None)
sys.exit(0) sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -15,9 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timezone from datetime import datetime, timezone
import ipaddress
import logging import logging
from multiprocessing import JoinableQueue
import threading import threading
import select
import sys import sys
from typing import Any, Mapping, Optional, Sequence
from nemubot import __version__ from nemubot import __version__
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
@ -26,22 +30,33 @@ import nemubot.hooks
logger = logging.getLogger("nemubot") logger = logging.getLogger("nemubot")
sync_queue = JoinableQueue()
def sync_act(*args):
if isinstance(act, bytes):
act = act.decode()
sync_queue.put(act)
class Bot(threading.Thread): class Bot(threading.Thread):
"""Class containing the bot context and ensuring key goals""" """Class containing the bot context and ensuring key goals"""
def __init__(self, ip="127.0.0.1", modules_paths=list(), def __init__(self,
data_store=datastore.Abstract(), verbosity=0): ip: Optional[ipaddress] = None,
modules_paths: Sequence[str] = list(),
data_store: Optional[datastore.Abstract] = None,
verbosity: int = 0):
"""Initialize the bot context """Initialize the bot context
Keyword arguments: Keyword arguments:
ip -- The external IP of the bot (default: 127.0.0.1) ip -- The external IP of the bot (default: 127.0.0.1)
modules_paths -- Paths to all directories where looking for module modules_paths -- Paths to all directories where looking for modules
data_store -- An instance of the nemubot datastore for bot's modules data_store -- An instance of the nemubot datastore for bot's modules
verbosity -- verbosity level
""" """
threading.Thread.__init__(self) super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__, __version__,
@ -51,16 +66,16 @@ class Bot(threading.Thread):
self.stop = None self.stop = None
# External IP for accessing this bot # External IP for accessing this bot
import ipaddress self.ip = ip if ip is not None else ipaddress.ip_address("127.0.0.1")
self.ip = ipaddress.ip_address(ip)
# Context paths # Context paths
self.modules_paths = modules_paths self.modules_paths = modules_paths
self.datastore = data_store self.datastore = data_store if data_store is not None else datastore.Abstract()
self.datastore.open() self.datastore.open()
# Keep global context: servers and modules # Keep global context: servers and modules
self.servers = dict() self._poll = select.poll()
self.servers = dict() # types: Mapping[str, AbstractServer]
self.modules = dict() self.modules = dict()
self.modules_configuration = dict() self.modules_configuration = dict()
@ -137,59 +152,76 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue() self.cnsr_queue = Queue()
self.cnsr_thrd = list() self.cnsr_thrd = list()
self.cnsr_thrd_size = -1 self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self): def run(self):
from select import select self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
from nemubot.server import _lock, _rlist, _wlist, _xlist
logger.info("Starting main loop")
self.stop = False self.stop = False
while not self.stop: while not self.stop:
with _lock: for fd, flag in self._poll.poll():
try: print("poll")
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) # Handle internal socket passing orders
except: if fd != sync_queue._reader.fileno():
logger.error("Something went wrong in select") srv = self.servers[fd]
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for x in xl: if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
x.exception()
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
try:
w.write_select()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
try: try:
self.receive_message(r, i) srv.exception(flag)
except: except:
logger.exception("Uncatched exception on server read") logger.exception("Uncatched exception on server exception")
if srv.fileno() > 0:
if flag & (select.POLLOUT):
try:
srv.async_write()
except:
logger.exception("Uncatched exception on server write")
if flag & (select.POLLIN | select.POLLPRI):
try:
for i in srv.async_read():
self.receive_message(srv, i)
except:
logger.exception("Uncatched exception on server read")
else:
del self.servers[fd]
# Always check the sync queue
while not sync_queue.empty():
import shlex
args = shlex.split(sync_queue.get())
action = args.pop(0)
logger.info("action: %s: %s", action, args)
if action == "sckt" and len(args) >= 2:
try:
if args[0] == "write":
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "unwrite":
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "register":
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "unregister":
self._poll.unregister(int(args[1]))
except:
logger.exception("Unhandled excpetion during action:")
elif action == "exit":
self.quit()
elif action == "loadconf":
for path in action.args:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
sync_queue.task_done()
# Launch new consumer threads if necessary # Launch new consumer threads if necessary
@ -200,17 +232,7 @@ class Bot(threading.Thread):
c = Consumer(self) c = Consumer(self)
self.cnsr_thrd.append(c) self.cnsr_thrd.append(c)
c.start() c.start()
logger.info("Ending main loop")
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
@ -252,9 +274,9 @@ class Bot(threading.Thread):
srv = server.server(config) srv = server.server(config)
# Add the server in the context # Add the server in the context
if self.add_server(srv, server.autoconnect): if self.add_server(srv, server.autoconnect):
logger.info("Server '%s' successfully added." % srv.id) logger.info("Server '%s' successfully added." % srv.name)
else: else:
logger.error("Can't add server '%s'." % srv.id) logger.error("Can't add server '%s'." % srv.name)
# Load module and their configuration # Load module and their configuration
for mod in config.modules: for mod in config.modules:
@ -303,7 +325,7 @@ class Bot(threading.Thread):
if type(eid) is uuid.UUID: if type(eid) is uuid.UUID:
evt.id = str(eid) evt.id = str(eid)
else: else:
# Ok, this is quite useless... # Ok, this is quiet useless...
try: try:
evt.id = str(uuid.UUID(eid)) evt.id = str(uuid.UUID(eid))
except ValueError: except ValueError:
@ -411,10 +433,11 @@ class Bot(threading.Thread):
autoconnect -- connect after add? autoconnect -- connect after add?
""" """
if srv.id not in self.servers: fileno = srv.fileno()
self.servers[srv.id] = srv if fileno not in self.servers:
self.servers[fileno] = srv
if autoconnect and not hasattr(self, "noautoconnect"): if autoconnect and not hasattr(self, "noautoconnect"):
srv.open() srv.connect()
return True return True
else: else:
@ -436,10 +459,10 @@ class Bot(threading.Thread):
__import__(name) __import__(name)
def add_module(self, module): def add_module(self, mdl: Any):
"""Add a module to the context, if already exists, unload the """Add a module to the context, if already exists, unload the
old one before""" old one before"""
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ module_name = mdl.__spec__.name if hasattr(mdl, "__spec__") else mdl.__name__
if hasattr(self, "stop") and self.stop: if hasattr(self, "stop") and self.stop:
logger.warn("The bot is stopped, can't register new modules") logger.warn("The bot is stopped, can't register new modules")
@ -451,40 +474,40 @@ class Bot(threading.Thread):
# Overwrite print built-in # Overwrite print built-in
def prnt(*args): def prnt(*args):
if hasattr(module, "logger"): if hasattr(mdl, "logger"):
module.logger.info(" ".join([str(s) for s in args])) mdl.logger.info(" ".join([str(s) for s in args]))
else: else:
logger.info("[%s] %s", module_name, " ".join([str(s) for s in args])) logger.info("[%s] %s", module_name, " ".join([str(s) for s in args]))
module.print = prnt mdl.print = prnt
# Create module context # Create module context
from nemubot.modulecontext import ModuleContext from nemubot.modulecontext import ModuleContext
module.__nemubot_context__ = ModuleContext(self, module) mdl.__nemubot_context__ = ModuleContext(self, mdl)
if not hasattr(module, "logger"): if not hasattr(mdl, "logger"):
module.logger = logging.getLogger("nemubot.module." + module_name) mdl.logger = logging.getLogger("nemubot.module." + module_name)
# Replace imported context by real one # Replace imported context by real one
for attr in module.__dict__: for attr in mdl.__dict__:
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext: if attr != "__nemubot_context__" and type(mdl.__dict__[attr]) == ModuleContext:
module.__dict__[attr] = module.__nemubot_context__ mdl.__dict__[attr] = mdl.__nemubot_context__
# Register decorated functions # Register decorated functions
import nemubot.hooks import nemubot.hooks
for s, h in nemubot.hooks.hook.last_registered: for s, h in nemubot.hooks.hook.last_registered:
module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s) mdl.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
nemubot.hooks.hook.last_registered = [] nemubot.hooks.hook.last_registered = []
# Launch the module # Launch the module
if hasattr(module, "load"): if hasattr(mdl, "load"):
try: try:
module.load(module.__nemubot_context__) mdl.load(mdl.__nemubot_context__)
except: except:
module.__nemubot_context__.unload() mdl.__nemubot_context__.unload()
raise raise
# Save a reference to the module # Save a reference to the module
self.modules[module_name] = module self.modules[module_name] = mdl
def unload_module(self, name): def unload_module(self, name):
@ -527,28 +550,28 @@ class Bot(threading.Thread):
def quit(self): def quit(self):
"""Save and unload modules and disconnect servers""" """Save and unload modules and disconnect servers"""
self.datastore.close()
if self.event_timer is not None: if self.event_timer is not None:
logger.info("Stop the event timer...") logger.info("Stop the event timer...")
self.event_timer.cancel() self.event_timer.cancel()
logger.info("Save and unload all modules...")
for mod in self.modules.items():
self.unload_module(mod)
logger.info("Close all servers connection...")
for k in self.servers:
self.servers[k].close()
logger.info("Stop consumers") logger.info("Stop consumers")
k = self.cnsr_thrd k = self.cnsr_thrd
for cnsr in k: for cnsr in k:
cnsr.stop = True cnsr.stop = True
logger.info("Save and unload all modules...") self.datastore.close()
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
logger.info("Close all servers connection...")
k = list(self.servers.keys())
for srv in k:
self.servers[srv].close()
self.stop = True self.stop = True
sync_queue.put("end")
sync_queue.join()
# Treatment # Treatment
@ -562,20 +585,3 @@ class Bot(threading.Thread):
del store[hook.name] del store[hook.name]
elif isinstance(store, list): elif isinstance(store, list):
store.remove(hook) store.remove(hook)
def hotswap(bak):
bak.stop = True
if bak.event_timer is not None:
bak.event_timer.cancel()
bak.datastore.close()
new = Bot(str(bak.ip), bak.modules_paths, bak.datastore)
new.servers = bak.servers
new.modules = bak.modules
new.modules_configuration = bak.modules_configuration
new.events = bak.events
new.hooks = bak.hooks
new._update_event_timer()
return new

View file

@ -15,13 +15,18 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
from typing import Optional
from nemubot.message import Abstract as AbstractMessage
class Channel: class Channel:
"""A chat room""" """A chat room"""
def __init__(self, name, password=None, encoding=None): def __init__(self,
name: str,
password: Optional[str] = None,
encoding: Optional[str] = None):
"""Initialize the channel """Initialize the channel
Arguments: Arguments:
@ -37,7 +42,8 @@ class Channel:
self.topic = "" self.topic = ""
self.logger = logging.getLogger("nemubot.channel." + name) self.logger = logging.getLogger("nemubot.channel." + name)
def treat(self, cmd, msg):
def treat(self, cmd: str, msg: AbstractMessage) -> None:
"""Treat a incoming IRC command """Treat a incoming IRC command
Arguments: Arguments:
@ -60,7 +66,8 @@ class Channel:
elif cmd == "TOPIC": elif cmd == "TOPIC":
self.topic = self.text self.topic = self.text
def join(self, nick, level=0):
def join(self, nick: str, level: int = 0) -> None:
"""Someone join the channel """Someone join the channel
Argument: Argument:
@ -71,7 +78,8 @@ class Channel:
self.logger.debug("%s join", nick) self.logger.debug("%s join", nick)
self.people[nick] = level self.people[nick] = level
def chtopic(self, newtopic):
def chtopic(self, newtopic: str) -> None:
"""Send command to change the topic """Send command to change the topic
Arguments: Arguments:
@ -81,7 +89,8 @@ class Channel:
self.srv.send_msg(self.name, newtopic, "TOPIC") self.srv.send_msg(self.name, newtopic, "TOPIC")
self.topic = newtopic self.topic = newtopic
def nick(self, oldnick, newnick):
def nick(self, oldnick: str, newnick: str) -> None:
"""Someone change his nick """Someone change his nick
Arguments: Arguments:
@ -95,7 +104,8 @@ class Channel:
del self.people[oldnick] del self.people[oldnick]
self.people[newnick] = lvl self.people[newnick] = lvl
def part(self, nick):
def part(self, nick: str) -> None:
"""Someone leave the channel """Someone leave the channel
Argument: Argument:
@ -106,7 +116,8 @@ class Channel:
self.logger.debug("%s has left", nick) self.logger.debug("%s has left", nick)
del self.people[nick] del self.people[nick]
def mode(self, msg):
def mode(self, msg: AbstractMessage) -> None:
"""Channel or user mode change """Channel or user mode change
Argument: Argument:
@ -132,7 +143,8 @@ class Channel:
elif msg.text[0] == "-v": elif msg.text[0] == "-v":
self.people[msg.nick] &= ~1 self.people[msg.nick] &= ~1
def parse332(self, msg):
def parse332(self, msg: AbstractMessage) -> None:
"""Parse RPL_TOPIC message """Parse RPL_TOPIC message
Argument: Argument:
@ -141,7 +153,8 @@ class Channel:
self.topic = msg.text self.topic = msg.text
def parse353(self, msg):
def parse353(self, msg: AbstractMessage) -> None:
"""Parse RPL_ENDOFWHO message """Parse RPL_ENDOFWHO message
Argument: Argument:

View file

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def get_boolean(s): def get_boolean(s) -> bool:
if isinstance(s, bool): if isinstance(s, bool):
return s return s
else: else:
@ -24,24 +24,3 @@ from nemubot.config.include import Include
from nemubot.config.module import Module from nemubot.config.module import Module
from nemubot.config.nemubot import Nemubot from nemubot.config.nemubot import Nemubot
from nemubot.config.server import Server from nemubot.config.server import Server
def reload():
global Include, Module, Nemubot, Server
import imp
import nemubot.config.include
imp.reload(nemubot.config.include)
Include = nemubot.config.include.Include
import nemubot.config.module
imp.reload(nemubot.config.module)
Module = nemubot.config.module.Module
import nemubot.config.nemubot
imp.reload(nemubot.config.nemubot)
Nemubot = nemubot.config.nemubot.Nemubot
import nemubot.config.server
imp.reload(nemubot.config.server)
Server = nemubot.config.server.Server

View file

@ -16,5 +16,5 @@
class Include: class Include:
def __init__(self, path): def __init__(self, path: str):
self.path = path self.path = path

View file

@ -15,12 +15,16 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.config import get_boolean from nemubot.config import get_boolean
from nemubot.tools.xmlparser.genericnode import GenericNode from nemubot.datastore.nodes.generic import GenericNode
class Module(GenericNode): class Module(GenericNode):
def __init__(self, name, autoload=True, **kwargs): def __init__(self,
name: str,
autoload: bool = True,
**kwargs):
super().__init__(None, **kwargs) super().__init__(None, **kwargs)
self.name = name self.name = name
self.autoload = get_boolean(autoload) self.autoload = get_boolean(autoload)

View file

@ -14,15 +14,23 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.config.include import Include from typing import Optional, Sequence, Union
from nemubot.config.module import Module
from nemubot.config.server import Server import nemubot.config.include
import nemubot.config.module
import nemubot.config.server
class Nemubot: class Nemubot:
def __init__(self, nick="nemubot", realname="nemubot", owner=None, def __init__(self,
ip=None, ssl=False, caps=None, encoding="utf-8"): nick: str = "nemubot",
realname: str = "nemubot",
owner: Optional[str] = None,
ip: Optional[str] = None,
ssl: bool = False,
caps: Optional[Sequence[str]] = None,
encoding: str = "utf-8"):
self.nick = nick self.nick = nick
self.realname = realname self.realname = realname
self.owner = owner self.owner = owner
@ -34,13 +42,13 @@ class Nemubot:
self.includes = [] self.includes = []
def addChild(self, name, child): def addChild(self, name: str, child: Union[nemubot.config.module.Module, nemubot.config.server.Server, nemubot.config.include.Include]):
if name == "module" and isinstance(child, Module): if name == "module" and isinstance(child, nemubot.config.module.Module):
self.modules.append(child) self.modules.append(child)
return True return True
elif name == "server" and isinstance(child, Server): elif name == "server" and isinstance(child, nemubot.config.server.Server):
self.servers.append(child) self.servers.append(child)
return True return True
elif name == "include" and isinstance(child, Include): elif name == "include" and isinstance(child, nemubot.config.include.Include):
self.includes.append(child) self.includes.append(child)
return True return True

View file

@ -14,12 +14,19 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, Sequence
from nemubot.channel import Channel from nemubot.channel import Channel
import nemubot.config.nemubot
class Server: class Server:
def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs): def __init__(self,
uri: str = "irc://nemubot@localhost/",
autoconnect: bool = True,
caps: Optional[Sequence[str]] = None,
**kwargs):
self.uri = uri self.uri = uri
self.autoconnect = autoconnect self.autoconnect = autoconnect
self.caps = caps.split(" ") if caps is not None else [] self.caps = caps.split(" ") if caps is not None else []
@ -27,7 +34,7 @@ class Server:
self.channels = [] self.channels = []
def addChild(self, name, child): def addChild(self, name: str, child: Channel):
if name == "channel" and isinstance(child, Channel): if name == "channel" and isinstance(child, Channel):
self.channels.append(child) self.channels.append(child)
return True return True

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,6 +17,12 @@
import logging import logging
import queue import queue
import threading import threading
from typing import List
from nemubot.bot import Bot
from nemubot.event import ModuleEvent
from nemubot.message.abstract import Abstract as AbstractMessage
from nemubot.server.abstract import AbstractServer
logger = logging.getLogger("nemubot.consumer") logger = logging.getLogger("nemubot.consumer")
@ -25,18 +31,15 @@ class MessageConsumer:
"""Store a message before treating""" """Store a message before treating"""
def __init__(self, srv, msg): def __init__(self, srv: AbstractServer, msg: AbstractMessage):
self.srv = srv self.srv = srv
self.orig = msg self.orig = msg
def run(self, context): def run(self, context: Bot) -> None:
"""Create, parse and treat the message""" """Create, parse and treat the message"""
from nemubot.bot import Bot msgs = [] # type: List[AbstractMessage]
assert isinstance(context, Bot)
msgs = []
# Parse the message # Parse the message
try: try:
@ -44,14 +47,14 @@ class MessageConsumer:
msgs.append(msg) msgs.append(msg)
except: except:
logger.exception("Error occurred during the processing of the %s: " logger.exception("Error occurred during the processing of the %s: "
"%s", type(self.msgs[0]).__name__, self.msgs[0]) "%s", type(self.orig).__name__, self.orig)
if len(msgs) <= 0: if len(msgs) <= 0:
return return
# Qualify the message # Qualify the message
if not hasattr(msg, "server") or msg.server is None: if not hasattr(msg, "server") or msg.server is None:
msg.server = self.srv.id msg.server = self.srv.name
if hasattr(msg, "frm_owner"): if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm) msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
@ -62,15 +65,19 @@ class MessageConsumer:
to_server = None to_server = None
if isinstance(res, str): if isinstance(res, str):
to_server = self.srv to_server = self.srv
elif not hasattr(res, "server"):
logger.error("No server defined for response of type %s: %s", type(res).__name__, res)
continue
elif res.server is None: elif res.server is None:
to_server = self.srv to_server = self.srv
res.server = self.srv.id res.server = self.srv.name
elif isinstance(res.server, str) and res.server in context.servers: elif res.server in context.servers:
to_server = context.servers[res.server] to_server = context.servers[res.server]
else:
to_server = res.server
if to_server is None: if to_server is None or not hasattr(to_server, "send_response") or not callable(to_server.send_response):
logger.error("The server defined in this response doesn't " logger.error("The server defined in this response doesn't exist: %s", res.server)
"exist: %s", res.server)
continue continue
# Sent the message only if treat_post authorize it # Sent the message only if treat_post authorize it
@ -81,12 +88,12 @@ class EventConsumer:
"""Store a event before treating""" """Store a event before treating"""
def __init__(self, evt, timeout=20): def __init__(self, evt: ModuleEvent, timeout: int = 20):
self.evt = evt self.evt = evt
self.timeout = timeout self.timeout = timeout
def run(self, context): def run(self, context: Bot) -> None:
try: try:
self.evt.check() self.evt.check()
except: except:
@ -107,13 +114,13 @@ class Consumer(threading.Thread):
"""Dequeue and exec requested action""" """Dequeue and exec requested action"""
def __init__(self, context): def __init__(self, context: Bot):
self.context = context self.context = context
self.stop = False self.stop = False
threading.Thread.__init__(self) super().__init__(name="Nemubot consumer")
def run(self): def run(self) -> None:
try: try:
while not self.stop: while not self.stop:
stm = self.context.cnsr_queue.get(True, 1) stm = self.context.cnsr_queue.get(True, 1)

View file

@ -16,16 +16,3 @@
from nemubot.datastore.abstract import Abstract from nemubot.datastore.abstract import Abstract
from nemubot.datastore.xml import XML from nemubot.datastore.xml import XML
def reload():
global Abstract, XML
import imp
import nemubot.datastore.abstract
imp.reload(nemubot.datastore.abstract)
Abstract = nemubot.datastore.abstract.Abstract
import nemubot.datastore.xml
imp.reload(nemubot.datastore.xml)
XML = nemubot.datastore.xml.XML

View file

@ -23,14 +23,16 @@ class Abstract:
"""Initialize a new empty storage tree """Initialize a new empty storage tree
""" """
from nemubot.tools.xmlparser import module_state return None
return module_state.ModuleState("nemubotstate")
def open(self):
return
def close(self): def open(self) -> bool:
return return True
def close(self) -> bool:
return True
def load(self, module): def load(self, module):
"""Load data for the given module """Load data for the given module
@ -44,6 +46,7 @@ class Abstract:
return self.new() return self.new()
def save(self, module, data): def save(self, module, data):
"""Load data for the given module """Load data for the given module
@ -57,9 +60,11 @@ class Abstract:
return True return True
def __enter__(self): def __enter__(self):
self.open() self.open()
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self.close() self.close()

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,8 +14,5 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
class PromptError(Exception): from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
def __init__(self, message):
super(PromptError, self).__init__(message)
self.message = message

View file

@ -14,27 +14,38 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
class ListNode: from typing import Any, Mapping, Sequence
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
class ListNode(Serializable):
"""XML node representing a Python dictionnnary """XML node representing a Python dictionnnary
""" """
serializetag = "list"
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.items = list() self.items = list() # type: Sequence
def addChild(self, name, child): def addChild(self, name: str, child) -> bool:
self.items.append(child) self.items.append(child)
return True return True
def parsedForm(self) -> Sequence:
return self.items
def __len__(self):
def __len__(self) -> int:
return len(self.items) return len(self.items)
def __getitem__(self, item): def __getitem__(self, item: int) -> Any:
return self.items[item] return self.items[item]
def __setitem__(self, item, v): def __setitem__(self, item: int, v: Any) -> None:
self.items[item] = v self.items[item] = v
def __contains__(self, item): def __contains__(self, item):
@ -44,65 +55,60 @@ class ListNode:
return self.items.__repr__() return self.items.__repr__()
class DictNode: def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
for i in self.items:
node.children.append(ParsingNode.serialize_node(i))
return node
class DictNode(Serializable):
"""XML node representing a Python dictionnnary """XML node representing a Python dictionnnary
""" """
serializetag = "dict"
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.items = dict() self.items = dict()
self._cur = None self._cur = None
def startElement(self, name, attrs): def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None and "key" in attrs: if self._cur is None and "key" in attrs:
self._cur = (attrs["key"], "") self._cur = attrs["key"]
return True
return False return False
def addChild(self, name: str, child: Any):
def characters(self, content):
if self._cur is not None:
key, cnt = self._cur
if isinstance(cnt, str):
cnt += content
self._cur = key, cnt
def endElement(self, name):
if name is None or self._cur is None:
return
key, cnt = self._cur
if isinstance(cnt, list) and len(cnt) == 1:
self.items[key] = cnt
else:
self.items[key] = cnt
self._cur = None
return True
def addChild(self, name, child):
if self._cur is None: if self._cur is None:
return False return False
key, cnt = self._cur self.items[self._cur] = child
if not isinstance(cnt, list): self._cur = None
cnt = []
cnt.append(child)
self._cur = key, cnt
return True return True
def parsedForm(self) -> Mapping:
return self.items
def __getitem__(self, item):
def __getitem__(self, item: str) -> Any:
return self.items[item] return self.items[item]
def __setitem__(self, item, v): def __setitem__(self, item: str, v: str) -> None:
self.items[item] = v self.items[item] = v
def __contains__(self, item): def __contains__(self, item: str) -> bool:
return item in self.items return item in self.items
def __repr__(self): def __repr__(self) -> str:
return self.items.__repr__() return self.items.__repr__()
def serialize(self) -> ParsingNode:
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
for k in self.items:
chld = ParsingNode.serialize_node(self.items[k])
chld.attrs["key"] = k
node.children.append(chld)
return node

View file

@ -14,57 +14,109 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Any, Optional, Mapping, Union
from nemubot.datastore.nodes.serializable import Serializable
class ParsingNode: class ParsingNode:
"""Allow any kind of subtags, just keep parsed ones """Allow any kind of subtags, just keep parsed ones
""" """
def __init__(self, tag=None, **kwargs): def __init__(self,
tag: Optional[str] = None,
**kwargs):
self.tag = tag self.tag = tag
self.attrs = kwargs self.attrs = kwargs
self.content = "" self.content = ""
self.children = [] self.children = []
def characters(self, content): def characters(self, content: str) -> None:
self.content += content self.content += content
def addChild(self, name, child): def addChild(self, name: str, child: Any) -> bool:
self.children.append(child) self.children.append(child)
return True return True
def hasNode(self, nodename): def hasNode(self, nodename: str) -> bool:
return self.getNode(nodename) is not None return self.getNode(nodename) is not None
def getNode(self, nodename): def getNode(self, nodename: str) -> Optional[Any]:
for c in self.children: for c in self.children:
if c is not None and c.tag == nodename: if c is not None and c.tag == nodename:
return c return c
return None return None
def __getitem__(self, item): def __getitem__(self, item: str) -> Any:
return self.attrs[item] return self.attrs[item]
def __contains__(self, item): def __contains__(self, item: str) -> bool:
return item in self.attrs return item in self.attrs
def serialize_node(node: Union[Serializable, str, int, float, list, dict],
**def_kwargs):
"""Serialize any node or basic data to a ParsingNode instance"""
if isinstance(node, Serializable):
node = node.serialize()
if isinstance(node, str):
from nemubot.datastore.nodes.python import StringNode
pn = StringNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, int):
from nemubot.datastore.nodes.python import IntNode
pn = IntNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, float):
from nemubot.datastore.nodes.python import FloatNode
pn = FloatNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, list):
from nemubot.datastore.nodes.basic import ListNode
pn = ListNode(**def_kwargs)
pn.items = node
return pn.serialize()
elif isinstance(node, dict):
from nemubot.datastore.nodes.basic import DictNode
pn = DictNode(**def_kwargs)
pn.items = node
return pn.serialize()
else:
assert isinstance(node, ParsingNode)
return node
class GenericNode(ParsingNode): class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary """Consider all subtags as dictionnary
""" """
def __init__(self, tag, **kwargs): def __init__(self,
tag: str,
**kwargs):
super().__init__(tag, **kwargs) super().__init__(tag, **kwargs)
self._cur = None self._cur = None
self._deep_cur = 0 self._deep_cur = 0
def startElement(self, name, attrs): def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None: if self._cur is None:
self._cur = GenericNode(name, **attrs) self._cur = GenericNode(name, **attrs)
self._deep_cur = 0 self._deep_cur = 0
@ -74,14 +126,14 @@ class GenericNode(ParsingNode):
return True return True
def characters(self, content): def characters(self, content: str):
if self._cur is None: if self._cur is None:
super().characters(content) super().characters(content)
else: else:
self._cur.characters(content) self._cur.characters(content)
def endElement(self, name): def endElement(self, name: str):
if name is None: if name is None:
return return

View file

@ -0,0 +1,89 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
class PythonTypeNode(Serializable):
"""XML node representing a Python simple type
"""
def __init__(self, **kwargs):
self.value = None
self._cnt = ""
def characters(self, content: str) -> None:
self._cnt += content
def endElement(self, name: str) -> None:
raise NotImplemented
def __repr__(self) -> str:
return self.value.__repr__()
def parsedForm(self):
return self.value
def serialize(self):
raise NotImplemented
class IntNode(PythonTypeNode):
serializetag = "int"
def endElement(self, name: str) -> bool:
self.value = int(self._cnt)
return True
def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
class FloatNode(PythonTypeNode):
serializetag = "float"
def endElement(self, name: str) -> bool:
self.value = float(self._cnt)
return True
def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
class StringNode(PythonTypeNode):
serializetag = "str"
def endElement(self, name):
self.value = str(self._cnt)
return True
def serialize(self):
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node

View file

@ -1,7 +1,5 @@
# coding=utf-8
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -16,8 +14,9 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
class PromptReset(Exception):
def __init__(self, type): class Serializable:
super(PromptReset, self).__init__("Prompt reset asked")
self.type = type def serialize(self):
# Implementations of this function should return ParsingNode items
return NotImplemented

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -17,6 +17,7 @@
import fcntl import fcntl
import logging import logging
import os import os
from typing import Any, Mapping
import xml.parsers.expat import xml.parsers.expat
from nemubot.datastore.abstract import Abstract from nemubot.datastore.abstract import Abstract
@ -28,7 +29,9 @@ class XML(Abstract):
"""A concrete implementation of a data store that relies on XML files""" """A concrete implementation of a data store that relies on XML files"""
def __init__(self, basedir, rotate=True): def __init__(self,
basedir: str,
rotate: bool = True):
"""Initialize the datastore """Initialize the datastore
Arguments: Arguments:
@ -36,17 +39,24 @@ class XML(Abstract):
rotate -- auto-backup files? rotate -- auto-backup files?
""" """
self.basedir = basedir self.basedir = os.path.abspath(basedir)
self.rotate = rotate self.rotate = rotate
self.nb_save = 0 self.nb_save = 0
def open(self): logger.info("Initiate XML datastore at %s, rotation %s",
self.basedir,
"enabled" if self.rotate else "disabled")
def open(self) -> bool:
"""Lock the directory""" """Lock the directory"""
if not os.path.isdir(self.basedir): if not os.path.isdir(self.basedir):
logger.debug("Datastore directory not found, creating: %s", self.basedir)
os.mkdir(self.basedir) os.mkdir(self.basedir)
lock_path = os.path.join(self.basedir, ".used_by_nemubot") lock_path = self._get_lock_file_path()
logger.debug("Locking datastore directory via %s", lock_path)
self.lock_file = open(lock_path, 'a+') self.lock_file = open(lock_path, 'a+')
ok = True ok = True
@ -64,57 +74,92 @@ class XML(Abstract):
self.lock_file.write(str(os.getpid())) self.lock_file.write(str(os.getpid()))
self.lock_file.flush() self.lock_file.flush()
logger.info("Datastore successfuly opened at %s", self.basedir)
return True return True
def close(self):
def close(self) -> bool:
"""Release a locked path""" """Release a locked path"""
if hasattr(self, "lock_file"): if hasattr(self, "lock_file"):
self.lock_file.close() self.lock_file.close()
lock_path = os.path.join(self.basedir, ".used_by_nemubot") lock_path = self._get_lock_file_path()
if os.path.isdir(self.basedir) and os.path.exists(lock_path): if os.path.isdir(self.basedir) and os.path.exists(lock_path):
os.unlink(lock_path) os.unlink(lock_path)
del self.lock_file del self.lock_file
logger.info("Datastore successfully closed at %s", self.basedir)
return True return True
else:
logger.warn("Datastore not open/locked or lock file not found")
return False return False
def _get_data_file_path(self, module):
def _get_data_file_path(self, module: str) -> str:
"""Get the path to the module data file""" """Get the path to the module data file"""
return os.path.join(self.basedir, module + ".xml") return os.path.join(self.basedir, module + ".xml")
def load(self, module):
def _get_lock_file_path(self) -> str:
"""Get the path to the datastore lock file"""
return os.path.join(self.basedir, ".used_by_nemubot")
def load(self, module: str, extendsTags: Mapping[str, Any] = {}) -> Abstract:
"""Load data for the given module """Load data for the given module
Argument: Argument:
module -- the module name of data to load module -- the module name of data to load
""" """
logger.debug("Trying to load data for %s%s",
module,
(" with tags: " + ", ".join(extendsTags.keys())) if len(extendsTags) else "")
data_file = self._get_data_file_path(module) data_file = self._get_data_file_path(module)
def parse(path: str):
from nemubot.tools.xmlparser import XMLParser
from nemubot.datastore.nodes import basic as basicNodes
from nemubot.datastore.nodes import python as pythonNodes
d = {
basicNodes.ListNode.serializetag: basicNodes.ListNode,
basicNodes.DictNode.serializetag: basicNodes.DictNode,
pythonNodes.IntNode.serializetag: pythonNodes.IntNode,
pythonNodes.FloatNode.serializetag: pythonNodes.FloatNode,
pythonNodes.StringNode.serializetag: pythonNodes.StringNode,
}
d.update(extendsTags)
p = XMLParser(d)
return p.parse_file(path)
# Try to load original file # Try to load original file
if os.path.isfile(data_file): if os.path.isfile(data_file):
from nemubot.tools.xmlparser import parse_file
try: try:
return parse_file(data_file) return parse(data_file)
except xml.parsers.expat.ExpatError: except xml.parsers.expat.ExpatError:
# Try to load from backup # Try to load from backup
for i in range(10): for i in range(10):
path = data_file + "." + str(i) path = data_file + "." + str(i)
if os.path.isfile(path): if os.path.isfile(path):
try: try:
cnt = parse_file(path) cnt = parse(path)
logger.warn("Restoring from backup: %s", path) logger.warn("Restoring data from backup: %s", path)
return cnt return cnt
except xml.parsers.expat.ExpatError: except xml.parsers.expat.ExpatError:
continue continue
# Default case: initialize a new empty datastore # Default case: initialize a new empty datastore
logger.warn("No data found in store for %s, creating new set", module)
return Abstract.load(self, module) return Abstract.load(self, module)
def _rotate(self, path):
def _rotate(self, path: str) -> None:
"""Backup given path """Backup given path
Argument: Argument:
@ -130,7 +175,26 @@ class XML(Abstract):
if os.path.isfile(src): if os.path.isfile(src):
os.rename(src, dst) os.rename(src, dst)
def save(self, module, data):
def _save_node(self, gen, node: Any):
from nemubot.datastore.nodes.generic import ParsingNode
# First, get the serialized form of the node
node = ParsingNode.serialize_node(node)
assert node.tag is not None, "Undefined tag name"
gen.startElement(node.tag, {k: str(node.attrs[k]) for k in node.attrs})
gen.characters(node.content)
for child in node.children:
self._save_node(gen, child)
gen.endElement(node.tag)
def save(self, module: str, data: Any) -> bool:
"""Load data for the given module """Load data for the given module
Argument: Argument:
@ -139,8 +203,22 @@ class XML(Abstract):
""" """
path = self._get_data_file_path(module) path = self._get_data_file_path(module)
logger.debug("Trying to save data for module %s in %s", module, path)
if self.rotate: if self.rotate:
self._rotate(path) self._rotate(path)
return data.save(path) import tempfile
_, tmpath = tempfile.mkstemp()
with open(tmpath, "w") as f:
import xml.sax.saxutils
gen = xml.sax.saxutils.XMLGenerator(f, "utf-8")
gen.startDocument()
self._save_node(gen, data)
gen.endDocument()
# Atomic save
import shutil
shutil.move(tmpath, path)
return True

View file

@ -15,15 +15,23 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Optional
class ModuleEvent: class ModuleEvent:
"""Representation of a event initiated by a bot module""" """Representation of a event initiated by a bot module"""
def __init__(self, call=None, call_data=None, func=None, func_data=None, def __init__(self,
cmp=None, cmp_data=None, interval=60, offset=0, times=1): call: Callable = None,
call_data: Callable = None,
func: Callable = None,
func_data: Any = None,
cmp: Any = None,
cmp_data: Any = None,
interval: int = 60,
offset: int = 0,
times: int = 1):
"""Initialize the event """Initialize the event
Keyword arguments: Keyword arguments:
@ -70,8 +78,9 @@ class ModuleEvent:
# How many times do this event? # How many times do this event?
self.times = times self.times = times
@property @property
def current(self): def current(self) -> Optional[datetime.datetime]:
"""Return the date of the near check""" """Return the date of the near check"""
if self.times != 0: if self.times != 0:
if self._end is None: if self._end is None:
@ -79,8 +88,9 @@ class ModuleEvent:
return self._end return self._end
return None return None
@property @property
def next(self): def next(self) -> Optional[datetime.datetime]:
"""Return the date of the next check""" """Return the date of the next check"""
if self.times != 0: if self.times != 0:
if self._end is None: if self._end is None:
@ -90,14 +100,16 @@ class ModuleEvent:
return self._end return self._end
return None return None
@property @property
def time_left(self): def time_left(self) -> Union[datetime.datetime, int]:
"""Return the time left before/after the near check""" """Return the time left before/after the near check"""
if self.current is not None: if self.current is not None:
return self.current - datetime.now(timezone.utc) return self.current - datetime.now(timezone.utc)
return 99999 # TODO: 99999 is not a valid time to return return 99999 # TODO: 99999 is not a valid time to return
def check(self):
def check(self) -> None:
"""Run a check and realized the event if this is time""" """Run a check and realized the event if this is time"""
# Get initial data # Get initial data

View file

@ -32,10 +32,3 @@ class IMException(Exception):
from nemubot.message import Text from nemubot.message import Text
return Text(*self.args, return Text(*self.args,
server=msg.server, to=msg.to_response) server=msg.server, to=msg.to_response)
def reload():
import imp
import nemubot.exception.Keyword
imp.reload(nemubot.exception.printer.IRC)

View file

@ -49,23 +49,3 @@ class hook:
def pre(*args, store=["pre"], **kwargs): def pre(*args, store=["pre"], **kwargs):
return hook._add(store, Abstract, *args, **kwargs) return hook._add(store, Abstract, *args, **kwargs)
def reload():
import imp
import nemubot.hooks.abstract
imp.reload(nemubot.hooks.abstract)
import nemubot.hooks.command
imp.reload(nemubot.hooks.command)
import nemubot.hooks.message
imp.reload(nemubot.hooks.message)
import nemubot.hooks.keywords
imp.reload(nemubot.hooks.keywords)
nemubot.hooks.keywords.reload()
import nemubot.hooks.manager
imp.reload(nemubot.hooks.manager)

View file

@ -26,11 +26,22 @@ class NoKeyword(Abstract):
return super().check(mkw) return super().check(mkw)
def reload(): class AnyKeyword(Abstract):
import imp
import nemubot.hooks.keywords.abstract def __init__(self, h):
imp.reload(nemubot.hooks.keywords.abstract) """Class that accepts any passed keywords
import nemubot.hooks.keywords.dict Arguments:
imp.reload(nemubot.hooks.keywords.dict) h -- Help string
"""
super().__init__()
self.h = h
def check(self, mkw):
return super().check(mkw)
def help(self):
return self.h

View file

@ -21,7 +21,7 @@ class HooksManager:
"""Class to manage hooks""" """Class to manage hooks"""
def __init__(self, name="core"): def __init__(self, name: str = "core"):
"""Initialize the manager""" """Initialize the manager"""
self.hooks = dict() self.hooks = dict()

View file

@ -18,16 +18,18 @@ from importlib.abc import Finder
from importlib.machinery import SourceFileLoader from importlib.machinery import SourceFileLoader
import logging import logging
import os import os
from typing import Callable
logger = logging.getLogger("nemubot.importer") logger = logging.getLogger("nemubot.importer")
class ModuleFinder(Finder): class ModuleFinder(Finder):
def __init__(self, modules_paths, add_module): def __init__(self, modules_paths: str, add_module: Callable[]):
self.modules_paths = modules_paths self.modules_paths = modules_paths
self.add_module = add_module self.add_module = add_module
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
# Search only for new nemubot modules (packages init) # Search only for new nemubot modules (packages init)
if path is None: if path is None:

View file

@ -19,27 +19,3 @@ from nemubot.message.text import Text
from nemubot.message.directask import DirectAsk from nemubot.message.directask import DirectAsk
from nemubot.message.command import Command from nemubot.message.command import Command
from nemubot.message.command import OwnerCommand from nemubot.message.command import OwnerCommand
def reload():
global Abstract, Text, DirectAsk, Command, OwnerCommand
import imp
import nemubot.message.abstract
imp.reload(nemubot.message.abstract)
Abstract = nemubot.message.abstract.Abstract
imp.reload(nemubot.message.text)
Text = nemubot.message.text.Text
imp.reload(nemubot.message.directask)
DirectAsk = nemubot.message.directask.DirectAsk
imp.reload(nemubot.message.command)
Command = nemubot.message.command.Command
OwnerCommand = nemubot.message.command.OwnerCommand
import nemubot.message.visitor
imp.reload(nemubot.message.visitor)
import nemubot.message.printer
imp.reload(nemubot.message.printer)
nemubot.message.printer.reload()

View file

@ -22,7 +22,7 @@ class Command(Abstract):
"""This class represents a specialized TextMessage""" """This class represents a specialized TextMessage"""
def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs): def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs):
Abstract.__init__(self, *nargs, **kargs) super().__init__(*nargs, **kargs)
self.cmd = cmd self.cmd = cmd
self.args = args if args is not None else list() self.args = args if args is not None else list()

View file

@ -28,7 +28,7 @@ class DirectAsk(Text):
designated -- the user designated by the message designated -- the user designated by the message
""" """
Text.__init__(self, *args, **kargs) super().__init__(*args, **kargs)
self.designated = designated self.designated = designated

View file

@ -22,4 +22,4 @@ class IRC(SocketPrinter):
def visit_Text(self, msg): def visit_Text(self, msg):
self.pp += "PRIVMSG %s :" % ",".join(msg.to) self.pp += "PRIVMSG %s :" % ",".join(msg.to)
SocketPrinter.visit_Text(self, msg) super().visit_Text(msg)

View file

@ -13,12 +13,3 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def reload():
import imp
import nemubot.message.printer.IRC
imp.reload(nemubot.message.printer.IRC)
import nemubot.message.printer.socket
imp.reload(nemubot.message.printer.socket)

View file

@ -35,7 +35,7 @@ class Socket(AbstractVisitor):
others = [to for to in msg.to if to != msg.designated] others = [to for to in msg.to if to != msg.designated]
# Avoid nick starting message when discussing on user channel # Avoid nick starting message when discussing on user channel
if len(others) != len(msg.to): if len(others) == 0 or len(others) != len(msg.to):
res = Text(msg.message, res = Text(msg.message,
server=msg.server, date=msg.date, server=msg.server, date=msg.date,
to=msg.to, frm=msg.frm) to=msg.to, frm=msg.frm)

View file

@ -0,0 +1,34 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.message.abstract import Abstract
class Response(Abstract):
def __init__(self, cmd, args=None, *nargs, **kargs):
super().__init__(*nargs, **kargs)
self.cmd = cmd
self.args = args if args is not None else list()
def __str__(self):
return self.cmd + " @" + ",@".join(self.args)
@property
def cmds(self):
# TODO: this is for legacy modules
return [self.cmd] + self.args

View file

@ -28,7 +28,7 @@ class Text(Abstract):
message -- the parsed message message -- the parsed message
""" """
Abstract.__init__(self, *args, **kargs) super().__init__(*args, **kargs)
self.message = message self.message = message

View file

@ -39,6 +39,7 @@ class ModuleContext:
self.hooks = list() self.hooks = list()
self.events = list() self.events = list()
self.extendtags = dict()
self.debug = context.verbosity > 0 if context is not None else False self.debug = context.verbosity > 0 if context is not None else False
from nemubot.hooks import Abstract as AbstractHook from nemubot.hooks import Abstract as AbstractHook
@ -46,7 +47,7 @@ class ModuleContext:
# Define some callbacks # Define some callbacks
if context is not None: if context is not None:
def load_data(): def load_data():
return context.datastore.load(module_name) return context.datastore.load(module_name, extendsTags=self.extendtags)
def add_hook(hook, *triggers): def add_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook assert isinstance(hook, AbstractHook), hook
@ -77,8 +78,7 @@ class ModuleContext:
else: # Used when using outside of nemubot else: # Used when using outside of nemubot
def load_data(): def load_data():
from nemubot.tools.xmlparser import module_state return None
return module_state.ModuleState("nemubotstate")
def add_hook(hook, *triggers): def add_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook assert isinstance(hook, AbstractHook), hook
@ -97,7 +97,9 @@ class ModuleContext:
module.logger.info("Send response: %s", res) module.logger.info("Send response: %s", res)
def save(): def save():
context.datastore.save(module_name, self.data) # Don't save if no data has been access
if hasattr(self, "_data"):
context.datastore.save(module_name, self.data)
def subparse(orig, cnt): def subparse(orig, cnt):
if orig.server in context.servers: if orig.server in context.servers:
@ -120,6 +122,21 @@ class ModuleContext:
self._data = self.load_data() self._data = self.load_data()
return self._data return self._data
@data.setter
def data(self, value):
assert value is not None
self._data = value
def register_tags(self, **tags):
self.extendtags.update(tags)
def unregister_tags(self, *tags):
for t in tags:
del self.extendtags[t]
def unload(self): def unload(self):
"""Perform actions for unloading the module""" """Perform actions for unloading the module"""

View file

@ -1,142 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import shlex
import sys
import traceback
from nemubot.prompt import builtins
class Prompt:
def __init__(self):
self.selectedServer = None
self.lastretcode = 0
self.HOOKS_CAPS = dict()
self.HOOKS_LIST = dict()
def add_cap_hook(self, name, call, data=None):
self.HOOKS_CAPS[name] = lambda t, c: call(t, data=data,
context=c, prompt=self)
def add_list_hook(self, name, call):
self.HOOKS_LIST[name] = call
def lex_cmd(self, line):
"""Return an array of tokens
Argument:
line -- the line to lex
"""
try:
cmds = shlex.split(line)
except:
exc_type, exc_value, _ = sys.exc_info()
sys.stderr.write(traceback.format_exception_only(exc_type,
exc_value)[0])
return
bgn = 0
# Separate commands (command separator: ;)
for i in range(0, len(cmds)):
if cmds[i][-1] == ';':
if i != bgn:
yield cmds[bgn:i]
bgn = i + 1
# Return rest of the command (that not end with a ;)
if bgn != len(cmds):
yield cmds[bgn:]
def exec_cmd(self, toks, context):
"""Execute the command
Arguments:
toks -- lexed tokens to executes
context -- current bot context
"""
if toks[0] in builtins.CAPS:
self.lastretcode = builtins.CAPS[toks[0]](toks, context, self)
elif toks[0] in self.HOOKS_CAPS:
self.lastretcode = self.HOOKS_CAPS[toks[0]](toks, context)
else:
print("Unknown command: `%s'" % toks[0])
self.lastretcode = 127
def getPS1(self):
"""Get the PS1 associated to the selected server"""
if self.selectedServer is None:
return "nemubot"
else:
return self.selectedServer.id
def run(self, context):
"""Launch the prompt
Argument:
context -- current bot context
"""
from nemubot.prompt.error import PromptError
from nemubot.prompt.reset import PromptReset
while True: # Stopped by exception
try:
line = input("\033[0;33m%s\033[0;%d\033[0m " %
(self.getPS1(), 31 if self.lastretcode else 32))
cmds = self.lex_cmd(line.strip())
for toks in cmds:
try:
self.exec_cmd(toks, context)
except PromptReset:
raise
except PromptError as e:
print(e.message)
self.lastretcode = 128
except:
exc_type, exc_value, exc_traceback = sys.exc_info()
traceback.print_exception(exc_type, exc_value,
exc_traceback)
except KeyboardInterrupt:
print("")
except EOFError:
print("quit")
return True
def hotswap(bak):
p = Prompt()
p.HOOKS_CAPS = bak.HOOKS_CAPS
p.HOOKS_LIST = bak.HOOKS_LIST
return p
def reload():
import imp
import nemubot.prompt.builtins
imp.reload(nemubot.prompt.builtins)
import nemubot.prompt.error
imp.reload(nemubot.prompt.error)
import nemubot.prompt.reset
imp.reload(nemubot.prompt.reset)

View file

@ -1,128 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def end(toks, context, prompt):
"""Quit the prompt for reload or exit"""
from nemubot.prompt.reset import PromptReset
if toks[0] == "refresh":
raise PromptReset("refresh")
elif toks[0] == "reset":
raise PromptReset("reset")
raise PromptReset("quit")
def liste(toks, context, prompt):
"""Show some lists"""
if len(toks) > 1:
for l in toks[1:]:
l = l.lower()
if l == "server" or l == "servers":
for srv in context.servers.keys():
print (" - %s (state: %s) ;" % (srv,
"connected" if context.servers[srv].connected else "disconnected"))
if len(context.servers) == 0:
print (" > No server loaded")
elif l == "mod" or l == "mods" or l == "module" or l == "modules":
for mod in context.modules.keys():
print (" - %s ;" % mod)
if len(context.modules) == 0:
print (" > No module loaded")
elif l in prompt.HOOKS_LIST:
f, d = prompt.HOOKS_LIST[l]
f(d, context, prompt)
else:
print (" Unknown list `%s'" % l)
return 2
return 0
else:
print (" Please give a list to show: servers, ...")
return 1
def load(toks, context, prompt):
"""Load an XML configuration file"""
if len(toks) > 1:
for filename in toks[1:]:
context.load_file(filename)
else:
print ("Not enough arguments. `load' takes a filename.")
return 1
def select(toks, context, prompt):
"""Select the current server"""
if (len(toks) == 2 and toks[1] != "None" and
toks[1] != "nemubot" and toks[1] != "none"):
if toks[1] in context.servers:
prompt.selectedServer = context.servers[toks[1]]
else:
print ("select: server `%s' not found." % toks[1])
return 1
else:
prompt.selectedServer = None
def unload(toks, context, prompt):
"""Unload a module"""
if len(toks) == 2 and toks[1] == "all":
for name in context.modules.keys():
context.unload_module(name)
elif len(toks) > 1:
for name in toks[1:]:
if context.unload_module(name):
print (" Module `%s' successfully unloaded." % name)
else:
print (" No module `%s' loaded, can't unload!" % name)
return 2
else:
print ("Not enough arguments. `unload' takes a module name.")
return 1
def debug(toks, context, prompt):
"""Enable/Disable debug mode on a module"""
if len(toks) > 1:
for name in toks[1:]:
if name in context.modules:
context.modules[name].DEBUG = not context.modules[name].DEBUG
if context.modules[name].DEBUG:
print (" Module `%s' now in DEBUG mode." % name)
else:
print (" Debug for module module `%s' disabled." % name)
else:
print (" No module `%s' loaded, can't debug!" % name)
return 2
else:
print ("Not enough arguments. `debug' takes a module name.")
return 1
# Register build-ins
CAPS = {
'quit': end, # Disconnect all server and quit
'exit': end, # Alias for quit
'reset': end, # Reload the prompt
'refresh': end, # Reload the prompt but save modules
'load': load, # Load a servers or module configuration file
'unload': unload, # Unload a module and remove it from the list
'select': select, # Select a server
'list': liste, # Show lists
'debug': debug, # Pass a module in debug mode
}

View file

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer): class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None): def __init__(self, srv, dest, socket=None):
server.Server.__init__(self) super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection? self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion self.messages = list() # Message queued before connexion

View file

@ -27,10 +27,10 @@ class IRC(SocketServer):
"""Concrete implementation of a connexion to an IRC server""" """Concrete implementation of a connexion to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None, def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None, nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None, realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None): channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server """Prepare a connection with an IRC server
Keyword arguments: Keyword arguments:
@ -54,8 +54,8 @@ class IRC(SocketServer):
self.owner = owner self.owner = owner
self.realname = realname self.realname = realname
self.id = self.username + "@" + host + ":" + str(port) super().__init__(name=self.username + "@" + host + ":" + str(port),
SocketServer.__init__(self, host=host, port=port, ssl=ssl) host=host, port=port, **kwargs)
self.printer = IRCPrinter self.printer = IRCPrinter
self.encoding = encoding self.encoding = encoding
@ -232,29 +232,29 @@ class IRC(SocketServer):
# Open/close # Open/close
def _open(self): def connect(self):
if SocketServer._open(self): super().connect()
if self.password is not None:
self.write("PASS :" + self.password) if self.password is not None:
if self.capabilities is not None: self.write("PASS :" + self.password)
self.write("CAP LS") if self.capabilities is not None:
self.write("NICK :" + self.nick) self.write("CAP LS")
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) self.write("NICK :" + self.nick)
return True self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
return False
def _close(self): def close(self):
if self.connected: self.write("QUIT") if not self._closed:
return SocketServer._close(self) self.write("QUIT")
return super().close()
# Writes: as inherited # Writes: as inherited
# Read # Read
def read(self): def async_read(self):
for line in SocketServer.read(self): for line in super().async_read():
# PING should be handled here, so start parsing here :/ # PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding) msg = IRCMessage(line, self.encoding)

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,20 +14,13 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock() def factory(uri, ssl=False, **init_args):
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
from urllib.parse import urlparse, unquote from urllib.parse import urlparse, unquote
o = urlparse(uri) o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs": if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
@ -36,7 +29,7 @@ def factory(uri, **init_args):
modifiers = o.path.split(",") modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:]) target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True if o.scheme == "ircs": ssl = True
if o.hostname is not None: args["host"] = o.hostname if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username if o.username is not None: args["username"] = o.username
@ -65,24 +58,11 @@ def factory(uri, **init_args):
args["channels"] = [ target ] args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer from nemubot.server.IRC import IRC as IRCServer
return IRCServer(**args) srv = IRCServer(**args)
else:
return None
if ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
return ctx.wrap_socket(srv)
def reload(): return srv
import imp
import nemubot.server.abstract
imp.reload(nemubot.server.abstract)
import nemubot.server.socket
imp.reload(nemubot.server.socket)
import nemubot.server.IRC
imp.reload(nemubot.server.IRC)
import nemubot.server.message
imp.reload(nemubot.server.message)
nemubot.server.message.reload()

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,71 +14,68 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging import logging
import queue import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase): class AbstractServer:
"""An abstract server: handle communication with an IM server""" """An abstract server: handle communication with an IM server"""
def __init__(self, send_callback=None): def __init__(self, name=None, **kwargs):
"""Initialize an abstract server """Initialize an abstract server
Keyword argument: Keyword argument:
send_callback -- Callback when developper want to send a message name -- Identifier of the socket, for convinience
""" """
if not hasattr(self, "id"): self._name = name
raise Exception("No id defined for this server. Please set one!")
self.logger = logging.getLogger("nemubot.server." + self.id) super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue() self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
def __del__(self):
print("Server deleted")
@property
def name(self):
if self._name is not None:
return self._name
else: else:
self._send_callback = self._write_select return self.fileno()
# Open/close # Open/close
def __enter__(self): def connect(self, *args, **kwargs):
self.open() """Register the server in _poll"""
return self
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._connected()
def _connected(self):
sync_act("sckt register %d" % self.fileno())
def __exit__(self, type, value, traceback): def close(self, *args, **kwargs):
self.close() """Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self): if self.fileno() > 0:
"""Generic open function that register the server un _rlist in case sync_act("sckt unregister %d" % self.fileno())
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
super().close(*args, **kwargs)
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
# Writes # Writes
@ -90,13 +87,16 @@ class AbstractServer(io.IOBase):
message -- message to send message -- message to send
""" """
self._send_callback(message) self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt write %d" % self.fileno())
def write_select(self): def async_write(self):
"""Internal function used by the select function""" """Internal function used when the file descriptor is writable"""
try: try:
_wlist.remove(self) sync_act("sckt unwrite %d" % self.fileno())
while not self._sending_queue.empty(): while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait()) self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done() self._sending_queue.task_done()
@ -105,19 +105,6 @@ class AbstractServer(io.IOBase):
pass pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to Queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response): def send_response(self, response):
"""Send a formated Message class """Send a formated Message class
@ -140,13 +127,39 @@ class AbstractServer(io.IOBase):
# Read # Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg): def parse(self, msg):
raise NotImplemented raise NotImplemented
# Exceptions # Exceptions
def exception(self): def exception(self, flags):
"""Exception occurs in fd""" """Exception occurs on fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.id) self.close()

View file

@ -146,7 +146,7 @@ class IRC(Abstract):
receivers = self.decode(self.params[0]).split(',') receivers = self.decode(self.params[0]).split(',')
common_args = { common_args = {
"server": srv.id, "server": srv.name,
"date": self.tags["time"], "date": self.tags["time"],
"to": receivers, "to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers], "to_response": [r if r != srv.nick else self.nick for r in receivers],

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,99 +14,35 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer): class _Socket(AbstractServer, socket.socket):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)""" """Concrete implementation of a socket connection"""
def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None): def __init__(self, **kwargs):
if id is not None: """Create a server socket
self.id = id
AbstractServer.__init__(self) Keyword arguments:
if sock_location is not None: ssl -- Should TLS connection enabled
self.filename = sock_location """
elif host is not None:
self.host = host super().__init__(**kwargs)
self.port = int(port)
self.ssl = ssl
self.socket = socket
self.readbuffer = b'' self.readbuffer = b''
self.printer = SocketPrinter self.printer = SocketPrinter
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def connected(self):
"""Indicator of the connection aliveness"""
return self.socket is not None
# Open/close
def _open(self):
import os
import socket
if self.connected:
return True
try:
if hasattr(self, "filename"):
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.connect(self.filename)
self.logger.info("Connected to %s", self.filename)
else:
self.socket = socket.create_connection((self.host, self.port))
self.logger.info("Connected to %s:%d", self.host, self.port)
except socket.error as e:
self.socket = None
self.logger.critical("Unable to connect to %s:%d: %s",
self.host, self.port,
os.strerror(e.errno))
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return True
def _close(self):
import socket
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if self.connected:
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
except socket.error:
pass
self.socket = None
return True
# Write # Write
def _write(self, cnt): def _write(self, cnt):
if not self.connected: self.sendall(cnt)
return
self.socket.send(cnt)
def format(self, txt): def format(self, txt):
@ -118,80 +54,113 @@ class SocketServer(AbstractServer):
# Read # Read
def read(self, n=1024):
return self.recv(n)
def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex
line = line.strip().decode()
try:
args = shlex.split(line)
except ValueError:
args = line.split(' ')
if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
class SocketServer(_Socket):
def __init__(self, host, port, bind=None, **kwargs):
super().__init__(family=socket.AF_INET, **kwargs)
assert(host is not None)
assert(isinstance(port, int))
self._host = host
self._port = port
self._bind = bind
def connect(self):
self.logger.info("Connection to %s:%d", self._host, self._port)
super().connect((self._host, self._port))
if self._bind:
super().bind(self._bind)
class UnixSocket(_Socket):
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def connect(self):
self.logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
class _Listener(_Socket):
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self): def read(self):
if not self.connected: conn, addr = self.accept()
return [] self.logger.info("Accept new connection from %s", addr)
raw = self.socket.recv(1024) fileno = conn.fileno()
temp = (self.readbuffer + raw).split(b'\r\n') ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
self.readbuffer = temp.pop() ss.connect = ss._connected
self._new_server_cb(ss, autoconnect=True)
for line in temp: return b''
yield line
class SocketListener(AbstractServer): class UnixSocketListener(_Listener, UnixSocket):
def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None): def __init__(self, **kwargs):
self.id = id super().__init__(**kwargs)
AbstractServer.__init__(self)
self.new_server_cb = new_server_cb
self.sock_location = sock_location
self.host = host
self.port = port
self.ssl = ssl
self.nb_son = 0
def fileno(self): def connect(self):
return self.socket.fileno() if self.socket else None self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._connected()
@property def close(self):
def connected(self):
"""Indicator of the connection aliveness"""
return self.socket is not None
def _open(self):
import os
import socket
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
if self.sock_location is not None:
try:
os.remove(self.sock_location)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket.bind((self.host, self.port))
self.socket.listen(5)
return True
def _close(self):
import os import os
import socket import socket
try: try:
self.socket.shutdown(socket.SHUT_RDWR) self.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
except socket.error: except socket.error:
pass pass
# Read super().close()
def read(self): try:
if not self.connected: if self._socket_path is not None:
return [] os.remove(self._socket_path)
except:
conn, addr = self.socket.accept() pass
self.nb_son += 1
ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []

View file

@ -13,29 +13,3 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def reload():
import imp
import nemubot.tools.config
imp.reload(nemubot.tools.config)
import nemubot.tools.countdown
imp.reload(nemubot.tools.countdown)
import nemubot.tools.feed
imp.reload(nemubot.tools.feed)
import nemubot.tools.date
imp.reload(nemubot.tools.date)
import nemubot.tools.human
imp.reload(nemubot.tools.human)
import nemubot.tools.web
imp.reload(nemubot.tools.web)
import nemubot.tools.xmlparser
imp.reload(nemubot.tools.xmlparser)
import nemubot.tools.xmlparser.node
imp.reload(nemubot.tools.xmlparser.node)

View file

@ -57,11 +57,15 @@ def word_distance(str1, str2):
return d[len(str1)][len(str2)] return d[len(str1)][len(str2)]
def guess(pattern, expect): def guess(pattern, expect, max_depth=0):
if max_depth == 0:
max_depth = 1 + len(pattern) / 4
elif max_depth <= -1:
max_depth = len(pattern) - max_depth
if len(expect): if len(expect):
se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1]) se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1])
_, m = se[0] _, m = se[0]
for e, wd in se: for e, wd in se:
if wd > m or wd > 1 + len(pattern) / 4: if wd > m or wd > max_depth:
break break
yield e yield e

View file

@ -51,11 +51,13 @@ class XMLParser:
def __init__(self, knodes): def __init__(self, knodes):
self.knodes = knodes self.knodes = knodes
def _reset(self):
self.stack = list() self.stack = list()
self.child = 0 self.child = 0
def parse_file(self, path): def parse_file(self, path):
self._reset()
p = xml.parsers.expat.ParserCreate() p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement p.StartElementHandler = self.startElement
@ -69,6 +71,7 @@ class XMLParser:
def parse_string(self, s): def parse_string(self, s):
self._reset()
p = xml.parsers.expat.ParserCreate() p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement p.StartElementHandler = self.startElement
@ -126,10 +129,13 @@ class XMLParser:
if hasattr(self.current, "endElement"): if hasattr(self.current, "endElement"):
self.current.endElement(None) self.current.endElement(None)
if hasattr(self.current, "parsedForm") and callable(self.current.parsedForm):
self.stack[-1] = self.current.parsedForm()
# Don't remove root # Don't remove root
if len(self.stack) > 1: if len(self.stack) > 1:
last = self.stack.pop() last = self.stack.pop()
if hasattr(self.current, "addChild"): if hasattr(self.current, "addChild") and callable(self.current.addChild):
if self.current.addChild(name, last): if self.current.addChild(name, last):
return return
raise TypeError(name + " tag not expected in " + self.display_stack()) raise TypeError(name + " tag not expected in " + self.display_stack())

View file

@ -63,13 +63,13 @@ setup(
'nemubot', 'nemubot',
'nemubot.config', 'nemubot.config',
'nemubot.datastore', 'nemubot.datastore',
'nemubot.datastore.nodes',
'nemubot.event', 'nemubot.event',
'nemubot.exception', 'nemubot.exception',
'nemubot.hooks', 'nemubot.hooks',
'nemubot.hooks.keywords', 'nemubot.hooks.keywords',
'nemubot.message', 'nemubot.message',
'nemubot.message.printer', 'nemubot.message.printer',
'nemubot.prompt',
'nemubot.server', 'nemubot.server',
'nemubot.server.message', 'nemubot.server.message',
'nemubot.tools', 'nemubot.tools',