Compare commits

...

30 commits

Author SHA1 Message Date
2ac37f79d9 New class storing scope 2017-07-18 00:50:30 +02:00
e3150a6061 Messages now implements Serializable 2017-07-18 00:50:30 +02:00
e0af09f3c5 New printer and parser for bot data, XML-based 2017-07-18 00:50:30 +02:00
2992c13ca7 socket: limit getaddrinfo to TCP connections 2017-07-18 00:48:30 +02:00
49c207a4c9 events: fix help when no event is defined 2017-07-18 00:48:30 +02:00
3c13043ca3 run: recreate the sync_queue on run, it seems to have strange behaviour when created before the fork 2017-07-18 00:48:30 +02:00
5b28828ede event: ensure that enough consumers are launched at the end of an event 2017-07-18 00:48:30 +02:00
2a25f5311a rename module nextstop: ratp to avoid import loop with the inderlying Python module 2017-07-18 00:48:30 +02:00
1db63600e5 main: new option -A to run as daemon 2017-07-18 00:48:30 +02:00
5f89428562 Use getaddrinfo to create the right socket 2017-07-18 00:48:30 +02:00
205a39ad70 Try to restaure frm_owner flag 2017-07-18 00:42:11 +02:00
c51b0a9170 When launched in daemon mode, attach to the socket 2017-07-18 00:42:11 +02:00
eb70fe560b Deamonize later 2017-07-18 00:42:11 +02:00
8b6f72587d Local client now detects when server close the connection 2017-07-18 00:42:11 +02:00
02838658b0 Fix communication over unix socket 2017-07-18 00:42:11 +02:00
1a813083e5 Handle multiple SIGTERM 2017-07-18 00:42:11 +02:00
3fbeb49a6c suivi: add fedex 2017-07-18 00:42:11 +02:00
f5f13202c5 suivi: use getURLContent instead of call to urllib 2017-07-18 00:42:11 +02:00
aeba947877 tools/web: fill a default Content-Type in case of POST 2017-07-18 00:42:11 +02:00
b27e01a196 tools/web: improve redirection reliability 2017-07-18 00:42:11 +02:00
4d68410777 tools/web: forward all arguments passed to getJSON and getXML to getURLContent 2017-07-18 00:42:11 +02:00
384fbc6717 Update weather module: refleting forcastAPI changes 2017-07-18 00:42:11 +02:00
63e65b2659 modulecontext: use inheritance instead of conditional init 2017-07-18 00:42:11 +02:00
bdf8a69ff0 Avoid stack-trace and DOS if event is not well formed 2017-07-18 00:42:11 +02:00
4f1dcb8524 [nextstop] Use as system wide module 2017-07-18 00:42:11 +02:00
6d2f90fe77 Allow module function to be generators 2017-07-18 00:42:10 +02:00
2e5834a89d Parse server urls using parse_qs 2017-07-18 00:42:10 +02:00
26d1f5b6e8 Format and typo 2017-07-18 00:42:10 +02:00
40fc84fcec Implement socket server subparse 2017-07-18 00:42:10 +02:00
449cb684f9 Refactor file/socket management (use poll instead of select) 2017-07-18 00:42:08 +02:00
36 changed files with 1244 additions and 700 deletions

3
.gitmodules vendored
View file

@ -1,3 +0,0 @@
[submodule "modules/nextstop/external"]
path = modules/nextstop/external
url = git://github.com/nbr23/NextStop.git

View file

@ -9,6 +9,8 @@ Requirements
*nemubot* requires at least Python 3.3 to work.
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
Some modules (like `cve`, `nextstop` or `laposte`) require the
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
but the core and framework has no dependency.

View file

@ -16,7 +16,7 @@ from more import Response
def help_full ():
return "This module store a lot of events: ny, we, " + (", ".join(context.datas.index.keys())) + "\n!eventslist: gets list of timer\n!start /something/: launch a timer"
return "This module store a lot of events: ny, we, " + (", ".join(context.datas.index.keys() if hasattr(context, "datas") else [])) + "\n!eventslist: gets list of timer\n!start /something/: launch a timer"
def load(context):

View file

@ -1,4 +0,0 @@
<?xml version="1.0" ?>
<nemubotmodule name="nextstop">
<message type="cmd" name="ratp" call="ask_ratp" />
</nemubotmodule>

View file

@ -1,55 +0,0 @@
# coding=utf-8
"""Informe les usagers des prochains passages des transports en communs de la RATP"""
from nemubot.exception import IMException
from nemubot.hooks import hook
from more import Response
nemubotversion = 3.4
from .external.src import ratp
def help_full ():
return "!ratp transport line [station]: Donne des informations sur les prochains passages du transport en commun séléctionné à l'arrêt désiré. Si aucune station n'est précisée, les liste toutes."
@hook.command("ratp")
def ask_ratp(msg):
"""Hook entry from !ratp"""
if len(msg.args) >= 3:
transport = msg.args[0]
line = msg.args[1]
station = msg.args[2]
if len(msg.args) == 4:
times = ratp.getNextStopsAtStation(transport, line, station, msg.args[3])
else:
times = ratp.getNextStopsAtStation(transport, line, station)
if len(times) == 0:
raise IMException("la station %s n'existe pas sur le %s ligne %s." % (station, transport, line))
(time, direction, stationname) = times[0]
return Response(message=["\x03\x02%s\x03\x02 direction %s" % (time, direction) for time, direction, stationname in times],
title="Prochains passages du %s ligne %s à l'arrêt %s" % (transport, line, stationname),
channel=msg.channel)
elif len(msg.args) == 2:
stations = ratp.getAllStations(msg.args[0], msg.args[1])
if len(stations) == 0:
raise IMException("aucune station trouvée.")
return Response([s for s in stations], title="Stations", channel=msg.channel)
else:
raise IMException("Mauvais usage, merci de spécifier un type de transport et une ligne, ou de consulter l'aide du module.")
@hook.command("ratp_alert")
def ratp_alert(msg):
if len(msg.args) == 2:
transport = msg.args[0]
cause = msg.args[1]
incidents = ratp.getDisturbance(cause, transport)
return Response(incidents, channel=msg.channel, nomore="No more incidents", count=" (%d more incidents)")
else:
raise IMException("Mauvais usage, merci de spécifier un type de transport et un type d'alerte (alerte, manif, travaux), ou de consulter l'aide du module.")

@ -1 +0,0 @@
Subproject commit 3d5c9b2d52fbd214f5aaad00e5f3952de919b3e5

74
modules/ratp.py Normal file
View file

@ -0,0 +1,74 @@
"""Informe les usagers des prochains passages des transports en communs de la RATP"""
# PYTHON STUFFS #######################################################
from nemubot.exception import IMException
from nemubot.hooks import hook
from more import Response
from nextstop import ratp
@hook.command("ratp",
help="Affiche les prochains horaires de passage",
help_usage={
"TRANSPORT": "Affiche les lignes du moyen de transport donné",
"TRANSPORT LINE": "Affiche les stations sur la ligne de transport donnée",
"TRANSPORT LINE STATION": "Affiche les prochains horaires de passage à l'arrêt donné",
"TRANSPORT LINE STATION DESTINATION": "Affiche les prochains horaires de passage dans la direction donnée",
})
def ask_ratp(msg):
l = len(msg.args)
transport = msg.args[0] if l > 0 else None
line = msg.args[1] if l > 1 else None
station = msg.args[2] if l > 2 else None
direction = msg.args[3] if l > 3 else None
if station is not None:
times = sorted(ratp.getNextStopsAtStation(transport, line, station, direction), key=lambda i: i[0])
if len(times) == 0:
raise IMException("la station %s n'existe pas sur le %s ligne %s." % (station, transport, line))
(time, direction, stationname) = times[0]
return Response(message=["\x03\x02%s\x03\x02 direction %s" % (time, direction) for time, direction, stationname in times],
title="Prochains passages du %s ligne %s à l'arrêt %s" % (transport, line, stationname),
channel=msg.channel)
elif line is not None:
stations = ratp.getAllStations(transport, line)
if len(stations) == 0:
raise IMException("aucune station trouvée.")
return Response(stations, title="Stations", channel=msg.channel)
elif transport is not None:
lines = ratp.getTransportLines(transport)
if len(lines) == 0:
raise IMException("aucune ligne trouvée.")
return Response(lines, title="Lignes", channel=msg.channel)
else:
raise IMException("précise au moins un moyen de transport.")
@hook.command("ratp_alert",
help="Affiche les perturbations en cours sur le réseau")
def ratp_alert(msg):
if len(msg.args) == 0:
raise IMException("précise au moins un moyen de transport.")
l = len(msg.args)
transport = msg.args[0] if l > 0 else None
line = msg.args[1] if l > 1 else None
if line is not None:
d = ratp.getDisturbanceFromLine(transport, line)
if "date" in d and d["date"] is not None:
incidents = "Au {date[date]}, {title}: {message}".format(**d)
else:
incidents = "{title}: {message}".format(**d)
else:
incidents = ratp.getDisturbance(None, transport)
return Response(incidents, channel=msg.channel, nomore="No more incidents", count=" (%d more incidents)")

View file

@ -2,14 +2,14 @@
# PYTHON STUFF ############################################
import urllib.request
import json
import urllib.parse
from bs4 import BeautifulSoup
import re
from nemubot.hooks import hook
from nemubot.exception import IMException
from nemubot.tools.web import getURLContent
from nemubot.tools.web import getURLContent, getJSON
from more import Response
@ -17,8 +17,7 @@ from more import Response
def get_tnt_info(track_id):
values = []
data = getURLContent('www.tnt.fr/public/suivi_colis/recherche/'
'visubontransport.do?bonTransport=%s' % track_id)
data = getURLContent('www.tnt.fr/public/suivi_colis/recherche/visubontransport.do?bonTransport=%s' % track_id)
soup = BeautifulSoup(data)
status_list = soup.find('div', class_='result__content')
if not status_list:
@ -32,8 +31,7 @@ def get_tnt_info(track_id):
def get_colissimo_info(colissimo_id):
colissimo_data = getURLContent("http://www.colissimo.fr/portail_colissimo/"
"suivre.do?colispart=%s" % colissimo_id)
colissimo_data = getURLContent("http://www.colissimo.fr/portail_colissimo/suivre.do?colispart=%s" % colissimo_id)
soup = BeautifulSoup(colissimo_data)
dataArray = soup.find(class_='dataArray')
@ -47,9 +45,8 @@ def get_colissimo_info(colissimo_id):
def get_chronopost_info(track_id):
data = urllib.parse.urlencode({'listeNumeros': track_id})
track_baseurl = "http://www.chronopost.fr/expedier/" \
"inputLTNumbersNoJahia.do?lang=fr_FR"
track_data = urllib.request.urlopen(track_baseurl, data.encode('utf-8'))
track_baseurl = "http://www.chronopost.fr/expedier/inputLTNumbersNoJahia.do?lang=fr_FR"
track_data = getURLContent(track_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(track_data)
infoClass = soup.find(class_='numeroColi2')
@ -65,9 +62,8 @@ def get_chronopost_info(track_id):
def get_colisprive_info(track_id):
data = urllib.parse.urlencode({'numColis': track_id})
track_baseurl = "https://www.colisprive.com/moncolis/pages/" \
"detailColis.aspx"
track_data = urllib.request.urlopen(track_baseurl, data.encode('utf-8'))
track_baseurl = "https://www.colisprive.com/moncolis/pages/detailColis.aspx"
track_data = getURLContent(track_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(track_data)
dataArray = soup.find(class_='BandeauInfoColis')
@ -82,8 +78,7 @@ def get_laposte_info(laposte_id):
data = urllib.parse.urlencode({'id': laposte_id})
laposte_baseurl = "http://www.part.csuivi.courrier.laposte.fr/suivi/index"
laposte_data = urllib.request.urlopen(laposte_baseurl,
data.encode('utf-8'))
laposte_data = getURLContent(laposte_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(laposte_data)
search_res = soup.find(class_='resultat_rech_simple_table').tbody.tr
if (soup.find(class_='resultat_rech_simple_table').thead
@ -112,8 +107,7 @@ def get_postnl_info(postnl_id):
data = urllib.parse.urlencode({'barcodes': postnl_id})
postnl_baseurl = "http://www.postnl.post/details/"
postnl_data = urllib.request.urlopen(postnl_baseurl,
data.encode('utf-8'))
postnl_data = getURLContent(postnl_baseurl, data.encode('utf-8'))
soup = BeautifulSoup(postnl_data)
if (soup.find(id='datatables')
and soup.find(id='datatables').tbody
@ -132,6 +126,42 @@ def get_postnl_info(postnl_id):
return (post_status.lower(), post_destination, post_date)
def get_fedex_info(fedex_id, lang="en_US"):
data = urllib.parse.urlencode({
'data': json.dumps({
"TrackPackagesRequest": {
"appType": "WTRK",
"appDeviceType": "DESKTOP",
"uniqueKey": "",
"processingParameters": {},
"trackingInfoList": [
{
"trackNumberInfo": {
"trackingNumber": str(fedex_id),
"trackingQualifier": "",
"trackingCarrier": ""
}
}
]
}
}),
'action': "trackpackages",
'locale': lang,
'version': 1,
'format': "json"
})
fedex_baseurl = "https://www.fedex.com/trackingCal/track"
fedex_data = getJSON(fedex_baseurl, data.encode('utf-8'))
if ("TrackPackagesResponse" in fedex_data and
"packageList" in fedex_data["TrackPackagesResponse"] and
len(fedex_data["TrackPackagesResponse"]["packageList"]) and
not fedex_data["TrackPackagesResponse"]["packageList"][0]["isInvalid"]
):
return fedex_data["TrackPackagesResponse"]["packageList"][0]
# TRACKING HANDLERS ###################################################
def handle_tnt(tracknum):
@ -189,6 +219,17 @@ def handle_coliprive(tracknum):
return ("Colis Privé: \x02%s\x0F : \x02%s\x0F." % (tracknum, info))
def handle_fedex(tracknum):
info = get_fedex_info(tracknum)
if info:
if info["displayActDeliveryDateTime"] != "":
return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, delivered on: {displayActDeliveryDateTime}.".format(**info))
elif info["statusLocationCity"] != "":
return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: estimated delivery: {displayEstDeliveryDateTime}.".format(**info))
else:
return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, estimated delivery: {displayEstDeliveryDateTime}.".format(**info))
TRACKING_HANDLERS = {
'laposte': handle_laposte,
'postnl': handle_postnl,
@ -196,6 +237,7 @@ TRACKING_HANDLERS = {
'chronopost': handle_chronopost,
'coliprive': handle_coliprive,
'tnt': handle_tnt,
'fedex': handle_fedex,
}

View file

@ -1,6 +1,6 @@
# coding=utf-8
"""The weather module"""
"""The weather module. Powered by Dark Sky <https://darksky.net/poweredby/>"""
import datetime
import re
@ -17,7 +17,7 @@ nemubotversion = 4.0
from more import Response
URL_DSAPI = "https://api.forecast.io/forecast/%s/%%s,%%s"
URL_DSAPI = "https://api.darksky.net/forecast/%s/%%s,%%s?lang=%%s&units=%%s"
def load(context):
if not context.config or "darkskyapikey" not in context.config:
@ -30,34 +30,19 @@ def load(context):
URL_DSAPI = URL_DSAPI % context.config["darkskyapikey"]
def help_full ():
return "!weather /city/: Display the current weather in /city/."
def fahrenheit2celsius(temp):
return int((temp - 32) * 50/9)/10
def mph2kmph(speed):
return int(speed * 160.9344)/100
def inh2mmh(size):
return int(size * 254)/10
def format_wth(wth):
return ("%s °C %s; precipitation (%s %% chance) intensity: %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" %
return ("%s °C %s; precipitation (%s %% chance) intensity: %s mm/h; relative humidity: %s %%; wind speed: %s m/s %s°; cloud coverage: %s %%; pressure: %s hPa; visibility: %s km; ozone: %s DU" %
(
fahrenheit2celsius(wth["temperature"]),
wth["temperature"],
wth["summary"],
int(wth["precipProbability"] * 100),
inh2mmh(wth["precipIntensity"]),
wth["precipIntensity"],
int(wth["humidity"] * 100),
mph2kmph(wth["windSpeed"]),
wth["windSpeed"],
wth["windBearing"],
int(wth["cloudCover"] * 100),
int(wth["pressure"]),
int(wth["visibility"]),
int(wth["ozone"])
))
@ -66,7 +51,7 @@ def format_forecast_daily(wth):
return ("%s; between %s-%s °C; precipitation (%s %% chance) intensity: maximum %s mm/h; relative humidity: %s %%; wind speed: %s km/h %s°; cloud coverage: %s %%; pressure: %s hPa; ozone: %s DU" %
(
wth["summary"],
fahrenheit2celsius(wth["temperatureMin"]), fahrenheit2celsius(wth["temperatureMax"]),
wth["temperatureMin"], wth["temperatureMax"],
int(wth["precipProbability"] * 100),
inh2mmh(wth["precipIntensityMax"]),
int(wth["humidity"] * 100),
@ -126,8 +111,8 @@ def treat_coord(msg):
raise IMException("indique-moi un nom de ville ou des coordonnées.")
def get_json_weather(coords):
wth = web.getJSON(URL_DSAPI % (float(coords[0]), float(coords[1])))
def get_json_weather(coords, lang="en", units="auto"):
wth = web.getJSON(URL_DSAPI % (float(coords[0]), float(coords[1]), lang, units))
# First read flags
if wth is None or "darksky-unavailable" in wth["flags"]:
@ -149,10 +134,16 @@ def cmd_coordinates(msg):
return Response("Les coordonnées de %s sont %s,%s" % (msg.args[0], coords["lat"], coords["long"]), channel=msg.channel)
@hook.command("alert")
@hook.command("alert",
keywords={
"lang=LANG": "change the output language of weather sumarry; default: en",
"units=UNITS": "return weather conditions in the requested units; default: auto",
})
def cmd_alert(msg):
loc, coords, specific = treat_coord(msg)
wth = get_json_weather(coords)
wth = get_json_weather(coords,
lang=msg.kwargs["lang"] if "lang" in msg.kwargs else "en",
units=msg.kwargs["units"] if "units" in msg.kwargs else "auto")
res = Response(channel=msg.channel, nomore="No more weather alert", count=" (%d more alerts)")
@ -166,10 +157,20 @@ def cmd_alert(msg):
return res
@hook.command("météo")
@hook.command("météo",
help="Display current weather and previsions",
help_usage={
"CITY": "Display the current weather and previsions in CITY",
},
keywords={
"lang=LANG": "change the output language of weather sumarry; default: en",
"units=UNITS": "return weather conditions in the requested units; default: auto",
})
def cmd_weather(msg):
loc, coords, specific = treat_coord(msg)
wth = get_json_weather(coords)
wth = get_json_weather(coords,
lang=msg.kwargs["lang"] if "lang" in msg.kwargs else "en",
units=msg.kwargs["units"] if "units" in msg.kwargs else "auto")
res = Response(channel=msg.channel, nomore="No more weather information")
@ -243,3 +244,7 @@ def parseask(msg):
context.save()
return Response("ok, j'ai bien noté les coordonnées de %s" % res.group("city"),
msg.channel, msg.nick)
if __name__ == "__main__":
sys.exit(main())

View file

@ -17,9 +17,9 @@
__version__ = '4.0.dev3'
__author__ = 'nemunaire'
from nemubot.modulecontext import ModuleContext
from nemubot.modulecontext import _ModuleContext
context = ModuleContext(None, None)
context = _ModuleContext()
def requires_version(min=None, max=None):
@ -53,41 +53,50 @@ def attach(pid, socketfile):
sys.stderr.write("\n")
return 1
from select import select
import select
mypoll = select.poll()
mypoll.register(sys.stdin.fileno(), select.POLLIN | select.POLLPRI)
mypoll.register(sock.fileno(), select.POLLIN | select.POLLPRI)
try:
while True:
rl, wl, xl = select([sys.stdin, sock], [], [])
for fd, flag in mypoll.poll():
if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
sock.close()
print("Connection closed.")
return 1
if sys.stdin in rl:
line = sys.stdin.readline().strip()
if line == "exit" or line == "quit":
return 0
elif line == "reload":
import os, signal
os.kill(pid, signal.SIGHUP)
print("Reload signal sent. Please wait...")
if fd == sys.stdin.fileno():
line = sys.stdin.readline().strip()
if line == "exit" or line == "quit":
return 0
elif line == "reload":
import os, signal
os.kill(pid, signal.SIGHUP)
print("Reload signal sent. Please wait...")
elif line == "shutdown":
import os, signal
os.kill(pid, signal.SIGTERM)
print("Shutdown signal sent. Please wait...")
elif line == "shutdown":
import os, signal
os.kill(pid, signal.SIGTERM)
print("Shutdown signal sent. Please wait...")
elif line == "kill":
import os, signal
os.kill(pid, signal.SIGKILL)
print("Signal sent...")
return 0
elif line == "kill":
import os, signal
os.kill(pid, signal.SIGKILL)
print("Signal sent...")
return 0
elif line == "stack" or line == "stacks":
import os, signal
os.kill(pid, signal.SIGUSR1)
print("Debug signal sent. Consult logs.")
elif line == "stack" or line == "stacks":
import os, signal
os.kill(pid, signal.SIGUSR1)
print("Debug signal sent. Consult logs.")
else:
sock.send(line.encode() + b'\r\n')
else:
sock.send(line.encode() + b'\r\n')
if fd == sock.fileno():
sys.stdout.write(sock.recv(2048).decode())
if sock in rl:
sys.stdout.write(sock.recv(2048).decode())
except KeyboardInterrupt:
pass
except:
@ -97,13 +106,28 @@ def attach(pid, socketfile):
return 0
def daemonize():
def daemonize(socketfile=None, autoattach=True):
"""Detach the running process to run as a daemon
"""
import os
import sys
if socketfile is not None:
try:
pid = os.fork()
if pid > 0:
if autoattach:
import time
os.waitpid(pid, 0)
time.sleep(1)
sys.exit(attach(pid, socketfile))
else:
sys.exit(0)
except OSError as err:
sys.stderr.write("Unable to fork: %s\n" % err)
sys.exit(1)
try:
pid = os.fork()
if pid > 0:

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2017 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -37,6 +37,9 @@ def main():
default=["./modules/"],
help="directory to use as modules store")
parser.add_argument("-A", "--no-attach", action="store_true",
help="don't attach after fork")
parser.add_argument("-d", "--debug", action="store_true",
help="don't deamonize, keep in foreground")
@ -71,32 +74,10 @@ def main():
args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile))
args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile))
args.logfile = os.path.abspath(os.path.expanduser(args.logfile))
args.files = [ x for x in map(os.path.abspath, args.files)]
args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)]
args.files = [x for x in map(os.path.abspath, args.files)]
args.modules_path = [x for x in map(os.path.abspath, args.modules_path)]
# Check if an instance is already launched
if args.pidfile is not None and os.path.isfile(args.pidfile):
with open(args.pidfile, "r") as f:
pid = int(f.readline())
try:
os.kill(pid, 0)
except OSError:
pass
else:
from nemubot import attach
sys.exit(attach(pid, args.socketfile))
# Daemonize
if not args.debug:
from nemubot import daemonize
daemonize()
# Store PID to pidfile
if args.pidfile is not None:
with open(args.pidfile, "w+") as f:
f.write(str(os.getpid()))
# Setup loggin interface
# Setup logging interface
import logging
logger = logging.getLogger("nemubot")
logger.setLevel(logging.DEBUG)
@ -115,6 +96,18 @@ def main():
fh.setFormatter(formatter)
logger.addHandler(fh)
# Check if an instance is already launched
if args.pidfile is not None and os.path.isfile(args.pidfile):
with open(args.pidfile, "r") as f:
pid = int(f.readline())
try:
os.kill(pid, 0)
except OSError:
pass
else:
from nemubot import attach
sys.exit(attach(pid, args.socketfile))
# Add modules dir paths
modules_paths = list()
for path in args.modules_path:
@ -125,7 +118,7 @@ def main():
# Create bot context
from nemubot import datastore
from nemubot.bot import Bot
from nemubot.bot import Bot, sync_act
context = Bot(modules_paths=modules_paths,
data_store=datastore.XML(args.data_path),
verbosity=args.verbose)
@ -141,7 +134,7 @@ def main():
# Load requested configuration files
for path in args.files:
if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path])
sync_act("loadconf", path)
else:
logger.error("%s is not a readable file", path)
@ -149,6 +142,17 @@ def main():
for module in args.module:
__import__(module)
if args.socketfile:
from nemubot.server.socket import UnixSocketListener
context.add_server(UnixSocketListener(new_server_cb=context.add_server,
location=args.socketfile,
name="master_socket"))
# Daemonize
if not args.debug:
from nemubot import daemonize
daemonize(args.socketfile, not args.no_attach)
# Signals handling
def sigtermhandler(signum, frame):
"""On SIGTERM and SIGINT, quit nicely"""
@ -165,22 +169,27 @@ def main():
# Reload configuration file
for path in args.files:
if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path])
sync_act("loadconf", path)
signal.signal(signal.SIGHUP, sighuphandler)
def sigusr1handler(signum, frame):
"""On SIGHUSR1, display stacktraces"""
import traceback
import threading, traceback
for threadId, stack in sys._current_frames().items():
logger.debug("########### Thread %d:\n%s",
threadId,
thName = "#%d" % threadId
for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack)))
signal.signal(signal.SIGUSR1, sigusr1handler)
if args.socketfile:
from nemubot.server.socket import SocketListener
context.add_server(SocketListener(context.add_server, "master_socket",
sock_location=args.socketfile))
# Store PID to pidfile
if args.pidfile is not None:
with open(args.pidfile, "w+") as f:
f.write(str(os.getpid()))
# context can change when performing an hotswap, always join the latest context
oldcontext = None
@ -195,5 +204,6 @@ def main():
sigusr1handler(0, None)
sys.exit(0)
if __name__ == "__main__":
main()

View file

@ -16,7 +16,9 @@
from datetime import datetime, timezone
import logging
from multiprocessing import JoinableQueue
import threading
import select
import sys
from nemubot import __version__
@ -26,6 +28,11 @@ import nemubot.hooks
logger = logging.getLogger("nemubot")
sync_queue = JoinableQueue()
def sync_act(*args):
sync_queue.put(list(args))
class Bot(threading.Thread):
@ -42,7 +49,7 @@ class Bot(threading.Thread):
verbosity -- verbosity level
"""
threading.Thread.__init__(self)
super().__init__(name="Nemubot main")
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)",
__version__,
@ -61,6 +68,7 @@ class Bot(threading.Thread):
self.datastore.open()
# Keep global context: servers and modules
self._poll = select.poll()
self.servers = dict()
self.modules = dict()
self.modules_configuration = dict()
@ -138,60 +146,84 @@ class Bot(threading.Thread):
self.cnsr_queue = Queue()
self.cnsr_thrd = list()
self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self):
from select import select
from nemubot.server import _lock, _rlist, _wlist, _xlist
global sync_queue
# Rewrite the sync_queue, as the daemonization process tend to disturb it
old_sync_queue, sync_queue = sync_queue, JoinableQueue()
while not old_sync_queue.empty():
sync_queue.put_nowait(old_sync_queue.get())
self._poll.register(sync_queue._reader, select.POLLIN | select.POLLPRI)
logger.info("Starting main loop")
self.stop = False
while not self.stop:
with _lock:
try:
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
except:
logger.error("Something went wrong in select")
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for fd, flag in self._poll.poll():
# Handle internal socket passing orders
if fd != sync_queue._reader.fileno() and fd in self.servers:
srv = self.servers[fd]
for x in xl:
try:
x.exception()
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
try:
w.write_select()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
if flag & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
try:
self.receive_message(r, i)
srv.exception(flag)
except:
logger.exception("Uncatched exception on server read")
logger.exception("Uncatched exception on server exception")
if srv.fileno() > 0:
if flag & (select.POLLOUT):
try:
srv.async_write()
except:
logger.exception("Uncatched exception on server write")
if flag & (select.POLLIN | select.POLLPRI):
try:
for i in srv.async_read():
self.receive_message(srv, i)
except:
logger.exception("Uncatched exception on server read")
else:
del self.servers[fd]
# Always check the sync queue
while not sync_queue.empty():
args = sync_queue.get()
action = args.pop(0)
logger.debug("Executing sync_queue action %s%s", action, args)
if action == "sckt" and len(args) >= 2:
try:
if args[0] == "write":
self._poll.modify(int(args[1]), select.POLLOUT | select.POLLIN | select.POLLPRI)
elif args[0] == "unwrite":
self._poll.modify(int(args[1]), select.POLLIN | select.POLLPRI)
elif args[0] == "register":
self._poll.register(int(args[1]), select.POLLIN | select.POLLPRI)
elif args[0] == "unregister":
self._poll.unregister(int(args[1]))
except:
logger.exception("Unhandled excpetion during action:")
elif action == "exit":
self.quit()
elif action == "launch_consumer":
pass # This is treated after the loop
elif action == "loadconf":
for path in args:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
sync_queue.task_done()
# Launch new consumer threads if necessary
@ -202,17 +234,7 @@ class Bot(threading.Thread):
c = Consumer(self)
self.cnsr_thrd.append(c)
c.start()
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
sync_queue = None
logger.info("Ending main loop")
@ -385,7 +407,13 @@ class Bot(threading.Thread):
self.event_timer.cancel()
if len(self.events):
remaining = self.events[0].time_left.total_seconds()
try:
remaining = self.events[0].time_left.total_seconds()
except:
logger.exception("An error occurs during event time calculation:")
self.events.pop(0)
return self._update_event_timer()
logger.debug("Update timer: next event in %d seconds", remaining)
self.event_timer = threading.Timer(remaining if remaining > 0 else 0, self._end_event_timer)
self.event_timer.start()
@ -400,6 +428,7 @@ class Bot(threading.Thread):
while len(self.events) > 0 and datetime.now(timezone.utc) >= self.events[0].current:
evt = self.events.pop(0)
self.cnsr_queue.put_nowait(EventConsumer(evt))
sync_act("launch_consumer")
self._update_event_timer()
@ -419,7 +448,7 @@ class Bot(threading.Thread):
self.servers[fileno] = srv
self.servers[srv.name] = srv
if autoconnect and not hasattr(self, "noautoconnect"):
srv.open()
srv.connect()
return True
else:
@ -463,7 +492,7 @@ class Bot(threading.Thread):
module.print = prnt
# Create module context
from nemubot.modulecontext import ModuleContext
from nemubot.modulecontext import _ModuleContext, ModuleContext
module.__nemubot_context__ = ModuleContext(self, module)
if not hasattr(module, "logger"):
@ -471,7 +500,7 @@ class Bot(threading.Thread):
# Replace imported context by real one
for attr in module.__dict__:
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext:
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == _ModuleContext:
module.__dict__[attr] = module.__nemubot_context__
# Register decorated functions
@ -532,28 +561,29 @@ class Bot(threading.Thread):
def quit(self):
"""Save and unload modules and disconnect servers"""
self.datastore.close()
if self.event_timer is not None:
logger.info("Stop the event timer...")
self.event_timer.cancel()
logger.info("Save and unload all modules...")
for mod in self.modules.items():
self.unload_module(mod)
logger.info("Close all servers connection...")
for srv in [self.servers[k] for k in self.servers]:
srv.close()
logger.info("Stop consumers")
k = self.cnsr_thrd
for cnsr in k:
cnsr.stop = True
logger.info("Save and unload all modules...")
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
self.datastore.close()
logger.info("Close all servers connection...")
k = list(self.servers.keys())
for srv in k:
self.servers[srv].close()
self.stop = True
if self.stop is False or sync_queue is not None:
self.stop = True
sync_act("end")
sync_queue.join()
# Treatment

View file

@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.config import get_boolean
from nemubot.tools.xmlparser.genericnode import GenericNode
from nemubot.datastore.nodes.generic import GenericNode
class Module(GenericNode):

View file

@ -23,8 +23,7 @@ class Abstract:
"""Initialize a new empty storage tree
"""
from nemubot.tools.xmlparser import module_state
return module_state.ModuleState("nemubotstate")
return None
def open(self):
return

View file

@ -0,0 +1,18 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.generic import ParsingNode
from nemubot.datastore.nodes.serializable import Serializable

View file

@ -14,11 +14,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class ListNode:
from nemubot.datastore.nodes.serializable import Serializable
class ListNode(Serializable):
"""XML node representing a Python dictionnnary
"""
serializetag = "list"
def __init__(self, **kwargs):
self.items = list()
@ -27,6 +32,9 @@ class ListNode:
self.items.append(child)
return True
def parsedForm(self):
return self.items
def __len__(self):
return len(self.items)
@ -44,11 +52,21 @@ class ListNode:
return self.items.__repr__()
class DictNode:
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
for i in self.items:
node.children.append(ParsingNode.serialize_node(i))
return node
class DictNode(Serializable):
"""XML node representing a Python dictionnnary
"""
serializetag = "dict"
def __init__(self, **kwargs):
self.items = dict()
self._cur = None
@ -56,44 +74,20 @@ class DictNode:
def startElement(self, name, attrs):
if self._cur is None and "key" in attrs:
self._cur = (attrs["key"], "")
return True
self._cur = attrs["key"]
return False
def characters(self, content):
if self._cur is not None:
key, cnt = self._cur
if isinstance(cnt, str):
cnt += content
self._cur = key, cnt
def endElement(self, name):
if name is None or self._cur is None:
return
key, cnt = self._cur
if isinstance(cnt, list) and len(cnt) == 1:
self.items[key] = cnt
else:
self.items[key] = cnt
self._cur = None
return True
def addChild(self, name, child):
if self._cur is None:
return False
key, cnt = self._cur
if not isinstance(cnt, list):
cnt = []
cnt.append(child)
self._cur = key, cnt
self.items[self._cur] = child
self._cur = None
return True
def parsedForm(self):
return self.items
def __getitem__(self, item):
return self.items[item]
@ -106,3 +100,13 @@ class DictNode:
def __repr__(self):
return self.items.__repr__()
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
for k in self.items:
chld = ParsingNode.serialize_node(self.items[k])
chld.attrs["key"] = k
node.children.append(chld)
return node

View file

@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.serializable import Serializable
class ParsingNode:
"""Allow any kind of subtags, just keep parsed ones
@ -53,6 +56,47 @@ class ParsingNode:
return item in self.attrs
def serialize_node(node, **def_kwargs):
"""Serialize any node or basic data to a ParsingNode instance"""
if isinstance(node, Serializable):
node = node.serialize()
if isinstance(node, str):
from nemubot.datastore.nodes.python import StringNode
pn = StringNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, int):
from nemubot.datastore.nodes.python import IntNode
pn = IntNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, float):
from nemubot.datastore.nodes.python import FloatNode
pn = FloatNode(**def_kwargs)
pn.value = node
return pn
elif isinstance(node, list):
from nemubot.datastore.nodes.basic import ListNode
pn = ListNode(**def_kwargs)
pn.items = node
return pn.serialize()
elif isinstance(node, dict):
from nemubot.datastore.nodes.basic import DictNode
pn = DictNode(**def_kwargs)
pn.items = node
return pn.serialize()
else:
assert isinstance(node, ParsingNode)
return node
class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary

View file

@ -0,0 +1,91 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes.serializable import Serializable
class PythonTypeNode(Serializable):
"""XML node representing a Python simple type
"""
def __init__(self, **kwargs):
self.value = None
self._cnt = ""
def characters(self, content):
self._cnt += content
def endElement(self, name):
raise NotImplemented
def __repr__(self):
return self.value.__repr__()
def parsedForm(self):
return self.value
def serialize(self):
raise NotImplemented
class IntNode(PythonTypeNode):
serializetag = "int"
def endElement(self, name):
self.value = int(self._cnt)
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
class FloatNode(PythonTypeNode):
serializetag = "float"
def endElement(self, name):
self.value = float(self._cnt)
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node
class StringNode(PythonTypeNode):
serializetag = "str"
def endElement(self, name):
self.value = str(self._cnt)
return True
def serialize(self):
from nemubot.datastore.nodes.generic import ParsingNode
node = ParsingNode(tag=self.serializetag)
node.content = str(self.value)
return node

View file

@ -0,0 +1,22 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class Serializable:
def serialize(self):
# Implementations of this function should return ParsingNode items
return NotImplemented

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -36,17 +36,24 @@ class XML(Abstract):
rotate -- auto-backup files?
"""
self.basedir = basedir
self.basedir = os.path.abspath(basedir)
self.rotate = rotate
self.nb_save = 0
logger.info("Initiate XML datastore at %s, rotation %s",
self.basedir,
"enabled" if self.rotate else "disabled")
def open(self):
"""Lock the directory"""
if not os.path.isdir(self.basedir):
logger.debug("Datastore directory not found, creating: %s", self.basedir)
os.mkdir(self.basedir)
lock_path = os.path.join(self.basedir, ".used_by_nemubot")
lock_path = self._get_lock_file_path()
logger.debug("Locking datastore directory via %s", lock_path)
self.lock_file = open(lock_path, 'a+')
ok = True
@ -64,56 +71,95 @@ class XML(Abstract):
self.lock_file.write(str(os.getpid()))
self.lock_file.flush()
logger.info("Datastore successfuly opened at %s", self.basedir)
return True
def close(self):
"""Release a locked path"""
if hasattr(self, "lock_file"):
self.lock_file.close()
lock_path = os.path.join(self.basedir, ".used_by_nemubot")
lock_path = self._get_lock_file_path()
if os.path.isdir(self.basedir) and os.path.exists(lock_path):
os.unlink(lock_path)
del self.lock_file
logger.info("Datastore successfully closed at %s", self.basedir)
return True
else:
logger.warn("Datastore not open/locked or lock file not found")
return False
def _get_data_file_path(self, module):
"""Get the path to the module data file"""
return os.path.join(self.basedir, module + ".xml")
def load(self, module):
def _get_lock_file_path(self):
"""Get the path to the datastore lock file"""
return os.path.join(self.basedir, ".used_by_nemubot")
def load(self, module, extendsTags={}):
"""Load data for the given module
Argument:
module -- the module name of data to load
"""
logger.debug("Trying to load data for %s%s",
module,
(" with tags: " + ", ".join(extendsTags.keys())) if len(extendsTags) else "")
data_file = self._get_data_file_path(module)
def parse(path):
from nemubot.tools.xmlparser import XMLParser
from nemubot.datastore.nodes import basic as basicNodes
from nemubot.datastore.nodes import python as pythonNodes
from nemubot.message.command import Command
from nemubot.scope import Scope
d = {
basicNodes.ListNode.serializetag: basicNodes.ListNode,
basicNodes.DictNode.serializetag: basicNodes.DictNode,
pythonNodes.IntNode.serializetag: pythonNodes.IntNode,
pythonNodes.FloatNode.serializetag: pythonNodes.FloatNode,
pythonNodes.StringNode.serializetag: pythonNodes.StringNode,
Command.serializetag: Command,
Scope.serializetag: Scope,
}
d.update(extendsTags)
p = XMLParser(d)
return p.parse_file(path)
# Try to load original file
if os.path.isfile(data_file):
from nemubot.tools.xmlparser import parse_file
try:
return parse_file(data_file)
return parse(data_file)
except xml.parsers.expat.ExpatError:
# Try to load from backup
for i in range(10):
path = data_file + "." + str(i)
if os.path.isfile(path):
try:
cnt = parse_file(path)
cnt = parse(path)
logger.warn("Restoring from backup: %s", path)
logger.warn("Restoring data from backup: %s", path)
return cnt
except xml.parsers.expat.ExpatError:
continue
# Default case: initialize a new empty datastore
logger.warn("No data found in store for %s, creating new set", module)
return Abstract.load(self, module)
def _rotate(self, path):
"""Backup given path
@ -130,6 +176,25 @@ class XML(Abstract):
if os.path.isfile(src):
os.rename(src, dst)
def _save_node(self, gen, node):
from nemubot.datastore.nodes.generic import ParsingNode
# First, get the serialized form of the node
node = ParsingNode.serialize_node(node)
assert node.tag is not None, "Undefined tag name"
gen.startElement(node.tag, {k: str(node.attrs[k]) for k in node.attrs})
gen.characters(node.content)
for child in node.children:
self._save_node(gen, child)
gen.endElement(node.tag)
def save(self, module, data):
"""Load data for the given module
@ -139,8 +204,22 @@ class XML(Abstract):
"""
path = self._get_data_file_path(module)
logger.debug("Trying to save data for module %s in %s", module, path)
if self.rotate:
self._rotate(path)
return data.save(path)
import tempfile
_, tmpath = tempfile.mkstemp()
with open(tmpath, "w") as f:
import xml.sax.saxutils
gen = xml.sax.saxutils.XMLGenerator(f, "utf-8")
gen.startDocument()
self._save_node(gen, data)
gen.endDocument()
# Atomic save
import shutil
shutil.move(tmpath, path)
return True

View file

@ -16,12 +16,17 @@
from datetime import datetime, timezone
from nemubot.datastore.nodes import Serializable
class Abstract:
class Abstract(Serializable):
"""This class represents an abstract message"""
def __init__(self, server=None, date=None, to=None, to_response=None, frm=None):
serializetag = "nemubotAMessage"
def __init__(self, server=None, date=None, to=None, to_response=None, frm=None, frm_owner=False):
"""Initialize an abstract message
Arguments:
@ -40,7 +45,7 @@ class Abstract:
else [ to_response ])
self.frm = frm # None allowed when it designate this bot
self.frm_owner = False # Filled later, in consumer
self.frm_owner = frm_owner
@property
@ -65,6 +70,14 @@ class Abstract:
return self.frm
@property
def scope(self):
from nemubot.scope import Scope
return Scope(server=self.server,
channel=self.to_response[0],
nick=self.frm)
def accept(self, visitor):
visitor.visit(self)
@ -78,7 +91,8 @@ class Abstract:
"date": self.date,
"to": self.to,
"to_response": self._to_response,
"frm": self.frm
"frm": self.frm,
"frm_owner": self.frm_owner,
}
for w in without:
@ -86,3 +100,8 @@ class Abstract:
del ret[w]
return ret
def serialize(self):
from nemubot.datastore.nodes import ParsingNode
return ParsingNode(tag=Abstract.serializetag, **self.export_args())

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -21,6 +21,9 @@ class Command(Abstract):
"""This class represents a specialized TextMessage"""
serializetag = "nemubotCommand"
def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs):
super().__init__(*nargs, **kargs)
@ -28,17 +31,35 @@ class Command(Abstract):
self.args = args if args is not None else list()
self.kwargs = kwargs if kwargs is not None else dict()
def __str__(self):
def __repr__(self):
return self.cmd + " @" + ",@".join(self.args)
@property
def cmds(self):
# TODO: this is for legacy modules
return [self.cmd] + self.args
def addChild(self, name, child):
if name == "list":
self.args = child
elif name == "dict":
self.kwargs = child
else:
return False
return True
def serialize(self):
from nemubot.datastore.nodes import ParsingNode
node = ParsingNode(tag=Command.serializetag, cmd=self.cmd)
if len(self.args):
node.children.append(ParsingNode.serialize_node(self.args))
if len(self.kwargs):
node.children.append(ParsingNode.serialize_node(self.kwargs))
return node
class OwnerCommand(Command):
"""This class represents a special command incomming from the owner"""
serializetag = "nemubotOCommand"
pass

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2017 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,105 +14,62 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class ModuleContext:
class _ModuleContext:
def __init__(self, context, module):
"""Initialize the module context
arguments:
context -- the bot context
module -- the module
"""
def __init__(self, module=None):
self.module = module
if module is not None:
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
self.module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
else:
module_name = ""
# Load module configuration if exists
if (context is not None and
module_name in context.modules_configuration):
self.config = context.modules_configuration[module_name]
else:
from nemubot.config.module import Module
self.config = Module(module_name)
self.module_name = ""
self.hooks = list()
self.events = list()
self.debug = context.verbosity > 0 if context is not None else False
self.extendtags = dict()
self.debug = False
from nemubot.config.module import Module
self.config = Module(self.module_name)
def load_data(self):
return None
def add_hook(self, hook, *triggers):
from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
# Define some callbacks
if context is not None:
def load_data():
return context.datastore.load(module_name)
def del_hook(self, hook, *triggers):
from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.remove((triggers, hook))
def add_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
return context.treater.hm.add_hook(hook, *triggers)
def subtreat(self, msg):
return None
def del_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook
self.hooks.remove((triggers, hook))
return context.treater.hm.del_hooks(*triggers, hook=hook)
def add_event(self, evt, eid=None):
return self.events.append((evt, eid))
def subtreat(msg):
yield from context.treater.treat_msg(msg)
def add_event(evt, eid=None):
return context.add_event(evt, eid, module_src=module)
def del_event(evt):
return context.del_event(evt, module_src=module)
def del_event(self, evt):
for i in self.events:
e, eid = i
if e == evt:
self.events.remove(i)
return True
return False
def send_response(server, res):
if server in context.servers:
if res.server is not None:
return context.servers[res.server].send_response(res)
else:
return context.servers[server].send_response(res)
else:
module.logger.error("Try to send a message to the unknown server: %s", server)
return False
def send_response(self, server, res):
self.module.logger.info("Send response: %s", res)
else: # Used when using outside of nemubot
def load_data():
from nemubot.tools.xmlparser import module_state
return module_state.ModuleState("nemubotstate")
def add_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
def del_hook(hook, *triggers):
assert isinstance(hook, AbstractHook), hook
self.hooks.remove((triggers, hook))
def subtreat(msg):
return None
def add_event(evt, eid=None):
return context.add_event(evt, eid, module_src=module)
def del_event(evt):
return context.del_event(evt, module_src=module)
def send_response(server, res):
module.logger.info("Send response: %s", res)
def save():
context.datastore.save(module_name, self.data)
def subparse(orig, cnt):
if orig.server in context.servers:
return context.servers[orig.server].subparse(orig, cnt)
self.load_data = load_data
self.add_hook = add_hook
self.del_hook = del_hook
self.add_event = add_event
self.del_event = del_event
self.save = save
self.send_response = send_response
self.subtreat = subtreat
self.subparse = subparse
def save(self):
# Don't save if no data has been access
if hasattr(self, "_data"):
context.datastore.save(self.module_name, self.data)
def subparse(self, orig, cnt):
if orig.server in self.context.servers:
return self.context.servers[orig.server].subparse(orig, cnt)
@property
def data(self):
@ -120,6 +77,21 @@ class ModuleContext:
self._data = self.load_data()
return self._data
@data.setter
def data(self, value):
assert value is not None
self._data = value
def register_tags(self, **tags):
self.extendtags.update(tags)
def unregister_tags(self, *tags):
for t in tags:
del self.extendtags[t]
def unload(self):
"""Perform actions for unloading the module"""
@ -129,7 +101,62 @@ class ModuleContext:
self.del_hook(h, *s)
# Remove registered events
for e in self.events:
self.del_event(e)
for evt, eid, module_src in self.events:
self.del_event(evt)
self.save()
class ModuleContext(_ModuleContext):
def __init__(self, context, *args, **kwargs):
"""Initialize the module context
arguments:
context -- the bot context
module -- the module
"""
super().__init__(*args, **kwargs)
# Load module configuration if exists
if self.module_name in context.modules_configuration:
self.config = context.modules_configuration[self.module_name]
self.context = context
self.debug = context.verbosity > 0
def load_data(self):
return self.context.datastore.load(self.module_name, extendsTags=self.extendtags)
def add_hook(self, hook, *triggers):
from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
return self.context.treater.hm.add_hook(hook, *triggers)
def del_hook(self, hook, *triggers):
from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.remove((triggers, hook))
return self.context.treater.hm.del_hooks(*triggers, hook=hook)
def subtreat(self, msg):
yield from self.context.treater.treat_msg(msg)
def add_event(self, evt, eid=None):
return self.context.add_event(evt, eid, module_src=self.module)
def del_event(self, evt):
return self.context.del_event(evt, module_src=self.module)
def send_response(self, server, res):
if server in self.context.servers:
if res.server is not None:
return self.context.servers[res.server].send_response(res)
else:
return self.context.servers[server].send_response(res)
else:
self.module.logger.error("Try to send a message to the unknown server: %s", server)
return False

83
nemubot/scope.py Normal file
View file

@ -0,0 +1,83 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from nemubot.datastore.nodes import Serializable
class Scope(Serializable):
serializetag = "nemubot-scope"
default_limit = "channel"
def __init__(self, server, channel, nick, limit=default_limit):
self._server = server
self._channel = channel
self._nick = nick
self._limit = limit
def sameServer(self, server):
return self._server is None or self._server == server
def sameChannel(self, server, channel):
return self.sameServer(server) and (self._channel is None or self._channel == channel)
def sameNick(self, server, channel, nick):
return self.sameChannel(server, channel) and (self._nick is None or self._nick == nick)
def check(self, scope, limit=None):
return self.checkScope(scope._server, scope._channel, scope._nick, limit)
def checkScope(self, server, channel, nick, limit=None):
if limit is None: limit = self._limit
assert limit == "global" or limit == "server" or limit == "channel" or limit == "nick"
if limit == "server":
return self.sameServer(server)
elif limit == "channel":
return self.sameChannel(server, channel)
elif limit == "nick":
return self.sameNick(server, channel, nick)
else:
return True
def narrow(self, scope):
return scope is None or (
scope._limit == "global" or
(scope._limit == "server" and (self._limit == "nick" or self._limit == "channel")) or
(scope._limit == "channel" and self._limit == "nick")
)
def serialize(self):
from nemubot.datastore.nodes import ParsingNode
args = {}
if self._server is not None:
args["server"] = self._server
if self._channel is not None:
args["channel"] = self._channel
if self._nick is not None:
args["nick"] = self._nick
if self._limit is not None:
args["limit"] = self._limit
return ParsingNode(tag=self.serializetag, **args)

View file

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None):
super().__init__(self)
super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion

View file

@ -16,21 +16,22 @@
from datetime import datetime
import re
import socket
from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter
from nemubot.server.message.IRC import IRC as IRCMessage
from nemubot.server.socket import SocketServer
from nemubot.server.socket import SocketServer, SecureSocketServer
class IRC(SocketServer):
class _IRC:
"""Concrete implementation of a connexion to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None,
def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None):
channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server
Keyword arguments:
@ -54,7 +55,8 @@ class IRC(SocketServer):
self.owner = owner
self.realname = realname
super().__init__(host=host, port=port, ssl=ssl, name=self.username + "@" + host + ":" + str(port))
super().__init__(name=self.username + "@" + host + ":" + str(port),
host=host, port=port, **kwargs)
self.printer = IRCPrinter
self.encoding = encoding
@ -231,20 +233,19 @@ class IRC(SocketServer):
# Open/close
def open(self):
if super().open():
if self.password is not None:
self.write("PASS :" + self.password)
if self.capabilities is not None:
self.write("CAP LS")
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
return True
return False
def connect(self):
super().connect()
if self.password is not None:
self.write("PASS :" + self.password)
if self.capabilities is not None:
self.write("CAP LS")
self.write("NICK :" + self.nick)
self.write("USER %s %s bla :%s" % (self.username, socket.getfqdn(), self.realname))
def close(self):
if not self.closed:
if not self._closed:
self.write("QUIT")
return super().close()
@ -253,8 +254,8 @@ class IRC(SocketServer):
# Read
def read(self):
for line in super().read():
def async_read(self):
for line in super().async_read():
# PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding)
@ -273,3 +274,10 @@ class IRC(SocketServer):
def subparse(self, orig, cnt):
msg = IRCMessage(("@time=%s :%s!user@host.com PRIVMSG %s :%s" % (orig.date.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), orig.frm, ",".join(orig.to), cnt)).encode(self.encoding), self.encoding)
return msg.to_bot_message(self)
class IRC(_IRC, SocketServer):
pass
class IRC_secure(_IRC, SecureSocketServer):
pass

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,57 +14,64 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock()
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
from urllib.parse import urlparse, unquote
def factory(uri, ssl=False, **init_args):
from urllib.parse import urlparse, unquote, parse_qs
o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
args = init_args
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True
if o.scheme == "ircs": ssl = True
if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username
if o.password is not None: args["password"] = o.password
queries = o.query.split("&")
for q in queries:
if "=" in q:
key, val = tuple(q.split("=", 1))
else:
key, val = q, ""
if key == "msg":
if "on_connect" not in args:
args["on_connect"] = []
args["on_connect"].append("PRIVMSG %s :%s" % (target, unquote(val)))
elif key == "key":
if "channels" not in args:
args["channels"] = []
args["channels"].append((target, unquote(val)))
elif key == "pass":
args["password"] = unquote(val)
elif key == "charset":
args["encoding"] = unquote(val)
if ssl:
try:
from ssl import create_default_context
args["_context"] = create_default_context()
except ImportError:
# Python 3.3 compat
from ssl import SSLContext, PROTOCOL_TLSv1
args["_context"] = SSLContext(PROTOCOL_TLSv1)
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
# Read query string
params = parse_qs(o.query)
if "msg" in params:
if "on_connect" not in args:
args["on_connect"] = []
args["on_connect"].append("PRIVMSG %s :%s" % (target, params["msg"]))
if "key" in params:
if "channels" not in args:
args["channels"] = []
args["channels"].append((target, params["key"]))
if "pass" in params:
args["password"] = params["pass"]
if "charset" in params:
args["encoding"] = params["charset"]
#
if "channels" not in args and "isnick" not in modifiers:
args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer
return IRCServer(**args)
else:
return None
if ssl:
from nemubot.server.IRC import IRC_secure as SecureIRCServer
srv = SecureIRCServer(**args)
else:
from nemubot.server.IRC import IRC as IRCServer
srv = IRCServer(**args)
return srv

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,34 +14,30 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging
import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist
from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase):
class AbstractServer:
"""An abstract server: handle communication with an IM server"""
def __init__(self, name=None, send_callback=None):
def __init__(self, name=None, **kwargs):
"""Initialize an abstract server
Keyword argument:
send_callback -- Callback when developper want to send a message
name -- Identifier of the socket, for convinience
"""
self._name = name
super().__init__()
super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + self.name)
self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
else:
self._send_callback = self._write_select
@property
@ -54,40 +50,28 @@ class AbstractServer(io.IOBase):
# Open/close
def __enter__(self):
self.open()
return self
def connect(self, *args, **kwargs):
"""Register the server in _poll"""
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._on_connect()
def _on_connect(self):
sync_act("sckt", "register", self.fileno())
def __exit__(self, type, value, traceback):
self.close()
def close(self, *args, **kwargs):
"""Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self):
"""Generic open function that register the server un _rlist in case
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
if self.fileno() > 0:
sync_act("sckt", "unregister", self.fileno())
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
super().close(*args, **kwargs)
# Writes
@ -99,13 +83,16 @@ class AbstractServer(io.IOBase):
message -- message to send
"""
self._send_callback(message)
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt", "write", self.fileno())
def write_select(self):
"""Internal function used by the select function"""
def async_write(self):
"""Internal function used when the file descriptor is writable"""
try:
_wlist.remove(self)
sync_act("sckt", "unwrite", self.fileno())
while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done()
@ -114,19 +101,6 @@ class AbstractServer(io.IOBase):
pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response):
"""Send a formated Message class
@ -149,13 +123,39 @@ class AbstractServer(io.IOBase):
# Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg):
raise NotImplemented
# Exceptions
def exception(self):
"""Exception occurs in fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.name)
def exception(self, flags):
"""Exception occurs on fd"""
self.close()

View file

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import socket
import unittest
from nemubot.server import factory
@ -22,34 +23,36 @@ class TestFactory(unittest.TestCase):
def test_IRC1(self):
from nemubot.server.IRC import IRC as IRCServer
from nemubot.server.IRC import IRC_secure as IRCSServer
# <host>: If omitted, the client must connect to a prespecified default IRC server.
server = factory("irc:///")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "localhost")
self.assertFalse(server.ssl)
self.assertEqual(server._sockaddr,
socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("ircs:///")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "localhost")
self.assertTrue(server.ssl)
self.assertIsInstance(server, IRCSServer)
self.assertEqual(server._sockaddr,
socket.getaddrinfo("localhost", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("irc://host1")
server = factory("irc://freenode.net")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host1")
self.assertFalse(server.ssl)
self.assertEqual(server._sockaddr,
socket.getaddrinfo("freenode.net", 6667, proto=socket.IPPROTO_TCP)[0][4])
server = factory("irc://host2:6667")
server = factory("irc://freenode.org:1234")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host2")
self.assertEqual(server.port, 6667)
self.assertFalse(server.ssl)
self.assertEqual(server._sockaddr,
socket.getaddrinfo("freenode.org", 1234, proto=socket.IPPROTO_TCP)[0][4])
server = factory("ircs://host3:194/")
self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host3")
self.assertEqual(server.port, 194)
self.assertTrue(server.ssl)
server = factory("ircs://nemunai.re:194/")
self.assertIsInstance(server, IRCSServer)
self.assertEqual(server._sockaddr,
socket.getaddrinfo("nemunai.re", 194, proto=socket.IPPROTO_TCP)[0][4])
with self.assertRaises(socket.gaierror):
factory("irc://_nonexistent.nemunai.re")
if __name__ == '__main__':

View file

@ -150,7 +150,8 @@ class IRC(Abstract):
"date": self.tags["time"],
"to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers],
"frm": self.nick
"frm": self.nick,
"frm_owner": self.nick == srv.owner
}
# If CTCP, remove 0x01

View file

@ -0,0 +1,15 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
# Copyright (C) 2012-2016 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,117 +14,33 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import ssl
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer):
class _Socket(AbstractServer):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)"""
"""Concrete implementation of a socket connection"""
def __init__(self, sock_location=None,
host=None, port=None,
sock=None,
ssl=False,
name=None):
def __init__(self, printer=SocketPrinter, **kwargs):
"""Create a server socket
Keyword arguments:
sock_location -- Path to the UNIX socket
host -- Hostname of the INET socket
port -- Port of the INET socket
sock -- Already connected socket
ssl -- Should TLS connection enabled
name -- Convinience name
"""
import socket
assert(sock is None or isinstance(sock, socket.SocketType))
assert(port is None or isinstance(port, int))
super().__init__(name=name)
if sock is None:
if sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connect_to = sock_location
elif host is not None:
for af, socktype, proto, canonname, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
self.socket = socket.socket(af, socktype, proto)
self.connect_to = sa
break
else:
self.socket = sock
self.ssl = ssl
super().__init__(**kwargs)
self.readbuffer = b''
self.printer = SocketPrinter
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket._closed
# Open/close
def open(self):
if not self.closed:
return True
try:
self.socket.connect(self.connect_to)
self.logger.info("Connected to %s", self.connect_to)
except:
self.socket.close()
self.logger.exception("Unable to connect to %s",
self.connect_to)
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return super().open()
def close(self):
import socket
# Flush the sending queue before close
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if not self.closed:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
return super().close()
self.printer = printer
# Write
def _write(self, cnt):
if self.closed:
return
self.socket.sendall(cnt)
self.sendall(cnt)
def format(self, txt):
@ -136,19 +52,12 @@ class SocketServer(AbstractServer):
# Read
def read(self):
if self.closed:
return []
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line
def recv(self, n=1024):
return super().recv(n)
def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex
line = line.strip().decode()
@ -157,48 +66,107 @@ class SocketServer(AbstractServer):
except ValueError:
args = line.split(' ')
yield message.Command(cmd=args[0], args=args[1:], server=self.name, to=["you"], frm="you")
if len(args):
yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
class SocketListener(AbstractServer):
def __init__(self, new_server_cb, name, sock_location=None, host=None, port=None, ssl=None):
super().__init__(name=name)
self.new_server_cb = new_server_cb
self.sock_location = sock_location
self.host = host
self.port = port
self.ssl = ssl
self.nb_son = 0
def subparse(self, orig, cnt):
for m in self.parse(cnt):
m.to = orig.to
m.frm = orig.frm
m.date = orig.date
yield m
def fileno(self):
return self.socket.fileno() if self.socket else None
class _SocketServer(_Socket):
def __init__(self, host, port, bind=None, **kwargs):
(family, type, proto, canonname, sockaddr) = socket.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)[0]
if isinstance(self, ssl.SSLSocket) and "server_hostname" not in kwargs:
kwargs["server_hostname"] = host
super().__init__(family=family, type=type, proto=proto, **kwargs)
self._sockaddr = sockaddr
self._bind = bind
@property
def closed(self):
"""Indicator of the connection aliveness"""
return self.socket is None
def connect(self):
self.logger.info("Connection to %s:%d", *self._sockaddr[:2])
super().connect(self._sockaddr)
if self._bind:
super().bind(self._bind)
def open(self):
import os
import socket
class SocketServer(_SocketServer, socket.socket):
pass
if self.sock_location is not None:
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.remove(self.sock_location)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.bind((self.host, self.port))
self.socket.listen(5)
return super().open()
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
pass
class UnixSocket:
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def connect(self):
self.logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
class SocketClient(_Socket, socket.socket):
def read(self):
return self.recv()
class _Listener:
def __init__(self, new_server_cb, instanciate=SocketClient, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self):
conn, addr = self.accept()
fileno = conn.fileno()
self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
ss.connect = ss._on_connect
self._new_server_cb(ss, autoconnect=True)
return b''
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self):
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._on_connect()
def close(self):
@ -206,25 +174,14 @@ class SocketListener(AbstractServer):
import socket
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
self.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
return super().close()
super().close()
# Read
def read(self):
if self.closed:
return []
conn, addr = self.socket.accept()
self.nb_son += 1
ss = SocketServer(name=self.name + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []
try:
if self._socket_path is not None:
os.remove(self._socket_path)
except:
pass

View file

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from urllib.parse import urlparse, urlsplit, urlunsplit
from urllib.parse import urljoin, urlparse, urlsplit, urlunsplit
from nemubot.exception import IMException
@ -108,6 +108,9 @@ def getURLContent(url, body=None, timeout=7, header=None):
elif "User-agent" not in header:
header["User-agent"] = "Nemubot v%s" % __version__
if body is not None and "Content-Type" not in header:
header["Content-Type"] = "application/x-www-form-urlencoded"
import socket
try:
if o.query != '':
@ -156,21 +159,23 @@ def getURLContent(url, body=None, timeout=7, header=None):
elif ((res.status == http.client.FOUND or
res.status == http.client.MOVED_PERMANENTLY) and
res.getheader("Location") != url):
return getURLContent(res.getheader("Location"), timeout=timeout)
return getURLContent(
urljoin(url, res.getheader("Location")),
body=body,
timeout=timeout,
header=header)
else:
raise IMException("A HTTP error occurs: %d - %s" %
(res.status, http.client.responses[res.status]))
def getXML(url, timeout=7):
def getXML(*args, **kwargs):
"""Get content page and return XML parsed content
Arguments:
url -- the URL to get
timeout -- maximum number of seconds to wait before returning an exception
Arguments: same as getURLContent
"""
cnt = getURLContent(url, timeout=timeout)
cnt = getURLContent(*args, **kwargs)
if cnt is None:
return None
else:
@ -178,15 +183,13 @@ def getXML(url, timeout=7):
return parseString(cnt)
def getJSON(url, timeout=7):
def getJSON(*args, **kwargs):
"""Get content page and return JSON content
Arguments:
url -- the URL to get
timeout -- maximum number of seconds to wait before returning an exception
Arguments: same as getURLContent
"""
cnt = getURLContent(url, timeout=timeout)
cnt = getURLContent(*args, **kwargs)
if cnt is None:
return None
else:

View file

@ -51,11 +51,13 @@ class XMLParser:
def __init__(self, knodes):
self.knodes = knodes
def _reset(self):
self.stack = list()
self.child = 0
def parse_file(self, path):
self._reset()
p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement
@ -69,6 +71,7 @@ class XMLParser:
def parse_string(self, s):
self._reset()
p = xml.parsers.expat.ParserCreate()
p.StartElementHandler = self.startElement
@ -126,10 +129,13 @@ class XMLParser:
if hasattr(self.current, "endElement"):
self.current.endElement(None)
if hasattr(self.current, "parsedForm") and callable(self.current.parsedForm):
self.stack[-1] = self.current.parsedForm()
# Don't remove root
if len(self.stack) > 1:
last = self.stack.pop()
if hasattr(self.current, "addChild"):
if hasattr(self.current, "addChild") and callable(self.current.addChild):
if self.current.addChild(name, last):
return
raise TypeError(name + " tag not expected in " + self.display_stack())

View file

@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import types
logger = logging.getLogger("nemubot.treatment")
@ -108,6 +109,9 @@ class MessageTreater:
msg -- message to treat
"""
if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(msg.server, "owner") or msg.server.owner == msg.frm)
while hook is not None:
res = hook.run(msg)
@ -116,10 +120,18 @@ class MessageTreater:
yield r
elif res is not None:
if not hasattr(res, "server") or res.server is None:
res.server = msg.server
if isinstance(res, types.GeneratorType):
for r in res:
if not hasattr(r, "server") or r.server is None:
r.server = msg.server
yield res
yield r
else:
if not hasattr(res, "server") or res.server is None:
res.server = msg.server
yield res
hook = next(hook_gen, None)

View file

@ -63,6 +63,7 @@ setup(
'nemubot',
'nemubot.config',
'nemubot.datastore',
'nemubot.datastore.nodes',
'nemubot.event',
'nemubot.exception',
'nemubot.hooks',