diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 23cf4a0..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "modules/nextstop/external"] - path = modules/nextstop/external - url = git://github.com/nbr23/NextStop.git diff --git a/README.md b/README.md index aa3b141..1d40faf 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ Requirements *nemubot* requires at least Python 3.3 to work. +Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629). + Some modules (like `cve`, `nextstop` or `laposte`) require the [BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/), but the core and framework has no dependency. diff --git a/modules/events.py b/modules/events.py index 2887514..a35c28b 100644 --- a/modules/events.py +++ b/modules/events.py @@ -16,7 +16,7 @@ from more import Response def help_full (): - return "This module store a lot of events: ny, we, " + (", ".join(context.datas.index.keys())) + "\n!eventslist: gets list of timer\n!start /something/: launch a timer" + return "This module store a lot of events: ny, we, " + (", ".join(context.datas.index.keys() if hasattr(context, "datas") else [])) + "\n!eventslist: gets list of timer\n!start /something/: launch a timer" def load(context): diff --git a/modules/nextstop.xml b/modules/nextstop.xml deleted file mode 100644 index d34e8ae..0000000 --- a/modules/nextstop.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - - diff --git a/modules/nextstop/__init__.py b/modules/nextstop/__init__.py deleted file mode 100644 index 9530ab8..0000000 --- a/modules/nextstop/__init__.py +++ /dev/null @@ -1,55 +0,0 @@ -# coding=utf-8 - -"""Informe les usagers des prochains passages des transports en communs de la RATP""" - -from nemubot.exception import IMException -from nemubot.hooks import hook -from more import Response - -nemubotversion = 3.4 - -from .external.src import ratp - -def help_full (): - return "!ratp transport line [station]: Donne des informations sur les prochains passages du transport en commun séléctionné à l'arrêt désiré. Si aucune station n'est précisée, les liste toutes." - - -@hook.command("ratp") -def ask_ratp(msg): - """Hook entry from !ratp""" - if len(msg.args) >= 3: - transport = msg.args[0] - line = msg.args[1] - station = msg.args[2] - if len(msg.args) == 4: - times = ratp.getNextStopsAtStation(transport, line, station, msg.args[3]) - else: - times = ratp.getNextStopsAtStation(transport, line, station) - - if len(times) == 0: - raise IMException("la station %s n'existe pas sur le %s ligne %s." % (station, transport, line)) - - (time, direction, stationname) = times[0] - return Response(message=["\x03\x02%s\x03\x02 direction %s" % (time, direction) for time, direction, stationname in times], - title="Prochains passages du %s ligne %s à l'arrêt %s" % (transport, line, stationname), - channel=msg.channel) - - elif len(msg.args) == 2: - stations = ratp.getAllStations(msg.args[0], msg.args[1]) - - if len(stations) == 0: - raise IMException("aucune station trouvée.") - return Response([s for s in stations], title="Stations", channel=msg.channel) - - else: - raise IMException("Mauvais usage, merci de spécifier un type de transport et une ligne, ou de consulter l'aide du module.") - -@hook.command("ratp_alert") -def ratp_alert(msg): - if len(msg.args) == 2: - transport = msg.args[0] - cause = msg.args[1] - incidents = ratp.getDisturbance(cause, transport) - return Response(incidents, channel=msg.channel, nomore="No more incidents", count=" (%d more incidents)") - else: - raise IMException("Mauvais usage, merci de spécifier un type de transport et un type d'alerte (alerte, manif, travaux), ou de consulter l'aide du module.") diff --git a/modules/nextstop/external b/modules/nextstop/external deleted file mode 160000 index 3d5c9b2..0000000 --- a/modules/nextstop/external +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3d5c9b2d52fbd214f5aaad00e5f3952de919b3e5 diff --git a/modules/ratp.py b/modules/ratp.py new file mode 100644 index 0000000..7f4b211 --- /dev/null +++ b/modules/ratp.py @@ -0,0 +1,74 @@ +"""Informe les usagers des prochains passages des transports en communs de la RATP""" + +# PYTHON STUFFS ####################################################### + +from nemubot.exception import IMException +from nemubot.hooks import hook +from more import Response + +from nextstop import ratp + +@hook.command("ratp", + help="Affiche les prochains horaires de passage", + help_usage={ + "TRANSPORT": "Affiche les lignes du moyen de transport donné", + "TRANSPORT LINE": "Affiche les stations sur la ligne de transport donnée", + "TRANSPORT LINE STATION": "Affiche les prochains horaires de passage à l'arrêt donné", + "TRANSPORT LINE STATION DESTINATION": "Affiche les prochains horaires de passage dans la direction donnée", + }) +def ask_ratp(msg): + l = len(msg.args) + + transport = msg.args[0] if l > 0 else None + line = msg.args[1] if l > 1 else None + station = msg.args[2] if l > 2 else None + direction = msg.args[3] if l > 3 else None + + if station is not None: + times = sorted(ratp.getNextStopsAtStation(transport, line, station, direction), key=lambda i: i[0]) + + if len(times) == 0: + raise IMException("la station %s n'existe pas sur le %s ligne %s." % (station, transport, line)) + + (time, direction, stationname) = times[0] + return Response(message=["\x03\x02%s\x03\x02 direction %s" % (time, direction) for time, direction, stationname in times], + title="Prochains passages du %s ligne %s à l'arrêt %s" % (transport, line, stationname), + channel=msg.channel) + + elif line is not None: + stations = ratp.getAllStations(transport, line) + + if len(stations) == 0: + raise IMException("aucune station trouvée.") + return Response(stations, title="Stations", channel=msg.channel) + + elif transport is not None: + lines = ratp.getTransportLines(transport) + if len(lines) == 0: + raise IMException("aucune ligne trouvée.") + return Response(lines, title="Lignes", channel=msg.channel) + + else: + raise IMException("précise au moins un moyen de transport.") + + +@hook.command("ratp_alert", + help="Affiche les perturbations en cours sur le réseau") +def ratp_alert(msg): + if len(msg.args) == 0: + raise IMException("précise au moins un moyen de transport.") + + l = len(msg.args) + transport = msg.args[0] if l > 0 else None + line = msg.args[1] if l > 1 else None + + if line is not None: + d = ratp.getDisturbanceFromLine(transport, line) + if "date" in d and d["date"] is not None: + incidents = "Au {date[date]}, {title}: {message}".format(**d) + else: + incidents = "{title}: {message}".format(**d) + else: + incidents = ratp.getDisturbance(None, transport) + + return Response(incidents, channel=msg.channel, nomore="No more incidents", count=" (%d more incidents)") diff --git a/modules/suivi.py b/modules/suivi.py index 79910d4..9e517da 100644 --- a/modules/suivi.py +++ b/modules/suivi.py @@ -2,14 +2,14 @@ # PYTHON STUFF ############################################ -import urllib.request +import json import urllib.parse from bs4 import BeautifulSoup import re from nemubot.hooks import hook from nemubot.exception import IMException -from nemubot.tools.web import getURLContent +from nemubot.tools.web import getURLContent, getJSON from more import Response @@ -17,8 +17,7 @@ from more import Response def get_tnt_info(track_id): values = [] - data = getURLContent('www.tnt.fr/public/suivi_colis/recherche/' - 'visubontransport.do?bonTransport=%s' % track_id) + data = getURLContent('www.tnt.fr/public/suivi_colis/recherche/visubontransport.do?bonTransport=%s' % track_id) soup = BeautifulSoup(data) status_list = soup.find('div', class_='result__content') if not status_list: @@ -32,8 +31,7 @@ def get_tnt_info(track_id): def get_colissimo_info(colissimo_id): - colissimo_data = getURLContent("http://www.colissimo.fr/portail_colissimo/" - "suivre.do?colispart=%s" % colissimo_id) + colissimo_data = getURLContent("http://www.colissimo.fr/portail_colissimo/suivre.do?colispart=%s" % colissimo_id) soup = BeautifulSoup(colissimo_data) dataArray = soup.find(class_='dataArray') @@ -47,9 +45,8 @@ def get_colissimo_info(colissimo_id): def get_chronopost_info(track_id): data = urllib.parse.urlencode({'listeNumeros': track_id}) - track_baseurl = "http://www.chronopost.fr/expedier/" \ - "inputLTNumbersNoJahia.do?lang=fr_FR" - track_data = urllib.request.urlopen(track_baseurl, data.encode('utf-8')) + track_baseurl = "http://www.chronopost.fr/expedier/inputLTNumbersNoJahia.do?lang=fr_FR" + track_data = getURLContent(track_baseurl, data.encode('utf-8')) soup = BeautifulSoup(track_data) infoClass = soup.find(class_='numeroColi2') @@ -65,9 +62,8 @@ def get_chronopost_info(track_id): def get_colisprive_info(track_id): data = urllib.parse.urlencode({'numColis': track_id}) - track_baseurl = "https://www.colisprive.com/moncolis/pages/" \ - "detailColis.aspx" - track_data = urllib.request.urlopen(track_baseurl, data.encode('utf-8')) + track_baseurl = "https://www.colisprive.com/moncolis/pages/detailColis.aspx" + track_data = getURLContent(track_baseurl, data.encode('utf-8')) soup = BeautifulSoup(track_data) dataArray = soup.find(class_='BandeauInfoColis') @@ -82,8 +78,7 @@ def get_laposte_info(laposte_id): data = urllib.parse.urlencode({'id': laposte_id}) laposte_baseurl = "http://www.part.csuivi.courrier.laposte.fr/suivi/index" - laposte_data = urllib.request.urlopen(laposte_baseurl, - data.encode('utf-8')) + laposte_data = getURLContent(laposte_baseurl, data.encode('utf-8')) soup = BeautifulSoup(laposte_data) search_res = soup.find(class_='resultat_rech_simple_table').tbody.tr if (soup.find(class_='resultat_rech_simple_table').thead @@ -112,8 +107,7 @@ def get_postnl_info(postnl_id): data = urllib.parse.urlencode({'barcodes': postnl_id}) postnl_baseurl = "http://www.postnl.post/details/" - postnl_data = urllib.request.urlopen(postnl_baseurl, - data.encode('utf-8')) + postnl_data = getURLContent(postnl_baseurl, data.encode('utf-8')) soup = BeautifulSoup(postnl_data) if (soup.find(id='datatables') and soup.find(id='datatables').tbody @@ -132,6 +126,42 @@ def get_postnl_info(postnl_id): return (post_status.lower(), post_destination, post_date) +def get_fedex_info(fedex_id, lang="en_US"): + data = urllib.parse.urlencode({ + 'data': json.dumps({ + "TrackPackagesRequest": { + "appType": "WTRK", + "appDeviceType": "DESKTOP", + "uniqueKey": "", + "processingParameters": {}, + "trackingInfoList": [ + { + "trackNumberInfo": { + "trackingNumber": str(fedex_id), + "trackingQualifier": "", + "trackingCarrier": "" + } + } + ] + } + }), + 'action': "trackpackages", + 'locale': lang, + 'version': 1, + 'format': "json" + }) + fedex_baseurl = "https://www.fedex.com/trackingCal/track" + + fedex_data = getJSON(fedex_baseurl, data.encode('utf-8')) + + if ("TrackPackagesResponse" in fedex_data and + "packageList" in fedex_data["TrackPackagesResponse"] and + len(fedex_data["TrackPackagesResponse"]["packageList"]) and + not fedex_data["TrackPackagesResponse"]["packageList"][0]["isInvalid"] + ): + return fedex_data["TrackPackagesResponse"]["packageList"][0] + + # TRACKING HANDLERS ################################################### def handle_tnt(tracknum): @@ -189,6 +219,17 @@ def handle_coliprive(tracknum): return ("Colis Privé: \x02%s\x0F : \x02%s\x0F." % (tracknum, info)) +def handle_fedex(tracknum): + info = get_fedex_info(tracknum) + if info: + if info["displayActDeliveryDateTime"] != "": + return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, delivered on: {displayActDeliveryDateTime}.".format(**info)) + elif info["statusLocationCity"] != "": + return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: estimated delivery: {displayEstDeliveryDateTime}.".format(**info)) + else: + return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, estimated delivery: {displayEstDeliveryDateTime}.".format(**info)) + + TRACKING_HANDLERS = { 'laposte': handle_laposte, 'postnl': handle_postnl, @@ -196,6 +237,7 @@ TRACKING_HANDLERS = { 'chronopost': handle_chronopost, 'coliprive': handle_coliprive, 'tnt': handle_tnt, + 'fedex': handle_fedex, } diff --git a/modules/weather.py b/modules/weather.py index 1fadc71..1de0eb7 100644 --- a/modules/weather.py +++ b/modules/weather.py @@ -1,6 +1,6 @@ # coding=utf-8 -"""The weather module""" +"""The weather module. Powered by Dark Sky """ import datetime import re @@ -17,7 +17,7 @@ nemubotversion = 4.0 from more import Response -URL_DSAPI = "https://api.forecast.io/forecast/%s/%%s,%%s" +URL_DSAPI = "https://api.darksky.net/forecast/%s/%%s,%%s?lang=%%s&units=%%s" def load(context): if not context.config or "darkskyapikey" not in context.config: @@ -30,34 +30,19 @@ def load(context): URL_DSAPI = URL_DSAPI % context.config["darkskyapikey"] -def help_full (): - return "!weather /city/: Display the current weather in /city/." - - -def fahrenheit2celsius(temp): - return int((temp - 32) * 50/9)/10 - - -def mph2kmph(speed): - return int(speed * 160.9344)/100 - - -def inh2mmh(size): - return int(size * 254)/10 - - def format_wth(wth): - return ("%s °C %s; precipitation (%s %% chance) intensity: %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" % + return ("%s °C %s; precipitation (%s %% chance) intensity: %s mm/h; relative humidity: %s %%; wind speed: %s m/s %s°; cloud coverage: %s %%; pressure: %s hPa; visibility: %s km; ozone: %s DU" % ( - fahrenheit2celsius(wth["temperature"]), + wth["temperature"], wth["summary"], int(wth["precipProbability"] * 100), - inh2mmh(wth["precipIntensity"]), + wth["precipIntensity"], int(wth["humidity"] * 100), - mph2kmph(wth["windSpeed"]), + wth["windSpeed"], wth["windBearing"], int(wth["cloudCover"] * 100), int(wth["pressure"]), + int(wth["visibility"]), int(wth["ozone"]) )) @@ -66,7 +51,7 @@ def format_forecast_daily(wth): return ("%s; between %s-%s °C; precipitation (%s %% chance) intensity: maximum %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" % ( wth["summary"], - fahrenheit2celsius(wth["temperatureMin"]), fahrenheit2celsius(wth["temperatureMax"]), + wth["temperatureMin"], wth["temperatureMax"], int(wth["precipProbability"] * 100), inh2mmh(wth["precipIntensityMax"]), int(wth["humidity"] * 100), @@ -126,8 +111,8 @@ def treat_coord(msg): raise IMException("indique-moi un nom de ville ou des coordonnées.") -def get_json_weather(coords): - wth = web.getJSON(URL_DSAPI % (float(coords[0]), float(coords[1]))) +def get_json_weather(coords, lang="en", units="auto"): + wth = web.getJSON(URL_DSAPI % (float(coords[0]), float(coords[1]), lang, units)) # First read flags if wth is None or "darksky-unavailable" in wth["flags"]: @@ -149,10 +134,16 @@ def cmd_coordinates(msg): return Response("Les coordonnées de %s sont %s,%s" % (msg.args[0], coords["lat"], coords["long"]), channel=msg.channel) -@hook.command("alert") +@hook.command("alert", + keywords={ + "lang=LANG": "change the output language of weather sumarry; default: en", + "units=UNITS": "return weather conditions in the requested units; default: auto", + }) def cmd_alert(msg): loc, coords, specific = treat_coord(msg) - wth = get_json_weather(coords) + wth = get_json_weather(coords, + lang=msg.kwargs["lang"] if "lang" in msg.kwargs else "en", + units=msg.kwargs["units"] if "units" in msg.kwargs else "auto") res = Response(channel=msg.channel, nomore="No more weather alert", count=" (%d more alerts)") @@ -166,10 +157,20 @@ def cmd_alert(msg): return res -@hook.command("météo") +@hook.command("météo", + help="Display current weather and previsions", + help_usage={ + "CITY": "Display the current weather and previsions in CITY", + }, + keywords={ + "lang=LANG": "change the output language of weather sumarry; default: en", + "units=UNITS": "return weather conditions in the requested units; default: auto", + }) def cmd_weather(msg): loc, coords, specific = treat_coord(msg) - wth = get_json_weather(coords) + wth = get_json_weather(coords, + lang=msg.kwargs["lang"] if "lang" in msg.kwargs else "en", + units=msg.kwargs["units"] if "units" in msg.kwargs else "auto") res = Response(channel=msg.channel, nomore="No more weather information") @@ -243,3 +244,7 @@ def parseask(msg): context.save() return Response("ok, j'ai bien noté les coordonnées de %s" % res.group("city"), msg.channel, msg.nick) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/nemubot/__init__.py b/nemubot/__init__.py index a56c472..48de6ea 100644 --- a/nemubot/__init__.py +++ b/nemubot/__init__.py @@ -17,9 +17,9 @@ __version__ = '4.0.dev3' __author__ = 'nemunaire' -from nemubot.modulecontext import ModuleContext +from nemubot.modulecontext import _ModuleContext -context = ModuleContext(None, None) +context = _ModuleContext() def requires_version(min=None, max=None): @@ -53,41 +53,50 @@ def attach(pid, socketfile): sys.stderr.write("\n") return 1 - from select import select + import select + mypoll = select.poll() + + mypoll.register(sys.stdin.fileno(), select.POLLIN | select.POLLPRI) + mypoll.register(sock.fileno(), select.POLLIN | select.POLLPRI) try: while True: - rl, wl, xl = select([sys.stdin, sock], [], []) + for fd, flag in mypoll.poll(): + if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL): + sock.close() + print("Connection closed.") + return 1 - if sys.stdin in rl: - line = sys.stdin.readline().strip() - if line == "exit" or line == "quit": - return 0 - elif line == "reload": - import os, signal - os.kill(pid, signal.SIGHUP) - print("Reload signal sent. Please wait...") + if fd == sys.stdin.fileno(): + line = sys.stdin.readline().strip() + if line == "exit" or line == "quit": + return 0 + elif line == "reload": + import os, signal + os.kill(pid, signal.SIGHUP) + print("Reload signal sent. Please wait...") - elif line == "shutdown": - import os, signal - os.kill(pid, signal.SIGTERM) - print("Shutdown signal sent. Please wait...") + elif line == "shutdown": + import os, signal + os.kill(pid, signal.SIGTERM) + print("Shutdown signal sent. Please wait...") - elif line == "kill": - import os, signal - os.kill(pid, signal.SIGKILL) - print("Signal sent...") - return 0 + elif line == "kill": + import os, signal + os.kill(pid, signal.SIGKILL) + print("Signal sent...") + return 0 - elif line == "stack" or line == "stacks": - import os, signal - os.kill(pid, signal.SIGUSR1) - print("Debug signal sent. Consult logs.") + elif line == "stack" or line == "stacks": + import os, signal + os.kill(pid, signal.SIGUSR1) + print("Debug signal sent. Consult logs.") - else: - sock.send(line.encode() + b'\r\n') + else: + sock.send(line.encode() + b'\r\n') + + if fd == sock.fileno(): + sys.stdout.write(sock.recv(2048).decode()) - if sock in rl: - sys.stdout.write(sock.recv(2048).decode()) except KeyboardInterrupt: pass except: @@ -97,13 +106,28 @@ def attach(pid, socketfile): return 0 -def daemonize(): +def daemonize(socketfile=None, autoattach=True): """Detach the running process to run as a daemon """ import os import sys + if socketfile is not None: + try: + pid = os.fork() + if pid > 0: + if autoattach: + import time + os.waitpid(pid, 0) + time.sleep(1) + sys.exit(attach(pid, socketfile)) + else: + sys.exit(0) + except OSError as err: + sys.stderr.write("Unable to fork: %s\n" % err) + sys.exit(1) + try: pid = os.fork() if pid > 0: diff --git a/nemubot/__main__.py b/nemubot/__main__.py index 5a236f4..e1576fb 100644 --- a/nemubot/__main__.py +++ b/nemubot/__main__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2017 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -37,6 +37,9 @@ def main(): default=["./modules/"], help="directory to use as modules store") + parser.add_argument("-A", "--no-attach", action="store_true", + help="don't attach after fork") + parser.add_argument("-d", "--debug", action="store_true", help="don't deamonize, keep in foreground") @@ -71,32 +74,10 @@ def main(): args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile)) args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile)) args.logfile = os.path.abspath(os.path.expanduser(args.logfile)) - args.files = [ x for x in map(os.path.abspath, args.files)] - args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)] + args.files = [x for x in map(os.path.abspath, args.files)] + args.modules_path = [x for x in map(os.path.abspath, args.modules_path)] - # Check if an instance is already launched - if args.pidfile is not None and os.path.isfile(args.pidfile): - with open(args.pidfile, "r") as f: - pid = int(f.readline()) - try: - os.kill(pid, 0) - except OSError: - pass - else: - from nemubot import attach - sys.exit(attach(pid, args.socketfile)) - - # Daemonize - if not args.debug: - from nemubot import daemonize - daemonize() - - # Store PID to pidfile - if args.pidfile is not None: - with open(args.pidfile, "w+") as f: - f.write(str(os.getpid())) - - # Setup loggin interface + # Setup logging interface import logging logger = logging.getLogger("nemubot") logger.setLevel(logging.DEBUG) @@ -115,6 +96,18 @@ def main(): fh.setFormatter(formatter) logger.addHandler(fh) + # Check if an instance is already launched + if args.pidfile is not None and os.path.isfile(args.pidfile): + with open(args.pidfile, "r") as f: + pid = int(f.readline()) + try: + os.kill(pid, 0) + except OSError: + pass + else: + from nemubot import attach + sys.exit(attach(pid, args.socketfile)) + # Add modules dir paths modules_paths = list() for path in args.modules_path: @@ -125,7 +118,7 @@ def main(): # Create bot context from nemubot import datastore - from nemubot.bot import Bot + from nemubot.bot import Bot, sync_act context = Bot(modules_paths=modules_paths, data_store=datastore.XML(args.data_path), verbosity=args.verbose) @@ -141,7 +134,7 @@ def main(): # Load requested configuration files for path in args.files: if os.path.isfile(path): - context.sync_queue.put_nowait(["loadconf", path]) + sync_act("loadconf", path) else: logger.error("%s is not a readable file", path) @@ -149,6 +142,17 @@ def main(): for module in args.module: __import__(module) + if args.socketfile: + from nemubot.server.socket import UnixSocketListener + context.add_server(UnixSocketListener(new_server_cb=context.add_server, + location=args.socketfile, + name="master_socket")) + + # Daemonize + if not args.debug: + from nemubot import daemonize + daemonize(args.socketfile, not args.no_attach) + # Signals handling def sigtermhandler(signum, frame): """On SIGTERM and SIGINT, quit nicely""" @@ -165,22 +169,27 @@ def main(): # Reload configuration file for path in args.files: if os.path.isfile(path): - context.sync_queue.put_nowait(["loadconf", path]) + sync_act("loadconf", path) signal.signal(signal.SIGHUP, sighuphandler) def sigusr1handler(signum, frame): """On SIGHUSR1, display stacktraces""" - import traceback + import threading, traceback for threadId, stack in sys._current_frames().items(): - logger.debug("########### Thread %d:\n%s", - threadId, + thName = "#%d" % threadId + for th in threading.enumerate(): + if th.ident == threadId: + thName = th.name + break + logger.debug("########### Thread %s:\n%s", + thName, "".join(traceback.format_stack(stack))) signal.signal(signal.SIGUSR1, sigusr1handler) - if args.socketfile: - from nemubot.server.socket import SocketListener - context.add_server(SocketListener(context.add_server, "master_socket", - sock_location=args.socketfile)) + # Store PID to pidfile + if args.pidfile is not None: + with open(args.pidfile, "w+") as f: + f.write(str(os.getpid())) # context can change when performing an hotswap, always join the latest context oldcontext = None @@ -195,5 +204,6 @@ def main(): sigusr1handler(0, None) sys.exit(0) + if __name__ == "__main__": main() diff --git a/nemubot/bot.py b/nemubot/bot.py index 2657d52..b0d3915 100644 --- a/nemubot/bot.py +++ b/nemubot/bot.py @@ -16,7 +16,9 @@ from datetime import datetime, timezone import logging +from multiprocessing import JoinableQueue import threading +import select import sys from nemubot import __version__ @@ -26,6 +28,11 @@ import nemubot.hooks logger = logging.getLogger("nemubot") +sync_queue = JoinableQueue() + +def sync_act(*args): + sync_queue.put(list(args)) + class Bot(threading.Thread): @@ -42,7 +49,7 @@ class Bot(threading.Thread): verbosity -- verbosity level """ - threading.Thread.__init__(self) + super().__init__(name="Nemubot main") logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", __version__, @@ -61,6 +68,7 @@ class Bot(threading.Thread): self.datastore.open() # Keep global context: servers and modules + self._poll = select.poll() self.servers = dict() self.modules = dict() self.modules_configuration = dict() @@ -138,60 +146,84 @@ class Bot(threading.Thread): self.cnsr_queue = Queue() self.cnsr_thrd = list() self.cnsr_thrd_size = -1 - # Synchrone actions to be treated by main thread - self.sync_queue = Queue() def run(self): - from select import select - from nemubot.server import _lock, _rlist, _wlist, _xlist + global sync_queue + + # Rewrite the sync_queue, as the daemonization process tend to disturb it + old_sync_queue, sync_queue = sync_queue, JoinableQueue() + while not old_sync_queue.empty(): + sync_queue.put_nowait(old_sync_queue.get()) + + self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI) logger.info("Starting main loop") self.stop = False while not self.stop: - with _lock: - try: - rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) - except: - logger.error("Something went wrong in select") - fnd_smth = False - # Looking for invalid server - for r in _rlist: - if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0: - _rlist.remove(r) - logger.error("Found invalid object in _rlist: " + str(r)) - fnd_smth = True - for w in _wlist: - if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0: - _wlist.remove(w) - logger.error("Found invalid object in _wlist: " + str(w)) - fnd_smth = True - for x in _xlist: - if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0: - _xlist.remove(x) - logger.error("Found invalid object in _xlist: " + str(x)) - fnd_smth = True - if not fnd_smth: - logger.exception("Can't continue, sorry") - self.quit() - continue + for fd, flag in self._poll.poll(): + # Handle internal socket passing orders + if fd != sync_queue._reader.fileno() and fd in self.servers: + srv = self.servers[fd] - for x in xl: - try: - x.exception() - except: - logger.exception("Uncatched exception on server exception") - for w in wl: - try: - w.write_select() - except: - logger.exception("Uncatched exception on server write") - for r in rl: - for i in r.read(): + if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL): try: - self.receive_message(r, i) + srv.exception(flag) except: - logger.exception("Uncatched exception on server read") + logger.exception("Uncatched exception on server exception") + + if srv.fileno() > 0: + if flag & (select.POLLOUT): + try: + srv.async_write() + except: + logger.exception("Uncatched exception on server write") + + if flag & (select.POLLIN | select.POLLPRI): + try: + for i in srv.async_read(): + self.receive_message(srv, i) + except: + logger.exception("Uncatched exception on server read") + + else: + del self.servers[fd] + + + # Always check the sync queue + while not sync_queue.empty(): + args = sync_queue.get() + action = args.pop(0) + + logger.debug("Executing sync_queue action %s%s", action, args) + + if action == "sckt" and len(args) >= 2: + try: + if args[0] == "write": + self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI) + elif args[0] == "unwrite": + self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI) + + elif args[0] == "register": + self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI) + elif args[0] == "unregister": + self._poll.unregister(int(args[1])) + except: + logger.exception("Unhandled excpetion during action:") + + elif action == "exit": + self.quit() + + elif action == "launch_consumer": + pass # This is treated after the loop + + elif action == "loadconf": + for path in args: + logger.debug("Load configuration from %s", path) + self.load_file(path) + logger.info("Configurations successfully loaded") + + sync_queue.task_done() # Launch new consumer threads if necessary @@ -202,17 +234,7 @@ class Bot(threading.Thread): c = Consumer(self) self.cnsr_thrd.append(c) c.start() - - while self.sync_queue.qsize() > 0: - action = self.sync_queue.get_nowait() - if action[0] == "exit": - self.quit() - elif action[0] == "loadconf": - for path in action[1:]: - logger.debug("Load configuration from %s", path) - self.load_file(path) - logger.info("Configurations successfully loaded") - self.sync_queue.task_done() + sync_queue = None logger.info("Ending main loop") @@ -385,7 +407,13 @@ class Bot(threading.Thread): self.event_timer.cancel() if len(self.events): - remaining = self.events[0].time_left.total_seconds() + try: + remaining = self.events[0].time_left.total_seconds() + except: + logger.exception("An error occurs during event time calculation:") + self.events.pop(0) + return self._update_event_timer() + logger.debug("Update timer: next event in %d seconds", remaining) self.event_timer = threading.Timer(remaining if remaining > 0 else 0, self._end_event_timer) self.event_timer.start() @@ -400,6 +428,7 @@ class Bot(threading.Thread): while len(self.events) > 0 and datetime.now(timezone.utc) >= self.events[0].current: evt = self.events.pop(0) self.cnsr_queue.put_nowait(EventConsumer(evt)) + sync_act("launch_consumer") self._update_event_timer() @@ -419,7 +448,7 @@ class Bot(threading.Thread): self.servers[fileno] = srv self.servers[srv.name] = srv if autoconnect and not hasattr(self, "noautoconnect"): - srv.open() + srv.connect() return True else: @@ -463,7 +492,7 @@ class Bot(threading.Thread): module.print = prnt # Create module context - from nemubot.modulecontext import ModuleContext + from nemubot.modulecontext import _ModuleContext, ModuleContext module.__nemubot_context__ = ModuleContext(self, module) if not hasattr(module, "logger"): @@ -471,7 +500,7 @@ class Bot(threading.Thread): # Replace imported context by real one for attr in module.__dict__: - if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext: + if attr != "__nemubot_context__" and type(module.__dict__[attr]) == _ModuleContext: module.__dict__[attr] = module.__nemubot_context__ # Register decorated functions @@ -532,28 +561,29 @@ class Bot(threading.Thread): def quit(self): """Save and unload modules and disconnect servers""" - self.datastore.close() - if self.event_timer is not None: logger.info("Stop the event timer...") self.event_timer.cancel() + logger.info("Save and unload all modules...") + for mod in self.modules.items(): + self.unload_module(mod) + + logger.info("Close all servers connection...") + for srv in [self.servers[k] for k in self.servers]: + srv.close() + logger.info("Stop consumers") k = self.cnsr_thrd for cnsr in k: cnsr.stop = True - logger.info("Save and unload all modules...") - k = list(self.modules.keys()) - for mod in k: - self.unload_module(mod) + self.datastore.close() - logger.info("Close all servers connection...") - k = list(self.servers.keys()) - for srv in k: - self.servers[srv].close() - - self.stop = True + if self.stop is False or sync_queue is not None: + self.stop = True + sync_act("end") + sync_queue.join() # Treatment diff --git a/nemubot/config/module.py b/nemubot/config/module.py index ab51971..7586697 100644 --- a/nemubot/config/module.py +++ b/nemubot/config/module.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from nemubot.config import get_boolean -from nemubot.tools.xmlparser.genericnode import GenericNode +from nemubot.datastore.nodes.generic import GenericNode class Module(GenericNode): diff --git a/nemubot/datastore/abstract.py b/nemubot/datastore/abstract.py index 96e2c0d..f54bbcd 100644 --- a/nemubot/datastore/abstract.py +++ b/nemubot/datastore/abstract.py @@ -23,8 +23,7 @@ class Abstract: """Initialize a new empty storage tree """ - from nemubot.tools.xmlparser import module_state - return module_state.ModuleState("nemubotstate") + return None def open(self): return diff --git a/nemubot/datastore/nodes/__init__.py b/nemubot/datastore/nodes/__init__.py new file mode 100644 index 0000000..e4b2788 --- /dev/null +++ b/nemubot/datastore/nodes/__init__.py @@ -0,0 +1,18 @@ +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2016 Mercier Pierre-Olivier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from nemubot.datastore.nodes.generic import ParsingNode +from nemubot.datastore.nodes.serializable import Serializable diff --git a/nemubot/tools/xmlparser/basic.py b/nemubot/datastore/nodes/basic.py similarity index 67% rename from nemubot/tools/xmlparser/basic.py rename to nemubot/datastore/nodes/basic.py index 8456629..6fbd136 100644 --- a/nemubot/tools/xmlparser/basic.py +++ b/nemubot/datastore/nodes/basic.py @@ -14,11 +14,16 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -class ListNode: +from nemubot.datastore.nodes.serializable import Serializable + + +class ListNode(Serializable): """XML node representing a Python dictionnnary """ + serializetag = "list" + def __init__(self, **kwargs): self.items = list() @@ -27,6 +32,9 @@ class ListNode: self.items.append(child) return True + def parsedForm(self): + return self.items + def __len__(self): return len(self.items) @@ -44,11 +52,21 @@ class ListNode: return self.items.__repr__() -class DictNode: + def serialize(self): + from nemubot.datastore.nodes.generic import ParsingNode + node = ParsingNode(tag=self.serializetag) + for i in self.items: + node.children.append(ParsingNode.serialize_node(i)) + return node + + +class DictNode(Serializable): """XML node representing a Python dictionnnary """ + serializetag = "dict" + def __init__(self, **kwargs): self.items = dict() self._cur = None @@ -56,44 +74,20 @@ class DictNode: def startElement(self, name, attrs): if self._cur is None and "key" in attrs: - self._cur = (attrs["key"], "") - return True + self._cur = attrs["key"] return False - - def characters(self, content): - if self._cur is not None: - key, cnt = self._cur - if isinstance(cnt, str): - cnt += content - self._cur = key, cnt - - - def endElement(self, name): - if name is None or self._cur is None: - return - - key, cnt = self._cur - if isinstance(cnt, list) and len(cnt) == 1: - self.items[key] = cnt - else: - self.items[key] = cnt - - self._cur = None - return True - - def addChild(self, name, child): if self._cur is None: return False - key, cnt = self._cur - if not isinstance(cnt, list): - cnt = [] - cnt.append(child) - self._cur = key, cnt + self.items[self._cur] = child + self._cur = None return True + def parsedForm(self): + return self.items + def __getitem__(self, item): return self.items[item] @@ -106,3 +100,13 @@ class DictNode: def __repr__(self): return self.items.__repr__() + + + def serialize(self): + from nemubot.datastore.nodes.generic import ParsingNode + node = ParsingNode(tag=self.serializetag) + for k in self.items: + chld = ParsingNode.serialize_node(self.items[k]) + chld.attrs["key"] = k + node.children.append(chld) + return node diff --git a/nemubot/tools/xmlparser/genericnode.py b/nemubot/datastore/nodes/generic.py similarity index 64% rename from nemubot/tools/xmlparser/genericnode.py rename to nemubot/datastore/nodes/generic.py index 9c29a23..c9840bc 100644 --- a/nemubot/tools/xmlparser/genericnode.py +++ b/nemubot/datastore/nodes/generic.py @@ -14,6 +14,9 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from nemubot.datastore.nodes.serializable import Serializable + + class ParsingNode: """Allow any kind of subtags, just keep parsed ones @@ -53,6 +56,47 @@ class ParsingNode: return item in self.attrs + def serialize_node(node, **def_kwargs): + """Serialize any node or basic data to a ParsingNode instance""" + + if isinstance(node, Serializable): + node = node.serialize() + + if isinstance(node, str): + from nemubot.datastore.nodes.python import StringNode + pn = StringNode(**def_kwargs) + pn.value = node + return pn + + elif isinstance(node, int): + from nemubot.datastore.nodes.python import IntNode + pn = IntNode(**def_kwargs) + pn.value = node + return pn + + elif isinstance(node, float): + from nemubot.datastore.nodes.python import FloatNode + pn = FloatNode(**def_kwargs) + pn.value = node + return pn + + elif isinstance(node, list): + from nemubot.datastore.nodes.basic import ListNode + pn = ListNode(**def_kwargs) + pn.items = node + return pn.serialize() + + elif isinstance(node, dict): + from nemubot.datastore.nodes.basic import DictNode + pn = DictNode(**def_kwargs) + pn.items = node + return pn.serialize() + + else: + assert isinstance(node, ParsingNode) + return node + + class GenericNode(ParsingNode): """Consider all subtags as dictionnary diff --git a/nemubot/datastore/nodes/python.py b/nemubot/datastore/nodes/python.py new file mode 100644 index 0000000..6e4278b --- /dev/null +++ b/nemubot/datastore/nodes/python.py @@ -0,0 +1,91 @@ +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2016 Mercier Pierre-Olivier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from nemubot.datastore.nodes.serializable import Serializable + + +class PythonTypeNode(Serializable): + + """XML node representing a Python simple type + """ + + def __init__(self, **kwargs): + self.value = None + self._cnt = "" + + + def characters(self, content): + self._cnt += content + + + def endElement(self, name): + raise NotImplemented + + + def __repr__(self): + return self.value.__repr__() + + + def parsedForm(self): + return self.value + + def serialize(self): + raise NotImplemented + + +class IntNode(PythonTypeNode): + + serializetag = "int" + + def endElement(self, name): + self.value = int(self._cnt) + return True + + def serialize(self): + from nemubot.datastore.nodes.generic import ParsingNode + node = ParsingNode(tag=self.serializetag) + node.content = str(self.value) + return node + + +class FloatNode(PythonTypeNode): + + serializetag = "float" + + def endElement(self, name): + self.value = float(self._cnt) + return True + + def serialize(self): + from nemubot.datastore.nodes.generic import ParsingNode + node = ParsingNode(tag=self.serializetag) + node.content = str(self.value) + return node + + +class StringNode(PythonTypeNode): + + serializetag = "str" + + def endElement(self, name): + self.value = str(self._cnt) + return True + + def serialize(self): + from nemubot.datastore.nodes.generic import ParsingNode + node = ParsingNode(tag=self.serializetag) + node.content = str(self.value) + return node diff --git a/nemubot/datastore/nodes/serializable.py b/nemubot/datastore/nodes/serializable.py new file mode 100644 index 0000000..e543699 --- /dev/null +++ b/nemubot/datastore/nodes/serializable.py @@ -0,0 +1,22 @@ +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2016 Mercier Pierre-Olivier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +class Serializable: + + def serialize(self): + # Implementations of this function should return ParsingNode items + return NotImplemented diff --git a/nemubot/datastore/xml.py b/nemubot/datastore/xml.py index 46dca70..a82318d 100644 --- a/nemubot/datastore/xml.py +++ b/nemubot/datastore/xml.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -36,17 +36,24 @@ class XML(Abstract): rotate -- auto-backup files? """ - self.basedir = basedir + self.basedir = os.path.abspath(basedir) self.rotate = rotate self.nb_save = 0 + logger.info("Initiate XML datastore at %s, rotation %s", + self.basedir, + "enabled" if self.rotate else "disabled") + + def open(self): """Lock the directory""" if not os.path.isdir(self.basedir): + logger.debug("Datastore directory not found, creating: %s", self.basedir) os.mkdir(self.basedir) - lock_path = os.path.join(self.basedir, ".used_by_nemubot") + lock_path = self._get_lock_file_path() + logger.debug("Locking datastore directory via %s", lock_path) self.lock_file = open(lock_path, 'a+') ok = True @@ -64,56 +71,95 @@ class XML(Abstract): self.lock_file.write(str(os.getpid())) self.lock_file.flush() + logger.info("Datastore successfuly opened at %s", self.basedir) return True + def close(self): """Release a locked path""" if hasattr(self, "lock_file"): self.lock_file.close() - lock_path = os.path.join(self.basedir, ".used_by_nemubot") + lock_path = self._get_lock_file_path() if os.path.isdir(self.basedir) and os.path.exists(lock_path): os.unlink(lock_path) del self.lock_file + logger.info("Datastore successfully closed at %s", self.basedir) return True + else: + logger.warn("Datastore not open/locked or lock file not found") return False + def _get_data_file_path(self, module): """Get the path to the module data file""" return os.path.join(self.basedir, module + ".xml") - def load(self, module): + + def _get_lock_file_path(self): + """Get the path to the datastore lock file""" + + return os.path.join(self.basedir, ".used_by_nemubot") + + + def load(self, module, extendsTags={}): """Load data for the given module Argument: module -- the module name of data to load """ + logger.debug("Trying to load data for %s%s", + module, + (" with tags: " + ", ".join(extendsTags.keys())) if len(extendsTags) else "") + data_file = self._get_data_file_path(module) + def parse(path): + from nemubot.tools.xmlparser import XMLParser + from nemubot.datastore.nodes import basic as basicNodes + from nemubot.datastore.nodes import python as pythonNodes + from nemubot.message.command import Command + from nemubot.scope import Scope + + d = { + basicNodes.ListNode.serializetag: basicNodes.ListNode, + basicNodes.DictNode.serializetag: basicNodes.DictNode, + pythonNodes.IntNode.serializetag: pythonNodes.IntNode, + pythonNodes.FloatNode.serializetag: pythonNodes.FloatNode, + pythonNodes.StringNode.serializetag: pythonNodes.StringNode, + Command.serializetag: Command, + Scope.serializetag: Scope, + } + d.update(extendsTags) + + p = XMLParser(d) + return p.parse_file(path) + # Try to load original file if os.path.isfile(data_file): - from nemubot.tools.xmlparser import parse_file try: - return parse_file(data_file) + return parse(data_file) except xml.parsers.expat.ExpatError: # Try to load from backup for i in range(10): path = data_file + "." + str(i) if os.path.isfile(path): try: - cnt = parse_file(path) + cnt = parse(path) - logger.warn("Restoring from backup: %s", path) + logger.warn("Restoring data from backup: %s", path) return cnt except xml.parsers.expat.ExpatError: continue # Default case: initialize a new empty datastore + logger.warn("No data found in store for %s, creating new set", module) return Abstract.load(self, module) + def _rotate(self, path): """Backup given path @@ -130,6 +176,25 @@ class XML(Abstract): if os.path.isfile(src): os.rename(src, dst) + + def _save_node(self, gen, node): + from nemubot.datastore.nodes.generic import ParsingNode + + # First, get the serialized form of the node + node = ParsingNode.serialize_node(node) + + assert node.tag is not None, "Undefined tag name" + + gen.startElement(node.tag, {k: str(node.attrs[k]) for k in node.attrs}) + + gen.characters(node.content) + + for child in node.children: + self._save_node(gen, child) + + gen.endElement(node.tag) + + def save(self, module, data): """Load data for the given module @@ -139,8 +204,22 @@ class XML(Abstract): """ path = self._get_data_file_path(module) + logger.debug("Trying to save data for module %s in %s", module, path) if self.rotate: self._rotate(path) - return data.save(path) + import tempfile + _, tmpath = tempfile.mkstemp() + with open(tmpath, "w") as f: + import xml.sax.saxutils + gen = xml.sax.saxutils.XMLGenerator(f, "utf-8") + gen.startDocument() + self._save_node(gen, data) + gen.endDocument() + + # Atomic save + import shutil + shutil.move(tmpath, path) + + return True diff --git a/nemubot/message/abstract.py b/nemubot/message/abstract.py index 5d74549..bf9a030 100644 --- a/nemubot/message/abstract.py +++ b/nemubot/message/abstract.py @@ -16,12 +16,17 @@ from datetime import datetime, timezone +from nemubot.datastore.nodes import Serializable -class Abstract: + +class Abstract(Serializable): """This class represents an abstract message""" - def __init__(self, server=None, date=None, to=None, to_response=None, frm=None): + serializetag = "nemubotAMessage" + + + def __init__(self, server=None, date=None, to=None, to_response=None, frm=None, frm_owner=False): """Initialize an abstract message Arguments: @@ -40,7 +45,7 @@ class Abstract: else [ to_response ]) self.frm = frm # None allowed when it designate this bot - self.frm_owner = False # Filled later, in consumer + self.frm_owner = frm_owner @property @@ -65,6 +70,14 @@ class Abstract: return self.frm + @property + def scope(self): + from nemubot.scope import Scope + return Scope(server=self.server, + channel=self.to_response[0], + nick=self.frm) + + def accept(self, visitor): visitor.visit(self) @@ -78,7 +91,8 @@ class Abstract: "date": self.date, "to": self.to, "to_response": self._to_response, - "frm": self.frm + "frm": self.frm, + "frm_owner": self.frm_owner, } for w in without: @@ -86,3 +100,8 @@ class Abstract: del ret[w] return ret + + + def serialize(self): + from nemubot.datastore.nodes import ParsingNode + return ParsingNode(tag=Abstract.serializetag, **self.export_args()) diff --git a/nemubot/message/command.py b/nemubot/message/command.py index 6c208b2..2fe8893 100644 --- a/nemubot/message/command.py +++ b/nemubot/message/command.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -21,6 +21,9 @@ class Command(Abstract): """This class represents a specialized TextMessage""" + serializetag = "nemubotCommand" + + def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs): super().__init__(*nargs, **kargs) @@ -28,17 +31,35 @@ class Command(Abstract): self.args = args if args is not None else list() self.kwargs = kwargs if kwargs is not None else dict() - def __str__(self): + + def __repr__(self): return self.cmd + " @" + ",@".join(self.args) - @property - def cmds(self): - # TODO: this is for legacy modules - return [self.cmd] + self.args + + def addChild(self, name, child): + if name == "list": + self.args = child + elif name == "dict": + self.kwargs = child + else: + return False + return True + + + def serialize(self): + from nemubot.datastore.nodes import ParsingNode + node = ParsingNode(tag=Command.serializetag, cmd=self.cmd) + if len(self.args): + node.children.append(ParsingNode.serialize_node(self.args)) + if len(self.kwargs): + node.children.append(ParsingNode.serialize_node(self.kwargs)) + return node class OwnerCommand(Command): """This class represents a special command incomming from the owner""" + serializetag = "nemubotOCommand" + pass diff --git a/nemubot/modulecontext.py b/nemubot/modulecontext.py index 1d1b3d0..7befe18 100644 --- a/nemubot/modulecontext.py +++ b/nemubot/modulecontext.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2017 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,105 +14,62 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -class ModuleContext: +class _ModuleContext: - def __init__(self, context, module): - """Initialize the module context - - arguments: - context -- the bot context - module -- the module - """ + def __init__(self, module=None): + self.module = module if module is not None: - module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ + self.module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ else: - module_name = "" - - # Load module configuration if exists - if (context is not None and - module_name in context.modules_configuration): - self.config = context.modules_configuration[module_name] - else: - from nemubot.config.module import Module - self.config = Module(module_name) + self.module_name = "" self.hooks = list() self.events = list() - self.debug = context.verbosity > 0 if context is not None else False + self.extendtags = dict() + self.debug = False + from nemubot.config.module import Module + self.config = Module(self.module_name) + + def load_data(self): + return None + + def add_hook(self, hook, *triggers): from nemubot.hooks import Abstract as AbstractHook + assert isinstance(hook, AbstractHook), hook + self.hooks.append((triggers, hook)) - # Define some callbacks - if context is not None: - def load_data(): - return context.datastore.load(module_name) + def del_hook(self, hook, *triggers): + from nemubot.hooks import Abstract as AbstractHook + assert isinstance(hook, AbstractHook), hook + self.hooks.remove((triggers, hook)) - def add_hook(hook, *triggers): - assert isinstance(hook, AbstractHook), hook - self.hooks.append((triggers, hook)) - return context.treater.hm.add_hook(hook, *triggers) + def subtreat(self, msg): + return None - def del_hook(hook, *triggers): - assert isinstance(hook, AbstractHook), hook - self.hooks.remove((triggers, hook)) - return context.treater.hm.del_hooks(*triggers, hook=hook) + def add_event(self, evt, eid=None): + return self.events.append((evt, eid)) - def subtreat(msg): - yield from context.treater.treat_msg(msg) - def add_event(evt, eid=None): - return context.add_event(evt, eid, module_src=module) - def del_event(evt): - return context.del_event(evt, module_src=module) + def del_event(self, evt): + for i in self.events: + e, eid = i + if e == evt: + self.events.remove(i) + return True + return False - def send_response(server, res): - if server in context.servers: - if res.server is not None: - return context.servers[res.server].send_response(res) - else: - return context.servers[server].send_response(res) - else: - module.logger.error("Try to send a message to the unknown server: %s", server) - return False + def send_response(self, server, res): + self.module.logger.info("Send response: %s", res) - else: # Used when using outside of nemubot - def load_data(): - from nemubot.tools.xmlparser import module_state - return module_state.ModuleState("nemubotstate") - - def add_hook(hook, *triggers): - assert isinstance(hook, AbstractHook), hook - self.hooks.append((triggers, hook)) - def del_hook(hook, *triggers): - assert isinstance(hook, AbstractHook), hook - self.hooks.remove((triggers, hook)) - def subtreat(msg): - return None - def add_event(evt, eid=None): - return context.add_event(evt, eid, module_src=module) - def del_event(evt): - return context.del_event(evt, module_src=module) - - def send_response(server, res): - module.logger.info("Send response: %s", res) - - def save(): - context.datastore.save(module_name, self.data) - - def subparse(orig, cnt): - if orig.server in context.servers: - return context.servers[orig.server].subparse(orig, cnt) - - self.load_data = load_data - self.add_hook = add_hook - self.del_hook = del_hook - self.add_event = add_event - self.del_event = del_event - self.save = save - self.send_response = send_response - self.subtreat = subtreat - self.subparse = subparse + def save(self): + # Don't save if no data has been access + if hasattr(self, "_data"): + context.datastore.save(self.module_name, self.data) + def subparse(self, orig, cnt): + if orig.server in self.context.servers: + return self.context.servers[orig.server].subparse(orig, cnt) @property def data(self): @@ -120,6 +77,21 @@ class ModuleContext: self._data = self.load_data() return self._data + @data.setter + def data(self, value): + assert value is not None + + self._data = value + + + def register_tags(self, **tags): + self.extendtags.update(tags) + + + def unregister_tags(self, *tags): + for t in tags: + del self.extendtags[t] + def unload(self): """Perform actions for unloading the module""" @@ -129,7 +101,62 @@ class ModuleContext: self.del_hook(h, *s) # Remove registered events - for e in self.events: - self.del_event(e) + for evt, eid, module_src in self.events: + self.del_event(evt) self.save() + + +class ModuleContext(_ModuleContext): + + def __init__(self, context, *args, **kwargs): + """Initialize the module context + + arguments: + context -- the bot context + module -- the module + """ + + super().__init__(*args, **kwargs) + + # Load module configuration if exists + if self.module_name in context.modules_configuration: + self.config = context.modules_configuration[self.module_name] + + self.context = context + self.debug = context.verbosity > 0 + + + def load_data(self): + return self.context.datastore.load(self.module_name, extendsTags=self.extendtags) + + def add_hook(self, hook, *triggers): + from nemubot.hooks import Abstract as AbstractHook + assert isinstance(hook, AbstractHook), hook + self.hooks.append((triggers, hook)) + return self.context.treater.hm.add_hook(hook, *triggers) + + def del_hook(self, hook, *triggers): + from nemubot.hooks import Abstract as AbstractHook + assert isinstance(hook, AbstractHook), hook + self.hooks.remove((triggers, hook)) + return self.context.treater.hm.del_hooks(*triggers, hook=hook) + + def subtreat(self, msg): + yield from self.context.treater.treat_msg(msg) + + def add_event(self, evt, eid=None): + return self.context.add_event(evt, eid, module_src=self.module) + + def del_event(self, evt): + return self.context.del_event(evt, module_src=self.module) + + def send_response(self, server, res): + if server in self.context.servers: + if res.server is not None: + return self.context.servers[res.server].send_response(res) + else: + return self.context.servers[server].send_response(res) + else: + self.module.logger.error("Try to send a message to the unknown server: %s", server) + return False diff --git a/nemubot/scope.py b/nemubot/scope.py new file mode 100644 index 0000000..5da1542 --- /dev/null +++ b/nemubot/scope.py @@ -0,0 +1,83 @@ +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2016 Mercier Pierre-Olivier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from nemubot.datastore.nodes import Serializable + + +class Scope(Serializable): + + + serializetag = "nemubot-scope" + default_limit = "channel" + + + def __init__(self, server, channel, nick, limit=default_limit): + self._server = server + self._channel = channel + self._nick = nick + self._limit = limit + + + def sameServer(self, server): + return self._server is None or self._server == server + + + def sameChannel(self, server, channel): + return self.sameServer(server) and (self._channel is None or self._channel == channel) + + + def sameNick(self, server, channel, nick): + return self.sameChannel(server, channel) and (self._nick is None or self._nick == nick) + + + def check(self, scope, limit=None): + return self.checkScope(scope._server, scope._channel, scope._nick, limit) + + + def checkScope(self, server, channel, nick, limit=None): + if limit is None: limit = self._limit + assert limit == "global" or limit == "server" or limit == "channel" or limit == "nick" + + if limit == "server": + return self.sameServer(server) + elif limit == "channel": + return self.sameChannel(server, channel) + elif limit == "nick": + return self.sameNick(server, channel, nick) + else: + return True + + + def narrow(self, scope): + return scope is None or ( + scope._limit == "global" or + (scope._limit == "server" and (self._limit == "nick" or self._limit == "channel")) or + (scope._limit == "channel" and self._limit == "nick") + ) + + + def serialize(self): + from nemubot.datastore.nodes import ParsingNode + args = {} + if self._server is not None: + args["server"] = self._server + if self._channel is not None: + args["channel"] = self._channel + if self._nick is not None: + args["nick"] = self._nick + if self._limit is not None: + args["limit"] = self._limit + return ParsingNode(tag=self.serializetag, **args) diff --git a/nemubot/server/DCC.py b/nemubot/server/DCC.py index 644a8cb..c1a6852 100644 --- a/nemubot/server/DCC.py +++ b/nemubot/server/DCC.py @@ -31,7 +31,7 @@ PORTS = list() class DCC(server.AbstractServer): def __init__(self, srv, dest, socket=None): - super().__init__(self) + super().__init__(name="Nemubot DCC server") self.error = False # An error has occur, closing the connection? self.messages = list() # Message queued before connexion diff --git a/nemubot/server/IRC.py b/nemubot/server/IRC.py index 08e2bc5..7469abc 100644 --- a/nemubot/server/IRC.py +++ b/nemubot/server/IRC.py @@ -16,21 +16,22 @@ from datetime import datetime import re +import socket from nemubot.channel import Channel from nemubot.message.printer.IRC import IRC as IRCPrinter from nemubot.server.message.IRC import IRC as IRCMessage -from nemubot.server.socket import SocketServer +from nemubot.server.socket import SocketServer, SecureSocketServer -class IRC(SocketServer): +class _IRC: """Concrete implementation of a connexion to an IRC server""" - def __init__(self, host="localhost", port=6667, ssl=False, owner=None, + def __init__(self, host="localhost", port=6667, owner=None, nick="nemubot", username=None, password=None, realname="Nemubot", encoding="utf-8", caps=None, - channels=list(), on_connect=None): + channels=list(), on_connect=None, **kwargs): """Prepare a connection with an IRC server Keyword arguments: @@ -54,7 +55,8 @@ class IRC(SocketServer): self.owner = owner self.realname = realname - super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port)) + super().__init__(name=self.username + "@" + host + ":" + str(port), + host=host, port=port, **kwargs) self.printer = IRCPrinter self.encoding = encoding @@ -231,20 +233,19 @@ class IRC(SocketServer): # Open/close - def open(self): - if super().open(): - if self.password is not None: - self.write("PASS :" + self.password) - if self.capabilities is not None: - self.write("CAP LS") - self.write("NICK :" + self.nick) - self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) - return True - return False + def connect(self): + super().connect() + + if self.password is not None: + self.write("PASS :" + self.password) + if self.capabilities is not None: + self.write("CAP LS") + self.write("NICK :" + self.nick) + self.write("USER %s %s bla :%s" % (self.username, socket.getfqdn(), self.realname)) def close(self): - if not self.closed: + if not self._closed: self.write("QUIT") return super().close() @@ -253,8 +254,8 @@ class IRC(SocketServer): # Read - def read(self): - for line in super().read(): + def async_read(self): + for line in super().async_read(): # PING should be handled here, so start parsing here :/ msg = IRCMessage(line, self.encoding) @@ -273,3 +274,10 @@ class IRC(SocketServer): def subparse(self, orig, cnt): msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding) return msg.to_bot_message(self) + + +class IRC(_IRC, SocketServer): + pass + +class IRC_secure(_IRC, SecureSocketServer): + pass diff --git a/nemubot/server/__init__.py b/nemubot/server/__init__.py index 3c88138..6998ef1 100644 --- a/nemubot/server/__init__.py +++ b/nemubot/server/__init__.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,57 +14,64 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import threading -_lock = threading.Lock() - -# Lists for select -_rlist = [] -_wlist = [] -_xlist = [] - - -def factory(uri, **init_args): - from urllib.parse import urlparse, unquote +def factory(uri, ssl=False, **init_args): + from urllib.parse import urlparse, unquote, parse_qs o = urlparse(uri) + srv = None + if o.scheme == "irc" or o.scheme == "ircs": # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html args = init_args - modifiers = o.path.split(",") - target = unquote(modifiers.pop(0)[1:]) - - if o.scheme == "ircs": args["ssl"] = True + if o.scheme == "ircs": ssl = True if o.hostname is not None: args["host"] = o.hostname if o.port is not None: args["port"] = o.port if o.username is not None: args["username"] = o.username if o.password is not None: args["password"] = o.password - queries = o.query.split("&") - for q in queries: - if "=" in q: - key, val = tuple(q.split("=", 1)) - else: - key, val = q, "" - if key == "msg": - if "on_connect" not in args: - args["on_connect"] = [] - args["on_connect"].append("PRIVMSG %s :%s" % (target, unquote(val))) - elif key == "key": - if "channels" not in args: - args["channels"] = [] - args["channels"].append((target, unquote(val))) - elif key == "pass": - args["password"] = unquote(val) - elif key == "charset": - args["encoding"] = unquote(val) + if ssl: + try: + from ssl import create_default_context + args["_context"] = create_default_context() + except ImportError: + # Python 3.3 compat + from ssl import SSLContext, PROTOCOL_TLSv1 + args["_context"] = SSLContext(PROTOCOL_TLSv1) + modifiers = o.path.split(",") + target = unquote(modifiers.pop(0)[1:]) + + # Read query string + params = parse_qs(o.query) + + if "msg" in params: + if "on_connect" not in args: + args["on_connect"] = [] + args["on_connect"].append("PRIVMSG %s :%s" % (target, params["msg"])) + + if "key" in params: + if "channels" not in args: + args["channels"] = [] + args["channels"].append((target, params["key"])) + + if "pass" in params: + args["password"] = params["pass"] + + if "charset" in params: + args["encoding"] = params["charset"] + + # if "channels" not in args and "isnick" not in modifiers: args["channels"] = [ target ] - from nemubot.server.IRC import IRC as IRCServer - return IRCServer(**args) - else: - return None + if ssl: + from nemubot.server.IRC import IRC_secure as SecureIRCServer + srv = SecureIRCServer(**args) + else: + from nemubot.server.IRC import IRC as IRCServer + srv = IRCServer(**args) + + return srv diff --git a/nemubot/server/abstract.py b/nemubot/server/abstract.py index dc2081d..fd25c2d 100644 --- a/nemubot/server/abstract.py +++ b/nemubot/server/abstract.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,34 +14,30 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import io import logging import queue -from nemubot.server import _lock, _rlist, _wlist, _xlist +from nemubot.bot import sync_act -# Extends from IOBase in order to be compatible with select function -class AbstractServer(io.IOBase): + +class AbstractServer: """An abstract server: handle communication with an IM server""" - def __init__(self, name=None, send_callback=None): + def __init__(self, name=None, **kwargs): """Initialize an abstract server Keyword argument: - send_callback -- Callback when developper want to send a message + name -- Identifier of the socket, for convinience """ self._name = name - super().__init__() + super().__init__(**kwargs) - self.logger = logging.getLogger("nemubot.server." + self.name) + self.logger = logging.getLogger("nemubot.server." + str(self.name)) + self._readbuffer = b'' self._sending_queue = queue.Queue() - if send_callback is not None: - self._send_callback = send_callback - else: - self._send_callback = self._write_select @property @@ -54,40 +50,28 @@ class AbstractServer(io.IOBase): # Open/close - def __enter__(self): - self.open() - return self + def connect(self, *args, **kwargs): + """Register the server in _poll""" + + self.logger.info("Opening connection") + + super().connect(*args, **kwargs) + + self._on_connect() + + def _on_connect(self): + sync_act("sckt", "register", self.fileno()) - def __exit__(self, type, value, traceback): - self.close() + def close(self, *args, **kwargs): + """Unregister the server from _poll""" + self.logger.info("Closing connection") - def open(self): - """Generic open function that register the server un _rlist in case - of successful _open""" - self.logger.info("Opening connection to %s", self.id) - if not hasattr(self, "_open") or self._open(): - _rlist.append(self) - _xlist.append(self) - return True - return False + if self.fileno() > 0: + sync_act("sckt", "unregister", self.fileno()) - - def close(self): - """Generic close function that register the server un _{r,w,x}list in - case of successful _close""" - self.logger.info("Closing connection to %s", self.id) - with _lock: - if not hasattr(self, "_close") or self._close(): - if self in _rlist: - _rlist.remove(self) - if self in _wlist: - _wlist.remove(self) - if self in _xlist: - _xlist.remove(self) - return True - return False + super().close(*args, **kwargs) # Writes @@ -99,13 +83,16 @@ class AbstractServer(io.IOBase): message -- message to send """ - self._send_callback(message) + self._sending_queue.put(self.format(message)) + self.logger.debug("Message '%s' appended to write queue", message) + sync_act("sckt", "write", self.fileno()) - def write_select(self): - """Internal function used by the select function""" + def async_write(self): + """Internal function used when the file descriptor is writable""" + try: - _wlist.remove(self) + sync_act("sckt", "unwrite", self.fileno()) while not self._sending_queue.empty(): self._write(self._sending_queue.get_nowait()) self._sending_queue.task_done() @@ -114,19 +101,6 @@ class AbstractServer(io.IOBase): pass - def _write_select(self, message): - """Send a message to the server safely through select - - Argument: - message -- message to send - """ - - self._sending_queue.put(self.format(message)) - self.logger.debug("Message '%s' appended to write queue", message) - if self not in _wlist: - _wlist.append(self) - - def send_response(self, response): """Send a formated Message class @@ -149,13 +123,39 @@ class AbstractServer(io.IOBase): # Read + def async_read(self): + """Internal function used when the file descriptor is readable + + Returns: + A list of fully received messages + """ + + ret, self._readbuffer = self.lex(self._readbuffer + self.read()) + + for r in ret: + yield r + + + def lex(self, buf): + """Assume lexing in default case is per line + + Argument: + buf -- buffer to lex + """ + + msgs = buf.split(b'\r\n') + partial = msgs.pop() + + return msgs, partial + + def parse(self, msg): raise NotImplemented # Exceptions - def exception(self): - """Exception occurs in fd""" - self.logger.warning("Unhandle file descriptor exception on server %s", - self.name) + def exception(self, flags): + """Exception occurs on fd""" + + self.close() diff --git a/nemubot/server/factory_test.py b/nemubot/server/factory_test.py index cc7d35b..e2b6752 100644 --- a/nemubot/server/factory_test.py +++ b/nemubot/server/factory_test.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import socket import unittest from nemubot.server import factory @@ -22,34 +23,36 @@ class TestFactory(unittest.TestCase): def test_IRC1(self): from nemubot.server.IRC import IRC as IRCServer + from nemubot.server.IRC import IRC_secure as IRCSServer # : If omitted, the client must connect to a prespecified default IRC server. server = factory("irc:///") self.assertIsInstance(server, IRCServer) - self.assertEqual(server.host, "localhost") - self.assertFalse(server.ssl) + self.assertEqual(server._sockaddr, + socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4]) server = factory("ircs:///") - self.assertIsInstance(server, IRCServer) - self.assertEqual(server.host, "localhost") - self.assertTrue(server.ssl) + self.assertIsInstance(server, IRCSServer) + self.assertEqual(server._sockaddr, + socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4]) - server = factory("irc://host1") + server = factory("irc://freenode.net") self.assertIsInstance(server, IRCServer) - self.assertEqual(server.host, "host1") - self.assertFalse(server.ssl) + self.assertEqual(server._sockaddr, + socket.getaddrinfo("freenode.net", 6667, proto=socket.IPPROTO_TCP)[0][4]) - server = factory("irc://host2:6667") + server = factory("irc://freenode.org:1234") self.assertIsInstance(server, IRCServer) - self.assertEqual(server.host, "host2") - self.assertEqual(server.port, 6667) - self.assertFalse(server.ssl) + self.assertEqual(server._sockaddr, + socket.getaddrinfo("freenode.org", 1234, proto=socket.IPPROTO_TCP)[0][4]) - server = factory("ircs://host3:194/") - self.assertIsInstance(server, IRCServer) - self.assertEqual(server.host, "host3") - self.assertEqual(server.port, 194) - self.assertTrue(server.ssl) + server = factory("ircs://nemunai.re:194/") + self.assertIsInstance(server, IRCSServer) + self.assertEqual(server._sockaddr, + socket.getaddrinfo("nemunai.re", 194, proto=socket.IPPROTO_TCP)[0][4]) + + with self.assertRaises(socket.gaierror): + factory("irc://_nonexistent.nemunai.re") if __name__ == '__main__': diff --git a/nemubot/server/message/IRC.py b/nemubot/server/message/IRC.py index 4c9e280..5ccd735 100644 --- a/nemubot/server/message/IRC.py +++ b/nemubot/server/message/IRC.py @@ -150,7 +150,8 @@ class IRC(Abstract): "date": self.tags["time"], "to": receivers, "to_response": [r if r != srv.nick else self.nick for r in receivers], - "frm": self.nick + "frm": self.nick, + "frm_owner": self.nick == srv.owner } # If CTCP, remove 0x01 diff --git a/nemubot/server/message/__init__.py b/nemubot/server/message/__init__.py new file mode 100644 index 0000000..57f3468 --- /dev/null +++ b/nemubot/server/message/__init__.py @@ -0,0 +1,15 @@ +# Nemubot is a smart and modulable IM bot. +# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py index 13ac9bd..612f4cb 100644 --- a/nemubot/server/socket.py +++ b/nemubot/server/socket.py @@ -1,5 +1,5 @@ # Nemubot is a smart and modulable IM bot. -# Copyright (C) 2012-2015 Mercier Pierre-Olivier +# Copyright (C) 2012-2016 Mercier Pierre-Olivier # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,117 +14,33 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import os +import socket +import ssl + import nemubot.message as message from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.server.abstract import AbstractServer -class SocketServer(AbstractServer): +class _Socket(AbstractServer): - """Concrete implementation of a socket connexion (can be wrapped with TLS)""" + """Concrete implementation of a socket connection""" - def __init__(self, sock_location=None, - host=None, port=None, - sock=None, - ssl=False, - name=None): + def __init__(self, printer=SocketPrinter, **kwargs): """Create a server socket - - Keyword arguments: - sock_location -- Path to the UNIX socket - host -- Hostname of the INET socket - port -- Port of the INET socket - sock -- Already connected socket - ssl -- Should TLS connection enabled - name -- Convinience name """ - import socket - - assert(sock is None or isinstance(sock, socket.SocketType)) - assert(port is None or isinstance(port, int)) - - super().__init__(name=name) - - if sock is None: - if sock_location is not None: - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.connect_to = sock_location - elif host is not None: - for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): - self.socket = socket.socket(af, socktype, proto) - self.connect_to = sa - break - else: - self.socket = sock - - self.ssl = ssl + super().__init__(**kwargs) self.readbuffer = b'' - self.printer = SocketPrinter - - - def fileno(self): - return self.socket.fileno() if self.socket else None - - - @property - def closed(self): - """Indicator of the connection aliveness""" - return self.socket._closed - - - # Open/close - - def open(self): - if not self.closed: - return True - - try: - self.socket.connect(self.connect_to) - self.logger.info("Connected to %s", self.connect_to) - except: - self.socket.close() - self.logger.exception("Unable to connect to %s", - self.connect_to) - return False - - # Wrap the socket for SSL - if self.ssl: - import ssl - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - self.socket = ctx.wrap_socket(self.socket) - - return super().open() - - - def close(self): - import socket - - # Flush the sending queue before close - from nemubot.server import _lock - _lock.release() - self._sending_queue.join() - _lock.acquire() - - if not self.closed: - try: - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self.socket.close() - - return super().close() + self.printer = printer # Write def _write(self, cnt): - if self.closed: - return - - self.socket.sendall(cnt) + self.sendall(cnt) def format(self, txt): @@ -136,19 +52,12 @@ class SocketServer(AbstractServer): # Read - def read(self): - if self.closed: - return [] - - raw = self.socket.recv(1024) - temp = (self.readbuffer + raw).split(b'\r\n') - self.readbuffer = temp.pop() - - for line in temp: - yield line + def recv(self, n=1024): + return super().recv(n) def parse(self, line): + """Implement a default behaviour for socket""" import shlex line = line.strip().decode() @@ -157,48 +66,107 @@ class SocketServer(AbstractServer): except ValueError: args = line.split(' ') - yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you") + if len(args): + yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you") -class SocketListener(AbstractServer): - - def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None): - super().__init__(name=name) - self.new_server_cb = new_server_cb - self.sock_location = sock_location - self.host = host - self.port = port - self.ssl = ssl - self.nb_son = 0 + def subparse(self, orig, cnt): + for m in self.parse(cnt): + m.to = orig.to + m.frm = orig.frm + m.date = orig.date + yield m - def fileno(self): - return self.socket.fileno() if self.socket else None +class _SocketServer(_Socket): + + def __init__(self, host, port, bind=None, **kwargs): + (family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0] + + if isinstance(self, ssl.SSLSocket) and "server_hostname" not in kwargs: + kwargs["server_hostname"] = host + + super().__init__(family=family, type=type, proto=proto, **kwargs) + + self._sockaddr = sockaddr + self._bind = bind - @property - def closed(self): - """Indicator of the connection aliveness""" - return self.socket is None + def connect(self): + self.logger.info("Connection to %s:%d", *self._sockaddr[:2]) + super().connect(self._sockaddr) + + if self._bind: + super().bind(self._bind) - def open(self): - import os - import socket +class SocketServer(_SocketServer, socket.socket): + pass - if self.sock_location is not None: - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - os.remove(self.sock_location) - except FileNotFoundError: - pass - self.socket.bind(self.sock_location) - elif self.host is not None and self.port is not None: - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.bind((self.host, self.port)) - self.socket.listen(5) - return super().open() +class SecureSocketServer(_SocketServer, ssl.SSLSocket): + pass + + +class UnixSocket: + + def __init__(self, location, **kwargs): + super().__init__(family=socket.AF_UNIX, **kwargs) + + self._socket_path = location + + + def connect(self): + self.logger.info("Connection to unix://%s", self._socket_path) + super().connect(self._socket_path) + + +class SocketClient(_Socket, socket.socket): + + def read(self): + return self.recv() + + +class _Listener: + + def __init__(self, new_server_cb, instanciate=SocketClient, **kwargs): + super().__init__(**kwargs) + + self._instanciate = instanciate + self._new_server_cb = new_server_cb + + + def read(self): + conn, addr = self.accept() + fileno = conn.fileno() + self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno) + + ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach()) + ss.connect = ss._on_connect + self._new_server_cb(ss, autoconnect=True) + + return b'' + + +class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + + def connect(self): + self.logger.info("Creating Unix socket at unix://%s", self._socket_path) + + try: + os.remove(self._socket_path) + except FileNotFoundError: + pass + + self.bind(self._socket_path) + self.listen(5) + self.logger.info("Socket ready for accepting new connections") + + self._on_connect() def close(self): @@ -206,25 +174,14 @@ class SocketListener(AbstractServer): import socket try: - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() - if self.sock_location is not None: - os.remove(self.sock_location) + self.shutdown(socket.SHUT_RDWR) except socket.error: pass - return super().close() + super().close() - - # Read - - def read(self): - if self.closed: - return [] - - conn, addr = self.socket.accept() - self.nb_son += 1 - ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn) - self.new_server_cb(ss) - - return [] + try: + if self._socket_path is not None: + os.remove(self._socket_path) + except: + pass diff --git a/nemubot/tools/web.py b/nemubot/tools/web.py index d35740c..0852664 100644 --- a/nemubot/tools/web.py +++ b/nemubot/tools/web.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from urllib.parse import urlparse, urlsplit, urlunsplit +from urllib.parse import urljoin, urlparse, urlsplit, urlunsplit from nemubot.exception import IMException @@ -108,6 +108,9 @@ def getURLContent(url, body=None, timeout=7, header=None): elif "User-agent" not in header: header["User-agent"] = "Nemubot v%s" % __version__ + if body is not None and "Content-Type" not in header: + header["Content-Type"] = "application/x-www-form-urlencoded" + import socket try: if o.query != '': @@ -156,21 +159,23 @@ def getURLContent(url, body=None, timeout=7, header=None): elif ((res.status == http.client.FOUND or res.status == http.client.MOVED_PERMANENTLY) and res.getheader("Location") != url): - return getURLContent(res.getheader("Location"), timeout=timeout) + return getURLContent( + urljoin(url, res.getheader("Location")), + body=body, + timeout=timeout, + header=header) else: raise IMException("A HTTP error occurs: %d - %s" % (res.status, http.client.responses[res.status])) -def getXML(url, timeout=7): +def getXML(*args, **kwargs): """Get content page and return XML parsed content - Arguments: - url -- the URL to get - timeout -- maximum number of seconds to wait before returning an exception + Arguments: same as getURLContent """ - cnt = getURLContent(url, timeout=timeout) + cnt = getURLContent(*args, **kwargs) if cnt is None: return None else: @@ -178,15 +183,13 @@ def getXML(url, timeout=7): return parseString(cnt) -def getJSON(url, timeout=7): +def getJSON(*args, **kwargs): """Get content page and return JSON content - Arguments: - url -- the URL to get - timeout -- maximum number of seconds to wait before returning an exception + Arguments: same as getURLContent """ - cnt = getURLContent(url, timeout=timeout) + cnt = getURLContent(*args, **kwargs) if cnt is None: return None else: diff --git a/nemubot/tools/xmlparser/__init__.py b/nemubot/tools/xmlparser/__init__.py index abc5bb9..687bf63 100644 --- a/nemubot/tools/xmlparser/__init__.py +++ b/nemubot/tools/xmlparser/__init__.py @@ -51,11 +51,13 @@ class XMLParser: def __init__(self, knodes): self.knodes = knodes + def _reset(self): self.stack = list() self.child = 0 def parse_file(self, path): + self._reset() p = xml.parsers.expat.ParserCreate() p.StartElementHandler = self.startElement @@ -69,6 +71,7 @@ class XMLParser: def parse_string(self, s): + self._reset() p = xml.parsers.expat.ParserCreate() p.StartElementHandler = self.startElement @@ -126,10 +129,13 @@ class XMLParser: if hasattr(self.current, "endElement"): self.current.endElement(None) + if hasattr(self.current, "parsedForm") and callable(self.current.parsedForm): + self.stack[-1] = self.current.parsedForm() + # Don't remove root if len(self.stack) > 1: last = self.stack.pop() - if hasattr(self.current, "addChild"): + if hasattr(self.current, "addChild") and callable(self.current.addChild): if self.current.addChild(name, last): return raise TypeError(name + " tag not expected in " + self.display_stack()) diff --git a/nemubot/treatment.py b/nemubot/treatment.py index 2c1955d..4f629e0 100644 --- a/nemubot/treatment.py +++ b/nemubot/treatment.py @@ -15,6 +15,7 @@ # along with this program. If not, see . import logging +import types logger = logging.getLogger("nemubot.treatment") @@ -108,6 +109,9 @@ class MessageTreater: msg -- message to treat """ + if hasattr(msg, "frm_owner"): + msg.frm_owner = (not hasattr(msg.server, "owner") or msg.server.owner == msg.frm) + while hook is not None: res = hook.run(msg) @@ -116,10 +120,18 @@ class MessageTreater: yield r elif res is not None: - if not hasattr(res, "server") or res.server is None: - res.server = msg.server + if isinstance(res, types.GeneratorType): + for r in res: + if not hasattr(r, "server") or r.server is None: + r.server = msg.server - yield res + yield r + + else: + if not hasattr(res, "server") or res.server is None: + res.server = msg.server + + yield res hook = next(hook_gen, None) diff --git a/setup.py b/setup.py index 36dddb4..a400c3c 100755 --- a/setup.py +++ b/setup.py @@ -63,6 +63,7 @@ setup( 'nemubot', 'nemubot.config', 'nemubot.datastore', + 'nemubot.datastore.nodes', 'nemubot.event', 'nemubot.exception', 'nemubot.hooks',