diff --git a/modules/cve.py b/modules/cve.py
index 23a0302..fa09de5 100644
--- a/modules/cve.py
+++ b/modules/cve.py
@@ -10,7 +10,7 @@ from nemubot.tools.web import getURLContent, striphtml
from more import Response
-BASEURL_NIST = 'https://web.nvd.nist.gov/view/vuln/detail?vulnId='
+BASEURL_NIST = 'https://nvd.nist.gov/vuln/detail/'
# MODULE CORE #########################################################
@@ -19,15 +19,40 @@ def get_cve(cve_id):
search_url = BASEURL_NIST + quote(cve_id.upper())
soup = BeautifulSoup(getURLContent(search_url))
- vuln = soup.body.find(class_="vuln-detail")
- cvss = vuln.findAll('div')[4]
- return [
- "Base score: " + cvss.findAll('div')[0].findAll('a')[0].text.strip(),
- vuln.findAll('p')[0].text, # description
- striphtml(vuln.findAll('div')[0].text).strip(), # publication date
- striphtml(vuln.findAll('div')[1].text).strip(), # last revised
- ]
+ return {
+ "description": soup.body.find(attrs={"data-testid":"vuln-description"}).text.strip(),
+ "published": soup.body.find(attrs={"data-testid":"vuln-published-on"}).text.strip(),
+ "last_modified": soup.body.find(attrs={"data-testid":"vuln-last-modified-on"}).text.strip(),
+ "source": soup.body.find(attrs={"data-testid":"vuln-source"}).text.strip(),
+
+ "base_score": float(soup.body.find(attrs={"data-testid":"vuln-cvssv3-base-score-link"}).text.strip()),
+ "severity": soup.body.find(attrs={"data-testid":"vuln-cvssv3-base-score-severity"}).text.strip(),
+ "impact_score": float(soup.body.find(attrs={"data-testid":"vuln-cvssv3-impact-score"}).text.strip()),
+ "exploitability_score": float(soup.body.find(attrs={"data-testid":"vuln-cvssv3-exploitability-score"}).text.strip()),
+
+ "av": soup.body.find(attrs={"data-testid":"vuln-cvssv3-av"}).text.strip(),
+ "ac": soup.body.find(attrs={"data-testid":"vuln-cvssv3-ac"}).text.strip(),
+ "pr": soup.body.find(attrs={"data-testid":"vuln-cvssv3-pr"}).text.strip(),
+ "ui": soup.body.find(attrs={"data-testid":"vuln-cvssv3-ui"}).text.strip(),
+ "s": soup.body.find(attrs={"data-testid":"vuln-cvssv3-s"}).text.strip(),
+ "c": soup.body.find(attrs={"data-testid":"vuln-cvssv3-c"}).text.strip(),
+ "i": soup.body.find(attrs={"data-testid":"vuln-cvssv3-i"}).text.strip(),
+ "a": soup.body.find(attrs={"data-testid":"vuln-cvssv3-a"}).text.strip(),
+ }
+
+
+def display_metrics(av, ac, pr, ui, s, c, i, a, **kwargs):
+ ret = []
+ if av != "None": ret.append("Attack Vector: \x02%s\x0F" % av)
+ if ac != "None": ret.append("Attack Complexity: \x02%s\x0F" % ac)
+ if pr != "None": ret.append("Privileges Required: \x02%s\x0F" % pr)
+ if ui != "None": ret.append("User Interaction: \x02%s\x0F" % ui)
+ if s != "Unchanged": ret.append("Scope: \x02%s\x0F" % s)
+ if c != "None": ret.append("Confidentiality: \x02%s\x0F" % c)
+ if i != "None": ret.append("Integrity: \x02%s\x0F" % i)
+ if a != "None": ret.append("Availability: \x02%s\x0F" % a)
+ return ', '.join(ret)
# MODULE INTERFACE ####################################################
@@ -42,6 +67,10 @@ def get_cve_desc(msg):
if cve_id[:3].lower() != 'cve':
cve_id = 'cve-' + cve_id
- res.append_message(get_cve(cve_id))
+ cve = get_cve(cve_id)
+ metrics = display_metrics(**cve)
+ res.append_message("{cveid}: Base score: \x02{base_score} {severity}\x0F (impact: \x02{impact_score}\x0F, exploitability: \x02{exploitability_score}\x0F; {metrics}), from \x02{source}\x0F, last modified on \x02{last_modified}\x0F. {description}".format(cveid=cve_id, metrics=metrics, **cve))
return res
+
+print(get_cve("CVE-2017-11108"))
diff --git a/modules/disas.py b/modules/disas.py
new file mode 100644
index 0000000..669ccc1
--- /dev/null
+++ b/modules/disas.py
@@ -0,0 +1,85 @@
+"""The Ultimate Disassembler Module"""
+
+# PYTHON STUFFS #######################################################
+
+import capstone
+
+from nemubot.exception import IMException
+from nemubot.hooks import hook
+
+from more import Response
+
+
+# MODULE CORE #########################################################
+
+ARCHITECTURES = {
+ "arm": capstone.CS_ARCH_ARM,
+ "arm64": capstone.CS_ARCH_ARM64,
+ "mips": capstone.CS_ARCH_MIPS,
+ "ppc": capstone.CS_ARCH_PPC,
+ "sparc": capstone.CS_ARCH_SPARC,
+ "sysz": capstone.CS_ARCH_SYSZ,
+ "x86": capstone.CS_ARCH_X86,
+ "xcore": capstone.CS_ARCH_XCORE,
+}
+
+MODES = {
+ "arm": capstone.CS_MODE_ARM,
+ "thumb": capstone.CS_MODE_THUMB,
+ "mips32": capstone.CS_MODE_MIPS32,
+ "mips64": capstone.CS_MODE_MIPS64,
+ "mips32r6": capstone.CS_MODE_MIPS32R6,
+ "16": capstone.CS_MODE_16,
+ "32": capstone.CS_MODE_32,
+ "64": capstone.CS_MODE_64,
+ "le": capstone.CS_MODE_LITTLE_ENDIAN,
+ "be": capstone.CS_MODE_BIG_ENDIAN,
+ "micro": capstone.CS_MODE_MICRO,
+ "mclass": capstone.CS_MODE_MCLASS,
+ "v8": capstone.CS_MODE_V8,
+ "v9": capstone.CS_MODE_V9,
+}
+
+# MODULE INTERFACE ####################################################
+
+@hook.command("disas",
+ help="Display assembly code",
+ help_usage={"CODE": "Display assembly code corresponding to the given CODE"},
+ keywords={
+ "arch=ARCH": "Specify the architecture of the code to disassemble (default: x86, choose between: %s)" % ', '.join(ARCHITECTURES.keys()),
+ "modes=MODE[,MODE]": "Specify hardware mode of the code to disassemble (default: 32, between: %s)" % ', '.join(MODES.keys()),
+ })
+def cmd_disas(msg):
+ if not len(msg.args):
+ raise IMException("please give me some code")
+
+ # Determine the architecture
+ if "architecture" in msg.kwargs:
+ if msg.kwargs["architecture"] not in ARCHITECTURES:
+ raise IMException("unknown architectures '%s'" % msg.kwargs["architecture"])
+ architecture = ARCHITECTURES[msg.kwargs["architecture"]]
+ else:
+ architecture = capstone.CS_ARCH_X86
+
+ # Determine hardware modes
+ if "modes" in msg.kwargs:
+ modes = 0
+ for mode in msg.kwargs["modes"].split(','):
+ if mode not in MODES:
+ raise IMException("unknown mode '%s'" % mode)
+ modes += MODES[mode]
+ else:
+ modes = capstone.CS_MODE_32
+
+ # Get the code
+ code = bytearray.fromhex(''.join([a.replace("0x", "") for a in msg.args]))
+
+ # Setup capstone
+ md = capstone.Cs(architecture, modes)
+
+ res = Response(channel=msg.channel, nomore="No more instruction")
+
+ for isn in md.disasm(code, 0x1000):
+ res.append_message("%s %s" %(isn.mnemonic, isn.op_str), title="0x%x" % isn.address)
+
+ return res
diff --git a/modules/freetarifs.py b/modules/freetarifs.py
new file mode 100644
index 0000000..b96a30f
--- /dev/null
+++ b/modules/freetarifs.py
@@ -0,0 +1,64 @@
+"""Inform about Free Mobile tarifs"""
+
+# PYTHON STUFFS #######################################################
+
+import urllib.parse
+from bs4 import BeautifulSoup
+
+from nemubot.exception import IMException
+from nemubot.hooks import hook
+from nemubot.tools import web
+
+from more import Response
+
+
+# MODULE CORE #########################################################
+
+ACT = {
+ "ff_toFixe": "Appel vers les fixes",
+ "ff_toMobile": "Appel vers les mobiles",
+ "ff_smsSendedToCountry": "SMS vers le pays",
+ "ff_mmsSendedToCountry": "MMS vers le pays",
+ "fc_callToFrance": "Appel vers la France",
+ "fc_smsToFrance": "SMS vers la france",
+ "fc_mmsSended": "MMS vers la france",
+ "fc_callToSameCountry": "Réception des appels",
+ "fc_callReceived": "Appel dans le pays",
+ "fc_smsReceived": "SMS (Réception)",
+ "fc_mmsReceived": "MMS (Réception)",
+ "fc_moDataFromCountry": "Data",
+}
+
+def get_land_tarif(country, forfait="pkgFREE"):
+ url = "http://mobile.international.free.fr/?" + urllib.parse.urlencode({'pays': country})
+ page = web.getURLContent(url)
+ soup = BeautifulSoup(page)
+
+ fact = soup.find(class_=forfait)
+
+ if fact is None:
+ raise IMException("Country or forfait not found.")
+
+ res = {}
+ for s in ACT.keys():
+ try:
+ res[s] = fact.find(attrs={"data-bind": "text: " + s}).text + " " + fact.find(attrs={"data-bind": "html: " + s + "Unit"}).text
+ except AttributeError:
+ res[s] = "inclus"
+
+ return res
+
+@hook.command("freetarifs",
+ help="Show Free Mobile tarifs for given contries",
+ help_usage={"COUNTRY": "Show Free Mobile tarifs for given CONTRY"},
+ keywords={
+ "forfait=FORFAIT": "Related forfait between Free (default) and 2euro"
+ })
+def get_freetarif(msg):
+ res = Response(channel=msg.channel)
+
+ for country in msg.args:
+ t = get_land_tarif(country.lower().capitalize(), "pkg" + (msg.kwargs["forfait"] if "forfait" in msg.kwargs else "FREE").upper())
+ res.append_message(["\x02%s\x0F : %s" % (ACT[k], t[k]) for k in sorted(ACT.keys(), reverse=True)], title=country)
+
+ return res
diff --git a/modules/openroute.py b/modules/openroute.py
new file mode 100644
index 0000000..440b05a
--- /dev/null
+++ b/modules/openroute.py
@@ -0,0 +1,158 @@
+"""Lost? use our commands to find your way!"""
+
+# PYTHON STUFFS #######################################################
+
+import re
+import urllib.parse
+
+from nemubot.exception import IMException
+from nemubot.hooks import hook
+from nemubot.tools import web
+
+from more import Response
+
+# GLOBALS #############################################################
+
+URL_DIRECTIONS_API = "https://api.openrouteservice.org/directions?api_key=%s&"
+URL_GEOCODE_API = "https://api.openrouteservice.org/geocoding?api_key=%s&"
+
+waytype = [
+ "unknown",
+ "state road",
+ "road",
+ "street",
+ "path",
+ "track",
+ "cycleway",
+ "footway",
+ "steps",
+ "ferry",
+ "construction",
+]
+
+
+# LOADING #############################################################
+
+def load(context):
+ if not context.config or "apikey" not in context.config:
+ raise ImportError("You need an OpenRouteService API key in order to use this "
+ "module. Add it to the module configuration file:\n"
+ "\nRegister at https://developers.openrouteservice.org")
+ global URL_DIRECTIONS_API
+ URL_DIRECTIONS_API = URL_DIRECTIONS_API % context.config["apikey"]
+ global URL_GEOCODE_API
+ URL_GEOCODE_API = URL_GEOCODE_API % context.config["apikey"]
+
+
+# MODULE CORE #########################################################
+
+def approx_distance(lng):
+ if lng > 1111:
+ return "%f km" % (lng / 1000)
+ else:
+ return "%f m" % lng
+
+
+def approx_duration(sec):
+ days = int(sec / 86400)
+ if days > 0:
+ return "%d days %f hours" % (days, (sec % 86400) / 3600)
+ hours = int((sec % 86400) / 3600)
+ if hours > 0:
+ return "%d hours %f minutes" % (hours, (sec % 3600) / 60)
+ minutes = (sec % 3600) / 60
+ if minutes > 0:
+ return "%d minutes" % minutes
+ else:
+ return "%d seconds" % sec
+
+
+def geocode(query, limit=7):
+ obj = web.getJSON(URL_GEOCODE_API + urllib.parse.urlencode({
+ 'query': query,
+ 'limit': limit,
+ }))
+
+ for f in obj["features"]:
+ yield f["geometry"]["coordinates"], f["properties"]
+
+
+def firstgeocode(query):
+ for g in geocode(query, limit=1):
+ return g
+
+
+def where(loc):
+ return "{name} {city} {state} {county} {country}".format(**loc)
+
+
+def directions(coordinates, **kwargs):
+ kwargs['coordinates'] = '|'.join(coordinates)
+
+ print(URL_DIRECTIONS_API + urllib.parse.urlencode(kwargs))
+ return web.getJSON(URL_DIRECTIONS_API + urllib.parse.urlencode(kwargs), decode_error=True)
+
+
+# MODULE INTERFACE ####################################################
+
+@hook.command("geocode",
+ help="Get GPS coordinates of a place",
+ help_usage={
+ "PLACE": "Get GPS coordinates of PLACE"
+ })
+def cmd_geocode(msg):
+ res = Response(channel=msg.channel, nick=msg.frm,
+ nomore="No more geocode", count=" (%s more geocode)")
+
+ for loc in geocode(' '.join(msg.args)):
+ res.append_message("%s is at %s,%s" % (
+ where(loc[1]),
+ loc[0][1], loc[0][0],
+ ))
+
+ return res
+
+
+@hook.command("directions",
+ help="Get routing instructions",
+ help_usage={
+ "POINT1 POINT2 ...": "Get routing instructions to go from POINT1 to the last POINTX via intermediates POINTX"
+ },
+ keywords={
+ "profile=PROF": "One of driving-car, driving-hgv, cycling-regular, cycling-road, cycling-safe, cycling-mountain, cycling-tour, cycling-electric, foot-walking, foot-hiking, wheelchair. Default: foot-walking",
+ "preference=PREF": "One of fastest, shortest, recommended. Default: recommended",
+ "lang=LANG": "default: en",
+ })
+def cmd_directions(msg):
+ drcts = directions(["{0},{1}".format(*firstgeocode(g)[0]) for g in msg.args],
+ profile=msg.kwargs["profile"] if "profile" in msg.kwargs else "foot-walking",
+ preference=msg.kwargs["preference"] if "preference" in msg.kwargs else "recommended",
+ units="m",
+ language=msg.kwargs["lang"] if "lang" in msg.kwargs else "en",
+ geometry=False,
+ instructions=True,
+ instruction_format="text")
+ if "error" in drcts and "message" in drcts["error"] and drcts["error"]["message"]:
+ raise IMException(drcts["error"]["message"])
+
+ if "routes" not in drcts or not drcts["routes"]:
+ raise IMException("No route available for this trip")
+
+ myway = drcts["routes"][0]
+ myway["summary"]["strduration"] = approx_duration(myway["summary"]["duration"])
+ myway["summary"]["strdistance"] = approx_distance(myway["summary"]["distance"])
+ res = Response("Trip summary: {strdistance} in approximate {strduration}; elevation +{ascent} m -{descent} m".format(**myway["summary"]), channel=msg.channel, count=" (%d more steps)", nomore="You have arrived!")
+
+ def formatSegments(segments):
+ for segment in segments:
+ for step in segment["steps"]:
+ step["strtype"] = waytype[step["type"]]
+ step["strduration"] = approx_duration(step["duration"])
+ step["strdistance"] = approx_distance(step["distance"])
+ yield "{instruction} for {strdistance} on {strtype} (approximate time: {strduration})".format(**step)
+
+ if "segments" in myway:
+ res.append_message([m for m in formatSegments(myway["segments"])])
+
+ return res
diff --git a/modules/pkgs.py b/modules/pkgs.py
new file mode 100644
index 0000000..5a7b0a9
--- /dev/null
+++ b/modules/pkgs.py
@@ -0,0 +1,68 @@
+"""Get information about common software"""
+
+# PYTHON STUFFS #######################################################
+
+import portage
+
+from nemubot import context
+from nemubot.exception import IMException
+from nemubot.hooks import hook
+
+from more import Response
+
+DB = None
+
+# MODULE CORE #########################################################
+
+def get_db():
+ global DB
+ if DB is None:
+ DB = portage.db[portage.root]["porttree"].dbapi
+ return DB
+
+
+def package_info(pkgname):
+ pv = get_db().xmatch("match-all", pkgname)
+ if not pv:
+ raise IMException("No package named '%s' found" % pkgname)
+
+ bv = get_db().xmatch("bestmatch-visible", pkgname)
+ pvsplit = portage.catpkgsplit(bv if bv else pv[-1])
+ info = get_db().aux_get(bv if bv else pv[-1], ["DESCRIPTION", "HOMEPAGE", "LICENSE", "IUSE", "KEYWORDS"])
+
+ return {
+ "pkgname": '/'.join(pvsplit[:2]),
+ "category": pvsplit[0],
+ "shortname": pvsplit[1],
+ "lastvers": '-'.join(pvsplit[2:]) if pvsplit[3] != "r0" else pvsplit[2],
+ "othersvers": ['-'.join(portage.catpkgsplit(p)[2:]) for p in pv if p != bv],
+ "description": info[0],
+ "homepage": info[1],
+ "license": info[2],
+ "uses": info[3],
+ "keywords": info[4],
+ }
+
+
+# MODULE INTERFACE ####################################################
+
+@hook.command("eix",
+ help="Get information about a package",
+ help_usage={
+ "NAME": "Get information about a software NAME"
+ })
+def cmd_eix(msg):
+ if not len(msg.args):
+ raise IMException("please give me a package to search")
+
+ def srch(term):
+ try:
+ yield package_info(term)
+ except portage.exception.AmbiguousPackageName as e:
+ for i in e.args[0]:
+ yield package_info(i)
+
+ res = Response(channel=msg.channel, count=" (%d more packages)", nomore="No more package '%s'" % msg.args[0])
+ for pi in srch(msg.args[0]):
+ res.append_message("\x03\x02{pkgname}:\x03\x02 {description} - {homepage} - {license} - last revisions: \x03\x02{lastvers}\x03\x02{ov}".format(ov=(", " + ', '.join(pi["othersvers"])) if pi["othersvers"] else "", **pi))
+ return res
diff --git a/modules/suivi.py b/modules/suivi.py
index a6f6ab4..6ad13e9 100644
--- a/modules/suivi.py
+++ b/modules/suivi.py
@@ -126,6 +126,24 @@ def get_postnl_info(postnl_id):
return (post_status.lower(), post_destination, post_date)
+def get_usps_info(usps_id):
+ usps_parcelurl = "https://tools.usps.com/go/TrackConfirmAction_input?" + urllib.parse.urlencode({'qtc_tLabels1': usps_id})
+
+ usps_data = getURLContent(usps_parcelurl)
+ soup = BeautifulSoup(usps_data)
+ if (soup.find(class_="tracking_history")
+ and soup.find(class_="tracking_history").find(class_="row_notification")
+ and soup.find(class_="tracking_history").find(class_="row_top").find_all("td")):
+ notification = soup.find(class_="tracking_history").find(class_="row_notification").text.strip()
+ date = re.sub(r"\s+", " ", soup.find(class_="tracking_history").find(class_="row_top").find_all("td")[0].text.strip())
+ status = soup.find(class_="tracking_history").find(class_="row_top").find_all("td")[1].text.strip()
+ last_location = soup.find(class_="tracking_history").find(class_="row_top").find_all("td")[2].text.strip()
+
+ print(notification)
+
+ return (notification, date, status, last_location)
+
+
def get_fedex_info(fedex_id, lang="en_US"):
data = urllib.parse.urlencode({
'data': json.dumps({
@@ -156,11 +174,22 @@ def get_fedex_info(fedex_id, lang="en_US"):
if ("TrackPackagesResponse" in fedex_data and
"packageList" in fedex_data["TrackPackagesResponse"] and
- len(fedex_data["TrackPackagesResponse"]["packageList"])
+ len(fedex_data["TrackPackagesResponse"]["packageList"]) and
+ not fedex_data["TrackPackagesResponse"]["errorList"][0]["code"] and
+ not fedex_data["TrackPackagesResponse"]["packageList"][0]["errorList"][0]["code"]
):
return fedex_data["TrackPackagesResponse"]["packageList"][0]
+def get_dhl_info(dhl_id, lang="en"):
+ dhl_parcelurl = "http://www.dhl.com/shipmentTracking?" + urllib.parse.urlencode({'AWB': dhl_id})
+
+ dhl_data = getJSON(dhl_parcelurl)
+
+ if "results" in dhl_data and dhl_data["results"]:
+ return dhl_data["results"][0]
+
+
# TRACKING HANDLERS ###################################################
def handle_tnt(tracknum):
@@ -195,6 +224,13 @@ def handle_postnl(tracknum):
")." % (tracknum, post_status, post_destination, post_date))
+def handle_usps(tracknum):
+ info = get_usps_info(tracknum)
+ if info:
+ notif, last_date, last_status, last_location = info
+ return ("USPS \x02{tracknum}\x0F is {last_status} in \x02{last_location}\x0F as of {last_date}: {notif}".format(tracknum=tracknum, notif=notif, last_date=last_date, last_status=last_status.lower(), last_location=last_location))
+
+
def handle_colissimo(tracknum):
info = get_colissimo_info(tracknum)
if info:
@@ -229,6 +265,12 @@ def handle_fedex(tracknum):
return ("{trackingCarrierDesc}: \x02{statusWithDetails}\x0F: in \x02{statusLocationCity}, {statusLocationCntryCD}\x0F, estimated delivery: {displayEstDeliveryDateTime}.".format(**info))
+def handle_dhl(tracknum):
+ info = get_dhl_info(tracknum)
+ if info:
+ return "DHL {label} {id}: \x02{description}\x0F".format(**info)
+
+
TRACKING_HANDLERS = {
'laposte': handle_laposte,
'postnl': handle_postnl,
@@ -237,6 +279,8 @@ TRACKING_HANDLERS = {
'coliprive': handle_coliprive,
'tnt': handle_tnt,
'fedex': handle_fedex,
+ 'dhl': handle_dhl,
+ 'usps': handle_usps,
}
diff --git a/nemubot/bot.py b/nemubot/bot.py
index b0d3915..aa1cb3e 100644
--- a/nemubot/bot.py
+++ b/nemubot/bot.py
@@ -20,6 +20,7 @@ from multiprocessing import JoinableQueue
import threading
import select
import sys
+import weakref
from nemubot import __version__
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
@@ -99,15 +100,15 @@ class Bot(threading.Thread):
from more import Response
res = Response(channel=msg.to_response)
if len(msg.args) >= 1:
- if msg.args[0] in self.modules:
- if hasattr(self.modules[msg.args[0]], "help_full"):
- hlp = self.modules[msg.args[0]].help_full()
+ if msg.args[0] in self.modules and self.modules[msg.args[0]]() is not None:
+ if hasattr(self.modules[msg.args[0]](), "help_full"):
+ hlp = self.modules[msg.args[0]]().help_full()
if isinstance(hlp, Response):
return hlp
else:
res.append_message(hlp)
else:
- res.append_message([str(h) for s,h in self.modules[msg.args[0]].__nemubot_context__.hooks], title="Available commands for module " + msg.args[0])
+ res.append_message([str(h) for s,h in self.modules[msg.args[0]]().__nemubot_context__.hooks], title="Available commands for module " + msg.args[0])
elif msg.args[0][0] == "!":
from nemubot.message.command import Command
for h in self.treater._in_hooks(Command(msg.args[0][1:])):
@@ -137,7 +138,7 @@ class Bot(threading.Thread):
res.append_message(title="Pour plus de détails sur un module, "
"envoyez \"!help nomdumodule\". Voici la liste"
" de tous les modules disponibles localement",
- message=["\x03\x02%s\x03\x02 (%s)" % (im, self.modules[im].__doc__) for im in self.modules if self.modules[im].__doc__])
+ message=["\x03\x02%s\x03\x02 (%s)" % (im, self.modules[im]().__doc__) for im in self.modules if self.modules[im]() is not None and self.modules[im]().__doc__])
return res
self.treater.hm.add_hook(nemubot.hooks.Command(_help_msg, "help"), "in", "Command")
@@ -518,18 +519,20 @@ class Bot(threading.Thread):
raise
# Save a reference to the module
- self.modules[module_name] = module
+ self.modules[module_name] = weakref.ref(module)
+ logger.info("Module '%s' successfully loaded.", module_name)
def unload_module(self, name):
"""Unload a module"""
- if name in self.modules:
- self.modules[name].print("Unloading module %s" % name)
+ if name in self.modules and self.modules[name]() is not None:
+ module = self.modules[name]()
+ module.print("Unloading module %s" % name)
# Call the user defined unload method
- if hasattr(self.modules[name], "unload"):
- self.modules[name].unload(self)
- self.modules[name].__nemubot_context__.unload()
+ if hasattr(module, "unload"):
+ module.unload(self)
+ module.__nemubot_context__.unload()
# Remove from the nemubot dict
del self.modules[name]
@@ -566,7 +569,7 @@ class Bot(threading.Thread):
self.event_timer.cancel()
logger.info("Save and unload all modules...")
- for mod in self.modules.items():
+ for mod in [m for m in self.modules.keys()]:
self.unload_module(mod)
logger.info("Close all servers connection...")
diff --git a/nemubot/datastore/xml.py b/nemubot/datastore/xml.py
index 46dca70..025c0c5 100644
--- a/nemubot/datastore/xml.py
+++ b/nemubot/datastore/xml.py
@@ -143,4 +143,15 @@ class XML(Abstract):
if self.rotate:
self._rotate(path)
- return data.save(path)
+ import tempfile
+ _, tmpath = tempfile.mkstemp()
+ with open(tmpath, "w") as f:
+ import xml.sax.saxutils
+ gen = xml.sax.saxutils.XMLGenerator(f, "utf-8")
+ gen.startDocument()
+ data.saveElement(gen)
+ gen.endDocument()
+
+ # Atomic save
+ import shutil
+ shutil.move(tmpath, path)
diff --git a/nemubot/modulecontext.py b/nemubot/modulecontext.py
index 877b8de..bfb1938 100644
--- a/nemubot/modulecontext.py
+++ b/nemubot/modulecontext.py
@@ -14,6 +14,9 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from nemubot.hooks import Abstract as AbstractHook
+from nemubot.server.abstract import AbstractServer
+
class _ModuleContext:
def __init__(self, module=None):
@@ -37,12 +40,10 @@ class _ModuleContext:
return module_state.ModuleState("nemubotstate")
def add_hook(self, hook, *triggers):
- from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
def del_hook(self, hook, *triggers):
- from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.remove((triggers, hook))
@@ -67,7 +68,9 @@ class _ModuleContext:
self.context.datastore.save(self.module_name, self.data)
def subparse(self, orig, cnt):
- if orig.server in self.context.servers:
+ if isinstance(orig.server, AbstractServer):
+ return orig.server.subparse(orig, cnt)
+ elif orig.server in self.context.servers:
return self.context.servers[orig.server].subparse(orig, cnt)
@property
@@ -115,13 +118,11 @@ class ModuleContext(_ModuleContext):
return self.context.datastore.load(self.module_name)
def add_hook(self, hook, *triggers):
- from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.append((triggers, hook))
return self.context.treater.hm.add_hook(hook, *triggers)
def del_hook(self, hook, *triggers):
- from nemubot.hooks import Abstract as AbstractHook
assert isinstance(hook, AbstractHook), hook
self.hooks.remove((triggers, hook))
return self.context.treater.hm.del_hooks(*triggers, hook=hook)
@@ -136,7 +137,9 @@ class ModuleContext(_ModuleContext):
return self.context.del_event(evt, module_src=self.module)
def send_response(self, server, res):
- if server in self.context.servers:
+ if isinstance(server, AbstractServer):
+ server.send_response(res)
+ elif server in self.context.servers:
if res.server is not None:
return self.context.servers[res.server].send_response(res)
else:
diff --git a/nemubot/server/message/IRC.py b/nemubot/server/message/IRC.py
index 5ccd735..48ef9a4 100644
--- a/nemubot/server/message/IRC.py
+++ b/nemubot/server/message/IRC.py
@@ -146,7 +146,7 @@ class IRC(Abstract):
receivers = self.decode(self.params[0]).split(',')
common_args = {
- "server": srv.name,
+ "server": srv,
"date": self.tags["time"],
"to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers],
diff --git a/nemubot/tools/test_xmlparser.py b/nemubot/tools/test_xmlparser.py
index d7f5a9a..0feda73 100644
--- a/nemubot/tools/test_xmlparser.py
+++ b/nemubot/tools/test_xmlparser.py
@@ -1,5 +1,6 @@
import unittest
+import io
import xml.parsers.expat
from nemubot.tools.xmlparser import XMLParser
@@ -12,6 +13,11 @@ class StringNode():
def characters(self, content):
self.string += content
+ def saveElement(self, store, tag="string"):
+ store.startElement(tag, {})
+ store.characters(self.string)
+ store.endElement(tag)
+
class TestNode():
def __init__(self, option=None):
@@ -22,6 +28,15 @@ class TestNode():
self.mystr = child.string
return True
+ def saveElement(self, store, tag="test"):
+ store.startElement(tag, {"option": self.option})
+
+ strNode = StringNode()
+ strNode.string = self.mystr
+ strNode.saveElement(store)
+
+ store.endElement(tag)
+
class Test2Node():
def __init__(self, option=None):
@@ -33,6 +48,15 @@ class Test2Node():
self.mystrs.append(attrs["value"])
return True
+ def saveElement(self, store, tag="test"):
+ store.startElement(tag, {"option": self.option} if self.option is not None else {})
+
+ for mystr in self.mystrs:
+ store.startElement("string", {"value": mystr})
+ store.endElement("string")
+
+ store.endElement(tag)
+
class TestXMLParser(unittest.TestCase):
@@ -44,9 +68,11 @@ class TestXMLParser(unittest.TestCase):
p.CharacterDataHandler = mod.characters
p.EndElementHandler = mod.endElement
- p.Parse("toto", 1)
+ inputstr = "toto"
+ p.Parse(inputstr, 1)
self.assertEqual(mod.root.string, "toto")
+ self.assertEqual(mod.saveDocument(header=False).getvalue(), inputstr)
def test_parser2(self):
@@ -57,10 +83,12 @@ class TestXMLParser(unittest.TestCase):
p.CharacterDataHandler = mod.characters
p.EndElementHandler = mod.endElement
- p.Parse("toto", 1)
+ inputstr = 'toto'
+ p.Parse(inputstr, 1)
self.assertEqual(mod.root.option, "123")
self.assertEqual(mod.root.mystr, "toto")
+ self.assertEqual(mod.saveDocument(header=False).getvalue(), inputstr)
def test_parser3(self):
@@ -71,12 +99,14 @@ class TestXMLParser(unittest.TestCase):
p.CharacterDataHandler = mod.characters
p.EndElementHandler = mod.endElement
- p.Parse("", 1)
+ inputstr = ''
+ p.Parse(inputstr, 1)
self.assertEqual(mod.root.option, None)
self.assertEqual(len(mod.root.mystrs), 2)
self.assertEqual(mod.root.mystrs[0], "toto")
self.assertEqual(mod.root.mystrs[1], "toto2")
+ self.assertEqual(mod.saveDocument(header=False, short_empty_elements=True).getvalue(), inputstr)
if __name__ == '__main__':
diff --git a/nemubot/tools/web.py b/nemubot/tools/web.py
index 0852664..0394aac 100644
--- a/nemubot/tools/web.py
+++ b/nemubot/tools/web.py
@@ -15,6 +15,7 @@
# along with this program. If not, see .
from urllib.parse import urljoin, urlparse, urlsplit, urlunsplit
+import socket
from nemubot.exception import IMException
@@ -67,13 +68,14 @@ def getPassword(url):
# Get real pages
-def getURLContent(url, body=None, timeout=7, header=None):
+def getURLContent(url, body=None, timeout=7, header=None, decode_error=False):
"""Return page content corresponding to URL or None if any error occurs
Arguments:
url -- the URL to get
body -- Data to send as POST content
timeout -- maximum number of seconds to wait before returning an exception
+ decode_error -- raise exception on non-200 pages or ignore it
"""
o = urlparse(_getNormalizedURL(url), "http")
@@ -123,6 +125,8 @@ def getURLContent(url, body=None, timeout=7, header=None):
o.path,
body,
header)
+ except socket.timeout as e:
+ raise IMException(e)
except OSError as e:
raise IMException(e.strerror)
@@ -163,7 +167,10 @@ def getURLContent(url, body=None, timeout=7, header=None):
urljoin(url, res.getheader("Location")),
body=body,
timeout=timeout,
- header=header)
+ header=header,
+ decode_error=decode_error)
+ elif decode_error:
+ return data.decode(charset).strip()
else:
raise IMException("A HTTP error occurs: %d - %s" %
(res.status, http.client.responses[res.status]))
diff --git a/nemubot/tools/xmlparser/__init__.py b/nemubot/tools/xmlparser/__init__.py
index abc5bb9..c8d393a 100644
--- a/nemubot/tools/xmlparser/__init__.py
+++ b/nemubot/tools/xmlparser/__init__.py
@@ -134,6 +134,21 @@ class XMLParser:
return
raise TypeError(name + " tag not expected in " + self.display_stack())
+ def saveDocument(self, f=None, header=True, short_empty_elements=False):
+ if f is None:
+ import io
+ f = io.StringIO()
+
+ import xml.sax.saxutils
+ gen = xml.sax.saxutils.XMLGenerator(f, "utf-8", short_empty_elements=short_empty_elements)
+ if header:
+ gen.startDocument()
+ self.root.saveElement(gen)
+ if header:
+ gen.endDocument()
+
+ return f
+
def parse_file(filename):
p = xml.parsers.expat.ParserCreate()
diff --git a/nemubot/tools/xmlparser/basic.py b/nemubot/tools/xmlparser/basic.py
index 8456629..f2d9fd5 100644
--- a/nemubot/tools/xmlparser/basic.py
+++ b/nemubot/tools/xmlparser/basic.py
@@ -44,6 +44,13 @@ class ListNode:
return self.items.__repr__()
+ def saveElement(self, store, tag="list"):
+ store.startElement(tag, {})
+ for i in self.items:
+ i.saveElement(store)
+ store.endElement(tag)
+
+
class DictNode:
"""XML node representing a Python dictionnnary
@@ -106,3 +113,10 @@ class DictNode:
def __repr__(self):
return self.items.__repr__()
+
+
+ def saveElement(self, store, tag="dict"):
+ store.startElement(tag, {})
+ for k, v in self.items.items():
+ v.saveElement(store)
+ store.endElement(tag)
diff --git a/nemubot/tools/xmlparser/genericnode.py b/nemubot/tools/xmlparser/genericnode.py
index 9c29a23..425934c 100644
--- a/nemubot/tools/xmlparser/genericnode.py
+++ b/nemubot/tools/xmlparser/genericnode.py
@@ -53,6 +53,14 @@ class ParsingNode:
return item in self.attrs
+ def saveElement(self, store, tag=None):
+ store.startElement(tag if tag is not None else self.tag, self.attrs)
+ for child in self.children:
+ child.saveElement(store)
+ store.characters(self.content)
+ store.endElement(tag if tag is not None else self.tag)
+
+
class GenericNode(ParsingNode):
"""Consider all subtags as dictionnary
diff --git a/nemubot/tools/xmlparser/node.py b/nemubot/tools/xmlparser/node.py
index 965a475..7df255e 100644
--- a/nemubot/tools/xmlparser/node.py
+++ b/nemubot/tools/xmlparser/node.py
@@ -196,7 +196,7 @@ class ModuleState:
if self.index_fieldname is not None:
self.setIndex(self.index_fieldname, self.index_tagname)
- def save_node(self, gen):
+ def saveElement(self, gen):
"""Serialize this node as a XML node"""
from datetime import datetime
attribs = {}
@@ -215,29 +215,9 @@ class ModuleState:
gen.startElement(self.name, attrs)
for child in self.childs:
- child.save_node(gen)
+ child.saveElement(gen)
gen.endElement(self.name)
except:
logger.exception("Error occured when saving the following "
"XML node: %s with %s", self.name, attrs)
-
- def save(self, filename):
- """Save the current node as root node in a XML file
-
- Argument:
- filename -- location of the file to create/erase
- """
-
- import tempfile
- _, tmpath = tempfile.mkstemp()
- with open(tmpath, "w") as f:
- import xml.sax.saxutils
- gen = xml.sax.saxutils.XMLGenerator(f, "utf-8")
- gen.startDocument()
- self.save_node(gen)
- gen.endDocument()
-
- # Atomic save
- import shutil
- shutil.move(tmpath, filename)