Compare commits

...

23 commits

Author SHA1 Message Date
c2f7606d1e Refactor 2016-11-10 18:39:52 +01:00
334d342b40 Implement socket server subparse 2016-08-02 18:15:10 +02:00
410fa6ded1 Refactor file/socket management (use poll instead of select) 2016-08-02 18:15:10 +02:00
559a2a8adc Use fileno instead of name to index existing servers 2016-08-02 18:14:33 +02:00
16facd949d Use super() instead of parent class name 2016-07-30 21:11:46 +02:00
b7b1a92161 Documentation 2016-07-30 21:11:46 +02:00
0c4144ce20 In debug mode, display running thread at exit 2016-07-30 21:11:45 +02:00
20f3e8b215 Handle case where frm and to have not been filled 2016-07-30 21:11:45 +02:00
56b8e8a260 Review consumer errors 2016-07-30 21:11:45 +02:00
4d0260b5fc Remove reload feature
As reload shoudl be done in a particular order, to keep valid types, and because maintaining such system is too complex (currently, it doesn't work for a while), now, a reload is just reload configuration file (and possibly modules)
2016-07-30 21:11:44 +02:00
fb670b0777 New keywords class that accepts any keywords 2016-07-30 21:11:44 +02:00
226785af19 [rnd] Add new function choiceres which pick a random response returned by a given subcommand 2016-07-30 21:11:43 +02:00
7eaf39f850 Can attach to the main process 2016-07-30 21:11:43 +02:00
9b68b9e217 Remove legacy prompt 2016-07-30 21:11:42 +02:00
db58319af5 Fix and improve reload process 2016-07-30 21:11:42 +02:00
dfc929f3e5 New argument: --socketfile that create a socket for internal communication 2016-07-30 21:11:41 +02:00
66341461a3 New CLI argument: --pidfile, path to store the daemon PID 2016-07-30 21:11:41 +02:00
fb3c715acd Catch SIGUSR1: log threads stack traces 2016-07-30 21:11:41 +02:00
1f82b17219 Extract deamonize to a dedicated function that can be called from anywhere 2016-07-30 21:11:40 +02:00
82e50061b1 Catch SIGHUP: deep reload 2016-07-30 21:11:40 +02:00
6c44df1482 Do a proper close on SIGINT and SIGTERM 2016-07-30 21:11:39 +02:00
ca8434a476 Remove prompt at launch 2016-07-30 21:11:39 +02:00
16b7024f62 Introducing daemon mode 2016-07-30 21:11:38 +02:00
34 changed files with 605 additions and 1612 deletions

View file

@ -9,6 +9,8 @@ Requirements
*nemubot* requires at least Python 3.3 to work. *nemubot* requires at least Python 3.3 to work.
Connecting to SSL server requires [this patch](http://bugs.python.org/issue27629).
Some modules (like `cve`, `nextstop` or `laposte`) require the Some modules (like `cve`, `nextstop` or `laposte`) require the
[BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/), [BeautifulSoup module](http://www.crummy.com/software/BeautifulSoup/),
but the core and framework has no dependency. but the core and framework has no dependency.

View file

@ -1,202 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import traceback
import sys
from nemubot.hooks import hook
nemubotversion = 3.4
NODATA = True
def getserver(toks, context, prompt, mandatory=False, **kwargs):
"""Choose the server in toks or prompt.
This function modify the tokens list passed as argument"""
if len(toks) > 1 and toks[1] in context.servers:
return context.servers[toks.pop(1)]
elif not mandatory or prompt.selectedServer:
return prompt.selectedServer
else:
from nemubot.prompt.error import PromptError
raise PromptError("Please SELECT a server or give its name in argument.")
@hook("prompt_cmd", "close")
def close(toks, context, **kwargs):
"""Disconnect and forget (remove from the servers list) the server"""
srv = getserver(toks, context=context, mandatory=True, **kwargs)
if srv.close():
del context.servers[srv.id]
return 0
return 1
@hook("prompt_cmd", "connect")
def connect(toks, **kwargs):
"""Make the connexion to a server"""
srv = getserver(toks, mandatory=True, **kwargs)
return not srv.open()
@hook("prompt_cmd", "disconnect")
def disconnect(toks, **kwargs):
"""Close the connection to a server"""
srv = getserver(toks, mandatory=True, **kwargs)
return not srv.close()
@hook("prompt_cmd", "discover")
def discover(toks, context, **kwargs):
"""Discover a new bot on a server"""
srv = getserver(toks, context=context, mandatory=True, **kwargs)
if len(toks) > 1 and "!" in toks[1]:
bot = context.add_networkbot(srv, name)
return not bot.connect()
else:
print(" %s is not a valid fullname, for example: "
"nemubot!nemubotV3@bot.nemunai.re" % ''.join(toks[1:1]))
return 1
@hook("prompt_cmd", "join")
@hook("prompt_cmd", "leave")
@hook("prompt_cmd", "part")
def join(toks, **kwargs):
"""Join or leave a channel"""
srv = getserver(toks, mandatory=True, **kwargs)
if len(toks) <= 2:
print("%s: not enough arguments." % toks[0])
return 1
if toks[0] == "join":
if len(toks) > 2:
srv.write("JOIN %s %s" % (toks[1], toks[2]))
else:
srv.write("JOIN %s" % toks[1])
elif toks[0] == "leave" or toks[0] == "part":
if len(toks) > 2:
srv.write("PART %s :%s" % (toks[1], " ".join(toks[2:])))
else:
srv.write("PART %s" % toks[1])
return 0
@hook("prompt_cmd", "save")
def save_mod(toks, context, **kwargs):
"""Force save module data"""
if len(toks) < 2:
print("save: not enough arguments.")
return 1
wrn = 0
for mod in toks[1:]:
if mod in context.modules:
context.modules[mod].save()
print("save: module `%s´ saved successfully" % mod)
else:
wrn += 1
print("save: no module named `%s´" % mod)
return wrn
@hook("prompt_cmd", "send")
def send(toks, **kwargs):
"""Send a message on a channel"""
srv = getserver(toks, mandatory=True, **kwargs)
# Check the server is connected
if not srv.connected:
print ("send: server `%s' not connected." % srv.id)
return 2
if len(toks) <= 3:
print ("send: not enough arguments.")
return 1
if toks[1] not in srv.channels:
print ("send: channel `%s' not authorized in server `%s'."
% (toks[1], srv.id))
return 3
from nemubot.message import Text
srv.send_response(Text(" ".join(toks[2:]), server=None,
to=[toks[1]]))
return 0
@hook("prompt_cmd", "zap")
def zap(toks, **kwargs):
"""Hard change connexion state"""
srv = getserver(toks, mandatory=True, **kwargs)
srv.connected = not srv.connected
@hook("prompt_cmd", "top")
def top(toks, context, **kwargs):
"""Display consumers load information"""
print("Queue size: %d, %d thread(s) running (counter: %d)" %
(context.cnsr_queue.qsize(),
len(context.cnsr_thrd),
context.cnsr_thrd_size))
if len(context.events) > 0:
print("Events registered: %d, next in %d seconds" %
(len(context.events),
context.events[0].time_left.seconds))
else:
print("No events registered")
for th in context.cnsr_thrd:
if th.is_alive():
print(("#" * 15 + " Stack trace for thread %u " + "#" * 15) %
th.ident)
traceback.print_stack(sys._current_frames()[th.ident])
@hook("prompt_cmd", "netstat")
def netstat(toks, context, **kwargs):
"""Display sockets in use and many other things"""
if len(context.network) > 0:
print("Distant bots connected: %d:" % len(context.network))
for name, bot in context.network.items():
print("# %s:" % name)
print(" * Declared hooks:")
lvl = 0
for hlvl in bot.hooks:
lvl += 1
for hook in (hlvl.all_pre + hlvl.all_post + hlvl.cmd_rgxp +
hlvl.cmd_default + hlvl.ask_rgxp +
hlvl.ask_default + hlvl.msg_rgxp +
hlvl.msg_default):
print(" %s- %s" % (' ' * lvl * 2, hook))
for kind in ["irc_hook", "cmd_hook", "ask_hook", "msg_hook"]:
print(" %s- <%s> %s" % (' ' * lvl * 2, kind,
", ".join(hlvl.__dict__[kind].keys())))
print(" * My tag: %d" % bot.my_tag)
print(" * Tags in use (%d):" % bot.inc_tag)
for tag, (cmd, data) in bot.tags.items():
print(" - %11s: %s « %s »" % (tag, cmd, data))
else:
print("No distant bot connected")

View file

@ -8,7 +8,6 @@ import shlex
from nemubot import context from nemubot import context
from nemubot.exception import IMException from nemubot.exception import IMException
from nemubot.hooks import hook from nemubot.hooks import hook
from nemubot.message import Command
from more import Response from more import Response
@ -32,8 +31,24 @@ def cmd_choicecmd(msg):
choice = shlex.split(random.choice(msg.args)) choice = shlex.split(random.choice(msg.args))
return [x for x in context.subtreat(Command(choice[0][1:], return [x for x in context.subtreat(context.subparse(msg, choice))]
choice[1:],
to_response=msg.to_response,
frm=msg.frm, @hook.command("choiceres")
server=msg.server))] def cmd_choiceres(msg):
if not len(msg.args):
raise IMException("indicate some command to pick a message from!")
rl = [x for x in context.subtreat(context.subparse(msg, " ".join(msg.args)))]
if len(rl) <= 0:
return rl
r = random.choice(rl)
if isinstance(r, Response):
for i in range(len(r.messages) - 1, -1, -1):
if isinstance(r.messages[i], list):
r.messages = [ random.choice(random.choice(r.messages)) ]
elif isinstance(r.messages[i], str):
r.messages = [ random.choice(r.messages) ]
return r

View file

@ -18,7 +18,8 @@ __version__ = '4.0.dev3'
__author__ = 'nemunaire' __author__ = 'nemunaire'
from nemubot.modulecontext import ModuleContext from nemubot.modulecontext import ModuleContext
context = ModuleContext(None, None)
context = ModuleContext(None, None, None)
def requires_version(min=None, max=None): def requires_version(min=None, max=None):
@ -38,62 +39,98 @@ def requires_version(min=None, max=None):
"but this is nemubot v%s." % (str(max), __version__)) "but this is nemubot v%s." % (str(max), __version__))
def reload(): def attach(pid, socketfile):
"""Reload code of all Python modules used by nemubot import socket
import sys
print("nemubot is already launched with PID %d. Attaching to Unix socket at: %s" % (pid, socketfile))
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(socketfile)
except socket.error as e:
sys.stderr.write(str(e))
sys.stderr.write("\n")
return 1
from select import select
try:
print("Connection established.")
while True:
rl, wl, xl = select([sys.stdin, sock], [], [])
if sys.stdin in rl:
line = sys.stdin.readline().strip()
if line == "exit" or line == "quit":
return 0
elif line == "reload":
import os, signal
os.kill(pid, signal.SIGHUP)
print("Reload signal sent. Please wait...")
elif line == "shutdown":
import os, signal
os.kill(pid, signal.SIGTERM)
print("Shutdown signal sent. Please wait...")
elif line == "kill":
import os, signal
os.kill(pid, signal.SIGKILL)
print("Signal sent...")
return 0
elif line == "stack" or line == "stacks":
import os, signal
os.kill(pid, signal.SIGUSR1)
print("Debug signal sent. Consult logs.")
else:
sock.send(line.encode() + b'\r\n')
if sock in rl:
sys.stdout.write(sock.recv(2048).decode())
except KeyboardInterrupt:
pass
except:
return 1
finally:
sock.close()
return 0
def daemonize():
"""Detach the running process to run as a daemon
""" """
import imp import os
import sys
import nemubot.channel try:
imp.reload(nemubot.channel) pid = os.fork()
if pid > 0:
sys.exit(0)
except OSError as err:
sys.stderr.write("Unable to fork: %s\n" % err)
sys.exit(1)
import nemubot.config os.setsid()
imp.reload(nemubot.config) os.umask(0)
os.chdir('/')
nemubot.config.reload() try:
pid = os.fork()
if pid > 0:
sys.exit(0)
except OSError as err:
sys.stderr.write("Unable to fork: %s\n" % err)
sys.exit(1)
import nemubot.consumer sys.stdout.flush()
imp.reload(nemubot.consumer) sys.stderr.flush()
si = open(os.devnull, 'r')
so = open(os.devnull, 'a+')
se = open(os.devnull, 'a+')
import nemubot.datastore os.dup2(si.fileno(), sys.stdin.fileno())
imp.reload(nemubot.datastore) os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno())
nemubot.datastore.reload()
import nemubot.event
imp.reload(nemubot.event)
import nemubot.exception
imp.reload(nemubot.exception)
nemubot.exception.reload()
import nemubot.hooks
imp.reload(nemubot.hooks)
nemubot.hooks.reload()
import nemubot.importer
imp.reload(nemubot.importer)
import nemubot.message
imp.reload(nemubot.message)
nemubot.message.reload()
import nemubot.prompt
imp.reload(nemubot.prompt)
nemubot.prompt.reload()
import nemubot.server
rl, wl, xl = nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist
imp.reload(nemubot.server)
nemubot.server._rlist, nemubot.server._wlist, nemubot.server._xlist = rl, wl, xl
nemubot.server.reload()
import nemubot.tools
imp.reload(nemubot.tools)
nemubot.tools.reload()

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -15,7 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def main(): def main():
import functools
import os import os
import signal
import sys import sys
# Parse command line arguments # Parse command line arguments
@ -36,6 +38,15 @@ def main():
default=["./modules/"], default=["./modules/"],
help="directory to use as modules store") help="directory to use as modules store")
parser.add_argument("-d", "--debug", action="store_true",
help="don't deamonize, keep in foreground")
parser.add_argument("-P", "--pidfile", default="./nemubot.pid",
help="Path to the file where store PID")
parser.add_argument("-S", "--socketfile", default="./nemubot.sock",
help="path where open the socket for internal communication")
parser.add_argument("-l", "--logfile", default="./nemubot.log", parser.add_argument("-l", "--logfile", default="./nemubot.log",
help="Path to store logs") help="Path to store logs")
@ -58,11 +69,35 @@ def main():
# Resolve relatives paths # Resolve relatives paths
args.data_path = os.path.abspath(os.path.expanduser(args.data_path)) args.data_path = os.path.abspath(os.path.expanduser(args.data_path))
args.pidfile = os.path.abspath(os.path.expanduser(args.pidfile))
args.socketfile = os.path.abspath(os.path.expanduser(args.socketfile))
args.logfile = os.path.abspath(os.path.expanduser(args.logfile)) args.logfile = os.path.abspath(os.path.expanduser(args.logfile))
args.files = [ x for x in map(os.path.abspath, args.files)] args.files = [x for x in map(os.path.abspath, args.files)]
args.modules_path = [ x for x in map(os.path.abspath, args.modules_path)] args.modules_path = [x for x in map(os.path.abspath, args.modules_path)]
# Setup loggin interface # Check if an instance is already launched
if args.pidfile is not None and os.path.isfile(args.pidfile):
with open(args.pidfile, "r") as f:
pid = int(f.readline())
try:
os.kill(pid, 0)
except OSError:
pass
else:
from nemubot import attach
sys.exit(attach(pid, args.socketfile))
# Daemonize
if not args.debug:
from nemubot import daemonize
daemonize()
# Store PID to pidfile
if args.pidfile is not None:
with open(args.pidfile, "w+") as f:
f.write(str(os.getpid()))
# Setup logging interface
import logging import logging
logger = logging.getLogger("nemubot") logger = logging.getLogger("nemubot")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -70,11 +105,12 @@ def main():
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s %(name)s %(levelname)s %(message)s') '%(asctime)s %(name)s %(levelname)s %(message)s')
ch = logging.StreamHandler() if args.debug:
ch.setFormatter(formatter) ch = logging.StreamHandler()
if args.verbose < 2: ch.setFormatter(formatter)
ch.setLevel(logging.INFO) if args.verbose < 2:
logger.addHandler(ch) ch.setLevel(logging.INFO)
logger.addHandler(ch)
fh = logging.FileHandler(args.logfile) fh = logging.FileHandler(args.logfile)
fh.setFormatter(formatter) fh.setFormatter(formatter)
@ -90,64 +126,89 @@ def main():
# Create bot context # Create bot context
from nemubot import datastore from nemubot import datastore
from nemubot.bot import Bot from nemubot.bot import Bot#, sync_act
context = Bot(modules_paths=modules_paths, context = Bot()
data_store=datastore.XML(args.data_path),
verbosity=args.verbose)
if args.no_connect: if args.no_connect:
context.noautoconnect = True context.noautoconnect = True
# Load the prompt
import nemubot.prompt
prmpt = nemubot.prompt.Prompt()
# Register the hook for futur import # Register the hook for futur import
from nemubot.importer import ModuleFinder #from nemubot.importer import ModuleFinder
sys.meta_path.append(ModuleFinder(context.modules_paths, context.add_module)) #module_finder = ModuleFinder(modules_paths, context.add_module)
#sys.meta_path.append(module_finder)
# Load requested configuration files # Load requested configuration files
for path in args.files: #for path in args.files:
if os.path.isfile(path): # if os.path.isfile(path):
context.sync_queue.put_nowait(["loadconf", path]) # sync_act("loadconf", path)
else: # else:
logger.error("%s is not a readable file", path) # logger.error("%s is not a readable file", path)
if args.module: if args.module:
for module in args.module: for module in args.module:
__import__(module) __import__(module)
print ("Nemubot v%s ready, my PID is %i!" % (nemubot.__version__, # Signals handling ################################################
os.getpid()))
while True:
from nemubot.prompt.reset import PromptReset
try:
context.start()
if prmpt.run(context):
break
except PromptReset as e:
if e.type == "quit":
break
try: # SIGINT and SIGTERM
import imp
# Reload all other modules
imp.reload(nemubot)
imp.reload(nemubot.prompt)
nemubot.reload()
import nemubot.bot
context = nemubot.bot.hotswap(context)
prmpt = nemubot.prompt.hotswap(prmpt)
print("\033[1;32mContext reloaded\033[0m, now in Nemubot %s" %
nemubot.__version__)
except:
logger.exception("\033[1;31mUnable to reload the prompt due to "
"errors.\033[0m Fix them before trying to reload "
"the prompt.")
context.quit() def ask_exit(signame):
print("Waiting for other threads shuts down...") """On SIGTERM and SIGINT, quit nicely"""
context.stop()
for sig in (signal.SIGINT, signal.SIGTERM):
context._loop.add_signal_handler(sig, functools.partial(ask_exit, sig))
# SIGHUP
def ask_reload():
"""Perform a deep reload"""
nonlocal context
logger.debug("SIGHUP receive, iniate reload procedure...")
# Reload configuration file
#for path in args.files:
# if os.path.isfile(path):
# sync_act("loadconf", path)
context._loop.add_signal_handler(signal.SIGHUP, ask_reload)
# SIGUSR1
def ask_debug_info():
"""Display debug informations and stacktraces"""
import threading, traceback
for threadId, stack in sys._current_frames().items():
thName = "#%d" % threadId
for th in threading.enumerate():
if th.ident == threadId:
thName = th.name
break
logger.debug("########### Thread %s:\n%s",
thName,
"".join(traceback.format_stack(stack)))
context._loop.add_signal_handler(signal.SIGUSR1, ask_debug_info)
#if args.socketfile:
# from nemubot.server.socket import UnixSocketListener
# context.add_server(UnixSocketListener(new_server_cb=context.add_server,
# location=args.socketfile,
# name="master_socket"))
# context can change when performing an hotswap, always join the latest context
oldcontext = None
while oldcontext != context:
oldcontext = context
context.run()
# Wait for consumers
logger.info("Waiting for other threads shuts down...")
if args.debug:
ask_debug_info()
sys.exit(0) sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -14,568 +14,34 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime, timezone
import logging import logging
import threading
import sys import sys
from nemubot import __version__
from nemubot.consumer import Consumer, EventConsumer, MessageConsumer
from nemubot import datastore
import nemubot.hooks
logger = logging.getLogger("nemubot") class Bot:
class Bot(threading.Thread):
"""Class containing the bot context and ensuring key goals""" """Class containing the bot context and ensuring key goals"""
def __init__(self, ip="127.0.0.1", modules_paths=list(), logger = logging.getLogger("nemubot")
data_store=datastore.Abstract(), verbosity=0):
"""Initialize the bot context
Keyword arguments: def __init__(self):
ip -- The external IP of the bot (default: 127.0.0.1) """Initialize the bot context
modules_paths -- Paths to all directories where looking for module
data_store -- An instance of the nemubot datastore for bot's modules
""" """
threading.Thread.__init__(self) from nemubot import __version__
Bot.logger.info("Initiate nemubot v%s, running on Python %s",
__version__, sys.version.split("\n")[0])
logger.info("Initiate nemubot v%s (running on Python %s.%s.%s)", import asyncio
__version__, self._loop = asyncio.get_event_loop()
sys.version_info.major, sys.version_info.minor, sys.version_info.micro)
self.verbosity = verbosity
self.stop = None
# External IP for accessing this bot
import ipaddress
self.ip = ipaddress.ip_address(ip)
# Context paths
self.modules_paths = modules_paths
self.datastore = data_store
self.datastore.open()
# Keep global context: servers and modules
self.servers = dict()
self.modules = dict()
self.modules_configuration = dict()
# Events
self.events = list()
self.event_timer = None
# Own hooks
from nemubot.treatment import MessageTreater
self.treater = MessageTreater()
import re
def in_ping(msg):
return msg.respond("pong")
self.treater.hm.add_hook(nemubot.hooks.Message(in_ping,
match=lambda msg: re.match("^ *(m[' ]?entends?[ -]+tu|h?ear me|do you copy|ping)",
msg.message, re.I)),
"in", "DirectAsk")
def in_echo(msg):
from nemubot.message import Text
return Text(msg.nick + ": " + " ".join(msg.args), to=msg.to_response)
self.treater.hm.add_hook(nemubot.hooks.Command(in_echo, "echo"), "in", "Command")
def _help_msg(msg):
"""Parse and response to help messages"""
from more import Response
res = Response(channel=msg.to_response)
if len(msg.args) >= 1:
if msg.args[0] in self.modules:
if hasattr(self.modules[msg.args[0]], "help_full"):
hlp = self.modules[msg.args[0]].help_full()
if isinstance(hlp, Response):
return hlp
else:
res.append_message(hlp)
else:
res.append_message([str(h) for s,h in self.modules[msg.args[0]].__nemubot_context__.hooks], title="Available commands for module " + msg.args[0])
elif msg.args[0][0] == "!":
from nemubot.message.command import Command
for h in self.treater._in_hooks(Command(msg.args[0][1:])):
if h.help_usage:
lp = ["\x03\x02%s%s\x03\x02: %s" % (msg.args[0], (" " + k if k is not None else ""), h.help_usage[k]) for k in h.help_usage]
jp = h.keywords.help()
return res.append_message(lp + ([". Moreover, you can provides some optional parameters: "] + jp if len(jp) else []), title="Usage for command %s" % msg.args[0])
elif h.help:
return res.append_message("Command %s: %s" % (msg.args[0], h.help))
else:
return res.append_message("Sorry, there is currently no help for the command %s. Feel free to make a pull request at https://github.com/nemunaire/nemubot/compare" % msg.args[0])
res.append_message("Sorry, there is no command %s" % msg.args[0])
else:
res.append_message("Sorry, there is no module named %s" % msg.args[0])
else:
res.append_message("Pour me demander quelque chose, commencez "
"votre message par mon nom ; je réagis "
"également à certaine commandes commençant par"
" !. Pour plus d'informations, envoyez le "
"message \"!more\".")
res.append_message("Mon code source est libre, publié sous "
"licence AGPL (http://www.gnu.org/licenses/). "
"Vous pouvez le consulter, le dupliquer, "
"envoyer des rapports de bogues ou bien "
"contribuer au projet sur GitHub : "
"http://github.com/nemunaire/nemubot/")
res.append_message(title="Pour plus de détails sur un module, "
"envoyez \"!help nomdumodule\". Voici la liste"
" de tous les modules disponibles localement",
message=["\x03\x02%s\x03\x02 (%s)" % (im, self.modules[im].__doc__) for im in self.modules if self.modules[im].__doc__])
return res
self.treater.hm.add_hook(nemubot.hooks.Command(_help_msg, "help"), "in", "Command")
from queue import Queue
# Messages to be treated
self.cnsr_queue = Queue()
self.cnsr_thrd = list()
self.cnsr_thrd_size = -1
# Synchrone actions to be treated by main thread
self.sync_queue = Queue()
def run(self): def run(self):
from select import select self._loop.run_forever()
from nemubot.server import _lock, _rlist, _wlist, _xlist self._loop.close()
self.stop = False
while not self.stop:
with _lock:
try:
rl, wl, xl = select(_rlist, _wlist, _xlist, 0.1)
except:
logger.error("Something went wrong in select")
fnd_smth = False
# Looking for invalid server
for r in _rlist:
if not hasattr(r, "fileno") or not isinstance(r.fileno(), int) or r.fileno() < 0:
_rlist.remove(r)
logger.error("Found invalid object in _rlist: " + str(r))
fnd_smth = True
for w in _wlist:
if not hasattr(w, "fileno") or not isinstance(w.fileno(), int) or w.fileno() < 0:
_wlist.remove(w)
logger.error("Found invalid object in _wlist: " + str(w))
fnd_smth = True
for x in _xlist:
if not hasattr(x, "fileno") or not isinstance(x.fileno(), int) or x.fileno() < 0:
_xlist.remove(x)
logger.error("Found invalid object in _xlist: " + str(x))
fnd_smth = True
if not fnd_smth:
logger.exception("Can't continue, sorry")
self.quit()
continue
for x in xl: def stop(self):
try:
x.exception()
except:
logger.exception("Uncatched exception on server exception")
for w in wl:
try:
w.write_select()
except:
logger.exception("Uncatched exception on server write")
for r in rl:
for i in r.read():
try:
self.receive_message(r, i)
except:
logger.exception("Uncatched exception on server read")
# Launch new consumer threads if necessary
while self.cnsr_queue.qsize() > self.cnsr_thrd_size:
# Next launch if two more items in queue
self.cnsr_thrd_size += 2
c = Consumer(self)
self.cnsr_thrd.append(c)
c.start()
while self.sync_queue.qsize() > 0:
action = self.sync_queue.get_nowait()
if action[0] == "exit":
self.quit()
elif action[0] == "loadconf":
for path in action[1:]:
logger.debug("Load configuration from %s", path)
self.load_file(path)
logger.info("Configurations successfully loaded")
self.sync_queue.task_done()
# Config methods
def load_file(self, filename):
"""Load a configuration file
Arguments:
filename -- the path to the file to load
"""
import os
# Unexisting file, assume a name was passed, import the module!
if not os.path.isfile(filename):
return self.import_module(filename)
from nemubot.channel import Channel
from nemubot import config
from nemubot.tools.xmlparser import XMLParser
try:
p = XMLParser({
"nemubotconfig": config.Nemubot,
"server": config.Server,
"channel": Channel,
"module": config.Module,
"include": config.Include,
})
config = p.parse_file(filename)
except:
logger.exception("Can't load `%s'; this is not a valid nemubot "
"configuration file." % filename)
return False
# Preset each server in this file
for server in config.servers:
srv = server.server(config)
# Add the server in the context
if self.add_server(srv, server.autoconnect):
logger.info("Server '%s' successfully added." % srv.id)
else:
logger.error("Can't add server '%s'." % srv.id)
# Load module and their configuration
for mod in config.modules:
self.modules_configuration[mod.name] = mod
if mod.autoload:
try:
__import__(mod.name)
except:
logger.exception("Exception occurs when loading module"
" '%s'", mod.name)
# Load files asked by the configuration file
for load in config.includes:
self.load_file(load.path)
# Events methods
def add_event(self, evt, eid=None, module_src=None):
"""Register an event and return its identifiant for futur update
Return:
None if the event is not in the queue (eg. if it has been executed during the call) or
returns the event ID.
Argument:
evt -- The event object to add
Keyword arguments:
eid -- The desired event ID (object or string UUID)
module_src -- The module to which the event is attached to
"""
if hasattr(self, "stop") and self.stop:
logger.warn("The bot is stopped, can't register new events")
return
import uuid
# Generate the event id if no given
if eid is None:
eid = uuid.uuid1()
# Fill the id field of the event
if type(eid) is uuid.UUID:
evt.id = str(eid)
else:
# Ok, this is quite useless...
try:
evt.id = str(uuid.UUID(eid))
except ValueError:
evt.id = eid
# TODO: mutex here plz
# Add the event in its place
t = evt.current
i = 0 # sentinel
for i in range(0, len(self.events)):
if self.events[i].current > t:
break
self.events.insert(i, evt)
if i == 0:
# First event changed, reset timer
self._update_event_timer()
if len(self.events) <= 0 or self.events[i] != evt:
# Our event has been executed and removed from queue
return None
# Register the event in the source module
if module_src is not None:
module_src.__nemubot_context__.events.append(evt.id)
evt.module_src = module_src
logger.info("New event registered in %d position: %s", i, t)
return evt.id
def del_event(self, evt, module_src=None):
"""Find and remove an event from list
Return:
True if the event has been found and removed, False else
Argument:
evt -- The ModuleEvent object to remove or just the event identifier
Keyword arguments:
module_src -- The module to which the event is attached to (ignored if evt is a ModuleEvent)
"""
logger.info("Removing event: %s from %s", evt, module_src)
from nemubot.event import ModuleEvent
if type(evt) is ModuleEvent:
id = evt.id
module_src = evt.module_src
else:
id = evt
if len(self.events) > 0 and id == self.events[0].id:
self.events.remove(self.events[0])
self._update_event_timer()
if module_src is not None:
module_src.__nemubot_context__.events.remove(id)
return True
for evt in self.events:
if evt.id == id:
self.events.remove(evt)
if module_src is not None:
module_src.__nemubot_context__.events.remove(evt.id)
return True
return False
def _update_event_timer(self):
"""(Re)launch the timer to end with the closest event"""
# Reset the timer if this is the first item
if self.event_timer is not None:
self.event_timer.cancel()
if len(self.events):
remaining = self.events[0].time_left.total_seconds()
logger.debug("Update timer: next event in %d seconds", remaining)
self.event_timer = threading.Timer(remaining if remaining > 0 else 0, self._end_event_timer)
self.event_timer.start()
else:
logger.debug("Update timer: no timer left")
def _end_event_timer(self):
"""Function called at the end of the event timer"""
while len(self.events) > 0 and datetime.now(timezone.utc) >= self.events[0].current:
evt = self.events.pop(0)
self.cnsr_queue.put_nowait(EventConsumer(evt))
self._update_event_timer()
# Consumers methods
def add_server(self, srv, autoconnect=True):
"""Add a new server to the context
Arguments:
srv -- a concrete AbstractServer instance
autoconnect -- connect after add?
"""
if srv.id not in self.servers:
self.servers[srv.id] = srv
if autoconnect and not hasattr(self, "noautoconnect"):
srv.open()
return True
else:
return False
# Modules methods
def import_module(self, name):
"""Load a module
Argument:
name -- name of the module to load
"""
if name in self.modules:
self.unload_module(name)
__import__(name)
def add_module(self, module):
"""Add a module to the context, if already exists, unload the
old one before"""
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
if hasattr(self, "stop") and self.stop:
logger.warn("The bot is stopped, can't register new modules")
return
# Check if the module already exists
if module_name in self.modules:
self.unload_module(module_name)
# Overwrite print built-in
def prnt(*args):
if hasattr(module, "logger"):
module.logger.info(" ".join([str(s) for s in args]))
else:
logger.info("[%s] %s", module_name, " ".join([str(s) for s in args]))
module.print = prnt
# Create module context
from nemubot.modulecontext import ModuleContext
module.__nemubot_context__ = ModuleContext(self, module)
if not hasattr(module, "logger"):
module.logger = logging.getLogger("nemubot.module." + module_name)
# Replace imported context by real one
for attr in module.__dict__:
if attr != "__nemubot_context__" and type(module.__dict__[attr]) == ModuleContext:
module.__dict__[attr] = module.__nemubot_context__
# Register decorated functions
import nemubot.hooks
for s, h in nemubot.hooks.hook.last_registered:
module.__nemubot_context__.add_hook(h, *s if isinstance(s, list) else s)
nemubot.hooks.hook.last_registered = []
# Launch the module
if hasattr(module, "load"):
try:
module.load(module.__nemubot_context__)
except:
module.__nemubot_context__.unload()
raise
# Save a reference to the module
self.modules[module_name] = module
def unload_module(self, name):
"""Unload a module"""
if name in self.modules:
self.modules[name].print("Unloading module %s" % name)
# Call the user defined unload method
if hasattr(self.modules[name], "unload"):
self.modules[name].unload(self)
self.modules[name].__nemubot_context__.unload()
# Remove from the nemubot dict
del self.modules[name]
# Remove from the Python dict
del sys.modules[name]
for mod in [i for i in sys.modules]:
if mod[:len(name) + 1] == name + ".":
logger.debug("Module '%s' also removed from system modules list.", mod)
del sys.modules[mod]
logger.info("Module `%s' successfully unloaded.", name)
return True
return False
def receive_message(self, srv, msg):
"""Queued the message for treatment
Arguments:
srv -- The server where the message comes from
msg -- The message not parsed, as simple as possible
"""
self.cnsr_queue.put_nowait(MessageConsumer(srv, msg))
def quit(self):
"""Save and unload modules and disconnect servers""" """Save and unload modules and disconnect servers"""
self.datastore.close() self._loop.stop()
if self.event_timer is not None:
logger.info("Stop the event timer...")
self.event_timer.cancel()
logger.info("Stop consumers")
k = self.cnsr_thrd
for cnsr in k:
cnsr.stop = True
logger.info("Save and unload all modules...")
k = list(self.modules.keys())
for mod in k:
self.unload_module(mod)
logger.info("Close all servers connection...")
k = list(self.servers.keys())
for srv in k:
self.servers[srv].close()
self.stop = True
# Treatment
def check_rest_times(self, store, hook):
"""Remove from store the hook if it has been executed given time"""
if hook.times == 0:
if isinstance(store, dict):
store[hook.name].remove(hook)
if len(store) == 0:
del store[hook.name]
elif isinstance(store, list):
store.remove(hook)
def hotswap(bak):
bak.stop = True
if bak.event_timer is not None:
bak.event_timer.cancel()
bak.datastore.close()
new = Bot(str(bak.ip), bak.modules_paths, bak.datastore)
new.servers = bak.servers
new.modules = bak.modules
new.modules_configuration = bak.modules_configuration
new.events = bak.events
new.hooks = bak.hooks
new._update_event_timer()
return new

View file

@ -24,24 +24,3 @@ from nemubot.config.include import Include
from nemubot.config.module import Module from nemubot.config.module import Module
from nemubot.config.nemubot import Nemubot from nemubot.config.nemubot import Nemubot
from nemubot.config.server import Server from nemubot.config.server import Server
def reload():
global Include, Module, Nemubot, Server
import imp
import nemubot.config.include
imp.reload(nemubot.config.include)
Include = nemubot.config.include.Include
import nemubot.config.module
imp.reload(nemubot.config.module)
Module = nemubot.config.module.Module
import nemubot.config.nemubot
imp.reload(nemubot.config.nemubot)
Nemubot = nemubot.config.nemubot.Nemubot
import nemubot.config.server
imp.reload(nemubot.config.server)
Server = nemubot.config.server.Server

View file

@ -33,13 +33,11 @@ class Server:
return True return True
def server(self, parent): def server(self, caps=[], **kwargs):
from nemubot.server import factory from nemubot.server import factory
for a in ["nick", "owner", "realname", "encoding"]: caps += self.caps
if a not in self.args:
self.args[a] = getattr(parent, a)
self.caps += parent.caps kwargs.update(self.args)
return factory(self.uri, caps=self.caps, channels=self.channels, **self.args) return factory(self.uri, caps=caps, channels=self.channels, **kwargs)

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -33,47 +33,47 @@ class MessageConsumer:
def run(self, context): def run(self, context):
"""Create, parse and treat the message""" """Create, parse and treat the message"""
# Dereference weakptr
srv = self.srv()
if srv is None:
return
from nemubot.bot import Bot from nemubot.bot import Bot
assert isinstance(context, Bot) assert isinstance(context, Bot)
msgs = [] msgs = []
# Parse the message # Parse message
try: try:
for msg in self.srv.parse(self.orig): for msg in srv.parse(self.orig):
msgs.append(msg) msgs.append(msg)
except: except:
logger.exception("Error occurred during the processing of the %s: " logger.exception("Error occurred during the processing of the %s: "
"%s", type(self.msgs[0]).__name__, self.msgs[0]) "%s", type(self.orig).__name__, self.orig)
if len(msgs) <= 0: # Treat message
return
# Qualify the message
if not hasattr(msg, "server") or msg.server is None:
msg.server = self.srv.id
if hasattr(msg, "frm_owner"):
msg.frm_owner = (not hasattr(self.srv, "owner") or self.srv.owner == msg.frm)
# Treat the message
for msg in msgs: for msg in msgs:
for res in context.treater.treat_msg(msg): for res in context.treater.treat_msg(msg):
# Identify the destination # Identify destination
to_server = None to_server = None
if isinstance(res, str): if isinstance(res, str):
to_server = self.srv to_server = srv
elif not hasattr(res, "server"):
logger.error("No server defined for response of type %s: %s", type(res).__name__, res)
continue
elif res.server is None: elif res.server is None:
to_server = self.srv to_server = srv
res.server = self.srv.id res.server = srv.fileno()
elif isinstance(res.server, str) and res.server in context.servers: elif res.server in context.servers:
to_server = context.servers[res.server] to_server = context.servers[res.server]
else:
to_server = res.server
if to_server is None: if to_server is None or not hasattr(to_server, "send_response") or not callable(to_server.send_response):
logger.error("The server defined in this response doesn't " logger.error("The server defined in this response doesn't exist: %s", res.server)
"exist: %s", res.server)
continue continue
# Sent the message only if treat_post authorize it # Sent message
to_server.send_response(res) to_server.send_response(res)
@ -94,12 +94,7 @@ class EventConsumer:
# Reappend the event in the queue if it has next iteration # Reappend the event in the queue if it has next iteration
if self.evt.next is not None: if self.evt.next is not None:
context.add_event(self.evt, eid=self.evt.id) context.add_event(self.evt)
# Or remove reference of this event
elif (hasattr(self.evt, "module_src") and
self.evt.module_src is not None):
self.evt.module_src.__nemubot_context__.events.remove(self.evt.id)
@ -108,20 +103,23 @@ class Consumer(threading.Thread):
"""Dequeue and exec requested action""" """Dequeue and exec requested action"""
def __init__(self, context): def __init__(self, context):
super().__init__(name="Nemubot consumer")
self.context = context self.context = context
self.stop = False
threading.Thread.__init__(self)
def run(self): def run(self):
try: try:
while not self.stop: while True:
stm = self.context.cnsr_queue.get(True, 1) context = self.context()
stm.run(self.context) if context is None:
self.context.cnsr_queue.task_done() break
stm = context.cnsr_queue.get(True, 1)
stm.run(context)
context.cnsr_queue.task_done()
except queue.Empty: except queue.Empty:
pass pass
finally: finally:
self.context.cnsr_thrd_size -= 2 if self.context() is not None:
self.context.cnsr_thrd.remove(self) self.context().cnsr_thrd.remove(self)

View file

@ -16,16 +16,3 @@
from nemubot.datastore.abstract import Abstract from nemubot.datastore.abstract import Abstract
from nemubot.datastore.xml import XML from nemubot.datastore.xml import XML
def reload():
global Abstract, XML
import imp
import nemubot.datastore.abstract
imp.reload(nemubot.datastore.abstract)
Abstract = nemubot.datastore.abstract.Abstract
import nemubot.datastore.xml
imp.reload(nemubot.datastore.xml)
XML = nemubot.datastore.xml.XML

View file

@ -32,10 +32,3 @@ class IMException(Exception):
from nemubot.message import Text from nemubot.message import Text
return Text(*self.args, return Text(*self.args,
server=msg.server, to=msg.to_response) server=msg.server, to=msg.to_response)
def reload():
import imp
import nemubot.exception.Keyword
imp.reload(nemubot.exception.printer.IRC)

View file

@ -49,23 +49,3 @@ class hook:
def pre(*args, store=["pre"], **kwargs): def pre(*args, store=["pre"], **kwargs):
return hook._add(store, Abstract, *args, **kwargs) return hook._add(store, Abstract, *args, **kwargs)
def reload():
import imp
import nemubot.hooks.abstract
imp.reload(nemubot.hooks.abstract)
import nemubot.hooks.command
imp.reload(nemubot.hooks.command)
import nemubot.hooks.message
imp.reload(nemubot.hooks.message)
import nemubot.hooks.keywords
imp.reload(nemubot.hooks.keywords)
nemubot.hooks.keywords.reload()
import nemubot.hooks.manager
imp.reload(nemubot.hooks.manager)

View file

@ -26,11 +26,22 @@ class NoKeyword(Abstract):
return super().check(mkw) return super().check(mkw)
def reload(): class AnyKeyword(Abstract):
import imp
import nemubot.hooks.keywords.abstract def __init__(self, h):
imp.reload(nemubot.hooks.keywords.abstract) """Class that accepts any passed keywords
import nemubot.hooks.keywords.dict Arguments:
imp.reload(nemubot.hooks.keywords.dict) h -- Help string
"""
super().__init__()
self.h = h
def check(self, mkw):
return super().check(mkw)
def help(self):
return self.h

View file

@ -19,27 +19,3 @@ from nemubot.message.text import Text
from nemubot.message.directask import DirectAsk from nemubot.message.directask import DirectAsk
from nemubot.message.command import Command from nemubot.message.command import Command
from nemubot.message.command import OwnerCommand from nemubot.message.command import OwnerCommand
def reload():
global Abstract, Text, DirectAsk, Command, OwnerCommand
import imp
import nemubot.message.abstract
imp.reload(nemubot.message.abstract)
Abstract = nemubot.message.abstract.Abstract
imp.reload(nemubot.message.text)
Text = nemubot.message.text.Text
imp.reload(nemubot.message.directask)
DirectAsk = nemubot.message.directask.DirectAsk
imp.reload(nemubot.message.command)
Command = nemubot.message.command.Command
OwnerCommand = nemubot.message.command.OwnerCommand
import nemubot.message.visitor
imp.reload(nemubot.message.visitor)
import nemubot.message.printer
imp.reload(nemubot.message.printer)
nemubot.message.printer.reload()

View file

@ -22,7 +22,7 @@ class Command(Abstract):
"""This class represents a specialized TextMessage""" """This class represents a specialized TextMessage"""
def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs): def __init__(self, cmd, args=None, kwargs=None, *nargs, **kargs):
Abstract.__init__(self, *nargs, **kargs) super().__init__(*nargs, **kargs)
self.cmd = cmd self.cmd = cmd
self.args = args if args is not None else list() self.args = args if args is not None else list()

View file

@ -28,7 +28,7 @@ class DirectAsk(Text):
designated -- the user designated by the message designated -- the user designated by the message
""" """
Text.__init__(self, *args, **kargs) super().__init__(*args, **kargs)
self.designated = designated self.designated = designated

View file

@ -22,4 +22,4 @@ class IRC(SocketPrinter):
def visit_Text(self, msg): def visit_Text(self, msg):
self.pp += "PRIVMSG %s :" % ",".join(msg.to) self.pp += "PRIVMSG %s :" % ",".join(msg.to)
SocketPrinter.visit_Text(self, msg) super().visit_Text(msg)

View file

@ -13,12 +13,3 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def reload():
import imp
import nemubot.message.printer.IRC
imp.reload(nemubot.message.printer.IRC)
import nemubot.message.printer.socket
imp.reload(nemubot.message.printer.socket)

View file

@ -35,7 +35,7 @@ class Socket(AbstractVisitor):
others = [to for to in msg.to if to != msg.designated] others = [to for to in msg.to if to != msg.designated]
# Avoid nick starting message when discussing on user channel # Avoid nick starting message when discussing on user channel
if len(others) != len(msg.to): if len(others) == 0 or len(others) != len(msg.to):
res = Text(msg.message, res = Text(msg.message,
server=msg.server, date=msg.date, server=msg.server, date=msg.date,
to=msg.to, frm=msg.frm) to=msg.to, frm=msg.frm)

View file

@ -1,5 +1,3 @@
# coding=utf-8
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2015 Mercier Pierre-Olivier
# #
@ -16,8 +14,21 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
class PromptReset(Exception): from nemubot.message.abstract import Abstract
def __init__(self, type):
super(PromptReset, self).__init__("Prompt reset asked") class Response(Abstract):
self.type = type
def __init__(self, cmd, args=None, *nargs, **kargs):
super().__init__(*nargs, **kargs)
self.cmd = cmd
self.args = args if args is not None else list()
def __str__(self):
return self.cmd + " @" + ",@".join(self.args)
@property
def cmds(self):
# TODO: this is for legacy modules
return [self.cmd] + self.args

View file

@ -28,7 +28,7 @@ class Text(Abstract):
message -- the parsed message message -- the parsed message
""" """
Abstract.__init__(self, *args, **kargs) super().__init__(*args, **kargs)
self.message = message self.message = message

View file

@ -16,19 +16,15 @@
class ModuleContext: class ModuleContext:
def __init__(self, context, module): def __init__(self, context, module_name, logger):
"""Initialize the module context """Initialize the module context
arguments: arguments:
context -- the bot context context -- the bot context
module -- the module module_name -- the module name
logger -- a logger
""" """
if module is not None:
module_name = module.__spec__.name if hasattr(module, "__spec__") else module.__name__
else:
module_name = ""
# Load module configuration if exists # Load module configuration if exists
if (context is not None and if (context is not None and
module_name in context.modules_configuration): module_name in context.modules_configuration):
@ -60,10 +56,16 @@ class ModuleContext:
def subtreat(msg): def subtreat(msg):
yield from context.treater.treat_msg(msg) yield from context.treater.treat_msg(msg)
def add_event(evt, eid=None): def add_event(call=None, **kwargs):
return context.add_event(evt, eid, module_src=module) evt = context.add_event(call, **kwargs)
if evt is not None:
self.events.append(evt)
return evt
def del_event(evt): def del_event(evt):
return context.del_event(evt, module_src=module) if context.del_event(evt):
self._clean_events()
return True
return False
def send_response(server, res): def send_response(server, res):
if server in context.servers: if server in context.servers:
@ -72,7 +74,7 @@ class ModuleContext:
else: else:
return context.servers[server].send_response(res) return context.servers[server].send_response(res)
else: else:
module.logger.error("Try to send a message to the unknown server: %s", server) logger.error("Try to send a message to the unknown server: %s", server)
return False return False
else: # Used when using outside of nemubot else: # Used when using outside of nemubot
@ -88,13 +90,13 @@ class ModuleContext:
self.hooks.remove((triggers, hook)) self.hooks.remove((triggers, hook))
def subtreat(msg): def subtreat(msg):
return None return None
def add_event(evt, eid=None): def add_event(evt):
return context.add_event(evt, eid, module_src=module) return context.add_event(evt)
def del_event(evt): def del_event(evt):
return context.del_event(evt, module_src=module) return context.del_event(evt, module_src=module)
def send_response(server, res): def send_response(server, res):
module.logger.info("Send response: %s", res) logger.info("Send response: %s", res)
def save(): def save():
context.datastore.save(module_name, self.data) context.datastore.save(module_name, self.data)
@ -121,6 +123,14 @@ class ModuleContext:
return self._data return self._data
def _clean_events(self):
"""Look for None weakref in the events list"""
for i in range(len(self.events), 0, -1):
if self.events[i-1]() is None:
self.events.remove()
def unload(self): def unload(self):
"""Perform actions for unloading the module""" """Perform actions for unloading the module"""

View file

@ -1,142 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import shlex
import sys
import traceback
from nemubot.prompt import builtins
class Prompt:
def __init__(self):
self.selectedServer = None
self.lastretcode = 0
self.HOOKS_CAPS = dict()
self.HOOKS_LIST = dict()
def add_cap_hook(self, name, call, data=None):
self.HOOKS_CAPS[name] = lambda t, c: call(t, data=data,
context=c, prompt=self)
def add_list_hook(self, name, call):
self.HOOKS_LIST[name] = call
def lex_cmd(self, line):
"""Return an array of tokens
Argument:
line -- the line to lex
"""
try:
cmds = shlex.split(line)
except:
exc_type, exc_value, _ = sys.exc_info()
sys.stderr.write(traceback.format_exception_only(exc_type,
exc_value)[0])
return
bgn = 0
# Separate commands (command separator: ;)
for i in range(0, len(cmds)):
if cmds[i][-1] == ';':
if i != bgn:
yield cmds[bgn:i]
bgn = i + 1
# Return rest of the command (that not end with a ;)
if bgn != len(cmds):
yield cmds[bgn:]
def exec_cmd(self, toks, context):
"""Execute the command
Arguments:
toks -- lexed tokens to executes
context -- current bot context
"""
if toks[0] in builtins.CAPS:
self.lastretcode = builtins.CAPS[toks[0]](toks, context, self)
elif toks[0] in self.HOOKS_CAPS:
self.lastretcode = self.HOOKS_CAPS[toks[0]](toks, context)
else:
print("Unknown command: `%s'" % toks[0])
self.lastretcode = 127
def getPS1(self):
"""Get the PS1 associated to the selected server"""
if self.selectedServer is None:
return "nemubot"
else:
return self.selectedServer.id
def run(self, context):
"""Launch the prompt
Argument:
context -- current bot context
"""
from nemubot.prompt.error import PromptError
from nemubot.prompt.reset import PromptReset
while True: # Stopped by exception
try:
line = input("\033[0;33m%s\033[0;%d\033[0m " %
(self.getPS1(), 31 if self.lastretcode else 32))
cmds = self.lex_cmd(line.strip())
for toks in cmds:
try:
self.exec_cmd(toks, context)
except PromptReset:
raise
except PromptError as e:
print(e.message)
self.lastretcode = 128
except:
exc_type, exc_value, exc_traceback = sys.exc_info()
traceback.print_exception(exc_type, exc_value,
exc_traceback)
except KeyboardInterrupt:
print("")
except EOFError:
print("quit")
return True
def hotswap(bak):
p = Prompt()
p.HOOKS_CAPS = bak.HOOKS_CAPS
p.HOOKS_LIST = bak.HOOKS_LIST
return p
def reload():
import imp
import nemubot.prompt.builtins
imp.reload(nemubot.prompt.builtins)
import nemubot.prompt.error
imp.reload(nemubot.prompt.error)
import nemubot.prompt.reset
imp.reload(nemubot.prompt.reset)

View file

@ -1,128 +0,0 @@
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
def end(toks, context, prompt):
"""Quit the prompt for reload or exit"""
from nemubot.prompt.reset import PromptReset
if toks[0] == "refresh":
raise PromptReset("refresh")
elif toks[0] == "reset":
raise PromptReset("reset")
raise PromptReset("quit")
def liste(toks, context, prompt):
"""Show some lists"""
if len(toks) > 1:
for l in toks[1:]:
l = l.lower()
if l == "server" or l == "servers":
for srv in context.servers.keys():
print (" - %s (state: %s) ;" % (srv,
"connected" if context.servers[srv].connected else "disconnected"))
if len(context.servers) == 0:
print (" > No server loaded")
elif l == "mod" or l == "mods" or l == "module" or l == "modules":
for mod in context.modules.keys():
print (" - %s ;" % mod)
if len(context.modules) == 0:
print (" > No module loaded")
elif l in prompt.HOOKS_LIST:
f, d = prompt.HOOKS_LIST[l]
f(d, context, prompt)
else:
print (" Unknown list `%s'" % l)
return 2
return 0
else:
print (" Please give a list to show: servers, ...")
return 1
def load(toks, context, prompt):
"""Load an XML configuration file"""
if len(toks) > 1:
for filename in toks[1:]:
context.load_file(filename)
else:
print ("Not enough arguments. `load' takes a filename.")
return 1
def select(toks, context, prompt):
"""Select the current server"""
if (len(toks) == 2 and toks[1] != "None" and
toks[1] != "nemubot" and toks[1] != "none"):
if toks[1] in context.servers:
prompt.selectedServer = context.servers[toks[1]]
else:
print ("select: server `%s' not found." % toks[1])
return 1
else:
prompt.selectedServer = None
def unload(toks, context, prompt):
"""Unload a module"""
if len(toks) == 2 and toks[1] == "all":
for name in context.modules.keys():
context.unload_module(name)
elif len(toks) > 1:
for name in toks[1:]:
if context.unload_module(name):
print (" Module `%s' successfully unloaded." % name)
else:
print (" No module `%s' loaded, can't unload!" % name)
return 2
else:
print ("Not enough arguments. `unload' takes a module name.")
return 1
def debug(toks, context, prompt):
"""Enable/Disable debug mode on a module"""
if len(toks) > 1:
for name in toks[1:]:
if name in context.modules:
context.modules[name].DEBUG = not context.modules[name].DEBUG
if context.modules[name].DEBUG:
print (" Module `%s' now in DEBUG mode." % name)
else:
print (" Debug for module module `%s' disabled." % name)
else:
print (" No module `%s' loaded, can't debug!" % name)
return 2
else:
print ("Not enough arguments. `debug' takes a module name.")
return 1
# Register build-ins
CAPS = {
'quit': end, # Disconnect all server and quit
'exit': end, # Alias for quit
'reset': end, # Reload the prompt
'refresh': end, # Reload the prompt but save modules
'load': load, # Load a servers or module configuration file
'unload': unload, # Unload a module and remove it from the list
'select': select, # Select a server
'list': liste, # Show lists
'debug': debug, # Pass a module in debug mode
}

View file

@ -31,7 +31,7 @@ PORTS = list()
class DCC(server.AbstractServer): class DCC(server.AbstractServer):
def __init__(self, srv, dest, socket=None): def __init__(self, srv, dest, socket=None):
server.Server.__init__(self) super().__init__(name="Nemubot DCC server")
self.error = False # An error has occur, closing the connection? self.error = False # An error has occur, closing the connection?
self.messages = list() # Message queued before connexion self.messages = list() # Message queued before connexion

View file

@ -20,17 +20,17 @@ import re
from nemubot.channel import Channel from nemubot.channel import Channel
from nemubot.message.printer.IRC import IRC as IRCPrinter from nemubot.message.printer.IRC import IRC as IRCPrinter
from nemubot.server.message.IRC import IRC as IRCMessage from nemubot.server.message.IRC import IRC as IRCMessage
from nemubot.server.socket import SocketServer from nemubot.server.socket import SocketServer, SecureSocketServer
class IRC(SocketServer): class IRC():
"""Concrete implementation of a connexion to an IRC server""" """Concrete implementation of a connection to an IRC server"""
def __init__(self, host="localhost", port=6667, ssl=False, owner=None, def __init__(self, host="localhost", port=6667, owner=None,
nick="nemubot", username=None, password=None, nick="nemubot", username=None, password=None,
realname="Nemubot", encoding="utf-8", caps=None, realname="Nemubot", encoding="utf-8", caps=None,
channels=list(), on_connect=None): channels=list(), on_connect=None, **kwargs):
"""Prepare a connection with an IRC server """Prepare a connection with an IRC server
Keyword arguments: Keyword arguments:
@ -54,8 +54,8 @@ class IRC(SocketServer):
self.owner = owner self.owner = owner
self.realname = realname self.realname = realname
self.id = self.username + "@" + host + ":" + str(port) super().__init__(name=self.username + "@" + host + ":" + str(port),
SocketServer.__init__(self, host=host, port=port, ssl=ssl) host=host, port=port, **kwargs)
self.printer = IRCPrinter self.printer = IRCPrinter
self.encoding = encoding self.encoding = encoding
@ -232,29 +232,29 @@ class IRC(SocketServer):
# Open/close # Open/close
def _open(self): def connect(self):
if SocketServer._open(self): super().connect()
if self.password is not None:
self.write("PASS :" + self.password) if self.password is not None:
if self.capabilities is not None: self.write("PASS :" + self.password)
self.write("CAP LS") if self.capabilities is not None:
self.write("NICK :" + self.nick) self.write("CAP LS")
self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname)) self.write("NICK :" + self.nick)
return True self.write("USER %s %s bla :%s" % (self.username, self.host, self.realname))
return False
def _close(self): def close(self):
if self.connected: self.write("QUIT") if not self._closed:
return SocketServer._close(self) self.write("QUIT")
return super().close()
# Writes: as inherited # Writes: as inherited
# Read # Read
def read(self): def async_read(self):
for line in SocketServer.read(self): for line in super().async_read():
# PING should be handled here, so start parsing here :/ # PING should be handled here, so start parsing here :/
msg = IRCMessage(line, self.encoding) msg = IRCMessage(line, self.encoding)

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,75 +14,65 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
_lock = threading.Lock() def factory(uri, ssl=False, **init_args):
from urllib.parse import urlparse, unquote, parse_qs
# Lists for select
_rlist = []
_wlist = []
_xlist = []
def factory(uri, **init_args):
from urllib.parse import urlparse, unquote
o = urlparse(uri) o = urlparse(uri)
srv = None
if o.scheme == "irc" or o.scheme == "ircs": if o.scheme == "irc" or o.scheme == "ircs":
# http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt # http://www.w3.org/Addressing/draft-mirashi-url-irc-01.txt
# http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html # http://www-archive.mozilla.org/projects/rt-messaging/chatzilla/irc-urls.html
args = init_args args = init_args
modifiers = o.path.split(",") if o.scheme == "ircs": ssl = True
target = unquote(modifiers.pop(0)[1:])
if o.scheme == "ircs": args["ssl"] = True
if o.hostname is not None: args["host"] = o.hostname if o.hostname is not None: args["host"] = o.hostname
if o.port is not None: args["port"] = o.port if o.port is not None: args["port"] = o.port
if o.username is not None: args["username"] = o.username if o.username is not None: args["username"] = o.username
if o.password is not None: args["password"] = o.password if o.password is not None: args["password"] = o.password
queries = o.query.split("&") if ssl:
for q in queries: try:
if "=" in q: from ssl import create_default_context
key, val = tuple(q.split("=", 1)) args["_context"] = create_default_context()
else: except ImportError:
key, val = q, "" # Python 3.3 compat
if key == "msg": from ssl import SSLContext, PROTOCOL_TLSv1
if "on_connect" not in args: args["_context"] = SSLContext(PROTOCOL_TLSv1)
args["on_connect"] = [] args["server_hostname"] = o.hostname
args["on_connect"].append("PRIVMSG %s :%s" % (target, unquote(val)))
elif key == "key":
if "channels" not in args:
args["channels"] = []
args["channels"].append((target, unquote(val)))
elif key == "pass":
args["password"] = unquote(val)
elif key == "charset":
args["encoding"] = unquote(val)
modifiers = o.path.split(",")
target = unquote(modifiers.pop(0)[1:])
# Read query string
params = parse_qs(o.query)
if "msg" in params:
if "on_connect" not in args:
args["on_connect"] = []
args["on_connect"].append("PRIVMSG %s :%s" % (target, params["msg"]))
if "key" in params:
if "channels" not in args:
args["channels"] = []
args["channels"].append((target, params["key"]))
if "pass" in params:
args["password"] = params["pass"]
if "charset" in params:
args["encoding"] = params["charset"]
#
if "channels" not in args and "isnick" not in modifiers: if "channels" not in args and "isnick" not in modifiers:
args["channels"] = [ target ] args["channels"] = [ target ]
from nemubot.server.IRC import IRC as IRCServer if ssl:
return IRCServer(**args) from nemubot.server.IRC import IRC_secure as SecureIRCServer
else: srv = SecureIRCServer(**args)
return None else:
from nemubot.server.IRC import IRC as IRCServer
srv = IRCServer(**args)
return srv
def reload():
import imp
import nemubot.server.abstract
imp.reload(nemubot.server.abstract)
import nemubot.server.socket
imp.reload(nemubot.server.socket)
import nemubot.server.IRC
imp.reload(nemubot.server.IRC)
import nemubot.server.message
imp.reload(nemubot.server.message)
nemubot.server.message.reload()

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,71 +14,64 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import io
import logging import logging
import queue import queue
from nemubot.server import _lock, _rlist, _wlist, _xlist #from nemubot.bot import sync_act
# Extends from IOBase in order to be compatible with select function
class AbstractServer(io.IOBase): class AbstractServer:
"""An abstract server: handle communication with an IM server""" """An abstract server: handle communication with an IM server"""
def __init__(self, send_callback=None): def __init__(self, name=None, **kwargs):
"""Initialize an abstract server """Initialize an abstract server
Keyword argument: Keyword argument:
send_callback -- Callback when developper want to send a message name -- Identifier of the socket, for convinience
""" """
if not hasattr(self, "id"): self._name = name
raise Exception("No id defined for this server. Please set one!")
self.logger = logging.getLogger("nemubot.server." + self.id) super().__init__(**kwargs)
self.logger = logging.getLogger("nemubot.server." + str(self.name))
self._readbuffer = b''
self._sending_queue = queue.Queue() self._sending_queue = queue.Queue()
if send_callback is not None:
self._send_callback = send_callback
@property
def name(self):
if self._name is not None:
return self._name
else: else:
self._send_callback = self._write_select return self.fileno()
# Open/close # Open/close
def __enter__(self): def connect(self, *args, **kwargs):
self.open() """Register the server in _poll"""
return self
self.logger.info("Opening connection")
super().connect(*args, **kwargs)
self._on_connect()
def _on_connect(self):
sync_act("sckt", "register", self.fileno())
def __exit__(self, type, value, traceback): def close(self, *args, **kwargs):
self.close() """Unregister the server from _poll"""
self.logger.info("Closing connection")
def open(self): if self.fileno() > 0:
"""Generic open function that register the server un _rlist in case sync_act("sckt", "unregister", self.fileno())
of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if not hasattr(self, "_open") or self._open():
_rlist.append(self)
_xlist.append(self)
return True
return False
super().close(*args, **kwargs)
def close(self):
"""Generic close function that register the server un _{r,w,x}list in
case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
with _lock:
if not hasattr(self, "_close") or self._close():
if self in _rlist:
_rlist.remove(self)
if self in _wlist:
_wlist.remove(self)
if self in _xlist:
_xlist.remove(self)
return True
return False
# Writes # Writes
@ -90,13 +83,16 @@ class AbstractServer(io.IOBase):
message -- message to send message -- message to send
""" """
self._send_callback(message) self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to write queue", message)
sync_act("sckt", "write", self.fileno())
def write_select(self): def async_write(self):
"""Internal function used by the select function""" """Internal function used when the file descriptor is writable"""
try: try:
_wlist.remove(self) sync_act("sckt", "unwrite", self.fileno())
while not self._sending_queue.empty(): while not self._sending_queue.empty():
self._write(self._sending_queue.get_nowait()) self._write(self._sending_queue.get_nowait())
self._sending_queue.task_done() self._sending_queue.task_done()
@ -105,19 +101,6 @@ class AbstractServer(io.IOBase):
pass pass
def _write_select(self, message):
"""Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to Queue", message)
if self not in _wlist:
_wlist.append(self)
def send_response(self, response): def send_response(self, response):
"""Send a formated Message class """Send a formated Message class
@ -140,13 +123,39 @@ class AbstractServer(io.IOBase):
# Read # Read
def async_read(self):
"""Internal function used when the file descriptor is readable
Returns:
A list of fully received messages
"""
ret, self._readbuffer = self.lex(self._readbuffer + self.read())
for r in ret:
yield r
def lex(self, buf):
"""Assume lexing in default case is per line
Argument:
buf -- buffer to lex
"""
msgs = buf.split(b'\r\n')
partial = msgs.pop()
return msgs, partial
def parse(self, msg): def parse(self, msg):
raise NotImplemented raise NotImplemented
# Exceptions # Exceptions
def exception(self): def exception(self, flags):
"""Exception occurs in fd""" """Exception occurs on fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.id) self.close()

View file

@ -22,34 +22,30 @@ class TestFactory(unittest.TestCase):
def test_IRC1(self): def test_IRC1(self):
from nemubot.server.IRC import IRC as IRCServer from nemubot.server.IRC import IRC as IRCServer
from nemubot.server.IRC import IRC_secure as IRCSServer
# <host>: If omitted, the client must connect to a prespecified default IRC server. # <host>: If omitted, the client must connect to a prespecified default IRC server.
server = factory("irc:///") server = factory("irc:///")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "localhost") self.assertEqual(server.host, "localhost")
self.assertFalse(server.ssl)
server = factory("ircs:///") server = factory("ircs:///")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "localhost") self.assertEqual(server.host, "localhost")
self.assertTrue(server.ssl)
server = factory("irc://host1") server = factory("irc://host1")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host1") self.assertEqual(server.host, "host1")
self.assertFalse(server.ssl)
server = factory("irc://host2:6667") server = factory("irc://host2:6667")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCServer)
self.assertEqual(server.host, "host2") self.assertEqual(server.host, "host2")
self.assertEqual(server.port, 6667) self.assertEqual(server.port, 6667)
self.assertFalse(server.ssl)
server = factory("ircs://host3:194/") server = factory("ircs://host3:194/")
self.assertIsInstance(server, IRCServer) self.assertIsInstance(server, IRCSServer)
self.assertEqual(server.host, "host3") self.assertEqual(server.host, "host3")
self.assertEqual(server.port, 194) self.assertEqual(server.port, 194)
self.assertTrue(server.ssl)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -146,7 +146,7 @@ class IRC(Abstract):
receivers = self.decode(self.params[0]).split(',') receivers = self.decode(self.params[0]).split(',')
common_args = { common_args = {
"server": srv.id, "server": srv.name,
"date": self.tags["time"], "date": self.tags["time"],
"to": receivers, "to": receivers,
"to_response": [r if r != srv.nick else self.nick for r in receivers], "to_response": [r if r != srv.nick else self.nick for r in receivers],

View file

@ -13,9 +13,3 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
class PromptError(Exception):
def __init__(self, message):
super(PromptError, self).__init__(message)
self.message = message

View file

@ -1,5 +1,5 @@
# Nemubot is a smart and modulable IM bot. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2015 Mercier Pierre-Olivier # Copyright (C) 2012-2016 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -14,99 +14,33 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import socket
import ssl
import nemubot.message as message
from nemubot.message.printer.socket import Socket as SocketPrinter from nemubot.message.printer.socket import Socket as SocketPrinter
from nemubot.server.abstract import AbstractServer from nemubot.server.abstract import AbstractServer
class SocketServer(AbstractServer): class _Socket(asyncio.Protocol):
"""Concrete implementation of a socket connexion (can be wrapped with TLS)""" """Concrete implementation of a socket connection"""
def __init__(self, sock_location=None, host=None, port=None, ssl=False, socket=None, id=None): def __init__(self, printer=SocketPrinter, **kwargs):
if id is not None: """Create a server socket
self.id = id """
AbstractServer.__init__(self)
if sock_location is not None: super().__init__(**kwargs)
self.filename = sock_location
elif host is not None:
self.host = host
self.port = int(port)
self.ssl = ssl
self.socket = socket
self.readbuffer = b'' self.readbuffer = b''
self.printer = SocketPrinter self.printer = printer
def fileno(self):
return self.socket.fileno() if self.socket else None
@property
def connected(self):
"""Indicator of the connection aliveness"""
return self.socket is not None
# Open/close
def _open(self):
import os
import socket
if self.connected:
return True
try:
if hasattr(self, "filename"):
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.connect(self.filename)
self.logger.info("Connected to %s", self.filename)
else:
self.socket = socket.create_connection((self.host, self.port))
self.logger.info("Connected to %s:%d", self.host, self.port)
except socket.error as e:
self.socket = None
self.logger.critical("Unable to connect to %s:%d: %s",
self.host, self.port,
os.strerror(e.errno))
return False
# Wrap the socket for SSL
if self.ssl:
import ssl
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.socket = ctx.wrap_socket(self.socket)
return True
def _close(self):
import socket
from nemubot.server import _lock
_lock.release()
self._sending_queue.join()
_lock.acquire()
if self.connected:
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
except socket.error:
pass
self.socket = None
return True
# Write # Write
def _write(self, cnt): def _write(self, cnt):
if not self.connected: self.sendall(cnt)
return
self.socket.send(cnt)
def format(self, txt): def format(self, txt):
@ -118,80 +52,134 @@ class SocketServer(AbstractServer):
# Read # Read
def read(self): def recv(self, n=1024):
if not self.connected: return super().recv(n)
return []
raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop()
for line in temp:
yield line
class SocketListener(AbstractServer): def parse(self, line):
"""Implement a default behaviour for socket"""
import shlex
def __init__(self, new_server_cb, id, sock_location=None, host=None, port=None, ssl=None): line = line.strip().decode()
self.id = id try:
AbstractServer.__init__(self) args = shlex.split(line)
self.new_server_cb = new_server_cb except ValueError:
self.sock_location = sock_location args = line.split(' ')
self.host = host
self.port = port if len(args):
self.ssl = ssl yield message.Command(cmd=args[0], args=args[1:], server=self.fileno(), to=["you"], frm="you")
self.nb_son = 0
def fileno(self): def subparse(self, orig, cnt):
return self.socket.fileno() if self.socket else None for m in self.parse(cnt):
m.to = orig.to
m.frm = orig.frm
m.date = orig.date
yield m
class _SocketServer(_Socket):
def __init__(self, host, port, bind=None, **kwargs):
super().__init__(family=socket.AF_INET, **kwargs)
assert(host is not None)
assert(isinstance(port, int))
self._host = host
self._port = port
self._bind = bind
@property @property
def connected(self): def host(self):
"""Indicator of the connection aliveness""" return self._host
return self.socket is not None
def _open(self): def connect(self):
import os self.logger.info("Connection to %s:%d", self._host, self._port)
import socket super().connect((self._host, self._port))
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) if self._bind:
if self.sock_location is not None: super().bind(self._bind)
try:
os.remove(self.sock_location)
except FileNotFoundError:
pass
self.socket.bind(self.sock_location)
elif self.host is not None and self.port is not None:
self.socket.bind((self.host, self.port))
self.socket.listen(5)
return True
def _close(self): class SocketServer(_SocketServer, socket.socket):
pass
class SecureSocketServer(_SocketServer, ssl.SSLSocket):
pass
class UnixSocket:
def __init__(self, location, **kwargs):
super().__init__(family=socket.AF_UNIX, **kwargs)
self._socket_path = location
def connect(self):
self.logger.info("Connection to unix://%s", self._socket_path)
super().connect(self._socket_path)
class _Listener:
def __init__(self, new_server_cb, instanciate=_Socket, **kwargs):
super().__init__(**kwargs)
self._instanciate = instanciate
self._new_server_cb = new_server_cb
def read(self):
conn, addr = self.accept()
fileno = conn.fileno()
self.logger.info("Accept new connection from %s (fd=%d)", addr, fileno)
ss = self._instanciate(name=self.name + "#" + str(fileno), fileno=conn.detach())
ss.connect = ss._on_connect
self._new_server_cb(ss, autoconnect=True)
return b''
class UnixSocketListener(_Listener, UnixSocket, _Socket, socket.socket):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self):
self.logger.info("Creating Unix socket at unix://%s", self._socket_path)
try:
os.remove(self._socket_path)
except FileNotFoundError:
pass
self.bind(self._socket_path)
self.listen(5)
self.logger.info("Socket ready for accepting new connections")
self._on_connect()
def close(self):
import os import os
import socket import socket
try: try:
self.socket.shutdown(socket.SHUT_RDWR) self.shutdown(socket.SHUT_RDWR)
self.socket.close()
if self.sock_location is not None:
os.remove(self.sock_location)
except socket.error: except socket.error:
pass pass
# Read super().close()
def read(self): try:
if not self.connected: if self._socket_path is not None:
return [] os.remove(self._socket_path)
except:
conn, addr = self.socket.accept() pass
self.nb_son += 1
ss = SocketServer(id=self.id + "#" + str(self.nb_son), socket=conn)
self.new_server_cb(ss)
return []

View file

@ -13,29 +13,3 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
def reload():
import imp
import nemubot.tools.config
imp.reload(nemubot.tools.config)
import nemubot.tools.countdown
imp.reload(nemubot.tools.countdown)
import nemubot.tools.feed
imp.reload(nemubot.tools.feed)
import nemubot.tools.date
imp.reload(nemubot.tools.date)
import nemubot.tools.human
imp.reload(nemubot.tools.human)
import nemubot.tools.web
imp.reload(nemubot.tools.web)
import nemubot.tools.xmlparser
imp.reload(nemubot.tools.xmlparser)
import nemubot.tools.xmlparser.node
imp.reload(nemubot.tools.xmlparser.node)

View file

@ -69,7 +69,6 @@ setup(
'nemubot.hooks.keywords', 'nemubot.hooks.keywords',
'nemubot.message', 'nemubot.message',
'nemubot.message.printer', 'nemubot.message.printer',
'nemubot.prompt',
'nemubot.server', 'nemubot.server',
'nemubot.server.message', 'nemubot.server.message',
'nemubot.tools', 'nemubot.tools',