Change add_server behaviour, fix IRC parameters parsing, can use with Python statement for managing server scope

This commit is contained in:
nemunaire 2014-10-09 07:37:52 +02:00
parent f9ee1fe898
commit 4dd837cf4b
8 changed files with 249 additions and 68 deletions

19
bot.py
View File

@ -25,7 +25,7 @@ import threading
import time import time
import uuid import uuid
__version__ = '3.4.dev1' __version__ = '3.4.dev2'
__author__ = 'nemunaire' __author__ = 'nemunaire'
from consumer import Consumer, EventConsumer, MessageConsumer from consumer import Consumer, EventConsumer, MessageConsumer
@ -33,8 +33,6 @@ from event import ModuleEvent
from hooks.messagehook import MessageHook from hooks.messagehook import MessageHook
from hooks.manager import HooksManager from hooks.manager import HooksManager
from networkbot import NetworkBot from networkbot import NetworkBot
from server.IRC import IRC as IRCServer
from server.DCC import DCC
logger = logging.getLogger("nemubot.bot") logger = logging.getLogger("nemubot.bot")
@ -312,13 +310,20 @@ class Bot(threading.Thread):
c.start() c.start()
def add_server(self, node, nick, owner, realname): def add_server(self, srv, autoconnect=False):
"""Add a new server to the context""" """Add a new server to the context
srv = IRCServer(node, nick, owner, realname)
Arguments:
srv -- a concrete AbstractServer instance
autoconnect -- connect after add?
"""
if srv.id not in self.servers: if srv.id not in self.servers:
self.servers[srv.id] = srv self.servers[srv.id] = srv
srv.open() if autoconnect:
srv.open()
return True return True
else: else:
return False return False

View File

@ -19,8 +19,9 @@
import traceback import traceback
import sys import sys
from networkbot import NetworkBot
from hooks import hook from hooks import hook
from message import TextMessage
from networkbot import NetworkBot
nemubotversion = 3.4 nemubotversion = 3.4
NODATA = True NODATA = True
@ -198,7 +199,7 @@ def send(data, toks, context, prompt):
print ("send: not enough arguments.") print ("send: not enough arguments.")
return return
srv.send_msg_final(chan, toks[rd]) srv.send_response(TextMessage(" ".join(toks[rd:]), server=None, to=[chan]))
return "done" return "done"
@hook("prompt_cmd", "zap") @hook("prompt_cmd", "zap")

View File

@ -1,8 +1,8 @@
#!/usr/bin/python3 #!/usr/bin/env python3.2
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Nemubot is a modulable IRC bot, built around XML configuration files. # Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012 Mercier Pierre-Olivier # Copyright (C) 2012-2014 Mercier Pierre-Olivier
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -21,7 +21,6 @@ import imp
import logging import logging
import os import os
import sys import sys
import traceback
import bot import bot
import prompt import prompt

View File

@ -19,10 +19,12 @@
import imp import imp
import logging import logging
import os import os
import xmlparser
logger = logging.getLogger("nemubot.prompt.builtins") logger = logging.getLogger("nemubot.prompt.builtins")
from server.IRC import IRC as IRCServer
import xmlparser
def end(toks, context, prompt): def end(toks, context, prompt):
"""Quit the prompt for reload or exit""" """Quit the prompt for reload or exit"""
if toks[0] == "refresh": if toks[0] == "refresh":
@ -67,16 +69,58 @@ def load_file(filename, context):
or config.getName() == "nemubotconfig"): or config.getName() == "nemubotconfig"):
# Preset each server in this file # Preset each server in this file
for server in config.getNodes("server"): for server in config.getNodes("server"):
ip = server["ip"] if server.hasAttribute("ip") else config["ip"] opts = {
nick = server["nick"] if server.hasAttribute("nick") else config["nick"] "host": server["host"],
owner = server["owner"] if server.hasAttribute("owner") else config["owner"] "ssl": server.hasAttribute("ssl") and server["ssl"].lower() == "true",
realname = server["realname"] if server.hasAttribute("realname") else config["realname"]
if context.add_server(server, nick, owner, realname): "nick": server["nick"] if server.hasAttribute("nick") else config["nick"],
print("Server `%s:%s' successfully added." % "owner": server["owner"] if server.hasAttribute("owner") else config["owner"],
(server["host"], server["port"])) }
# Optional keyword arguments
for optional_opt in [ "port", "realname", "password", "encoding", "caps" ]:
if server.hasAttribute(optional_opt):
opts[optional_opt] = server[optional_opt]
elif optional_opt in config:
opts[optional_opt] = config[optional_opt]
# Command to send on connection
if "on_connect" in server:
def on_connect():
yield server["on_connect"]
opts["on_connect"] = on_connect
# Channels to autojoin on connection
if server.hasNode("channel"):
opts["channels"] = list()
for chn in server.getNodes("channel"):
opts["channels"].append((chn["name"], chn["password"]) if chn["password"] is not None else chn["name"])
# Server/client capabilities
if "caps" in server or "caps" in config:
capsl = (server["caps"] if server.hasAttribute("caps") else config["caps"]).lower()
if capsl == "no" or capsl == "off" or capsl == "false":
opts["caps"] = None
else:
opts["caps"] = capsl.split(',')
else: else:
print("Server `%s:%s' already added, skiped." % opts["caps"] = list()
(server["host"], server["port"]))
# Bind the protocol asked to the corresponding implementation
if "protocol" not in server or server["protocol"] == "irc":
srvcls = IRCServer
else:
raise Exception("Unhandled protocol '%s'" % server["protocol"])
# Initialize the server
srv = srvcls(**opts)
# Add the server in the context
if context.add_server(srv,
"autoconnect" in server and server["autoconnect"].lower() != "false"):
print("Server '%s' successfully added." % srv.id)
else:
print("Can't add server '%s'." % srv.id)
# Load module and their configuration # Load module and their configuration
for mod in config.getNodes("module"): for mod in config.getNodes("module"):

View File

@ -26,37 +26,46 @@ from channel import Channel
import message import message
from message.printer.IRC import IRC as IRCPrinter from message.printer.IRC import IRC as IRCPrinter
from server.socket import SocketServer from server.socket import SocketServer
import tools
class IRC(SocketServer): class IRC(SocketServer):
def __init__(self, node, nick, owner, realname): def __init__(self, owner, nick="nemubot", host="localhost", port=6667,
self.id = nick + "@" + node["host"] + ":" + node["port"] ssl=False, password=None, realname="Nemubot",
self.printer = IRCPrinter encoding="utf-8", caps=None, channels=list(),
SocketServer.__init__(self, on_connect=None):
node["host"], """Prepare a connection with an IRC server
node["port"],
node["password"],
node.hasAttribute("ssl") and node["ssl"].lower() == "true")
Keyword arguments:
owner -- bot's owner
nick -- bot's nick
host -- host to join
port -- port on the host to reach
ssl -- is this server using a TLS socket
password -- if a password is required to connect to the server
realname -- the bot's realname
encoding -- the encoding used on the whole server
caps -- client capabilities to register on the server
channels -- list of channels to join on connection (if a channel is password protected, give a tuple: (channel_name, password))
on_connect -- generator to call when connection is done
"""
self.id = nick + "@" + host + ":" + port
self.printer = IRCPrinter
SocketServer.__init__(self, host=host, port=port, ssl=ssl)
self.password = password
self.nick = nick self.nick = nick
self.owner = owner self.owner = owner
self.realname = realname self.realname = realname
#Keep a list of connected channels self.encoding = encoding
# Keep a list of joined channels
self.channels = dict() self.channels = dict()
if node.hasAttribute("encoding"): # Server/client capabilities
self.encoding = node["encoding"] self.capabilities = caps
else:
self.encoding = "utf-8"
if node.hasAttribute("caps"):
if node["caps"].lower() == "no":
self.capabilities = None
else:
self.capabilities = node["caps"].split(",")
else:
self.capabilities = list()
# Register CTCP capabilities # Register CTCP capabilities
self.ctcp_capabilities = dict() self.ctcp_capabilities = dict()
@ -68,7 +77,7 @@ class IRC(SocketServer):
def _ctcp_dcc(msg, cmds): def _ctcp_dcc(msg, cmds):
"""Response to DCC CTCP message""" """Response to DCC CTCP message"""
try: try:
ip = srv.toIP(int(cmds[3])) ip = tools.toIP(int(cmds[3]))
port = int(cmds[4]) port = int(cmds[4])
conn = DCC(srv, msg.sender) conn = DCC(srv, msg.sender)
except: except:
@ -98,6 +107,7 @@ class IRC(SocketServer):
self.logger.debug("CTCP capabilities setup: %s", ", ".join(self.ctcp_capabilities)) self.logger.debug("CTCP capabilities setup: %s", ", ".join(self.ctcp_capabilities))
# Register hooks on some IRC CMD # Register hooks on some IRC CMD
self.hookscmd = dict() self.hookscmd = dict()
@ -109,14 +119,15 @@ class IRC(SocketServer):
# Respond to 001 # Respond to 001
def _on_connect(msg): def _on_connect(msg):
# First, send user defined command # First, send user defined command
if node.hasAttribute("on_connect"): if on_connect is not None:
self.write(node["on_connect"]) for oc in on_connect():
self.write(oc)
# Then, JOIN some channels # Then, JOIN some channels
for chn in node.getNodes("channel"): for chn in channels:
if chn["password"] is not None: if isinstance(chn, tuple):
self.write("JOIN %s %s" % (chn["name"], chn["password"])) self.write("JOIN %s %s" % chn)
else: else:
self.write("JOIN %s" % chn["name"]) self.write("JOIN %s" % chn)
self.hookscmd["001"] = _on_connect self.hookscmd["001"] = _on_connect
# Respond to ERROR # Respond to ERROR
@ -141,9 +152,9 @@ class IRC(SocketServer):
def _on_join(msg): def _on_join(msg):
if len(msg.params) == 0: return if len(msg.params) == 0: return
for chname in msg.params[0].split(b","): for chname in msg.decode(msg.params[0]).split(","):
# Register the channel # Register the channel
chan = Channel(msg.decode(chname)) chan = Channel(chname)
self.channels[chname] = chan self.channels[chname] = chan
self.hookscmd["JOIN"] = _on_join self.hookscmd["JOIN"] = _on_join
# Respond to PART # Respond to PART
@ -197,6 +208,8 @@ class IRC(SocketServer):
self.hookscmd["PRIVMSG"] = _on_ctcp self.hookscmd["PRIVMSG"] = _on_ctcp
# Open/close
def _open(self): def _open(self):
if SocketServer._open(self): if SocketServer._open(self):
if self.password is not None: if self.password is not None:
@ -214,6 +227,10 @@ class IRC(SocketServer):
return SocketServer._close(self) return SocketServer._close(self)
# Writes: as inherited
# Read
def read(self): def read(self):
for line in SocketServer.read(self): for line in SocketServer.read(self):
msg = IRCMessage(line, self.encoding) msg = IRCMessage(line, self.encoding)
@ -226,6 +243,8 @@ class IRC(SocketServer):
yield mes yield mes
# Parsing stuff
mgx = re.compile(b'''^(?:@(?P<tags>[^ ]+)\ )? mgx = re.compile(b'''^(?:@(?P<tags>[^ ]+)\ )?
(?::(?P<prefix> (?::(?P<prefix>
(?P<nick>[^!@ ]+) (?P<nick>[^!@ ]+)
@ -269,7 +288,7 @@ class IRCMessage:
self.cmd = self.decode(p.group("command")) self.cmd = self.decode(p.group("command"))
# Parse params # Parse params
if p.group("params") is not None: if p.group("params") is not None and p.group("params") != b'':
for param in p.group("params").strip().split(b' '): for param in p.group("params").strip().split(b' '):
self.params.append(param) self.params.append(param)
@ -278,7 +297,13 @@ class IRCMessage:
def add_tag(self, key, value=None): def add_tag(self, key, value=None):
"""Add an IRCv3.2 Message Tags""" """Add an IRCv3.2 Message Tags
Arguments:
key -- tag identifier (unique for the message)
value -- optional value for the tag
"""
# Treat special tags # Treat special tags
if key == "time": if key == "time":
value = datetime.fromtimestamp(calendar.timegm(time.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ")), timezone.utc) value = datetime.fromtimestamp(calendar.timegm(time.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ")), timezone.utc)
@ -289,11 +314,17 @@ class IRCMessage:
@property @property
def is_ctcp(self): def is_ctcp(self):
"""Analyze a message, to determine if this is a CTCP one"""
return self.cmd == "PRIVMSG" and len(self.params) == 2 and len(self.params[1]) > 1 and (self.params[1][0] == 0x01 or self.params[1][1] == 0x01) return self.cmd == "PRIVMSG" and len(self.params) == 2 and len(self.params[1]) > 1 and (self.params[1][0] == 0x01 or self.params[1][1] == 0x01)
def decode(self, s): def decode(self, s):
"""Decode the content string usign a specific encoding""" """Decode the content string usign a specific encoding
Argument:
s -- string to decode
"""
if isinstance(s, bytes): if isinstance(s, bytes):
try: try:
s = s.decode() s = s.decode()
@ -326,6 +357,12 @@ class IRCMessage:
def to_message(self, srv): def to_message(self, srv):
"""Convert to one of concrete implementation of AbstractMessage
Argument:
srv -- the server from the message was received
"""
if self.cmd == "PRIVMSG" or self.cmd == "NOTICE": if self.cmd == "PRIVMSG" or self.cmd == "NOTICE":
receivers = self.decode(self.params[0]).split(',') receivers = self.decode(self.params[0]).split(',')
@ -344,6 +381,13 @@ class IRCMessage:
else: else:
text = self.decode(self.params[1]) text = self.decode(self.params[1])
if text.find(srv.nick) == 0 and len(text) > len(srv.nick) + 2 and text[len(srv.nick)] == ":":
designated = srv.nick
text = text[len(srv.nick) + 1:].strip()
else:
designated = None
# Is this a command?
if len(text) > 1 and text[0] == '!': if len(text) > 1 and text[0] == '!':
text = text[1:].strip() text = text[1:].strip()
@ -355,10 +399,11 @@ class IRCMessage:
return message.Command(cmd=args[0], args=args[1:], **common_args) return message.Command(cmd=args[0], args=args[1:], **common_args)
elif text.find(srv.nick) == 0 and len(text) > len(srv.nick) + 2 and text[len(srv.nick)] == ":": # Is this an ask for this bot?
text = text[len(srv.nick) + 1:].strip() elif designated is not None:
return message.DirectAsk(designated=srv.nick, message=text, **common_args) return message.DirectAsk(designated=designated, message=text, **common_args)
# Normal message
else: else:
return message.TextMessage(message=text, **common_args) return message.TextMessage(message=text, **common_args)

View File

@ -38,6 +38,9 @@ class AbstractServer(io.IOBase):
send_callback -- Callback when developper want to send a message send_callback -- Callback when developper want to send a message
""" """
if not hasattr(self, "id"):
raise Exception("No id defined for this server. Please set one!")
self.logger = logging.getLogger("nemubot.server." + self.id) self.logger = logging.getLogger("nemubot.server." + self.id)
self._sending_queue = queue.Queue() self._sending_queue = queue.Queue()
if send_callback is not None: if send_callback is not None:
@ -46,8 +49,20 @@ class AbstractServer(io.IOBase):
self._send_callback = self._write_select self._send_callback = self._write_select
# Open/close
def __enter__(self):
self.open()
return self
def __exit__(self, type, value, traceback):
self.close()
def open(self): def open(self):
"""Generic open function that register the server un _rlist in case of successful _open""" """Generic open function that register the server un _rlist in case of successful _open"""
self.logger.info("Opening connection to %s", self.id)
if self._open(): if self._open():
_rlist.append(self) _rlist.append(self)
_xlist.append(self) _xlist.append(self)
@ -55,6 +70,7 @@ class AbstractServer(io.IOBase):
def close(self): def close(self):
"""Generic close function that register the server un _{r,w,x}list in case of successful _close""" """Generic close function that register the server un _{r,w,x}list in case of successful _close"""
self.logger.info("Closing connection to %s", self.id)
if self._close(): if self._close():
if self in _rlist: if self in _rlist:
_rlist.remove(self) _rlist.remove(self)
@ -64,10 +80,18 @@ class AbstractServer(io.IOBase):
_xlist.remove(self) _xlist.remove(self)
# Writes
def write(self, message): def write(self, message):
"""Send a message to the server using send_callback""" """Asynchronymously send a message to the server using send_callback
Argument:
message -- message to send
"""
self._send_callback(message) self._send_callback(message)
def write_select(self): def write_select(self):
"""Internal function used by the select function""" """Internal function used by the select function"""
try: try:
@ -79,20 +103,27 @@ class AbstractServer(io.IOBase):
except queue.Empty: except queue.Empty:
pass pass
def _write_select(self, message): def _write_select(self, message):
"""Send a message to the server safely through select""" """Send a message to the server safely through select
Argument:
message -- message to send
"""
self._sending_queue.put(self.format(message)) self._sending_queue.put(self.format(message))
self.logger.debug("Message '%s' appended to Queue", message) self.logger.debug("Message '%s' appended to Queue", message)
if self not in _wlist: if self not in _wlist:
_wlist.append(self) _wlist.append(self)
def exception(self):
"""Exception occurs in fd"""
print("Unhandle file descriptor exception on server " + self.id)
def send_response(self, response): def send_response(self, response):
"""Send a formated Message class""" """Send a formated Message class
Argument:
response -- message to send
"""
if response is None: if response is None:
return return
@ -104,3 +135,11 @@ class AbstractServer(io.IOBase):
vprnt = self.printer() vprnt = self.printer()
response.accept(vprnt) response.accept(vprnt)
self.write(vprnt.pp) self.write(vprnt.pp)
# Exceptions
def exception(self):
"""Exception occurs in fd"""
self.logger.warning("Unhandle file descriptor exception on server %s",
self.id)

View File

@ -23,23 +23,28 @@ from server import AbstractServer
class SocketServer(AbstractServer): class SocketServer(AbstractServer):
def __init__(self, host, port=6667, password=None, ssl=False): def __init__(self, host, port, ssl=False):
AbstractServer.__init__(self) AbstractServer.__init__(self)
self.host = host self.host = host
self.port = int(port) self.port = int(port)
self.password = password
self.ssl = ssl self.ssl = ssl
self.socket = None self.socket = None
self.readbuffer = b'' self.readbuffer = b''
def fileno(self): def fileno(self):
return self.socket.fileno() if self.socket else None return self.socket.fileno() if self.socket else None
@property @property
def connected(self): def connected(self):
"""Indicator of the connection aliveness"""
return self.socket is not None return self.socket is not None
# Open/close
def _open(self): def _open(self):
# Create the socket # Create the socket
self.socket = socket.socket() self.socket = socket.socket()
@ -60,6 +65,7 @@ class SocketServer(AbstractServer):
return True return True
def _close(self): def _close(self):
self._sending_queue.join() self._sending_queue.join()
if self.connected: if self.connected:
@ -71,18 +77,27 @@ class SocketServer(AbstractServer):
self.socket = None self.socket = None
return True return True
# Write
def _write(self, cnt): def _write(self, cnt):
if not self.connected: return if not self.connected: return
self.socket.send(cnt) self.socket.send(cnt)
def format(self, txt): def format(self, txt):
if isinstance(txt, bytes): if isinstance(txt, bytes):
return txt + b'\r\n' return txt + b'\r\n'
else: else:
return txt.encode() + b'\r\n' return txt.encode() + b'\r\n'
# Read
def read(self): def read(self):
if not self.connected: return if not self.connected: return
raw = self.socket.recv(1024) raw = self.socket.recv(1024)
temp = (self.readbuffer + raw).split(b'\r\n') temp = (self.readbuffer + raw).split(b'\r\n')
self.readbuffer = temp.pop() self.readbuffer = temp.pop()

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Nemubot is a smart and modulable IM bot.
# Copyright (C) 2012-2014 nemunaire
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import imp
def intToIP(n):
ip = ""
for i in range(0,4):
mod = n % 256
ip = "%d.%s" % (mod, ip)
n = (n - mod) / 256
return ip[:len(ip) - 1]
def ipToInt(ip):
sum = 0
for b in ip.split("."):
sum = 256 * sum + int(b)
return sum