diff --git a/bot_sample.xml b/bot_sample.xml index ce821d2..ed1a41f 100644 --- a/bot_sample.xml +++ b/bot_sample.xml @@ -1,11 +1,11 @@ - + diff --git a/modules/books.py b/modules/books.py index f532a3b..4a4d5aa 100644 --- a/modules/books.py +++ b/modules/books.py @@ -15,7 +15,7 @@ from more import Response # LOADING ############################################################# def load(context): - if not context.config or not context.config.getAttribute("goodreadskey"): + if not context.config or "goodreadskey" not in context.config: raise ImportError("You need a Goodreads API key in order to use this " "module. Add it to the module configuration file:\n" "\n" diff --git a/modules/mapquest.py b/modules/mapquest.py index 95952ab..f147176 100644 --- a/modules/mapquest.py +++ b/modules/mapquest.py @@ -16,7 +16,7 @@ from more import Response URL_API = "http://open.mapquestapi.com/geocoding/v1/address?key=%s&location=%%s" def load(context): - if not context.config or not context.config.hasAttribute("apikey"): + if not context.config or "apikey" not in context.config: raise ImportError("You need a MapQuest API key in order to use this " "module. Add it to the module configuration file:\n" "\nSample " diff --git a/modules/translate.py b/modules/translate.py index a0d8dc2..7452889 100644 --- a/modules/translate.py +++ b/modules/translate.py @@ -19,7 +19,7 @@ LANG = ["ar", "zh", "cz", "en", "fr", "gr", "it", URL = "http://api.wordreference.com/0.8/%s/json/%%s%%s/%%s" def load(context): - if not context.config or not context.config.hasAttribute("wrapikey"): + if not context.config or "wrapikey" not in context.config: raise ImportError("You need a WordReference API key in order to use " "this module. Add it to the module configuration " "file:\n\n" diff --git a/modules/whois.py b/modules/whois.py index 32c13ea..878d4a2 100644 --- a/modules/whois.py +++ b/modules/whois.py @@ -16,7 +16,7 @@ PASSWD_FILE = None def load(context): global PASSWD_FILE - if not context.config or not context.config.hasAttribute("passwd"): + if not context.config or "passwd" not in context.config: print("No passwd file given") return None PASSWD_FILE = context.config["passwd"] diff --git a/modules/wolframalpha.py b/modules/wolframalpha.py index ef1cc82..7a13200 100644 --- a/modules/wolframalpha.py +++ b/modules/wolframalpha.py @@ -19,7 +19,7 @@ URL_API = "http://api.wolframalpha.com/v2/query?input=%%s&appid=%s" def load(context): global URL_API - if not context.config or not context.config.hasAttribute("apikey"): + if not context.config or "apikey" not in context.config: raise ImportError ("You need a Wolfram|Alpha API key in order to use " "this module. Add it to the module configuration: " "\n 1: - from nemubot.tools.config import load_file - for filename in toks[1:]: - load_file(filename, context) + context.load_file(filename) else: print ("Not enough arguments. `load' takes a filename.") return 1 diff --git a/nemubot/tools/config.py b/nemubot/tools/config.py index 479b96f..33fd3cc 100644 --- a/nemubot/tools/config.py +++ b/nemubot/tools/config.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # Nemubot is a smart and modulable IM bot. # Copyright (C) 2012-2015 Mercier Pierre-Olivier # @@ -16,123 +14,146 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import logging - -logger = logging.getLogger("nemubot.tools.config") +def get_boolean(s): + if isinstance(s, bool): + return s + else: + return (s and s != "0" and s.lower() != "false" and s.lower() != "off") -def get_boolean(d, k, default=False): - return ((k in d and d[k].lower() != "false" and d[k].lower() != "off") or - (k not in d and default)) +class GenericNode: + + def __init__(self, tag, **kwargs): + self.tag = tag + self.attrs = kwargs + self.content = "" + self.children = [] + self._cur = None + self._deep_cur = 0 -def _load_server(config, xmlnode): - """Load a server configuration - - Arguments: - config -- the global configuration - xmlnode -- the current server configuration node - """ - - opts = { - "host": xmlnode["host"], - "ssl": xmlnode.hasAttribute("ssl") and xmlnode["ssl"].lower() == "true", - - "nick": xmlnode["nick"] if xmlnode.hasAttribute("nick") else config["nick"], - "owner": xmlnode["owner"] if xmlnode.hasAttribute("owner") else config["owner"], - } - - # Optional keyword arguments - for optional_opt in [ "port", "username", "realname", - "password", "encoding", "caps" ]: - if xmlnode.hasAttribute(optional_opt): - opts[optional_opt] = xmlnode[optional_opt] - elif optional_opt in config: - opts[optional_opt] = config[optional_opt] - - # Command to send on connection - if "on_connect" in xmlnode: - def on_connect(): - yield xmlnode["on_connect"] - opts["on_connect"] = on_connect - - # Channels to autojoin on connection - if xmlnode.hasNode("channel"): - opts["channels"] = list() - for chn in xmlnode.getNodes("channel"): - opts["channels"].append((chn["name"], chn["password"]) - if chn["password"] is not None - else chn["name"]) - - # Server/client capabilities - if "caps" in xmlnode or "caps" in config: - capsl = (xmlnode["caps"] if xmlnode.hasAttribute("caps") - else config["caps"]).lower() - if capsl == "no" or capsl == "off" or capsl == "false": - opts["caps"] = None + def startElement(self, name, attrs): + if self._cur is None: + self._cur = GenericNode(name, **attrs) + self._deep_cur = 0 else: - opts["caps"] = capsl.split(',') - else: - opts["caps"] = list() - - # Bind the protocol asked to the corresponding implementation - if "protocol" not in xmlnode or xmlnode["protocol"] == "irc": - from nemubot.server.IRC import IRC as IRCServer - srvcls = IRCServer - else: - raise Exception("Unhandled protocol '%s'" % - xmlnode["protocol"]) - - # Initialize the server - return srvcls(**opts) + self._deep_cur += 1 + self._cur.startElement(name, attrs) + return True -def load_file(filename, context): - """Load the configuration file - - Arguments: - filename -- the path to the file to load - """ - - import os - - if os.path.isfile(filename): - from nemubot.tools.xmlparser import parse_file - - config = parse_file(filename) - - # This is a true nemubot configuration file, load it! - if config.getName() == "nemubotconfig": - # Preset each server in this file - for server in config.getNodes("server"): - srv = _load_server(config, server) - - # Add the server in the context - if context.add_server(srv, get_boolean(server, "autoconnect")): - logger.info("Server '%s' successfully added." % srv.id) - else: - logger.error("Can't add server '%s'." % srv.id) - - # Load module and their configuration - for mod in config.getNodes("module"): - context.modules_configuration[mod["name"]] = mod - if get_boolean(mod, "autoload", default=True): - try: - __import__(mod["name"]) - except: - logger.exception("Exception occurs when loading module" - " '%s'", mod["name"]) - - - # Load files asked by the configuration file - for load in config.getNodes("include"): - load_file(load["path"], context) - - # Other formats + def characters(self, content): + if self._cur is None: + self.content += content else: - logger.error("Can't load `%s'; this is not a valid nemubot " - "configuration file." % filename) + self._cur.characters(content) - # Unexisting file, assume a name was passed, import the module! - else: - context.import_module(filename) + + def endElement(self, name): + if name is None: + return + + if self._deep_cur: + self._deep_cur -= 1 + self._cur.endElement(name) + else: + self.children.append(self._cur) + self._cur = None + return True + + + def hasNode(self, nodename): + return self.getNode(nodename) is not None + + + def getNode(self, nodename): + for c in self.children: + if c is not None and c.tag == nodename: + return c + return None + + + def __getitem__(self, item): + return self.attrs[item] + + def __contains__(self, item): + return item in self.attrs + + +class NemubotConfig: + + def __init__(self, nick="nemubot", realname="nemubot", owner=None, + ip=None, ssl=False, caps=None, encoding="utf-8"): + self.nick = nick + self.realname = realname + self.owner = owner + self.ip = ip + self.caps = caps.split(" ") if caps is not None else [] + self.encoding = encoding + self.servers = [] + self.modules = [] + self.includes = [] + + + def addChild(self, name, child): + if name == "module" and isinstance(child, ModuleConfig): + self.modules.append(child) + return True + elif name == "server" and isinstance(child, ServerConfig): + self.servers.append(child) + return True + elif name == "include" and isinstance(child, IncludeConfig): + self.includes.append(child) + return True + + +class ServerConfig: + + def __init__(self, uri="irc://nemubot@localhost/", autoconnect=True, caps=None, **kwargs): + self.uri = uri + self.autoconnect = autoconnect + self.caps = caps.split(" ") if caps is not None else [] + self.args = kwargs + self.channels = [] + + + def addChild(self, name, child): + if name == "channel" and isinstance(child, Channel): + self.channels.append(child) + return True + + + def server(self, parent): + from nemubot.server import factory + + for a in ["nick", "owner", "realname", "encoding"]: + if a not in self.args: + self.args[a] = getattr(parent, a) + + self.caps += parent.caps + + return factory(self.uri, **self.args) + + +class IncludeConfig: + + def __init__(self, path): + self.path = path + + +class ModuleConfig(GenericNode): + + def __init__(self, name, autoload=True, **kwargs): + super(ModuleConfig, self).__init__(None, **kwargs) + self.name = name + self.autoload = get_boolean(autoload) + +from nemubot.channel import Channel + +config_nodes = { + "nemubotconfig": NemubotConfig, + "server": ServerConfig, + "channel": Channel, + "module": ModuleConfig, + "include": IncludeConfig, +} diff --git a/nemubot/tools/test_xmlparser.py b/nemubot/tools/test_xmlparser.py new file mode 100644 index 0000000..faf5684 --- /dev/null +++ b/nemubot/tools/test_xmlparser.py @@ -0,0 +1,82 @@ +import unittest + +import xml.parsers.expat + +from nemubot.tools.xmlparser import XMLParser + + +class StringNode(): + def __init__(self): + self.string = "" + + def characters(self, content): + self.string += content + + +class TestNode(): + def __init__(self, option=None): + self.option = option + self.mystr = None + + def addChild(self, name, child): + self.mystr = child.string + + +class Test2Node(): + def __init__(self, option=None): + self.option = option + self.mystrs = list() + + def startElement(self, name, attrs): + if name == "string": + self.mystrs.append(attrs["value"]) + return True + + +class TestXMLParser(unittest.TestCase): + + def test_parser1(self): + p = xml.parsers.expat.ParserCreate() + mod = XMLParser({"string": StringNode}) + + p.StartElementHandler = mod.startElement + p.CharacterDataHandler = mod.characters + p.EndElementHandler = mod.endElement + + p.Parse("toto", 1) + + self.assertEqual(mod.root.string, "toto") + + + def test_parser2(self): + p = xml.parsers.expat.ParserCreate() + mod = XMLParser({"string": StringNode, "test": TestNode}) + + p.StartElementHandler = mod.startElement + p.CharacterDataHandler = mod.characters + p.EndElementHandler = mod.endElement + + p.Parse("toto", 1) + + self.assertEqual(mod.root.option, "123") + self.assertEqual(mod.root.mystr, "toto") + + + def test_parser3(self): + p = xml.parsers.expat.ParserCreate() + mod = XMLParser({"string": StringNode, "test": Test2Node}) + + p.StartElementHandler = mod.startElement + p.CharacterDataHandler = mod.characters + p.EndElementHandler = mod.endElement + + p.Parse("", 1) + + self.assertEqual(mod.root.option, None) + self.assertEqual(len(mod.root.mystrs), 2) + self.assertEqual(mod.root.mystrs[0], "toto") + self.assertEqual(mod.root.mystrs[1], "toto2") + + +if __name__ == '__main__': + unittest.main() diff --git a/nemubot/tools/xmlparser/__init__.py b/nemubot/tools/xmlparser/__init__.py index 4617b57..5e546f4 100644 --- a/nemubot/tools/xmlparser/__init__.py +++ b/nemubot/tools/xmlparser/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # Nemubot is a smart and modulable IM bot. # Copyright (C) 2012-2015 Mercier Pierre-Olivier # @@ -48,9 +46,107 @@ class ModuleStatesFile: self.root = child +class XMLParser: + + def __init__(self, knodes): + self.knodes = knodes + + self.stack = list() + self.child = 0 + + + def parse_file(self, path): + p = xml.parsers.expat.ParserCreate() + + p.StartElementHandler = self.startElement + p.CharacterDataHandler = self.characters + p.EndElementHandler = self.endElement + + with open(path, "rb") as f: + p.ParseFile(f) + + return self.root + + + def parse_string(self, s): + p = xml.parsers.expat.ParserCreate() + + p.StartElementHandler = self.startElement + p.CharacterDataHandler = self.characters + p.EndElementHandler = self.endElement + + p.Parse(s, 1) + + return self.root + + + @property + def root(self): + if len(self.stack): + return self.stack[0] + else: + return None + + + @property + def current(self): + if len(self.stack): + return self.stack[-1] + else: + return None + + + def display_stack(self): + return " in ".join([str(type(s).__name__) for s in reversed(self.stack)]) + + + def startElement(self, name, attrs): + if not self.current or not hasattr(self.current, "startElement") or not self.current.startElement(name, attrs): + if name not in self.knodes: + raise TypeError(name + " is not a known type to decode") + else: + self.stack.append(self.knodes[name](**attrs)) + else: + self.child += 1 + + + def characters(self, content): + if self.current and hasattr(self.current, "characters"): + self.current.characters(content) + + + def endElement(self, name): + if self.child: + self.child -= 1 + + if hasattr(self.current, "endElement"): + self.current.endElement(name) + return + + if hasattr(self.current, "endElement"): + self.current.endElement(None) + + # Don't remove root + if len(self.stack) > 1: + last = self.stack.pop() + if hasattr(self.current, "addChild"): + if self.current.addChild(name, last): + return + raise TypeError(name + " tag not expected in " + self.display_stack()) + + def parse_file(filename): - with open(filename, "r") as f: - return parse_string(f.read()) + p = xml.parsers.expat.ParserCreate() + mod = ModuleStatesFile() + + p.StartElementHandler = mod.startElement + p.EndElementHandler = mod.endElement + p.CharacterDataHandler = mod.characters + + with open(filename, "rb") as f: + p.ParseFile(f) + + return mod.root def parse_string(string): diff --git a/nemubot/tools/xmlparser/node.py b/nemubot/tools/xmlparser/node.py index 5f8a509..fa5d0a5 100644 --- a/nemubot/tools/xmlparser/node.py +++ b/nemubot/tools/xmlparser/node.py @@ -1,5 +1,3 @@ -# coding=utf-8 - # Nemubot is a smart and modulable IM bot. # Copyright (C) 2012-2015 Mercier Pierre-Olivier # @@ -37,7 +35,7 @@ class ModuleState: """Get the name of the current node""" return self.name - def display(self, level = 0): + def display(self, level=0): ret = "" out = list() for k in self.attributes: @@ -51,6 +49,9 @@ class ModuleState: def __str__(self): return self.display() + def __repr__(self): + return self.display() + def __getitem__(self, i): """Return the attribute asked""" return self.getAttribute(i)