Compare commits

...

30 commits

Author SHA1 Message Date
2ac37f79d9 New class storing scope 2017-07-18 00:50:30 +02:00
e3150a6061 Messages now implements Serializable 2017-07-18 00:50:30 +02:00
e0af09f3c5 New printer and parser for bot data, XML-based 2017-07-18 00:50:30 +02:00
2992c13ca7 socket: limit getaddrinfo to TCP connections 2017-07-18 00:48:30 +02:00
49c207a4c9 events: fix help when no event is defined 2017-07-18 00:48:30 +02:00
3c13043ca3 run: recreate the sync_queue on run, it seems to have strange behaviour when created before the fork 2017-07-18 00:48:30 +02:00
5b28828ede event: ensure that enough consumers are launched at the end of an event 2017-07-18 00:48:30 +02:00
2a25f5311a rename module nextstop: ratp to avoid import loop with the inderlying Python module 2017-07-18 00:48:30 +02:00
1db63600e5 main: new option -A to run as daemon 2017-07-18 00:48:30 +02:00
5f89428562 Use getaddrinfo to create the right socket 2017-07-18 00:48:30 +02:00
205a39ad70 Try to restaure frm_owner flag 2017-07-18 00:42:11 +02:00
c51b0a9170 When launched in daemon mode, attach to the socket 2017-07-18 00:42:11 +02:00
eb70fe560b Deamonize later 2017-07-18 00:42:11 +02:00
8b6f72587d Local client now detects when server close the connection 2017-07-18 00:42:11 +02:00
02838658b0 Fix communication over unix socket 2017-07-18 00:42:11 +02:00
1a813083e5 Handle multiple SIGTERM 2017-07-18 00:42:11 +02:00
3fbeb49a6c suivi: add fedex 2017-07-18 00:42:11 +02:00
f5f13202c5 suivi: use getURLContent instead of call to urllib 2017-07-18 00:42:11 +02:00
aeba947877 tools/web: fill a default Content-Type in case of POST 2017-07-18 00:42:11 +02:00
b27e01a196 tools/web: improve redirection reliability 2017-07-18 00:42:11 +02:00
4d68410777 tools/web: forward all arguments passed to getJSON and getXML to getURLContent 2017-07-18 00:42:11 +02:00
384fbc6717 Update weather module: refleting forcastAPI changes 2017-07-18 00:42:11 +02:00
63e65b2659 modulecontext: use inheritance instead of conditional init 2017-07-18 00:42:11 +02:00
bdf8a69ff0 Avoid stack-trace and DOS if event is not well formed 2017-07-18 00:42:11 +02:00
4f1dcb8524 [nextstop] Use as system wide module 2017-07-18 00:42:11 +02:00
6d2f90fe77 Allow module function to be generators 2017-07-18 00:42:10 +02:00
2e5834a89d Parse server urls using parse_qs 2017-07-18 00:42:10 +02:00
26d1f5b6e8 Format and typo 2017-07-18 00:42:10 +02:00
40fc84fcec Implement socket server subparse 2017-07-18 00:42:10 +02:00
449cb684f9 Refactor file/socket management (use poll instead of select) 2017-07-18 00:42:08 +02:00
36 changed files with 1244 additions and 700 deletions

3
.gitmodules vendored
View file

@ -1,3 +0,0 @@
[submodule "modules/nextstop/external"]
path = modules/nextstop/external
url = git://github.com/nbr23/NextStop.git

View file

@ -9,6 +9,8 @@ Requirements
*nemubot* requires at least Python 3.3 to work. *nemubot* requires at least Python 3.3 to work.
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
Some modules (like `cve`, `nextstop` or `laposte`) require the Some modules (like `cve`, `nextstop` or `laposte`) require the
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/), [BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
but the core and framework has no dependency. but the core and framework has no dependency.

View file

@ -16,7 +16,7 @@ from more import Response
def help_full (): def help_full ():
return "This module store a lot of events: ny, we, " + (", ".join(context.datas.index.keys())) + "\n!eventslist: gets list of timer\n!start /something/: launch a timer" return "This module store a lot of events: ny, we, " + (", ".join(context.datas.index.keys() if hasattr(context, "datas") else [])) + "\n!eventslist: gets list of timer\n!start /something/: launch a timer"
def load(context): def load(context):

View file

@ -1,4 +0,0 @@
<?xml version="1.0" ?>
<nemubotmodule name="nextstop">
<message type="cmd" name="ratp" call="ask_ratp" />
</nemubotmodule>

View file

@ -1,55 +0,0 @@
# coding=utf-8
"""Informe les usagers des prochains passages des transports en communs de la RATP"""
from nemubot.exception import IMException
from nemubot.hooks import hook
from more import Response
nemubotversion = 3.4
from .external.src import ratp
def help_full ():
return "!ratp transport line [station]: Donne des informations sur les prochains passages du transport en commun séléctionné à l'arrêt désiré. Si aucune station n'est précisée, les liste toutes."
@hook.command("ratp")
def ask_ratp(msg):
"""Hook entry from !ratp"""
if len(msg.args) >= 3:
transport = msg.args[0]
line = msg.args[1]
station = msg.args[2]
if len(msg.args) == 4:
times = ratp.getNextStopsAtStation(transport, line, station, msg.args[3])
else:
times = ratp.getNextStopsAtStation(transport, line, station)
if len(times) == 0:
raise IMException("la station %s n'existe pas sur le %s ligne %s." % (station, transport, line))
(time, direction, stationname) = times[0]
return Response(message=["\x03\x02%s\x03\x02 direction %s" % (time, direction) for time, direction, stationname in times],
title="Prochains passages du %s ligne %s à l'arrêt %s" % (transport, line, stationname),
channel=msg.channel)
elif len(msg.args) == 2:
stations = ratp.getAllStations(msg.args[0], msg.args[1])
if len(stations) == 0:
raise IMException("aucune station trouvée.")
return Response([s for s in stations], title="Stations", channel=msg.channel)
else:
raise IMException("Mauvais usage, merci de spécifier un type de transport et une ligne, ou de consulter l'aide du module.")
@hook.command("ratp_alert")
def ratp_alert(msg):
if len(msg.args) == 2:
transport = msg.args[0]
cause = msg.args[1]
incidents = ratp.getDisturbance(cause, transport)
return Response(incidents, channel=msg.channel, nomore="No more incidents", count=" (%d more incidents)")
else:
raise IMException("Mauvais usage, merci de spécifier un type de transport et un type d'alerte (alerte, manif, travaux), ou de consulter l'aide du module.")

@ -1 +0,0 @@
Subproject commit 3d5c9b2d52fbd214f5aaad00e5f3952de919b3e5

74
modules/ratp.py Normal file
View file

@ -0,0 +1,74 @@
"""Informe les usagers des prochains passages des transports en communs de la RATP"""
# PYTHON STUFFS #######################################################
from nemubot.exception import IMException
from nemubot.hooks import hook
from more import Response
from nextstop import ratp
@hook.command("ratp",
help="Affiche les prochains horaires de passage",
help_usage={
"TRANSPORT": "Affiche les lignes du moyen de transport donné",
"TRANSPORT LINE": "Affiche les stations sur la ligne de transport donnée",
"TRANSPORT LINE STATION": "Affiche les prochains horaires de passage à l'arrêt donné",
"TRANSPORT LINE STATION DESTINATION": "Affiche les prochains horaires de passage dans la direction donnée",
})
def ask_ratp(msg):
l = len(msg.args)
transport = msg.args[0] if l > 0 else None
line = msg.args[1] if l > 1 else None
station = msg.args[2] if l > 2 else None
direction = msg.args[3] if l > 3 else None
if station is not None:
times = sorted(ratp.getNextStopsAtStation(transport, line, station, direction), key=lambda i: i[0])
if len(times) == 0:
raise IMException("la station %s n'existe pas sur le %s ligne %s." % (station, transport, line))
(time, direction, stationname) = times[0]
return Response(message=["\x03\x02%s\x03\x02 direction %s" % (time, direction) for time, direction, stationname in times],
title="Prochains passages du %s ligne %s à l'arrêt %s" % (transport, line, stationname),
channel=msg.channel)
elif line is not None:
stations = ratp.getAllStations(transport, line)
if len(stations) == 0:
raise IMException("aucune station trouvée.")
return Response(stations, title="Stations", channel=msg.channel)
elif transport is not None:
lines = ratp.getTransportLines(transport)
if len(lines) == 0:
raise IMException("aucune ligne trouvée.")
return Response(lines, title="Lignes", channel=msg.channel)
else:
raise IMException("précise au moins un moyen de transport.")
@hook.command("ratp_alert",
help="Affiche les perturbations en cours sur le réseau")
def ratp_alert(msg):
if len(msg.args) == 0:
raise IMException("précise au moins un moyen de transport.")
l = len(msg.args)
transport = msg.args[0] if l > 0 else None
line = msg.args[1] if l > 1 else None
if line is not None:
d = ratp.getDisturbanceFromLine(transport, line)
if "date" in d and d["date"] is not None:
incidents = "Au {date[date]}, {title}: {message}".format(**d)
else:
incidents = "{title}: {message}".format(**d)
else:
incidents = ratp.getDisturbance(None, transport)
return Response(incidents, channel=msg.channel, nomore="No more incidents", count=" (%d more incidents)")

View file

@ -2,14 +2,14 @@
# PYTHON STUFF ############################################ # PYTHON STUFF ############################################
import urllib.request import json
import urllib.parse import urllib.parse
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import re import re
from nemubot.hooks import hook from nemubot.hooks import hook
from nemubot.exception import IMException from nemubot.exception import IMException
from nemubot.tools.web import getURLContent from nemubot.tools.web import getURLContent, getJSON
from more import Response from more import Response
@ -17,8 +17,7 @@ from more import Response
def get_tnt_info(track_id): def get_tnt_info(track_id):
values = [] values = []
data = getURLContent('www.tnt.fr/public/suivi_colis/recherche/' data = getURLContent('www.tnt.fr/public/suivi_colis/recherche/visubontransport.do?bonTransport=%s' % track_id)
'visubontransport.do?bonTransport=%s' % track_id)
soup = BeautifulSoup(data) soup = BeautifulSoup(data)
status_list = soup.find('div', class_='result__content') status_list = soup.find('div', class_='result__content')
if not status_list: if not status_list:
@ -32,8 +31,7 @@ def get_tnt_info(track_id):
def get_colissimo_info(colissimo_id): def get_colissimo_info(colissimo_id):
colissimo_data = getURLContent("http://www.colissimo.fr/portail_colissimo/" colissimo_data = getURLContent("http://www.colissimo.fr/portail_colissimo/suivre.do?colispart=%s" % colissimo_id)
"suivre.do?colispart=%s" % colissimo_id)
soup = BeautifulSoup(colissimo_data) soup = BeautifulSoup(colissimo_data)
dataArray = soup.find(class_='dataArray') dataArray = soup.find(class_='dataArray')
@ -47,9 +45,8 @@ def get_colissimo_info(colissimo_id):
def get_chronopost_info(track_id): def get_chronopost_info(track_id):
data = urllib.parse.urlencode({'listeNumeros': track_id}) data = urllib.parse.urlencode({'listeNumeros': track_id})
track_baseurl = "http://www.chronopost.fr/expedier/" \ track_baseurl = "http://www.chronopost.fr/expedier/inputLTNumbersNoJahia.do?lang=fr_FR"
"inputLTNumbersNoJahia.do?lang=fr_FR" track_data = getURLContent(track_baseurl, data.encode('utf-8'))
track_data = urllib.request.urlopen(track_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(track_data) soup = BeautifulSoup(track_data)
infoClass = soup.find(class_='numeroColi2') infoClass = soup.find(class_='numeroColi2')
@ -65,9 +62,8 @@ def get_chronopost_info(track_id):
def get_colisprive_info(track_id): def get_colisprive_info(track_id):
data = urllib.parse.urlencode({'numColis': track_id}) data = urllib.parse.urlencode({'numColis': track_id})
track_baseurl = "https://www.colisprive.com/moncolis/pages/" \ track_baseurl = "https://www.colisprive.com/moncolis/pages/detailColis.aspx"
"detailColis.aspx" track_data = getURLContent(track_baseurl, data.encode('utf-8'))
track_data = urllib.request.urlopen(track_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(track_data) soup = BeautifulSoup(track_data)
dataArray = soup.find(class_='BandeauInfoColis') dataArray = soup.find(class_='BandeauInfoColis')
@ -82,8 +78,7 @@ def get_laposte_info(laposte_id):
data = urllib.parse.urlencode({'id': laposte_id}) data = urllib.parse.urlencode({'id': laposte_id})
laposte_baseurl = "http://www.part.csuivi.courrier.laposte.fr/suivi/index" laposte_baseurl = "http://www.part.csuivi.courrier.laposte.fr/suivi/index"
laposte_data = urllib.request.urlopen(laposte_baseurl, laposte_data = getURLContent(laposte_baseurl, data.encode('utf-8'))
data.encode('utf-8'))
soup = BeautifulSoup(laposte_data) soup = BeautifulSoup(laposte_data)
search_res = soup.find(class_='resultat_rech_simple_table').tbody.tr search_res = soup.find(class_='resultat_rech_simple_table').tbody.tr
if (soup.find(class_='resultat_rech_simple_table').thead if (soup.find(class_='resultat_rech_simple_table').thead
@ -112,8 +107,7 @@ def get_postnl_info(postnl_id):
data = urllib.parse.urlencode({'barcodes': postnl_id}) data = urllib.parse.urlencode({'barcodes': postnl_id})
postnl_baseurl = "http://www.postnl.post/details/" postnl_baseurl = "http://www.postnl.post/details/"
postnl_data = urllib.request.urlopen(postnl_baseurl, postnl_data = getURLContent(postnl_baseurl, data.encode('utf-8'))
data.encode('utf-8'))
soup = BeautifulSoup(postnl_data) soup = BeautifulSoup(postnl_data)
if (soup.find(id='datatables') if (soup.find(id='datatables')
and soup.find(id='datatables').tbody and soup.find(id='datatables').tbody
@ -132,6 +126,42 @@ def get_postnl_info(postnl_id):
return (post_status.lower(), post_destination, post_date) return (post_status.lower(), post_destination, post_date)
def get_fedex_info(fedex_id, lang="en_US"):
data = urllib.parse.urlencode({
'data': json.dumps({
"TrackPackagesRequest": {
"appType": "WTRK",
"appDeviceType": "DESKTOP",
"uniqueKey": "",
"processingParameters": {},
"trackingInfoList": [
{
"trackNumberInfo": {
"trackingNumber": str(fedex_id),
"trackingQualifier": "",
"trackingCarrier": ""
}
}
]
}
}),
'action': "trackpackages",
'locale': lang,
'version': 1,
'format': "json"
})
fedex_baseurl = "https://www.fedex.com/trackingCal/track"
fedex_data = getJSON(fedex_baseurl, data.encode('utf-8'))
if ("TrackPackagesResponse" in fedex_data and
"packageList" in fedex_data["TrackPackagesResponse"] and
len(fedex_data["TrackPackagesResponse"]["packageList"]) and
not fedex_data["TrackPackagesResponse"]["packageList"][0]["isInvalid"]
):
return fedex_data["TrackPackagesResponse"]["packageList"][0]
# TRACKING HANDLERS ################################################### # TRACKING HANDLERS ###################################################
def handle_tnt(tracknum): def handle_tnt(tracknum):
@ -189,6 +219,17 @@ def handle_coliprive(tracknum):
return ("Colis Privé: \x02%s\x0F : \x02%s\x0F." % (tracknum, info)) return ("Colis Privé: \x02%s\x0F : \x02%s\x0F." % (tracknum, info))
def handle_fedex(tracknum):
info = get_fedex_info(tracknum)
if info:
if info["displayActDeliveryDateTime"] != "":
return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, delivered on: {displayActDeliveryDateTime}.".format(**info))
elif info["statusLocationCity"] != "":
return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: estimated delivery: {displayEstDeliveryDateTime}.".format(**info))
else:
return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, estimated delivery: {displayEstDeliveryDateTime}.".format(**info))
TRACKING_HANDLERS = { TRACKING_HANDLERS = {
'laposte': handle_laposte, 'laposte': handle_laposte,
'postnl': handle_postnl, 'postnl': handle_postnl,
@ -196,6 +237,7 @@ TRACKING_HANDLERS = {
'chronopost': handle_chronopost, 'chronopost': handle_chronopost,
'coliprive': handle_coliprive, 'coliprive': handle_coliprive,
'tnt': handle_tnt, 'tnt': handle_tnt,
'fedex': handle_fedex,
} }

View file

@ -1,6 +1,6 @@
# coding=utf-8 # coding=utf-8
"""The weather module""" """The weather module. Powered by Dark Sky <https://darksky.net/poweredby/>"""
import datetime import datetime
import re import re
@ -17,7 +17,7 @@ nemubotversion = 4.0
from more import Response from more import Response
URL_DSAPI = "https://api.forecast.io/forecast/%s/%%s,%%s" URL_DSAPI = "https://api.darksky.net/forecast/%s/%%s,%%s?lang=%%s&units=%%s"
def load(context): def load(context):
if not context.config or "darkskyapikey" not in context.config: if not context.config or "darkskyapikey" not in context.config:
@ -30,34 +30,19 @@ def load(context):
URL_DSAPI = URL_DSAPI % context.config["darkskyapikey"] URL_DSAPI = URL_DSAPI % context.config["darkskyapikey"]
def help_full ():
return "!weather /city/: Display the current weather in /city/."
def fahrenheit2celsius(temp):
return int((temp - 32) * 50/9)/10
def mph2kmph(speed):
return int(speed * 160.9344)/100
def inh2mmh(size):
return int(size * 254)/10
def format_wth(wth): def format_wth(wth):
return ("%s °C %s; precipitation (%s %% chance) intensity: %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" % return ("%s °C %s; precipitation (%s %% chance) intensity: %s mm/h; relative humidity: %s %%; wind speed: %s m/s %s°; cloud coverage: %s %%; pressure: %s hPa; visibility: %s km; ozone: %s DU" %
( (
fahrenheit2celsius(wth["temperature"]), wth["temperature"],
wth["summary"], wth["summary"],
int(wth["precipProbability"] * 100), int(wth["precipProbability"] * 100),
inh2mmh(wth["precipIntensity"]), wth["precipIntensity"],
int(wth["humidity"] * 100), int(wth["humidity"] * 100),
mph2kmph(wth["windSpeed"]), wth["windSpeed"],
wth["windBearing"], wth["windBearing"],
int(wth["cloudCover"] * 100), int(wth["cloudCover"] * 100),
int(wth["pressure"]), int(wth["pressure"]),
int(wth["visibility"]),
int(wth["ozone"]) int(wth["ozone"])
)) ))
@ -66,7 +51,7 @@ def format_forecast_daily(wth):
return ("%s; between %s-%s °C; precipitation (%s %% chance) intensity: maximum %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" % return ("%s; between %s-%s °C; precipitation (%s %% chance) intensity: maximum %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" %
( (
wth["summary"], wth["summary"],
fahrenheit2celsius(wth["temperatureMin"]), fahrenheit2celsius(wth["temperatureMax"]), wth["temperatureMin"], wth["temperatureMax"],
int(wth["precipProbability"] * 100), int(wth["precipProbability"] * 100),
inh2mmh(wth["precipIntensityMax"]), inh2mmh(wth["precipIntensityMax"]),
int(wth["humidity"] * 100), int(wth["humidity"] * 100),
@ -126,8 +111,8 @@ def treat_coord(msg):
raise IMException("indique-moi un nom de ville ou des coordonnées.") raise IMException("indique-moi un nom de ville ou des coordonnées.")
def get_json_weather(coords): def get_json_weather(coords, lang="en", units="auto"):
wth = web.getJSON(URL_DSAPI % (float(coords[0]), float(coords[1]))) wth = web.getJSON(URL_DSAPI % (float(coords[0]), float(coords[1]), lang, units))
# First read flags # First read flags
if wth is None or "darksky-unavailable" in wth["flags"]: if wth is None or "darksky-unavailable" in wth["flags"]:
@ -149,10 +134,16 @@ def cmd_coordinates(msg):
return Response("Les coordonnées de %s sont %s,%s" % (msg.args[0], coords["lat"], coords["long"]), channel=msg.channel) return Response("Les coordonnées de %s sont %s,%s" % (msg.args[0], coords["lat"], coords["long"]), channel=msg.channel)
@hook.command("alert") @hook.command("alert",
keywords={
"lang=LANG": "change the output language of weather sumarry; default: en",
"units=UNITS": "return weather conditions in the requested units; default: auto",
})
def cmd_alert(msg): def cmd_alert(msg):
loc, coords, specific = treat_coord(msg) loc, coords, specific = treat_coord(msg)
wth = get_json_weather(coords) wth = get_json_weather(coords,
lang=msg.kwargs["lang"] if "lang" in msg.kwargs else "en",
units=msg.kwargs["units"] if "units" in msg.kwargs else "auto")
res = Response(channel=msg.channel, nomore="No more weather alert", count=" (%d more alerts)") res = Response(channel=msg.channel, nomore="No more weather alert", count=" (%d more alerts)")
@ -166,10 +157,20 @@ def cmd_alert(msg):
return res return res
@hook.command("météo") @hook.command("météo",
help="Display current weather and previsions",
help_usage={
"CITY": "Display the current weather and previsions in CITY",
},
keywords={
"lang=LANG": "change the output language of weather sumarry; default: en",
"units=UNITS": "return weather conditions in the requested units; default: auto",
})
def cmd_weather(msg): def cmd_weather(msg):
loc, coords, specific = treat_coord(msg) loc, coords, specific = treat_coord(msg)
wth = get_json_weather(coords) wth = get_json_weather(coords,
lang=msg.kwargs["lang"] if "lang" in msg.kwargs else "en",
units=msg.kwargs["units"] if "units" in msg.kwargs else "auto")
res = Response(channel=msg.channel, nomore="No more weather information") res = Response(channel=msg.channel, nomore="No more weather information")
@ -243,3 +244,7 @@ def parseask(msg):
context.save() context.save()
return Response("ok, j'ai bien noté les coordonnées de %s" % res.group("city"), return Response("ok, j'ai bien noté les coordonnées de %s" % res.group("city"),
msg.channel, msg.nick) msg.channel, msg.nick)
if __name__ == "__main__":
sys.exit(main())

View file

@ -17,9 +17,9 @@
__version__ = '4.0.dev3' __version__ = '4.0.dev3'
__author__ = 'nemunaire' __author__ = 'nemunaire'
from nemubot.modulecontext import ModuleContext from nemubot.modulecontext import _ModuleContext
context = ModuleContext(None, None) context = _ModuleContext()
def requires_version(min=None, max=None): def requires_version(min=None, max=None):
@ -53,41 +53,50 @@ def attach(pid, socketfile):
sys.stderr.write("\n") sys.stderr.write("\n")
return 1 return 1
from select import select import select
mypoll = select.poll()
mypoll.register(sys.stdin.fileno(), select.POLLIN | select.POLLPRI)
mypoll.register(sock.fileno(), select.POLLIN | select.POLLPRI)
try: try:
while True: while True:
rl, wl, xl = select([sys.stdin, sock], [], []) for fd, flag in mypoll.poll():
if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
sock.close()
print("Connection closed.")
return 1
if sys.stdin in rl: if fd == sys.stdin.fileno():
line = sys.stdin.readline().strip() line = sys.stdin.readline().strip()
if line == "exit" or line == "quit": if line == "exit" or line == "quit":
return 0 return 0
elif line == "reload": elif line == "reload":
import os, signal import os, signal
os.kill(pid, signal.SIGHUP) os.kill(pid, signal.SIGHUP)
print("Reload signal sent. Please wait...") print("Reload signal sent. Please wait...")
elif line == "shutdown": elif line == "shutdown":
import os, signal import os, signal
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
print("Shutdown signal sent. Please wait...") print("Shutdown signal sent. Please wait...")
elif line == "kill": elif line == "kill":
import os, signal import os, signal
os.kill(pid, signal.SIGKILL) os.kill(pid, signal.SIGKILL)
print("Signal sent...") print("Signal sent...")
return 0 return 0
elif line == "stack" or line == "stacks": elif line == "stack" or line == "stacks":
import os, signal import os, signal
os.kill(pid, signal.SIGUSR1) os.kill(pid, signal.SIGUSR1)
print("Debug signal sent. Consult logs.") print("Debug signal sent. Consult logs.")
else: else:
sock.send(line.encode() + b'\r\n') sock.send(line.encode() + b'\r\n')
if fd == sock.fileno():
sys.stdout.write(sock.recv(2048).decode())
if sock in rl:
sys.stdout.write(sock.recv(2048).decode())
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except: except:
@ -97,13 +106,28 @@ def attach(pid, socketfile):
return 0 return 0
def daemonize(): def daemonize(socketfile=None, autoattach=True):
"""Detach the running process to run as a daemon """Detach the running process to run as a daemon
""" """
import os import os
import sys import sys
if socketfile is not None:
try:
pid = os.fork()
if pid > 0:
if autoattach:
import time
os.waitpid(pid, 0)
time.sleep(1)
sys.exit(attach(pid, socketfile))
else:
sys.exit(0)
except OSError as err:
sys.stderr.write("Unable to fork: %s\n" % err)
sys.exit(1)
try: try:
pid = os.fork() pid = os.fork()
if pid > 0: if pid > 0:

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2017 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -37,6 +37,9 @@ def main():
default=["./modules/"], default=["./modules/"],
help="directory to use as modules store") help="directory to use as modules store")
parser.add_argument("-A", "--no-attach", action="store_true",
help="don't attach after fork")
parser.add_argument("-d", "--debug", action="store_true", parser.add_argument("-d", "--debug", action="store_true",
help="don't deamonize, keep in foreground") help="don't deamonize, keep in foreground")
@ -71,32 +74,10 @@ def main():
args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile)) args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile))
args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile)) args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile))
args.logfile = os.path.abspath(os.path.expanduser(args.logfile)) args.logfile = os.path.abspath(os.path.expanduser(args.logfile))
args.files = [ x for x in map(os.path.abspath, args.files)] args.files = [x for x in map(os.path.abspath, args.files)]
args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)] args.modules_path = [x for x in map(os.path.abspath, args.modules_path)]
# Check if an instance is already launched # Setup logging interface
if args.pidfile is not None and os.path.isfile(args.pidfile):
with open(args.pidfile, "r") as f:
pid = int(f.readline())
try:
os.kill(pid, 0)
except OSError:
pass
else:
from nemubot import attach
sys.exit(attach(pid, args.socketfile))
# Daemonize
if not args.debug:
from nemubot import daemonize
daemonize()
# Store PID to pidfile
if args.pidfile is not None:
with open(args.pidfile, "w+") as f:
f.write(str(os.getpid()))
# Setup loggin interface
import logging import logging
logger = logging.getLogger("nemubot") logger = logging.getLogger("nemubot")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -115,6 +96,18 @@ def main():
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
# Check if an instance is already launched
if args.pidfile is not None and os.path.isfile(args.pidfile):
with open(args.pidfile, "r") as f:
pid = int(f.readline())
try:
os.kill(pid, 0)
except OSError:
pass
else:
from nemubot import attach
sys.exit(attach(pid, args.socketfile))
# Add modules dir paths # Add modules dir paths
modules_paths = list() modules_paths = list()
for path in args.modules_path: for path in args.modules_path:
@ -125,7 +118,7 @@ def main():
# Create bot context # Create bot context
from nemubot import datastore from nemubot import datastore
from nemubot.bot import Bot from nemubot.bot import Bot, sync_act
context = Bot(modules_paths=modules_paths, context = Bot(modules_paths=modules_paths,
data_store=datastore.XML(args.data_path), data_store=datastore.XML(args.data_path),
verbosity=args.verbose) verbosity=args.verbose)
@ -141,7 +134,7 @@ def main():
# Load requested configuration files # Load requested configuration files
for path in args.files: for path in args.files:
if os.path.isfile(path): if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path]) sync_act("loadconf", path)
else: else:
logger.error("%s is not a readable file", path) logger.error("%s is not a readable file", path)
@ -149,6 +142,17 @@ def main():
for module in args.module: for module in args.module:
__import__(module) __import__(module)
if args.socketfile:
from nemubot.server.socket import UnixSocketListener
context.add_server(UnixSocketListener(new_server_cb=context.add_server,
location=args.socketfile,
name="master_socket"))
# Daemonize
if not args.debug:
from nemubot import daemonize
daemonize(args.socketfile, not args.no_attach)
# Signals handling # Signals handling
def sigtermhandler(signum, frame): def sigtermhandler(signum, frame):
"""On SIGTERM and SIGINT, quit nicely""" """On SIGTERM and SIGINT, quit nicely"""
@ -165,22 +169,27 @@ def main():
# Reload configuration file # Reload configuration file
for path in args.files: for path in args.files:
if os.path.isfile(path): if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path]) sync_act("loadconf", path)
signal.signal(signal.SIGHUP, sighuphandler) signal.signal(signal.SIGHUP, sighuphandler)
def sigusr1handler(signum, frame): def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces""" """On SIGHUSR1, display stacktraces"""
import traceback import threading, traceback
for threadId, stack in sys._current_frames().items(): for threadId, stack in sys._current_frames().items():
logger.debug("########### Thread %d:\n%s", thName = "#%d" % threadId
threadId, for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack))) "".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler) signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile: # Store PID to pidfile
from nemubot.server.socket import SocketListener if args.pidfile is not None:
context.add_server(SocketListener(context.add_server, "master_socket", with open(args.pidfile, "w+") as f:
sock_location=args.socketfile)) f.write(str(os.getpid()))
# context can change when performing an hotswap, always join the latest context # context can change when performing an hotswap, always join the latest context
oldcontext = None oldcontext = None
@ -195,5 +204,6 @@ def main():
sigusr1handler(0, None) sigusr1handler(0, None)
sys.exit(0) sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -16,7 +16,9 @@
from datetime import datetime, timezone from datetime import datetime, timezone
import logging import logging
from multiprocessing import JoinableQueue
import threading import threading
import select
import sys import sys
from nemubot import __version__ from nemubot import __version__
@ -26,6 +28,11 @@ import nemubot.hooks
logger = logging.getLogger("nemubot") logger = logging.getLogger("nemubot")
sync_queue = JoinableQueue()
def sync_act(*args):
sync_queue.put(list(args))
class Bot(threading.Thread): class Bot(threading.Thread):
@ -42,7 +49,7 @@ class Bot(threading.Thread):
verbosity -- verbosity level verbosity -- verbosity level
""" """
threading.Thread.__init__(self) super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__, __version__,
@ -61,6 +68,7 @@ class Bot(threading.Thread):
self.datastore.open() self.datastore.open()
# Keep global context: servers and modules # Keep global context: servers and modules
self._poll = select.poll()
self.servers = dict() self.servers = dict()
self.modules = dict() self.modules = dict()
self.modules_configuration = dict() self.modules_configuration = dict()
@ -138,60 +146,84 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue() self.cnsr_queue = Queue()
self.cnsr_thrd = list() self.cnsr_thrd = list()
self.cnsr_thrd_size = -1 self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self): def run(self):
from select import select global sync_queue
from nemubot.server import _lock, _rlist, _wlist, _xlist
# Rewrite the sync_queue, as the daemonization process tend to disturb it
old_sync_queue, sync_queue = sync_queue, JoinableQueue()
while not old_sync_queue.empty():
sync_queue.put_nowait(old_sync_queue.get())
self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
logger.info("Starting main loop") logger.info("Starting main loop")
self.stop = False self.stop = False
while not self.stop: while not self.stop:
with _lock: for fd, flag in self._poll.poll():
try: # Handle internal socket passing orders
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1) if fd != sync_queue._reader.fileno() and fd in self.servers:
except: srv = self.servers[fd]
logger.error("Something went wrong in select")
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for x in xl: if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
x.exception()
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
try:
w.write_select()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
try: try:
self.receive_message(r, i) srv.exception(flag)
except: except:
logger.exception("Uncatched exception on server read") logger.exception("Uncatched exception on server exception")
if srv.fileno() > 0:
if flag & (select.POLLOUT):
try:
srv.async_write()
except:
logger.exception("Uncatched exception on server write")
if flag & (select.POLLIN | select.POLLPRI):
try:
for i in srv.async_read():
self.receive_message(srv, i)
except:
logger.exception("Uncatched exception on server read")
else:
del self.servers[fd]
# Always check the sync queue
while not sync_queue.empty():
args = sync_queue.get()
action = args.pop(0)
logger.debug("Executing sync_queue action %s%s", action, args)
if action == "sckt" and len(args) >= 2:
try:
if args[0] == "write":
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI)
elif args[0] == "unwrite":
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI)
elif args[0] == "register":
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI)
elif args[0] == "unregister":
self._poll.unregister(int(args[1]))
except:
logger.exception("Unhandled excpetion during action:")
elif action == "exit":
self.quit()
elif action == "launch_consumer":
pass # This is treated after the loop
elif action == "loadconf":
for path in args:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
sync_queue.task_done()
# Launch new consumer threads if necessary # Launch new consumer threads if necessary
@ -202,17 +234,7 @@ class Bot(threading.Thread):
c = Consumer(self) c = Consumer(self)
self.cnsr_thrd.append(c) self.cnsr_thrd.append(c)
c.start() c.start()
sync_queue = None
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
logger.info("Ending main loop") logger.info("Ending main loop")
@ -385,7 +407,13 @@ class Bot(threading.Thread):
self.event_timer.cancel() self.event_timer.cancel()
if len(self.events): if len(self.events):
remaining = self.events[0].time_left.total_seconds() try:
remaining = self.events[0].time_left.total_seconds()
except:
logger.exception("An error occurs during event time calculation:")
self.events.pop(0)
return self._update_event_timer()
logger.debug("Update timer: next event in %d seconds", remaining) logger.debug("Update timer: next event in %d seconds", remaining)
self.event_timer = threading.Timer(remaining if remaining > 0 else 0, self._end_event_timer) self.event_timer = threading.Timer(remaining if remaining > 0 else 0, self._end_event_timer)
self.event_timer.start() self.event_timer.start()
@ -400,6 +428,7 @@ class Bot(threading.Thread):
while len(self.events) > 0 and datetime.now(timezone.utc) >= self.events[0].current: while len(self.events) > 0 and datetime.now(timezone.utc) >= self.events[0].current:
evt = self.events.pop(0) evt = self.events.pop(0)
self.cnsr_queue.put_nowait(EventConsumer(evt)) self.cnsr_queue.put_nowait(EventConsumer(evt))
sync_act("launch_consumer")
self._update_event_timer() self._update_event_timer()
@ -419,7 +448,7 @@ class Bot(threading.Thread):
self.servers[fileno] = srv self.servers[fileno] = srv
self.servers[srv.name] = srv self.servers[srv.name] = srv
if autoconnect and not hasattr(self, "noautoconnect"): if autoconnect and not hasattr(self, "noautoconnect"):
srv.open() srv.connect()
return True return True
else: else:
@ -463,7 +492,7 @@ class Bot(threading.Thread):
module.print = prnt module.print = prnt
# Create module context # Create module context
from nemubot.modulecontext import ModuleContext from nemubot.modulecontext import _ModuleContext, ModuleContext
module.__nemubot_context__ = ModuleContext(self, module) module.__nemubot_context__ = ModuleContext(self, module)
if not hasattr(module, "logger"): if not hasattr(module, "logger"):
@ -471,7 +500,7 @@ class Bot(threading.Thread):
# Replace imported context by real one # Replace imported context by real one
for attr in module.__dict__: for attr in module.__dict__:
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext: if attr != "__nemubot_context__" and type(module.__dict__[attr]) == _ModuleContext:
module.__dict__[attr] = module.__nemubot_context__ module.__dict__[attr] = module.__nemubot_context__
# Register decorated functions # Register decorated functions
@ -532,28 +561,29 @@ class Bot(threading.Thread):
def quit(self): def quit(self):
"""Save and unload modules and disconnect servers""" """Save and unload modules and disconnect servers"""
self.datastore.close()
if self.event_timer is not None: if self.event_timer is not None:
logger.info("Stop the event timer...") logger.info("Stop the event timer...")
self.event_timer.cancel() self.event_timer.cancel()
logger.info("Save and unload all modules...")
for mod in self.modules.items():
self.unload_module(mod)
logger.info("Close all servers connection...")
for srv in [self.servers[k] for k in self.servers]:
srv.close()
logger.info("Stop consumers") logger.info("Stop consumers")
k = self.cnsr_thrd k = self.cnsr_thrd
for cnsr in k: for cnsr in k:
cnsr.stop = True cnsr.stop = True
logger.info("Save and unload all modules...") self.datastore.close()
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
logger.info("Close all servers connection...") if self.stop is False or sync_queue is not None:
k = list(self.servers.keys()) self.stop = True
for srv in k: sync_act("end")
self.servers[srv].close() sync_queue.join()
self.stop = True
# Treatment # Treatment

View file

@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.config import get_boolean from nemubot.config import get_boolean
from nemubot.tools.xmlparser.genericnode import GenericNode from nemubot.datastore.nodes.generic import GenericNode
class Module(GenericNode): class Module(GenericNode):

View file

@ -23,8 +23,7 @@ class Abstract:
"""Initialize a new empty storage tree """Initialize a new empty storage tree
""" """
from nemubot.tools.xmlparser import module_state return None
return module_state.ModuleState("nemubotstate")
def open(self): def open(self):
return return

View file

@ -0,0 +1,18 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable

View file

@ -14,11 +14,16 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
class ListNode: from nemubot.datastore.nodes.serializable import Serializable
class ListNode(Serializable):
"""XML node representing a Python dictionnnary """XML node representing a Python dictionnnary
""" """
serializetag = "list"
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.items = list() self.items = list()
@ -27,6 +32,9 @@ class ListNode:
self.items.append(child) self.items.append(child)
return True return True
def parsedForm(self):
return self.items
def __len__(self): def __len__(self):
return len(self.items) return len(self.items)
@ -44,11 +52,21 @@ class ListNode:
return self.items.__repr__() return self.items.__repr__()
class DictNode: def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
for i in self.items:
node.children.append(ParsingNode.serialize_node(i))
return node
class DictNode(Serializable):
"""XML node representing a Python dictionnnary """XML node representing a Python dictionnnary
""" """
serializetag = "dict"
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.items = dict() self.items = dict()
self._cur = None self._cur = None
@ -56,44 +74,20 @@ class DictNode:
def startElement(self, name, attrs): def startElement(self, name, attrs):
if self._cur is None and "key" in attrs: if self._cur is None and "key" in attrs:
self._cur = (attrs["key"], "") self._cur = attrs["key"]
return True
return False return False
def characters(self, content):
if self._cur is not None:
key, cnt = self._cur
if isinstance(cnt, str):
cnt += content
self._cur = key, cnt
def endElement(self, name):
if name is None or self._cur is None:
return
key, cnt = self._cur
if isinstance(cnt, list) and len(cnt) == 1:
self.items[key] = cnt
else:
self.items[key] = cnt
self._cur = None
return True
def addChild(self, name, child): def addChild(self, name, child):
if self._cur is None: if self._cur is None:
return False return False
key, cnt = self._cur self.items[self._cur] = child
if not isinstance(cnt, list): self._cur = None
cnt = []
cnt.append(child)
self._cur = key, cnt
return True return True
def parsedForm(self):
return self.items
def __getitem__(self, item): def __getitem__(self, item):
return self.items[item] return self.items[item]
@ -106,3 +100,13 @@ class DictNode:
def __repr__(self): def __repr__(self):
return self.items.__repr__() return self.items.__repr__()
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
for k in self.items:
chld = ParsingNode.serialize_node(self.items[k])
chld.attrs["key"] = k
node.children.append(chld)
return node

View file

@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.serializable import Serializable
class ParsingNode: class ParsingNode:
"""Allow any kind of subtags, just keep parsed ones """Allow any kind of subtags, just keep parsed ones
@ -53,6 +56,47 @@ class ParsingNode:
return item in self.attrs return item in self.attrs
def serialize_node(node, **def_kwargs):
"""Serialize any node or basic data to a ParsingNode instance"""
if isinstance(node, Serializable):
node = node.serialize()
if isinstance(node, str):
from nemubot.datastore.nodes.python import StringNode
pn = StringNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, int):
from nemubot.datastore.nodes.python import IntNode
pn = IntNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, float):
from nemubot.datastore.nodes.python import FloatNode
pn = FloatNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, list):
from nemubot.datastore.nodes.basic import ListNode
pn = ListNode(**def_kwargs)
pn.items = node
return pn.serialize()
elif isinstance(node, dict):
from nemubot.datastore.nodes.basic import DictNode
pn = DictNode(**def_kwargs)
pn.items = node
return pn.serialize()
else:
assert isinstance(node, ParsingNode)
return node
class GenericNode(ParsingNode): class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary """Consider all subtags as dictionnary

View file

@ -0,0 +1,91 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.serializable import Serializable
class PythonTypeNode(Serializable):
"""XML node representing a Python simple type
"""
def __init__(self, **kwargs):
self.value = None
self._cnt = ""
def characters(self, content):
self._cnt += content
def endElement(self, name):
raise NotImplemented
def __repr__(self):
return self.value.__repr__()
def parsedForm(self):
return self.value
def serialize(self):
raise NotImplemented
class IntNode(PythonTypeNode):
serializetag = "int"
def endElement(self, name):
self.value = int(self._cnt)
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
class FloatNode(PythonTypeNode):
serializetag = "float"
def endElement(self, name):
self.value = float(self._cnt)
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
class StringNode(PythonTypeNode):
serializetag = "str"
def endElement(self, name):
self.value = str(self._cnt)
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node

View file

@ -0,0 +1,22 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class Serializable:
def serialize(self):
# Implementations of this function should return ParsingNode items
return NotImplemented

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -36,17 +36,24 @@ class XML(Abstract):
rotate -- auto-backup files? rotate -- auto-backup files?
""" """
self.basedir = basedir self.basedir = os.path.abspath(basedir)
self.rotate = rotate self.rotate = rotate
self.nb_save = 0 self.nb_save = 0
logger.info("Initiate XML datastore at %s, rotation %s",
self.basedir,
"enabled" if self.rotate else "disabled")
def open(self): def open(self):
"""Lock the directory""" """Lock the directory"""
if not os.path.isdir(self.basedir): if not os.path.isdir(self.basedir):
logger.debug("Datastore directory not found, creating: %s", self.basedir)
os.mkdir(self.basedir) os.mkdir(self.basedir)
lock_path = os.path.join(self.basedir, ".used_by_nemubot") lock_path = self._get_lock_file_path()
logger.debug("Locking datastore directory via %s", lock_path)
self.lock_file = open(lock_path, 'a+') self.lock_file = open(lock_path, 'a+')
ok = True ok = True
@ -64,56 +71,95 @@ class XML(Abstract):
self.lock_file.write(str(os.getpid())) self.lock_file.write(str(os.getpid()))
self.lock_file.flush() self.lock_file.flush()
logger.info("Datastore successfuly opened at %s", self.basedir)
return True return True
def close(self): def close(self):
"""Release a locked path""" """Release a locked path"""
if hasattr(self, "lock_file"): if hasattr(self, "lock_file"):
self.lock_file.close() self.lock_file.close()
lock_path = os.path.join(self.basedir, ".used_by_nemubot") lock_path = self._get_lock_file_path()
if os.path.isdir(self.basedir) and os.path.exists(lock_path): if os.path.isdir(self.basedir) and os.path.exists(lock_path):
os.unlink(lock_path) os.unlink(lock_path)
del self.lock_file del self.lock_file
logger.info("Datastore successfully closed at %s", self.basedir)
return True return True
else:
logger.warn("Datastore not open/locked or lock file not found")
return False return False
def _get_data_file_path(self, module): def _get_data_file_path(self, module):
"""Get the path to the module data file""" """Get the path to the module data file"""
return os.path.join(self.basedir, module + ".xml") return os.path.join(self.basedir, module + ".xml")
def load(self, module):
def _get_lock_file_path(self):
"""Get the path to the datastore lock file"""
return os.path.join(self.basedir, ".used_by_nemubot")
def load(self, module, extendsTags={}):
"""Load data for the given module """Load data for the given module
Argument: Argument:
module -- the module name of data to load module -- the module name of data to load
""" """
logger.debug("Trying to load data for %s%s",
module,
(" with tags: " + ", ".join(extendsTags.keys())) if len(extendsTags) else "")
data_file = self._get_data_file_path(module) data_file = self._get_data_file_path(module)
def parse(path):
from nemubot.tools.xmlparser import XMLParser
from nemubot.datastore.nodes import basic as basicNodes
from nemubot.datastore.nodes import python as pythonNodes
from nemubot.message.command import Command
from nemubot.scope import Scope
d = {
basicNodes.ListNode.serializetag: basicNodes.ListNode,
basicNodes.DictNode.serializetag: basicNodes.DictNode,
pythonNodes.IntNode.serializetag: pythonNodes.IntNode,
pythonNodes.FloatNode.serializetag: pythonNodes.FloatNode,
pythonNodes.StringNode.serializetag: pythonNodes.StringNode,
Command.serializetag: Command,
Scope.serializetag: Scope,
}
d.update(extendsTags)
p = XMLParser(d)
return p.parse_file(path)
# Try to load original file # Try to load original file
if os.path.isfile(data_file): if os.path.isfile(data_file):
from nemubot.tools.xmlparser import parse_file
try: try:
return parse_file(data_file) return parse(data_file)
except xml.parsers.expat.ExpatError: except xml.parsers.expat.ExpatError:
# Try to load from backup # Try to load from backup
for i in range(10): for i in range(10):
path = data_file + "." + str(i) path = data_file + "." + str(i)
if os.path.isfile(path): if os.path.isfile(path):
try: try:
cnt = parse_file(path) cnt = parse(path)
logger.warn("Restoring from backup: %s", path) logger.warn("Restoring data from backup: %s", path)
return cnt return cnt
except xml.parsers.expat.ExpatError: except xml.parsers.expat.ExpatError:
continue continue
# Default case: initialize a new empty datastore # Default case: initialize a new empty datastore
logger.warn("No data found in store for %s, creating new set", module)
return Abstract.load(self, module) return Abstract.load(self, module)
def _rotate(self, path): def _rotate(self, path):
"""Backup given path """Backup given path
@ -130,6 +176,25 @@ class XML(Abstract):
if os.path.isfile(src): if os.path.isfile(src):
os.rename(src, dst) os.rename(src, dst)
def _save_node(self, gen, node):
from nemubot.datastore.nodes.generic import ParsingNode
# First, get the serialized form of the node
node = ParsingNode.serialize_node(node)
assert node.tag is not None, "Undefined tag name"
gen.startElement(node.tag, {k: str(node.attrs[k]) for k in node.attrs})
gen.characters(node.content)
for child in node.children:
self._save_node(gen, child)
gen.endElement(node.tag)
def save(self, module, data): def save(self, module, data):
"""Load data for the given module """Load data for the given module
@ -139,8 +204,22 @@ class XML(Abstract):
""" """
path = self._get_data_file_path(module) path = self._get_data_file_path(module)
logger.debug("Trying to save data for module %s in %s", module, path)
if self.rotate: if self.rotate:
self._rotate(path) self._rotate(path)
return data.save(path) import tempfile
_, tmpath = tempfile.mkstemp()
with open(tmpath, "w") as f:
import xml.sax.saxutils
gen = xml.sax.saxutils.XMLGenerator(f, "utf-8")
gen.startDocument()
self._save_node(gen, data)
gen.endDocument()
# Atomic save
import shutil
shutil.move(tmpath, path)
return True

View file

@ -16,12 +16,17 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from nemubot.datastore.nodes import Serializable
class Abstract:
class Abstract(Serializable):
"""This class represents an abstract message""" """This class represents an abstract message"""
def __init__(self, server=None, date=None, to=None, to_response=None, frm=None): serializetag = "nemubotAMessage"
def __init__(self, server=None, date=None, to=None, to_response=None, frm=None, frm_owner=False):
"""Initialize an abstract message """Initialize an abstract message
Arguments: Arguments:
@ -40,7 +45,7 @@ class Abstract:
else [ to_response ]) else [ to_response ])
self.frm = frm # None allowed when it designate this bot self.frm = frm # None allowed when it designate this bot
self.frm_owner = False # Filled later, in consumer self.frm_owner = frm_owner
@property @property
@ -65,6 +70,14 @@ class Abstract:
return self.frm return self.frm
@property
def scope(self):
from nemubot.scope import Scope
return Scope(server=self.server,
channel=self.to_response[0],
nick=self.frm)
def accept(self, visitor): def accept(self, visitor):
visitor.visit(self) visitor.visit(self)
@ -78,7 +91,8 @@ class Abstract:
"date": self.date, "date": self.date,
"to": self.to, "to": self.to,
"to_response": self._to_response, "to_response": self._to_response,
"frm": self.frm "frm": self.frm,
"frm_owner": self.frm_owner,
} }
for w in without: for w in without:
@ -86,3 +100,8 @@ class Abstract:
del ret[w] del ret[w]
return ret return ret
def serialize(self):
from nemubot.datastore.nodes import ParsingNode
return ParsingNode(tag=Abstract.serializetag, **self.export_args())

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -21,6 +21,9 @@ class Command(Abstract):
"""This class represents a specialized TextMessage""" """This class represents a specialized TextMessage"""
serializetag = "nemubotCommand"
def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs): def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs):
super().__init__(*nargs, **kargs) super().__init__(*nargs, **kargs)
@ -28,17 +31,35 @@ class Command(Abstract):
self.args = args if args is not None else list() self.args = args if args is not None else list()
self.kwargs = kwargs if kwargs is not None else dict() self.kwargs = kwargs if kwargs is not None else dict()
def __str__(self):
def __repr__(self):
return self.cmd + " @" + ",@".join(self.args) return self.cmd + " @" + ",@".join(self.args)
@property
def cmds(self): def addChild(self, name, child):
# TODO: this is for legacy modules if name == "list":
return [self.cmd] + self.args self.args = child
elif name == "dict":
self.kwargs = child
else:
return False
return True
def serialize(self):
from nemubot.datastore.nodes import ParsingNode
node = ParsingNode(tag=Command.serializetag, cmd=self.cmd)
if len(self.args):
node.children.append(ParsingNode.serialize_node(self.args))
if len(self.kwargs):
node.children.append(ParsingNode.serialize_node(self.kwargs))
return node
class OwnerCommand(Command): class OwnerCommand(Command):
"""This class represents a special command incomming from the owner""" """This class represents a special command incomming from the owner"""
serializetag = "nemubotOCommand"
pass pass

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2017 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,105 +14,62 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
class ModuleContext: class _ModuleContext:
def __init__(self, context, module): def __init__(self, module=None):
"""Initialize the module context self.module = module
arguments:
context -- the bot context
module -- the module
"""
if module is not None: if module is not None:
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__ self.module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
else: else:
module_name = "" self.module_name = ""
# Load module configuration if exists
if (context is not None and
module_name in context.modules_configuration):
self.config = context.modules_configuration[module_name]
else:
from nemubot.config.module import Module
self.config = Module(module_name)
self.hooks = list() self.hooks = list()
self.events = list() self.events = list()
self.debug = context.verbosity > 0 if context is not None else False self.extendtags = dict()
self.debug = False
from nemubot.config.module import Module
self.config = Module(self.module_name)
def load_data(self):
return None
def add_hook(self, hook, *triggers):
from nemubot.hooks import Abstract as AbstractHook from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
# Define some callbacks def del_hook(self, hook, *triggers):
if context is not None: from nemubot.hooks import Abstract as AbstractHook
def load_data(): assert isinstance(hook, AbstractHook), hook
return context.datastore.load(module_name) self.hooks.remove((triggers, hook))
def add_hook(hook, *triggers): def subtreat(self, msg):
assert isinstance(hook, AbstractHook), hook return None
self.hooks.append((triggers, hook))
return context.treater.hm.add_hook(hook, *triggers)
def del_hook(hook, *triggers): def add_event(self, evt, eid=None):
assert isinstance(hook, AbstractHook), hook return self.events.append((evt, eid))
self.hooks.remove((triggers, hook))
return context.treater.hm.del_hooks(*triggers, hook=hook)
def subtreat(msg): def del_event(self, evt):
yield from context.treater.treat_msg(msg) for i in self.events:
def add_event(evt, eid=None): e, eid = i
return context.add_event(evt, eid, module_src=module) if e == evt:
def del_event(evt): self.events.remove(i)
return context.del_event(evt, module_src=module) return True
return False
def send_response(server, res): def send_response(self, server, res):
if server in context.servers: self.module.logger.info("Send response: %s", res)
if res.server is not None:
return context.servers[res.server].send_response(res)
else:
return context.servers[server].send_response(res)
else:
module.logger.error("Try to send a message to the unknown server: %s", server)
return False
else: # Used when using outside of nemubot def save(self):
def load_data(): # Don't save if no data has been access
from nemubot.tools.xmlparser import module_state if hasattr(self, "_data"):
return module_state.ModuleState("nemubotstate") context.datastore.save(self.module_name, self.data)
def add_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
def del_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook
self.hooks.remove((triggers, hook))
def subtreat(msg):
return None
def add_event(evt, eid=None):
return context.add_event(evt, eid, module_src=module)
def del_event(evt):
return context.del_event(evt, module_src=module)
def send_response(server, res):
module.logger.info("Send response: %s", res)
def save():
context.datastore.save(module_name, self.data)
def subparse(orig, cnt):
if orig.server in context.servers:
return context.servers[orig.server].subparse(orig, cnt)
self.load_data = load_data
self.add_hook = add_hook
self.del_hook = del_hook
self.add_event = add_event
self.del_event = del_event
self.save = save
self.send_response = send_response
self.subtreat = subtreat
self.subparse = subparse
def subparse(self, orig, cnt):
if orig.server in self.context.servers:
return self.context.servers[orig.server].subparse(orig, cnt)
@property @property
def data(self): def data(self):
@ -120,6 +77,21 @@ class ModuleContext:
self._data = self.load_data() self._data = self.load_data()
return self._data return self._data
@data.setter
def data(self, value):
assert value is not None
self._data = value
def register_tags(self, **tags):
self.extendtags.update(tags)
def unregister_tags(self, *tags):
for t in tags:
del self.extendtags[t]
def unload(self): def unload(self):
"""Perform actions for unloading the module""" """Perform actions for unloading the module"""
@ -129,7 +101,62 @@ class ModuleContext:
self.del_hook(h, *s) self.del_hook(h, *s)
# Remove registered events # Remove registered events
for e in self.events: for evt, eid, module_src in self.events:
self.del_event(e) self.del_event(evt)
self.save() self.save()
class ModuleContext(_ModuleContext):
def __init__(self, context, *args, **kwargs):
"""Initialize the module context
arguments:
context -- the bot context
module -- the module
"""
super().__init__(*args, **kwargs)
# Load module configuration if exists
if self.module_name in context.modules_configuration:
self.config = context.modules_configuration[self.module_name]
self.context = context
self.debug = context.verbosity > 0
def load_data(self):
return self.context.datastore.load(self.module_name, extendsTags=self.extendtags)
def add_hook(self, hook, *triggers):
from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
return self.context.treater.hm.add_hook(hook, *triggers)
def del_hook(self, hook, *triggers):
from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.remove((triggers, hook))
return self.context.treater.hm.del_hooks(*triggers, hook=hook)
def subtreat(self, msg):
yield from self.context.treater.treat_msg(msg)
def add_event(self, evt, eid=None):
return self.context.add_event(evt, eid, module_src=self.module)
def del_event(self, evt):
return self.context.del_event(evt, module_src=self.module)
def send_response(self, server, res):
if server in self.context.servers:
if res.server is not None:
return self.context.servers[res.server].send_response(res)
else:
return self.context.servers[server].send_response(res)
else:
self.module.logger.error("Try to send a message to the unknown server: %s", server)
return False

83
nemubot/scope.py Normal file
View file

@ -0,0 +1,83 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes import Serializable
class Scope(Serializable):
serializetag = "nemubot-scope"
default_limit = "channel"
def __init__(self, server, channel, nick, limit=default_limit):
self._server = server
self._channel = channel
self._nick = nick
self._limit = limit
def sameServer(self, server):
return self._server is None or self._server == server
def sameChannel(self, server, channel):
return self.sameServer(server) and (self._channel is None or self._channel == channel)
def sameNick(self, server, channel, nick):
return self.sameChannel(server, channel) and (self._nick is None or self._nick == nick)
def check(self, scope, limit=None):
return self.checkScope(scope._server, scope._channel, scope._nick, limit)
def checkScope(self, server, channel, nick, limit=None):
if limit is None: limit = self._limit
assert limit == "global" or limit == "server" or limit == "channel" or limit == "nick"
if limit == "server":
return self.sameServer(server)
elif limit == "channel":
return self.sameChannel(server, channel)
elif limit == "nick":
return self.sameNick(server, channel, nick)
else:
return True
def narrow(self, scope):
return scope is None or (
scope._limit == "global" or
(scope._limit == "server" and (self._limit == "nick" or self._limit == "channel")) or
(scope._limit == "channel" and self._limit == "nick")
)
def serialize(self):
from nemubot.datastore.nodes import ParsingNode
args = {}
if self._server is not None:
args["server"] = self._server
if self._channel is not None:
args["channel"] = self._channel
if self._nick is not None:
args["nick"] = self._nick
if self._limit is not None:
args["limit"] = self._limit
return ParsingNode(tag=self.serializetag, **args)

View file

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer): class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None): def __init__(self, srv, dest, socket=None):
super().__init__(self) super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection? self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion self.messages = list() # Message queued before connexion

View file

@ -16,21 +16,22 @@
from datetime import datetime from datetime import datetime
import re import re
import socket
from nemubot.channel import Channel from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter from nemubot.message.printer.IRC import IRC as IRCPrinter
from nemubot.server.message.IRC import IRC as IRCMessage from nemubot.server.message.IRC import IRC as IRCMessage
from nemubot.server.socket import SocketServer from nemubot.server.socket import SocketServer, SecureSocketServer
class IRC(SocketServer): class _IRC:
"""Concrete implementation of a connexion to an IRC server""" """Concrete implementation of a connexion to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None, def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None, nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None, realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None): channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server """Prepare a connection with an IRC server
Keyword arguments: Keyword arguments:
@ -54,7 +55,8 @@ class IRC(SocketServer):
self.owner = owner self.owner = owner
self.realname = realname self.realname = realname
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port)) super().__init__(name=self.username + "@" + host + ":" + str(port),
host=host, port=port, **kwargs)
self.printer = IRCPrinter self.printer = IRCPrinter
self.encoding = encoding self.encoding = encoding
@ -231,20 +233,19 @@ class IRC(SocketServer):
# Open/close # Open/close
def open(self): def connect(self):
if super().open(): super().connect()
if self.password is not None:
self.write("PASS :" + self.password) if self.password is not None:
if self.capabilities is not None: self.write("PASS :" + self.password)
self.write("CAP LS") if self.capabilities is not None:
self.write("NICK :" + self.nick) self.write("CAP LS")
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) self.write("NICK :" + self.nick)
return True self.write("USER %s %s bla :%s" % (self.username, socket.getfqdn(), self.realname))
return False
def close(self): def close(self):
if not self.closed: if not self._closed:
self.write("QUIT") self.write("QUIT")
return super().close() return super().close()
@ -253,8 +254,8 @@ class IRC(SocketServer):
# Read # Read
def read(self): def async_read(self):
for line in super().read(): for line in super().async_read():
# PING should be handled here, so start parsing here :/ # PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding) msg = IRCMessage(line, self.encoding)
@ -273,3 +274,10 @@ class IRC(SocketServer):
def subparse(self, orig, cnt): def subparse(self, orig, cnt):
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding) msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
return msg.to_bot_message(self) return msg.to_bot_message(self)
class IRC(_IRC, SocketServer):
pass
class IRC_secure(_IRC, SecureSocketServer):
pass

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,57 +14,64 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock() def factory(uri, ssl=False, **init_args):
from urllib.parse import urlparse, unquote, parse_qs
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
from urllib.parse import urlparse, unquote
o = urlparse(uri) o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs": if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
args = init_args args = init_args
modifiers = o.path.split(",") if o.scheme == "ircs": ssl = True
target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True
if o.hostname is not None: args["host"] = o.hostname if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username if o.username is not None: args["username"] = o.username
if o.password is not None: args["password"] = o.password if o.password is not None: args["password"] = o.password
queries = o.query.split("&") if ssl:
for q in queries: try:
if "=" in q: from ssl import create_default_context
key, val = tuple(q.split("=", 1)) args["_context"] = create_default_context()
else: except ImportError:
key, val = q, "" # Python 3.3 compat
if key == "msg": from ssl import SSLContext, PROTOCOL_TLSv1
if "on_connect" not in args: args["_context"] = SSLContext(PROTOCOL_TLSv1)
args["on_connect"] = []
args["on_connect"].append("PRIVMSG %s :%s" % (target, unquote(val)))
elif key == "key":
if "channels" not in args:
args["channels"] = []
args["channels"].append((target, unquote(val)))
elif key == "pass":
args["password"] = unquote(val)
elif key == "charset":
args["encoding"] = unquote(val)
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
# Read query string
params = parse_qs(o.query)
if "msg" in params:
if "on_connect" not in args:
args["on_connect"] = []
args["on_connect"].append("PRIVMSG %s :%s" % (target, params["msg"]))
if "key" in params:
if "channels" not in args:
args["channels"] = []
args["channels"].append((target, params["key"]))
if "pass" in params:
args["password"] = params["pass"]
if "charset" in params:
args["encoding"] = params["charset"]
#
if "channels" not in args and "isnick" not in modifiers: if "channels" not in args and "isnick" not in modifiers:
args["channels"] = [ target ] args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer if ssl:
return IRCServer(**args) from nemubot.server.IRC import IRC_secure as SecureIRCServer
else: srv = SecureIRCServer(**args)
return None else:
from nemubot.server.IRC import IRC as IRCServer
srv = IRCServer(**args)
return srv

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,34 +14,30 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging import logging
import queue import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase): class AbstractServer:
"""An abstract server: handle communication with an IM server""" """An abstract server: handle communication with an IM server"""
def __init__(self, name=None, send_callback=None): def __init__(self, name=None, **kwargs):
"""Initialize an abstract server """Initialize an abstract server
Keyword argument: Keyword argument:
send_callback -- Callback when developper want to send a message name -- Identifier of the socket, for convinience
""" """
self._name = name self._name = name
super().__init__() super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + self.name) self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue() self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
else:
self._send_callback = self._write_select
@property @property
@ -54,40 +50,28 @@ class AbstractServer(io.IOBase):
# Open/close # Open/close
def __enter__(self): def connect(self, *args, **kwargs):
self.open() """Register the server in _poll"""
return self
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._on_connect()
def _on_connect(self):
sync_act("sckt", "register", self.fileno())
def __exit__(self, type, value, traceback): def close(self, *args, **kwargs):
self.close() """Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self): if self.fileno() > 0:
"""Generic open function that register the server un _rlist in case sync_act("sckt", "unregister", self.fileno())
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
super().close(*args, **kwargs)
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
# Writes # Writes
@ -99,13 +83,16 @@ class AbstractServer(io.IOBase):
message -- message to send message -- message to send
""" """
self._send_callback(message) self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt", "write", self.fileno())
def write_select(self): def async_write(self):
"""Internal function used by the select function""" """Internal function used when the file descriptor is writable"""
try: try:
_wlist.remove(self) sync_act("sckt", "unwrite", self.fileno())
while not self._sending_queue.empty(): while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait()) self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done() self._sending_queue.task_done()
@ -114,19 +101,6 @@ class AbstractServer(io.IOBase):
pass pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response): def send_response(self, response):
"""Send a formated Message class """Send a formated Message class
@ -149,13 +123,39 @@ class AbstractServer(io.IOBase):
# Read # Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg): def parse(self, msg):
raise NotImplemented raise NotImplemented
# Exceptions # Exceptions
def exception(self): def exception(self, flags):
"""Exception occurs in fd""" """Exception occurs on fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.name) self.close()

View file

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import socket
import unittest import unittest
from nemubot.server import factory from nemubot.server import factory
@ -22,34 +23,36 @@ class TestFactory(unittest.TestCase):
def test_IRC1(self): def test_IRC1(self):
from nemubot.server.IRC import IRC as IRCServer from nemubot.server.IRC import IRC as IRCServer
from nemubot.server.IRC import IRC_secure as IRCSServer
# <host>: If omitted, the client must connect to a prespecified default IRC server. # <host>: If omitted, the client must connect to a prespecified default IRC server.
server = factory("irc:///") server = factory("irc:///")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "localhost") self.assertEqual(server._sockaddr,
self.assertFalse(server.ssl) socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("ircs:///") server = factory("ircs:///")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "localhost") self.assertEqual(server._sockaddr,
self.assertTrue(server.ssl) socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("irc://host1") server = factory("irc://freenode.net")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host1") self.assertEqual(server._sockaddr,
self.assertFalse(server.ssl) socket.getaddrinfo("freenode.net", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("irc://host2:6667") server = factory("irc://freenode.org:1234")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host2") self.assertEqual(server._sockaddr,
self.assertEqual(server.port, 6667) socket.getaddrinfo("freenode.org", 1234, proto=socket.IPPROTO_TCP)[0][4])
self.assertFalse(server.ssl)
server = factory("ircs://host3:194/") server = factory("ircs://nemunai.re:194/")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "host3") self.assertEqual(server._sockaddr,
self.assertEqual(server.port, 194) socket.getaddrinfo("nemunai.re", 194, proto=socket.IPPROTO_TCP)[0][4])
self.assertTrue(server.ssl)
with self.assertRaises(socket.gaierror):
factory("irc://_nonexistent.nemunai.re")
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -150,7 +150,8 @@ class IRC(Abstract):
"date": self.tags["time"], "date": self.tags["time"],
"to": receivers, "to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers], "to_response": [r if r != srv.nick else self.nick for r in receivers],
"frm": self.nick "frm": self.nick,
"frm_owner": self.nick == srv.owner
} }
# If CTCP, remove 0x01 # If CTCP, remove 0x01

View file

@ -0,0 +1,15 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,117 +14,33 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import ssl
import nemubot.message as message import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer): class _Socket(AbstractServer):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)""" """Concrete implementation of a socket connection"""
def __init__(self, sock_location=None, def __init__(self, printer=SocketPrinter, **kwargs):
host=None, port=None,
sock=None,
ssl=False,
name=None):
"""Create a server socket """Create a server socket
Keyword arguments:
sock_location -- Path to the UNIX socket
host -- Hostname of the INET socket
port -- Port of the INET socket
sock -- Already connected socket
ssl -- Should TLS connection enabled
name -- Convinience name
""" """
import socket super().__init__(**kwargs)
assert(sock is None or isinstance(sock, socket.SocketType))
assert(port is None or isinstance(port, int))
super().__init__(name=name)
if sock is None:
if sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connect_to = sock_location
elif host is not None:
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
self.socket = socket.socket(af, socktype, proto)
self.connect_to = sa
break
else:
self.socket = sock
self.ssl = ssl
self.readbuffer = b'' self.readbuffer = b''
self.printer = SocketPrinter self.printer = printer
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket._closed
# Open/close
def open(self):
if not self.closed:
return True
try:
self.socket.connect(self.connect_to)
self.logger.info("Connected to %s", self.connect_to)
except:
self.socket.close()
self.logger.exception("Unable to connect to %s",
self.connect_to)
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return super().open()
def close(self):
import socket
# Flush the sending queue before close
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if not self.closed:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
return super().close()
# Write # Write
def _write(self, cnt): def _write(self, cnt):
if self.closed: self.sendall(cnt)
return
self.socket.sendall(cnt)
def format(self, txt): def format(self, txt):
@ -136,19 +52,12 @@ class SocketServer(AbstractServer):
# Read # Read
def read(self): def recv(self, n=1024):
if self.closed: return super().recv(n)
return []
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line
def parse(self, line): def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex import shlex
line = line.strip().decode() line = line.strip().decode()
@ -157,48 +66,107 @@ class SocketServer(AbstractServer):
except ValueError: except ValueError:
args = line.split(' ') args = line.split(' ')
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you") if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
class SocketListener(AbstractServer): def subparse(self, orig, cnt):
for m in self.parse(cnt):
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None): m.to = orig.to
super().__init__(name=name) m.frm = orig.frm
self.new_server_cb = new_server_cb m.date = orig.date
self.sock_location = sock_location yield m
self.host = host
self.port = port
self.ssl = ssl
self.nb_son = 0
def fileno(self): class _SocketServer(_Socket):
return self.socket.fileno() if self.socket else None
def __init__(self, host, port, bind=None, **kwargs):
(family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
if isinstance(self, ssl.SSLSocket) and "server_hostname" not in kwargs:
kwargs["server_hostname"] = host
super().__init__(family=family, type=type, proto=proto, **kwargs)
self._sockaddr = sockaddr
self._bind = bind
@property def connect(self):
def closed(self): self.logger.info("Connection to %s:%d", *self._sockaddr[:2])
"""Indicator of the connection aliveness""" super().connect(self._sockaddr)
return self.socket is None
if self._bind:
super().bind(self._bind)
def open(self): class SocketServer(_SocketServer, socket.socket):
import os pass
import socket
if self.sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.remove(self.sock_location)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind((self.host, self.port))
self.socket.listen(5)
return super().open() class SecureSocketServer(_SocketServer, ssl.SSLSocket):
pass
class UnixSocket:
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def connect(self):
self.logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
class SocketClient(_Socket, socket.socket):
def read(self):
return self.recv()
class _Listener:
def __init__(self, new_server_cb, instanciate=SocketClient, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self):
conn, addr = self.accept()
fileno = conn.fileno()
self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
ss.connect = ss._on_connect
self._new_server_cb(ss, autoconnect=True)
return b''
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self):
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._on_connect()
def close(self): def close(self):
@ -206,25 +174,14 @@ class SocketListener(AbstractServer):
import socket import socket
try: try:
self.socket.shutdown(socket.SHUT_RDWR) self.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
except socket.error: except socket.error:
pass pass
return super().close() super().close()
try:
# Read if self._socket_path is not None:
os.remove(self._socket_path)
def read(self): except:
if self.closed: pass
return []
conn, addr = self.socket.accept()
self.nb_son += 1
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []

View file

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from urllib.parse import urlparse, urlsplit, urlunsplit from urllib.parse import urljoin, urlparse, urlsplit, urlunsplit
from nemubot.exception import IMException from nemubot.exception import IMException
@ -108,6 +108,9 @@ def getURLContent(url, body=None, timeout=7, header=None):
elif "User-agent" not in header: elif "User-agent" not in header:
header["User-agent"] = "Nemubot v%s" % __version__ header["User-agent"] = "Nemubot v%s" % __version__
if body is not None and "Content-Type" not in header:
header["Content-Type"] = "application/x-www-form-urlencoded"
import socket import socket
try: try:
if o.query != '': if o.query != '':
@ -156,21 +159,23 @@ def getURLContent(url, body=None, timeout=7, header=None):
elif ((res.status == http.client.FOUND or elif ((res.status == http.client.FOUND or
res.status == http.client.MOVED_PERMANENTLY) and res.status == http.client.MOVED_PERMANENTLY) and
res.getheader("Location") != url): res.getheader("Location") != url):
return getURLContent(res.getheader("Location"), timeout=timeout) return getURLContent(
urljoin(url, res.getheader("Location")),
body=body,
timeout=timeout,
header=header)
else: else:
raise IMException("A HTTP error occurs: %d - %s" % raise IMException("A HTTP error occurs: %d - %s" %
(res.status, http.client.responses[res.status])) (res.status, http.client.responses[res.status]))
def getXML(url, timeout=7): def getXML(*args, **kwargs):
"""Get content page and return XML parsed content """Get content page and return XML parsed content
Arguments: Arguments: same as getURLContent
url -- the URL to get
timeout -- maximum number of seconds to wait before returning an exception
""" """
cnt = getURLContent(url, timeout=timeout) cnt = getURLContent(*args, **kwargs)
if cnt is None: if cnt is None:
return None return None
else: else:
@ -178,15 +183,13 @@ def getXML(url, timeout=7):
return parseString(cnt) return parseString(cnt)
def getJSON(url, timeout=7): def getJSON(*args, **kwargs):
"""Get content page and return JSON content """Get content page and return JSON content
Arguments: Arguments: same as getURLContent
url -- the URL to get
timeout -- maximum number of seconds to wait before returning an exception
""" """
cnt = getURLContent(url, timeout=timeout) cnt = getURLContent(*args, **kwargs)
if cnt is None: if cnt is None:
return None return None
else: else:

View file

@ -51,11 +51,13 @@ class XMLParser:
def __init__(self, knodes): def __init__(self, knodes):
self.knodes = knodes self.knodes = knodes
def _reset(self):
self.stack = list() self.stack = list()
self.child = 0 self.child = 0
def parse_file(self, path): def parse_file(self, path):
self._reset()
p = xml.parsers.expat.ParserCreate() p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement p.StartElementHandler = self.startElement
@ -69,6 +71,7 @@ class XMLParser:
def parse_string(self, s): def parse_string(self, s):
self._reset()
p = xml.parsers.expat.ParserCreate() p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement p.StartElementHandler = self.startElement
@ -126,10 +129,13 @@ class XMLParser:
if hasattr(self.current, "endElement"): if hasattr(self.current, "endElement"):
self.current.endElement(None) self.current.endElement(None)
if hasattr(self.current, "parsedForm") and callable(self.current.parsedForm):
self.stack[-1] = self.current.parsedForm()
# Don't remove root # Don't remove root
if len(self.stack) > 1: if len(self.stack) > 1:
last = self.stack.pop() last = self.stack.pop()
if hasattr(self.current, "addChild"): if hasattr(self.current, "addChild") and callable(self.current.addChild):
if self.current.addChild(name, last): if self.current.addChild(name, last):
return return
raise TypeError(name + " tag not expected in " + self.display_stack()) raise TypeError(name + " tag not expected in " + self.display_stack())

View file

@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
import types
logger = logging.getLogger("nemubot.treatment") logger = logging.getLogger("nemubot.treatment")
@ -108,6 +109,9 @@ class MessageTreater:
msg -- message to treat msg -- message to treat
""" """
if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(msg.server, "owner") or msg.server.owner == msg.frm)
while hook is not None: while hook is not None:
res = hook.run(msg) res = hook.run(msg)
@ -116,10 +120,18 @@ class MessageTreater:
yield r yield r
elif res is not None: elif res is not None:
if not hasattr(res, "server") or res.server is None: if isinstance(res, types.GeneratorType):
res.server = msg.server for r in res:
if not hasattr(r, "server") or r.server is None:
r.server = msg.server
yield res yield r
else:
if not hasattr(res, "server") or res.server is None:
res.server = msg.server
yield res
hook = next(hook_gen, None) hook = next(hook_gen, None)

View file

@ -63,6 +63,7 @@ setup(
'nemubot', 'nemubot',
'nemubot.config', 'nemubot.config',
'nemubot.datastore', 'nemubot.datastore',
'nemubot.datastore.nodes',
'nemubot.event', 'nemubot.event',
'nemubot.exception', 'nemubot.exception',
'nemubot.hooks', 'nemubot.hooks',