Compare commits

...

22 commits

Author SHA1 Message Date
39936e9d39 Add type for use with mypy 2016-07-10 00:17:37 +02:00
a8706d6213 New printer and parser for bot data, XML-based 2016-07-10 00:17:37 +02:00
e103d22bf2 Use fileno instead of name to index existing servers 2016-07-10 00:15:14 +02:00
1e8cb3a12a Use super() instead of parent class name 2016-07-10 00:15:14 +02:00
fab747fcfd Documentation 2016-07-10 00:15:14 +02:00
c0e489f6b6 In debug mode, display running thread at exit 2016-07-08 22:42:44 +02:00
509446a0f4 Handle case where frm and to have not been filled 2016-07-08 22:27:57 +02:00
4473d9547e Review consumer errors 2016-07-08 22:27:57 +02:00
95fc044783 Remove reload feature
As reload shoudl be done in a particular order, to keep valid types, and because maintaining such system is too complex (currently, it doesn't work for a while), now, a reload is just reload configuration file (and possibly modules)
2016-07-08 22:27:57 +02:00
1dd06f1621 New keywords class that accepts any keywords 2016-07-08 22:27:57 +02:00
f9837abba8 [rnd] Add new function choiceres which pick a random response returned by a given subcommand 2016-07-08 22:27:57 +02:00
670a1319a2 Can attach to the main process 2016-07-08 22:27:57 +02:00
2b1469c03f Remove legacy prompt 2016-07-08 22:27:57 +02:00
8983b9b67c Fix and improve reload process 2016-07-08 22:27:57 +02:00
c0fed51fde New argument: --socketfile that create a socket for internal communication 2016-07-08 22:27:57 +02:00
0b14207c88 New CLI argument: --pidfile, path to store the daemon PID 2016-07-08 22:27:57 +02:00
e17368cf26 Catch SIGUSR1: log threads stack traces 2016-07-08 22:27:57 +02:00
220bc7356e Extract deamonize to a dedicated function that can be called from anywhere 2016-07-08 22:27:57 +02:00
29913bd943 Catch SIGHUP: deep reload 2016-07-08 22:27:57 +02:00
7440bd4222 Do a proper close on SIGINT and SIGTERM 2016-07-08 22:27:57 +02:00
179397e96a Remove prompt at launch 2016-07-08 22:27:57 +02:00
992c847e27 Introducing daemon mode 2016-07-08 22:27:57 +02:00
47 changed files with 1051 additions and 1216 deletions

View file

@ -1,202 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import traceback
import sys
from nemubot.hooks import hook
nemubotversion = 3.4
NODATA = True
def getserver(toks, context, prompt, mandatory=False, **kwargs):
"""Choose the server in toks or prompt.
This function modify the tokens list passed as argument"""
if len(toks) > 1 and toks[1] in context.servers:
return context.servers[toks.pop(1)]
elif not mandatory or prompt.selectedServer:
return prompt.selectedServer
else:
from nemubot.prompt.error import PromptError
raise PromptError("Please SELECT a server or give its name in argument.")
@hook("prompt_cmd", "close")
def close(toks, context, **kwargs):
"""Disconnect and forget (remove from the servers list) the server"""
srv = getserver(toks, context=context, mandatory=True, **kwargs)
if srv.close():
del context.servers[srv.id]
return 0
return 1
@hook("prompt_cmd", "connect")
def connect(toks, **kwargs):
"""Make the connexion to a server"""
srv = getserver(toks, mandatory=True, **kwargs)
return not srv.open()
@hook("prompt_cmd", "disconnect")
def disconnect(toks, **kwargs):
"""Close the connection to a server"""
srv = getserver(toks, mandatory=True, **kwargs)
return not srv.close()
@hook("prompt_cmd", "discover")
def discover(toks, context, **kwargs):
"""Discover a new bot on a server"""
srv = getserver(toks, context=context, mandatory=True, **kwargs)
if len(toks) > 1 and "!" in toks[1]:
bot = context.add_networkbot(srv, name)
return not bot.connect()
else:
print(" %s is not a valid fullname, for example: "
"nemubot!nemubotV3@bot.nemunai.re" % ''.join(toks[1:1]))
return 1
@hook("prompt_cmd", "join")
@hook("prompt_cmd", "leave")
@hook("prompt_cmd", "part")
def join(toks, **kwargs):
"""Join or leave a channel"""
srv = getserver(toks, mandatory=True, **kwargs)
if len(toks) <= 2:
print("%s: not enough arguments." % toks[0])
return 1
if toks[0] == "join":
if len(toks) > 2:
srv.write("JOIN %s %s" % (toks[1], toks[2]))
else:
srv.write("JOIN %s" % toks[1])
elif toks[0] == "leave" or toks[0] == "part":
if len(toks) > 2:
srv.write("PART %s :%s" % (toks[1], " ".join(toks[2:])))
else:
srv.write("PART %s" % toks[1])
return 0
@hook("prompt_cmd", "save")
def save_mod(toks, context, **kwargs):
"""Force save module data"""
if len(toks) < 2:
print("save: not enough arguments.")
return 1
wrn = 0
for mod in toks[1:]:
if mod in context.modules:
context.modules[mod].save()
print("save: module `%s´ saved successfully" % mod)
else:
wrn += 1
print("save: no module named `%s´" % mod)
return wrn
@hook("prompt_cmd", "send")
def send(toks, **kwargs):
"""Send a message on a channel"""
srv = getserver(toks, mandatory=True, **kwargs)
# Check the server is connected
if not srv.connected:
print ("send: server `%s' not connected." % srv.id)
return 2
if len(toks) <= 3:
print ("send: not enough arguments.")
return 1
if toks[1] not in srv.channels:
print ("send: channel `%s' not authorized in server `%s'."
% (toks[1], srv.id))
return 3
from nemubot.message import Text
srv.send_response(Text(" ".join(toks[2:]), server=None,
to=[toks[1]]))
return 0
@hook("prompt_cmd", "zap")
def zap(toks, **kwargs):
"""Hard change connexion state"""
srv = getserver(toks, mandatory=True, **kwargs)
srv.connected = not srv.connected
@hook("prompt_cmd", "top")
def top(toks, context, **kwargs):
"""Display consumers load information"""
print("Queue size: %d, %d thread(s) running (counter: %d)" %
(context.cnsr_queue.qsize(),
len(context.cnsr_thrd),
context.cnsr_thrd_size))
if len(context.events) > 0:
print("Events registered: %d, next in %d seconds" %
(len(context.events),
context.events[0].time_left.seconds))
else:
print("No events registered")
for th in context.cnsr_thrd:
if th.is_alive():
print(("#" * 15 + " Stack trace for thread %u " + "#" * 15) %
th.ident)
traceback.print_stack(sys._current_frames()[th.ident])
@hook("prompt_cmd", "netstat")
def netstat(toks, context, **kwargs):
"""Display sockets in use and many other things"""
if len(context.network) > 0:
print("Distant bots connected: %d:" % len(context.network))
for name, bot in context.network.items():
print("# %s:" % name)
print(" * Declared hooks:")
lvl = 0
for hlvl in bot.hooks:
lvl += 1
for hook in (hlvl.all_pre + hlvl.all_post + hlvl.cmd_rgxp +
hlvl.cmd_default + hlvl.ask_rgxp +
hlvl.ask_default + hlvl.msg_rgxp +
hlvl.msg_default):
print(" %s- %s" % (' ' * lvl * 2, hook))
for kind in ["irc_hook", "cmd_hook", "ask_hook", "msg_hook"]:
print(" %s- <%s> %s" % (' ' * lvl * 2, kind,
", ".join(hlvl.__dict__[kind].keys())))
print(" * My tag: %d" % bot.my_tag)
print(" * Tags in use (%d):" % bot.inc_tag)
for tag, (cmd, data) in bot.tags.items():
print(" - %11s: %s « %s »" % (tag, cmd, data))
else:
print("No distant bot connected")

View file

@ -8,7 +8,6 @@ import shlex
from nemubot import context
from nemubot.exception import IMException
from nemubot.hooks import hook
from nemubot.message import Command
from more import Response
@ -32,8 +31,24 @@ def cmd_choicecmd(msg):
choice = shlex.split(random.choice(msg.args))
return [x for x in context.subtreat(Command(choice[0][1:],
choice[1:],
to_response=msg.to_response,
frm=msg.frm,
server=msg.server))]
return [x for x in context.subtreat(context.subparse(msg, choice))]
@hook.command("choiceres")
def cmd_choiceres(msg):
if not len(msg.args):
raise IMException("indicate some command to pick a message from!")
rl = [x for x in context.subtreat(context.subparse(msg, " ".join(msg.args)))]
if len(rl) <= 0:
return rl
r = random.choice(rl)
if isinstance(r, Response):
for i in range(len(r.messages) - 1, -1, -1):
if isinstance(r.messages[i], list):
r.messages = [ random.choice(random.choice(r.messages)) ]
elif isinstance(r.messages[i], str):
r.messages = [ random.choice(r.messages) ]
return r

View file

@ -17,11 +17,15 @@
__version__ = '4.0.dev3'
__author__ = 'nemunaire'
from typing import Optional
from nemubot.modulecontext import ModuleContext
context = ModuleContext(None, None)
def requires_version(min=None, max=None):
def requires_version(min: Optional[int] = None,
max: Optional[int] = None) -> None:
"""Raise ImportError if the current version is not in the given range
Keyword arguments:
@ -38,62 +42,98 @@ def requires_version(min=None, max=None):
"but this is nemubot v%s." % (str(max), __version__))
def reload():
"""Reload code of all Python modules used by nemubot
def attach(pid: int, socketfile: str) -> int:
import socket
import sys
print("nemubot is already launched with PID %d. Attaching to Unix socket at: %s" % (pid, socketfile))
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(socketfile)
except socket.error as e:
sys.stderr.write(str(e))
sys.stderr.write("\n")
return 1
from select import select
try:
print("Connection established.")
while True:
rl, wl, xl = select([sys.stdin, sock], [], [])
if sys.stdin in rl:
line = sys.stdin.readline().strip()
if line == "exit" or line == "quit":
return 0
elif line == "reload":
import os, signal
os.kill(pid, signal.SIGHUP)
print("Reload signal sent. Please wait...")
elif line == "shutdown":
import os, signal
os.kill(pid, signal.SIGTERM)
print("Shutdown signal sent. Please wait...")
elif line == "kill":
import os, signal
os.kill(pid, signal.SIGKILL)
print("Signal sent...")
return 0
elif line == "stack" or line == "stacks":
import os, signal
os.kill(pid, signal.SIGUSR1)
print("Debug signal sent. Consult logs.")
else:
sock.send(line.encode() + b'\r\n')
if sock in rl:
sys.stdout.write(sock.recv(2048).decode())
except KeyboardInterrupt:
pass
except:
return 1
finally:
sock.close()
return 0
def daemonize() -> None:
"""Detach the running process to run as a daemon
"""
import imp
import os
import sys
import nemubot.channel
imp.reload(nemubot.channel)
try:
pid = os.fork()
if pid > 0:
sys.exit(0)
except OSError as err:
sys.stderr.write("Unable to fork: %s\n" % err)
sys.exit(1)
import nemubot.config
imp.reload(nemubot.config)
os.setsid()
os.umask(0)
os.chdir('/')
nemubot.config.reload()
try:
pid = os.fork()
if pid > 0:
sys.exit(0)
except OSError as err:
sys.stderr.write("Unable to fork: %s\n" % err)
sys.exit(1)
import nemubot.consumer
imp.reload(nemubot.consumer)
sys.stdout.flush()
sys.stderr.flush()
si = open(os.devnull, 'r')
so = open(os.devnull, 'a+')
se = open(os.devnull, 'a+')
import nemubot.datastore
imp.reload(nemubot.datastore)
nemubot.datastore.reload()
import nemubot.event
imp.reload(nemubot.event)
import nemubot.exception
imp.reload(nemubot.exception)
nemubot.exception.reload()
import nemubot.hooks
imp.reload(nemubot.hooks)
nemubot.hooks.reload()
import nemubot.importer
imp.reload(nemubot.importer)
import nemubot.message
imp.reload(nemubot.message)
nemubot.message.reload()
import nemubot.prompt
imp.reload(nemubot.prompt)
nemubot.prompt.reload()
import nemubot.server
rl, wl, xl = nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist
imp.reload(nemubot.server)
nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist = rl, wl, xl
nemubot.server.reload()
import nemubot.tools
imp.reload(nemubot.tools)
nemubot.tools.reload()
os.dup2(si.fileno(), sys.stdin.fileno())
os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno())

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,8 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def main():
def main() -> None:
import os
import signal
import sys
# Parse command line arguments
@ -36,6 +37,15 @@ def main():
default=["./modules/"],
help="directory to use as modules store")
parser.add_argument("-d", "--debug", action="store_true",
help="don't deamonize, keep in foreground")
parser.add_argument("-P", "--pidfile", default="./nemubot.pid",
help="Path to the file where store PID")
parser.add_argument("-S", "--socketfile", default="./nemubot.sock",
help="path where open the socket for internal communication")
parser.add_argument("-l", "--logfile", default="./nemubot.log",
help="Path to store logs")
@ -58,10 +68,34 @@ def main():
# Resolve relatives paths
args.data_path = os.path.abspath(os.path.expanduser(args.data_path))
args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile))
args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile))
args.logfile = os.path.abspath(os.path.expanduser(args.logfile))
args.files = [ x for x in map(os.path.abspath, args.files)]
args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)]
# Check if an instance is already launched
if args.pidfile is not None and os.path.isfile(args.pidfile):
with open(args.pidfile, "r") as f:
pid = int(f.readline())
try:
os.kill(pid, 0)
except OSError:
pass
else:
from nemubot import attach
sys.exit(attach(pid, args.socketfile))
# Daemonize
if not args.debug:
from nemubot import daemonize
daemonize()
# Store PID to pidfile
if args.pidfile is not None:
with open(args.pidfile, "w+") as f:
f.write(str(os.getpid()))
# Setup loggin interface
import logging
logger = logging.getLogger("nemubot")
@ -70,6 +104,7 @@ def main():
formatter = logging.Formatter(
'%(asctime)s %(name)s %(levelname)s %(message)s')
if args.debug:
ch = logging.StreamHandler()
ch.setFormatter(formatter)
if args.verbose < 2:
@ -98,13 +133,10 @@ def main():
if args.no_connect:
context.noautoconnect = True
# Load the prompt
import nemubot.prompt
prmpt = nemubot.prompt.Prompt()
# Register the hook for futur import
from nemubot.importer import ModuleFinder
sys.meta_path.append(ModuleFinder(context.modules_paths, context.add_module))
module_finder = ModuleFinder(context.modules_paths, context.add_module)
sys.meta_path.append(module_finder)
# Load requested configuration files
for path in args.files:
@ -117,36 +149,57 @@ def main():
for module in args.module:
__import__(module)
print ("Nemubot v%s ready, my PID is %i!" % (nemubot.__version__,
os.getpid()))
while True:
from nemubot.prompt.reset import PromptReset
try:
context.start()
if prmpt.run(context):
break
except PromptReset as e:
if e.type == "quit":
break
try:
import imp
# Reload all other modules
imp.reload(nemubot)
imp.reload(nemubot.prompt)
nemubot.reload()
import nemubot.bot
context = nemubot.bot.hotswap(context)
prmpt = nemubot.prompt.hotswap(prmpt)
print("\033[1;32mContext reloaded\033[0m, now in Nemubot %s" %
nemubot.__version__)
except:
logger.exception("\033[1;31mUnable to reload the prompt due to "
"errors.\033[0m Fix them before trying to reload "
"the prompt.")
# Signals handling
def sigtermhandler(signum, frame):
"""On SIGTERM and SIGINT, quit nicely"""
sigusr1handler(signum, frame)
context.quit()
print("Waiting for other threads shuts down...")
signal.signal(signal.SIGINT, sigtermhandler)
signal.signal(signal.SIGTERM, sigtermhandler)
def sighuphandler(signum, frame):
"""On SIGHUP, perform a deep reload"""
nonlocal context
logger.debug("SIGHUP receive, iniate reload procedure...")
# Reload configuration file
for path in args.files:
if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path])
signal.signal(signal.SIGHUP, sighuphandler)
def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces"""
import threading, traceback
for threadId, stack in sys._current_frames().items():
thName = "#%d" % threadId
for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile:
from nemubot.server.socket import UnixSocketListener
context.add_server(UnixSocketListener(new_server_cb=context.add_server,
location=args.socketfile,
name="master_socket"))
# context can change when performing an hotswap, always join the latest context
oldcontext = None
while oldcontext != context:
oldcontext = context
context.start()
context.join()
# Wait for consumers
logger.info("Waiting for other threads shuts down...")
if args.debug:
sigusr1handler(0, None)
sys.exit(0)
if __name__ == "__main__":

View file

@ -15,9 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timezone
import ipaddress
import logging
from multiprocessing import JoinableQueue
import threading
import select
import sys
from typing import Any, Mapping, Optional, Sequence
from nemubot import __version__
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
@ -26,22 +30,33 @@ import nemubot.hooks
logger = logging.getLogger("nemubot")
sync_queue = JoinableQueue()
def sync_act(*args):
if isinstance(act, bytes):
act = act.decode()
sync_queue.put(act)
class Bot(threading.Thread):
"""Class containing the bot context and ensuring key goals"""
def __init__(self, ip="127.0.0.1", modules_paths=list(),
data_store=datastore.Abstract(), verbosity=0):
def __init__(self,
ip: Optional[ipaddress] = None,
modules_paths: Sequence[str] = list(),
data_store: Optional[datastore.Abstract] = None,
verbosity: int = 0):
"""Initialize the bot context
Keyword arguments:
ip -- The external IP of the bot (default: 127.0.0.1)
modules_paths -- Paths to all directories where looking for module
modules_paths -- Paths to all directories where looking for modules
data_store -- An instance of the nemubot datastore for bot's modules
verbosity -- verbosity level
"""
threading.Thread.__init__(self)
super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__,
@ -51,16 +66,16 @@ class Bot(threading.Thread):
self.stop = None
# External IP for accessing this bot
import ipaddress
self.ip = ipaddress.ip_address(ip)
self.ip = ip if ip is not None else ipaddress.ip_address("127.0.0.1")
# Context paths
self.modules_paths = modules_paths
self.datastore = data_store
self.datastore = data_store if data_store is not None else datastore.Abstract()
self.datastore.open()
# Keep global context: servers and modules
self.servers = dict()
self._poll = select.poll()
self.servers = dict() # types: Mapping[str, AbstractServer]
self.modules = dict()
self.modules_configuration = dict()
@ -137,60 +152,77 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue()
self.cnsr_thrd = list()
self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self):
from select import select
from nemubot.server import _lock, _rlist, _wlist, _xlist
self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
logger.info("Starting main loop")
self.stop = False
while not self.stop:
with _lock:
try:
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
except:
logger.error("Something went wrong in select")
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for fd, flag in self._poll.poll():
print("poll")
# Handle internal socket passing orders
if fd != sync_queue._reader.fileno():
srv = self.servers[fd]
for x in xl:
if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
x.exception()
srv.exception(flag)
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
if srv.fileno() > 0:
if flag & (select.POLLOUT):
try:
w.write_select()
srv.async_write()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
if flag & (select.POLLIN | select.POLLPRI):
try:
self.receive_message(r, i)
for i in srv.async_read():
self.receive_message(srv, i)
except:
logger.exception("Uncatched exception on server read")
else:
del self.servers[fd]
# Always check the sync queue
while not sync_queue.empty():
import shlex
args = shlex.split(sync_queue.get())
action = args.pop(0)
logger.info("action: %s: %s", action, args)
if action == "sckt" and len(args) >= 2:
try:
if args[0] == "write":
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "unwrite":
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "register":
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
elif args[0] == "unregister":
self._poll.unregister(int(args[1]))
except:
logger.exception("Unhandled excpetion during action:")
elif action == "exit":
self.quit()
elif action == "loadconf":
for path in action.args:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
sync_queue.task_done()
# Launch new consumer threads if necessary
while self.cnsr_queue.qsize() > self.cnsr_thrd_size:
@ -200,17 +232,7 @@ class Bot(threading.Thread):
c = Consumer(self)
self.cnsr_thrd.append(c)
c.start()
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
logger.info("Ending main loop")
@ -252,9 +274,9 @@ class Bot(threading.Thread):
srv = server.server(config)
# Add the server in the context
if self.add_server(srv, server.autoconnect):
logger.info("Server '%s' successfully added." % srv.id)
logger.info("Server '%s' successfully added." % srv.name)
else:
logger.error("Can't add server '%s'." % srv.id)
logger.error("Can't add server '%s'." % srv.name)
# Load module and their configuration
for mod in config.modules:
@ -303,7 +325,7 @@ class Bot(threading.Thread):
if type(eid) is uuid.UUID:
evt.id = str(eid)
else:
# Ok, this is quite useless...
# Ok, this is quiet useless...
try:
evt.id = str(uuid.UUID(eid))
except ValueError:
@ -411,10 +433,11 @@ class Bot(threading.Thread):
autoconnect -- connect after add?
"""
if srv.id not in self.servers:
self.servers[srv.id] = srv
fileno = srv.fileno()
if fileno not in self.servers:
self.servers[fileno] = srv
if autoconnect and not hasattr(self, "noautoconnect"):
srv.open()
srv.connect()
return True
else:
@ -436,10 +459,10 @@ class Bot(threading.Thread):
__import__(name)
def add_module(self, module):
def add_module(self, mdl: Any):
"""Add a module to the context, if already exists, unload the
old one before"""
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
module_name = mdl.__spec__.name if hasattr(mdl, "__spec__") else mdl.__name__
if hasattr(self, "stop") and self.stop:
logger.warn("The bot is stopped, can't register new modules")
@ -451,40 +474,40 @@ class Bot(threading.Thread):
# Overwrite print built-in
def prnt(*args):
if hasattr(module, "logger"):
module.logger.info(" ".join([str(s) for s in args]))
if hasattr(mdl, "logger"):
mdl.logger.info(" ".join([str(s) for s in args]))
else:
logger.info("[%s] %s", module_name, " ".join([str(s) for s in args]))
module.print = prnt
mdl.print = prnt
# Create module context
from nemubot.modulecontext import ModuleContext
module.__nemubot_context__ = ModuleContext(self, module)
mdl.__nemubot_context__ = ModuleContext(self, mdl)
if not hasattr(module, "logger"):
module.logger = logging.getLogger("nemubot.module." + module_name)
if not hasattr(mdl, "logger"):
mdl.logger = logging.getLogger("nemubot.module." + module_name)
# Replace imported context by real one
for attr in module.__dict__:
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext:
module.__dict__[attr] = module.__nemubot_context__
for attr in mdl.__dict__:
if attr != "__nemubot_context__" and type(mdl.__dict__[attr]) == ModuleContext:
mdl.__dict__[attr] = mdl.__nemubot_context__
# Register decorated functions
import nemubot.hooks
for s, h in nemubot.hooks.hook.last_registered:
module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
mdl.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
nemubot.hooks.hook.last_registered = []
# Launch the module
if hasattr(module, "load"):
if hasattr(mdl, "load"):
try:
module.load(module.__nemubot_context__)
mdl.load(mdl.__nemubot_context__)
except:
module.__nemubot_context__.unload()
mdl.__nemubot_context__.unload()
raise
# Save a reference to the module
self.modules[module_name] = module
self.modules[module_name] = mdl
def unload_module(self, name):
@ -527,28 +550,28 @@ class Bot(threading.Thread):
def quit(self):
"""Save and unload modules and disconnect servers"""
self.datastore.close()
if self.event_timer is not None:
logger.info("Stop the event timer...")
self.event_timer.cancel()
logger.info("Save and unload all modules...")
for mod in self.modules.items():
self.unload_module(mod)
logger.info("Close all servers connection...")
for k in self.servers:
self.servers[k].close()
logger.info("Stop consumers")
k = self.cnsr_thrd
for cnsr in k:
cnsr.stop = True
logger.info("Save and unload all modules...")
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
logger.info("Close all servers connection...")
k = list(self.servers.keys())
for srv in k:
self.servers[srv].close()
self.datastore.close()
self.stop = True
sync_queue.put("end")
sync_queue.join()
# Treatment
@ -562,20 +585,3 @@ class Bot(threading.Thread):
del store[hook.name]
elif isinstance(store, list):
store.remove(hook)
def hotswap(bak):
bak.stop = True
if bak.event_timer is not None:
bak.event_timer.cancel()
bak.datastore.close()
new = Bot(str(bak.ip), bak.modules_paths, bak.datastore)
new.servers = bak.servers
new.modules = bak.modules
new.modules_configuration = bak.modules_configuration
new.events = bak.events
new.hooks = bak.hooks
new._update_event_timer()
return new

View file

@ -15,13 +15,18 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
from typing import Optional
from nemubot.message import Abstract as AbstractMessage
class Channel:
"""A chat room"""
def __init__(self, name, password=None, encoding=None):
def __init__(self,
name: str,
password: Optional[str] = None,
encoding: Optional[str] = None):
"""Initialize the channel
Arguments:
@ -37,7 +42,8 @@ class Channel:
self.topic = ""
self.logger = logging.getLogger("nemubot.channel." + name)
def treat(self, cmd, msg):
def treat(self, cmd: str, msg: AbstractMessage) -> None:
"""Treat a incoming IRC command
Arguments:
@ -60,7 +66,8 @@ class Channel:
elif cmd == "TOPIC":
self.topic = self.text
def join(self, nick, level=0):
def join(self, nick: str, level: int = 0) -> None:
"""Someone join the channel
Argument:
@ -71,7 +78,8 @@ class Channel:
self.logger.debug("%s join", nick)
self.people[nick] = level
def chtopic(self, newtopic):
def chtopic(self, newtopic: str) -> None:
"""Send command to change the topic
Arguments:
@ -81,7 +89,8 @@ class Channel:
self.srv.send_msg(self.name, newtopic, "TOPIC")
self.topic = newtopic
def nick(self, oldnick, newnick):
def nick(self, oldnick: str, newnick: str) -> None:
"""Someone change his nick
Arguments:
@ -95,7 +104,8 @@ class Channel:
del self.people[oldnick]
self.people[newnick] = lvl
def part(self, nick):
def part(self, nick: str) -> None:
"""Someone leave the channel
Argument:
@ -106,7 +116,8 @@ class Channel:
self.logger.debug("%s has left", nick)
del self.people[nick]
def mode(self, msg):
def mode(self, msg: AbstractMessage) -> None:
"""Channel or user mode change
Argument:
@ -132,7 +143,8 @@ class Channel:
elif msg.text[0] == "-v":
self.people[msg.nick] &= ~1
def parse332(self, msg):
def parse332(self, msg: AbstractMessage) -> None:
"""Parse RPL_TOPIC message
Argument:
@ -141,7 +153,8 @@ class Channel:
self.topic = msg.text
def parse353(self, msg):
def parse353(self, msg: AbstractMessage) -> None:
"""Parse RPL_ENDOFWHO message
Argument:

View file

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def get_boolean(s):
def get_boolean(s) -> bool:
if isinstance(s, bool):
return s
else:
@ -24,24 +24,3 @@ from nemubot.config.include import Include
from nemubot.config.module import Module
from nemubot.config.nemubot import Nemubot
from nemubot.config.server import Server
def reload():
global Include, Module, Nemubot, Server
import imp
import nemubot.config.include
imp.reload(nemubot.config.include)
Include = nemubot.config.include.Include
import nemubot.config.module
imp.reload(nemubot.config.module)
Module = nemubot.config.module.Module
import nemubot.config.nemubot
imp.reload(nemubot.config.nemubot)
Nemubot = nemubot.config.nemubot.Nemubot
import nemubot.config.server
imp.reload(nemubot.config.server)
Server = nemubot.config.server.Server

View file

@ -16,5 +16,5 @@
class Include:
def __init__(self, path):
def __init__(self, path: str):
self.path = path

View file

@ -15,12 +15,16 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.config import get_boolean
from nemubot.tools.xmlparser.genericnode import GenericNode
from nemubot.datastore.nodes.generic import GenericNode
class Module(GenericNode):
def __init__(self, name, autoload=True, **kwargs):
def __init__(self,
name: str,
autoload: bool = True,
**kwargs):
super().__init__(None, **kwargs)
self.name = name
self.autoload = get_boolean(autoload)

View file

@ -14,15 +14,23 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.config.include import Include
from nemubot.config.module import Module
from nemubot.config.server import Server
from typing import Optional, Sequence, Union
import nemubot.config.include
import nemubot.config.module
import nemubot.config.server
class Nemubot:
def __init__(self, nick="nemubot", realname="nemubot", owner=None,
ip=None, ssl=False, caps=None, encoding="utf-8"):
def __init__(self,
nick: str = "nemubot",
realname: str = "nemubot",
owner: Optional[str] = None,
ip: Optional[str] = None,
ssl: bool = False,
caps: Optional[Sequence[str]] = None,
encoding: str = "utf-8"):
self.nick = nick
self.realname = realname
self.owner = owner
@ -34,13 +42,13 @@ class Nemubot:
self.includes = []
def addChild(self, name, child):
if name == "module" and isinstance(child, Module):
def addChild(self, name: str, child: Union[nemubot.config.module.Module, nemubot.config.server.Server, nemubot.config.include.Include]):
if name == "module" and isinstance(child, nemubot.config.module.Module):
self.modules.append(child)
return True
elif name == "server" and isinstance(child, Server):
elif name == "server" and isinstance(child, nemubot.config.server.Server):
self.servers.append(child)
return True
elif name == "include" and isinstance(child, Include):
elif name == "include" and isinstance(child, nemubot.config.include.Include):
self.includes.append(child)
return True

View file

@ -14,12 +14,19 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Optional, Sequence
from nemubot.channel import Channel
import nemubot.config.nemubot
class Server:
def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs):
def __init__(self,
uri: str = "irc://nemubot@localhost/",
autoconnect: bool = True,
caps: Optional[Sequence[str]] = None,
**kwargs):
self.uri = uri
self.autoconnect = autoconnect
self.caps = caps.split(" ") if caps is not None else []
@ -27,7 +34,7 @@ class Server:
self.channels = []
def addChild(self, name, child):
def addChild(self, name: str, child: Channel):
if name == "channel" and isinstance(child, Channel):
self.channels.append(child)
return True

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,6 +17,12 @@
import logging
import queue
import threading
from typing import List
from nemubot.bot import Bot
from nemubot.event import ModuleEvent
from nemubot.message.abstract import Abstract as AbstractMessage
from nemubot.server.abstract import AbstractServer
logger = logging.getLogger("nemubot.consumer")
@ -25,18 +31,15 @@ class MessageConsumer:
"""Store a message before treating"""
def __init__(self, srv, msg):
def __init__(self, srv: AbstractServer, msg: AbstractMessage):
self.srv = srv
self.orig = msg
def run(self, context):
def run(self, context: Bot) -> None:
"""Create, parse and treat the message"""
from nemubot.bot import Bot
assert isinstance(context, Bot)
msgs = []
msgs = [] # type: List[AbstractMessage]
# Parse the message
try:
@ -44,14 +47,14 @@ class MessageConsumer:
msgs.append(msg)
except:
logger.exception("Error occurred during the processing of the %s: "
"%s", type(self.msgs[0]).__name__, self.msgs[0])
"%s", type(self.orig).__name__, self.orig)
if len(msgs) <= 0:
return
# Qualify the message
if not hasattr(msg, "server") or msg.server is None:
msg.server = self.srv.id
msg.server = self.srv.name
if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
@ -62,15 +65,19 @@ class MessageConsumer:
to_server = None
if isinstance(res, str):
to_server = self.srv
elif not hasattr(res, "server"):
logger.error("No server defined for response of type %s: %s", type(res).__name__, res)
continue
elif res.server is None:
to_server = self.srv
res.server = self.srv.id
elif isinstance(res.server, str) and res.server in context.servers:
res.server = self.srv.name
elif res.server in context.servers:
to_server = context.servers[res.server]
else:
to_server = res.server
if to_server is None:
logger.error("The server defined in this response doesn't "
"exist: %s", res.server)
if to_server is None or not hasattr(to_server, "send_response") or not callable(to_server.send_response):
logger.error("The server defined in this response doesn't exist: %s", res.server)
continue
# Sent the message only if treat_post authorize it
@ -81,12 +88,12 @@ class EventConsumer:
"""Store a event before treating"""
def __init__(self, evt, timeout=20):
def __init__(self, evt: ModuleEvent, timeout: int = 20):
self.evt = evt
self.timeout = timeout
def run(self, context):
def run(self, context: Bot) -> None:
try:
self.evt.check()
except:
@ -107,13 +114,13 @@ class Consumer(threading.Thread):
"""Dequeue and exec requested action"""
def __init__(self, context):
def __init__(self, context: Bot):
self.context = context
self.stop = False
threading.Thread.__init__(self)
super().__init__(name="Nemubot consumer")
def run(self):
def run(self) -> None:
try:
while not self.stop:
stm = self.context.cnsr_queue.get(True, 1)

View file

@ -16,16 +16,3 @@
from nemubot.datastore.abstract import Abstract
from nemubot.datastore.xml import XML
def reload():
global Abstract, XML
import imp
import nemubot.datastore.abstract
imp.reload(nemubot.datastore.abstract)
Abstract = nemubot.datastore.abstract.Abstract
import nemubot.datastore.xml
imp.reload(nemubot.datastore.xml)
XML = nemubot.datastore.xml.XML

View file

@ -23,14 +23,16 @@ class Abstract:
"""Initialize a new empty storage tree
"""
from nemubot.tools.xmlparser import module_state
return module_state.ModuleState("nemubotstate")
return None
def open(self):
return
def close(self):
return
def open(self) -> bool:
return True
def close(self) -> bool:
return True
def load(self, module):
"""Load data for the given module
@ -44,6 +46,7 @@ class Abstract:
return self.new()
def save(self, module, data):
"""Load data for the given module
@ -57,9 +60,11 @@ class Abstract:
return True
def __enter__(self):
self.open()
return self
def __exit__(self, type, value, traceback):
self.close()

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,8 +14,5 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class PromptError(Exception):
def __init__(self, message):
super(PromptError, self).__init__(message)
self.message = message
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable

View file

@ -14,27 +14,38 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class ListNode:
from typing import Any, Mapping, Sequence
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
class ListNode(Serializable):
"""XML node representing a Python dictionnnary
"""
serializetag = "list"
def __init__(self, **kwargs):
self.items = list()
self.items = list() # type: Sequence
def addChild(self, name, child):
def addChild(self, name: str, child) -> bool:
self.items.append(child)
return True
def parsedForm(self) -> Sequence:
return self.items
def __len__(self):
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, item):
def __getitem__(self, item: int) -> Any:
return self.items[item]
def __setitem__(self, item, v):
def __setitem__(self, item: int, v: Any) -> None:
self.items[item] = v
def __contains__(self, item):
@ -44,65 +55,60 @@ class ListNode:
return self.items.__repr__()
class DictNode:
def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
for i in self.items:
node.children.append(ParsingNode.serialize_node(i))
return node
class DictNode(Serializable):
"""XML node representing a Python dictionnnary
"""
serializetag = "dict"
def __init__(self, **kwargs):
self.items = dict()
self._cur = None
def startElement(self, name, attrs):
def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None and "key" in attrs:
self._cur = (attrs["key"], "")
return True
self._cur = attrs["key"]
return False
def characters(self, content):
if self._cur is not None:
key, cnt = self._cur
if isinstance(cnt, str):
cnt += content
self._cur = key, cnt
def endElement(self, name):
if name is None or self._cur is None:
return
key, cnt = self._cur
if isinstance(cnt, list) and len(cnt) == 1:
self.items[key] = cnt
else:
self.items[key] = cnt
self._cur = None
return True
def addChild(self, name, child):
def addChild(self, name: str, child: Any):
if self._cur is None:
return False
key, cnt = self._cur
if not isinstance(cnt, list):
cnt = []
cnt.append(child)
self._cur = key, cnt
self.items[self._cur] = child
self._cur = None
return True
def parsedForm(self) -> Mapping:
return self.items
def __getitem__(self, item):
def __getitem__(self, item: str) -> Any:
return self.items[item]
def __setitem__(self, item, v):
def __setitem__(self, item: str, v: str) -> None:
self.items[item] = v
def __contains__(self, item):
def __contains__(self, item: str) -> bool:
return item in self.items
def __repr__(self):
def __repr__(self) -> str:
return self.items.__repr__()
def serialize(self) -> ParsingNode:
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
for k in self.items:
chld = ParsingNode.serialize_node(self.items[k])
chld.attrs["key"] = k
node.children.append(chld)
return node

View file

@ -14,57 +14,109 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Any, Optional, Mapping, Union
from nemubot.datastore.nodes.serializable import Serializable
class ParsingNode:
"""Allow any kind of subtags, just keep parsed ones
"""
def __init__(self, tag=None, **kwargs):
def __init__(self,
tag: Optional[str] = None,
**kwargs):
self.tag = tag
self.attrs = kwargs
self.content = ""
self.children = []
def characters(self, content):
def characters(self, content: str) -> None:
self.content += content
def addChild(self, name, child):
def addChild(self, name: str, child: Any) -> bool:
self.children.append(child)
return True
def hasNode(self, nodename):
def hasNode(self, nodename: str) -> bool:
return self.getNode(nodename) is not None
def getNode(self, nodename):
def getNode(self, nodename: str) -> Optional[Any]:
for c in self.children:
if c is not None and c.tag == nodename:
return c
return None
def __getitem__(self, item):
def __getitem__(self, item: str) -> Any:
return self.attrs[item]
def __contains__(self, item):
def __contains__(self, item: str) -> bool:
return item in self.attrs
def serialize_node(node: Union[Serializable, str, int, float, list, dict],
**def_kwargs):
"""Serialize any node or basic data to a ParsingNode instance"""
if isinstance(node, Serializable):
node = node.serialize()
if isinstance(node, str):
from nemubot.datastore.nodes.python import StringNode
pn = StringNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, int):
from nemubot.datastore.nodes.python import IntNode
pn = IntNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, float):
from nemubot.datastore.nodes.python import FloatNode
pn = FloatNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, list):
from nemubot.datastore.nodes.basic import ListNode
pn = ListNode(**def_kwargs)
pn.items = node
return pn.serialize()
elif isinstance(node, dict):
from nemubot.datastore.nodes.basic import DictNode
pn = DictNode(**def_kwargs)
pn.items = node
return pn.serialize()
else:
assert isinstance(node, ParsingNode)
return node
class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary
"""
def __init__(self, tag, **kwargs):
def __init__(self,
tag: str,
**kwargs):
super().__init__(tag, **kwargs)
self._cur = None
self._deep_cur = 0
def startElement(self, name, attrs):
def startElement(self, name: str, attrs: Mapping[str, str]):
if self._cur is None:
self._cur = GenericNode(name, **attrs)
self._deep_cur = 0
@ -74,14 +126,14 @@ class GenericNode(ParsingNode):
return True
def characters(self, content):
def characters(self, content: str):
if self._cur is None:
super().characters(content)
else:
self._cur.characters(content)
def endElement(self, name):
def endElement(self, name: str):
if name is None:
return

View file

@ -0,0 +1,89 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable
class PythonTypeNode(Serializable):
"""XML node representing a Python simple type
"""
def __init__(self, **kwargs):
self.value = None
self._cnt = ""
def characters(self, content: str) -> None:
self._cnt += content
def endElement(self, name: str) -> None:
raise NotImplemented
def __repr__(self) -> str:
return self.value.__repr__()
def parsedForm(self):
return self.value
def serialize(self):
raise NotImplemented
class IntNode(PythonTypeNode):
serializetag = "int"
def endElement(self, name: str) -> bool:
self.value = int(self._cnt)
return True
def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
class FloatNode(PythonTypeNode):
serializetag = "float"
def endElement(self, name: str) -> bool:
self.value = float(self._cnt)
return True
def serialize(self) -> ParsingNode:
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
class StringNode(PythonTypeNode):
serializetag = "str"
def endElement(self, name):
self.value = str(self._cnt)
return True
def serialize(self):
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node

View file

@ -1,7 +1,5 @@
# coding=utf-8
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -16,8 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class PromptReset(Exception):
def __init__(self, type):
super(PromptReset, self).__init__("Prompt reset asked")
self.type = type
class Serializable:
def serialize(self):
# Implementations of this function should return ParsingNode items
return NotImplemented

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,6 +17,7 @@
import fcntl
import logging
import os
from typing import Any, Mapping
import xml.parsers.expat
from nemubot.datastore.abstract import Abstract
@ -28,7 +29,9 @@ class XML(Abstract):
"""A concrete implementation of a data store that relies on XML files"""
def __init__(self, basedir, rotate=True):
def __init__(self,
basedir: str,
rotate: bool = True):
"""Initialize the datastore
Arguments:
@ -36,17 +39,24 @@ class XML(Abstract):
rotate -- auto-backup files?
"""
self.basedir = basedir
self.basedir = os.path.abspath(basedir)
self.rotate = rotate
self.nb_save = 0
def open(self):
logger.info("Initiate XML datastore at %s, rotation %s",
self.basedir,
"enabled" if self.rotate else "disabled")
def open(self) -> bool:
"""Lock the directory"""
if not os.path.isdir(self.basedir):
logger.debug("Datastore directory not found, creating: %s", self.basedir)
os.mkdir(self.basedir)
lock_path = os.path.join(self.basedir, ".used_by_nemubot")
lock_path = self._get_lock_file_path()
logger.debug("Locking datastore directory via %s", lock_path)
self.lock_file = open(lock_path, 'a+')
ok = True
@ -64,57 +74,92 @@ class XML(Abstract):
self.lock_file.write(str(os.getpid()))
self.lock_file.flush()
logger.info("Datastore successfuly opened at %s", self.basedir)
return True
def close(self):
def close(self) -> bool:
"""Release a locked path"""
if hasattr(self, "lock_file"):
self.lock_file.close()
lock_path = os.path.join(self.basedir, ".used_by_nemubot")
lock_path = self._get_lock_file_path()
if os.path.isdir(self.basedir) and os.path.exists(lock_path):
os.unlink(lock_path)
del self.lock_file
logger.info("Datastore successfully closed at %s", self.basedir)
return True
else:
logger.warn("Datastore not open/locked or lock file not found")
return False
def _get_data_file_path(self, module):
def _get_data_file_path(self, module: str) -> str:
"""Get the path to the module data file"""
return os.path.join(self.basedir, module + ".xml")
def load(self, module):
def _get_lock_file_path(self) -> str:
"""Get the path to the datastore lock file"""
return os.path.join(self.basedir, ".used_by_nemubot")
def load(self, module: str, extendsTags: Mapping[str, Any] = {}) -> Abstract:
"""Load data for the given module
Argument:
module -- the module name of data to load
"""
logger.debug("Trying to load data for %s%s",
module,
(" with tags: " + ", ".join(extendsTags.keys())) if len(extendsTags) else "")
data_file = self._get_data_file_path(module)
def parse(path: str):
from nemubot.tools.xmlparser import XMLParser
from nemubot.datastore.nodes import basic as basicNodes
from nemubot.datastore.nodes import python as pythonNodes
d = {
basicNodes.ListNode.serializetag: basicNodes.ListNode,
basicNodes.DictNode.serializetag: basicNodes.DictNode,
pythonNodes.IntNode.serializetag: pythonNodes.IntNode,
pythonNodes.FloatNode.serializetag: pythonNodes.FloatNode,
pythonNodes.StringNode.serializetag: pythonNodes.StringNode,
}
d.update(extendsTags)
p = XMLParser(d)
return p.parse_file(path)
# Try to load original file
if os.path.isfile(data_file):
from nemubot.tools.xmlparser import parse_file
try:
return parse_file(data_file)
return parse(data_file)
except xml.parsers.expat.ExpatError:
# Try to load from backup
for i in range(10):
path = data_file + "." + str(i)
if os.path.isfile(path):
try:
cnt = parse_file(path)
cnt = parse(path)
logger.warn("Restoring from backup: %s", path)
logger.warn("Restoring data from backup: %s", path)
return cnt
except xml.parsers.expat.ExpatError:
continue
# Default case: initialize a new empty datastore
logger.warn("No data found in store for %s, creating new set", module)
return Abstract.load(self, module)
def _rotate(self, path):
def _rotate(self, path: str) -> None:
"""Backup given path
Argument:
@ -130,7 +175,26 @@ class XML(Abstract):
if os.path.isfile(src):
os.rename(src, dst)
def save(self, module, data):
def _save_node(self, gen, node: Any):
from nemubot.datastore.nodes.generic import ParsingNode
# First, get the serialized form of the node
node = ParsingNode.serialize_node(node)
assert node.tag is not None, "Undefined tag name"
gen.startElement(node.tag, {k: str(node.attrs[k]) for k in node.attrs})
gen.characters(node.content)
for child in node.children:
self._save_node(gen, child)
gen.endElement(node.tag)
def save(self, module: str, data: Any) -> bool:
"""Load data for the given module
Argument:
@ -139,8 +203,22 @@ class XML(Abstract):
"""
path = self._get_data_file_path(module)
logger.debug("Trying to save data for module %s in %s", module, path)
if self.rotate:
self._rotate(path)
return data.save(path)
import tempfile
_, tmpath = tempfile.mkstemp()
with open(tmpath, "w") as f:
import xml.sax.saxutils
gen = xml.sax.saxutils.XMLGenerator(f, "utf-8")
gen.startDocument()
self._save_node(gen, data)
gen.endDocument()
# Atomic save
import shutil
shutil.move(tmpath, path)
return True

View file

@ -15,15 +15,23 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Optional
class ModuleEvent:
"""Representation of a event initiated by a bot module"""
def __init__(self, call=None, call_data=None, func=None, func_data=None,
cmp=None, cmp_data=None, interval=60, offset=0, times=1):
def __init__(self,
call: Callable = None,
call_data: Callable = None,
func: Callable = None,
func_data: Any = None,
cmp: Any = None,
cmp_data: Any = None,
interval: int = 60,
offset: int = 0,
times: int = 1):
"""Initialize the event
Keyword arguments:
@ -70,8 +78,9 @@ class ModuleEvent:
# How many times do this event?
self.times = times
@property
def current(self):
def current(self) -> Optional[datetime.datetime]:
"""Return the date of the near check"""
if self.times != 0:
if self._end is None:
@ -79,8 +88,9 @@ class ModuleEvent:
return self._end
return None
@property
def next(self):
def next(self) -> Optional[datetime.datetime]:
"""Return the date of the next check"""
if self.times != 0:
if self._end is None:
@ -90,14 +100,16 @@ class ModuleEvent:
return self._end
return None
@property
def time_left(self):
def time_left(self) -> Union[datetime.datetime, int]:
"""Return the time left before/after the near check"""
if self.current is not None:
return self.current - datetime.now(timezone.utc)
return 99999 # TODO: 99999 is not a valid time to return
def check(self):
def check(self) -> None:
"""Run a check and realized the event if this is time"""
# Get initial data

View file

@ -32,10 +32,3 @@ class IMException(Exception):
from nemubot.message import Text
return Text(*self.args,
server=msg.server, to=msg.to_response)
def reload():
import imp
import nemubot.exception.Keyword
imp.reload(nemubot.exception.printer.IRC)

View file

@ -49,23 +49,3 @@ class hook:
def pre(*args, store=["pre"], **kwargs):
return hook._add(store, Abstract, *args, **kwargs)
def reload():
import imp
import nemubot.hooks.abstract
imp.reload(nemubot.hooks.abstract)
import nemubot.hooks.command
imp.reload(nemubot.hooks.command)
import nemubot.hooks.message
imp.reload(nemubot.hooks.message)
import nemubot.hooks.keywords
imp.reload(nemubot.hooks.keywords)
nemubot.hooks.keywords.reload()
import nemubot.hooks.manager
imp.reload(nemubot.hooks.manager)

View file

@ -26,11 +26,22 @@ class NoKeyword(Abstract):
return super().check(mkw)
def reload():
import imp
class AnyKeyword(Abstract):
import nemubot.hooks.keywords.abstract
imp.reload(nemubot.hooks.keywords.abstract)
def __init__(self, h):
"""Class that accepts any passed keywords
import nemubot.hooks.keywords.dict
imp.reload(nemubot.hooks.keywords.dict)
Arguments:
h -- Help string
"""
super().__init__()
self.h = h
def check(self, mkw):
return super().check(mkw)
def help(self):
return self.h

View file

@ -21,7 +21,7 @@ class HooksManager:
"""Class to manage hooks"""
def __init__(self, name="core"):
def __init__(self, name: str = "core"):
"""Initialize the manager"""
self.hooks = dict()

View file

@ -18,16 +18,18 @@ from importlib.abc import Finder
from importlib.machinery import SourceFileLoader
import logging
import os
from typing import Callable
logger = logging.getLogger("nemubot.importer")
class ModuleFinder(Finder):
def __init__(self, modules_paths, add_module):
def __init__(self, modules_paths: str, add_module: Callable[]):
self.modules_paths = modules_paths
self.add_module = add_module
def find_module(self, fullname, path=None):
# Search only for new nemubot modules (packages init)
if path is None:

View file

@ -19,27 +19,3 @@ from nemubot.message.text import Text
from nemubot.message.directask import DirectAsk
from nemubot.message.command import Command
from nemubot.message.command import OwnerCommand
def reload():
global Abstract, Text, DirectAsk, Command, OwnerCommand
import imp
import nemubot.message.abstract
imp.reload(nemubot.message.abstract)
Abstract = nemubot.message.abstract.Abstract
imp.reload(nemubot.message.text)
Text = nemubot.message.text.Text
imp.reload(nemubot.message.directask)
DirectAsk = nemubot.message.directask.DirectAsk
imp.reload(nemubot.message.command)
Command = nemubot.message.command.Command
OwnerCommand = nemubot.message.command.OwnerCommand
import nemubot.message.visitor
imp.reload(nemubot.message.visitor)
import nemubot.message.printer
imp.reload(nemubot.message.printer)
nemubot.message.printer.reload()

View file

@ -22,7 +22,7 @@ class Command(Abstract):
"""This class represents a specialized TextMessage"""
def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs):
Abstract.__init__(self, *nargs, **kargs)
super().__init__(*nargs, **kargs)
self.cmd = cmd
self.args = args if args is not None else list()

View file

@ -28,7 +28,7 @@ class DirectAsk(Text):
designated -- the user designated by the message
"""
Text.__init__(self, *args, **kargs)
super().__init__(*args, **kargs)
self.designated = designated

View file

@ -22,4 +22,4 @@ class IRC(SocketPrinter):
def visit_Text(self, msg):
self.pp += "PRIVMSG %s :" % ",".join(msg.to)
SocketPrinter.visit_Text(self, msg)
super().visit_Text(msg)

View file

@ -13,12 +13,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def reload():
import imp
import nemubot.message.printer.IRC
imp.reload(nemubot.message.printer.IRC)
import nemubot.message.printer.socket
imp.reload(nemubot.message.printer.socket)

View file

@ -35,7 +35,7 @@ class Socket(AbstractVisitor):
others = [to for to in msg.to if to != msg.designated]
# Avoid nick starting message when discussing on user channel
if len(others) != len(msg.to):
if len(others) == 0 or len(others) != len(msg.to):
res = Text(msg.message,
server=msg.server, date=msg.date,
to=msg.to, frm=msg.frm)

View file

@ -0,0 +1,34 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.message.abstract import Abstract
class Response(Abstract):
def __init__(self, cmd, args=None, *nargs, **kargs):
super().__init__(*nargs, **kargs)
self.cmd = cmd
self.args = args if args is not None else list()
def __str__(self):
return self.cmd + " @" + ",@".join(self.args)
@property
def cmds(self):
# TODO: this is for legacy modules
return [self.cmd] + self.args

View file

@ -28,7 +28,7 @@ class Text(Abstract):
message -- the parsed message
"""
Abstract.__init__(self, *args, **kargs)
super().__init__(*args, **kargs)
self.message = message

View file

@ -39,6 +39,7 @@ class ModuleContext:
self.hooks = list()
self.events = list()
self.extendtags = dict()
self.debug = context.verbosity > 0 if context is not None else False
from nemubot.hooks import Abstract as AbstractHook
@ -46,7 +47,7 @@ class ModuleContext:
# Define some callbacks
if context is not None:
def load_data():
return context.datastore.load(module_name)
return context.datastore.load(module_name, extendsTags=self.extendtags)
def add_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook
@ -77,8 +78,7 @@ class ModuleContext:
else: # Used when using outside of nemubot
def load_data():
from nemubot.tools.xmlparser import module_state
return module_state.ModuleState("nemubotstate")
return None
def add_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook
@ -97,6 +97,8 @@ class ModuleContext:
module.logger.info("Send response: %s", res)
def save():
# Don't save if no data has been access
if hasattr(self, "_data"):
context.datastore.save(module_name, self.data)
def subparse(orig, cnt):
@ -120,6 +122,21 @@ class ModuleContext:
self._data = self.load_data()
return self._data
@data.setter
def data(self, value):
assert value is not None
self._data = value
def register_tags(self, **tags):
self.extendtags.update(tags)
def unregister_tags(self, *tags):
for t in tags:
del self.extendtags[t]
def unload(self):
"""Perform actions for unloading the module"""

View file

@ -1,142 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import shlex
import sys
import traceback
from nemubot.prompt import builtins
class Prompt:
def __init__(self):
self.selectedServer = None
self.lastretcode = 0
self.HOOKS_CAPS = dict()
self.HOOKS_LIST = dict()
def add_cap_hook(self, name, call, data=None):
self.HOOKS_CAPS[name] = lambda t, c: call(t, data=data,
context=c, prompt=self)
def add_list_hook(self, name, call):
self.HOOKS_LIST[name] = call
def lex_cmd(self, line):
"""Return an array of tokens
Argument:
line -- the line to lex
"""
try:
cmds = shlex.split(line)
except:
exc_type, exc_value, _ = sys.exc_info()
sys.stderr.write(traceback.format_exception_only(exc_type,
exc_value)[0])
return
bgn = 0
# Separate commands (command separator: ;)
for i in range(0, len(cmds)):
if cmds[i][-1] == ';':
if i != bgn:
yield cmds[bgn:i]
bgn = i + 1
# Return rest of the command (that not end with a ;)
if bgn != len(cmds):
yield cmds[bgn:]
def exec_cmd(self, toks, context):
"""Execute the command
Arguments:
toks -- lexed tokens to executes
context -- current bot context
"""
if toks[0] in builtins.CAPS:
self.lastretcode = builtins.CAPS[toks[0]](toks, context, self)
elif toks[0] in self.HOOKS_CAPS:
self.lastretcode = self.HOOKS_CAPS[toks[0]](toks, context)
else:
print("Unknown command: `%s'" % toks[0])
self.lastretcode = 127
def getPS1(self):
"""Get the PS1 associated to the selected server"""
if self.selectedServer is None:
return "nemubot"
else:
return self.selectedServer.id
def run(self, context):
"""Launch the prompt
Argument:
context -- current bot context
"""
from nemubot.prompt.error import PromptError
from nemubot.prompt.reset import PromptReset
while True: # Stopped by exception
try:
line = input("\033[0;33m%s\033[0;%d\033[0m " %
(self.getPS1(), 31 if self.lastretcode else 32))
cmds = self.lex_cmd(line.strip())
for toks in cmds:
try:
self.exec_cmd(toks, context)
except PromptReset:
raise
except PromptError as e:
print(e.message)
self.lastretcode = 128
except:
exc_type, exc_value, exc_traceback = sys.exc_info()
traceback.print_exception(exc_type, exc_value,
exc_traceback)
except KeyboardInterrupt:
print("")
except EOFError:
print("quit")
return True
def hotswap(bak):
p = Prompt()
p.HOOKS_CAPS = bak.HOOKS_CAPS
p.HOOKS_LIST = bak.HOOKS_LIST
return p
def reload():
import imp
import nemubot.prompt.builtins
imp.reload(nemubot.prompt.builtins)
import nemubot.prompt.error
imp.reload(nemubot.prompt.error)
import nemubot.prompt.reset
imp.reload(nemubot.prompt.reset)

View file

@ -1,128 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def end(toks, context, prompt):
"""Quit the prompt for reload or exit"""
from nemubot.prompt.reset import PromptReset
if toks[0] == "refresh":
raise PromptReset("refresh")
elif toks[0] == "reset":
raise PromptReset("reset")
raise PromptReset("quit")
def liste(toks, context, prompt):
"""Show some lists"""
if len(toks) > 1:
for l in toks[1:]:
l = l.lower()
if l == "server" or l == "servers":
for srv in context.servers.keys():
print (" - %s (state: %s) ;" % (srv,
"connected" if context.servers[srv].connected else "disconnected"))
if len(context.servers) == 0:
print (" > No server loaded")
elif l == "mod" or l == "mods" or l == "module" or l == "modules":
for mod in context.modules.keys():
print (" - %s ;" % mod)
if len(context.modules) == 0:
print (" > No module loaded")
elif l in prompt.HOOKS_LIST:
f, d = prompt.HOOKS_LIST[l]
f(d, context, prompt)
else:
print (" Unknown list `%s'" % l)
return 2
return 0
else:
print (" Please give a list to show: servers, ...")
return 1
def load(toks, context, prompt):
"""Load an XML configuration file"""
if len(toks) > 1:
for filename in toks[1:]:
context.load_file(filename)
else:
print ("Not enough arguments. `load' takes a filename.")
return 1
def select(toks, context, prompt):
"""Select the current server"""
if (len(toks) == 2 and toks[1] != "None" and
toks[1] != "nemubot" and toks[1] != "none"):
if toks[1] in context.servers:
prompt.selectedServer = context.servers[toks[1]]
else:
print ("select: server `%s' not found." % toks[1])
return 1
else:
prompt.selectedServer = None
def unload(toks, context, prompt):
"""Unload a module"""
if len(toks) == 2 and toks[1] == "all":
for name in context.modules.keys():
context.unload_module(name)
elif len(toks) > 1:
for name in toks[1:]:
if context.unload_module(name):
print (" Module `%s' successfully unloaded." % name)
else:
print (" No module `%s' loaded, can't unload!" % name)
return 2
else:
print ("Not enough arguments. `unload' takes a module name.")
return 1
def debug(toks, context, prompt):
"""Enable/Disable debug mode on a module"""
if len(toks) > 1:
for name in toks[1:]:
if name in context.modules:
context.modules[name].DEBUG = not context.modules[name].DEBUG
if context.modules[name].DEBUG:
print (" Module `%s' now in DEBUG mode." % name)
else:
print (" Debug for module module `%s' disabled." % name)
else:
print (" No module `%s' loaded, can't debug!" % name)
return 2
else:
print ("Not enough arguments. `debug' takes a module name.")
return 1
# Register build-ins
CAPS = {
'quit': end, # Disconnect all server and quit
'exit': end, # Alias for quit
'reset': end, # Reload the prompt
'refresh': end, # Reload the prompt but save modules
'load': load, # Load a servers or module configuration file
'unload': unload, # Unload a module and remove it from the list
'select': select, # Select a server
'list': liste, # Show lists
'debug': debug, # Pass a module in debug mode
}

View file

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None):
server.Server.__init__(self)
super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion

View file

@ -27,10 +27,10 @@ class IRC(SocketServer):
"""Concrete implementation of a connexion to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None,
def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None):
channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server
Keyword arguments:
@ -54,8 +54,8 @@ class IRC(SocketServer):
self.owner = owner
self.realname = realname
self.id = self.username + "@" + host + ":" + str(port)
SocketServer.__init__(self, host=host, port=port, ssl=ssl)
super().__init__(name=self.username + "@" + host + ":" + str(port),
host=host, port=port, **kwargs)
self.printer = IRCPrinter
self.encoding = encoding
@ -232,29 +232,29 @@ class IRC(SocketServer):
# Open/close
def _open(self):
if SocketServer._open(self):
def connect(self):
super().connect()
if self.password is not None:
self.write("PASS :" + self.password)
if self.capabilities is not None:
self.write("CAP LS")
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
return True
return False
def _close(self):
if self.connected: self.write("QUIT")
return SocketServer._close(self)
def close(self):
if not self._closed:
self.write("QUIT")
return super().close()
# Writes: as inherited
# Read
def read(self):
for line in SocketServer.read(self):
def async_read(self):
for line in super().async_read():
# PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding)

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,20 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock()
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
def factory(uri, ssl=False, **init_args):
from urllib.parse import urlparse, unquote
o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
@ -36,7 +29,7 @@ def factory(uri, **init_args):
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True
if o.scheme == "ircs": ssl = True
if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username
@ -65,24 +58,11 @@ def factory(uri, **init_args):
args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer
return IRCServer(**args)
else:
return None
srv = IRCServer(**args)
if ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
return ctx.wrap_socket(srv)
def reload():
import imp
import nemubot.server.abstract
imp.reload(nemubot.server.abstract)
import nemubot.server.socket
imp.reload(nemubot.server.socket)
import nemubot.server.IRC
imp.reload(nemubot.server.IRC)
import nemubot.server.message
imp.reload(nemubot.server.message)
nemubot.server.message.reload()
return srv

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,71 +14,68 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging
import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist
from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase):
class AbstractServer:
"""An abstract server: handle communication with an IM server"""
def __init__(self, send_callback=None):
def __init__(self, name=None, **kwargs):
"""Initialize an abstract server
Keyword argument:
send_callback -- Callback when developper want to send a message
name -- Identifier of the socket, for convinience
"""
if not hasattr(self, "id"):
raise Exception("No id defined for this server. Please set one!")
self._name = name
self.logger = logging.getLogger("nemubot.server." + self.id)
super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
def __del__(self):
print("Server deleted")
@property
def name(self):
if self._name is not None:
return self._name
else:
self._send_callback = self._write_select
return self.fileno()
# Open/close
def __enter__(self):
self.open()
return self
def connect(self, *args, **kwargs):
"""Register the server in _poll"""
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._connected()
def _connected(self):
sync_act("sckt register %d" % self.fileno())
def __exit__(self, type, value, traceback):
self.close()
def close(self, *args, **kwargs):
"""Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self):
"""Generic open function that register the server un _rlist in case
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
if self.fileno() > 0:
sync_act("sckt unregister %d" % self.fileno())
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
super().close(*args, **kwargs)
# Writes
@ -90,13 +87,16 @@ class AbstractServer(io.IOBase):
message -- message to send
"""
self._send_callback(message)
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt write %d" % self.fileno())
def write_select(self):
"""Internal function used by the select function"""
def async_write(self):
"""Internal function used when the file descriptor is writable"""
try:
_wlist.remove(self)
sync_act("sckt unwrite %d" % self.fileno())
while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done()
@ -105,19 +105,6 @@ class AbstractServer(io.IOBase):
pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to Queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response):
"""Send a formated Message class
@ -140,13 +127,39 @@ class AbstractServer(io.IOBase):
# Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg):
raise NotImplemented
# Exceptions
def exception(self):
"""Exception occurs in fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.id)
def exception(self, flags):
"""Exception occurs on fd"""
self.close()

View file

@ -146,7 +146,7 @@ class IRC(Abstract):
receivers = self.decode(self.params[0]).split(',')
common_args = {
"server": srv.id,
"server": srv.name,
"date": self.tags["time"],
"to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers],

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,99 +14,35 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer):
class _Socket(AbstractServer, socket.socket):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
"""Concrete implementation of a socket connection"""
def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None):
if id is not None:
self.id = id
AbstractServer.__init__(self)
if sock_location is not None:
self.filename = sock_location
elif host is not None:
self.host = host
self.port = int(port)
self.ssl = ssl
def __init__(self, **kwargs):
"""Create a server socket
Keyword arguments:
ssl -- Should TLS connection enabled
"""
super().__init__(**kwargs)
self.socket = socket
self.readbuffer = b''
self.printer = SocketPrinter
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def connected(self):
"""Indicator of the connection aliveness"""
return self.socket is not None
# Open/close
def _open(self):
import os
import socket
if self.connected:
return True
try:
if hasattr(self, "filename"):
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.connect(self.filename)
self.logger.info("Connected to %s", self.filename)
else:
self.socket = socket.create_connection((self.host, self.port))
self.logger.info("Connected to %s:%d", self.host, self.port)
except socket.error as e:
self.socket = None
self.logger.critical("Unable to connect to %s:%d: %s",
self.host, self.port,
os.strerror(e.errno))
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return True
def _close(self):
import socket
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if self.connected:
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
except socket.error:
pass
self.socket = None
return True
# Write
def _write(self, cnt):
if not self.connected:
return
self.socket.send(cnt)
self.sendall(cnt)
def format(self, txt):
@ -118,80 +54,113 @@ class SocketServer(AbstractServer):
# Read
def read(self):
if not self.connected:
return []
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line
def read(self, n=1024):
return self.recv(n)
class SocketListener(AbstractServer):
def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex
def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None):
self.id = id
AbstractServer.__init__(self)
self.new_server_cb = new_server_cb
self.sock_location = sock_location
self.host = host
self.port = port
self.ssl = ssl
self.nb_son = 0
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def connected(self):
"""Indicator of the connection aliveness"""
return self.socket is not None
def _open(self):
import os
import socket
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
if self.sock_location is not None:
line = line.strip().decode()
try:
os.remove(self.sock_location)
args = shlex.split(line)
except ValueError:
args = line.split(' ')
if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
class SocketServer(_Socket):
def __init__(self, host, port, bind=None, **kwargs):
super().__init__(family=socket.AF_INET, **kwargs)
assert(host is not None)
assert(isinstance(port, int))
self._host = host
self._port = port
self._bind = bind
def connect(self):
self.logger.info("Connection to %s:%d", self._host, self._port)
super().connect((self._host, self._port))
if self._bind:
super().bind(self._bind)
class UnixSocket(_Socket):
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def connect(self):
self.logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
class _Listener(_Socket):
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self):
conn, addr = self.accept()
self.logger.info("Accept new connection from %s", addr)
fileno = conn.fileno()
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
ss.connect = ss._connected
self._new_server_cb(ss, autoconnect=True)
return b''
class UnixSocketListener(_Listener, UnixSocket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self):
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket.bind((self.host, self.port))
self.socket.listen(5)
return True
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._connected()
def _close(self):
def close(self):
import os
import socket
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
self.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
# Read
super().close()
def read(self):
if not self.connected:
return []
conn, addr = self.socket.accept()
self.nb_son += 1
ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []
try:
if self._socket_path is not None:
os.remove(self._socket_path)
except:
pass

View file

@ -13,29 +13,3 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def reload():
import imp
import nemubot.tools.config
imp.reload(nemubot.tools.config)
import nemubot.tools.countdown
imp.reload(nemubot.tools.countdown)
import nemubot.tools.feed
imp.reload(nemubot.tools.feed)
import nemubot.tools.date
imp.reload(nemubot.tools.date)
import nemubot.tools.human
imp.reload(nemubot.tools.human)
import nemubot.tools.web
imp.reload(nemubot.tools.web)
import nemubot.tools.xmlparser
imp.reload(nemubot.tools.xmlparser)
import nemubot.tools.xmlparser.node
imp.reload(nemubot.tools.xmlparser.node)

View file

@ -57,11 +57,15 @@ def word_distance(str1, str2):
return d[len(str1)][len(str2)]
def guess(pattern, expect):
def guess(pattern, expect, max_depth=0):
if max_depth == 0:
max_depth = 1 + len(pattern) / 4
elif max_depth <= -1:
max_depth = len(pattern) - max_depth
if len(expect):
se = sorted([(e, word_distance(pattern, e)) for e in expect], key=lambda x: x[1])
_, m = se[0]
for e, wd in se:
if wd > m or wd > 1 + len(pattern) / 4:
if wd > m or wd > max_depth:
break
yield e

View file

@ -51,11 +51,13 @@ class XMLParser:
def __init__(self, knodes):
self.knodes = knodes
def _reset(self):
self.stack = list()
self.child = 0
def parse_file(self, path):
self._reset()
p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement
@ -69,6 +71,7 @@ class XMLParser:
def parse_string(self, s):
self._reset()
p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement
@ -126,10 +129,13 @@ class XMLParser:
if hasattr(self.current, "endElement"):
self.current.endElement(None)
if hasattr(self.current, "parsedForm") and callable(self.current.parsedForm):
self.stack[-1] = self.current.parsedForm()
# Don't remove root
if len(self.stack) > 1:
last = self.stack.pop()
if hasattr(self.current, "addChild"):
if hasattr(self.current, "addChild") and callable(self.current.addChild):
if self.current.addChild(name, last):
return
raise TypeError(name + " tag not expected in " + self.display_stack())

View file

@ -63,13 +63,13 @@ setup(
'nemubot',
'nemubot.config',
'nemubot.datastore',
'nemubot.datastore.nodes',
'nemubot.event',
'nemubot.exception',
'nemubot.hooks',
'nemubot.hooks.keywords',
'nemubot.message',
'nemubot.message.printer',
'nemubot.prompt',
'nemubot.server',
'nemubot.server.message',
'nemubot.tools',