diff --git a/.gitmodules b/.gitmodules
deleted file mode 100644
index 23cf4a0..0000000
--- a/.gitmodules
+++ /dev/null
@@ -1,3 +0,0 @@
-[submodule "modules/nextstop/external"]
- path = modules/nextstop/external
- url = git://github.com/nbr23/NextStop.git
diff --git a/README.md b/README.md
index aa3b141..1d40faf 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,8 @@ Requirements
*nemubot* requires at least Python 3.3 to work.
+Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
+
Some modules (like `cve`, `nextstop` or `laposte`) require the
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
but the core and framework has no dependency.
diff --git a/modules/events.py b/modules/events.py
index 2887514..a35c28b 100644
--- a/modules/events.py
+++ b/modules/events.py
@@ -16,7 +16,7 @@ from more import Response
def help_full ():
- return "This module store a lot of events: ny, we, " + (", ".join(context.datas.index.keys())) + "\n!eventslist: gets list of timer\n!start /something/: launch a timer"
+ return "This module store a lot of events: ny, we, " + (", ".join(context.datas.index.keys() if hasattr(context, "datas") else [])) + "\n!eventslist: gets list of timer\n!start /something/: launch a timer"
def load(context):
diff --git a/modules/nextstop.xml b/modules/nextstop.xml
deleted file mode 100644
index d34e8ae..0000000
--- a/modules/nextstop.xml
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
diff --git a/modules/nextstop/__init__.py b/modules/nextstop/__init__.py
deleted file mode 100644
index 9530ab8..0000000
--- a/modules/nextstop/__init__.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# coding=utf-8
-
-"""Informe les usagers des prochains passages des transports en communs de la RATP"""
-
-from nemubot.exception import IMException
-from nemubot.hooks import hook
-from more import Response
-
-nemubotversion = 3.4
-
-from .external.src import ratp
-
-def help_full ():
- return "!ratp transport line [station]: Donne des informations sur les prochains passages du transport en commun séléctionné à l'arrêt désiré. Si aucune station n'est précisée, les liste toutes."
-
-
-@hook.command("ratp")
-def ask_ratp(msg):
- """Hook entry from !ratp"""
- if len(msg.args) >= 3:
- transport = msg.args[0]
- line = msg.args[1]
- station = msg.args[2]
- if len(msg.args) == 4:
- times = ratp.getNextStopsAtStation(transport, line, station, msg.args[3])
- else:
- times = ratp.getNextStopsAtStation(transport, line, station)
-
- if len(times) == 0:
- raise IMException("la station %s n'existe pas sur le %s ligne %s." % (station, transport, line))
-
- (time, direction, stationname) = times[0]
- return Response(message=["\x03\x02%s\x03\x02 direction %s" % (time, direction) for time, direction, stationname in times],
- title="Prochains passages du %s ligne %s à l'arrêt %s" % (transport, line, stationname),
- channel=msg.channel)
-
- elif len(msg.args) == 2:
- stations = ratp.getAllStations(msg.args[0], msg.args[1])
-
- if len(stations) == 0:
- raise IMException("aucune station trouvée.")
- return Response([s for s in stations], title="Stations", channel=msg.channel)
-
- else:
- raise IMException("Mauvais usage, merci de spécifier un type de transport et une ligne, ou de consulter l'aide du module.")
-
-@hook.command("ratp_alert")
-def ratp_alert(msg):
- if len(msg.args) == 2:
- transport = msg.args[0]
- cause = msg.args[1]
- incidents = ratp.getDisturbance(cause, transport)
- return Response(incidents, channel=msg.channel, nomore="No more incidents", count=" (%d more incidents)")
- else:
- raise IMException("Mauvais usage, merci de spécifier un type de transport et un type d'alerte (alerte, manif, travaux), ou de consulter l'aide du module.")
diff --git a/modules/nextstop/external b/modules/nextstop/external
deleted file mode 160000
index 3d5c9b2..0000000
--- a/modules/nextstop/external
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 3d5c9b2d52fbd214f5aaad00e5f3952de919b3e5
diff --git a/modules/ratp.py b/modules/ratp.py
new file mode 100644
index 0000000..7f4b211
--- /dev/null
+++ b/modules/ratp.py
@@ -0,0 +1,74 @@
+"""Informe les usagers des prochains passages des transports en communs de la RATP"""
+
+# PYTHON STUFFS #######################################################
+
+from nemubot.exception import IMException
+from nemubot.hooks import hook
+from more import Response
+
+from nextstop import ratp
+
+@hook.command("ratp",
+ help="Affiche les prochains horaires de passage",
+ help_usage={
+ "TRANSPORT": "Affiche les lignes du moyen de transport donné",
+ "TRANSPORT LINE": "Affiche les stations sur la ligne de transport donnée",
+ "TRANSPORT LINE STATION": "Affiche les prochains horaires de passage à l'arrêt donné",
+ "TRANSPORT LINE STATION DESTINATION": "Affiche les prochains horaires de passage dans la direction donnée",
+ })
+def ask_ratp(msg):
+ l = len(msg.args)
+
+ transport = msg.args[0] if l > 0 else None
+ line = msg.args[1] if l > 1 else None
+ station = msg.args[2] if l > 2 else None
+ direction = msg.args[3] if l > 3 else None
+
+ if station is not None:
+ times = sorted(ratp.getNextStopsAtStation(transport, line, station, direction), key=lambda i: i[0])
+
+ if len(times) == 0:
+ raise IMException("la station %s n'existe pas sur le %s ligne %s." % (station, transport, line))
+
+ (time, direction, stationname) = times[0]
+ return Response(message=["\x03\x02%s\x03\x02 direction %s" % (time, direction) for time, direction, stationname in times],
+ title="Prochains passages du %s ligne %s à l'arrêt %s" % (transport, line, stationname),
+ channel=msg.channel)
+
+ elif line is not None:
+ stations = ratp.getAllStations(transport, line)
+
+ if len(stations) == 0:
+ raise IMException("aucune station trouvée.")
+ return Response(stations, title="Stations", channel=msg.channel)
+
+ elif transport is not None:
+ lines = ratp.getTransportLines(transport)
+ if len(lines) == 0:
+ raise IMException("aucune ligne trouvée.")
+ return Response(lines, title="Lignes", channel=msg.channel)
+
+ else:
+ raise IMException("précise au moins un moyen de transport.")
+
+
+@hook.command("ratp_alert",
+ help="Affiche les perturbations en cours sur le réseau")
+def ratp_alert(msg):
+ if len(msg.args) == 0:
+ raise IMException("précise au moins un moyen de transport.")
+
+ l = len(msg.args)
+ transport = msg.args[0] if l > 0 else None
+ line = msg.args[1] if l > 1 else None
+
+ if line is not None:
+ d = ratp.getDisturbanceFromLine(transport, line)
+ if "date" in d and d["date"] is not None:
+ incidents = "Au {date[date]}, {title}: {message}".format(**d)
+ else:
+ incidents = "{title}: {message}".format(**d)
+ else:
+ incidents = ratp.getDisturbance(None, transport)
+
+ return Response(incidents, channel=msg.channel, nomore="No more incidents", count=" (%d more incidents)")
diff --git a/modules/suivi.py b/modules/suivi.py
index 79910d4..9e517da 100644
--- a/modules/suivi.py
+++ b/modules/suivi.py
@@ -2,14 +2,14 @@
# PYTHON STUFF ############################################
-import urllib.request
+import json
import urllib.parse
from bs4 import BeautifulSoup
import re
from nemubot.hooks import hook
from nemubot.exception import IMException
-from nemubot.tools.web import getURLContent
+from nemubot.tools.web import getURLContent, getJSON
from more import Response
@@ -17,8 +17,7 @@ from more import Response
def get_tnt_info(track_id):
values = []
- data = getURLContent('www.tnt.fr/public/suivi_colis/recherche/'
- 'visubontransport.do?bonTransport=%s' % track_id)
+ data = getURLContent('www.tnt.fr/public/suivi_colis/recherche/visubontransport.do?bonTransport=%s' % track_id)
soup = BeautifulSoup(data)
status_list = soup.find('div', class_='result__content')
if not status_list:
@@ -32,8 +31,7 @@ def get_tnt_info(track_id):
def get_colissimo_info(colissimo_id):
- colissimo_data = getURLContent("http://www.colissimo.fr/portail_colissimo/"
- "suivre.do?colispart=%s" % colissimo_id)
+ colissimo_data = getURLContent("http://www.colissimo.fr/portail_colissimo/suivre.do?colispart=%s" % colissimo_id)
soup = BeautifulSoup(colissimo_data)
dataArray = soup.find(class_='dataArray')
@@ -47,9 +45,8 @@ def get_colissimo_info(colissimo_id):
def get_chronopost_info(track_id):
data = urllib.parse.urlencode({'listeNumeros': track_id})
- track_baseurl = "http://www.chronopost.fr/expedier/" \
- "inputLTNumbersNoJahia.do?lang=fr_FR"
- track_data = urllib.request.urlopen(track_baseurl, data.encode('utf-8'))
+ track_baseurl = "http://www.chronopost.fr/expedier/inputLTNumbersNoJahia.do?lang=fr_FR"
+ track_data = getURLContent(track_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(track_data)
infoClass = soup.find(class_='numeroColi2')
@@ -65,9 +62,8 @@ def get_chronopost_info(track_id):
def get_colisprive_info(track_id):
data = urllib.parse.urlencode({'numColis': track_id})
- track_baseurl = "https://www.colisprive.com/moncolis/pages/" \
- "detailColis.aspx"
- track_data = urllib.request.urlopen(track_baseurl, data.encode('utf-8'))
+ track_baseurl = "https://www.colisprive.com/moncolis/pages/detailColis.aspx"
+ track_data = getURLContent(track_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(track_data)
dataArray = soup.find(class_='BandeauInfoColis')
@@ -82,8 +78,7 @@ def get_laposte_info(laposte_id):
data = urllib.parse.urlencode({'id': laposte_id})
laposte_baseurl = "http://www.part.csuivi.courrier.laposte.fr/suivi/index"
- laposte_data = urllib.request.urlopen(laposte_baseurl,
- data.encode('utf-8'))
+ laposte_data = getURLContent(laposte_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(laposte_data)
search_res = soup.find(class_='resultat_rech_simple_table').tbody.tr
if (soup.find(class_='resultat_rech_simple_table').thead
@@ -112,8 +107,7 @@ def get_postnl_info(postnl_id):
data = urllib.parse.urlencode({'barcodes': postnl_id})
postnl_baseurl = "http://www.postnl.post/details/"
- postnl_data = urllib.request.urlopen(postnl_baseurl,
- data.encode('utf-8'))
+ postnl_data = getURLContent(postnl_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(postnl_data)
if (soup.find(id='datatables')
and soup.find(id='datatables').tbody
@@ -132,6 +126,42 @@ def get_postnl_info(postnl_id):
return (post_status.lower(), post_destination, post_date)
+def get_fedex_info(fedex_id, lang="en_US"):
+ data = urllib.parse.urlencode({
+ 'data': json.dumps({
+ "TrackPackagesRequest": {
+ "appType": "WTRK",
+ "appDeviceType": "DESKTOP",
+ "uniqueKey": "",
+ "processingParameters": {},
+ "trackingInfoList": [
+ {
+ "trackNumberInfo": {
+ "trackingNumber": str(fedex_id),
+ "trackingQualifier": "",
+ "trackingCarrier": ""
+ }
+ }
+ ]
+ }
+ }),
+ 'action': "trackpackages",
+ 'locale': lang,
+ 'version': 1,
+ 'format': "json"
+ })
+ fedex_baseurl = "https://www.fedex.com/trackingCal/track"
+
+ fedex_data = getJSON(fedex_baseurl, data.encode('utf-8'))
+
+ if ("TrackPackagesResponse" in fedex_data and
+ "packageList" in fedex_data["TrackPackagesResponse"] and
+ len(fedex_data["TrackPackagesResponse"]["packageList"]) and
+ not fedex_data["TrackPackagesResponse"]["packageList"][0]["isInvalid"]
+ ):
+ return fedex_data["TrackPackagesResponse"]["packageList"][0]
+
+
# TRACKING HANDLERS ###################################################
def handle_tnt(tracknum):
@@ -189,6 +219,17 @@ def handle_coliprive(tracknum):
return ("Colis Privé: \x02%s\x0F : \x02%s\x0F." % (tracknum, info))
+def handle_fedex(tracknum):
+ info = get_fedex_info(tracknum)
+ if info:
+ if info["displayActDeliveryDateTime"] != "":
+ return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, delivered on: {displayActDeliveryDateTime}.".format(**info))
+ elif info["statusLocationCity"] != "":
+ return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: estimated delivery: {displayEstDeliveryDateTime}.".format(**info))
+ else:
+ return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, estimated delivery: {displayEstDeliveryDateTime}.".format(**info))
+
+
TRACKING_HANDLERS = {
'laposte': handle_laposte,
'postnl': handle_postnl,
@@ -196,6 +237,7 @@ TRACKING_HANDLERS = {
'chronopost': handle_chronopost,
'coliprive': handle_coliprive,
'tnt': handle_tnt,
+ 'fedex': handle_fedex,
}
diff --git a/modules/weather.py b/modules/weather.py
index 1fadc71..1de0eb7 100644
--- a/modules/weather.py
+++ b/modules/weather.py
@@ -1,6 +1,6 @@
# coding=utf-8
-"""The weather module"""
+"""The weather module. Powered by Dark Sky """
import datetime
import re
@@ -17,7 +17,7 @@ nemubotversion = 4.0
from more import Response
-URL_DSAPI = "https://api.forecast.io/forecast/%s/%%s,%%s"
+URL_DSAPI = "https://api.darksky.net/forecast/%s/%%s,%%s?lang=%%s&units=%%s"
def load(context):
if not context.config or "darkskyapikey" not in context.config:
@@ -30,34 +30,19 @@ def load(context):
URL_DSAPI = URL_DSAPI % context.config["darkskyapikey"]
-def help_full ():
- return "!weather /city/: Display the current weather in /city/."
-
-
-def fahrenheit2celsius(temp):
- return int((temp - 32) * 50/9)/10
-
-
-def mph2kmph(speed):
- return int(speed * 160.9344)/100
-
-
-def inh2mmh(size):
- return int(size * 254)/10
-
-
def format_wth(wth):
- return ("%s °C %s; precipitation (%s %% chance) intensity: %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" %
+ return ("%s °C %s; precipitation (%s %% chance) intensity: %s mm/h; relative humidity: %s %%; wind speed: %s m/s %s°; cloud coverage: %s %%; pressure: %s hPa; visibility: %s km; ozone: %s DU" %
(
- fahrenheit2celsius(wth["temperature"]),
+ wth["temperature"],
wth["summary"],
int(wth["precipProbability"] * 100),
- inh2mmh(wth["precipIntensity"]),
+ wth["precipIntensity"],
int(wth["humidity"] * 100),
- mph2kmph(wth["windSpeed"]),
+ wth["windSpeed"],
wth["windBearing"],
int(wth["cloudCover"] * 100),
int(wth["pressure"]),
+ int(wth["visibility"]),
int(wth["ozone"])
))
@@ -66,7 +51,7 @@ def format_forecast_daily(wth):
return ("%s; between %s-%s °C; precipitation (%s %% chance) intensity: maximum %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" %
(
wth["summary"],
- fahrenheit2celsius(wth["temperatureMin"]), fahrenheit2celsius(wth["temperatureMax"]),
+ wth["temperatureMin"], wth["temperatureMax"],
int(wth["precipProbability"] * 100),
inh2mmh(wth["precipIntensityMax"]),
int(wth["humidity"] * 100),
@@ -126,8 +111,8 @@ def treat_coord(msg):
raise IMException("indique-moi un nom de ville ou des coordonnées.")
-def get_json_weather(coords):
- wth = web.getJSON(URL_DSAPI % (float(coords[0]), float(coords[1])))
+def get_json_weather(coords, lang="en", units="auto"):
+ wth = web.getJSON(URL_DSAPI % (float(coords[0]), float(coords[1]), lang, units))
# First read flags
if wth is None or "darksky-unavailable" in wth["flags"]:
@@ -149,10 +134,16 @@ def cmd_coordinates(msg):
return Response("Les coordonnées de %s sont %s,%s" % (msg.args[0], coords["lat"], coords["long"]), channel=msg.channel)
-@hook.command("alert")
+@hook.command("alert",
+ keywords={
+ "lang=LANG": "change the output language of weather sumarry; default: en",
+ "units=UNITS": "return weather conditions in the requested units; default: auto",
+ })
def cmd_alert(msg):
loc, coords, specific = treat_coord(msg)
- wth = get_json_weather(coords)
+ wth = get_json_weather(coords,
+ lang=msg.kwargs["lang"] if "lang" in msg.kwargs else "en",
+ units=msg.kwargs["units"] if "units" in msg.kwargs else "auto")
res = Response(channel=msg.channel, nomore="No more weather alert", count=" (%d more alerts)")
@@ -166,10 +157,20 @@ def cmd_alert(msg):
return res
-@hook.command("météo")
+@hook.command("météo",
+ help="Display current weather and previsions",
+ help_usage={
+ "CITY": "Display the current weather and previsions in CITY",
+ },
+ keywords={
+ "lang=LANG": "change the output language of weather sumarry; default: en",
+ "units=UNITS": "return weather conditions in the requested units; default: auto",
+ })
def cmd_weather(msg):
loc, coords, specific = treat_coord(msg)
- wth = get_json_weather(coords)
+ wth = get_json_weather(coords,
+ lang=msg.kwargs["lang"] if "lang" in msg.kwargs else "en",
+ units=msg.kwargs["units"] if "units" in msg.kwargs else "auto")
res = Response(channel=msg.channel, nomore="No more weather information")
@@ -243,3 +244,7 @@ def parseask(msg):
context.save()
return Response("ok, j'ai bien noté les coordonnées de %s" % res.group("city"),
msg.channel, msg.nick)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/nemubot/__init__.py b/nemubot/__init__.py
index a56c472..48de6ea 100644
--- a/nemubot/__init__.py
+++ b/nemubot/__init__.py
@@ -17,9 +17,9 @@
__version__ = '4.0.dev3'
__author__ = 'nemunaire'
-from nemubot.modulecontext import ModuleContext
+from nemubot.modulecontext import _ModuleContext
-context = ModuleContext(None, None)
+context = _ModuleContext()
def requires_version(min=None, max=None):
@@ -53,41 +53,50 @@ def attach(pid, socketfile):
sys.stderr.write("\n")
return 1
- from select import select
+ import select
+ mypoll = select.poll()
+
+ mypoll.register(sys.stdin.fileno(), select.POLLIN | select.POLLPRI)
+ mypoll.register(sock.fileno(), select.POLLIN | select.POLLPRI)
try:
while True:
- rl, wl, xl = select([sys.stdin, sock], [], [])
+ for fd, flag in mypoll.poll():
+ if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
+ sock.close()
+ print("Connection closed.")
+ return 1
- if sys.stdin in rl:
- line = sys.stdin.readline().strip()
- if line == "exit" or line == "quit":
- return 0
- elif line == "reload":
- import os, signal
- os.kill(pid, signal.SIGHUP)
- print("Reload signal sent. Please wait...")
+ if fd == sys.stdin.fileno():
+ line = sys.stdin.readline().strip()
+ if line == "exit" or line == "quit":
+ return 0
+ elif line == "reload":
+ import os, signal
+ os.kill(pid, signal.SIGHUP)
+ print("Reload signal sent. Please wait...")
- elif line == "shutdown":
- import os, signal
- os.kill(pid, signal.SIGTERM)
- print("Shutdown signal sent. Please wait...")
+ elif line == "shutdown":
+ import os, signal
+ os.kill(pid, signal.SIGTERM)
+ print("Shutdown signal sent. Please wait...")
- elif line == "kill":
- import os, signal
- os.kill(pid, signal.SIGKILL)
- print("Signal sent...")
- return 0
+ elif line == "kill":
+ import os, signal
+ os.kill(pid, signal.SIGKILL)
+ print("Signal sent...")
+ return 0
- elif line == "stack" or line == "stacks":
- import os, signal
- os.kill(pid, signal.SIGUSR1)
- print("Debug signal sent. Consult logs.")
+ elif line == "stack" or line == "stacks":
+ import os, signal
+ os.kill(pid, signal.SIGUSR1)
+ print("Debug signal sent. Consult logs.")
- else:
- sock.send(line.encode() + b'\r\n')
+ else:
+ sock.send(line.encode() + b'\r\n')
+
+ if fd == sock.fileno():
+ sys.stdout.write(sock.recv(2048).decode())
- if sock in rl:
- sys.stdout.write(sock.recv(2048).decode())
except KeyboardInterrupt:
pass
except:
@@ -97,13 +106,28 @@ def attach(pid, socketfile):
return 0
-def daemonize():
+def daemonize(socketfile=None, autoattach=True):
"""Detach the running process to run as a daemon
"""
import os
import sys
+ if socketfile is not None:
+ try:
+ pid = os.fork()
+ if pid > 0:
+ if autoattach:
+ import time
+ os.waitpid(pid, 0)
+ time.sleep(1)
+ sys.exit(attach(pid, socketfile))
+ else:
+ sys.exit(0)
+ except OSError as err:
+ sys.stderr.write("Unable to fork: %s\n" % err)
+ sys.exit(1)
+
try:
pid = os.fork()
if pid > 0:
diff --git a/nemubot/__main__.py b/nemubot/__main__.py
index 5a236f4..e1576fb 100644
--- a/nemubot/__main__.py
+++ b/nemubot/__main__.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2017 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -37,6 +37,9 @@ def main():
default=["./modules/"],
help="directory to use as modules store")
+ parser.add_argument("-A", "--no-attach", action="store_true",
+ help="don't attach after fork")
+
parser.add_argument("-d", "--debug", action="store_true",
help="don't deamonize, keep in foreground")
@@ -71,32 +74,10 @@ def main():
args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile))
args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile))
args.logfile = os.path.abspath(os.path.expanduser(args.logfile))
- args.files = [ x for x in map(os.path.abspath, args.files)]
- args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)]
+ args.files = [x for x in map(os.path.abspath, args.files)]
+ args.modules_path = [x for x in map(os.path.abspath, args.modules_path)]
- # Check if an instance is already launched
- if args.pidfile is not None and os.path.isfile(args.pidfile):
- with open(args.pidfile, "r") as f:
- pid = int(f.readline())
- try:
- os.kill(pid, 0)
- except OSError:
- pass
- else:
- from nemubot import attach
- sys.exit(attach(pid, args.socketfile))
-
- # Daemonize
- if not args.debug:
- from nemubot import daemonize
- daemonize()
-
- # Store PID to pidfile
- if args.pidfile is not None:
- with open(args.pidfile, "w+") as f:
- f.write(str(os.getpid()))
-
- # Setup loggin interface
+ # Setup logging interface
import logging
logger = logging.getLogger("nemubot")
logger.setLevel(logging.DEBUG)
@@ -115,6 +96,18 @@ def main():
fh.setFormatter(formatter)
logger.addHandler(fh)
+ # Check if an instance is already launched
+ if args.pidfile is not None and os.path.isfile(args.pidfile):
+ with open(args.pidfile, "r") as f:
+ pid = int(f.readline())
+ try:
+ os.kill(pid, 0)
+ except OSError:
+ pass
+ else:
+ from nemubot import attach
+ sys.exit(attach(pid, args.socketfile))
+
# Add modules dir paths
modules_paths = list()
for path in args.modules_path:
@@ -125,7 +118,7 @@ def main():
# Create bot context
from nemubot import datastore
- from nemubot.bot import Bot
+ from nemubot.bot import Bot, sync_act
context = Bot(modules_paths=modules_paths,
data_store=datastore.XML(args.data_path),
verbosity=args.verbose)
@@ -141,7 +134,7 @@ def main():
# Load requested configuration files
for path in args.files:
if os.path.isfile(path):
- context.sync_queue.put_nowait(["loadconf", path])
+ sync_act("loadconf", path)
else:
logger.error("%s is not a readable file", path)
@@ -149,6 +142,17 @@ def main():
for module in args.module:
__import__(module)
+ if args.socketfile:
+ from nemubot.server.socket import UnixSocketListener
+ context.add_server(UnixSocketListener(new_server_cb=context.add_server,
+ location=args.socketfile,
+ name="master_socket"))
+
+ # Daemonize
+ if not args.debug:
+ from nemubot import daemonize
+ daemonize(args.socketfile, not args.no_attach)
+
# Signals handling
def sigtermhandler(signum, frame):
"""On SIGTERM and SIGINT, quit nicely"""
@@ -165,22 +169,27 @@ def main():
# Reload configuration file
for path in args.files:
if os.path.isfile(path):
- context.sync_queue.put_nowait(["loadconf", path])
+ sync_act("loadconf", path)
signal.signal(signal.SIGHUP, sighuphandler)
def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces"""
- import traceback
+ import threading, traceback
for threadId, stack in sys._current_frames().items():
- logger.debug("########### Thread %d:\n%s",
- threadId,
+ thName = "#%d" % threadId
+ for th in threading.enumerate():
+ if th.ident == threadId:
+ thName = th.name
+ break
+ logger.debug("########### Thread %s:\n%s",
+ thName,
"".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler)
- if args.socketfile:
- from nemubot.server.socket import SocketListener
- context.add_server(SocketListener(context.add_server, "master_socket",
- sock_location=args.socketfile))
+ # Store PID to pidfile
+ if args.pidfile is not None:
+ with open(args.pidfile, "w+") as f:
+ f.write(str(os.getpid()))
# context can change when performing an hotswap, always join the latest context
oldcontext = None
@@ -195,5 +204,6 @@ def main():
sigusr1handler(0, None)
sys.exit(0)
+
if __name__ == "__main__":
main()
diff --git a/nemubot/bot.py b/nemubot/bot.py
index 2657d52..b0d3915 100644
--- a/nemubot/bot.py
+++ b/nemubot/bot.py
@@ -16,7 +16,9 @@
from datetime import datetime, timezone
import logging
+from multiprocessing import JoinableQueue
import threading
+import select
import sys
from nemubot import __version__
@@ -26,6 +28,11 @@ import nemubot.hooks
logger = logging.getLogger("nemubot")
+sync_queue = JoinableQueue()
+
+def sync_act(*args):
+ sync_queue.put(list(args))
+
class Bot(threading.Thread):
@@ -42,7 +49,7 @@ class Bot(threading.Thread):
verbosity -- verbosity level
"""
- threading.Thread.__init__(self)
+ super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__,
@@ -61,6 +68,7 @@ class Bot(threading.Thread):
self.datastore.open()
# Keep global context: servers and modules
+ self._poll = select.poll()
self.servers = dict()
self.modules = dict()
self.modules_configuration = dict()
@@ -138,60 +146,84 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue()
self.cnsr_thrd = list()
self.cnsr_thrd_size = -1
- # Synchrone actions to be treated by main thread
- self.sync_queue = Queue()
def run(self):
- from select import select
- from nemubot.server import _lock, _rlist, _wlist, _xlist
+ global sync_queue
+
+ # Rewrite the sync_queue, as the daemonization process tend to disturb it
+ old_sync_queue, sync_queue = sync_queue, JoinableQueue()
+ while not old_sync_queue.empty():
+ sync_queue.put_nowait(old_sync_queue.get())
+
+ self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
logger.info("Starting main loop")
self.stop = False
while not self.stop:
- with _lock:
- try:
- rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
- except:
- logger.error("Something went wrong in select")
- fnd_smth = False
- # Looking for invalid server
- for r in _rlist:
- if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
- _rlist.remove(r)
- logger.error("Found invalid object in _rlist: " + str(r))
- fnd_smth = True
- for w in _wlist:
- if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
- _wlist.remove(w)
- logger.error("Found invalid object in _wlist: " + str(w))
- fnd_smth = True
- for x in _xlist:
- if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
- _xlist.remove(x)
- logger.error("Found invalid object in _xlist: " + str(x))
- fnd_smth = True
- if not fnd_smth:
- logger.exception("Can't continue, sorry")
- self.quit()
- continue
+ for fd, flag in self._poll.poll():
+ # Handle internal socket passing orders
+ if fd != sync_queue._reader.fileno() and fd in self.servers:
+ srv = self.servers[fd]
- for x in xl:
- try:
- x.exception()
- except:
- logger.exception("Uncatched exception on server exception")
- for w in wl:
- try:
- w.write_select()
- except:
- logger.exception("Uncatched exception on server write")
- for r in rl:
- for i in r.read():
+ if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
- self.receive_message(r, i)
+ srv.exception(flag)
except:
- logger.exception("Uncatched exception on server read")
+ logger.exception("Uncatched exception on server exception")
+
+ if srv.fileno() > 0:
+ if flag & (select.POLLOUT):
+ try:
+ srv.async_write()
+ except:
+ logger.exception("Uncatched exception on server write")
+
+ if flag & (select.POLLIN | select.POLLPRI):
+ try:
+ for i in srv.async_read():
+ self.receive_message(srv, i)
+ except:
+ logger.exception("Uncatched exception on server read")
+
+ else:
+ del self.servers[fd]
+
+
+ # Always check the sync queue
+ while not sync_queue.empty():
+ args = sync_queue.get()
+ action = args.pop(0)
+
+ logger.debug("Executing sync_queue action %s%s", action, args)
+
+ if action == "sckt" and len(args) >= 2:
+ try:
+ if args[0] == "write":
+ self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI)
+ elif args[0] == "unwrite":
+ self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI)
+
+ elif args[0] == "register":
+ self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI)
+ elif args[0] == "unregister":
+ self._poll.unregister(int(args[1]))
+ except:
+ logger.exception("Unhandled excpetion during action:")
+
+ elif action == "exit":
+ self.quit()
+
+ elif action == "launch_consumer":
+ pass # This is treated after the loop
+
+ elif action == "loadconf":
+ for path in args:
+ logger.debug("Load configuration from %s", path)
+ self.load_file(path)
+ logger.info("Configurations successfully loaded")
+
+ sync_queue.task_done()
# Launch new consumer threads if necessary
@@ -202,17 +234,7 @@ class Bot(threading.Thread):
c = Consumer(self)
self.cnsr_thrd.append(c)
c.start()
-
- while self.sync_queue.qsize() > 0:
- action = self.sync_queue.get_nowait()
- if action[0] == "exit":
- self.quit()
- elif action[0] == "loadconf":
- for path in action[1:]:
- logger.debug("Load configuration from %s", path)
- self.load_file(path)
- logger.info("Configurations successfully loaded")
- self.sync_queue.task_done()
+ sync_queue = None
logger.info("Ending main loop")
@@ -385,7 +407,13 @@ class Bot(threading.Thread):
self.event_timer.cancel()
if len(self.events):
- remaining = self.events[0].time_left.total_seconds()
+ try:
+ remaining = self.events[0].time_left.total_seconds()
+ except:
+ logger.exception("An error occurs during event time calculation:")
+ self.events.pop(0)
+ return self._update_event_timer()
+
logger.debug("Update timer: next event in %d seconds", remaining)
self.event_timer = threading.Timer(remaining if remaining > 0 else 0, self._end_event_timer)
self.event_timer.start()
@@ -400,6 +428,7 @@ class Bot(threading.Thread):
while len(self.events) > 0 and datetime.now(timezone.utc) >= self.events[0].current:
evt = self.events.pop(0)
self.cnsr_queue.put_nowait(EventConsumer(evt))
+ sync_act("launch_consumer")
self._update_event_timer()
@@ -419,7 +448,7 @@ class Bot(threading.Thread):
self.servers[fileno] = srv
self.servers[srv.name] = srv
if autoconnect and not hasattr(self, "noautoconnect"):
- srv.open()
+ srv.connect()
return True
else:
@@ -463,7 +492,7 @@ class Bot(threading.Thread):
module.print = prnt
# Create module context
- from nemubot.modulecontext import ModuleContext
+ from nemubot.modulecontext import _ModuleContext, ModuleContext
module.__nemubot_context__ = ModuleContext(self, module)
if not hasattr(module, "logger"):
@@ -471,7 +500,7 @@ class Bot(threading.Thread):
# Replace imported context by real one
for attr in module.__dict__:
- if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext:
+ if attr != "__nemubot_context__" and type(module.__dict__[attr]) == _ModuleContext:
module.__dict__[attr] = module.__nemubot_context__
# Register decorated functions
@@ -532,28 +561,29 @@ class Bot(threading.Thread):
def quit(self):
"""Save and unload modules and disconnect servers"""
- self.datastore.close()
-
if self.event_timer is not None:
logger.info("Stop the event timer...")
self.event_timer.cancel()
+ logger.info("Save and unload all modules...")
+ for mod in self.modules.items():
+ self.unload_module(mod)
+
+ logger.info("Close all servers connection...")
+ for srv in [self.servers[k] for k in self.servers]:
+ srv.close()
+
logger.info("Stop consumers")
k = self.cnsr_thrd
for cnsr in k:
cnsr.stop = True
- logger.info("Save and unload all modules...")
- k = list(self.modules.keys())
- for mod in k:
- self.unload_module(mod)
+ self.datastore.close()
- logger.info("Close all servers connection...")
- k = list(self.servers.keys())
- for srv in k:
- self.servers[srv].close()
-
- self.stop = True
+ if self.stop is False or sync_queue is not None:
+ self.stop = True
+ sync_act("end")
+ sync_queue.join()
# Treatment
diff --git a/nemubot/config/module.py b/nemubot/config/module.py
index ab51971..7586697 100644
--- a/nemubot/config/module.py
+++ b/nemubot/config/module.py
@@ -15,7 +15,7 @@
# along with this program. If not, see .
from nemubot.config import get_boolean
-from nemubot.tools.xmlparser.genericnode import GenericNode
+from nemubot.datastore.nodes.generic import GenericNode
class Module(GenericNode):
diff --git a/nemubot/datastore/abstract.py b/nemubot/datastore/abstract.py
index 96e2c0d..f54bbcd 100644
--- a/nemubot/datastore/abstract.py
+++ b/nemubot/datastore/abstract.py
@@ -23,8 +23,7 @@ class Abstract:
"""Initialize a new empty storage tree
"""
- from nemubot.tools.xmlparser import module_state
- return module_state.ModuleState("nemubotstate")
+ return None
def open(self):
return
diff --git a/nemubot/datastore/nodes/__init__.py b/nemubot/datastore/nodes/__init__.py
new file mode 100644
index 0000000..e4b2788
--- /dev/null
+++ b/nemubot/datastore/nodes/__init__.py
@@ -0,0 +1,18 @@
+# Nemubot is a smart and modulable IM bot.
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from nemubot.datastore.nodes.generic import ParsingNode
+from nemubot.datastore.nodes.serializable import Serializable
diff --git a/nemubot/tools/xmlparser/basic.py b/nemubot/datastore/nodes/basic.py
similarity index 67%
rename from nemubot/tools/xmlparser/basic.py
rename to nemubot/datastore/nodes/basic.py
index 8456629..6fbd136 100644
--- a/nemubot/tools/xmlparser/basic.py
+++ b/nemubot/datastore/nodes/basic.py
@@ -14,11 +14,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-class ListNode:
+from nemubot.datastore.nodes.serializable import Serializable
+
+
+class ListNode(Serializable):
"""XML node representing a Python dictionnnary
"""
+ serializetag = "list"
+
def __init__(self, **kwargs):
self.items = list()
@@ -27,6 +32,9 @@ class ListNode:
self.items.append(child)
return True
+ def parsedForm(self):
+ return self.items
+
def __len__(self):
return len(self.items)
@@ -44,11 +52,21 @@ class ListNode:
return self.items.__repr__()
-class DictNode:
+ def serialize(self):
+ from nemubot.datastore.nodes.generic import ParsingNode
+ node = ParsingNode(tag=self.serializetag)
+ for i in self.items:
+ node.children.append(ParsingNode.serialize_node(i))
+ return node
+
+
+class DictNode(Serializable):
"""XML node representing a Python dictionnnary
"""
+ serializetag = "dict"
+
def __init__(self, **kwargs):
self.items = dict()
self._cur = None
@@ -56,44 +74,20 @@ class DictNode:
def startElement(self, name, attrs):
if self._cur is None and "key" in attrs:
- self._cur = (attrs["key"], "")
- return True
+ self._cur = attrs["key"]
return False
-
- def characters(self, content):
- if self._cur is not None:
- key, cnt = self._cur
- if isinstance(cnt, str):
- cnt += content
- self._cur = key, cnt
-
-
- def endElement(self, name):
- if name is None or self._cur is None:
- return
-
- key, cnt = self._cur
- if isinstance(cnt, list) and len(cnt) == 1:
- self.items[key] = cnt
- else:
- self.items[key] = cnt
-
- self._cur = None
- return True
-
-
def addChild(self, name, child):
if self._cur is None:
return False
- key, cnt = self._cur
- if not isinstance(cnt, list):
- cnt = []
- cnt.append(child)
- self._cur = key, cnt
+ self.items[self._cur] = child
+ self._cur = None
return True
+ def parsedForm(self):
+ return self.items
+
def __getitem__(self, item):
return self.items[item]
@@ -106,3 +100,13 @@ class DictNode:
def __repr__(self):
return self.items.__repr__()
+
+
+ def serialize(self):
+ from nemubot.datastore.nodes.generic import ParsingNode
+ node = ParsingNode(tag=self.serializetag)
+ for k in self.items:
+ chld = ParsingNode.serialize_node(self.items[k])
+ chld.attrs["key"] = k
+ node.children.append(chld)
+ return node
diff --git a/nemubot/tools/xmlparser/genericnode.py b/nemubot/datastore/nodes/generic.py
similarity index 64%
rename from nemubot/tools/xmlparser/genericnode.py
rename to nemubot/datastore/nodes/generic.py
index 9c29a23..c9840bc 100644
--- a/nemubot/tools/xmlparser/genericnode.py
+++ b/nemubot/datastore/nodes/generic.py
@@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from nemubot.datastore.nodes.serializable import Serializable
+
+
class ParsingNode:
"""Allow any kind of subtags, just keep parsed ones
@@ -53,6 +56,47 @@ class ParsingNode:
return item in self.attrs
+ def serialize_node(node, **def_kwargs):
+ """Serialize any node or basic data to a ParsingNode instance"""
+
+ if isinstance(node, Serializable):
+ node = node.serialize()
+
+ if isinstance(node, str):
+ from nemubot.datastore.nodes.python import StringNode
+ pn = StringNode(**def_kwargs)
+ pn.value = node
+ return pn
+
+ elif isinstance(node, int):
+ from nemubot.datastore.nodes.python import IntNode
+ pn = IntNode(**def_kwargs)
+ pn.value = node
+ return pn
+
+ elif isinstance(node, float):
+ from nemubot.datastore.nodes.python import FloatNode
+ pn = FloatNode(**def_kwargs)
+ pn.value = node
+ return pn
+
+ elif isinstance(node, list):
+ from nemubot.datastore.nodes.basic import ListNode
+ pn = ListNode(**def_kwargs)
+ pn.items = node
+ return pn.serialize()
+
+ elif isinstance(node, dict):
+ from nemubot.datastore.nodes.basic import DictNode
+ pn = DictNode(**def_kwargs)
+ pn.items = node
+ return pn.serialize()
+
+ else:
+ assert isinstance(node, ParsingNode)
+ return node
+
+
class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary
diff --git a/nemubot/datastore/nodes/python.py b/nemubot/datastore/nodes/python.py
new file mode 100644
index 0000000..6e4278b
--- /dev/null
+++ b/nemubot/datastore/nodes/python.py
@@ -0,0 +1,91 @@
+# Nemubot is a smart and modulable IM bot.
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from nemubot.datastore.nodes.serializable import Serializable
+
+
+class PythonTypeNode(Serializable):
+
+ """XML node representing a Python simple type
+ """
+
+ def __init__(self, **kwargs):
+ self.value = None
+ self._cnt = ""
+
+
+ def characters(self, content):
+ self._cnt += content
+
+
+ def endElement(self, name):
+ raise NotImplemented
+
+
+ def __repr__(self):
+ return self.value.__repr__()
+
+
+ def parsedForm(self):
+ return self.value
+
+ def serialize(self):
+ raise NotImplemented
+
+
+class IntNode(PythonTypeNode):
+
+ serializetag = "int"
+
+ def endElement(self, name):
+ self.value = int(self._cnt)
+ return True
+
+ def serialize(self):
+ from nemubot.datastore.nodes.generic import ParsingNode
+ node = ParsingNode(tag=self.serializetag)
+ node.content = str(self.value)
+ return node
+
+
+class FloatNode(PythonTypeNode):
+
+ serializetag = "float"
+
+ def endElement(self, name):
+ self.value = float(self._cnt)
+ return True
+
+ def serialize(self):
+ from nemubot.datastore.nodes.generic import ParsingNode
+ node = ParsingNode(tag=self.serializetag)
+ node.content = str(self.value)
+ return node
+
+
+class StringNode(PythonTypeNode):
+
+ serializetag = "str"
+
+ def endElement(self, name):
+ self.value = str(self._cnt)
+ return True
+
+ def serialize(self):
+ from nemubot.datastore.nodes.generic import ParsingNode
+ node = ParsingNode(tag=self.serializetag)
+ node.content = str(self.value)
+ return node
diff --git a/nemubot/datastore/nodes/serializable.py b/nemubot/datastore/nodes/serializable.py
new file mode 100644
index 0000000..e543699
--- /dev/null
+++ b/nemubot/datastore/nodes/serializable.py
@@ -0,0 +1,22 @@
+# Nemubot is a smart and modulable IM bot.
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+
+class Serializable:
+
+ def serialize(self):
+ # Implementations of this function should return ParsingNode items
+ return NotImplemented
diff --git a/nemubot/datastore/xml.py b/nemubot/datastore/xml.py
index 46dca70..a82318d 100644
--- a/nemubot/datastore/xml.py
+++ b/nemubot/datastore/xml.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -36,17 +36,24 @@ class XML(Abstract):
rotate -- auto-backup files?
"""
- self.basedir = basedir
+ self.basedir = os.path.abspath(basedir)
self.rotate = rotate
self.nb_save = 0
+ logger.info("Initiate XML datastore at %s, rotation %s",
+ self.basedir,
+ "enabled" if self.rotate else "disabled")
+
+
def open(self):
"""Lock the directory"""
if not os.path.isdir(self.basedir):
+ logger.debug("Datastore directory not found, creating: %s", self.basedir)
os.mkdir(self.basedir)
- lock_path = os.path.join(self.basedir, ".used_by_nemubot")
+ lock_path = self._get_lock_file_path()
+ logger.debug("Locking datastore directory via %s", lock_path)
self.lock_file = open(lock_path, 'a+')
ok = True
@@ -64,56 +71,95 @@ class XML(Abstract):
self.lock_file.write(str(os.getpid()))
self.lock_file.flush()
+ logger.info("Datastore successfuly opened at %s", self.basedir)
return True
+
def close(self):
"""Release a locked path"""
if hasattr(self, "lock_file"):
self.lock_file.close()
- lock_path = os.path.join(self.basedir, ".used_by_nemubot")
+ lock_path = self._get_lock_file_path()
if os.path.isdir(self.basedir) and os.path.exists(lock_path):
os.unlink(lock_path)
del self.lock_file
+ logger.info("Datastore successfully closed at %s", self.basedir)
return True
+ else:
+ logger.warn("Datastore not open/locked or lock file not found")
return False
+
def _get_data_file_path(self, module):
"""Get the path to the module data file"""
return os.path.join(self.basedir, module + ".xml")
- def load(self, module):
+
+ def _get_lock_file_path(self):
+ """Get the path to the datastore lock file"""
+
+ return os.path.join(self.basedir, ".used_by_nemubot")
+
+
+ def load(self, module, extendsTags={}):
"""Load data for the given module
Argument:
module -- the module name of data to load
"""
+ logger.debug("Trying to load data for %s%s",
+ module,
+ (" with tags: " + ", ".join(extendsTags.keys())) if len(extendsTags) else "")
+
data_file = self._get_data_file_path(module)
+ def parse(path):
+ from nemubot.tools.xmlparser import XMLParser
+ from nemubot.datastore.nodes import basic as basicNodes
+ from nemubot.datastore.nodes import python as pythonNodes
+ from nemubot.message.command import Command
+ from nemubot.scope import Scope
+
+ d = {
+ basicNodes.ListNode.serializetag: basicNodes.ListNode,
+ basicNodes.DictNode.serializetag: basicNodes.DictNode,
+ pythonNodes.IntNode.serializetag: pythonNodes.IntNode,
+ pythonNodes.FloatNode.serializetag: pythonNodes.FloatNode,
+ pythonNodes.StringNode.serializetag: pythonNodes.StringNode,
+ Command.serializetag: Command,
+ Scope.serializetag: Scope,
+ }
+ d.update(extendsTags)
+
+ p = XMLParser(d)
+ return p.parse_file(path)
+
# Try to load original file
if os.path.isfile(data_file):
- from nemubot.tools.xmlparser import parse_file
try:
- return parse_file(data_file)
+ return parse(data_file)
except xml.parsers.expat.ExpatError:
# Try to load from backup
for i in range(10):
path = data_file + "." + str(i)
if os.path.isfile(path):
try:
- cnt = parse_file(path)
+ cnt = parse(path)
- logger.warn("Restoring from backup: %s", path)
+ logger.warn("Restoring data from backup: %s", path)
return cnt
except xml.parsers.expat.ExpatError:
continue
# Default case: initialize a new empty datastore
+ logger.warn("No data found in store for %s, creating new set", module)
return Abstract.load(self, module)
+
def _rotate(self, path):
"""Backup given path
@@ -130,6 +176,25 @@ class XML(Abstract):
if os.path.isfile(src):
os.rename(src, dst)
+
+ def _save_node(self, gen, node):
+ from nemubot.datastore.nodes.generic import ParsingNode
+
+ # First, get the serialized form of the node
+ node = ParsingNode.serialize_node(node)
+
+ assert node.tag is not None, "Undefined tag name"
+
+ gen.startElement(node.tag, {k: str(node.attrs[k]) for k in node.attrs})
+
+ gen.characters(node.content)
+
+ for child in node.children:
+ self._save_node(gen, child)
+
+ gen.endElement(node.tag)
+
+
def save(self, module, data):
"""Load data for the given module
@@ -139,8 +204,22 @@ class XML(Abstract):
"""
path = self._get_data_file_path(module)
+ logger.debug("Trying to save data for module %s in %s", module, path)
if self.rotate:
self._rotate(path)
- return data.save(path)
+ import tempfile
+ _, tmpath = tempfile.mkstemp()
+ with open(tmpath, "w") as f:
+ import xml.sax.saxutils
+ gen = xml.sax.saxutils.XMLGenerator(f, "utf-8")
+ gen.startDocument()
+ self._save_node(gen, data)
+ gen.endDocument()
+
+ # Atomic save
+ import shutil
+ shutil.move(tmpath, path)
+
+ return True
diff --git a/nemubot/message/abstract.py b/nemubot/message/abstract.py
index 5d74549..bf9a030 100644
--- a/nemubot/message/abstract.py
+++ b/nemubot/message/abstract.py
@@ -16,12 +16,17 @@
from datetime import datetime, timezone
+from nemubot.datastore.nodes import Serializable
-class Abstract:
+
+class Abstract(Serializable):
"""This class represents an abstract message"""
- def __init__(self, server=None, date=None, to=None, to_response=None, frm=None):
+ serializetag = "nemubotAMessage"
+
+
+ def __init__(self, server=None, date=None, to=None, to_response=None, frm=None, frm_owner=False):
"""Initialize an abstract message
Arguments:
@@ -40,7 +45,7 @@ class Abstract:
else [ to_response ])
self.frm = frm # None allowed when it designate this bot
- self.frm_owner = False # Filled later, in consumer
+ self.frm_owner = frm_owner
@property
@@ -65,6 +70,14 @@ class Abstract:
return self.frm
+ @property
+ def scope(self):
+ from nemubot.scope import Scope
+ return Scope(server=self.server,
+ channel=self.to_response[0],
+ nick=self.frm)
+
+
def accept(self, visitor):
visitor.visit(self)
@@ -78,7 +91,8 @@ class Abstract:
"date": self.date,
"to": self.to,
"to_response": self._to_response,
- "frm": self.frm
+ "frm": self.frm,
+ "frm_owner": self.frm_owner,
}
for w in without:
@@ -86,3 +100,8 @@ class Abstract:
del ret[w]
return ret
+
+
+ def serialize(self):
+ from nemubot.datastore.nodes import ParsingNode
+ return ParsingNode(tag=Abstract.serializetag, **self.export_args())
diff --git a/nemubot/message/command.py b/nemubot/message/command.py
index 6c208b2..2fe8893 100644
--- a/nemubot/message/command.py
+++ b/nemubot/message/command.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -21,6 +21,9 @@ class Command(Abstract):
"""This class represents a specialized TextMessage"""
+ serializetag = "nemubotCommand"
+
+
def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs):
super().__init__(*nargs, **kargs)
@@ -28,17 +31,35 @@ class Command(Abstract):
self.args = args if args is not None else list()
self.kwargs = kwargs if kwargs is not None else dict()
- def __str__(self):
+
+ def __repr__(self):
return self.cmd + " @" + ",@".join(self.args)
- @property
- def cmds(self):
- # TODO: this is for legacy modules
- return [self.cmd] + self.args
+
+ def addChild(self, name, child):
+ if name == "list":
+ self.args = child
+ elif name == "dict":
+ self.kwargs = child
+ else:
+ return False
+ return True
+
+
+ def serialize(self):
+ from nemubot.datastore.nodes import ParsingNode
+ node = ParsingNode(tag=Command.serializetag, cmd=self.cmd)
+ if len(self.args):
+ node.children.append(ParsingNode.serialize_node(self.args))
+ if len(self.kwargs):
+ node.children.append(ParsingNode.serialize_node(self.kwargs))
+ return node
class OwnerCommand(Command):
"""This class represents a special command incomming from the owner"""
+ serializetag = "nemubotOCommand"
+
pass
diff --git a/nemubot/modulecontext.py b/nemubot/modulecontext.py
index 1d1b3d0..7befe18 100644
--- a/nemubot/modulecontext.py
+++ b/nemubot/modulecontext.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2017 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -14,105 +14,62 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-class ModuleContext:
+class _ModuleContext:
- def __init__(self, context, module):
- """Initialize the module context
-
- arguments:
- context -- the bot context
- module -- the module
- """
+ def __init__(self, module=None):
+ self.module = module
if module is not None:
- module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
+ self.module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
else:
- module_name = ""
-
- # Load module configuration if exists
- if (context is not None and
- module_name in context.modules_configuration):
- self.config = context.modules_configuration[module_name]
- else:
- from nemubot.config.module import Module
- self.config = Module(module_name)
+ self.module_name = ""
self.hooks = list()
self.events = list()
- self.debug = context.verbosity > 0 if context is not None else False
+ self.extendtags = dict()
+ self.debug = False
+ from nemubot.config.module import Module
+ self.config = Module(self.module_name)
+
+ def load_data(self):
+ return None
+
+ def add_hook(self, hook, *triggers):
from nemubot.hooks import Abstract as AbstractHook
+ assert isinstance(hook, AbstractHook), hook
+ self.hooks.append((triggers, hook))
- # Define some callbacks
- if context is not None:
- def load_data():
- return context.datastore.load(module_name)
+ def del_hook(self, hook, *triggers):
+ from nemubot.hooks import Abstract as AbstractHook
+ assert isinstance(hook, AbstractHook), hook
+ self.hooks.remove((triggers, hook))
- def add_hook(hook, *triggers):
- assert isinstance(hook, AbstractHook), hook
- self.hooks.append((triggers, hook))
- return context.treater.hm.add_hook(hook, *triggers)
+ def subtreat(self, msg):
+ return None
- def del_hook(hook, *triggers):
- assert isinstance(hook, AbstractHook), hook
- self.hooks.remove((triggers, hook))
- return context.treater.hm.del_hooks(*triggers, hook=hook)
+ def add_event(self, evt, eid=None):
+ return self.events.append((evt, eid))
- def subtreat(msg):
- yield from context.treater.treat_msg(msg)
- def add_event(evt, eid=None):
- return context.add_event(evt, eid, module_src=module)
- def del_event(evt):
- return context.del_event(evt, module_src=module)
+ def del_event(self, evt):
+ for i in self.events:
+ e, eid = i
+ if e == evt:
+ self.events.remove(i)
+ return True
+ return False
- def send_response(server, res):
- if server in context.servers:
- if res.server is not None:
- return context.servers[res.server].send_response(res)
- else:
- return context.servers[server].send_response(res)
- else:
- module.logger.error("Try to send a message to the unknown server: %s", server)
- return False
+ def send_response(self, server, res):
+ self.module.logger.info("Send response: %s", res)
- else: # Used when using outside of nemubot
- def load_data():
- from nemubot.tools.xmlparser import module_state
- return module_state.ModuleState("nemubotstate")
-
- def add_hook(hook, *triggers):
- assert isinstance(hook, AbstractHook), hook
- self.hooks.append((triggers, hook))
- def del_hook(hook, *triggers):
- assert isinstance(hook, AbstractHook), hook
- self.hooks.remove((triggers, hook))
- def subtreat(msg):
- return None
- def add_event(evt, eid=None):
- return context.add_event(evt, eid, module_src=module)
- def del_event(evt):
- return context.del_event(evt, module_src=module)
-
- def send_response(server, res):
- module.logger.info("Send response: %s", res)
-
- def save():
- context.datastore.save(module_name, self.data)
-
- def subparse(orig, cnt):
- if orig.server in context.servers:
- return context.servers[orig.server].subparse(orig, cnt)
-
- self.load_data = load_data
- self.add_hook = add_hook
- self.del_hook = del_hook
- self.add_event = add_event
- self.del_event = del_event
- self.save = save
- self.send_response = send_response
- self.subtreat = subtreat
- self.subparse = subparse
+ def save(self):
+ # Don't save if no data has been access
+ if hasattr(self, "_data"):
+ context.datastore.save(self.module_name, self.data)
+ def subparse(self, orig, cnt):
+ if orig.server in self.context.servers:
+ return self.context.servers[orig.server].subparse(orig, cnt)
@property
def data(self):
@@ -120,6 +77,21 @@ class ModuleContext:
self._data = self.load_data()
return self._data
+ @data.setter
+ def data(self, value):
+ assert value is not None
+
+ self._data = value
+
+
+ def register_tags(self, **tags):
+ self.extendtags.update(tags)
+
+
+ def unregister_tags(self, *tags):
+ for t in tags:
+ del self.extendtags[t]
+
def unload(self):
"""Perform actions for unloading the module"""
@@ -129,7 +101,62 @@ class ModuleContext:
self.del_hook(h, *s)
# Remove registered events
- for e in self.events:
- self.del_event(e)
+ for evt, eid, module_src in self.events:
+ self.del_event(evt)
self.save()
+
+
+class ModuleContext(_ModuleContext):
+
+ def __init__(self, context, *args, **kwargs):
+ """Initialize the module context
+
+ arguments:
+ context -- the bot context
+ module -- the module
+ """
+
+ super().__init__(*args, **kwargs)
+
+ # Load module configuration if exists
+ if self.module_name in context.modules_configuration:
+ self.config = context.modules_configuration[self.module_name]
+
+ self.context = context
+ self.debug = context.verbosity > 0
+
+
+ def load_data(self):
+ return self.context.datastore.load(self.module_name, extendsTags=self.extendtags)
+
+ def add_hook(self, hook, *triggers):
+ from nemubot.hooks import Abstract as AbstractHook
+ assert isinstance(hook, AbstractHook), hook
+ self.hooks.append((triggers, hook))
+ return self.context.treater.hm.add_hook(hook, *triggers)
+
+ def del_hook(self, hook, *triggers):
+ from nemubot.hooks import Abstract as AbstractHook
+ assert isinstance(hook, AbstractHook), hook
+ self.hooks.remove((triggers, hook))
+ return self.context.treater.hm.del_hooks(*triggers, hook=hook)
+
+ def subtreat(self, msg):
+ yield from self.context.treater.treat_msg(msg)
+
+ def add_event(self, evt, eid=None):
+ return self.context.add_event(evt, eid, module_src=self.module)
+
+ def del_event(self, evt):
+ return self.context.del_event(evt, module_src=self.module)
+
+ def send_response(self, server, res):
+ if server in self.context.servers:
+ if res.server is not None:
+ return self.context.servers[res.server].send_response(res)
+ else:
+ return self.context.servers[server].send_response(res)
+ else:
+ self.module.logger.error("Try to send a message to the unknown server: %s", server)
+ return False
diff --git a/nemubot/scope.py b/nemubot/scope.py
new file mode 100644
index 0000000..5da1542
--- /dev/null
+++ b/nemubot/scope.py
@@ -0,0 +1,83 @@
+# Nemubot is a smart and modulable IM bot.
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from nemubot.datastore.nodes import Serializable
+
+
+class Scope(Serializable):
+
+
+ serializetag = "nemubot-scope"
+ default_limit = "channel"
+
+
+ def __init__(self, server, channel, nick, limit=default_limit):
+ self._server = server
+ self._channel = channel
+ self._nick = nick
+ self._limit = limit
+
+
+ def sameServer(self, server):
+ return self._server is None or self._server == server
+
+
+ def sameChannel(self, server, channel):
+ return self.sameServer(server) and (self._channel is None or self._channel == channel)
+
+
+ def sameNick(self, server, channel, nick):
+ return self.sameChannel(server, channel) and (self._nick is None or self._nick == nick)
+
+
+ def check(self, scope, limit=None):
+ return self.checkScope(scope._server, scope._channel, scope._nick, limit)
+
+
+ def checkScope(self, server, channel, nick, limit=None):
+ if limit is None: limit = self._limit
+ assert limit == "global" or limit == "server" or limit == "channel" or limit == "nick"
+
+ if limit == "server":
+ return self.sameServer(server)
+ elif limit == "channel":
+ return self.sameChannel(server, channel)
+ elif limit == "nick":
+ return self.sameNick(server, channel, nick)
+ else:
+ return True
+
+
+ def narrow(self, scope):
+ return scope is None or (
+ scope._limit == "global" or
+ (scope._limit == "server" and (self._limit == "nick" or self._limit == "channel")) or
+ (scope._limit == "channel" and self._limit == "nick")
+ )
+
+
+ def serialize(self):
+ from nemubot.datastore.nodes import ParsingNode
+ args = {}
+ if self._server is not None:
+ args["server"] = self._server
+ if self._channel is not None:
+ args["channel"] = self._channel
+ if self._nick is not None:
+ args["nick"] = self._nick
+ if self._limit is not None:
+ args["limit"] = self._limit
+ return ParsingNode(tag=self.serializetag, **args)
diff --git a/nemubot/server/DCC.py b/nemubot/server/DCC.py
index 644a8cb..c1a6852 100644
--- a/nemubot/server/DCC.py
+++ b/nemubot/server/DCC.py
@@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None):
- super().__init__(self)
+ super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion
diff --git a/nemubot/server/IRC.py b/nemubot/server/IRC.py
index 08e2bc5..7469abc 100644
--- a/nemubot/server/IRC.py
+++ b/nemubot/server/IRC.py
@@ -16,21 +16,22 @@
from datetime import datetime
import re
+import socket
from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter
from nemubot.server.message.IRC import IRC as IRCMessage
-from nemubot.server.socket import SocketServer
+from nemubot.server.socket import SocketServer, SecureSocketServer
-class IRC(SocketServer):
+class _IRC:
"""Concrete implementation of a connexion to an IRC server"""
- def __init__(self, host="localhost", port=6667, ssl=False, owner=None,
+ def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None,
- channels=list(), on_connect=None):
+ channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server
Keyword arguments:
@@ -54,7 +55,8 @@ class IRC(SocketServer):
self.owner = owner
self.realname = realname
- super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
+ super().__init__(name=self.username + "@" + host + ":" + str(port),
+ host=host, port=port, **kwargs)
self.printer = IRCPrinter
self.encoding = encoding
@@ -231,20 +233,19 @@ class IRC(SocketServer):
# Open/close
- def open(self):
- if super().open():
- if self.password is not None:
- self.write("PASS :" + self.password)
- if self.capabilities is not None:
- self.write("CAP LS")
- self.write("NICK :" + self.nick)
- self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
- return True
- return False
+ def connect(self):
+ super().connect()
+
+ if self.password is not None:
+ self.write("PASS :" + self.password)
+ if self.capabilities is not None:
+ self.write("CAP LS")
+ self.write("NICK :" + self.nick)
+ self.write("USER %s %s bla :%s" % (self.username, socket.getfqdn(), self.realname))
def close(self):
- if not self.closed:
+ if not self._closed:
self.write("QUIT")
return super().close()
@@ -253,8 +254,8 @@ class IRC(SocketServer):
# Read
- def read(self):
- for line in super().read():
+ def async_read(self):
+ for line in super().async_read():
# PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding)
@@ -273,3 +274,10 @@ class IRC(SocketServer):
def subparse(self, orig, cnt):
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
return msg.to_bot_message(self)
+
+
+class IRC(_IRC, SocketServer):
+ pass
+
+class IRC_secure(_IRC, SecureSocketServer):
+ pass
diff --git a/nemubot/server/__init__.py b/nemubot/server/__init__.py
index 3c88138..6998ef1 100644
--- a/nemubot/server/__init__.py
+++ b/nemubot/server/__init__.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -14,57 +14,64 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import threading
-_lock = threading.Lock()
-
-# Lists for select
-_rlist = []
-_wlist = []
-_xlist = []
-
-
-def factory(uri, **init_args):
- from urllib.parse import urlparse, unquote
+def factory(uri, ssl=False, **init_args):
+ from urllib.parse import urlparse, unquote, parse_qs
o = urlparse(uri)
+ srv = None
+
if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
args = init_args
- modifiers = o.path.split(",")
- target = unquote(modifiers.pop(0)[1:])
-
- if o.scheme == "ircs": args["ssl"] = True
+ if o.scheme == "ircs": ssl = True
if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username
if o.password is not None: args["password"] = o.password
- queries = o.query.split("&")
- for q in queries:
- if "=" in q:
- key, val = tuple(q.split("=", 1))
- else:
- key, val = q, ""
- if key == "msg":
- if "on_connect" not in args:
- args["on_connect"] = []
- args["on_connect"].append("PRIVMSG %s :%s" % (target, unquote(val)))
- elif key == "key":
- if "channels" not in args:
- args["channels"] = []
- args["channels"].append((target, unquote(val)))
- elif key == "pass":
- args["password"] = unquote(val)
- elif key == "charset":
- args["encoding"] = unquote(val)
+ if ssl:
+ try:
+ from ssl import create_default_context
+ args["_context"] = create_default_context()
+ except ImportError:
+ # Python 3.3 compat
+ from ssl import SSLContext, PROTOCOL_TLSv1
+ args["_context"] = SSLContext(PROTOCOL_TLSv1)
+ modifiers = o.path.split(",")
+ target = unquote(modifiers.pop(0)[1:])
+
+ # Read query string
+ params = parse_qs(o.query)
+
+ if "msg" in params:
+ if "on_connect" not in args:
+ args["on_connect"] = []
+ args["on_connect"].append("PRIVMSG %s :%s" % (target, params["msg"]))
+
+ if "key" in params:
+ if "channels" not in args:
+ args["channels"] = []
+ args["channels"].append((target, params["key"]))
+
+ if "pass" in params:
+ args["password"] = params["pass"]
+
+ if "charset" in params:
+ args["encoding"] = params["charset"]
+
+ #
if "channels" not in args and "isnick" not in modifiers:
args["channels"] = [ target ]
- from nemubot.server.IRC import IRC as IRCServer
- return IRCServer(**args)
- else:
- return None
+ if ssl:
+ from nemubot.server.IRC import IRC_secure as SecureIRCServer
+ srv = SecureIRCServer(**args)
+ else:
+ from nemubot.server.IRC import IRC as IRCServer
+ srv = IRCServer(**args)
+
+ return srv
diff --git a/nemubot/server/abstract.py b/nemubot/server/abstract.py
index dc2081d..fd25c2d 100644
--- a/nemubot/server/abstract.py
+++ b/nemubot/server/abstract.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -14,34 +14,30 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import io
import logging
import queue
-from nemubot.server import _lock, _rlist, _wlist, _xlist
+from nemubot.bot import sync_act
-# Extends from IOBase in order to be compatible with select function
-class AbstractServer(io.IOBase):
+
+class AbstractServer:
"""An abstract server: handle communication with an IM server"""
- def __init__(self, name=None, send_callback=None):
+ def __init__(self, name=None, **kwargs):
"""Initialize an abstract server
Keyword argument:
- send_callback -- Callback when developper want to send a message
+ name -- Identifier of the socket, for convinience
"""
self._name = name
- super().__init__()
+ super().__init__(**kwargs)
- self.logger = logging.getLogger("nemubot.server." + self.name)
+ self.logger = logging.getLogger("nemubot.server." + str(self.name))
+ self._readbuffer = b''
self._sending_queue = queue.Queue()
- if send_callback is not None:
- self._send_callback = send_callback
- else:
- self._send_callback = self._write_select
@property
@@ -54,40 +50,28 @@ class AbstractServer(io.IOBase):
# Open/close
- def __enter__(self):
- self.open()
- return self
+ def connect(self, *args, **kwargs):
+ """Register the server in _poll"""
+
+ self.logger.info("Opening connection")
+
+ super().connect(*args, **kwargs)
+
+ self._on_connect()
+
+ def _on_connect(self):
+ sync_act("sckt", "register", self.fileno())
- def __exit__(self, type, value, traceback):
- self.close()
+ def close(self, *args, **kwargs):
+ """Unregister the server from _poll"""
+ self.logger.info("Closing connection")
- def open(self):
- """Generic open function that register the server un _rlist in case
- of successful _open"""
- self.logger.info("Opening connection to %s", self.id)
- if not hasattr(self, "_open") or self._open():
- _rlist.append(self)
- _xlist.append(self)
- return True
- return False
+ if self.fileno() > 0:
+ sync_act("sckt", "unregister", self.fileno())
-
- def close(self):
- """Generic close function that register the server un _{r,w,x}list in
- case of successful _close"""
- self.logger.info("Closing connection to %s", self.id)
- with _lock:
- if not hasattr(self, "_close") or self._close():
- if self in _rlist:
- _rlist.remove(self)
- if self in _wlist:
- _wlist.remove(self)
- if self in _xlist:
- _xlist.remove(self)
- return True
- return False
+ super().close(*args, **kwargs)
# Writes
@@ -99,13 +83,16 @@ class AbstractServer(io.IOBase):
message -- message to send
"""
- self._send_callback(message)
+ self._sending_queue.put(self.format(message))
+ self.logger.debug("Message '%s' appended to write queue", message)
+ sync_act("sckt", "write", self.fileno())
- def write_select(self):
- """Internal function used by the select function"""
+ def async_write(self):
+ """Internal function used when the file descriptor is writable"""
+
try:
- _wlist.remove(self)
+ sync_act("sckt", "unwrite", self.fileno())
while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done()
@@ -114,19 +101,6 @@ class AbstractServer(io.IOBase):
pass
- def _write_select(self, message):
- """Send a message to the server safely through select
-
- Argument:
- message -- message to send
- """
-
- self._sending_queue.put(self.format(message))
- self.logger.debug("Message '%s' appended to write queue", message)
- if self not in _wlist:
- _wlist.append(self)
-
-
def send_response(self, response):
"""Send a formated Message class
@@ -149,13 +123,39 @@ class AbstractServer(io.IOBase):
# Read
+ def async_read(self):
+ """Internal function used when the file descriptor is readable
+
+ Returns:
+ A list of fully received messages
+ """
+
+ ret, self._readbuffer = self.lex(self._readbuffer + self.read())
+
+ for r in ret:
+ yield r
+
+
+ def lex(self, buf):
+ """Assume lexing in default case is per line
+
+ Argument:
+ buf -- buffer to lex
+ """
+
+ msgs = buf.split(b'\r\n')
+ partial = msgs.pop()
+
+ return msgs, partial
+
+
def parse(self, msg):
raise NotImplemented
# Exceptions
- def exception(self):
- """Exception occurs in fd"""
- self.logger.warning("Unhandle file descriptor exception on server %s",
- self.name)
+ def exception(self, flags):
+ """Exception occurs on fd"""
+
+ self.close()
diff --git a/nemubot/server/factory_test.py b/nemubot/server/factory_test.py
index cc7d35b..e2b6752 100644
--- a/nemubot/server/factory_test.py
+++ b/nemubot/server/factory_test.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import socket
import unittest
from nemubot.server import factory
@@ -22,34 +23,36 @@ class TestFactory(unittest.TestCase):
def test_IRC1(self):
from nemubot.server.IRC import IRC as IRCServer
+ from nemubot.server.IRC import IRC_secure as IRCSServer
# : If omitted, the client must connect to a prespecified default IRC server.
server = factory("irc:///")
self.assertIsInstance(server, IRCServer)
- self.assertEqual(server.host, "localhost")
- self.assertFalse(server.ssl)
+ self.assertEqual(server._sockaddr,
+ socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("ircs:///")
- self.assertIsInstance(server, IRCServer)
- self.assertEqual(server.host, "localhost")
- self.assertTrue(server.ssl)
+ self.assertIsInstance(server, IRCSServer)
+ self.assertEqual(server._sockaddr,
+ socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
- server = factory("irc://host1")
+ server = factory("irc://freenode.net")
self.assertIsInstance(server, IRCServer)
- self.assertEqual(server.host, "host1")
- self.assertFalse(server.ssl)
+ self.assertEqual(server._sockaddr,
+ socket.getaddrinfo("freenode.net", 6667, proto=socket.IPPROTO_TCP)[0][4])
- server = factory("irc://host2:6667")
+ server = factory("irc://freenode.org:1234")
self.assertIsInstance(server, IRCServer)
- self.assertEqual(server.host, "host2")
- self.assertEqual(server.port, 6667)
- self.assertFalse(server.ssl)
+ self.assertEqual(server._sockaddr,
+ socket.getaddrinfo("freenode.org", 1234, proto=socket.IPPROTO_TCP)[0][4])
- server = factory("ircs://host3:194/")
- self.assertIsInstance(server, IRCServer)
- self.assertEqual(server.host, "host3")
- self.assertEqual(server.port, 194)
- self.assertTrue(server.ssl)
+ server = factory("ircs://nemunai.re:194/")
+ self.assertIsInstance(server, IRCSServer)
+ self.assertEqual(server._sockaddr,
+ socket.getaddrinfo("nemunai.re", 194, proto=socket.IPPROTO_TCP)[0][4])
+
+ with self.assertRaises(socket.gaierror):
+ factory("irc://_nonexistent.nemunai.re")
if __name__ == '__main__':
diff --git a/nemubot/server/message/IRC.py b/nemubot/server/message/IRC.py
index 4c9e280..5ccd735 100644
--- a/nemubot/server/message/IRC.py
+++ b/nemubot/server/message/IRC.py
@@ -150,7 +150,8 @@ class IRC(Abstract):
"date": self.tags["time"],
"to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers],
- "frm": self.nick
+ "frm": self.nick,
+ "frm_owner": self.nick == srv.owner
}
# If CTCP, remove 0x01
diff --git a/nemubot/server/message/__init__.py b/nemubot/server/message/__init__.py
new file mode 100644
index 0000000..57f3468
--- /dev/null
+++ b/nemubot/server/message/__init__.py
@@ -0,0 +1,15 @@
+# Nemubot is a smart and modulable IM bot.
+# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
diff --git a/nemubot/server/socket.py b/nemubot/server/socket.py
index 13ac9bd..612f4cb 100644
--- a/nemubot/server/socket.py
+++ b/nemubot/server/socket.py
@@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
-# Copyright (C) 2012-2015 Mercier Pierre-Olivier
+# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -14,117 +14,33 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import os
+import socket
+import ssl
+
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer
-class SocketServer(AbstractServer):
+class _Socket(AbstractServer):
- """Concrete implementation of a socket connexion (can be wrapped with TLS)"""
+ """Concrete implementation of a socket connection"""
- def __init__(self, sock_location=None,
- host=None, port=None,
- sock=None,
- ssl=False,
- name=None):
+ def __init__(self, printer=SocketPrinter, **kwargs):
"""Create a server socket
-
- Keyword arguments:
- sock_location -- Path to the UNIX socket
- host -- Hostname of the INET socket
- port -- Port of the INET socket
- sock -- Already connected socket
- ssl -- Should TLS connection enabled
- name -- Convinience name
"""
- import socket
-
- assert(sock is None or isinstance(sock, socket.SocketType))
- assert(port is None or isinstance(port, int))
-
- super().__init__(name=name)
-
- if sock is None:
- if sock_location is not None:
- self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- self.connect_to = sock_location
- elif host is not None:
- for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
- self.socket = socket.socket(af, socktype, proto)
- self.connect_to = sa
- break
- else:
- self.socket = sock
-
- self.ssl = ssl
+ super().__init__(**kwargs)
self.readbuffer = b''
- self.printer = SocketPrinter
-
-
- def fileno(self):
- return self.socket.fileno() if self.socket else None
-
-
- @property
- def closed(self):
- """Indicator of the connection aliveness"""
- return self.socket._closed
-
-
- # Open/close
-
- def open(self):
- if not self.closed:
- return True
-
- try:
- self.socket.connect(self.connect_to)
- self.logger.info("Connected to %s", self.connect_to)
- except:
- self.socket.close()
- self.logger.exception("Unable to connect to %s",
- self.connect_to)
- return False
-
- # Wrap the socket for SSL
- if self.ssl:
- import ssl
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- self.socket = ctx.wrap_socket(self.socket)
-
- return super().open()
-
-
- def close(self):
- import socket
-
- # Flush the sending queue before close
- from nemubot.server import _lock
- _lock.release()
- self._sending_queue.join()
- _lock.acquire()
-
- if not self.closed:
- try:
- self.socket.shutdown(socket.SHUT_RDWR)
- except socket.error:
- pass
-
- self.socket.close()
-
- return super().close()
+ self.printer = printer
# Write
def _write(self, cnt):
- if self.closed:
- return
-
- self.socket.sendall(cnt)
+ self.sendall(cnt)
def format(self, txt):
@@ -136,19 +52,12 @@ class SocketServer(AbstractServer):
# Read
- def read(self):
- if self.closed:
- return []
-
- raw = self.socket.recv(1024)
- temp = (self.readbuffer + raw).split(b'\r\n')
- self.readbuffer = temp.pop()
-
- for line in temp:
- yield line
+ def recv(self, n=1024):
+ return super().recv(n)
def parse(self, line):
+ """Implement a default behaviour for socket"""
import shlex
line = line.strip().decode()
@@ -157,48 +66,107 @@ class SocketServer(AbstractServer):
except ValueError:
args = line.split(' ')
- yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
+ if len(args):
+ yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
-class SocketListener(AbstractServer):
-
- def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
- super().__init__(name=name)
- self.new_server_cb = new_server_cb
- self.sock_location = sock_location
- self.host = host
- self.port = port
- self.ssl = ssl
- self.nb_son = 0
+ def subparse(self, orig, cnt):
+ for m in self.parse(cnt):
+ m.to = orig.to
+ m.frm = orig.frm
+ m.date = orig.date
+ yield m
- def fileno(self):
- return self.socket.fileno() if self.socket else None
+class _SocketServer(_Socket):
+
+ def __init__(self, host, port, bind=None, **kwargs):
+ (family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
+
+ if isinstance(self, ssl.SSLSocket) and "server_hostname" not in kwargs:
+ kwargs["server_hostname"] = host
+
+ super().__init__(family=family, type=type, proto=proto, **kwargs)
+
+ self._sockaddr = sockaddr
+ self._bind = bind
- @property
- def closed(self):
- """Indicator of the connection aliveness"""
- return self.socket is None
+ def connect(self):
+ self.logger.info("Connection to %s:%d", *self._sockaddr[:2])
+ super().connect(self._sockaddr)
+
+ if self._bind:
+ super().bind(self._bind)
- def open(self):
- import os
- import socket
+class SocketServer(_SocketServer, socket.socket):
+ pass
- if self.sock_location is not None:
- self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- try:
- os.remove(self.sock_location)
- except FileNotFoundError:
- pass
- self.socket.bind(self.sock_location)
- elif self.host is not None and self.port is not None:
- self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.socket.bind((self.host, self.port))
- self.socket.listen(5)
- return super().open()
+class SecureSocketServer(_SocketServer, ssl.SSLSocket):
+ pass
+
+
+class UnixSocket:
+
+ def __init__(self, location, **kwargs):
+ super().__init__(family=socket.AF_UNIX, **kwargs)
+
+ self._socket_path = location
+
+
+ def connect(self):
+ self.logger.info("Connection to unix://%s", self._socket_path)
+ super().connect(self._socket_path)
+
+
+class SocketClient(_Socket, socket.socket):
+
+ def read(self):
+ return self.recv()
+
+
+class _Listener:
+
+ def __init__(self, new_server_cb, instanciate=SocketClient, **kwargs):
+ super().__init__(**kwargs)
+
+ self._instanciate = instanciate
+ self._new_server_cb = new_server_cb
+
+
+ def read(self):
+ conn, addr = self.accept()
+ fileno = conn.fileno()
+ self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
+
+ ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
+ ss.connect = ss._on_connect
+ self._new_server_cb(ss, autoconnect=True)
+
+ return b''
+
+
+class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+
+ def connect(self):
+ self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
+
+ try:
+ os.remove(self._socket_path)
+ except FileNotFoundError:
+ pass
+
+ self.bind(self._socket_path)
+ self.listen(5)
+ self.logger.info("Socket ready for accepting new connections")
+
+ self._on_connect()
def close(self):
@@ -206,25 +174,14 @@ class SocketListener(AbstractServer):
import socket
try:
- self.socket.shutdown(socket.SHUT_RDWR)
- self.socket.close()
- if self.sock_location is not None:
- os.remove(self.sock_location)
+ self.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
- return super().close()
+ super().close()
-
- # Read
-
- def read(self):
- if self.closed:
- return []
-
- conn, addr = self.socket.accept()
- self.nb_son += 1
- ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
- self.new_server_cb(ss)
-
- return []
+ try:
+ if self._socket_path is not None:
+ os.remove(self._socket_path)
+ except:
+ pass
diff --git a/nemubot/tools/web.py b/nemubot/tools/web.py
index d35740c..0852664 100644
--- a/nemubot/tools/web.py
+++ b/nemubot/tools/web.py
@@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from urllib.parse import urlparse, urlsplit, urlunsplit
+from urllib.parse import urljoin, urlparse, urlsplit, urlunsplit
from nemubot.exception import IMException
@@ -108,6 +108,9 @@ def getURLContent(url, body=None, timeout=7, header=None):
elif "User-agent" not in header:
header["User-agent"] = "Nemubot v%s" % __version__
+ if body is not None and "Content-Type" not in header:
+ header["Content-Type"] = "application/x-www-form-urlencoded"
+
import socket
try:
if o.query != '':
@@ -156,21 +159,23 @@ def getURLContent(url, body=None, timeout=7, header=None):
elif ((res.status == http.client.FOUND or
res.status == http.client.MOVED_PERMANENTLY) and
res.getheader("Location") != url):
- return getURLContent(res.getheader("Location"), timeout=timeout)
+ return getURLContent(
+ urljoin(url, res.getheader("Location")),
+ body=body,
+ timeout=timeout,
+ header=header)
else:
raise IMException("A HTTP error occurs: %d - %s" %
(res.status, http.client.responses[res.status]))
-def getXML(url, timeout=7):
+def getXML(*args, **kwargs):
"""Get content page and return XML parsed content
- Arguments:
- url -- the URL to get
- timeout -- maximum number of seconds to wait before returning an exception
+ Arguments: same as getURLContent
"""
- cnt = getURLContent(url, timeout=timeout)
+ cnt = getURLContent(*args, **kwargs)
if cnt is None:
return None
else:
@@ -178,15 +183,13 @@ def getXML(url, timeout=7):
return parseString(cnt)
-def getJSON(url, timeout=7):
+def getJSON(*args, **kwargs):
"""Get content page and return JSON content
- Arguments:
- url -- the URL to get
- timeout -- maximum number of seconds to wait before returning an exception
+ Arguments: same as getURLContent
"""
- cnt = getURLContent(url, timeout=timeout)
+ cnt = getURLContent(*args, **kwargs)
if cnt is None:
return None
else:
diff --git a/nemubot/tools/xmlparser/__init__.py b/nemubot/tools/xmlparser/__init__.py
index abc5bb9..687bf63 100644
--- a/nemubot/tools/xmlparser/__init__.py
+++ b/nemubot/tools/xmlparser/__init__.py
@@ -51,11 +51,13 @@ class XMLParser:
def __init__(self, knodes):
self.knodes = knodes
+ def _reset(self):
self.stack = list()
self.child = 0
def parse_file(self, path):
+ self._reset()
p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement
@@ -69,6 +71,7 @@ class XMLParser:
def parse_string(self, s):
+ self._reset()
p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement
@@ -126,10 +129,13 @@ class XMLParser:
if hasattr(self.current, "endElement"):
self.current.endElement(None)
+ if hasattr(self.current, "parsedForm") and callable(self.current.parsedForm):
+ self.stack[-1] = self.current.parsedForm()
+
# Don't remove root
if len(self.stack) > 1:
last = self.stack.pop()
- if hasattr(self.current, "addChild"):
+ if hasattr(self.current, "addChild") and callable(self.current.addChild):
if self.current.addChild(name, last):
return
raise TypeError(name + " tag not expected in " + self.display_stack())
diff --git a/nemubot/treatment.py b/nemubot/treatment.py
index 2c1955d..4f629e0 100644
--- a/nemubot/treatment.py
+++ b/nemubot/treatment.py
@@ -15,6 +15,7 @@
# along with this program. If not, see .
import logging
+import types
logger = logging.getLogger("nemubot.treatment")
@@ -108,6 +109,9 @@ class MessageTreater:
msg -- message to treat
"""
+ if hasattr(msg, "frm_owner"):
+ msg.frm_owner = (not hasattr(msg.server, "owner") or msg.server.owner == msg.frm)
+
while hook is not None:
res = hook.run(msg)
@@ -116,10 +120,18 @@ class MessageTreater:
yield r
elif res is not None:
- if not hasattr(res, "server") or res.server is None:
- res.server = msg.server
+ if isinstance(res, types.GeneratorType):
+ for r in res:
+ if not hasattr(r, "server") or r.server is None:
+ r.server = msg.server
- yield res
+ yield r
+
+ else:
+ if not hasattr(res, "server") or res.server is None:
+ res.server = msg.server
+
+ yield res
hook = next(hook_gen, None)
diff --git a/setup.py b/setup.py
index 36dddb4..a400c3c 100755
--- a/setup.py
+++ b/setup.py
@@ -63,6 +63,7 @@ setup(
'nemubot',
'nemubot.config',
'nemubot.datastore',
+ 'nemubot.datastore.nodes',
'nemubot.event',
'nemubot.exception',
'nemubot.hooks',