diff --git a/nemubot/datastore/abstract.py b/nemubot/datastore/abstract.py index 96e2c0d..aeaecc6 100644 --- a/nemubot/datastore/abstract.py +++ b/nemubot/datastore/abstract.py @@ -32,16 +32,20 @@ class Abstract: def close(self): return - def load(self, module): + def load(self, module, knodes): """Load data for the given module Argument: module -- the module name of data to load + knodes -- the schema to use to load the datas Return: The loaded data """ + if knodes is not None: + return None + return self.new() def save(self, module, data): diff --git a/nemubot/datastore/xml.py b/nemubot/datastore/xml.py index 025c0c5..aa6cbd0 100644 --- a/nemubot/datastore/xml.py +++ b/nemubot/datastore/xml.py @@ -83,27 +83,38 @@ class XML(Abstract): return os.path.join(self.basedir, module + ".xml") - def load(self, module): + def load(self, module, knodes): """Load data for the given module Argument: module -- the module name of data to load + knodes -- the schema to use to load the datas """ data_file = self._get_data_file_path(module) + if knodes is None: + from nemubot.tools.xmlparser import parse_file + def _true_load(path): + return parse_file(path) + + else: + from nemubot.tools.xmlparser import XMLParser + p = XMLParser(knodes) + def _true_load(path): + return p.parse_file(path) + # Try to load original file if os.path.isfile(data_file): - from nemubot.tools.xmlparser import parse_file try: - return parse_file(data_file) + return _true_load(data_file) except xml.parsers.expat.ExpatError: # Try to load from backup for i in range(10): path = data_file + "." + str(i) if os.path.isfile(path): try: - cnt = parse_file(path) + cnt = _true_load(path) logger.warn("Restoring from backup: %s", path) @@ -112,7 +123,7 @@ class XML(Abstract): continue # Default case: initialize a new empty datastore - return Abstract.load(self, module) + return super().load(module, knodes) def _rotate(self, path): """Backup given path @@ -143,6 +154,9 @@ class XML(Abstract): if self.rotate: self._rotate(path) + if data is None: + return + import tempfile _, tmpath = tempfile.mkstemp() with open(tmpath, "w") as f: diff --git a/nemubot/modulecontext.py b/nemubot/modulecontext.py index f39934c..c7fa3d4 100644 --- a/nemubot/modulecontext.py +++ b/nemubot/modulecontext.py @@ -16,7 +16,7 @@ class _ModuleContext: - def __init__(self, module=None): + def __init__(self, module=None, knodes=None): self.module = module if module is not None: @@ -30,12 +30,16 @@ class _ModuleContext: from nemubot.config.module import Module self.config = Module(self.module_name) + self._knodes = knodes def load_data(self): from nemubot.tools.xmlparser import module_state return module_state.ModuleState("nemubotstate") + def set_knodes(self, knodes): + self._knodes = knodes + def add_hook(self, hook, *triggers): from nemubot.hooks import Abstract as AbstractHook assert isinstance(hook, AbstractHook), hook @@ -112,7 +116,7 @@ class ModuleContext(_ModuleContext): def load_data(self): - return self.context.datastore.load(self.module_name) + return self.context.datastore.load(self.module_name, self._knodes) def add_hook(self, hook, *triggers): from nemubot.hooks import Abstract as AbstractHook